In Puzzle 9: Long sum, there is no mention of N1 while the function takes as input N1 and T. I think this is a typo and N1 can be safely removed. I tested my solution removing N1 from the function and it seemed to work.
For context, the problem statement
## Puzzle 7: Long Sum
Sum of a batch of numbers.
Uses one program blocks. Block size `B0` represents a range of batches of `x` of length `N0`.
Each element is of length `T`. Process it `B1 < T` elements at a time.
and the suggested change:
# before: def sum_kernel(x_ptr, z_ptr, N0, N1, T, B0: tl.constexpr, B1: tl.constexpr):
def sum_spec(x: Float32[Tensor, "4 200"]) -> Float32[Tensor, "4"]:
return x.sum(1)
@triton.jit
def sum_kernel(x_ptr, z_ptr, N0, T, B0: tl.constexpr, B1: tl.constexpr):
# TODO: add implementation
return
# before: test(sum_kernel, sum_spec, B={"B0": 1, "B1": 32}, nelem={"N0": 4, "N1": 32, "T": 200})
test(sum_kernel, sum_spec, B={"B0": 1, "B1": 32}, nelem={"N0": 4, "T": 200})
In Puzzle 9: Long sum, there is no mention of N1 while the function takes as input N1 and T. I think this is a typo and N1 can be safely removed. I tested my solution removing N1 from the function and it seemed to work.
For context, the problem statement
and the suggested change: