Skip to content

Commit bf56519

Browse files
committed
ensure data is equal
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent d06f118 commit bf56519

1 file changed

Lines changed: 15 additions & 5 deletions

File tree

models/esm2/tests/test_thd_inputs.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,25 +73,35 @@ def test_thd_values_match(te_model_checkpoint, tokenizer, monkeypatch):
7373
for seq in sequences
7474
]
7575

76-
bhsd_collator = DataCollatorForTokenClassification(tokenizer=tokenizer, padding=True)
76+
bshd_collator = DataCollatorForTokenClassification(tokenizer=tokenizer, padding=True)
7777
thd_collator = DataCollatorWithFlattening(return_flash_attn_kwargs=True)
7878

79-
input_data_bhsd = bhsd_collator(sequences)
79+
input_data_bshd = bshd_collator(sequences)
8080
input_data_thd = thd_collator(sequences)
8181

82+
torch.testing.assert_close(
83+
input_data_bshd["input_ids"][input_data_bshd["attention_mask"].to(bool)],
84+
input_data_thd["input_ids"].flatten(0),
85+
)
86+
87+
torch.testing.assert_close(
88+
input_data_bshd["labels"][input_data_bshd["attention_mask"].to(bool)],
89+
input_data_thd["labels"].flatten(0),
90+
)
91+
8292
model_bshd = NVEsmForMaskedLM.from_pretrained(te_model_checkpoint, torch_dtype=torch.bfloat16)
8393
model_thd = NVEsmForMaskedLM.from_pretrained(
8494
te_model_checkpoint, attn_input_format="thd", torch_dtype=torch.bfloat16
8595
)
8696
model_bshd.to("cuda")
8797
model_thd.to("cuda")
8898

89-
input_data_bhsd = {k: v.to("cuda") for k, v in input_data_bhsd.items()}
99+
input_data_bshd = {k: v.to("cuda") for k, v in input_data_bshd.items()}
90100
input_data_thd = {k: v.to("cuda") if isinstance(v, torch.Tensor) else v for k, v in input_data_thd.items()}
91101

92-
bshd_outputs = model_bshd(**input_data_bhsd)
102+
bshd_outputs = model_bshd(**input_data_bshd)
93103
thd_outputs = model_thd(**input_data_thd)
94104

95-
bhsd_logits = bshd_outputs.logits[input_data_bhsd["attention_mask"].to(bool)]
105+
bhsd_logits = bshd_outputs.logits[input_data_bshd["attention_mask"].to(bool)]
96106
torch.testing.assert_close(bhsd_logits, thd_outputs.logits)
97107
torch.testing.assert_close(bshd_outputs.loss, thd_outputs.loss)

0 commit comments

Comments
 (0)