Skip to content

Commit b8ff665

Browse files
wyfEmmaaahladc
authored andcommitted
revise the script for:
1. configure adaptive y limits 2. using mean and standard diviation of the metrics instead of averaging 3. use solid line for v1, dashed line for v2
1 parent a5f4878 commit b8ff665

19 files changed

Lines changed: 192 additions & 71 deletions

logs/curve_plotting/plot_schedule_free.py

Lines changed: 192 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,51 @@
22
import glob
33
import json
44
import pandas as pd
5+
import numpy as np
56
import matplotlib.pyplot as plt
67
from pathlib import Path
78

8-
# Define submissions and their styles
9+
# Configure styling params
10+
plt.rcParams.update({
11+
'font.family': 'sans-serif',
12+
'font.sans-serif': ['Helvetica', 'Arial', 'DejaVu Sans'],
13+
'font.size': 12,
14+
'axes.labelsize': 13,
15+
'axes.titlesize': 14,
16+
'xtick.labelsize': 11,
17+
'ytick.labelsize': 11,
18+
'legend.fontsize': 11,
19+
'figure.titlesize': 16,
20+
'pdf.fonttype': 42,
21+
'ps.fonttype': 42
22+
})
23+
24+
# Define submissions and distinct styles
925
submissions = {
10-
'schedule_free_adamw': {'color': 'skyblue', 'label': 'PyTorch v1', 'alpha': 0.8},
11-
'schedule_free_adamw_v2': {'color': 'darkblue', 'label': 'PyTorch v2', 'alpha': 0.8},
12-
'schedule_free_adamw_jax': {'color': 'wheat', 'label': 'JAX v1', 'alpha': 0.8},
13-
'schedule_free_adamw_jax_v2': {'color': 'darkorange', 'label': 'JAX v2', 'alpha': 0.8}
26+
'schedule_free_adamw': {
27+
'color': '#1F77B4', # Classic Blue
28+
'linestyle': '-', # Solid
29+
'label': 'PyTorch v1',
30+
'alpha': 0.9
31+
},
32+
'schedule_free_adamw_v2': {
33+
'color': '#0B3C5D', # Deep Navy
34+
'linestyle': '--', # Dashed
35+
'label': 'PyTorch v2',
36+
'alpha': 0.9
37+
},
38+
'schedule_free_adamw_jax': {
39+
'color': '#FF7F0E', # Safety Orange
40+
'linestyle': '-', # Solid
41+
'label': 'JAX v1',
42+
'alpha': 0.9
43+
},
44+
'schedule_free_adamw_jax_v2': {
45+
'color': '#D9531E', # Vibrant Rust
46+
'linestyle': '--', # Dashed
47+
'label': 'JAX v2',
48+
'alpha': 0.9
49+
}
1450
}
1551

1652
base_log_dir = Path('~/submissions_algorithms/logs/self_tuning').expanduser()
@@ -36,8 +72,8 @@
3672
target_metric = None
3773
target_value = None
3874

39-
for sub in submissions.keys():
40-
pattern = os.path.join(base_log_dir, sub, 'study_*', f"{workload}*", 'trial_*', 'meta_data_0.json')
75+
for sub in submissions.items():
76+
pattern = os.path.join(base_log_dir, sub[0], 'study_*', f"{workload}*", 'trial_*', 'meta_data_0.json')
4177
files = glob.glob(pattern)
4278
if files:
4379
try:
@@ -57,101 +93,186 @@
5793

5894
csv_col_name = f"validation/{target_metric}"
5995

96+
# Check if metric is "higher is better" (Accuracy, BLEU, SSIM, AUC, MAP)
97+
higher_is_better = any(x in target_metric.lower() for x in ['accuracy', 'auc', 'map', 'bleu', 'ssim', 'precision', 'score'])
98+
6099
# Prepare plots
61-
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
62-
fig.suptitle(f"Workload: {workload} (Target: {target_metric})")
100+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5.5))
101+
fig.suptitle(f"Workload: {workload} (Metric: {target_metric})", fontweight='bold', y=0.98)
63102

64103
has_data = False
104+
all_metric_values = []
65105

106+
# Step 1: Gather and inspect data curves for bounds checking
107+
workload_curves = {}
66108
for sub, style in submissions.items():
67109
pattern = os.path.join(base_log_dir, sub, 'study_*', f"{workload}*", 'trial_*', 'eval_measurements.csv')
68110
files = glob.glob(pattern)
69111

70112
if not files:
71-
print(f" No data for {sub}")
72113
continue
73114

74-
print(f" Found {len(files)} trials for {sub}")
75-
76-
all_dfs = []
115+
dfs = []
77116
for f in files:
78117
try:
79118
df = pd.read_csv(f)
80119
if csv_col_name in df.columns:
81-
all_dfs.append(df)
82-
else:
83-
print(f" Column {csv_col_name} not found in {f}")
120+
df = df.dropna(subset=[csv_col_name, 'accumulated_submission_time', 'global_step'])
121+
if not df.empty:
122+
dfs.append(df)
123+
all_metric_values.extend(df[csv_col_name].tolist())
84124
except Exception as e:
85-
print(f" Error reading {f}: {e}")
125+
pass
86126

