|
2 | 2 | import glob |
3 | 3 | import json |
4 | 4 | import pandas as pd |
| 5 | +import numpy as np |
5 | 6 | import matplotlib.pyplot as plt |
6 | 7 | from pathlib import Path |
7 | 8 |
|
8 | | -# Define submissions and their styles |
| 9 | +# Configure styling params |
| 10 | +plt.rcParams.update({ |
| 11 | + 'font.family': 'sans-serif', |
| 12 | + 'font.sans-serif': ['Helvetica', 'Arial', 'DejaVu Sans'], |
| 13 | + 'font.size': 12, |
| 14 | + 'axes.labelsize': 13, |
| 15 | + 'axes.titlesize': 14, |
| 16 | + 'xtick.labelsize': 11, |
| 17 | + 'ytick.labelsize': 11, |
| 18 | + 'legend.fontsize': 11, |
| 19 | + 'figure.titlesize': 16, |
| 20 | + 'pdf.fonttype': 42, |
| 21 | + 'ps.fonttype': 42 |
| 22 | +}) |
| 23 | + |
| 24 | +# Define submissions and distinct styles |
9 | 25 | submissions = { |
10 | | - 'schedule_free_adamw': {'color': 'skyblue', 'label': 'PyTorch v1', 'alpha': 0.8}, |
11 | | - 'schedule_free_adamw_v2': {'color': 'darkblue', 'label': 'PyTorch v2', 'alpha': 0.8}, |
12 | | - 'schedule_free_adamw_jax': {'color': 'wheat', 'label': 'JAX v1', 'alpha': 0.8}, |
13 | | - 'schedule_free_adamw_jax_v2': {'color': 'darkorange', 'label': 'JAX v2', 'alpha': 0.8} |
| 26 | + 'schedule_free_adamw': { |
| 27 | + 'color': '#1F77B4', # Classic Blue |
| 28 | + 'linestyle': '-', # Solid |
| 29 | + 'label': 'PyTorch v1', |
| 30 | + 'alpha': 0.9 |
| 31 | + }, |
| 32 | + 'schedule_free_adamw_v2': { |
| 33 | + 'color': '#0B3C5D', # Deep Navy |
| 34 | + 'linestyle': '--', # Dashed |
| 35 | + 'label': 'PyTorch v2', |
| 36 | + 'alpha': 0.9 |
| 37 | + }, |
| 38 | + 'schedule_free_adamw_jax': { |
| 39 | + 'color': '#FF7F0E', # Safety Orange |
| 40 | + 'linestyle': '-', # Solid |
| 41 | + 'label': 'JAX v1', |
| 42 | + 'alpha': 0.9 |
| 43 | + }, |
| 44 | + 'schedule_free_adamw_jax_v2': { |
| 45 | + 'color': '#D9531E', # Vibrant Rust |
| 46 | + 'linestyle': '--', # Dashed |
| 47 | + 'label': 'JAX v2', |
| 48 | + 'alpha': 0.9 |
| 49 | + } |
14 | 50 | } |
15 | 51 |
|
16 | 52 | base_log_dir = Path('~/submissions_algorithms/logs/self_tuning').expanduser() |
|
36 | 72 | target_metric = None |
37 | 73 | target_value = None |
38 | 74 |
|
39 | | - for sub in submissions.keys(): |
40 | | - pattern = os.path.join(base_log_dir, sub, 'study_*', f"{workload}*", 'trial_*', 'meta_data_0.json') |
| 75 | + for sub in submissions.items(): |
| 76 | + pattern = os.path.join(base_log_dir, sub[0], 'study_*', f"{workload}*", 'trial_*', 'meta_data_0.json') |
41 | 77 | files = glob.glob(pattern) |
42 | 78 | if files: |
43 | 79 | try: |
|
57 | 93 |
|
58 | 94 | csv_col_name = f"validation/{target_metric}" |
59 | 95 |
|
| 96 | + # Check if metric is "higher is better" (Accuracy, BLEU, SSIM, AUC, MAP) |
| 97 | + higher_is_better = any(x in target_metric.lower() for x in ['accuracy', 'auc', 'map', 'bleu', 'ssim', 'precision', 'score']) |
| 98 | + |
60 | 99 | # Prepare plots |
61 | | - fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6)) |
62 | | - fig.suptitle(f"Workload: {workload} (Target: {target_metric})") |
| 100 | + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5.5)) |
| 101 | + fig.suptitle(f"Workload: {workload} (Metric: {target_metric})", fontweight='bold', y=0.98) |
63 | 102 |
|
64 | 103 | has_data = False |
| 104 | + all_metric_values = [] |
65 | 105 |
|
| 106 | + # Step 1: Gather and inspect data curves for bounds checking |
| 107 | + workload_curves = {} |
66 | 108 | for sub, style in submissions.items(): |
67 | 109 | pattern = os.path.join(base_log_dir, sub, 'study_*', f"{workload}*", 'trial_*', 'eval_measurements.csv') |
68 | 110 | files = glob.glob(pattern) |
69 | 111 |
|
70 | 112 | if not files: |
71 | | - print(f" No data for {sub}") |
72 | 113 | continue |
73 | 114 |
|
74 | | - print(f" Found {len(files)} trials for {sub}") |
75 | | - |
76 | | - all_dfs = [] |
| 115 | + dfs = [] |
77 | 116 | for f in files: |
78 | 117 | try: |
79 | 118 | df = pd.read_csv(f) |
80 | 119 | if csv_col_name in df.columns: |
81 | | - all_dfs.append(df) |
82 | | - else: |
83 | | - print(f" Column {csv_col_name} not found in {f}") |
| 120 | + df = df.dropna(subset=[csv_col_name, 'accumulated_submission_time', 'global_step']) |
| 121 | + if not df.empty: |
| 122 | + dfs.append(df) |
| 123 | + all_metric_values.extend(df[csv_col_name].tolist()) |
84 | 124 | except Exception as e: |
85 | | - print(f" Error reading {f}: {e}") |
| 125 | + pass |
86 | 126 |
|
87 | | - if not all_dfs: |
88 | | - continue |
89 | | - |
90 | | - has_data = True |
| 127 | + if dfs: |
| 128 | + workload_curves[sub] = (dfs, style) |
| 129 | + has_data = True |
91 | 130 |
|
92 | | - # 1. Plot individual trial runs faintly to show raw trajectories |
93 | | - for df in all_dfs: |
94 | | - ax1.plot(df['accumulated_submission_time'], df[csv_col_name], |
95 | | - color=style['color'], alpha=0.15, linewidth=1) |
96 | | - ax2.plot(df['global_step'], df[csv_col_name], |
97 | | - color=style['color'], alpha=0.15, linewidth=1) |
98 | | - |
99 | | - # 2. Align metrics and calculate Mean + Standard Deviation across trials |
100 | | - combined_df = pd.concat(all_dfs) |
101 | | - stats_df = combined_df.groupby('global_step')[csv_col_name].agg(['mean', 'std']).reset_index() |
102 | | - |
103 | | - # Average the time per step to align the time-series plot |
104 | | - time_df = combined_df.groupby('global_step')['accumulated_submission_time'].mean().reset_index() |
105 | | - stats_df = pd.merge(stats_df, time_df, on='global_step') |
106 | | - |
107 | | - # Fill missing std calculations (e.g., if step counts differ slightly between runs) |
108 | | - stats_df['std'] = stats_df['std'].fillna(0) |
109 | | - |
110 | | - # 3. Plot the bold average line |
111 | | - ax1.plot(stats_df['accumulated_submission_time'], stats_df['mean'], |
112 | | - color=style['color'], label=style['label'], alpha=1.0, linewidth=2.5) |
113 | | - |
114 | | - ax2.plot(stats_df['global_step'], stats_df['mean'], |
115 | | - color=style['color'], label=style['label'], alpha=1.0, linewidth=2.5) |
116 | | - |
117 | | - # 4. Fill the shaded area representing +/- 1 standard deviation |
118 | | - ax1.fill_between(stats_df['accumulated_submission_time'], |
119 | | - stats_df['mean'] - stats_df['std'], |
120 | | - stats_df['mean'] + stats_df['std'], |
121 | | - color=style['color'], alpha=0.10) |
122 | | - |
123 | | - ax2.fill_between(stats_df['global_step'], |
124 | | - stats_df['mean'] - stats_df['std'], |
125 | | - stats_df['mean'] + stats_df['std'], |
126 | | - color=style['color'], alpha=0.10) |
127 | | - |
128 | 131 | if not has_data: |
129 | 132 | plt.close(fig) |
130 | 133 | continue |
131 | 134 |
|
| 135 | + # Calculate y-limits using percentiles |
| 136 | + if all_metric_values: |
| 137 | + sorted_vals = sorted(all_metric_values) |
| 138 | + n = len(sorted_vals) |
| 139 | + |
| 140 | + if higher_is_better: |
| 141 | + pct_5 = sorted_vals[int(n * 0.05)] |
| 142 | + ymin = max(0.0, pct_5 * 0.95) if pct_5 > 0.1 else 0.0 |
| 143 | + |
| 144 | + ymax = sorted_vals[-1] |
| 145 | + if target_value is not None: |
| 146 | + ymax = max(ymax, target_value) |
| 147 | + ymax = ymax * 1.05 |
| 148 | + |
| 149 | + # If it's a fractional metric (all values <= 1.0), cap at 1.0 |
| 150 | + if all(v <= 1.0 for v in all_metric_values): |
| 151 | + ymax = min(1.0, ymax) |
| 152 | + else: |
| 153 | + min_val = sorted_vals[0] |
| 154 | + ymin = min_val * 0.95 |
| 155 | + if target_value is not None: |
| 156 | + ymin = min(ymin, target_value * 0.9) |
| 157 | + ymin = max(0.0, ymin) |
| 158 | + |
| 159 | + pct_90 = sorted_vals[int(n * 0.90)] |
| 160 | + ymax = pct_90 |
| 161 | + if target_value is not None: |
| 162 | + ymax = max(ymax, target_value * 1.5) |
| 163 | + |
| 164 | + if ymax <= ymin: |
| 165 | + ymax = ymin * 2.0 if ymin > 0 else 1.0 |
| 166 | + else: |
| 167 | + ymin, ymax = 0.0, 1.0 |
| 168 | + |
| 169 | + # Second pass: interpolate and plot |
| 170 | + for sub, (dfs, style) in workload_curves.items(): |
| 171 | + # --- Time-based Interpolation --- |
| 172 | + # Find global time range for this submission |
| 173 | + all_times = [] |
| 174 | + for df in dfs: |
| 175 | + all_times.extend(df['accumulated_submission_time'].tolist()) |
| 176 | + |
| 177 | + if all_times: |
| 178 | + min_time = min(all_times) |
| 179 | + max_time = max(all_times) |
| 180 | + |
| 181 | + # Create a uniform grid of 150 points for smooth rendering |
| 182 | + grid_times = np.linspace(min_time, max_time, 150) |
| 183 | + |
| 184 | + interpolated_metrics = [] |
| 185 | + for df in dfs: |
| 186 | + # Interpolate this trial's metrics to the unified time grid. |
| 187 | + # Use the last metric value for times beyond the duration of the trial to flatline. |
| 188 | + interp_val = np.interp(grid_times, df['accumulated_submission_time'], df[csv_col_name], right=df[csv_col_name].iloc[-1]) |
| 189 | + interpolated_metrics.append(interp_val) |
| 190 | + |
| 191 | + # Compute mean and standard deviation ignoring NaNs |
| 192 | + mean_time_curve = np.nanmean(interpolated_metrics, axis=0) |
| 193 | + std_time_curve = np.nanstd(interpolated_metrics, axis=0) |
| 194 | + std_time_curve = np.nan_to_num(std_time_curve, nan=0.0) |
| 195 | + |
| 196 | + # Plot Time Curves (in hours) |
| 197 | + time_hours = grid_times / 3600.0 |
| 198 | + ax1.plot(time_hours, mean_time_curve, |
| 199 | + color=style['color'], linestyle=style['linestyle'], |
| 200 | + label=style['label'], alpha=style['alpha'], linewidth=2.5) |
| 201 | + |
| 202 | + ax1.fill_between(time_hours, |
| 203 | + mean_time_curve - std_time_curve, |
| 204 | + mean_time_curve + std_time_curve, |
| 205 | + color=style['color'], alpha=0.10, edgecolor='none') |
| 206 | + |
| 207 | + # --- Step-based Interpolation --- |
| 208 | + all_steps = [] |
| 209 | + for df in dfs: |
| 210 | + all_steps.extend(df['global_step'].tolist()) |
| 211 | + |
| 212 | + if all_steps: |
| 213 | + min_step = min(all_steps) |
| 214 | + max_step = max(all_steps) |
| 215 | + |
| 216 | + grid_steps = np.linspace(min_step, max_step, 150) |
| 217 | + |
| 218 | + interpolated_steps = [] |
| 219 | + for df in dfs: |
| 220 | + interp_val = np.interp(grid_steps, df['global_step'], df[csv_col_name], right=df[csv_col_name].iloc[-1]) |
| 221 | + interpolated_steps.append(interp_val) |
| 222 | + |
| 223 | + mean_step_curve = np.nanmean(interpolated_steps, axis=0) |
| 224 | + std_step_curve = np.nanstd(interpolated_steps, axis=0) |
| 225 | + std_step_curve = np.nan_to_num(std_step_curve, nan=0.0) |
| 226 | + |
| 227 | + steps_k = grid_steps / 1000.0 |
| 228 | + ax2.plot(steps_k, mean_step_curve, |
| 229 | + color=style['color'], linestyle=style['linestyle'], |
| 230 | + label=style['label'], alpha=style['alpha'], linewidth=2.5) |
| 231 | + |
| 232 | + ax2.fill_between(steps_k, |
| 233 | + mean_step_curve - std_step_curve, |
| 234 | + mean_step_curve + std_step_curve, |
| 235 | + color=style['color'], alpha=0.10, edgecolor='none') |
| 236 | + |
132 | 237 | # Configure axes |
133 | 238 | for ax in [ax1, ax2]: |
134 | | - ax.set_yscale('log') |
| 239 | + if not higher_is_better and any(x in target_metric.lower() for x in ['loss', 'perplexity']) and (ymax / (ymin + 1e-8) > 10): |
| 240 | + ax.set_yscale('log') |
| 241 | + else: |
| 242 | + ax.set_yscale('linear') |
| 243 | + |
| 244 | + ax.set_ylim(ymin, ymax) |
| 245 | + |
135 | 246 | if target_value is not None: |
136 | | - ax.axhline(y=target_value, color='r', linestyle='--', label=f'Target ({target_value})') |
137 | | - ax.legend() |
138 | | - ax.grid(True, which="both", ls="-", alpha=0.2) |
| 247 | + ax.axhline(y=target_value, color='#D0021B', linestyle=':', linewidth=1.5, label=f'Target ({target_value})') |
| 248 | + |
| 249 | + ax.legend(frameon=True, facecolor='white', framealpha=0.9, edgecolor='#e5e5e5') |
| 250 | + ax.grid(True, which="major", color="#e8e8e8", linestyle="-", linewidth=0.8) |
139 | 251 |
|
140 | | - ax1.set_xlabel('Accumulated Submission Time (s)') |
141 | | - ax1.set_ylabel(csv_col_name) |
142 | | - ax1.set_title(f'{csv_col_name} vs Time') |
| 252 | + ax.spines['top'].set_visible(False) |
| 253 | + ax.spines['right'].set_visible(False) |
| 254 | + ax.spines['left'].set_color('#cccccc') |
| 255 | + ax.spines['bottom'].set_color('#cccccc') |
| 256 | + |
| 257 | + ax1.set_xlabel('Accumulated Time (hours)', color='#333333', fontweight='semibold') |
| 258 | + ax1.set_ylabel(f'Validation {target_metric.upper()}', color='#333333', fontweight='semibold') |
143 | 259 |
|
144 | | - ax2.set_xlabel('Global Step') |
145 | | - ax2.set_ylabel(csv_col_name) |
146 | | - ax2.set_title(f'{csv_col_name} vs Step') |
| 260 | + ax2.set_xlabel('Global Steps (x10³)', color='#333333', fontweight='semibold') |
| 261 | + ax2.set_ylabel(f'Validation {target_metric.upper()}', color='#333333', fontweight='semibold') |
147 | 262 |
|
148 | 263 | plt.tight_layout() |
149 | 264 |
|
150 | | - # Save plot |
| 265 | + # Save plots |
151 | 266 | save_dir.mkdir(exist_ok=True, parents=True) |
152 | | - out_path = save_dir / f'{workload}_curves.png' |
153 | | - plt.savefig(out_path) |
| 267 | + |
| 268 | + pdf_path = save_dir / f'{workload}_curves.pdf' |
| 269 | + plt.savefig(pdf_path, bbox_inches='tight') |
| 270 | + |
| 271 | + png_path = save_dir / f'{workload}_curves.png' |
| 272 | + plt.savefig(png_path, dpi=300, bbox_inches='tight') |
| 273 | + |
154 | 274 | plt.close(fig) |
155 | | - print(f"Saved plot to {out_path}") |
| 275 | + print(f"Saved PDF to {pdf_path}") |
| 276 | + print(f"Saved PNG to {png_path}") |
156 | 277 |
|
157 | 278 | print("Done.") |
0 commit comments