@@ -52,6 +52,7 @@ def test_thd_format():
5252 DataCollatorForLanguageModeling (
5353 tokenizer = tokenizer ,
5454 mlm_probability = 0.15 ,
55+ seed = 42 ,
5556 ),
5657 DataCollatorWithFlattening (
5758 return_flash_attn_kwargs = True ,
@@ -137,7 +138,7 @@ def test_thd_format_with_different_batch_sizes():
137138 single_batch = [{"input_ids" : [0 , 5 , 10 , 15 , 1 ], "attention_mask" : [1 , 1 , 1 , 1 , 1 ]}]
138139
139140 data_collator = MLMDataCollatorWithFlattening (
140- DataCollatorForLanguageModeling (tokenizer = tokenizer , mlm_probability = 0.15 ),
141+ DataCollatorForLanguageModeling (tokenizer = tokenizer , mlm_probability = 0.15 , seed = 42 ),
141142 DataCollatorWithFlattening (return_flash_attn_kwargs = True ),
142143 )
143144
@@ -156,7 +157,9 @@ def test_thd_format_sequence_lengths():
156157 original_lengths = [len (seq ["input_ids" ]) for seq in batch ]
157158
158159 data_collator = MLMDataCollatorWithFlattening (
159- DataCollatorForLanguageModeling (tokenizer = tokenizer , mlm_probability = 0.0 ), # No masking for length test
160+ DataCollatorForLanguageModeling (
161+ tokenizer = tokenizer , mlm_probability = 0.0 , seed = 42
162+ ), # No masking for length test
160163 DataCollatorWithFlattening (return_flash_attn_kwargs = True ),
161164 )
162165
@@ -174,7 +177,7 @@ def test_thd_format_tensor_types():
174177 batch = create_test_batch ()
175178
176179 data_collator = MLMDataCollatorWithFlattening (
177- DataCollatorForLanguageModeling (tokenizer = tokenizer , mlm_probability = 0.15 ),
180+ DataCollatorForLanguageModeling (tokenizer = tokenizer , mlm_probability = 0.15 , seed = 42 ),
178181 DataCollatorWithFlattening (return_flash_attn_kwargs = True ),
179182 )
180183
@@ -208,7 +211,7 @@ def test_mlm_data_collator_integration():
208211 # Test with different MLM probabilities
209212 for mlm_prob in [0.0 , 0.15 , 0.3 ]:
210213 data_collator = MLMDataCollatorWithFlattening (
211- DataCollatorForLanguageModeling (tokenizer = tokenizer , mlm_probability = mlm_prob ),
214+ DataCollatorForLanguageModeling (tokenizer = tokenizer , mlm_probability = mlm_prob , seed = 42 ),
212215 DataCollatorWithFlattening (return_flash_attn_kwargs = True ),
213216 )
214217
0 commit comments