diff --git a/src/maxtext/configs/pyconfig.py b/src/maxtext/configs/pyconfig.py index 0dfb76f29c..fb3235379a 100644 --- a/src/maxtext/configs/pyconfig.py +++ b/src/maxtext/configs/pyconfig.py @@ -503,6 +503,14 @@ def _initialize_pydantic(argv: list[str] | None = None, **kwargs) -> MaxTextConf pydantic_kwargs = _prepare_for_pydantic(raw_keys_dict) + # Resolve relative tokenizer_path against the config directory (fileset root on Borg) + if pydantic_kwargs.get("tokenizer_path"): + fileset_root = os.path.dirname(config_path) + candidate_path = os.path.join(fileset_root, pydantic_kwargs["tokenizer_path"]) + if os.path.exists(candidate_path): + logger.info("Resolved tokenizer_path %s to %s under fileset root", pydantic_kwargs["tokenizer_path"], candidate_path) + pydantic_kwargs["tokenizer_path"] = candidate_path + if pydantic_kwargs.get("use_tokamax_splash") and pydantic_kwargs.get("use_jax_splash"): raise ValueError("At most one of `use_tokamax_splash` and `use_jax_splash` can be set to True.") diff --git a/src/maxtext/inference/decode.py b/src/maxtext/inference/decode.py index bc60d51932..32e2d2a22a 100644 --- a/src/maxtext/inference/decode.py +++ b/src/maxtext/inference/decode.py @@ -176,7 +176,7 @@ def main(argv: Sequence[str]) -> None: # Prefill rng, rng_prefill = jax.random.split(rng) # Split RNG before calling prefill for i in range(_NUM_STREAMS): - with jax.profiler.StepTraceAnnotation("prefill", stream=i): + with jax.profiler.StepTraceAnnotation("prefill", step_num=i): prefill_result, first_token = engine.prefill( params=params, padded_tokens=tokens, @@ -206,7 +206,7 @@ def main(argv: Sequence[str]) -> None: sampled_tokens_list.append(_batch_first_result_token(first_token_list, batch_size)) for i in steps: rng, rng_generate = jax.random.split(rng) - with jax.profiler.StepTraceAnnotation("generate", step=i): + with jax.profiler.StepTraceAnnotation("generate", step_num=i): decode_state, sampled_tokens = engine.generate(params, decode_state, rng=rng_generate) # Automatically deactivate profiler after profiler_steps steps