diff --git a/examples/jagged_dense_add.py b/examples/jagged_dense_add.py index 53cd0293e1..fad59035fc 100644 --- a/examples/jagged_dense_add.py +++ b/examples/jagged_dense_add.py @@ -17,6 +17,7 @@ import helion from helion._testing import DEVICE +from helion._testing import LONG_INT_TYPE from helion._testing import run_example import helion.language as hl @@ -143,10 +144,15 @@ def random_jagged_2d( - x_offsets: (num_rows+1) tensor with offsets for each row """ # random positive K_i for each row - lengths = torch.randint(1, max_cols + 1, (num_rows,), device=device) + lengths = torch.randint( + 1, max_cols + 1, (num_rows,), dtype=LONG_INT_TYPE, device=device + ) # prefix-sum -> offsets x_offsets = torch.cat( - [torch.zeros(1, dtype=torch.long, device=device), torch.cumsum(lengths, dim=0)] + [ + torch.zeros(1, dtype=LONG_INT_TYPE, device=device), + torch.cumsum(lengths, dim=0, dtype=LONG_INT_TYPE), + ] ) # total nnz nnz = int(x_offsets[-1]) diff --git a/examples/jagged_dense_bmm.py b/examples/jagged_dense_bmm.py index ce7fe4162a..78f02dcca1 100644 --- a/examples/jagged_dense_bmm.py +++ b/examples/jagged_dense_bmm.py @@ -35,6 +35,7 @@ import helion from helion._testing import DEVICE +from helion._testing import LONG_INT_TYPE from helion._testing import run_example import helion.language as hl @@ -127,9 +128,11 @@ def random_input( max_seq_len: int = 3, dtype: torch.dtype = torch.float32, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - lengths = torch.randint(max_seq_len + 1, size=(batch_size,), device=DEVICE) - seq_offsets = torch.zeros((batch_size + 1,), dtype=torch.int64, device=DEVICE) - seq_offsets[1:] = torch.cumsum(lengths, dim=0) + lengths = torch.randint( + max_seq_len + 1, size=(batch_size,), dtype=LONG_INT_TYPE, device=DEVICE + ) + seq_offsets = torch.zeros((batch_size + 1,), dtype=LONG_INT_TYPE, device=DEVICE) + seq_offsets[1:] = torch.cumsum(lengths, dim=0, dtype=LONG_INT_TYPE) jagged_size = int(seq_offsets[-1].item()) jagged = ( torch.empty((jagged_size, D), dtype=dtype, device=DEVICE) diff --git a/examples/jagged_layer_norm.py b/examples/jagged_layer_norm.py index aaa38bb304..f000025077 100644 --- a/examples/jagged_layer_norm.py +++ b/examples/jagged_layer_norm.py @@ -25,6 +25,7 @@ import helion from helion._testing import DEVICE +from helion._testing import LONG_INT_TYPE from helion._testing import run_example import helion.language as hl @@ -191,13 +192,15 @@ def create_test_jagged_tensor( """Create test jagged tensor data.""" # Generate random sequence lengths - seq_lengths = torch.randint(1, max_seqlen + 1, (B,), device=device) + seq_lengths = torch.randint( + 1, max_seqlen + 1, (B,), dtype=LONG_INT_TYPE, device=device + ) # Create offsets x_offsets = torch.cat( [ - torch.zeros(1, dtype=torch.long, device=device), - torch.cumsum(seq_lengths, dim=0), + torch.zeros(1, dtype=LONG_INT_TYPE, device=device), + torch.cumsum(seq_lengths, dim=0, dtype=LONG_INT_TYPE), ] ) diff --git a/examples/jagged_mean.py b/examples/jagged_mean.py index 3eae5b87df..3d5c0eb3dc 100644 --- a/examples/jagged_mean.py +++ b/examples/jagged_mean.py @@ -19,6 +19,7 @@ import helion from helion._testing import DEVICE +from helion._testing import LONG_INT_TYPE from helion._testing import run_example import helion.language as hl @@ -171,9 +172,14 @@ def main() -> None: num_rows, max_cols = 32, 64 device = DEVICE - lengths = torch.randint(1, max_cols + 1, (num_rows,), device=device) + lengths = torch.randint( + 1, max_cols + 1, (num_rows,), dtype=LONG_INT_TYPE, device=device + ) x_offsets = torch.cat( - [torch.zeros(1, dtype=torch.long, device=device), torch.cumsum(lengths, dim=0)] + [ + torch.zeros(1, dtype=LONG_INT_TYPE, device=device), + torch.cumsum(lengths, dim=0, dtype=LONG_INT_TYPE), + ] ) nnz = int(x_offsets[-1]) M = 8 # number of features diff --git a/examples/jagged_softmax.py b/examples/jagged_softmax.py index abdb773898..421b5f30f4 100644 --- a/examples/jagged_softmax.py +++ b/examples/jagged_softmax.py @@ -19,6 +19,7 @@ import helion from helion._testing import DEVICE +from helion._testing import LONG_INT_TYPE from helion._testing import run_example import helion.language as hl @@ -154,9 +155,14 @@ def main() -> None: num_rows, max_cols = 512, 64 device = DEVICE - lengths = torch.randint(1, max_cols + 1, (num_rows,), device=device) + lengths = torch.randint( + 1, max_cols + 1, (num_rows,), dtype=LONG_INT_TYPE, device=device + ) x_offsets = torch.cat( - [torch.zeros(1, dtype=torch.long, device=device), torch.cumsum(lengths, dim=0)] + [ + torch.zeros(1, dtype=LONG_INT_TYPE, device=device), + torch.cumsum(lengths, dim=0, dtype=LONG_INT_TYPE), + ] ) nnz = int(x_offsets[-1]) M = 128 # number of features diff --git a/examples/jagged_sum.py b/examples/jagged_sum.py index 7798938c4c..0347f79f3c 100644 --- a/examples/jagged_sum.py +++ b/examples/jagged_sum.py @@ -19,6 +19,7 @@ import helion from helion._testing import DEVICE +from helion._testing import LONG_INT_TYPE from helion._testing import run_example import helion.language as hl @@ -151,13 +152,15 @@ def create_test_jagged_tensor( """Create test jagged tensor data.""" # Generate random sequence lengths - seq_lengths = torch.randint(1, max_seqlen + 1, (B,), device=device) + seq_lengths = torch.randint( + 1, max_seqlen + 1, (B,), dtype=LONG_INT_TYPE, device=device + ) # Create offsets x_offsets = torch.cat( [ - torch.zeros(1, dtype=torch.long, device=device), - torch.cumsum(seq_lengths, dim=0), + torch.zeros(1, dtype=LONG_INT_TYPE, device=device), + torch.cumsum(seq_lengths, dim=0, dtype=LONG_INT_TYPE), ] ) diff --git a/test/test_examples.py b/test/test_examples.py index 01aad5ce01..d7f9153ffb 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -983,11 +983,13 @@ def test_long_sum_manual_non_divisible(self): def test_jagged_mean(self): num_rows, max_cols = 32, 64 M = 8 # number of features - lengths = torch.randint(1, max_cols + 1, (num_rows,), device=DEVICE) + lengths = torch.randint( + 1, max_cols + 1, (num_rows,), dtype=LONG_INT_TYPE, device=DEVICE + ) x_offsets = torch.cat( [ - torch.zeros(1, dtype=torch.long, device=DEVICE), - torch.cumsum(lengths, dim=0), + torch.zeros(1, dtype=LONG_INT_TYPE, device=DEVICE), + torch.cumsum(lengths, dim=0, dtype=LONG_INT_TYPE), ] ) nnz = int(x_offsets[-1]) @@ -1284,11 +1286,13 @@ def test_layernorm_without_bias(self): def test_jagged_softmax(self): num_rows, max_cols = 128, 64 M = 8 # number of features - lengths = torch.randint(1, max_cols + 1, (num_rows,), device=DEVICE) + lengths = torch.randint( + 1, max_cols + 1, (num_rows,), dtype=LONG_INT_TYPE, device=DEVICE + ) x_offsets = torch.cat( [ - torch.zeros(1, dtype=torch.long, device=DEVICE), - torch.cumsum(lengths, dim=0), + torch.zeros(1, dtype=LONG_INT_TYPE, device=DEVICE), + torch.cumsum(lengths, dim=0, dtype=LONG_INT_TYPE), ] ) nnz = int(x_offsets[-1]) @@ -1328,7 +1332,7 @@ def test_jagged_hstu_attn(self): seq_offsets = torch.cat( [ torch.tensor([0], dtype=torch.int32, device=DEVICE), - torch.cumsum(seq_lengths, dim=0), + torch.cumsum(seq_lengths, dim=0, dtype=torch.int32), ] ) total_seq_len = int(seq_offsets[-1].item()) @@ -1650,11 +1654,13 @@ def test_nvfp4_gemm(self): def test_jagged_sum(self): num_rows, max_cols = 128, 64 M = 8 # number of features - lengths = torch.randint(1, max_cols + 1, (num_rows,), device=DEVICE) + lengths = torch.randint( + 1, max_cols + 1, (num_rows,), dtype=LONG_INT_TYPE, device=DEVICE + ) x_offsets = torch.cat( [ - torch.zeros(1, dtype=torch.long, device=DEVICE), - torch.cumsum(lengths, dim=0), + torch.zeros(1, dtype=LONG_INT_TYPE, device=DEVICE), + torch.cumsum(lengths, dim=0, dtype=LONG_INT_TYPE), ] ) nnz = int(x_offsets[-1]) @@ -1722,11 +1728,13 @@ def test_fused_linear_jsd(self): def test_jagged_layer_norm(self): num_rows, max_cols = 128, 64 M = 8 # number of features - lengths = torch.randint(1, max_cols + 1, (num_rows,), device=DEVICE) + lengths = torch.randint( + 1, max_cols + 1, (num_rows,), dtype=LONG_INT_TYPE, device=DEVICE + ) x_offsets = torch.cat( [ - torch.zeros(1, dtype=torch.long, device=DEVICE), - torch.cumsum(lengths, dim=0), + torch.zeros(1, dtype=LONG_INT_TYPE, device=DEVICE), + torch.cumsum(lengths, dim=0, dtype=LONG_INT_TYPE), ] ) nnz = int(x_offsets[-1])