|
11 | 11 | from sklearn.manifold import TSNE |
12 | 12 | from torch.utils.data import DataLoader, Dataset |
13 | 13 | from torchviz import make_dot |
| 14 | +from tqdm import tqdm |
14 | 15 |
|
15 | 16 | from data import test_texts, test_labels |
16 | 17 |
|
@@ -205,35 +206,69 @@ def visualize_feature_importance(input_dim_, filename="Feature_Importance.svg"): |
205 | 206 | plt.close() |
206 | 207 |
|
207 | 208 |
|
208 | | -def plot_loss_landscape_3d(model_, dataloader_, criterion_, grid_size=30, epsilon=0.01, |
209 | | - filename="Loss_Landscape_3D.html"): |
| 209 | +def plot_loss_landscape_3d(model_, dataloader_, criterion_, grid_size=None, epsilon=0.01, |
| 210 | + filename="Loss_Landscape_3D.html", device="cpu"): |
210 | 211 | model_.eval() |
211 | | - param = next(model_.parameters()) |
212 | | - param_flat = param.view(-1) |
213 | | - u = epsilon * torch.randn_like(param_flat).view(param.shape) |
214 | | - v = epsilon * torch.randn_like(param_flat).view(param.shape) |
| 212 | + model_.to(device) |
| 213 | + |
| 214 | + # Flatten all parameters into a single vector |
| 215 | + params = torch.cat([p.view(-1) for p in model_.parameters()]) |
| 216 | + |
| 217 | + # Create random directions u, v in parameter space |
| 218 | + u = epsilon * torch.randn_like(params) |
| 219 | + v = epsilon * torch.randn_like(params) |
215 | 220 | u /= torch.norm(u) |
216 | 221 | v /= torch.norm(v) |
| 222 | + |
| 223 | + if grid_size is None: |
| 224 | + param_count = sum(p.numel() for p in model_.parameters()) |
| 225 | + grid_size = max(10, min(50, param_count // 10)) |
217 | 226 | x = np.linspace(-1, 1, grid_size) |
218 | 227 | y = np.linspace(-1, 1, grid_size) |
219 | 228 | loss_values = np.zeros((grid_size, grid_size)) |
220 | 229 |
|
221 | | - for i, dx in enumerate(x): |
222 | | - for j, dy in enumerate(y): |
223 | | - param.data += dx * u.to(DEVICE) + dy * v.to(DEVICE) |
224 | | - total_loss = 0 |
225 | | - for X, yb in dataloader_: |
226 | | - X, yb = X.to(DEVICE), yb.to(DEVICE) |
227 | | - yb = yb.float().view(-1, 1) # reshape to match output |
228 | | - out = model_(X) |
229 | | - total_loss += criterion_(out, yb).item() |
230 | | - loss_values[i, j] = total_loss |
231 | | - param.data -= dx * u.to(DEVICE) + dy * v.to(DEVICE) |
| 230 | + # Store original parameters |
| 231 | + orig_params = params.clone() |
232 | 232 |
|
| 233 | + with torch.no_grad(): |
| 234 | + for i, dx in enumerate(tqdm(x, desc="dx")): |
| 235 | + for j, dy in enumerate(y): |
| 236 | + # Perturbed parameter vector |
| 237 | + new_params = orig_params + dx * u + dy * v |
| 238 | + |
| 239 | + # Load new parameters into the model temporarily |
| 240 | + idx = 0 |
| 241 | + for p in model_.parameters(): |
| 242 | + numel = p.numel() |
| 243 | + p.copy_(new_params[idx:idx + numel].view_as(p)) |
| 244 | + idx += numel |
| 245 | + |
| 246 | + # Compute loss |
| 247 | + total_loss = 0 |
| 248 | + for X, yb in dataloader_: |
| 249 | + X, yb = X.to(device), yb.to(device) |
| 250 | + yb = yb.float().view(-1, 1) |
| 251 | + out = model_(X) |
| 252 | + total_loss += criterion_(out, yb).item() |
| 253 | + loss_values[i, j] = total_loss |
| 254 | + |
| 255 | + # Restore original parameters |
| 256 | + idx = 0 |
| 257 | + for p in model_.parameters(): |
| 258 | + numel = p.numel() |
| 259 | + p.copy_(orig_params[idx:idx + numel].view_as(p)) |
| 260 | + idx += numel |
| 261 | + |
| 262 | + # Plot |
233 | 263 | X_grid, Y_grid = np.meshgrid(x, y) |
234 | 264 | fig = go.Figure(data=[go.Surface(z=loss_values, x=X_grid, y=Y_grid, colorscale="Viridis")]) |
235 | 265 | fig.update_layout(title="Loss Landscape", scene=dict(xaxis_title="u", yaxis_title="v", zaxis_title="Loss")) |
236 | | - fig.write_html(os.path.join(OUTPUT_DIR, filename)) |
| 266 | + |
| 267 | + dir_name = os.path.dirname(filename) |
| 268 | + if dir_name: |
| 269 | + os.makedirs(dir_name, exist_ok=True) |
| 270 | + fig.write_html(filename) |
| 271 | + print(f"3D loss landscape saved to {filename}") |
237 | 272 |
|
238 | 273 |
|
239 | 274 | def save_model_state_dict(model_, filename="Model_State_Dict.txt"): |
@@ -311,6 +346,13 @@ def save_model_summary(model_, filename="Model_Summary.txt"): |
311 | 346 | print("Saving model summary...") |
312 | 347 | save_model_summary(model) |
313 | 348 | print("Running plot_loss_landscape_3d...") |
314 | | -model_cpu = model.to("cpu") |
315 | | -plot_loss_landscape_3d(model_cpu, dataloader, criterion) |
| 349 | +DEVICE = "cpu" # or "cuda" if you want GPU |
| 350 | +model_cpu = model.to(DEVICE) |
| 351 | +plot_loss_landscape_3d( |
| 352 | + model_=model_cpu, |
| 353 | + dataloader_=dataloader, |
| 354 | + criterion_=criterion, |
| 355 | + filename="../Model_SenseMacro.4n1_Data_Visualization/Loss_Landscape_3D.html", |
| 356 | + device=DEVICE |
| 357 | +) |
316 | 358 | print("All visualizations completed. Files saved in 'data/' directory.") |
0 commit comments