Skip to content

Commit a0a8092

Browse files
committed
simplify invocations
1 parent 1a51440 commit a0a8092

2 files changed

Lines changed: 35 additions & 17 deletions

File tree

src/maxtext/configs/pyconfig.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -77,18 +77,22 @@ def _module_from_path(path: str) -> str | None:
7777
return None
7878

7979

80-
def _resolve_or_infer_config(argv: list[str], **kwargs) -> tuple[str, list[str]]:
80+
def _resolve_or_infer_config(argv: list[str] | None = None, **kwargs) -> tuple[str, list[str]]:
8181
"""Resolves or infers config file path from module."""
82+
if argv is None:
83+
argv = [""]
8284
if len(argv) >= 2 and argv[1].endswith(".yml"):
8385
return resolve_config_path(argv[1]), argv[2:]
84-
module = _module_from_path(argv[0])
86+
module = _module_from_path(argv[0]) if len(argv) > 0 else None
8587
if module not in _CONFIG_FILE_MAPPING:
86-
raise ValueError(
87-
f"No config file provided and no default config found for module '{module}'"
88+
config_path = os.path.join(MAXTEXT_CONFIGS_DIR, "base.yml")
89+
logger.warning(
90+
"No config file provided and no default config found for module '%s', using base.yml", module
8891
)
89-
config_path = os.path.join(MAXTEXT_CONFIGS_DIR, _CONFIG_FILE_MAPPING[module])
90-
logger.warning("No config file provided, using default config mapping: %s", config_path)
91-
remaining_argv = argv[1:]
92+
else:
93+
config_path = os.path.join(MAXTEXT_CONFIGS_DIR, _CONFIG_FILE_MAPPING[module])
94+
logger.warning("No config file provided, using default config mapping: %s", config_path)
95+
remaining_argv = argv[1:] if len(argv) > 1 else []
9296

9397
return config_path, remaining_argv
9498

@@ -299,14 +303,14 @@ def get_keys(self) -> dict[str, Any]:
299303
return self._flat_config
300304

301305

302-
def initialize(argv: list[str], **kwargs) -> HyperParameters:
306+
def initialize(argv: list[str] | None = None, **kwargs) -> HyperParameters:
303307
"""Initializes the configuration by loading YAML files, and applying CLI, env, and kwarg overrides."""
304308
pydantic_config = initialize_pydantic(argv, **kwargs)
305309
config = HyperParameters(pydantic_config)
306310
return config
307311

308312

309-
def initialize_pydantic(argv: list[str], **kwargs) -> MaxTextConfig:
313+
def initialize_pydantic(argv: list[str] | None = None, **kwargs) -> MaxTextConfig:
310314
"""Initializes the configuration by loading YAML files, and applying CLI, env, and kwarg overrides.
311315
Returns pydantic MaxTextConfig class whereas `initialize` returns the og `HyperParameters`
312316
"""
@@ -446,3 +450,10 @@ def initialize_pydantic(argv: list[str], **kwargs) -> MaxTextConfig:
446450
# Shim for backward compatibility with pyconfig_deprecated_test.py
447451
validate_and_update_keys = pyconfig_deprecated.validate_and_update_keys
448452
__all__ = ["initialize", "initialize_pydantic"]
453+
454+
class _CallablePyconfigModule(sys.modules[__name__].__class__):
455+
"""Allows calling the module directly as mt.pyconfig()."""
456+
def __call__(self, argv: list[str] | None = None, **kwargs) -> HyperParameters:
457+
return initialize(argv, **kwargs)
458+
459+
sys.modules[__name__].__class__ = _CallablePyconfigModule

src/maxtext/utils/model_creation_utils.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,14 @@ def create_model(config, mesh, model_mode: str = MODEL_MODE_TRAIN, rngs: nnx.Rng
115115
return model
116116

117117

118-
def setup_configs_and_devices(argv: list[str], kwargs):
118+
def setup_configs_and_devices(argv: list[str] | None = None, kwargs: dict | None = None, **extra_kwargs):
119119
"""Setup device allocation and configs for training and inference."""
120-
config = pyconfig.initialize_pydantic(argv, **kwargs)
120+
if argv is None:
121+
argv = [""]
122+
123+
combined_kwargs = dict(kwargs) if kwargs else {}
124+
combined_kwargs.update(extra_kwargs)
125+
config = pyconfig.initialize_pydantic(argv, **combined_kwargs)
121126
devices = jax.devices()
122127
if config.num_trainer_slices == -1 and config.num_samplers_slices == -1:
123128
max_logging.log("Running on a single slice")
@@ -172,22 +177,24 @@ def setup_configs_and_devices(argv: list[str], kwargs):
172177
)
173178
trainer_fsdp = trainer_devices_per_slice // tp
174179

175-
trainer_update = {
180+
trainer_kwargs = dict(combined_kwargs)
181+
trainer_kwargs.update({
176182
"num_slices": config.num_trainer_slices,
177183
"ici_fsdp_parallelism": trainer_fsdp,
178184
"ici_tensor_parallelism": tp,
179185
"dcn_data_parallelism": config.num_trainer_slices,
180-
}
186+
})
181187

182-
sampler_update = {
188+
sampler_kwargs = dict(combined_kwargs)
189+
sampler_kwargs.update({
183190
"num_slices": config.num_samplers_slices,
184191
"ici_fsdp_parallelism": len(sampler_devices) // config.num_samplers_slices,
185192
"ici_tensor_parallelism": -1,
186193
"dcn_data_parallelism": config.num_samplers_slices,
187-
}
194+
})
188195

189-
trainer_config = pyconfig.initialize_pydantic(argv, **trainer_update)
190-
sampler_config = pyconfig.initialize_pydantic(argv, **sampler_update)
196+
trainer_config = pyconfig.initialize_pydantic(argv, **trainer_kwargs)
197+
sampler_config = pyconfig.initialize_pydantic(argv, **sampler_kwargs)
191198

192199
else:
193200
raise ValueError("num_trainer_slices and num_samplers_slices should be both -1 or positive")

0 commit comments

Comments
 (0)