2323from jax .sharding import Mesh
2424from MaxText import pyconfig
2525from MaxText .common_types import MODEL_MODE_AUTOREGRESSIVE
26- from MaxText .globals import MAXTEXT_PKG_DIR
26+ from MaxText .globals import MAXTEXT_CONFIGS_DIR
2727from maxtext .utils import max_logging
2828from maxtext .utils import model_creation_utils
2929
@@ -73,7 +73,7 @@ def generate_maxtext_config(vllm_config: VllmConfig) -> pyconfig.HyperParameters
7373 raise ValueError ("hf_config_path must be provided when using MaxTextForCausalLM." )
7474
7575 # Add base config path to positional args
76- base_config_path = os .path .join (MAXTEXT_PKG_DIR , "configs " , "vllm.yml" )
76+ base_config_path = os .path .join (MAXTEXT_CONFIGS_DIR , "inference " , "vllm.yml" )
7777 argv_list = ["" , str (base_config_path )]
7878
7979 maxtext_config = pyconfig .initialize (argv_list , ** overrides )
@@ -151,7 +151,7 @@ def __call__(
151151
152152 with self .mesh , nn .logical_axis_rules (self .maxtext_config .logical_axis_rules ):
153153 aux_hidden_states = []
154- hidden , updated_kv_caches = self .model (
154+ hidden , kv_caches = self .model (
155155 decoder_input_tokens = input_ids ,
156156 decoder_positions = input_positions ,
157157 kv_caches = kv_caches ,
@@ -163,7 +163,7 @@ def __call__(
163163 # To be compatible with vLLM, we reshape to (batch * seq, dim).
164164 hidden = hidden .reshape ((- 1 , hidden .shape [- 1 ]))
165165
166- return updated_kv_caches , hidden , aux_hidden_states
166+ return kv_caches , hidden , aux_hidden_states
167167
168168 def forward (self , * args , ** kwargs ):
169169 """Alias for __call__ for compatibility.
0 commit comments