Skip to content

Commit 496ed40

Browse files
Merge pull request #3129 from AI-Hypercomputer:nicogrande/duplicate-kv-cache
PiperOrigin-RevId: 869738345
2 parents 1238d84 + 37d547d commit 496ed40

6 files changed

Lines changed: 10 additions & 13 deletions

File tree

src/MaxText/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
from maxtext.trainers.post_train.dpo import dpo_utils
3535
from maxtext.utils import maxtext_utils
3636
from maxtext.utils import model_creation_utils
37-
from maxtext.utils.model_creation_utils import from_config
3837

3938
Transformer = models.Transformer
4039
transformer_as_linen = models.transformer_as_linen

src/MaxText/integration/vllm/maxtext_vllm_adapter/adapter.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from jax.sharding import Mesh
2424
from MaxText import pyconfig
2525
from MaxText.common_types import MODEL_MODE_AUTOREGRESSIVE
26-
from MaxText.globals import MAXTEXT_PKG_DIR
26+
from MaxText.globals import MAXTEXT_CONFIGS_DIR
2727
from maxtext.utils import max_logging
2828
from maxtext.utils import model_creation_utils
2929

@@ -73,7 +73,7 @@ def generate_maxtext_config(vllm_config: VllmConfig) -> pyconfig.HyperParameters
7373
raise ValueError("hf_config_path must be provided when using MaxTextForCausalLM.")
7474

7575
# Add base config path to positional args
76-
base_config_path = os.path.join(MAXTEXT_PKG_DIR, "configs", "vllm.yml")
76+
base_config_path = os.path.join(MAXTEXT_CONFIGS_DIR, "inference", "vllm.yml")
7777
argv_list = ["", str(base_config_path)]
7878

7979
maxtext_config = pyconfig.initialize(argv_list, **overrides)
@@ -151,7 +151,7 @@ def __call__(
151151

152152
with self.mesh, nn.logical_axis_rules(self.maxtext_config.logical_axis_rules):
153153
aux_hidden_states = []
154-
hidden, updated_kv_caches = self.model(
154+
hidden, kv_caches = self.model(
155155
decoder_input_tokens=input_ids,
156156
decoder_positions=input_positions,
157157
kv_caches=kv_caches,
@@ -163,7 +163,7 @@ def __call__(
163163
# To be compatible with vLLM, we reshape to (batch * seq, dim).
164164
hidden = hidden.reshape((-1, hidden.shape[-1]))
165165

166-
return updated_kv_caches, hidden, aux_hidden_states
166+
return kv_caches, hidden, aux_hidden_states
167167

168168
def forward(self, *args, **kwargs):
169169
"""Alias for __call__ for compatibility.

src/MaxText/rl/train_rl.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@
7373
os.environ["SKIP_JAX_PRECOMPILE"] = "1"
7474

7575
from MaxText import pyconfig
76-
from MaxText.globals import MAXTEXT_PKG_DIR
76+
from MaxText.globals import MAXTEXT_CONFIGS_DIR
7777
from MaxText.integration.tunix.tunix_adapter import TunixMaxTextAdapter
7878
from MaxText.rl.evaluate_rl import evaluate
7979
from MaxText.rl import utils_rl
@@ -370,7 +370,6 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
370370
max_logging.log("Creating policy model with same config as reference model on trainer mesh")
371371
actor_model, actor_mesh = get_maxtext_model(trainer_config, trainer_devices)
372372

373-
374373
if trainer_config.debug.rl:
375374
max_logging.log("Policy Model initialized successfully")
376375
nnx.display(actor_model)
@@ -495,8 +494,7 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
495494
"enable_tunix_perf_metrics is True but tunix.perf modules are not available, skipping Tunix-managed metrics."
496495
)
497496

498-
configs_dir = os.environ.get("MAXTEXT_CONFIGS_DIR", os.path.join(MAXTEXT_PKG_DIR, "configs"))
499-
vllm_config_path = epath.Path(configs_dir) / "vllm.yml"
497+
vllm_config_path = os.path.join(MAXTEXT_CONFIGS_DIR, "inference", "vllm.yml")
500498
argv_list = ["", str(vllm_config_path), "log_config=False"]
501499
vllm_config = pyconfig.initialize(argv_list)
502500

src/maxtext/vllm_decode.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
from maxtext.utils import model_creation_utils
4848
from MaxText import pyconfig
4949
from MaxText.common_types import Config
50-
from MaxText.globals import MAXTEXT_PKG_DIR
50+
from MaxText.globals import MAXTEXT_CONFIGS_DIR
5151
from MaxText.integration.tunix.tunix_adapter import TunixMaxTextAdapter
5252
from tunix.rl.rollout import base_rollout
5353
from tunix.rl.rollout.vllm_rollout import VllmRollout
@@ -185,7 +185,7 @@ def decode_with_vllm(
185185
f"and EP={ici_expert_parallelism if enable_expert_parallel else 0}..."
186186
)
187187

188-
vllm_config_path = os.path.join(MAXTEXT_PKG_DIR, "configs", "vllm.yml")
188+
vllm_config_path = os.path.join(MAXTEXT_CONFIGS_DIR, "inference", "vllm.yml")
189189
argv_list = ["", str(vllm_config_path), "log_config=False"]
190190
vllm_config = pyconfig.initialize(argv_list)
191191

tools/gcs_benchmarks/standalone_checkpointer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,14 @@
3131

3232
from flax.linen import partitioning as nn_partitioning
3333

34-
import MaxText as mt
3534
from MaxText import pyconfig
3635
from MaxText.train import get_first_step
3736
from MaxText.layers import models
3837
from maxtext.common import checkpointing
3938
from maxtext.utils import max_logging
4039
from maxtext.utils import maxtext_utils
4140
from maxtext.utils import train_utils
41+
from maxtext.utils.model_creation_utils import from_config
4242

4343
Transformer = models.transformer_as_linen
4444

@@ -52,7 +52,7 @@ def checkpoint_loop(config, state=None):
5252
ckpt_path:
5353
Returns:
5454
"""
55-
model = mt.from_config(config)
55+
model = from_config(config)
5656
mesh = model.mesh
5757
init_rng, checkpoint_manager, _, tx = train_utils.create_training_tools(config, model, mesh)
5858

0 commit comments

Comments
 (0)