Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 25 additions & 4 deletions src/forge/observability/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,17 +906,32 @@ class WandbBackend(LoggerBackend):
per_rank_share_run (bool, default False): For per-rank modes, whether to share run ID across ranks.
If true, a single wandb run is created and all ranks log to it. Particularly useful for
logging with no_reduce to capture time-based streams. Not recommended if reducing values.
**kwargs: Any argument accepted by wandb.init() (e.g., project, group, name, tags, notes, etc.)

Example:
**kwargs: Any argument accepted by wandb.init() (e.g., project, group, name, tags, notes,
entity, mode, dir, ...) is forwarded verbatim. Anything you place under
``metric_logging.wandb:`` in your YAML config that isn't ``logging_mode`` or
``per_rank_share_run`` lands in this **kwargs and reaches ``wandb.init`` unchanged.

YAML example:
metric_logging:
wandb:
logging_mode: global_reduce
project: my_project
group: exp_group
name: my_experiment
tags: ["rl", "v2"]
notes: "Testing new reward"
entity: my_team # forwarded to wandb.init
mode: offline # forwarded to wandb.init

Python example:
WandbBackend(
logging_mode=LoggingMode.PER_RANK_REDUCE,
per_rank_share_run=False,
project="my_project",
group="exp_group",
name="my_experiment",
tags=["rl", "v2"],
notes="Testing new reward"
notes="Testing new reward",
)
"""

Expand All @@ -929,6 +944,12 @@ def __init__(
self.run = None
self.process_name = None
self._tables: dict[str, "wandb.Table"] = {}
if kwargs:
logger.info(
"WandbBackend: forwarding %d extra kwarg(s) to wandb.init: %s",
len(kwargs),
sorted(kwargs.keys()),
)

async def init(
self,
Expand Down
33 changes: 33 additions & 0 deletions tests/unit_tests/observability/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,39 @@ def test_wandb_backend_creation(self):
metadata = backend.get_metadata_for_secondary_ranks()
assert metadata == {} # Should be empty when no run

def test_wandb_backend_logs_extra_kwargs(self, caplog):
"""Visibility check for #594: extra kwargs forwarded to wandb.init
should be surfaced at INFO so users see what the YAML actually sent."""
import logging as _logging

with caplog.at_level(_logging.INFO, logger="forge.observability.metrics"):
WandbBackend(
logging_mode=LoggingMode.GLOBAL_REDUCE,
project="p",
entity="my_team",
mode="offline",
)

forwarded_logs = [
r for r in caplog.records if "forwarding" in r.getMessage()
]
assert forwarded_logs, "expected an INFO log listing forwarded wandb kwargs"
msg = forwarded_logs[0].getMessage()
for expected in ("project", "entity", "mode"):
assert expected in msg, f"kwarg {expected!r} not surfaced in log: {msg}"

def test_wandb_backend_no_log_when_no_extra_kwargs(self, caplog):
"""No-op when only the named args are passed."""
import logging as _logging

with caplog.at_level(_logging.INFO, logger="forge.observability.metrics"):
WandbBackend(logging_mode=LoggingMode.GLOBAL_REDUCE)

forwarded_logs = [
r for r in caplog.records if "forwarding" in r.getMessage()
]
assert not forwarded_logs, "should not log when no extra kwargs provided"

@pytest.mark.asyncio
async def test_console_backend(self):
"""Test ConsoleBackend basic operations."""
Expand Down