forked from N8python/mlx-pretrain
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathplot_training.py
More file actions
executable file
·197 lines (158 loc) · 6.82 KB
/
Copy pathplot_training.py
File metadata and controls
executable file
·197 lines (158 loc) · 6.82 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
#!/usr/bin/env python
import matplotlib.pyplot as plt
import numpy as np
import argparse
import re
import time
from pathlib import Path
def extract_metrics_from_log(log_file_path):
"""Extract metrics from a log file."""
steps = []
losses = []
val_steps = []
val_losses = []
with open(log_file_path, 'r') as f:
for line in f:
# Extract regular training metrics
if line.startswith("Step") and "validation:" not in line:
step_match = re.search(r"Step (\d+)", line)
loss_match = re.search(r"loss=([0-9.e-]+)", line)
if step_match and loss_match:
step = int(step_match.group(1))
loss = float(loss_match.group(1))
steps.append(step)
losses.append(loss)
# Extract validation metrics
elif "validation:" in line:
step_match = re.search(r"Step (\d+)", line)
val_loss_match = re.search(r"val_loss=([0-9.e-]+)", line)
if step_match and val_loss_match:
step = int(step_match.group(1))
val_loss = float(val_loss_match.group(1))
val_steps.append(step)
val_losses.append(val_loss)
return steps, losses, val_steps, val_losses
def plot_training_metrics(log_file_path, output_path=None, interval=60):
"""Plot training metrics and save to file or display."""
plt.figure(figsize=(14, 8))
while True:
# Extract metrics
steps, losses, val_steps, val_losses = extract_metrics_from_log(log_file_path)
if not steps:
print(f"No training data found in log file: {log_file_path}")
time.sleep(interval)
continue
# Apply exponential moving average to smooth the loss curve
ema = 0.9
smoothed_losses = []
if losses:
smoothed_losses = [losses[0]]
for loss in losses[1:]:
smoothed_losses.append(ema * smoothed_losses[-1] + (1 - ema) * loss)
# Clear previous plot
plt.clf()
# Plot training loss
plt.plot(steps, smoothed_losses, label="Training Loss (EMA)", color='blue')
plt.plot(steps, losses, alpha=0.3, color='lightblue')
# Plot validation loss
if val_steps and val_losses:
plt.plot(val_steps, val_losses, 'o-', label="Validation Loss", color='red')
# Set labels and title
plt.xlabel("Step")
plt.ylabel("Loss")
plt.title(f"Training Progress - Last step: {steps[-1]}")
# Add grid
plt.grid(True, alpha=0.3)
# Add legend
plt.legend()
# Save or display
if output_path:
plt.savefig(output_path)
print(f"Updated plot saved to {output_path}")
else:
plt.savefig("training_plot.png")
print(f"Updated plot saved to training_plot.png")
# Sleep for specified interval
time.sleep(interval)
def find_latest_log_file(name="Muon-200M"):
"""Find the latest log file for the given model."""
# Check in runs directory
run_log_path = Path(f"runs/{name}/log.txt")
if run_log_path.exists():
return str(run_log_path)
# Check logs directory
logs_dir = Path("logs")
if logs_dir.exists():
log_files = [f for f in logs_dir.glob(f"train_{name.lower()}*.log")]
if log_files:
# Sort by modification time and return the latest
latest_log = max(log_files, key=lambda f: f.stat().st_mtime)
return str(latest_log)
# If we can't find a specific log for the model, look for any training log
if logs_dir.exists():
log_files = [f for f in logs_dir.glob("train_*.log")]
if log_files:
# Sort by modification time and return the latest
latest_log = max(log_files, key=lambda f: f.stat().st_mtime)
return str(latest_log)
return None
def main():
parser = argparse.ArgumentParser(description="Plot training metrics")
parser.add_argument("--log", type=str, help="Path to log file", default=None)
parser.add_argument("--model", type=str, help="Model name to find log for", default="Muon-200M")
parser.add_argument("--output", type=str, help="Output image path", default=None)
parser.add_argument("--interval", type=int, help="Update interval in seconds", default=60)
parser.add_argument("--no-watch", action="store_true", help="Generate plot once and exit")
args = parser.parse_args()
# Get log file path
log_file_path = args.log
if not log_file_path:
log_file_path = find_latest_log_file(args.model)
if not log_file_path or not Path(log_file_path).exists():
print(f"Error: Log file not found. Please specify a valid log file path.")
return
print(f"Using log file: {log_file_path}")
if args.no_watch:
# Extract metrics once
steps, losses, val_steps, val_losses = extract_metrics_from_log(log_file_path)
if not steps:
print(f"No training data found in log file: {log_file_path}")
return
# Apply exponential moving average to smooth the loss curve
ema = 0.9
smoothed_losses = []
if losses:
smoothed_losses = [losses[0]]
for loss in losses[1:]:
smoothed_losses.append(ema * smoothed_losses[-1] + (1 - ema) * loss)
plt.figure(figsize=(14, 8))
# Plot training loss
plt.plot(steps, smoothed_losses, label="Training Loss (EMA)", color='blue')
plt.plot(steps, losses, alpha=0.3, color='lightblue')
# Plot validation loss
if val_steps and val_losses:
plt.plot(val_steps, val_losses, 'o-', label="Validation Loss", color='red')
# Set labels and title
plt.xlabel("Step")
plt.ylabel("Loss")
plt.title(f"Training Progress - Last step: {steps[-1]}")
# Add grid
plt.grid(True, alpha=0.3)
# Add legend
plt.legend()
# Save or display
if args.output:
plt.savefig(args.output)
print(f"Plot saved to {args.output}")
else:
plt.savefig("training_plot.png")
print(f"Plot saved to training_plot.png")
else:
# Watch mode
print(f"Watching log file with update interval {args.interval} seconds. Press Ctrl+C to stop.")
try:
plot_training_metrics(log_file_path, args.output, args.interval)
except KeyboardInterrupt:
print("Monitoring stopped by user.")
if __name__ == "__main__":
main()