@@ -73,25 +73,35 @@ def test_thd_values_match(te_model_checkpoint, tokenizer, monkeypatch):
7373 for seq in sequences
7474 ]
7575
76- bhsd_collator = DataCollatorForTokenClassification (tokenizer = tokenizer , padding = True )
76+ bshd_collator = DataCollatorForTokenClassification (tokenizer = tokenizer , padding = True )
7777 thd_collator = DataCollatorWithFlattening (return_flash_attn_kwargs = True )
7878
79- input_data_bhsd = bhsd_collator (sequences )
79+ input_data_bshd = bshd_collator (sequences )
8080 input_data_thd = thd_collator (sequences )
8181
82+ torch .testing .assert_close (
83+ input_data_bshd ["input_ids" ][input_data_bshd ["attention_mask" ].to (bool )],
84+ input_data_thd ["input_ids" ].flatten (0 ),
85+ )
86+
87+ torch .testing .assert_close (
88+ input_data_bshd ["labels" ][input_data_bshd ["attention_mask" ].to (bool )],
89+ input_data_thd ["labels" ].flatten (0 ),
90+ )
91+
8292 model_bshd = NVEsmForMaskedLM .from_pretrained (te_model_checkpoint , torch_dtype = torch .bfloat16 )
8393 model_thd = NVEsmForMaskedLM .from_pretrained (
8494 te_model_checkpoint , attn_input_format = "thd" , torch_dtype = torch .bfloat16
8595 )
8696 model_bshd .to ("cuda" )
8797 model_thd .to ("cuda" )
8898
89- input_data_bhsd = {k : v .to ("cuda" ) for k , v in input_data_bhsd .items ()}
99+ input_data_bshd = {k : v .to ("cuda" ) for k , v in input_data_bshd .items ()}
90100 input_data_thd = {k : v .to ("cuda" ) if isinstance (v , torch .Tensor ) else v for k , v in input_data_thd .items ()}
91101
92- bshd_outputs = model_bshd (** input_data_bhsd )
102+ bshd_outputs = model_bshd (** input_data_bshd )
93103 thd_outputs = model_thd (** input_data_thd )
94104
95- bhsd_logits = bshd_outputs .logits [input_data_bhsd ["attention_mask" ].to (bool )]
105+ bhsd_logits = bshd_outputs .logits [input_data_bshd ["attention_mask" ].to (bool )]
96106 torch .testing .assert_close (bhsd_logits , thd_outputs .logits )
97107 torch .testing .assert_close (bshd_outputs .loss , thd_outputs .loss )
0 commit comments