Skip to content

Commit 4576f6a

Browse files
authored
Merge pull request #185 from chairc/dev
feat(logger): Logger support distributed training.
2 parents d5b5d50 + 29287fd commit 4576f6a

5 files changed

Lines changed: 44 additions & 37 deletions

File tree

iddm/autoencoder/train.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
image_format_choices
3434
from iddm.config.version import get_version_banner
3535
from iddm.model.trainers.autoencoder import AutoencoderTrainer
36-
from iddm.utils.logger import init_logger, get_logger
36+
from iddm.utils.logger import get_logger
3737

3838
logger = get_logger(name=__name__)
3939

@@ -44,11 +44,6 @@ def main(args):
4444
:param args: Input parameters
4545
:return: None
4646
"""
47-
# Init logger
48-
init_logger(
49-
is_save_log=True,
50-
log_path=os.path.join(str(args.result_path), str(args.run_name))
51-
)
5247
if args.distributed:
5348
gpus = torch.cuda.device_count()
5449
mp.spawn(AutoencoderTrainer(args=args).train, nprocs=gpus)

iddm/model/trainers/base.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from iddm.utils.initializer import device_initializer, lr_initializer
3535

3636
from iddm.utils.utils import setup_logging, save_train_logging
37-
from iddm.utils.logger import get_logger
37+
from iddm.utils.logger import get_logger, init_logger
3838

3939
logger = get_logger(name=__name__)
4040

@@ -151,14 +151,21 @@ def check_args_and_kwargs(self, kwarg, default=None):
151151
logger.info(msg=f"[Note]: args.{kwarg} already set => {value}")
152152
return value
153153

154-
def train(self, rank=None):
154+
def train(self, rank=0):
155155
"""
156156
Training method
157157
:param rank: Device id
158158
"""
159159
# Init rank
160160
self.rank = rank
161161

162+
# Logger initializer
163+
init_logger(
164+
is_save_log=True,
165+
log_path=os.path.join(str(self.result_path), str(self.run_name)),
166+
rank=self.rank
167+
)
168+
162169
# Training
163170
self.before_train()
164171
self.train_in_epochs()

iddm/sr/train.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from iddm.config.choices import sr_network_choices, optim_choices, sr_loss_func_choices, image_format_choices
3333
from iddm.config.version import get_version_banner
3434
from iddm.model.trainers.sr import SRTrainer
35-
from iddm.utils.logger import init_logger, get_logger
35+
from iddm.utils.logger import get_logger
3636

3737
logger = get_logger(name=__name__)
3838

@@ -43,11 +43,6 @@ def main(args):
4343
:param args: Input parameters
4444
:return: None
4545
"""
46-
# Init logger
47-
init_logger(
48-
is_save_log=True,
49-
log_path=os.path.join(str(args.result_path), str(args.run_name))
50-
)
5146
if args.distributed:
5247
gpus = torch.cuda.device_count()
5348
mp.spawn(SRTrainer(args=args).train, nprocs=gpus)

iddm/tools/train.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from iddm.config.version import get_version_banner
3535
from iddm.model.trainers import DMTrainer
3636
from iddm.utils.check import check_parse_image_size_type
37-
from iddm.utils.logger import init_logger, get_logger
37+
from iddm.utils.logger import get_logger
3838

3939
logger = get_logger(name=__name__)
4040

@@ -45,11 +45,6 @@ def main(args):
4545
:param args: Input parameters
4646
:return: None
4747
"""
48-
# Init logger
49-
init_logger(
50-
is_save_log=True,
51-
log_path=os.path.join(str(args.result_path), str(args.run_name))
52-
)
5348
if args.distributed:
5449
gpus = torch.cuda.device_count()
5550
mp.spawn(DMTrainer(args=args).train, nprocs=gpus)

iddm/utils/logger.py

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@
3434
"level": None,
3535
"is_save_log": None,
3636
"log_path": None,
37-
"initialized": False # Whether the tag is initialized
37+
"initialized": False, # Whether the tag is initialized
38+
"rank": 0
3839
}
3940

4041

@@ -48,10 +49,12 @@ def __init__(
4849
name: str,
4950
level: Union[int, str] = logging.INFO,
5051
is_save_log: bool = False,
51-
log_path: Optional[str] = None
52+
log_path: Optional[str] = None,
53+
rank: int = 0
5254
):
5355
super().__init__(name, level)
5456
self.is_save_log = is_save_log
57+
self.rank = rank
5558
self.log_path = log_path
5659
self._init_handlers()
5760
self._setup_colored_logs()
@@ -75,7 +78,7 @@ def _add_console_handler(self) -> None:
7578
Add a console processor
7679
"""
7780
console_handler = logging.StreamHandler()
78-
formatter = logging.Formatter('%(asctime)s %(name)s %(levelname)s %(message)s')
81+
formatter = logging.Formatter('%(asctime)s %(name)s Process-%(process)d %(levelname)s %(message)s')
7982
console_handler.setFormatter(formatter)
8083
self.addHandler(console_handler)
8184

@@ -89,10 +92,10 @@ def _add_file_handler(self) -> None:
8992

