Skip to content

Commit f6038c4

Browse files
committed
exclude is_los_nan from metrics
1 parent 4612204 commit f6038c4

2 files changed

Lines changed: 12 additions & 4 deletions

File tree

pufferlib/pufferl.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -187,12 +187,13 @@ def _train_worker(args):
187187

188188
backend.close(pufferl)
189189

190-
def _downsample_logs(all_logs, n):
190+
def _downsample_logs(all_logs, n, exclude_keys=()):
191191
if not all_logs:
192192
raise ValueError('Cannot downsample empty logs')
193193

194194
expected_keys = set(all_logs[0])
195-
metrics = {k: [[]] for k in all_logs[0]}
195+
exclude_keys = set(exclude_keys)
196+
metrics = {k: [[]] for k in all_logs[0] if k not in exclude_keys}
196197
logged_timesteps = all_logs[-1]['agent_steps']
197198
next_bin = logged_timesteps / (n - 1) if n > 1 else np.inf
198199
for idx, log in enumerate(all_logs):
@@ -206,7 +207,8 @@ def _downsample_logs(all_logs, n):
206207
)
207208

208209
for k, v in log.items():
209-
metrics[k][-1].append(v)
210+
if k in metrics:
211+
metrics[k][-1].append(v)
210212

211213
if log['agent_steps'] < next_bin:
212214
continue
@@ -365,7 +367,12 @@ def _train(env_name, args, result_queue=None, verbose=False, sweep_early_stop=No
365367
# This version has the training perf logs and eval env logs
366368
all_logs.append(flat_logs)
367369

368-
metrics = _downsample_logs(all_logs, args['sweep']['downsample'])
370+
exclude_keys = ()
371+
if sweep_early_stop is not None:
372+
exclude_keys = pufferlib.sweep.SWEEP_NON_METRIC_LOG_KEYS
373+
374+
metrics = _downsample_logs(
375+
all_logs, args['sweep']['downsample'], exclude_keys=exclude_keys)
369376

370377
# Match-mode: single observation at final-training cost. Protein's curve
371378
# fit collapses to one point — we only trust the match winrate, not any

pufferlib/sweep.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
'is_loss_nan': False,
3030
'early_stop_threshold': EARLY_STOP_THRESHOLD_FLOOR,
3131
}
32+
SWEEP_NON_METRIC_LOG_KEYS = frozenset(('is_loss_nan',))
3233

3334
def apply_early_stop_log_defaults(logs):
3435
logs.update(SWEEP_EARLY_STOP_LOG_DEFAULTS)

0 commit comments

Comments
 (0)