@@ -91,13 +91,24 @@ def close(self):
9191class WandbTracker (Tracker ):
9292 """Tracks experiments using Weights & Biases."""
9393
94+ @staticmethod
95+ def _default_wandb_dir () -> str :
96+ # specforge/tracker.py -> project root is one level up
97+ return os .path .normpath (os .path .join (os .path .dirname (__file__ ), ".." , "wandb" ))
98+
9499 @classmethod
95100 def validate_args (cls , parser , args ):
96101 if wandb is None :
97102 parser .error (
98103 "To use --report-to wandb, you must install wandb: 'pip install wandb'"
99104 )
100105
106+ if args .wandb_dir is None :
107+ args .wandb_dir = cls ._default_wandb_dir ()
108+
109+ if args .wandb_offline :
110+ return
111+
101112 if args .wandb_key is not None :
102113 return
103114
@@ -128,10 +139,24 @@ def validate_args(cls, parser, args):
128139 def __init__ (self , args , output_dir : str ):
129140 super ().__init__ (args , output_dir )
130141 if self .rank == 0 :
131- wandb .login (key = args .wandb_key )
132- wandb .init (
133- project = args .wandb_project , name = args .wandb_name , config = vars (args )
134- )
142+ if args .wandb_dir is None :
143+ args .wandb_dir = self ._default_wandb_dir ()
144+ os .makedirs (args .wandb_dir , exist_ok = True )
145+ if args .wandb_offline :
146+ os .environ ["WANDB_MODE" ] = "offline"
147+ os .environ ["WANDB_DIR" ] = args .wandb_dir
148+
149+ if not args .wandb_offline :
150+ wandb .login (key = args .wandb_key )
151+ init_kwargs = {
152+ "project" : args .wandb_project ,
153+ "name" : args .wandb_name ,
154+ "config" : vars (args ),
155+ "dir" : args .wandb_dir ,
156+ }
157+ if args .wandb_offline :
158+ init_kwargs ["mode" ] = "offline"
159+ wandb .init (** init_kwargs )
135160 self .is_initialized = True
136161
137162 def log (self , log_dict : Dict [str , Any ], step : Optional [int ] = None ):
0 commit comments