Skip to content

Commit 78a6689

Browse files
authored
Fix float32 indices in test_batched_export_with_backprop
Differential Revision: D101547370 Pull Request resolved: #18992
1 parent 8e5ec80 commit 78a6689

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

examples/models/llama/tests/test_static_attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ def test_batched_export_with_backprop(self):
325325
static_config, input_len, cache_len, batch_size=batch_size
326326
)
327327
example_inputs = (
328-
torch.zeros(batch_size, input_len),
328+
torch.zeros(batch_size, input_len, dtype=torch.long),
329329
{
330330
"masks": mgr.masks,
331331
"freqs_cos_override": mgr.freqs_cos[:input_len],
@@ -350,7 +350,7 @@ def test_batched_export_with_backprop(self):
350350
static_config, input_len, cache_len, batch_size=1
351351
)
352352
example_inputs = (
353-
torch.zeros(1, input_len),
353+
torch.zeros(1, input_len, dtype=torch.long),
354354
{
355355
"masks": mgr.masks,
356356
"freqs_cos_override": mgr.freqs_cos[:input_len],

0 commit comments

Comments
 (0)