Skip to content

Commit 105ff96

Browse files
committed
use esm collator
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent 5fad7fa commit 105ff96

1 file changed

Lines changed: 4 additions & 6 deletions

File tree

models/esm2/tests/test_thd_inputs.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515

1616
import pytest
1717
import torch
18-
from transformers import AutoModelForMaskedLM, DataCollatorForTokenClassification, DataCollatorWithFlattening
18+
from transformers import AutoModelForMaskedLM, DataCollatorForTokenClassification
1919

20+
from esm.collator import DataCollatorWithFlattening
2021
from esm.convert import convert_esm_hf_to_te
2122
from esm.modeling_esm_te import NVEsmForMaskedLM
2223

@@ -44,11 +45,6 @@ def test_thd_from_collator_output(te_model_checkpoint, input_data_thd):
4445
def test_thd_values_match(te_model_checkpoint, tokenizer, monkeypatch):
4546
# Manually masked input tokens so that both BSHD and THD models have the same mask pattern
4647

47-
# We know that the THD model is using Flash Attention, so use the same kernel for the BSHD model to ensure the
48-
# values are as close as possible.
49-
monkeypatch.setenv("NVTE_FLASH_ATTN", "1")
50-
monkeypatch.setenv("NVTE_FUSED_ATTN", "0")
51-
5248
proteins = [
5349
"MLSATEKLSDYISSLFASVSIINSISTEDLFFLKLTCQTFSKDSEEYKAAYRILRGVQRGKVQIIEEALVS",
5450
"MFVFFAGTLVNQDTLNFRDQLNINVVGTVRGIAQDASKYLEYAIDSV",
@@ -102,3 +98,5 @@ def test_thd_values_match(te_model_checkpoint, tokenizer, monkeypatch):
10298
print("bshd_outputs.loss", bshd_outputs.loss)
10399
print("thd_outputs.loss", thd_outputs.loss)
104100
torch.testing.assert_close(bshd_outputs.loss, thd_outputs.loss)
101+
102+
breakpoint()

0 commit comments

Comments
 (0)