Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions iddm/autoencoder/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
image_format_choices
from iddm.config.version import get_version_banner
from iddm.model.trainers.autoencoder import AutoencoderTrainer
from iddm.utils.logger import init_logger, get_logger
from iddm.utils.logger import get_logger

logger = get_logger(name=__name__)

Expand All @@ -44,11 +44,6 @@ def main(args):
:param args: Input parameters
:return: None
"""
# Init logger
init_logger(
is_save_log=True,
log_path=os.path.join(str(args.result_path), str(args.run_name))
)
if args.distributed:
gpus = torch.cuda.device_count()
mp.spawn(AutoencoderTrainer(args=args).train, nprocs=gpus)
Expand Down
11 changes: 9 additions & 2 deletions iddm/model/trainers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from iddm.utils.initializer import device_initializer, lr_initializer

from iddm.utils.utils import setup_logging, save_train_logging
from iddm.utils.logger import get_logger
from iddm.utils.logger import get_logger, init_logger

logger = get_logger(name=__name__)

Expand Down Expand Up @@ -151,14 +151,21 @@ def check_args_and_kwargs(self, kwarg, default=None):
logger.info(msg=f"[Note]: args.{kwarg} already set => {value}")
return value

def train(self, rank=None):
def train(self, rank=0):
"""
Training method
:param rank: Device id
"""
# Init rank
self.rank = rank

# Logger initializer
init_logger(
is_save_log=True,
log_path=os.path.join(str(self.result_path), str(self.run_name)),
rank=self.rank
)

# Training
self.before_train()
self.train_in_epochs()
Expand Down
7 changes: 1 addition & 6 deletions iddm/sr/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from iddm.config.choices import sr_network_choices, optim_choices, sr_loss_func_choices, image_format_choices
from iddm.config.version import get_version_banner
from iddm.model.trainers.sr import SRTrainer
from iddm.utils.logger import init_logger, get_logger
from iddm.utils.logger import get_logger

logger = get_logger(name=__name__)

Expand All @@ -43,11 +43,6 @@ def main(args):
:param args: Input parameters
:return: None
"""
# Init logger
init_logger(
is_save_log=True,
log_path=os.path.join(str(args.result_path), str(args.run_name))
)
if args.distributed:
gpus = torch.cuda.device_count()
mp.spawn(SRTrainer(args=args).train, nprocs=gpus)
Expand Down
2 changes: 1 addition & 1 deletion iddm/tools/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def init_generate_args():
# Input image size (required)
# [Warn] Compatible with older versions
# [Warn] Version <= 1.1.1 need to be equal to model's image size, version > 1.1.1 can set whatever you want
parser.add_argument("--image_size", "-i", type=parse_image_size_type, default=64)
parser.add_argument("--image_size", "-i", type=check_parse_image_size_type, default=64)
# Set the use GPU in generate (required)
parser.add_argument("--use_gpu", type=int, default=0)

Expand Down
7 changes: 1 addition & 6 deletions iddm/tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from iddm.config.version import get_version_banner
from iddm.model.trainers import DMTrainer
from iddm.utils.check import check_parse_image_size_type
from iddm.utils.logger import init_logger, get_logger
from iddm.utils.logger import get_logger

logger = get_logger(name=__name__)

Expand All @@ -45,11 +45,6 @@ def main(args):
:param args: Input parameters
:return: None
"""
# Init logger
init_logger(
is_save_log=True,
log_path=os.path.join(str(args.result_path), str(args.run_name))
)
if args.distributed:
gpus = torch.cuda.device_count()
mp.spawn(DMTrainer(args=args).train, nprocs=gpus)
Expand Down
49 changes: 32 additions & 17 deletions iddm/utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
"level": None,
"is_save_log": None,
"log_path": None,
"initialized": False # Whether the tag is initialized
"initialized": False, # Whether the tag is initialized
"rank": 0
}


Expand All @@ -48,10 +49,12 @@ def __init__(
name: str,
level: Union[int, str] = logging.INFO,
is_save_log: bool = False,
log_path: Optional[str] = None
log_path: Optional[str] = None,
rank: int = 0
):
super().__init__(name, level)
self.is_save_log = is_save_log
self.rank = rank
self.log_path = log_path
self._init_handlers()
self._setup_colored_logs()
Expand All @@ -75,7 +78,7 @@ def _add_console_handler(self) -> None:
Add a console processor
"""
console_handler = logging.StreamHandler()
formatter = logging.Formatter('%(asctime)s %(name)s %(levelname)s %(message)s')
formatter = logging.Formatter('%(asctime)s %(name)s Process-%(process)d %(levelname)s %(message)s')
console_handler.setFormatter(formatter)
self.addHandler(console_handler)