87-
if not all_dfs:
88-
continue
89-
90-
has_data = True
127+
if dfs:
128+
workload_curves[sub] = (dfs, style)
129+
has_data = True
91130

92-
# 1. Plot individual trial runs faintly to show raw trajectories
93-
for df in all_dfs:
94-
ax1.plot(df['accumulated_submission_time'], df[csv_col_name],
95-
color=style['color'], alpha=0.15, linewidth=1)
96-
ax2.plot(df['global_step'], df[csv_col_name],
97-
color=style['color'], alpha=0.15, linewidth=1)
98-
99-
# 2. Align metrics and calculate Mean + Standard Deviation across trials
100-
combined_df = pd.concat(all_dfs)
101-
stats_df = combined_df.groupby('global_step')[csv_col_name].agg(['mean', 'std']).reset_index()
102-
103-
# Average the time per step to align the time-series plot
104-
time_df = combined_df.groupby('global_step')['accumulated_submission_time'].mean().reset_index()
105-
stats_df = pd.merge(stats_df, time_df, on='global_step')
106-
107-
# Fill missing std calculations (e.g., if step counts differ slightly between runs)
108-
stats_df['std'] = stats_df['std'].fillna(0)
109-
110-
# 3. Plot the bold average line
111-
ax1.plot(stats_df['accumulated_submission_time'], stats_df['mean'],
112-
color=style['color'], label=style['label'], alpha=1.0, linewidth=2.5)
113-
114-
ax2.plot(stats_df['global_step'], stats_df['mean'],
115-
color=style['color'], label=style['label'], alpha=1.0, linewidth=2.5)
116-
117-
# 4. Fill the shaded area representing +/- 1 standard deviation
118-
ax1.fill_between(stats_df['accumulated_submission_time'],
119-
stats_df['mean'] - stats_df['std'],
120-
stats_df['mean'] + stats_df['std'],
121-
color=style['color'], alpha=0.10)
122-
123-
ax2.fill_between(stats_df['global_step'],
124-
stats_df['mean'] - stats_df['std'],
125-
stats_df['mean'] + stats_df['std'],
126-
color=style['color'], alpha=0.10)
127-
128131
if not has_data:
129132
plt.close(fig)
130133
continue
131134

