Skip to content

Commit 4eca564

Browse files
committed
working for llama
1 parent 6176367 commit 4eca564

9 files changed

Lines changed: 126 additions & 17 deletions

File tree

src/maxtext/inference/vllm_decode.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@
4848
from vllm import LLM
4949
from vllm.sampling_params import SamplingParams
5050
from maxtext.configs import pyconfig
51+
import maxtext.integration.vllm.maxtext_vllm_adapter as adapter
52+
adapter.register()
5153

5254
os.environ["SKIP_JAX_PRECOMPILE"] = "1"
5355
os.environ["NEW_MODEL_DESIGN"] = "1"

src/maxtext/layers/decoders.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -986,9 +986,6 @@ def __call__(
986986
current_broadcast_args = list(broadcast_args)
987987
current_in_axes_tuple = list(in_axes_tuple)
988988

989-
current_broadcast_args.append(attention_metadata)
990-
current_in_axes_tuple.append(nn.broadcast)
991-
992989
if kv_caches is not None:
993990
# Stack kv_caches for scan: [num_layers, ...]
994991
stacked_kv_cache = jnp.stack(kv_caches, axis=0)
@@ -998,6 +995,13 @@ def __call__(
998995

999996
# We don't pass kv_cache as a scanned argument anymore
1000997

998+
# Pass None for previous_chunk, slot, page_state, kv_cache to align with __call__ signature
999+
current_broadcast_args.extend([None, None, None, None, attention_metadata])
1000+
current_in_axes_tuple.extend([nn.broadcast] * 5)
1001+
1002+
max_logging.info(f"DEBUG: len(current_broadcast_args)={len(current_broadcast_args)}")
1003+
max_logging.info(f"DEBUG: current_broadcast_args={[type(a) for a in current_broadcast_args]}")
1004+
10011005
final_carry, _ = self.scan_decoder_layers(
10021006
cfg,
10031007
RemattedBlockLayer,

src/maxtext/models/gemma3.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,13 @@ def __call__(
194194
):
195195
cfg = self.config
196196
# Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache))
197-
if isinstance(inputs, tuple):
197+
is_scan_carry = False
198+
if isinstance(inputs, tuple) and len(inputs) == 3:
199+
hidden_states, stacked_kv_cache, layer_idx = inputs
200+
kv_cache = stacked_kv_cache[layer_idx]
201+
inputs = hidden_states
202+
is_scan_carry = True
203+
elif isinstance(inputs, tuple):
198204
inputs = inputs[0]
199205
inputs = nn.with_logical_constraint(inputs, self.activation_axis_names)
200206
inputs = checkpoint_name(inputs, "decoder_layer_input")
@@ -244,7 +250,15 @@ def __call__(
244250
jnp.sum(layer_output == 0) / jnp.size(layer_output),
245251
)
246252

247-
if cfg.scan_layers:
253+
if is_scan_carry:
254+
def update_cache(cache, val):
255+
if jnp.size(val) > 0:
256+
return cache.at[layer_idx].set(val)
257+
return cache
258+
259+
stacked_kv_cache = jax.tree_util.tree_map(update_cache, stacked_kv_cache, kv_cache)
260+
return (layer_output, stacked_kv_cache, layer_idx + 1), None
261+
elif cfg.scan_layers:
248262
return layer_output, None
249263
else:
250264
return layer_output, kv_cache

src/maxtext/models/gemma4.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,13 @@ def __call__(
322322
):
323323
cfg = self.config
324324
# Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache))
325-
if isinstance(inputs, tuple):
325+
is_scan_carry = False
326+
if isinstance(inputs, tuple) and len(inputs) == 3:
327+
hidden_states, stacked_kv_cache, layer_idx = inputs
328+
kv_cache = stacked_kv_cache[layer_idx]
329+
inputs = hidden_states
330+
is_scan_carry = True
331+
elif isinstance(inputs, tuple):
326332
inputs = inputs[0]
327333
inputs = nn.with_logical_constraint(inputs, self.activation_axis_names)
328334
inputs = checkpoint_name(inputs, "decoder_layer_input")
@@ -383,7 +389,15 @@ def __call__(
383389
jnp.sum(layer_output == 0) / jnp.size(layer_output),
384390
)
385391

386-
if cfg.scan_layers:
392+
if is_scan_carry:
393+
def update_cache(cache, val):
394+
if jnp.size(val) > 0:
395+
return cache.at[layer_idx].set(val)
396+
return cache
397+
398+
stacked_kv_cache = jax.tree_util.tree_map(update_cache, stacked_kv_cache, kv_cache)
399+
return (layer_output, stacked_kv_cache, layer_idx + 1), None
400+
elif cfg.scan_layers:
387401
return layer_output, None
388402
else:
389403
return layer_output, kv_cache

src/maxtext/models/gpt_oss.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from flax import linen as nn
2424
from flax import nnx
2525
from jax.ad_checkpoint import checkpoint_name
26+
import jax
2627
import jax.numpy as jnp
2728
from jax.sharding import Mesh
2829
from maxtext.common.common_types import AttentionType, Config
@@ -146,7 +147,13 @@ def __call__(
146147
):
147148
cfg = self.config
148149
# Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache))
149-
if isinstance(inputs, tuple):
150+
is_scan_carry = False
151+
if isinstance(inputs, tuple) and len(inputs) == 3:
152+
hidden_states, stacked_kv_cache, layer_idx = inputs
153+
kv_cache = stacked_kv_cache[layer_idx]
154+
inputs = hidden_states
155+
is_scan_carry = True
156+
elif isinstance(inputs, tuple):
150157
inputs = inputs[0]
151158

152159
inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_norm_length", "activation_embed"))
@@ -201,7 +208,15 @@ def __call__(
201208
jnp.sum(layer_output == 0) / jnp.size(layer_output),
202209
)
203210

204-
if cfg.scan_layers:
211+
if is_scan_carry:
212+
def update_cache(cache, val):
213+
if jnp.size(val) > 0:
214+
return cache.at[layer_idx].set(val)
215+
return cache
216+
217+
stacked_kv_cache = jax.tree_util.tree_map(update_cache, stacked_kv_cache, kv_cache)
218+
return (layer_output, stacked_kv_cache, layer_idx + 1), None
219+
elif cfg.scan_layers:
205220
return layer_output, None
206221
else:
207222
return layer_output, kv_cache

src/maxtext/models/llama2.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import functools
2020
from flax import nnx
2121
from jax.ad_checkpoint import checkpoint_name
22+
import jax
2223
import jax.numpy as jnp
2324
from jax.sharding import Mesh
2425
from maxtext.common.common_types import Config
@@ -152,7 +153,13 @@ def __call__(
152153
cfg = self.config
153154

154155
# Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache))
155-
if isinstance(inputs, tuple):
156+
is_scan_carry = False
157+
if isinstance(inputs, tuple) and len(inputs) == 3:
158+
hidden_states, stacked_kv_cache, layer_idx = inputs
159+
kv_cache = stacked_kv_cache[layer_idx]
160+
inputs = hidden_states
161+
is_scan_carry = True
162+
elif isinstance(inputs, tuple):
156163
inputs = inputs[0]
157164
inputs = self._maybe_shard_with_logical(inputs, self.activation_axis_names)
158165
inputs = checkpoint_name(inputs, "decoder_layer_input")
@@ -206,7 +213,15 @@ def __call__(
206213
jnp.sum(layer_output == 0) / jnp.size(layer_output),
207214
)
208215

209-
if cfg.scan_layers:
216+
if is_scan_carry:
217+
def update_cache(cache, val):
218+
if jnp.size(val) > 0:
219+
return cache.at[layer_idx].set(val)
220+
return cache
221+
222+
stacked_kv_cache = jax.tree_util.tree_map(update_cache, stacked_kv_cache, kv_cache)
223+
return (layer_output, stacked_kv_cache, layer_idx + 1), None
224+
elif cfg.scan_layers:
210225
return layer_output, None
211226
else:
212227
return layer_output, kv_cache

src/maxtext/models/llama4.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from flax import linen as nn
2121
from flax import nnx
22+
import jax
2223
from jax import lax
2324
from jax.ad_checkpoint import checkpoint_name
2425
import jax.numpy as jnp
@@ -452,7 +453,13 @@ def __call__(
452453
assert cfg.num_experts >= 1, "Expected the Llama4 config to have `num_experts > 1`."
453454

454455
# Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache))
455-
if isinstance(inputs, tuple):
456+
is_scan_carry = False
457+
if isinstance(inputs, tuple) and len(inputs) == 3:
458+
hidden_states, stacked_kv_cache, layer_idx = inputs
459+
kv_cache = stacked_kv_cache[layer_idx]
460+
inputs = hidden_states
461+
is_scan_carry = True
462+
elif isinstance(inputs, tuple):
456463
inputs = inputs[0]
457464
inputs = nn.with_logical_constraint(inputs, self.activation_axis_names)
458465
inputs = checkpoint_name(inputs, "decoder_layer_input")
@@ -504,7 +511,15 @@ def __call__(
504511
jnp.sum(layer_output == 0) / jnp.size(layer_output),
505512
)
506513

507-
if cfg.scan_layers:
514+
if is_scan_carry:
515+
def update_cache(cache, val):
516+
if jnp.size(val) > 0:
517+
return cache.at[layer_idx].set(val)
518+
return cache
519+
520+
stacked_kv_cache = jax.tree_util.tree_map(update_cache, stacked_kv_cache, kv_cache)
521+
return (layer_output, stacked_kv_cache, layer_idx + 1), None
522+
elif cfg.scan_layers:
508523
return layer_output, None
509524
else:
510525
return layer_output, kv_cache

src/maxtext/models/mistral.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from flax import linen as nn
2020
from flax import nnx
2121
from jax.ad_checkpoint import checkpoint_name
22+
import jax
2223
import jax.numpy as jnp
2324
from jax.sharding import Mesh
2425
from maxtext.common.common_types import Config
@@ -136,7 +137,13 @@ def __call__(
136137
cfg = self.config
137138

138139
# Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache))
139-
if isinstance(inputs, tuple):
140+
is_scan_carry = False
141+
if isinstance(inputs, tuple) and len(inputs) == 3:
142+
hidden_states, stacked_kv_cache, layer_idx = inputs
143+
kv_cache = stacked_kv_cache[layer_idx]
144+
inputs = hidden_states
145+
is_scan_carry = True
146+
elif isinstance(inputs, tuple):
140147
inputs = inputs[0]
141148
inputs = nn.with_logical_constraint(inputs, self.activation_axis_names)
142149
inputs = checkpoint_name(inputs, "decoder_layer_input")
@@ -181,7 +188,15 @@ def __call__(
181188
jnp.sum(layer_output == 0) / jnp.size(layer_output),
182189
)
183190

184-
if cfg.scan_layers:
191+
if is_scan_carry:
192+
def update_cache(cache, val):
193+
if jnp.size(val) > 0:
194+
return cache.at[layer_idx].set(val)
195+
return cache
196+
197+
stacked_kv_cache = jax.tree_util.tree_map(update_cache, stacked_kv_cache, kv_cache)
198+
return (layer_output, stacked_kv_cache, layer_idx + 1), None
199+
elif cfg.scan_layers:
185200
return layer_output, None
186201
else:
187202
return layer_output, kv_cache

src/maxtext/models/olmo3.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from flax import linen as nn
2525
from flax import nnx
2626
from jax.ad_checkpoint import checkpoint_name
27+
import jax
2728
import jax.numpy as jnp
2829
from jax.sharding import Mesh
2930
from maxtext.common.common_types import AttentionType, Config
@@ -155,7 +156,13 @@ def __call__(
155156
):
156157
cfg = self.config
157158
# Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache))
158-
if isinstance(inputs, tuple):
159+
is_scan_carry = False
160+
if isinstance(inputs, tuple) and len(inputs) == 3:
161+
hidden_states, stacked_kv_cache, layer_idx = inputs
162+
kv_cache = stacked_kv_cache[layer_idx]
163+
inputs = hidden_states
164+
is_scan_carry = True
165+
elif isinstance(inputs, tuple):
159166
inputs = inputs[0]
160167

161168
inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_norm_length", "activation_embed"))
@@ -209,7 +216,15 @@ def __call__(
209216
jnp.sum(layer_output == 0) / jnp.size(layer_output),
210217
)
211218

212-
if cfg.scan_layers:
219+
if is_scan_carry:
220+
def update_cache(cache, val):
221+
if jnp.size(val) > 0:
222+
return cache.at[layer_idx].set(val)
223+
return cache
224+
225+
stacked_kv_cache = jax.tree_util.tree_map(update_cache, stacked_kv_cache, kv_cache)
226+
return (layer_output, stacked_kv_cache, layer_idx + 1), None
227+
elif cfg.scan_layers:
213228
return layer_output, None
214229
else:
215230
return layer_output, kv_cache

0 commit comments

Comments
 (0)