-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsave_results.py
More file actions
249 lines (198 loc) · 8.39 KB
/
save_results.py
File metadata and controls
249 lines (198 loc) · 8.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
"""
Smart results saving module that compares configurations and manages model checkpoints.
"""
import os
import json
import pathlib
import torch
import time
from typing import Dict, Any, Optional, Tuple
import config as cfg
def get_relevant_config(config: Optional[Dict[str, Any]]) -> Dict[str, Any]:
"""Extract only the relevant config parameters for comparison."""
if config is None:
return {
'DATASET': cfg.DATASET.value,
'ENCODER_ARCH': cfg.ENCODER_ARCH.value,
'TOKENIZER_TYPE': cfg.TOKENIZER_TYPE.value,
}
else:
return {
'DATASET': config['DATASET'],
'ENCODER_ARCH': config['ENCODER_ARCH'],
'TOKENIZER_TYPE': config['TOKENIZER_TYPE'],
}
def load_config_from_file(config_path: str) -> Dict[str, Any]:
"""Load configuration from a saved config file."""
try:
return cfg.import_config(config_path, override=False)
except Exception as e:
print(f"Error loading config from {config_path}: {e}")
return {}
def load_results_from_file(results_path: str) -> Dict[str, Any]:
"""Load results from a saved results file."""
try:
with open(results_path, 'r') as f:
return json.load(f)
except Exception as e:
print(f"Error loading results from {results_path}: {e}")
return {}
def configs_match(config1: Dict[str, Any], config2: Dict[str, Any], verbose: bool = True) -> bool:
"""Compare only the relevant config parameters."""
relevant_keys = ['DATASET', 'ENCODER_ARCH', 'TOKENIZER_TYPE']
for key in relevant_keys:
# Handle cases where values might be enum strings or different formats
val1 = config1.get(key, '').value
val2 = str(config2.get(key, ''))
# Extract the enum value if it contains a period (e.g., "Dataset.COCO" -> "coco")
if '.' in val1:
val1 = val1.split('.')[-1].lower()
if '.' in val2:
val2 = val2.split('.')[-1].lower()
if verbose:
print(f" Comparing {key}: '{val1}' vs '{val2}'")
if val1.lower() != val2.lower():
if verbose:
print(f" -> Mismatch found in {key}")
return False
return True
def find_matching_config_folder(
results_root: pathlib.Path,
current_config: Dict[str, Any],
verbose: bool = True
) -> Optional[pathlib.Path]:
"""
Search for an existing subfolder with matching configuration.
Returns the path to the matching folder, or None if no match is found.
"""
if not results_root.exists():
return None
if verbose:
print("Searching for matching configuration folder...")
for subfolder in results_root.iterdir():
if not subfolder.is_dir():
continue
if verbose:
print(f" - Checking folder: {subfolder.name}")
config_file = subfolder / 'config.json'
if not config_file.exists():
if verbose:
print(" -> No config file found, skipping.")
continue
saved_config = load_config_from_file(str(config_file))
if configs_match(current_config, saved_config, verbose=verbose):
if verbose:
print(f" -> Match found: {subfolder.name}")
return subfolder
return None
def extract_test_loss(results: Dict[str, Any]) -> Optional[float]:
"""Extract test loss from results dictionary."""
key = 'test_loss'
if key in results:
value = results[key]
if isinstance(value, (int, float)):
return float(value)
elif isinstance(value, list) and len(value) > 0:
return float(value[-1]) # Return last element if it's a list
return None
def save_results_smart(
model: torch.nn.Module,
results: Dict[str, Any],
current_config: Dict[str, Any],
verbose: bool = True
) -> Tuple[bool, str, str]:
"""
Smart saving mechanism that compares configurations and manages model checkpoints.
Args:
model: The model to save (state_dict will be extracted)
results: Dictionary containing training results (must include test loss)
config_root: Root directory for results.
Returns:
Tuple of (success: bool, message: str)
"""
config_root = pathlib.Path(current_config["CONFIG_ROOT"])
results_root = config_root / 'results'
results_root.mkdir(parents=True, exist_ok=True)
current_config = get_relevant_config(config=current_config)
current_test_loss = extract_test_loss(results)
if verbose:
print(f"Current test loss: {current_test_loss}")
if current_test_loss is None:
return False, "Error: Could not extract test loss from results", ""
# Look for matching config folder
matching_folder = find_matching_config_folder(results_root, current_config, verbose=verbose)
if matching_folder is not None:
# Found matching config - compare test losses
existing_results_file = matching_folder / 'training_results.json'
if existing_results_file.exists():
existing_results = load_results_from_file(str(existing_results_file))
existing_test_loss = extract_test_loss(existing_results)
if verbose:
print(f"Existing test loss: {existing_test_loss}")
if existing_test_loss is not None and current_test_loss >= existing_test_loss:
# Current test loss is not better
return False, (
f"Current test loss ({current_test_loss:.6f}) is not better than "
f"existing test loss ({existing_test_loss:.6f}). Not overwriting."
), str(matching_folder)
# Current model is better or no existing results - overwrite
target_folder = matching_folder
message = f"Overwriting results in matching config folder: {target_folder.name}"
else:
# No matching config found - create new folder
timestamp = time.strftime("%Y%m%d-%H%M%S")
target_folder = results_root / f"config_{timestamp}"
target_folder.mkdir(parents=True, exist_ok=True)
message = f"Created new config folder: {target_folder.name}"
# Save model checkpoint
try:
model_filename = f'cptr_model.pth'
model_path = target_folder / model_filename
torch.save(model.state_dict(), model_path)
message += f"\n - Model saved: {model_filename}"
except Exception as e:
return False, f"Error saving model: {e}", str(target_folder)
# Save results
try:
results_path = target_folder / 'training_results.json'
with open(results_path, 'w') as f:
json.dump(results, f, indent=4)
message += f"\n - Results saved: training_results.json"
except Exception as e:
return False, f"Error saving results: {e}", str(target_folder)
# Save config
try:
config_path = target_folder / 'config.json'
cfg.export_config(str(config_path))
message += f"\n - Config saved: config.json"
except Exception as e:
return False, f"Error saving config: {e}", str(target_folder)
return True, message, str(target_folder)
def list_saved_configs(current_config: Dict[str, Any]) -> Dict[str, Dict[str, Any]]:
"""
List all saved configurations in the results directory.
Returns:
Dictionary mapping folder names to their configuration info.
"""
config_root = pathlib.Path(current_config["CONFIG_ROOT"])
results_root = config_root / 'results'
configs_info = {}
if not results_root.exists():
return configs_info
for subfolder in sorted(results_root.iterdir()):
if not subfolder.is_dir():
continue
config_file = subfolder / 'config.json'
results_file = subfolder / 'training_results.json'
info = {
'path': str(subfolder),
'config': None,
'test_loss': None,
}
if config_file.exists():
info['config'] = load_config_from_file(str(config_file))
if results_file.exists():
results = load_results_from_file(str(results_file))
info['test_loss'] = extract_test_loss(results)
configs_info[subfolder.name] = info
return configs_info