135+
# Calculate y-limits using percentiles
136+
if all_metric_values:
137+
sorted_vals = sorted(all_metric_values)
138+
n = len(sorted_vals)
139+
140+
if higher_is_better:
141+
pct_5 = sorted_vals[int(n * 0.05)]
142+
ymin = max(0.0, pct_5 * 0.95) if pct_5 > 0.1 else 0.0
143+
144+
ymax = sorted_vals[-1]
145+
if target_value is not None:
146+
ymax = max(ymax, target_value)
147+
ymax = ymax * 1.05
148+
149+
# If it's a fractional metric (all values <= 1.0), cap at 1.0
150+
if all(v <= 1.0 for v in all_metric_values):
151+
ymax = min(1.0, ymax)
152+
else:
153+
min_val = sorted_vals[0]
154+
ymin = min_val * 0.95
155+
if target_value is not None:
156+
ymin = min(ymin, target_value * 0.9)
157+
ymin = max(0.0, ymin)
158+
159+
pct_90 = sorted_vals[int(n * 0.90)]
160+
ymax = pct_90
161+
if target_value is not None:
162+
ymax = max(ymax, target_value * 1.5)
163+
164+
if ymax <= ymin:
165+
ymax = ymin * 2.0 if ymin > 0 else 1.0
166+
else:
167+
ymin, ymax = 0.0, 1.0
168+
169+
# Second pass: interpolate and plot
170+
for sub, (dfs, style) in workload_curves.items():
171+
# --- Time-based Interpolation ---
172+
# Find global time range for this submission
173+
all_times = []
174+
for df in dfs:
175+
all_times.extend(df['accumulated_submission_time'].tolist())
176+
177+
if all_times:
178+
min_time = min(all_times)
179+
max_time = max(all_times)
180+
181+
# Create a uniform grid of 150 points for smooth rendering
182+
grid_times = np.linspace(min_time, max_time, 150)
183+
184+
interpolated_metrics = []
185+
for df in dfs:
186+
# Interpolate this trial's metrics to the unified time grid.
187+
# Use the last metric value for times beyond the duration of the trial to flatline.
188+
interp_val = np.interp(grid_times, df['accumulated_submission_time'], df[csv_col_name], right=df[csv_col_name].iloc[-1])
189+
interpolated_metrics.append(interp_val)
190+
191+
# Compute mean and standard deviation ignoring NaNs
192+
mean_time_curve = np.nanmean(interpolated_metrics, axis=0)
193+
std_time_curve = np.nanstd(interpolated_metrics, axis=0)
194+
std_time_curve = np.nan_to_num(std_time_curve, nan=0.0)
195+
196+
# Plot Time Curves (in hours)
197+
time_hours = grid_times / 3600.0
198+
ax1.plot(time_hours, mean_time_curve,
199+
color=style['color'], linestyle=style['linestyle'],
200+
label=style['label'], alpha=style['alpha'], linewidth=2.5)
201+
202+
ax1.fill_between(time_hours,
203+
mean_time_curve - std_time_curve,
204+
mean_time_curve + std_time_curve,
205+
color=style['color'], alpha=0.10, edgecolor='none')
206+
207+
# --- Step-based Interpolation ---
208+
all_steps = []
209+
for df in dfs:
210+
all_steps.extend(df['global_step'].tolist())
211+
212+
if all_steps:
213+
min_step = min(all_steps)
214+
max_step = max(all_steps)
215+
216+
grid_steps = np.linspace(min_step, max_step, 150)
217+
218+
interpolated_steps = []
219+
for df in dfs:
220+
interp_val = np.interp(grid_steps, df['global_step'], df[csv_col_name], right=df[csv_col_name].iloc[-1])
221+
interpolated_steps.append(interp_val)
222+
223+
mean_step_curve = np.nanmean(interpolated_steps, axis=0)
224+
std_step_curve = np.nanstd(interpolated_steps, axis=0)
225+
std_step_curve = np.nan_to_num(std_step_curve, nan=0.0)
226+
227+
steps_k = grid_steps / 1000.0
228+
ax2.plot(steps_k, mean_step_curve,
229+
color=style['color'], linestyle=style['linestyle'],
230+
label=style['label'], alpha=style['alpha'], linewidth=2.5)
231+
232+
ax2.fill_between(steps_k,
233+
mean_step_curve - std_step_curve,
234+
mean_step_curve + std_step_curve,
235+
color=style['color'], alpha=0.10, edgecolor='none')
236+
132237
# Configure axes
133238
for ax in [ax1, ax2]:
134-
ax.set_yscale('log')
239+
if not higher_is_better and any(x in target_metric.lower() for x in ['loss', 'perplexity']) and (ymax / (ymin + 1e-8) > 10):
240+
ax.set_yscale('log')
241+
else:
242+
ax.set_yscale('linear')
243+
244+
ax.set_ylim(ymin, ymax)
245+
135246
if target_value is not None:
136-
ax.axhline(y=target_value, color='r', linestyle='--', label=f'Target ({target_value})')
137-
ax.legend()
138-
ax.grid(True, which="both", ls="-", alpha=0.2)
247+
ax.axhline(y=target_value, color='#D0021B', linestyle=':', linewidth=1.5, label=f'Target ({target_value})')
248+
249+
ax.legend(frameon=True, facecolor='white', framealpha=0.9, edgecolor='#e5e5e5')
250+
ax.grid(True, which="major", color="#e8e8e8", linestyle="-", linewidth=0.8)
139251

140-
ax1.set_xlabel('Accumulated Submission Time (s)')
141-
ax1.set_ylabel(csv_col_name)
142-
ax1.set_title(f'{csv_col_name} vs Time')
252+
ax.spines['top'].set_visible(False)
253+
ax.spines['right'].set_visible(False)
254+
ax.spines['left'].set_color('#cccccc')
255+
ax.spines['bottom'].set_color('#cccccc')
256+
257+
ax1.set_xlabel('Accumulated Time (hours)', color='#333333', fontweight='semibold')
258+
ax1.set_ylabel(f'Validation {target_metric.upper()}', color='#333333', fontweight='semibold')
143259

144-
ax2.set_xlabel('Global Step')
145-
ax2.set_ylabel(csv_col_name)
146-
ax2.set_title(f'{csv_col_name} vs Step')
260+
ax2.set_xlabel('Global Steps (x10³)', color='#333333', fontweight='semibold')
261+
ax2.set_ylabel(f'Validation {target_metric.upper()}', color='#333333', fontweight='semibold')
147262

148263
plt.tight_layout()
149264

150-
# Save plot
265+
# Save plots
151266
save_dir.mkdir(exist_ok=True, parents=True)
152-
out_path = save_dir / f'{workload}_curves.png'
153-
plt.savefig(out_path)
267+
268+
pdf_path = save_dir / f'{workload}_curves.pdf'
269+
plt.savefig(pdf_path, bbox_inches='tight')
270+
271+
png_path = save_dir / f'{workload}_curves.png'
272+
plt.savefig(png_path, dpi=300, bbox_inches='tight')
273+
154274
plt.close(fig)
155-
print(f"Saved plot to {out_path}")
275+
print(f"Saved PDF to {pdf_path}")
276+
print(f"Saved PNG to {png_path}")
156277

157278
print("Done.")
44.4 KB
Binary file not shown.
171 KB
Loading
45.1 KB
Binary file not shown.
211 KB
Loading
47.8 KB
Binary file not shown.
428 KB
Loading
49.1 KB
Binary file not shown.
348 KB
Loading
49.7 KB
Binary file not shown.

0 commit comments

Comments
 (0)