Skip to content

Commit c60111a

Browse files
committed
feat: add wandb offline mode and custom wandb directory support
1 parent 2c4893f commit c60111a

2 files changed

Lines changed: 42 additions & 4 deletions

File tree

specforge/args.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ class TrackerArgs:
1111
wandb_project: str = None
1212
wandb_name: str = None
1313
wandb_key: str = None
14+
wandb_offline: bool = False
15+
wandb_dir: str = None
1416
swanlab_project: str = None
1517
swanlab_name: str = None
1618
swanlab_key: str = None
@@ -33,6 +35,17 @@ def add_args(parser: argparse.ArgumentParser) -> None:
3335
parser.add_argument("--wandb-project", type=str, default=None)
3436
parser.add_argument("--wandb-name", type=str, default=None)
3537
parser.add_argument("--wandb-key", type=str, default=None, help="W&B API key.")
38+
parser.add_argument(
39+
"--wandb-offline",
40+
action="store_true",
41+
help="Enable W&B offline mode and store logs locally.",
42+
)
43+
parser.add_argument(
44+
"--wandb-dir",
45+
type=str,
46+
default=None,
47+
help="Directory to store W&B files. Defaults to './wandb' under the project root when using W&B.",
48+
)
3649
# swanlab-specific args
3750
parser.add_argument(
3851
"--swanlab-project",

specforge/tracker.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,13 +91,24 @@ def close(self):
9191
class WandbTracker(Tracker):
9292
"""Tracks experiments using Weights & Biases."""
9393

94+
@staticmethod
95+
def _default_wandb_dir() -> str:
96+
# specforge/tracker.py -> project root is one level up
97+
return os.path.normpath(os.path.join(os.path.dirname(__file__), "..", "wandb"))
98+
9499
@classmethod
95100
def validate_args(cls, parser, args):
96101
if wandb is None:
97102
parser.error(
98103
"To use --report-to wandb, you must install wandb: 'pip install wandb'"
99104
)
100105

106+
if args.wandb_dir is None:
107+
args.wandb_dir = cls._default_wandb_dir()
108+
109+
if args.wandb_offline:
110+
return
111+
101112
if args.wandb_key is not None:
102113
return
103114

@@ -128,10 +139,24 @@ def validate_args(cls, parser, args):
128139
def __init__(self, args, output_dir: str):
129140
super().__init__(args, output_dir)
130141
if self.rank == 0:
131-
wandb.login(key=args.wandb_key)
132-
wandb.init(
133-
project=args.wandb_project, name=args.wandb_name, config=vars(args)
134-
)
142+
if args.wandb_dir is None:
143+
args.wandb_dir = self._default_wandb_dir()
144+
os.makedirs(args.wandb_dir, exist_ok=True)
145+
if args.wandb_offline:
146+
os.environ["WANDB_MODE"] = "offline"
147+
os.environ["WANDB_DIR"] = args.wandb_dir
148+
149+
if not args.wandb_offline:
150+
wandb.login(key=args.wandb_key)
151+
init_kwargs = {
152+
"project": args.wandb_project,
153+
"name": args.wandb_name,
154+
"config": vars(args),
155+
"dir": args.wandb_dir,
156+
}
157+
if args.wandb_offline:
158+
init_kwargs["mode"] = "offline"
159+
wandb.init(**init_kwargs)
135160
self.is_initialized = True
136161

137162
def log(self, log_dict: Dict[str, Any], step: Optional[int] = None):

0 commit comments

Comments
 (0)