9093
# Different types of logs use different file prefixes
9194
prefix = "webui" if isinstance(self, WebUILogger) else "app"
92-
log_file = os.path.join(log_save_path, f"{prefix}_{create_time}.log")
95+
log_file = os.path.join(log_save_path, f"{prefix}_rank{self.rank}_{create_time}.log")
9396

9497
file_handler = logging.FileHandler(log_file, mode="a", encoding="utf-8")
95-
formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s %(message)s")
98+
formatter = logging.Formatter("%(asctime)s %(name)s Process-%(process)d %(levelname)s %(message)s")
9699
file_handler.setFormatter(formatter)
97100
self.addHandler(file_handler)
98101

@@ -102,13 +105,14 @@ def _setup_colored_logs(self) -> None:
102105
"""
103106
coloredlogs.install(level=self.level, logger=self)
104107

105-
def refresh_config(self, level: Union[int, str], is_save_log: bool, log_path: Optional[str]):
108+
def refresh_config(self, level: Union[int, str], is_save_log: bool, log_path: Optional[str], rank: int) -> None:
106109
"""
107110
Refresh log instance configuration (used to update existing instances after init_logger)
108111
"""
109112
self.level = level
110113
self.is_save_log = is_save_log
111114
self.log_path = log_path
115+
self.rank = rank
112116
self._init_handlers()
113117
self._setup_colored_logs()
114118

@@ -123,9 +127,10 @@ def __init__(
123127
name: str,
124128
level: Union[int, str] = logging.INFO,
125129
is_save_log: bool = False,
126-
log_path: Optional[str] = None
130+
log_path: Optional[str] = None,
131+
rank: int = 0
127132
):
128-
super().__init__(name, level, is_save_log, log_path)
133+
super().__init__(name, level, is_save_log, log_path, rank)
129134
# Cumulative text for WebUI display
130135
self.webui_text = ""
131136

@@ -182,21 +187,24 @@ def __init__(
182187
name: str,
183188
level: Union[int, str] = logging.INFO,
184189
is_save_log: bool = False,
185-
log_path: Optional[str] = None
190+
log_path: Optional[str] = None,
191+
rank: int = 0
186192
):
187-
super().__init__(name, level, is_save_log, log_path)
193+
super().__init__(name, level, is_save_log, log_path, rank)
188194

189195

190196
def init_logger(
191197
level: Union[int, str] = logging.INFO,
192198
is_save_log: bool = False,
193-
log_path: Optional[str] = None
199+
log_path: Optional[str] = None,
200+
rank: int = 0
194201
) -> None:
195202
"""
196203
Initialize the global log configuration (only the first call takes effect)
197204
:param level: Log level
198205
:param is_save_log: Whether to save the log to a file
199206
:param log_path: Log saving path
207+
:param rank: Distributed training rank
200208
:return: None
201209
"""
202210
global _global_log_config
@@ -209,15 +217,17 @@ def init_logger(
209217
"level": level,
210218
"is_save_log": is_save_log,
211219
"log_path": log_path,
212-
"initialized": True
220+
"initialized": True,
221+
"rank": rank
213222
})
214223

215224
# Refresh existing log instances (apply new configuration)
216225
for logger in _global_loggers.values():
217226
logger.refresh_config(
218227
level=level,
219228
is_save_log=is_save_log,
220-
log_path=log_path
229+
log_path=log_path,
230+
rank=rank
221231
)
222232

223233

@@ -226,7 +236,8 @@ def get_logger(
226236
logger_type: str = "app",
227237
level: Union[int, str] = logging.INFO,
228238
is_save_log: bool = False,
229-
log_path: Optional[str] = None
239+
log_path: Optional[str] = None,
240+
rank: int = 0
230241
) -> Union[AppLogger, WebUILogger]:
231242
"""
232243
Get a global log instance (singleton mode)
@@ -235,6 +246,7 @@ def get_logger(
235246
:param level: Log level
236247
:param is_save_log: Whether to save the log to a file
237248
:param log_path: Log saving path
249+
:param rank: Distributed training rank
238250
:return: Log instance
239251
"""
240252

@@ -244,24 +256,27 @@ def get_logger(
244256
level = level or _global_log_config.get("level", logging.INFO)
245257
is_save_log = is_save_log if is_save_log is not None else _global_log_config.get("is_save_log", False)
246258
log_path = log_path or _global_log_config.get("log_path")
259+
rank = rank or _global_log_config.get("rank", 0)
247260

248261
# Use (name + type) as the unique key to ensure isolation
249-
key = f"{name}_{logger_type}"
262+
key = f"{name}_{logger_type}_{rank}"
250263

251264
if key not in _global_loggers:
252265
if logger_type == "webui":
253266
_global_loggers[key] = WebUILogger(
254267
name=name,
255268
level=level,
256269
is_save_log=is_save_log,
257-
log_path=log_path
270+
log_path=log_path,
271+
rank=rank
258272
)
259273
else:
260274
_global_loggers[key] = AppLogger(
261275
name=name,
262276
level=level,
263277
is_save_log=is_save_log,
264-
log_path=log_path
278+
log_path=log_path,
279+
rank=rank
265280
)
266281

267282
return _global_loggers[key]

0 commit comments

Comments
 (0)