|
23 | 23 | from flax import linen as nn |
24 | 24 | from flax import nnx |
25 | 25 | from jax.ad_checkpoint import checkpoint_name |
| 26 | +import jax |
26 | 27 | import jax.numpy as jnp |
27 | 28 | from jax.sharding import Mesh |
28 | 29 | from maxtext.common.common_types import AttentionType, Config |
@@ -146,7 +147,13 @@ def __call__( |
146 | 147 | ): |
147 | 148 | cfg = self.config |
148 | 149 | # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) |
149 | | - if isinstance(inputs, tuple): |
| 150 | + is_scan_carry = False |
| 151 | + if isinstance(inputs, tuple) and len(inputs) == 3: |
| 152 | + hidden_states, stacked_kv_cache, layer_idx = inputs |
| 153 | + kv_cache = stacked_kv_cache[layer_idx] |
| 154 | + inputs = hidden_states |
| 155 | + is_scan_carry = True |
| 156 | + elif isinstance(inputs, tuple): |
150 | 157 | inputs = inputs[0] |
151 | 158 |
|
152 | 159 | inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_norm_length", "activation_embed")) |
@@ -201,7 +208,15 @@ def __call__( |
201 | 208 | jnp.sum(layer_output == 0) / jnp.size(layer_output), |
202 | 209 | ) |
203 | 210 |
|
204 | | - if cfg.scan_layers: |
| 211 | + if is_scan_carry: |
| 212 | + def update_cache(cache, val): |
| 213 | + if jnp.size(val) > 0: |
| 214 | + return cache.at[layer_idx].set(val) |
| 215 | + return cache |
| 216 | + |
| 217 | + stacked_kv_cache = jax.tree_util.tree_map(update_cache, stacked_kv_cache, kv_cache) |
| 218 | + return (layer_output, stacked_kv_cache, layer_idx + 1), None |
| 219 | + elif cfg.scan_layers: |
205 | 220 | return layer_output, None |
206 | 221 | else: |
207 | 222 | return layer_output, kv_cache |
|
0 commit comments