From 39c622fcf3ff45654ece9c0b9ef342ed22cbf82b Mon Sep 17 00:00:00 2001 From: cheny Date: Thu, 25 Dec 2025 14:21:07 +0800 Subject: [PATCH] feat(logger): Logger support distributed training. --- iddm/autoencoder/train.py | 7 +----- iddm/model/trainers/base.py | 11 +++++++-- iddm/sr/train.py | 7 +----- iddm/tools/train.py | 7 +----- iddm/utils/logger.py | 49 ++++++++++++++++++++++++------------- 5 files changed, 44 insertions(+), 37 deletions(-) diff --git a/iddm/autoencoder/train.py b/iddm/autoencoder/train.py index c9a4eed..37b67c5 100644 --- a/iddm/autoencoder/train.py +++ b/iddm/autoencoder/train.py @@ -33,7 +33,7 @@ image_format_choices from iddm.config.version import get_version_banner from iddm.model.trainers.autoencoder import AutoencoderTrainer -from iddm.utils.logger import init_logger, get_logger +from iddm.utils.logger import get_logger logger = get_logger(name=__name__) @@ -44,11 +44,6 @@ def main(args): :param args: Input parameters :return: None """ - # Init logger - init_logger( - is_save_log=True, - log_path=os.path.join(str(args.result_path), str(args.run_name)) - ) if args.distributed: gpus = torch.cuda.device_count() mp.spawn(AutoencoderTrainer(args=args).train, nprocs=gpus) diff --git a/iddm/model/trainers/base.py b/iddm/model/trainers/base.py index 0383473..cb4d2ca 100644 --- a/iddm/model/trainers/base.py +++ b/iddm/model/trainers/base.py @@ -34,7 +34,7 @@ from iddm.utils.initializer import device_initializer, lr_initializer from iddm.utils.utils import setup_logging, save_train_logging -from iddm.utils.logger import get_logger +from iddm.utils.logger import get_logger, init_logger logger = get_logger(name=__name__) @@ -151,7 +151,7 @@ def check_args_and_kwargs(self, kwarg, default=None): logger.info(msg=f"[Note]: args.{kwarg} already set => {value}") return value - def train(self, rank=None): + def train(self, rank=0): """ Training method :param rank: Device id @@ -159,6 +159,13 @@ def train(self, rank=None): # Init rank self.rank = rank + # Logger initializer + init_logger( + is_save_log=True, + log_path=os.path.join(str(self.result_path), str(self.run_name)), + rank=self.rank + ) + # Training self.before_train() self.train_in_epochs() diff --git a/iddm/sr/train.py b/iddm/sr/train.py index 53410e6..e68631d 100644 --- a/iddm/sr/train.py +++ b/iddm/sr/train.py @@ -32,7 +32,7 @@ from iddm.config.choices import sr_network_choices, optim_choices, sr_loss_func_choices, image_format_choices from iddm.config.version import get_version_banner from iddm.model.trainers.sr import SRTrainer -from iddm.utils.logger import init_logger, get_logger +from iddm.utils.logger import get_logger logger = get_logger(name=__name__) @@ -43,11 +43,6 @@ def main(args): :param args: Input parameters :return: None """ - # Init logger - init_logger( - is_save_log=True, - log_path=os.path.join(str(args.result_path), str(args.run_name)) - ) if args.distributed: gpus = torch.cuda.device_count() mp.spawn(SRTrainer(args=args).train, nprocs=gpus) diff --git a/iddm/tools/train.py b/iddm/tools/train.py index 832ae5e..4bdc10d 100644 --- a/iddm/tools/train.py +++ b/iddm/tools/train.py @@ -34,7 +34,7 @@ from iddm.config.version import get_version_banner from iddm.model.trainers import DMTrainer from iddm.utils.check import check_parse_image_size_type -from iddm.utils.logger import init_logger, get_logger +from iddm.utils.logger import get_logger logger = get_logger(name=__name__) @@ -45,11 +45,6 @@ def main(args): :param args: Input parameters :return: None """ - # Init logger - init_logger( - is_save_log=True, - log_path=os.path.join(str(args.result_path), str(args.run_name)) - ) if args.distributed: gpus = torch.cuda.device_count() mp.spawn(DMTrainer(args=args).train, nprocs=gpus) diff --git a/iddm/utils/logger.py b/iddm/utils/logger.py index 5a17383..fda351c 100644 --- a/iddm/utils/logger.py +++ b/iddm/utils/logger.py @@ -34,7 +34,8 @@ "level": None, "is_save_log": None, "log_path": None, - "initialized": False # Whether the tag is initialized + "initialized": False, # Whether the tag is initialized + "rank": 0 } @@ -48,10 +49,12 @@ def __init__( name: str, level: Union[int, str] = logging.INFO, is_save_log: bool = False, - log_path: Optional[str] = None + log_path: Optional[str] = None, + rank: int = 0 ): super().__init__(name, level) self.is_save_log = is_save_log + self.rank = rank self.log_path = log_path self._init_handlers() self._setup_colored_logs() @@ -75,7 +78,7 @@ def _add_console_handler(self) -> None: Add a console processor """ console_handler = logging.StreamHandler() - formatter = logging.Formatter('%(asctime)s %(name)s %(levelname)s %(message)s') + formatter = logging.Formatter('%(asctime)s %(name)s Process-%(process)d %(levelname)s %(message)s') console_handler.setFormatter(formatter) self.addHandler(console_handler) @@ -89,10 +92,10 @@ def _add_file_handler(self) -> None: # Different types of logs use different file prefixes prefix = "webui" if isinstance(self, WebUILogger) else "app" - log_file = os.path.join(log_save_path, f"{prefix}_{create_time}.log") + log_file = os.path.join(log_save_path, f"{prefix}_rank{self.rank}_{create_time}.log") file_handler = logging.FileHandler(log_file, mode="a", encoding="utf-8") - formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s %(message)s") + formatter = logging.Formatter("%(asctime)s %(name)s Process-%(process)d %(levelname)s %(message)s") file_handler.setFormatter(formatter) self.addHandler(file_handler) @@ -102,13 +105,14 @@ def _setup_colored_logs(self) -> None: """ coloredlogs.install(level=self.level, logger=self) - def refresh_config(self, level: Union[int, str], is_save_log: bool, log_path: Optional[str]): + def refresh_config(self, level: Union[int, str], is_save_log: bool, log_path: Optional[str], rank: int) -> None: """ Refresh log instance configuration (used to update existing instances after init_logger) """ self.level = level self.is_save_log = is_save_log self.log_path = log_path + self.rank = rank self._init_handlers() self._setup_colored_logs() @@ -123,9 +127,10 @@ def __init__( name: str, level: Union[int, str] = logging.INFO, is_save_log: bool = False, - log_path: Optional[str] = None + log_path: Optional[str] = None, + rank: int = 0 ): - super().__init__(name, level, is_save_log, log_path) + super().__init__(name, level, is_save_log, log_path, rank) # Cumulative text for WebUI display self.webui_text = "" @@ -182,21 +187,24 @@ def __init__( name: str, level: Union[int, str] = logging.INFO, is_save_log: bool = False, - log_path: Optional[str] = None + log_path: Optional[str] = None, + rank: int = 0 ): - super().__init__(name, level, is_save_log, log_path) + super().__init__(name, level, is_save_log, log_path, rank) def init_logger( level: Union[int, str] = logging.INFO, is_save_log: bool = False, - log_path: Optional[str] = None + log_path: Optional[str] = None, + rank: int = 0 ) -> None: """ Initialize the global log configuration (only the first call takes effect) :param level: Log level :param is_save_log: Whether to save the log to a file :param log_path: Log saving path + :param rank: Distributed training rank :return: None """ global _global_log_config @@ -209,7 +217,8 @@ def init_logger( "level": level, "is_save_log": is_save_log, "log_path": log_path, - "initialized": True + "initialized": True, + "rank": rank }) # Refresh existing log instances (apply new configuration) @@ -217,7 +226,8 @@ def init_logger( logger.refresh_config( level=level, is_save_log=is_save_log, - log_path=log_path + log_path=log_path, + rank=rank ) @@ -226,7 +236,8 @@ def get_logger( logger_type: str = "app", level: Union[int, str] = logging.INFO, is_save_log: bool = False, - log_path: Optional[str] = None + log_path: Optional[str] = None, + rank: int = 0 ) -> Union[AppLogger, WebUILogger]: """ Get a global log instance (singleton mode) @@ -235,6 +246,7 @@ def get_logger( :param level: Log level :param is_save_log: Whether to save the log to a file :param log_path: Log saving path + :param rank: Distributed training rank :return: Log instance """ @@ -244,9 +256,10 @@ def get_logger( level = level or _global_log_config.get("level", logging.INFO) is_save_log = is_save_log if is_save_log is not None else _global_log_config.get("is_save_log", False) log_path = log_path or _global_log_config.get("log_path") + rank = rank or _global_log_config.get("rank", 0) # Use (name + type) as the unique key to ensure isolation - key = f"{name}_{logger_type}" + key = f"{name}_{logger_type}_{rank}" if key not in _global_loggers: if logger_type == "webui": @@ -254,14 +267,16 @@ def get_logger( name=name, level=level, is_save_log=is_save_log, - log_path=log_path + log_path=log_path, + rank=rank ) else: _global_loggers[key] = AppLogger( name=name, level=level, is_save_log=is_save_log, - log_path=log_path + log_path=log_path, + rank=rank ) return _global_loggers[key]