@@ -122,10 +122,10 @@ def run_single(
122122 hooks:
123123 The hooks to run after the model is run.
124124 """
125- job_id , repeat_id = key
125+ job_id , run_id = key
126126 model = model_cls (
127127 parameters = cfg ,
128- run_id = repeat_id ,
128+ run_id = run_id ,
129129 seed = seed ,
130130 ** kwargs ,
131131 )
@@ -135,7 +135,7 @@ def run_single(
135135 for hook_name , hook_func in hooks .items ():
136136 logger .info (f"Running hook { hook_name } ." )
137137 _call_hook_with_optional_args (
138- hook_func , model , job_id = job_id , repeat_id = repeat_id
138+ hook_func , model , job_id = job_id , run_id = run_id
139139 )
140140 return key , seed , results
141141
@@ -370,45 +370,17 @@ def _load_hydra_cfg(
370370
371371 return cfg
372372
373- # def _get_logging_mode(self, repeat_id: Optional[int] = None) -> str | bool:
374- # log_mode = self.exp_config.get("logging", "once")
375- # if log_mode == "once":
376- # if repeat_id == 1:
377- # logging: bool | str = self.name
378- # else:
379- # return False
380- # elif bool(log_mode):
381- # logging = f"{self.name}_{repeat_id}"
382- # else:
383- # logging = False
384- # return logging
385-
386- # def _update_log_config(
387- # self, config, repeat_id: Optional[int] = None
388- # ) -> bool:
389- # """Update the log configuration."""
390- # if isinstance(config, dict):
391- # config = DictConfig(config)
392- # OmegaConf.set_struct(config, False)
393- # log_name = self._get_logging_mode(repeat_id=repeat_id)
394- # if not log_name:
395- # config["log"] = False
396- # return config
397- # logging_cfg = OmegaConf.create({"log": {"name": log_name}})
398- # config = OmegaConf.merge(config, logging_cfg)
399- # return config
400-
401- def _get_seed (self , repeat_id : int , job_id : Optional [int ] = None ) -> Optional [int ]:
373+ def _get_seed (self , run_id : int , job_id : Optional [int ] = None ) -> Optional [int ]:
402374 """获取每次运行的随机种子
403375
404376 使用基础种子初始化随机数生成器,为每次运行生成唯一的随机种子。
405377 这样可以保证:
406378 1. 如果基础种子相同,生成的种子序列也相同
407- 2. 不同的 job_id 和 repeat_id 组合会得到不同的种子
379+ 2. 不同的 job_id 和 run_id 组合会得到不同的种子
408380 3. 种子序列具有更好的随机性
409381
410382 Args:
411- repeat_id : 重复实验的ID
383+ run_id : 重复实验的ID
412384
413385 Returns:
414386 如果没有设置基础种子则返回 None,否则返回生成的随机种子
@@ -419,7 +391,7 @@ def _get_seed(self, repeat_id: int, job_id: Optional[int] = None) -> Optional[in
419391 if job_id is None :
420392 job_id = self .job_id
421393 # 使用基础种子和 job_id 创建随机数生成器
422- r = random .Random (self ._base_seed + job_id * 1000 + repeat_id )
394+ r = random .Random (self ._base_seed + job_id * 1000 + run_id )
423395 return r .randrange (2 ** 32 )
424396
425397 def _get_logging_mode (self ) -> str :
@@ -431,13 +403,13 @@ def _get_logging_mode(self) -> str:
431403 return get_log_mode (self ._cfg )
432404
433405 def _get_log_file_path (
434- self , log_name : str , repeat_id : int , logging_mode : str
406+ self , log_name : str , run_id : int , logging_mode : str
435407 ) -> Optional [Path ]:
436408 """Get log file path for a specific repeat.
437409
438410 Args:
439411 log_name: Base log file name.
440- repeat_id : Repeat ID (1-indexed).
412+ run_id : Repeat ID (1-indexed).
441413 logging_mode: Logging mode.
442414
443415 Returns:
@@ -449,7 +421,7 @@ def _get_log_file_path(
449421 outpath = self .outpath ,
450422 log_name = log_name ,
451423 logging_mode = logging_mode ,
452- repeat_id = repeat_id ,
424+ run_id = run_id ,
453425 )
454426
455427 def _log_experiment_info (
@@ -518,18 +490,18 @@ def _batch_run_repeats(
518490 if self ._is_hydra_parallel () or number_process == 1 :
519491 # Hydra 并行或指定单进程时,顺序执行
520492 disable = repeats == 1 or not display_progress
521- for repeat_id in tqdm (
493+ for run_id in tqdm (
522494 range (1 , repeats + 1 ),
523495 disable = disable ,
524496 desc = f"Job { self .job_id } repeats { repeats } times." ,
525497 ):
526498 # Log separator for merge mode
527- if logging_mode == "merge" and repeat_id > 1 :
499+ if logging_mode == "merge" and run_id > 1 :
528500 # Note: Separator will be logged in model setup
529501 pass
530502
531503 # Get log file path for this repeat
532- log_path = self ._get_log_file_path (log_name , repeat_id , logging_mode )
504+ log_path = self ._get_log_file_path (log_name , run_id , logging_mode )
533505
534506 # Display log file location for separate mode
535507 # This should only go to stdout, not to model run log files
@@ -539,14 +511,14 @@ def _batch_run_repeats(
539511 and log_path is not None
540512 ):
541513 # Use print instead of logger to avoid writing to model run log files
542- print (f"Repeat { repeat_id } : Logging to { log_path } " )
514+ print (f"Repeat { run_id } : Logging to { log_path } " )
543515
544516 run_single (
545517 model_cls = self .model_cls ,
546518 cfg = cfg ,
547- key = (self .job_id , repeat_id ),
519+ key = (self .job_id , run_id ),
548520 outpath = self .outpath ,
549- seed = self ._get_seed (repeat_id ),
521+ seed = self ._get_seed (run_id ),
550522 hooks = self ._manager .hooks ,
551523 ** self ._extra_kwargs ,
552524 )
@@ -564,13 +536,13 @@ def _batch_run_repeats(
564536 delayed (run_single )(
565537 model_cls = self .model_cls ,
566538 cfg = cfg ,
567- key = (self .job_id , repeat_id ),
539+ key = (self .job_id , run_id ),
568540 outpath = self .outpath ,
569- seed = self ._get_seed (repeat_id ),
541+ seed = self ._get_seed (run_id ),
570542 hooks = self ._manager .hooks ,
571543 ** self ._extra_kwargs ,
572544 )
573- for repeat_id in tqdm (
545+ for run_id in tqdm (
574546 range (1 , repeats + 1 ),
575547 disable = not display_progress ,
576548 desc = f"Job { self .job_id } repeats { repeats } times, with { number_process } processes." ,
@@ -644,22 +616,22 @@ def _call_hook_with_optional_args(
644616 hook_func : Callable ,
645617 model : MainModelProtocol ,
646618 job_id : Optional [int ] = None ,
647- repeat_id : Optional [int ] = None ,
619+ run_id : Optional [int ] = None ,
648620) -> Any :
649621 """根据钩子函数的参数签名动态调用函数
650622
651623 Args:
652624 hook_func: 要调用的钩子函数
653625 model: 模型实例
654626 job_id: 可选的任务ID
655- repeat_id : 可选的重复实验ID
627+ run_id : 可选的重复实验ID
656628 """
657629 sig = inspect .signature (hook_func )
658630 hook_args = {}
659631
660632 if "job_id" in sig .parameters :
661633 hook_args ["job_id" ] = job_id
662- if "repeat_id " in sig .parameters :
663- hook_args ["repeat_id " ] = repeat_id
634+ if "run_id " in sig .parameters :
635+ hook_args ["run_id " ] = run_id
664636
665637 return hook_func (model , ** hook_args )
0 commit comments