Skip to content

Commit 04eea5a

Browse files
Update model loading to support checkpoint structure in Study_Models
1 parent e89d247 commit 04eea5a

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

tools/Study_Models.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,8 @@ def save_model_summary(model_, filename="Model_Summary.txt"):
275275
# Load model
276276
model = SimpleNN(input_dim).to(DEVICE)
277277
if os.path.exists(MODEL_PATH):
278-
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
278+
checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
279+
model.load_state_dict(checkpoint["model_state_dict"])
279280
model.eval()
280281

281282
# Sample input for activation visualization

0 commit comments

Comments
 (0)