Expand All @@ -89,10 +92,10 @@ def _add_file_handler(self) -> None:

# Different types of logs use different file prefixes
prefix = "webui" if isinstance(self, WebUILogger) else "app"
log_file = os.path.join(log_save_path, f"{prefix}_{create_time}.log")
log_file = os.path.join(log_save_path, f"{prefix}_rank{self.rank}_{create_time}.log")

file_handler = logging.FileHandler(log_file, mode="a", encoding="utf-8")
formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s %(message)s")
formatter = logging.Formatter("%(asctime)s %(name)s Process-%(process)d %(levelname)s %(message)s")
file_handler.setFormatter(formatter)
self.addHandler(file_handler)

Expand All @@ -102,13 +105,14 @@ def _setup_colored_logs(self) -> None:
"""
coloredlogs.install(level=self.level, logger=self)

def refresh_config(self, level: Union[int, str], is_save_log: bool, log_path: Optional[str]):
def refresh_config(self, level: Union[int, str], is_save_log: bool, log_path: Optional[str], rank: int) -> None:
"""
Refresh log instance configuration (used to update existing instances after init_logger)
"""
self.level = level
self.is_save_log = is_save_log
self.log_path = log_path
self.rank = rank
self._init_handlers()
self._setup_colored_logs()

Expand All @@ -123,9 +127,10 @@ def __init__(
name: str,
level: Union[int, str] = logging.INFO,
is_save_log: bool = False,
log_path: Optional[str] = None
log_path: Optional[str] = None,
rank: int = 0
):
super().__init__(name, level, is_save_log, log_path)
super().__init__(name, level, is_save_log, log_path, rank)
# Cumulative text for WebUI display
self.webui_text = ""

Expand Down Expand Up @@ -182,21 +187,24 @@ def __init__(
name: str,
level: Union[int, str] = logging.INFO,
is_save_log: bool = False,
log_path: Optional[str] = None
log_path: Optional[str] = None,
rank: int = 0
):
super().__init__(name, level, is_save_log, log_path)
super().__init__(name, level, is_save_log, log_path, rank)


def init_logger(
level: Union[int, str] = logging.INFO,
is_save_log: bool = False,
log_path: Optional[str] = None
log_path: Optional[str] = None,
rank: int = 0
) -> None:
"""
Initialize the global log configuration (only the first call takes effect)
:param level: Log level
:param is_save_log: Whether to save the log to a file
:param log_path: Log saving path
:param rank: Distributed training rank
:return: None
"""
global _global_log_config
Expand All @@ -209,15 +217,17 @@ def init_logger(
"level": level,
"is_save_log": is_save_log,
"log_path": log_path,
"initialized": True
"initialized": True,
"rank": rank
})

# Refresh existing log instances (apply new configuration)
for logger in _global_loggers.values():
logger.refresh_config(
level=level,
is_save_log=is_save_log,
log_path=log_path
log_path=log_path,
rank=rank
)


Expand All @@ -226,7 +236,8 @@ def get_logger(
logger_type: str = "app",
level: Union[int, str] = logging.INFO,
is_save_log: bool = False,
log_path: Optional[str] = None
log_path: Optional[str] = None,
rank: int = 0
) -> Union[AppLogger, WebUILogger]:
"""
Get a global log instance (singleton mode)
Expand All @@ -235,6 +246,7 @@ def get_logger(
:param level: Log level
:param is_save_log: Whether to save the log to a file
:param log_path: Log saving path
:param rank: Distributed training rank
:return: Log instance
"""

Expand All @@ -244,24 +256,27 @@ def get_logger(
level = level or _global_log_config.get("level", logging.INFO)
is_save_log = is_save_log if is_save_log is not None else _global_log_config.get("is_save_log", False)
log_path = log_path or _global_log_config.get("log_path")
rank = rank or _global_log_config.get("rank", 0)

# Use (name + type) as the unique key to ensure isolation
key = f"{name}_{logger_type}"
key = f"{name}_{logger_type}_{rank}"

if key not in _global_loggers:
if logger_type == "webui":
_global_loggers[key] = WebUILogger(
name=name,
level=level,
is_save_log=is_save_log,
log_path=log_path
log_path=log_path,
rank=rank
)
else:
_global_loggers[key] = AppLogger(
name=name,
level=level,
is_save_log=is_save_log,
log_path=log_path
log_path=log_path,
rank=rank
)

return _global_loggers[key]
Expand Down