Skip to content

Commit bc597d4

Browse files
committed
set seed for mlm convergence
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent 3412006 commit bc597d4

1 file changed

Lines changed: 7 additions & 4 deletions

File tree

recipes/esm2_native_te_nvfsdp_thd/test_thd_format.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def test_thd_format():
5252
DataCollatorForLanguageModeling(
5353
tokenizer=tokenizer,
5454
mlm_probability=0.15,
55+
seed=42,
5556
),
5657
DataCollatorWithFlattening(
5758
return_flash_attn_kwargs=True,
@@ -137,7 +138,7 @@ def test_thd_format_with_different_batch_sizes():
137138
single_batch = [{"input_ids": [0, 5, 10, 15, 1], "attention_mask": [1, 1, 1, 1, 1]}]
138139

139140
data_collator = MLMDataCollatorWithFlattening(
140-
DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15),
141+
DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15, seed=42),
141142
DataCollatorWithFlattening(return_flash_attn_kwargs=True),
142143
)
143144

@@ -156,7 +157,9 @@ def test_thd_format_sequence_lengths():
156157
original_lengths = [len(seq["input_ids"]) for seq in batch]
157158

158159
data_collator = MLMDataCollatorWithFlattening(
159-
DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.0), # No masking for length test
160+
DataCollatorForLanguageModeling(
161+
tokenizer=tokenizer, mlm_probability=0.0, seed=42
162+
), # No masking for length test
160163
DataCollatorWithFlattening(return_flash_attn_kwargs=True),
161164
)
162165

@@ -174,7 +177,7 @@ def test_thd_format_tensor_types():
174177
batch = create_test_batch()
175178

176179
data_collator = MLMDataCollatorWithFlattening(
177-
DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15),
180+
DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15, seed=42),
178181
DataCollatorWithFlattening(return_flash_attn_kwargs=True),
179182
)
180183

@@ -208,7 +211,7 @@ def test_mlm_data_collator_integration():
208211
# Test with different MLM probabilities
209212
for mlm_prob in [0.0, 0.15, 0.3]:
210213
data_collator = MLMDataCollatorWithFlattening(
211-
DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=mlm_prob),
214+
DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=mlm_prob, seed=42),
212215
DataCollatorWithFlattening(return_flash_attn_kwargs=True),
213216
)
214217

0 commit comments

Comments
 (0)