Skip to content

Commit f26c529

Browse files
authored
Merge pull request #141 from SongshGeoLab/dev
fix(experiment): 🐛 Update repeat_id to run_id for consistency in experiment logging
2 parents a07c0fc + aa04e97 commit f26c529

9 files changed

Lines changed: 286 additions & 94 deletions

File tree

abses/core/experiment.py

Lines changed: 23 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -122,10 +122,10 @@ def run_single(
122122
hooks:
123123
The hooks to run after the model is run.
124124
"""
125-
job_id, repeat_id = key
125+
job_id, run_id = key
126126
model = model_cls(
127127
parameters=cfg,
128-
run_id=repeat_id,
128+
run_id=run_id,
129129
seed=seed,
130130
**kwargs,
131131
)
@@ -135,7 +135,7 @@ def run_single(
135135
for hook_name, hook_func in hooks.items():
136136
logger.info(f"Running hook {hook_name}.")
137137
_call_hook_with_optional_args(
138-
hook_func, model, job_id=job_id, repeat_id=repeat_id
138+
hook_func, model, job_id=job_id, run_id=run_id
139139
)
140140
return key, seed, results
141141

@@ -370,45 +370,17 @@ def _load_hydra_cfg(
370370

371371
return cfg
372372

373-
# def _get_logging_mode(self, repeat_id: Optional[int] = None) -> str | bool:
374-
# log_mode = self.exp_config.get("logging", "once")
375-
# if log_mode == "once":
376-
# if repeat_id == 1:
377-
# logging: bool | str = self.name
378-
# else:
379-
# return False
380-
# elif bool(log_mode):
381-
# logging = f"{self.name}_{repeat_id}"
382-
# else:
383-
# logging = False
384-
# return logging
385-
386-
# def _update_log_config(
387-
# self, config, repeat_id: Optional[int] = None
388-
# ) -> bool:
389-
# """Update the log configuration."""
390-
# if isinstance(config, dict):
391-
# config = DictConfig(config)
392-
# OmegaConf.set_struct(config, False)
393-
# log_name = self._get_logging_mode(repeat_id=repeat_id)
394-
# if not log_name:
395-
# config["log"] = False
396-
# return config
397-
# logging_cfg = OmegaConf.create({"log": {"name": log_name}})
398-
# config = OmegaConf.merge(config, logging_cfg)
399-
# return config
400-
401-
def _get_seed(self, repeat_id: int, job_id: Optional[int] = None) -> Optional[int]:
373+
def _get_seed(self, run_id: int, job_id: Optional[int] = None) -> Optional[int]:
402374
"""获取每次运行的随机种子
403375
404376
使用基础种子初始化随机数生成器,为每次运行生成唯一的随机种子。
405377
这样可以保证:
406378
1. 如果基础种子相同,生成的种子序列也相同
407-
2. 不同的 job_id 和 repeat_id 组合会得到不同的种子
379+
2. 不同的 job_id 和 run_id 组合会得到不同的种子
408380
3. 种子序列具有更好的随机性
409381
410382
Args:
411-
repeat_id: 重复实验的ID
383+
run_id: 重复实验的ID
412384
413385
Returns:
414386
如果没有设置基础种子则返回 None,否则返回生成的随机种子
@@ -419,7 +391,7 @@ def _get_seed(self, repeat_id: int, job_id: Optional[int] = None) -> Optional[in
419391
if job_id is None:
420392
job_id = self.job_id
421393
# 使用基础种子和 job_id 创建随机数生成器
422-
r = random.Random(self._base_seed + job_id * 1000 + repeat_id)
394+
r = random.Random(self._base_seed + job_id * 1000 + run_id)
423395
return r.randrange(2**32)
424396

425397
def _get_logging_mode(self) -> str:
@@ -431,13 +403,13 @@ def _get_logging_mode(self) -> str:
431403
return get_log_mode(self._cfg)
432404

433405
def _get_log_file_path(
434-
self, log_name: str, repeat_id: int, logging_mode: str
406+
self, log_name: str, run_id: int, logging_mode: str
435407
) -> Optional[Path]:
436408
"""Get log file path for a specific repeat.
437409
438410
Args:
439411
log_name: Base log file name.
440-
repeat_id: Repeat ID (1-indexed).
412+
run_id: Repeat ID (1-indexed).
441413
logging_mode: Logging mode.
442414
443415
Returns:
@@ -449,7 +421,7 @@ def _get_log_file_path(
449421
outpath=self.outpath,
450422
log_name=log_name,
451423
logging_mode=logging_mode,
452-
repeat_id=repeat_id,
424+
run_id=run_id,
453425
)
454426

455427
def _log_experiment_info(
@@ -518,18 +490,18 @@ def _batch_run_repeats(
518490
if self._is_hydra_parallel() or number_process == 1:
519491
# Hydra 并行或指定单进程时,顺序执行
520492
disable = repeats == 1 or not display_progress
521-
for repeat_id in tqdm(
493+
for run_id in tqdm(
522494
range(1, repeats + 1),
523495
disable=disable,
524496
desc=f"Job {self.job_id} repeats {repeats} times.",
525497
):
526498
# Log separator for merge mode
527-
if logging_mode == "merge" and repeat_id > 1:
499+
if logging_mode == "merge" and run_id > 1:
528500
# Note: Separator will be logged in model setup
529501
pass
530502

531503
# Get log file path for this repeat
532-
log_path = self._get_log_file_path(log_name, repeat_id, logging_mode)
504+
log_path = self._get_log_file_path(log_name, run_id, logging_mode)
533505

534506
# Display log file location for separate mode
535507
# This should only go to stdout, not to model run log files
@@ -539,14 +511,14 @@ def _batch_run_repeats(
539511
and log_path is not None
540512
):
541513
# Use print instead of logger to avoid writing to model run log files
542-
print(f"Repeat {repeat_id}: Logging to {log_path}")
514+
print(f"Repeat {run_id}: Logging to {log_path}")
543515

544516
run_single(
545517
model_cls=self.model_cls,
546518
cfg=cfg,
547-
key=(self.job_id, repeat_id),
519+
key=(self.job_id, run_id),
548520
outpath=self.outpath,
549-
seed=self._get_seed(repeat_id),
521+
seed=self._get_seed(run_id),
550522
hooks=self._manager.hooks,
551523
**self._extra_kwargs,
552524
)
@@ -564,13 +536,13 @@ def _batch_run_repeats(
564536
delayed(run_single)(
565537
model_cls=self.model_cls,
566538
cfg=cfg,
567-
key=(self.job_id, repeat_id),
539+
key=(self.job_id, run_id),
568540
outpath=self.outpath,
569-
seed=self._get_seed(repeat_id),
541+
seed=self._get_seed(run_id),
570542
hooks=self._manager.hooks,
571543
**self._extra_kwargs,
572544
)
573-
for repeat_id in tqdm(
545+
for run_id in tqdm(
574546
range(1, repeats + 1),
575547
disable=not display_progress,
576548
desc=f"Job {self.job_id} repeats {repeats} times, with {number_process} processes.",
@@ -644,22 +616,22 @@ def _call_hook_with_optional_args(
644616
hook_func: Callable,
645617
model: MainModelProtocol,
646618
job_id: Optional[int] = None,
647-
repeat_id: Optional[int] = None,
619+
run_id: Optional[int] = None,
648620
) -> Any:
649621
"""根据钩子函数的参数签名动态调用函数
650622
651623
Args:
652624
hook_func: 要调用的钩子函数
653625
model: 模型实例
654626
job_id: 可选的任务ID
655-
repeat_id: 可选的重复实验ID
627+
run_id: 可选的重复实验ID
656628
"""
657629
sig = inspect.signature(hook_func)
658630
hook_args = {}
659631

660632
if "job_id" in sig.parameters:
661633
hook_args["job_id"] = job_id
662-
if "repeat_id" in sig.parameters:
663-
hook_args["repeat_id"] = repeat_id
634+
if "run_id" in sig.parameters:
635+
hook_args["run_id"] = run_id
664636

665637
return hook_func(model, **hook_args)

abses/core/job_manager.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from __future__ import annotations
99

10+
import warnings
1011
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Type
1112

1213
import pandas as pd
@@ -102,7 +103,7 @@ def update_result(
102103
"""更新实验结果
103104
104105
Args:
105-
key: (job_id, repeat_id) tuple
106+
key: (job_id, run_id) tuple
106107
overrides: Configuration overrides for this run
107108
datasets: Row-like mapping of metrics/values to store
108109
seed: Random seed used for this run
@@ -115,25 +116,44 @@ def dict_to_df(self, results: dict) -> pd.DataFrame:
115116
"""将嵌套字典转换为 DataFrame
116117
117118
Args:
118-
results: 形如 {(job_id, repeat_id): {'metric': value}} 的字典
119+
results: 形如 {(job_id, run_id): {'metric': value}} 的字典
119120
120121
Returns:
121-
包含 job_id, repeat_id 和指标值的 DataFrame
122+
包含 job_id, run_id 和指标值的 DataFrame
122123
"""
123124
return pd.DataFrame(results.values(), index=self.index)
124125

125126
def get_datasets(
126127
self,
127128
seed: bool = True,
128129
) -> pd.DataFrame:
129-
"""获取所有实验结果的 DataFrame"""
130+
"""获取所有实验结果的 DataFrame
131+
132+
Note:
133+
The ``repeat_id`` column is **deprecated** and will be removed in a
134+
future version. Please use the ``run_id`` column instead.
135+
"""
130136
to_concat = []
131137
to_concat.append(self.dict_to_df(self._overrides))
132138
if seed:
133139
seed = pd.Series(self._seeds, name="seed", index=self.index)
134140
to_concat.append(seed)
135141
to_concat.append(self.dict_to_df(self._datasets))
136-
return pd.concat(to_concat, axis=1).reset_index()
142+
df = pd.concat(to_concat, axis=1).reset_index()
143+
144+
# Backward compatibility: if legacy results contain a `repeat_id` column
145+
# (e.g. from older versions or custom datasets), mirror it into `run_id`
146+
# and emit a deprecation warning. New code should only rely on `run_id`.
147+
if "repeat_id" in df.columns and "run_id" not in df.columns:
148+
warnings.warn(
149+
"Column 'repeat_id' is deprecated and will be removed in a future "
150+
"version. Please use 'run_id' instead.",
151+
DeprecationWarning,
152+
stacklevel=2,
153+
)
154+
df["run_id"] = df["repeat_id"]
155+
156+
return df
137157

138158
def add_a_hook(
139159
self,

abses/core/model.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,9 @@ def __init__(
140140
tracker_backend = create_tracker(tracker_cfg, model=self)
141141
collector_cfg = prepare_collector_config(tracker_cfg)
142142
self.datacollector: ABSESpyDataCollector = ABSESpyDataCollector(
143-
reports=collector_cfg, tracker=tracker_backend
143+
reports=collector_cfg,
144+
tracker=tracker_backend,
145+
run_id=run_id,
144146
)
145147

146148
# Setup logging BEFORE initialize() so user logs in initialize() are captured
@@ -347,7 +349,7 @@ def _setup_logger(self, log_cfg: Dict[str, Any]) -> None:
347349
rotation=rotation,
348350
retention=retention,
349351
logging_mode=logging_mode,
350-
repeat_id=self.run_id,
352+
run_id=self.run_id,
351353
file_level=file_level,
352354
file_format=file_format,
353355
file_datefmt=file_datefmt,

abses/utils/datacollector.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -116,15 +116,18 @@ def __init__(
116116
self,
117117
reports: Dict[ReportType, Dict[str, Reporter]] | None = None,
118118
tracker: Optional[TrackerProtocol] = None,
119+
run_id: Optional[int] = None,
119120
):
120121
"""Initialize data collector.
121122
122123
Args:
123124
reports: Reporters configuration.
124125
tracker: Optional tracker backend.
126+
run_id: Optional run id.
125127
"""
126128
reports = reports or {}
127129
self.tracker = tracker
130+
self.run_id = run_id
128131
self.model_reporters: Dict[str, Reporter] = {}
129132
self.final_reporters: Dict[str, Reporter] = {}
130133
self.agent_reporters: Dict[str, Dict[str, Reporter]] = {}
@@ -162,6 +165,13 @@ def add_reporters(
162165
for name, reporter in reporters.items():
163166
self._new_agent_reporter(breed=item, name=name, reporter=reporter)
164167

168+
def _add_run_id_to_data(
169+
self, data: pd.DataFrame | Dict[str, Any]
170+
) -> pd.DataFrame | Dict[str, Any]:
171+
if self.run_id is not None:
172+
data["run_id"] = self.run_id
173+
return data
174+
165175
def _new_model_reporter(self, name: str, reporter: Reporter) -> None:
166176
"""Add a new model-level reporter to collect data.
167177
@@ -216,8 +226,9 @@ def get_model_vars_dataframe(self):
216226
logger.warning(
217227
"No model reporters have been definedreturning empty DataFrame."
218228
)
219-
220-
return pd.DataFrame(self.model_vars)
229+
df = pd.DataFrame(self.model_vars)
230+
df = self._add_run_id_to_data(df)
231+
return df
221232

222233
def get_agent_vars_dataframe(self, breed: Optional[str] = None) -> pd.DataFrame:
223234
"""获取某种 Agents 的 DataFrame"""
@@ -229,8 +240,12 @@ def get_agent_vars_dataframe(self, breed: Optional[str] = None) -> pd.DataFrame:
229240
if not self.agent_reporters:
230241
logger.warning("No agent reporters have been defined in the DataCollector.")
231242
if results := self._agent_records.get(breed):
232-
return pd.concat([pd.DataFrame(res) for res in results])
233-
return pd.DataFrame()
243+
df = pd.concat([pd.DataFrame(res) for res in results])
244+
else:
245+
logger.warning(f"No agent records found for breed {breed}.")
246+
df = pd.DataFrame()
247+
df = self._add_run_id_to_data(data=df)
248+
return df
234249

235250
def get_final_vars_report(self, model: MainModel) -> Dict[str, Any]:
236251
"""Report at the end of this model.
@@ -239,11 +254,10 @@ def get_final_vars_report(self, model: MainModel) -> Dict[str, Any]:
239254
A dictionary mapping variable names to their computed values.
240255
"""
241256
if not self.final_reporters:
242-
logger.warning(
243-
"No final reporters have been defined, returning empty dict."
244-
)
257+
logger.info("No final reporters have been defined.")
245258
return {}
246259
results = {var: func(model) for var, func in self.final_reporters.items()}
260+
self._add_run_id_to_data(results)
247261
if self.tracker is not None:
248262
self.tracker.log_final_metrics(results)
249263
return results

abses/utils/log_config.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -229,15 +229,15 @@ def determine_log_file_path(
229229
outpath: Optional[Path],
230230
log_name: str,
231231
logging_mode: str = "once",
232-
repeat_id: Optional[int] = None,
232+
run_id: Optional[int] = None,
233233
) -> Optional[Path]:
234234
"""Determine log file path based on logging mode.
235235
236236
Args:
237237
outpath: Output directory for log files.
238238
log_name: Base log file name (without extension).
239239
logging_mode: Logging mode - 'once', 'separate', or 'merge'.
240-
repeat_id: Repeat ID for the current run (1-indexed).
240+
run_id: Run ID for the current run (1-indexed).
241241
242242
Returns:
243243
Path to log file, or None if logging should be disabled.
@@ -250,21 +250,21 @@ def determine_log_file_path(
250250

251251
if logging_mode == "once":
252252
# Only log the first repeat
253-
if repeat_id is None or repeat_id == 1:
253+
if run_id is None or run_id == 1:
254254
return outpath / f"{log_name}.log"
255255
return None
256256
elif logging_mode == "separate":
257257
# Each repeat gets its own file
258-
# In separate mode, repeat_id must be provided
259-
if repeat_id is None:
258+
# In separate mode, run_id must be provided
259+
if run_id is None:
260260
return None # Don't create default file in separate mode
261-
return outpath / f"{log_name}_{repeat_id}.log"
261+
return outpath / f"{log_name}_{run_id}.log"
262262
elif logging_mode == "merge":
263263
# All repeats go to the same file
264264
return outpath / f"{log_name}.log"
265265
else:
266266
# Unknown mode, default to once behavior
267-
if repeat_id is None or repeat_id == 1:
267+
if run_id is None or run_id == 1:
268268
return outpath / f"{log_name}.log"
269269
return None
270270

0 commit comments

Comments
 (0)