@@ -77,18 +77,22 @@ def _module_from_path(path: str) -> str | None:
7777 return None
7878
7979
80- def _resolve_or_infer_config (argv : list [str ], ** kwargs ) -> tuple [str , list [str ]]:
80+ def _resolve_or_infer_config (argv : list [str ] | None = None , ** kwargs ) -> tuple [str , list [str ]]:
8181 """Resolves or infers config file path from module."""
82+ if argv is None :
83+ argv = ["" ]
8284 if len (argv ) >= 2 and argv [1 ].endswith (".yml" ):
8385 return resolve_config_path (argv [1 ]), argv [2 :]
84- module = _module_from_path (argv [0 ])
86+ module = _module_from_path (argv [0 ]) if len ( argv ) > 0 else None
8587 if module not in _CONFIG_FILE_MAPPING :
86- raise ValueError (
87- f"No config file provided and no default config found for module '{ module } '"
88+ config_path = os .path .join (MAXTEXT_CONFIGS_DIR , "base.yml" )
89+ logger .warning (
90+ "No config file provided and no default config found for module '%s', using base.yml" , module
8891 )
89- config_path = os .path .join (MAXTEXT_CONFIGS_DIR , _CONFIG_FILE_MAPPING [module ])
90- logger .warning ("No config file provided, using default config mapping: %s" , config_path )
91- remaining_argv = argv [1 :]
92+ else :
93+ config_path = os .path .join (MAXTEXT_CONFIGS_DIR , _CONFIG_FILE_MAPPING [module ])
94+ logger .warning ("No config file provided, using default config mapping: %s" , config_path )
95+ remaining_argv = argv [1 :] if len (argv ) > 1 else []
9296
9397 return config_path , remaining_argv
9498
@@ -299,14 +303,14 @@ def get_keys(self) -> dict[str, Any]:
299303 return self ._flat_config
300304
301305
302- def initialize (argv : list [str ], ** kwargs ) -> HyperParameters :
306+ def initialize (argv : list [str ] | None = None , ** kwargs ) -> HyperParameters :
303307 """Initializes the configuration by loading YAML files, and applying CLI, env, and kwarg overrides."""
304308 pydantic_config = initialize_pydantic (argv , ** kwargs )
305309 config = HyperParameters (pydantic_config )
306310 return config
307311
308312
309- def initialize_pydantic (argv : list [str ], ** kwargs ) -> MaxTextConfig :
313+ def initialize_pydantic (argv : list [str ] | None = None , ** kwargs ) -> MaxTextConfig :
310314 """Initializes the configuration by loading YAML files, and applying CLI, env, and kwarg overrides.
311315 Returns pydantic MaxTextConfig class whereas `initialize` returns the og `HyperParameters`
312316 """
@@ -446,3 +450,10 @@ def initialize_pydantic(argv: list[str], **kwargs) -> MaxTextConfig:
446450# Shim for backward compatibility with pyconfig_deprecated_test.py
447451validate_and_update_keys = pyconfig_deprecated .validate_and_update_keys
448452__all__ = ["initialize" , "initialize_pydantic" ]
453+
454+ class _CallablePyconfigModule (sys .modules [__name__ ].__class__ ):
455+ """Allows calling the module directly as mt.pyconfig()."""
456+ def __call__ (self , argv : list [str ] | None = None , ** kwargs ) -> HyperParameters :
457+ return initialize (argv , ** kwargs )
458+
459+ sys .modules [__name__ ].__class__ = _CallablePyconfigModule
0 commit comments