diff --git a/src/forge/observability/metrics.py b/src/forge/observability/metrics.py index 8434f1f94..5a516473e 100644 --- a/src/forge/observability/metrics.py +++ b/src/forge/observability/metrics.py @@ -906,9 +906,24 @@ class WandbBackend(LoggerBackend): per_rank_share_run (bool, default False): For per-rank modes, whether to share run ID across ranks. If true, a single wandb run is created and all ranks log to it. Particularly useful for logging with no_reduce to capture time-based streams. Not recommended if reducing values. - **kwargs: Any argument accepted by wandb.init() (e.g., project, group, name, tags, notes, etc.) - - Example: + **kwargs: Any argument accepted by wandb.init() (e.g., project, group, name, tags, notes, + entity, mode, dir, ...) is forwarded verbatim. Anything you place under + ``metric_logging.wandb:`` in your YAML config that isn't ``logging_mode`` or + ``per_rank_share_run`` lands in this **kwargs and reaches ``wandb.init`` unchanged. + + YAML example: + metric_logging: + wandb: + logging_mode: global_reduce + project: my_project + group: exp_group + name: my_experiment + tags: ["rl", "v2"] + notes: "Testing new reward" + entity: my_team # forwarded to wandb.init + mode: offline # forwarded to wandb.init + + Python example: WandbBackend( logging_mode=LoggingMode.PER_RANK_REDUCE, per_rank_share_run=False, @@ -916,7 +931,7 @@ class WandbBackend(LoggerBackend): group="exp_group", name="my_experiment", tags=["rl", "v2"], - notes="Testing new reward" + notes="Testing new reward", ) """ @@ -929,6 +944,12 @@ def __init__( self.run = None self.process_name = None self._tables: dict[str, "wandb.Table"] = {} + if kwargs: + logger.info( + "WandbBackend: forwarding %d extra kwarg(s) to wandb.init: %s", + len(kwargs), + sorted(kwargs.keys()), + ) async def init( self, diff --git a/tests/unit_tests/observability/test_metrics.py b/tests/unit_tests/observability/test_metrics.py index 325921895..a80df2107 100644 --- a/tests/unit_tests/observability/test_metrics.py +++ b/tests/unit_tests/observability/test_metrics.py @@ -346,6 +346,39 @@ def test_wandb_backend_creation(self): metadata = backend.get_metadata_for_secondary_ranks() assert metadata == {} # Should be empty when no run + def test_wandb_backend_logs_extra_kwargs(self, caplog): + """Visibility check for #594: extra kwargs forwarded to wandb.init + should be surfaced at INFO so users see what the YAML actually sent.""" + import logging as _logging + + with caplog.at_level(_logging.INFO, logger="forge.observability.metrics"): + WandbBackend( + logging_mode=LoggingMode.GLOBAL_REDUCE, + project="p", + entity="my_team", + mode="offline", + ) + + forwarded_logs = [ + r for r in caplog.records if "forwarding" in r.getMessage() + ] + assert forwarded_logs, "expected an INFO log listing forwarded wandb kwargs" + msg = forwarded_logs[0].getMessage() + for expected in ("project", "entity", "mode"): + assert expected in msg, f"kwarg {expected!r} not surfaced in log: {msg}" + + def test_wandb_backend_no_log_when_no_extra_kwargs(self, caplog): + """No-op when only the named args are passed.""" + import logging as _logging + + with caplog.at_level(_logging.INFO, logger="forge.observability.metrics"): + WandbBackend(logging_mode=LoggingMode.GLOBAL_REDUCE) + + forwarded_logs = [ + r for r in caplog.records if "forwarding" in r.getMessage() + ] + assert not forwarded_logs, "should not log when no extra kwargs provided" + @pytest.mark.asyncio async def test_console_backend(self): """Test ConsoleBackend basic operations."""