Skip to content

Commit cae320a

Browse files
committed
revert dtype calls to transformers 4.55.0
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent af09587 commit cae320a

2 files changed

Lines changed: 8 additions & 4 deletions

File tree

models/esm2/src/esm/modeling_esm_te.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def __init__(self, config: NVEsmConfig):
128128
micro_batch_size=config.micro_batch_size,
129129
num_gqa_groups=config.num_attention_heads,
130130
fuse_qkv_params=config.fuse_qkv_params,
131-
params_dtype=config.dtype,
131+
params_dtype=config.torch_dtype,
132132
window_size=(-1, -1),
133133
)
134134
for i in range(config.num_hidden_layers)

models/esm2/tests/test_thd_inputs.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@ def te_model_checkpoint(tmp_path):
3030

3131

3232
def test_thd_from_collator_output(te_model_checkpoint, input_data_thd):
33-
model_thd = NVEsmForMaskedLM.from_pretrained(te_model_checkpoint, attn_input_format="thd", dtype=torch.bfloat16)
33+
model_thd = NVEsmForMaskedLM.from_pretrained(
34+
te_model_checkpoint, attn_input_format="thd", torch_dtype=torch.bfloat16
35+
)
3436
model_thd.to("cuda")
3537
input_data_thd = {k: v.to("cuda") if isinstance(v, torch.Tensor) else v for k, v in input_data_thd.items()}
3638
with torch.no_grad(), torch.amp.autocast("cuda", dtype=torch.bfloat16):
@@ -76,8 +78,10 @@ def test_thd_values_match(te_model_checkpoint, tokenizer):
7678
input_data_bhsd = bhsd_collator(sequences)
7779
input_data_thd = thd_collator(sequences)
7880

79-
model_bshd = NVEsmForMaskedLM.from_pretrained(te_model_checkpoint, dtype=torch.bfloat16)
80-
model_thd = NVEsmForMaskedLM.from_pretrained(te_model_checkpoint, attn_input_format="thd", dtype=torch.bfloat16)
81+
model_bshd = NVEsmForMaskedLM.from_pretrained(te_model_checkpoint, torch_dtype=torch.bfloat16)
82+
model_thd = NVEsmForMaskedLM.from_pretrained(
83+
te_model_checkpoint, attn_input_format="thd", torch_dtype=torch.bfloat16
84+
)
8185
model_bshd.to("cuda")
8286
model_thd.to("cuda")
8387

0 commit comments

Comments
 (0)