Skip to content

Commit c14f433

Browse files
committed
update the in/out shapes
Signed-off-by: Farhad Ramezanghorbani <farhadr@nvidia.com>
1 parent 261ca02 commit c14f433

1 file changed

Lines changed: 15 additions & 15 deletions

File tree

sub-packages/bionemo-evo2/tests/bionemo/evo2/test_hyena_operators.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -68,14 +68,14 @@ def test_gpu_forward(self, operator: ParallelHyenaOperator):
6868
g = operator.num_groups
6969
dg = operator.group_dim
7070

71-
x1 = torch.ones((batch_size, seq_len, g, dg), device=device)
72-
x2 = torch.ones((batch_size, seq_len, g, dg), device=device)
73-
v = torch.ones((batch_size, seq_len, g, dg), device=device)
71+
x1 = torch.ones((batch_size, (g * dg), seq_len), device=device)
72+
x2 = torch.ones((batch_size, (g * dg), seq_len), device=device)
73+
v = torch.ones((batch_size, (g * dg), seq_len), device=device)
7474

7575
output = operator(x1, x2, v)
7676
assert output.shape[0] == batch_size
77-
assert output.shape[1] == seq_len
78-
assert output.shape[2] == operator.hidden_size
77+
assert output.shape[1] == operator.hidden_size
78+
assert output.shape[2] == seq_len
7979

8080

8181
class TestParallelShortHyenaOperator:
@@ -108,14 +108,14 @@ def test_gpu_forward(self, operator: ParallelShortHyenaOperator):
108108
g = operator.num_groups
109109
dg = operator.group_dim
110110

111-
x1 = torch.ones((batch_size, seq_len, g, dg), device=device)
112-
x2 = torch.ones((batch_size, seq_len, g, dg), device=device)
113-
v = torch.ones((batch_size, seq_len, g, dg), device=device)
111+
x1 = torch.ones((batch_size, (g * dg), seq_len), device=device)
112+
x2 = torch.ones((batch_size, (g * dg), seq_len), device=device)
113+
v = torch.ones((batch_size, (g * dg), seq_len), device=device)
114114

115115
output = operator(x1, x2, v)
116116
assert output.shape[0] == batch_size
117-
assert output.shape[1] == seq_len
118-
assert output.shape[2] == operator.hidden_size
117+
assert output.shape[1] == operator.hidden_size
118+
assert output.shape[2] == seq_len
119119

120120

121121
class TestParallelShortHyenaOperatorWithConvBias:
@@ -148,14 +148,14 @@ def test_gpu_forward(self, operator: ParallelShortHyenaOperator):
148148
g = operator.num_groups
149149
dg = operator.group_dim
150150

151-
x1 = torch.ones((batch_size, seq_len, g, dg), device=device)
152-
x2 = torch.ones((batch_size, seq_len, g, dg), device=device)
153-
v = torch.ones((batch_size, seq_len, g, dg), device=device)
151+
x1 = torch.ones((batch_size, (g * dg), seq_len), device=device)
152+
x2 = torch.ones((batch_size, (g * dg), seq_len), device=device)
153+
v = torch.ones((batch_size, (g * dg), seq_len), device=device)
154154

155155
output = operator(x1, x2, v)
156156
assert output.shape[0] == batch_size
157-
assert output.shape[1] == seq_len
158-
assert output.shape[2] == operator.hidden_size
157+
assert output.shape[1] == operator.hidden_size
158+
assert output.shape[2] == seq_len
159159

160160

161161
class TestParallelCausalDepthwiseConv1d:

0 commit comments

Comments
 (0)