@@ -77,16 +77,48 @@ def _module_from_path(path: str) -> str | None:
7777 return None
7878
7979
80- def _resolve_or_infer_config (argv : list [str ]) -> tuple [str , list [str ]]:
80+ def _resolve_or_infer_config (argv : list [str ] | None = None , ** kwargs ) -> tuple [str , list [str ]]:
8181 """Resolves or infers config file path from module."""
82+ if argv is None :
83+ argv = ["" ]
84+
85+ if kwargs .get ("base_config" ):
86+ logger .info ("Using config : %s" , kwargs ["base_config" ])
87+ return resolve_config_path (kwargs ["base_config" ]), argv [1 :]
88+
89+ # if passing at least two arguments via list (no kwargs), then we have to specify
90+ # first one as either "" or python script like train_rl.py or train.py
91+ # the second argument is the yaml file
8292 if len (argv ) >= 2 and argv [1 ].endswith (".yml" ):
8393 return resolve_config_path (argv [1 ]), argv [2 :]
84- module = _module_from_path (argv [0 ])
94+ module = _module_from_path (argv [0 ]) if len ( argv ) > 0 else None
8595 if module not in _CONFIG_FILE_MAPPING :
86- raise ValueError (f"No config file provided and no default config found for module '{ module } '" )
87- config_path = os .path .join (MAXTEXT_CONFIGS_DIR , _CONFIG_FILE_MAPPING [module ])
88- logger .warning ("No config file provided, using default config mapping: %s" , config_path )
89- return config_path , argv [1 :]
96+ config_path = os .path .join (MAXTEXT_CONFIGS_DIR , "base.yml" )
97+ logger .warning ("No config file provided and no default config found for module '%s', using base.yml" , module )
98+ else :
99+ config_path = os .path .join (MAXTEXT_CONFIGS_DIR , _CONFIG_FILE_MAPPING [module ])
100+ logger .warning ("No config file provided, using default config mapping: %s" , config_path )
101+ remaining_argv = argv [1 :]
102+
103+ return config_path , remaining_argv
104+
105+
106+ def _resolve_or_infer_addl_config (** kwargs ):
107+ """Resolves or infers more configs from module."""
108+ inferred_kwargs = {}
109+ # if base_output_directory key is not seen
110+ if not kwargs .get ("base_output_directory" ):
111+ max_logging .warning ("base_output_directory is not provided; Using local directory called maxtext_output" )
112+ base_output_directory = os .path .abspath ("maxtext_output" )
113+ inferred_kwargs ["base_output_directory" ] = base_output_directory
114+
115+ # if hf_access_token key is not seen
116+ if not kwargs .get ("hf_access_token" ):
117+ hf_access_token = os .environ .get ("HF_TOKEN" )
118+ if hf_access_token :
119+ inferred_kwargs ["hf_access_token" ] = hf_access_token
120+
121+ return inferred_kwargs
90122
91123
92124def yaml_key_to_env_key (s : str ) -> str :
@@ -289,28 +321,35 @@ def get_keys(self) -> dict[str, Any]:
289321 return self ._flat_config
290322
291323
292- def initialize (argv : list [str ], ** kwargs ) -> HyperParameters :
324+ def initialize (argv : list [str ] | None = None , ** kwargs ) -> HyperParameters :
293325 """Initializes the configuration by loading YAML files, and applying CLI, env, and kwarg overrides."""
294326 pydantic_config = initialize_pydantic (argv , ** kwargs )
295327 config = HyperParameters (pydantic_config )
296328 return config
297329
298330
299- def initialize_pydantic (argv : list [str ], ** kwargs ) -> MaxTextConfig :
331+ def initialize_pydantic (argv : list [str ] | None = None , ** kwargs ) -> MaxTextConfig :
300332 """Initializes the configuration by loading YAML files, and applying CLI, env, and kwarg overrides.
301333 Returns pydantic MaxTextConfig class whereas `initialize` returns the og `HyperParameters`
302334 """
303335 # 1. Load base and inherited configs from file(s)
304- config_path , cli_args = _resolve_or_infer_config (argv )
336+ config_path , cli_args = _resolve_or_infer_config (argv , ** kwargs )
305337 base_yml_config = _load_config (config_path )
306338
307339 # 2. Get overrides from CLI and kwargs
308340 cli_cfg = omegaconf .OmegaConf .from_cli (cli_args )
309341 kwargs_cfg = omegaconf .OmegaConf .create (kwargs )
310342 overrides_cfg = omegaconf .OmegaConf .merge (cli_cfg , kwargs_cfg )
311343
312- # 3. Handle model-specific config
344+ temp_cfg1 = omegaconf .OmegaConf .merge (base_yml_config , overrides_cfg )
345+ # 3.1. infer more configs if possible
346+ temp_cfg1 = _resolve_or_infer_addl_config (** temp_cfg1 )
347+ # update overrides_cfg with temp_cfg1
348+ overrides_cfg = omegaconf .OmegaConf .merge (overrides_cfg , temp_cfg1 )
313349 temp_cfg = omegaconf .OmegaConf .merge (base_yml_config , overrides_cfg )
350+
351+ # 3.2. Handle model-specific config
352+
314353 model_name = temp_cfg .get ("model_name" , "default" )
315354 # The architecture for -Instruct v/s base models are the same, so for identifying the
316355 # architecture we replace "-Instruct" from the model_name and get the base model name
@@ -437,3 +476,13 @@ def initialize_pydantic(argv: list[str], **kwargs) -> MaxTextConfig:
437476# Shim for backward compatibility with pyconfig_deprecated_test.py
438477validate_and_update_keys = pyconfig_deprecated .validate_and_update_keys
439478__all__ = ["initialize" , "initialize_pydantic" ]
479+
480+
481+ class _CallablePyconfigModule (sys .modules [__name__ ].__class__ ):
482+ """Allows calling the module directly as mt.pyconfig()."""
483+
484+ def __call__ (self , argv : list [str ] | None = None , ** kwargs ) -> HyperParameters :
485+ return initialize (argv , ** kwargs )
486+
487+
488+ sys .modules [__name__ ].__class__ = _CallablePyconfigModule
0 commit comments