Skip to content

Commit f266946

Browse files
author
Donglai Wei
committed
minimize log.out
1 parent 14c8966 commit f266946

3 files changed

Lines changed: 55 additions & 15 deletions

File tree

connectomics/training/lit/model.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -634,8 +634,15 @@ def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> STEP_
634634
print(f"[D1 Step {self.global_step}] TARGET: min={target_min:.3f}, max={target_max:.3f}, "
635635
f"mean={target_mean:.3f}, >0: {target_positive_frac:.1f}%")
636636

637-
# Log losses (sync across GPUs for distributed training)
638-
self.log_dict(loss_dict, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
637+
# Keep full training curves in TensorBoard while avoiding console spam.
638+
self.log_dict(
639+
loss_dict,
640+
on_step=True,
641+
on_epoch=True,
642+
prog_bar=False,
643+
logger=True,
644+
sync_dist=False,
645+
)
639646

640647
return total_loss
641648

@@ -728,8 +735,27 @@ def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> STE
728735
self.val_accuracy(preds, targets)
729736
self.log('val_accuracy', self.val_accuracy, on_step=False, on_epoch=True, prog_bar=True)
730737

731-
# Log losses (sync across GPUs for distributed training)
732-
self.log_dict(loss_dict, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
738+
# Show only validation total loss on the progress bar.
739+
if "val_loss_total" in loss_dict:
740+
self.log(
741+
"val_loss",
742+
loss_dict["val_loss_total"],
743+
on_step=False,
744+
on_epoch=True,
745+
prog_bar=True,
746+
logger=False,
747+
sync_dist=True,
748+
)
749+
750+
# Log full validation losses to logger at epoch granularity.
751+
self.log_dict(
752+
loss_dict,
753+
on_step=False,
754+
on_epoch=True,
755+
prog_bar=False,
756+
logger=True,
757+
sync_dist=True,
758+
)
733759

734760
return total_loss
735761

connectomics/training/lit/trainer.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
ModelCheckpoint,
2020
EarlyStopping,
2121
LearningRateMonitor,
22-
RichProgressBar,
2322
)
2423
from pytorch_lightning.loggers import TensorBoardLogger
2524
from pytorch_lightning.plugins.environments import LightningEnvironment
@@ -106,7 +105,7 @@ def create_trainer(
106105
save_top_k=cfg.monitor.checkpoint.save_top_k,
107106
save_last=cfg.monitor.checkpoint.save_last,
108107
every_n_epochs=cfg.monitor.checkpoint.save_every_n_epochs,
109-
verbose=True,
108+
verbose=False,
110109
save_on_train_epoch_end=True, # Save based on training metrics
111110
)
112111
callbacks.append(checkpoint_callback)
@@ -132,7 +131,7 @@ def create_trainer(
132131
patience=cfg.monitor.early_stopping.patience,
133132
mode=cfg.monitor.early_stopping.mode,
134133
min_delta=cfg.monitor.early_stopping.min_delta,
135-
verbose=True,
134+
verbose=False,
136135
check_on_train_epoch_end=True, # Check at end of train epoch (not validation)
137136
check_finite=cfg.monitor.early_stopping.check_finite, # Stop on NaN/inf
138137
stopping_threshold=cfg.monitor.early_stopping.threshold,
@@ -184,19 +183,13 @@ def create_trainer(
184183
# Previous fix in val_dataloader() only ran once during setup
185184
validation_reseeding_callback = ValidationReseedingCallback(
186185
base_seed=cfg.system.seed,
187-
log_fingerprint=True,
186+
log_fingerprint=False,
188187
log_all_ranks=False,
189-
verbose=True,
188+
verbose=False,
190189
)
191190
callbacks.append(validation_reseeding_callback)
192191
print(f" Validation Reseeding: Enabled (base_seed={cfg.system.seed})")
193192

194-
# Progress bar (optional - requires rich package)
195-
try:
196-
callbacks.append(RichProgressBar())
197-
except (ImportError, ModuleNotFoundError):
198-
pass # Use default progress bar
199-
200193
# Setup logger (training only - in run_dir/logs/)
201194
# Always create a logger for training to avoid warnings about missing logger
202195
logger = None
@@ -322,6 +315,7 @@ def create_trainer(
322315
benchmark=cfg.optimization.benchmark,
323316
fast_dev_run=bool(fast_dev_run),
324317
detect_anomaly=detect_anomaly,
318+
enable_progress_bar=False,
325319
plugins=plugins,
326320
)
327321

scripts/main.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
python scripts/main.py --config tutorials/mito_lucchi++.yaml --checkpoint path/to/ckpt.ckpt --reset-max-epochs 500
2727
"""
2828

29+
import os
2930
import sys
3031
from pathlib import Path
3132

@@ -68,6 +69,23 @@
6869
# Setup seed_everything with version fallback
6970
seed_everything = setup_seed_everything()
7071

72+
_RANK_STDOUT_REDIRECT = None
73+
74+
75+
def suppress_nonzero_rank_stdout() -> None:
76+
"""Reduce duplicate stdout spam from DDP subprocesses.
77+
78+
In local multi-GPU spawn, each subprocess executes this script and prints
79+
the same setup logs. Keep rank 0 stdout visible and silence stdout on
80+
non-zero ranks. stderr is untouched for error visibility.
81+
"""
82+
global _RANK_STDOUT_REDIRECT
83+
local_rank = os.environ.get("LOCAL_RANK")
84+
if local_rank is None or local_rank == "0":
85+
return
86+
_RANK_STDOUT_REDIRECT = open(os.devnull, "w")
87+
sys.stdout = _RANK_STDOUT_REDIRECT
88+
7189

7290
def configure_matmul_precision(cfg: Config) -> None:
7391
"""Enable Tensor Core matmul precision when supported by available CUDA devices."""
@@ -164,6 +182,8 @@ def extract_step_from_checkpoint(checkpoint_path: str) -> str:
164182

165183
def main():
166184
"""Main training function."""
185+
suppress_nonzero_rank_stdout()
186+
167187
# Parse arguments
168188
args = parse_args()
169189

0 commit comments

Comments
 (0)