@@ -68,14 +68,14 @@ def test_gpu_forward(self, operator: ParallelHyenaOperator):
6868 g = operator .num_groups
6969 dg = operator .group_dim
7070
71- x1 = torch .ones ((batch_size , seq_len , g , dg ), device = device )
72- x2 = torch .ones ((batch_size , seq_len , g , dg ), device = device )
73- v = torch .ones ((batch_size , seq_len , g , dg ), device = device )
71+ x1 = torch .ones ((batch_size , ( g * dg ), seq_len ), device = device )
72+ x2 = torch .ones ((batch_size , ( g * dg ), seq_len ), device = device )
73+ v = torch .ones ((batch_size , ( g * dg ), seq_len ), device = device )
7474
7575 output = operator (x1 , x2 , v )
7676 assert output .shape [0 ] == batch_size
77- assert output .shape [1 ] == seq_len
78- assert output .shape [2 ] == operator . hidden_size
77+ assert output .shape [1 ] == operator . hidden_size
78+ assert output .shape [2 ] == seq_len
7979
8080
8181class TestParallelShortHyenaOperator :
@@ -108,14 +108,14 @@ def test_gpu_forward(self, operator: ParallelShortHyenaOperator):
108108 g = operator .num_groups
109109 dg = operator .group_dim
110110
111- x1 = torch .ones ((batch_size , seq_len , g , dg ), device = device )
112- x2 = torch .ones ((batch_size , seq_len , g , dg ), device = device )
113- v = torch .ones ((batch_size , seq_len , g , dg ), device = device )
111+ x1 = torch .ones ((batch_size , ( g * dg ), seq_len ), device = device )
112+ x2 = torch .ones ((batch_size , ( g * dg ), seq_len ), device = device )
113+ v = torch .ones ((batch_size , ( g * dg ), seq_len ), device = device )
114114
115115 output = operator (x1 , x2 , v )
116116 assert output .shape [0 ] == batch_size
117- assert output .shape [1 ] == seq_len
118- assert output .shape [2 ] == operator . hidden_size
117+ assert output .shape [1 ] == operator . hidden_size
118+ assert output .shape [2 ] == seq_len
119119
120120
121121class TestParallelShortHyenaOperatorWithConvBias :
@@ -148,14 +148,14 @@ def test_gpu_forward(self, operator: ParallelShortHyenaOperator):
148148 g = operator .num_groups
149149 dg = operator .group_dim
150150
151- x1 = torch .ones ((batch_size , seq_len , g , dg ), device = device )
152- x2 = torch .ones ((batch_size , seq_len , g , dg ), device = device )
153- v = torch .ones ((batch_size , seq_len , g , dg ), device = device )
151+ x1 = torch .ones ((batch_size , ( g * dg ), seq_len ), device = device )
152+ x2 = torch .ones ((batch_size , ( g * dg ), seq_len ), device = device )
153+ v = torch .ones ((batch_size , ( g * dg ), seq_len ), device = device )
154154
155155 output = operator (x1 , x2 , v )
156156 assert output .shape [0 ] == batch_size
157- assert output .shape [1 ] == seq_len
158- assert output .shape [2 ] == operator . hidden_size
157+ assert output .shape [1 ] == operator . hidden_size
158+ assert output .shape [2 ] == seq_len
159159
160160
161161class TestParallelCausalDepthwiseConv1d :
0 commit comments