Skip to content

Commit 10c251d

Browse files
Enhance plot_loss_landscape_3d function with device support and parameter handling; ensure output directory creation for saved files
This fixes crash bugs
1 parent a464e0d commit 10c251d

File tree

1 file changed

+62
-20
lines changed

1 file changed

+62
-20
lines changed

tools/Study_Models.py

Lines changed: 62 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from sklearn.manifold import TSNE
1212
from torch.utils.data import DataLoader, Dataset
1313
from torchviz import make_dot
14+
from tqdm import tqdm
1415

1516
from data import test_texts, test_labels
1617

@@ -205,35 +206,69 @@ def visualize_feature_importance(input_dim_, filename="Feature_Importance.svg"):
205206
plt.close()
206207

207208

208-
def plot_loss_landscape_3d(model_, dataloader_, criterion_, grid_size=30, epsilon=0.01,
209-
filename="Loss_Landscape_3D.html"):
209+
def plot_loss_landscape_3d(model_, dataloader_, criterion_, grid_size=None, epsilon=0.01,
210+
filename="Loss_Landscape_3D.html", device="cpu"):
210211
model_.eval()
211-
param = next(model_.parameters())
212-
param_flat = param.view(-1)
213-
u = epsilon * torch.randn_like(param_flat).view(param.shape)
214-
v = epsilon * torch.randn_like(param_flat).view(param.shape)
212+
model_.to(device)
213+
214+
# Flatten all parameters into a single vector
215+
params = torch.cat([p.view(-1) for p in model_.parameters()])
216+
217+
# Create random directions u, v in parameter space
218+
u = epsilon * torch.randn_like(params)
219+
v = epsilon * torch.randn_like(params)
215220
u /= torch.norm(u)
216221
v /= torch.norm(v)
222+
223+
if grid_size is None:
224+
param_count = sum(p.numel() for p in model_.parameters())
225+
grid_size = max(10, min(50, param_count // 10))
217226
x = np.linspace(-1, 1, grid_size)
218227
y = np.linspace(-1, 1, grid_size)
219228
loss_values = np.zeros((grid_size, grid_size))
220229

221-
for i, dx in enumerate(x):
222-
for j, dy in enumerate(y):
223-
param.data += dx * u.to(DEVICE) + dy * v.to(DEVICE)
224-
total_loss = 0
225-
for X, yb in dataloader_:
226-
X, yb = X.to(DEVICE), yb.to(DEVICE)
227-
yb = yb.float().view(-1, 1) # reshape to match output
228-
out = model_(X)
229-
total_loss += criterion_(out, yb).item()
230-
loss_values[i, j] = total_loss
231-
param.data -= dx * u.to(DEVICE) + dy * v.to(DEVICE)
230+
# Store original parameters
231+
orig_params = params.clone()
232232

233+
with torch.no_grad():
234+
for i, dx in enumerate(tqdm(x, desc="dx")):
235+
for j, dy in enumerate(y):
236+
# Perturbed parameter vector
237+
new_params = orig_params + dx * u + dy * v
238+
239+
# Load new parameters into the model temporarily
240+
idx = 0
241+
for p in model_.parameters():
242+
numel = p.numel()
243+
p.copy_(new_params[idx:idx + numel].view_as(p))
244+
idx += numel
245+
246+
# Compute loss
247+
total_loss = 0
248+
for X, yb in dataloader_:
249+
X, yb = X.to(device), yb.to(device)
250+
yb = yb.float().view(-1, 1)
251+
out = model_(X)
252+
total_loss += criterion_(out, yb).item()
253+
loss_values[i, j] = total_loss
254+
255+
# Restore original parameters
256+
idx = 0
257+
for p in model_.parameters():
258+
numel = p.numel()
259+
p.copy_(orig_params[idx:idx + numel].view_as(p))
260+
idx += numel
261+
262+
# Plot
233263
X_grid, Y_grid = np.meshgrid(x, y)
234264
fig = go.Figure(data=[go.Surface(z=loss_values, x=X_grid, y=Y_grid, colorscale="Viridis")])
235265
fig.update_layout(title="Loss Landscape", scene=dict(xaxis_title="u", yaxis_title="v", zaxis_title="Loss"))
236-
fig.write_html(os.path.join(OUTPUT_DIR, filename))
266+
267+
dir_name = os.path.dirname(filename)
268+
if dir_name:
269+
os.makedirs(dir_name, exist_ok=True)
270+
fig.write_html(filename)
271+
print(f"3D loss landscape saved to {filename}")
237272

238273

239274
def save_model_state_dict(model_, filename="Model_State_Dict.txt"):
@@ -311,6 +346,13 @@ def save_model_summary(model_, filename="Model_Summary.txt"):
311346
print("Saving model summary...")
312347
save_model_summary(model)
313348
print("Running plot_loss_landscape_3d...")
314-
model_cpu = model.to("cpu")
315-
plot_loss_landscape_3d(model_cpu, dataloader, criterion)
349+
DEVICE = "cpu" # or "cuda" if you want GPU
350+
model_cpu = model.to(DEVICE)
351+
plot_loss_landscape_3d(
352+
model_=model_cpu,
353+
dataloader_=dataloader,
354+
criterion_=criterion,
355+
filename="Loss_Landscape_3D.html",
356+
device=DEVICE
357+
)
316358
print("All visualizations completed. Files saved in 'data/' directory.")

0 commit comments

Comments
 (0)