|
15 | 15 |
|
16 | 16 | import pytest |
17 | 17 | import torch |
18 | | -from transformers import AutoModelForMaskedLM, DataCollatorForTokenClassification, DataCollatorWithFlattening |
| 18 | +from transformers import AutoModelForMaskedLM, DataCollatorForTokenClassification |
19 | 19 |
|
| 20 | +from esm.collator import DataCollatorWithFlattening |
20 | 21 | from esm.convert import convert_esm_hf_to_te |
21 | 22 | from esm.modeling_esm_te import NVEsmForMaskedLM |
22 | 23 |
|
@@ -44,11 +45,6 @@ def test_thd_from_collator_output(te_model_checkpoint, input_data_thd): |
44 | 45 | def test_thd_values_match(te_model_checkpoint, tokenizer, monkeypatch): |
45 | 46 | # Manually masked input tokens so that both BSHD and THD models have the same mask pattern |
46 | 47 |
|
47 | | - # We know that the THD model is using Flash Attention, so use the same kernel for the BSHD model to ensure the |
48 | | - # values are as close as possible. |
49 | | - monkeypatch.setenv("NVTE_FLASH_ATTN", "1") |
50 | | - monkeypatch.setenv("NVTE_FUSED_ATTN", "0") |
51 | | - |
52 | 48 | proteins = [ |
53 | 49 | "MLSATEKLSDYISSLFASVSIINSISTEDLFFLKLTCQTFSKDSEEYKAAYRILRGVQRGKVQIIEEALVS", |
54 | 50 | "MFVFFAGTLVNQDTLNFRDQLNINVVGTVRGIAQDASKYLEYAIDSV", |
@@ -102,3 +98,5 @@ def test_thd_values_match(te_model_checkpoint, tokenizer, monkeypatch): |
102 | 98 | print("bshd_outputs.loss", bshd_outputs.loss) |
103 | 99 | print("thd_outputs.loss", thd_outputs.loss) |
104 | 100 | torch.testing.assert_close(bshd_outputs.loss, thd_outputs.loss) |
| 101 | + |
| 102 | + breakpoint() |
0 commit comments