We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent df70edf commit cc3dc4bCopy full SHA for cc3dc4b
1 file changed
plugins/online-data-mixing/artifacts/custom_loop_usage.py
@@ -95,14 +95,14 @@ class State:
95
if step_idx % 1 == 0:
96
if torch.isnan(loss):
97
loss = torch.tensor([10]) # nan -> very high loss
98
- print(f"Step {step_idx} ||| Loss: {loss.item():.4f}")
+ if accelerator.is_main_process:
99
+ print(f"Step {step_idx} ||| Loss: {loss.item():.4f}")
100
state.log_history.append(
101
{"loss": loss.item() if not torch.isnan(loss) else 1e100}
102
)
103
if step_idx % update_interval == 0:
104
dataloader.dataset.update_sampling_weights(model, accelerator, state)
- max_steps -= 1
105
- if max_steps == 0:
+ if step_idx > max_steps:
106
break
107
108
print("training completed!")
0 commit comments