Skip to content

Commit 5fb1c6c

Browse files
committed
updated tests to add gradient accumulation flags in config
1 parent f124f10 commit 5fb1c6c

1 file changed

Lines changed: 4 additions & 0 deletions

File tree

tests/post_training/unit/train_distill_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -709,6 +709,8 @@ def test_main_offline_mode_skips_teacher_loading(
709709
mock_student_cfg.checkpoint_period = 10
710710
mock_student_cfg.gradient_clipping_threshold = 0.0
711711
mock_student_cfg.eval_interval = -1
712+
mock_student_cfg.gradient_accumulation_steps = 1
713+
mock_student_cfg.global_batch_size = 8
712714

713715
# Add dummy numbers for strategy math/logic
714716
mock_student_cfg.distill_temperature = 1.0
@@ -786,6 +788,8 @@ def test_main_online_mode_loads_teacher(
786788
mock_student_cfg.checkpoint_period = 10
787789
mock_student_cfg.gradient_clipping_threshold = 0.0
788790
mock_student_cfg.eval_interval = -1
791+
mock_student_cfg.gradient_accumulation_steps = 1
792+
mock_student_cfg.global_batch_size = 8
789793

790794
# Add dummy numbers for strategy math/logic
791795
mock_student_cfg.distill_temperature = 1.0

0 commit comments

Comments
 (0)