@@ -170,22 +170,22 @@ def __init__(
170170
171171 self .key_pages = nnx .Cache (
172172 jnp .zeros (self .kv_pages_shape , dtype = self .dtype ),
173- sharding = self .kv_pages_axis_names ,
173+ out_sharding = self .kv_pages_axis_names ,
174174 )
175175 self .value_pages = nnx .Cache (
176176 jnp .zeros (self .kv_pages_shape , dtype = self .dtype ),
177- sharding = self .kv_pages_axis_names ,
177+ out_sharding = self .kv_pages_axis_names ,
178178 )
179179
180180 def _maybe_materialize_cache (self , cache : nnx .Cache ) -> nnx .Cache :
181181 """Materializes the cache if it's currently a ShapeDtypeStruct."""
182- if isinstance (cache .value , jax .ShapeDtypeStruct ):
182+ if isinstance (cache .get_value () , jax .ShapeDtypeStruct ):
183183 # This is needed because the Linen bridge lazily creates this state. We
184184 # need to ensure the cache state is accessible at runtime.
185185 # TODO: Delete this function when the to_linen bridge is no longer needed.
186186 return nnx .Cache (
187187 jnp .zeros (self .kv_pages_shape , dtype = self .dtype ),
188- sharding = cache .sharding ,
188+ out_sharding = cache .get_metadata ( "out_sharding" ) ,
189189 )
190190 return cache
191191
@@ -204,8 +204,8 @@ def get_kv_pages(self):
204204 self .key_pages = self ._maybe_materialize_cache (self .key_pages )
205205 self .value_pages = self ._maybe_materialize_cache (self .value_pages )
206206
207- self .key_pages .value = nn .with_logical_constraint (self .key_pages .value , self .kv_pages_axis_names )
208- self .value_pages .value = nn .with_logical_constraint (self .value_pages .value , self .kv_pages_axis_names )
207+ self .key_pages .set_value ( nn .with_logical_constraint (self .key_pages .get_value () , self .kv_pages_axis_names ) )
208+ self .value_pages .set_value ( nn .with_logical_constraint (self .value_pages .get_value () , self .kv_pages_axis_names ) )
209209 return self .key_pages , self .value_pages
210210
211211 def pad_qkv (self , * qkv ):
@@ -264,9 +264,9 @@ def paged_attention_v2_prefill(
264264 is the batch_size is only 1
265265 """
266266 assert query .shape [0 ] == 1 # ensure the batch size is 0
267- # shape of key_pages_cache.value is [num_kv_heads, num_pages, tokens_per_page, head_dim]
268- k_p = jnp .permute_dims (key_pages_cache .value , (1 , 2 , 0 , 3 ))
269- v_p = jnp .permute_dims (value_pages_cache .value , (1 , 2 , 0 , 3 ))
267+ # shape of key_pages_cache.get_value() is [num_kv_heads, num_pages, tokens_per_page, head_dim]
268+ k_p = jnp .permute_dims (key_pages_cache .get_value () , (1 , 2 , 0 , 3 ))
269+ v_p = jnp .permute_dims (value_pages_cache .get_value () , (1 , 2 , 0 , 3 ))
270270 c_q_l = jnp .array ([0 , page_state .sequence_lengths [0 ]]) # [0, prefill_true_length]
271271 num_seqs = jnp .array ([1 ])
272272 query = query [0 ] # [batch_size, max_num_tokens, num_kv_heads, head_dim] to [max_num_tokens, num_kv_heads, head_dim]
@@ -294,8 +294,8 @@ def paged_attention_v2_decode(
294294 """Apply ragged input Paged Attention in decode only."""
295295 batch_size = query .shape [0 ]
296296 query = jnp .squeeze (query , axis = 1 ) # [batch_size, seq_len, n_kv_head, head_dim] to [batch_size, n_kv_head, head_dim]
297- k_p = jnp .permute_dims (key_pages_cache .value , (1 , 2 , 0 , 3 ))
298- v_p = jnp .permute_dims (value_pages_cache .value , (1 , 2 , 0 , 3 ))
297+ k_p = jnp .permute_dims (key_pages_cache .get_value () , (1 , 2 , 0 , 3 ))
298+ v_p = jnp .permute_dims (value_pages_cache .get_value () , (1 , 2 , 0 , 3 ))
299299 c_q_l = jnp .arange (batch_size + 1 ) # one token per sequence
300300 num_seqs = jnp .array ([batch_size ]) # real number of requests, set it to batch_size
301301 result = paged_attention_kernel_v2 .ragged_paged_attention (
@@ -352,8 +352,8 @@ def wrap_paged_attention(q, k_pages, v_pages, lengths, page_indices, pages_per_c
352352
353353 return wrap_paged_attention (
354354 query ,
355- key_pages_cache .value ,
356- value_pages_cache .value ,
355+ key_pages_cache .get_value () ,
356+ value_pages_cache .get_value () ,
357357 page_state .sequence_lengths ,
358358 page_state .page_map ,
359359 self .pages_per_compute_block ,
@@ -441,12 +441,12 @@ def update_prefill_step_pages(
441441 ), f"prefill_step key/value should have the same shape, but getting { key .shape = } and { value .shape = } instead"
442442 batch_size , seq_len , n_kv_head , head_dim = key .shape
443443 assert seq_len % self .tokens_per_page == 0 , f"seq_length { seq_len } and tokens_per_page { self .tokens_per_page } "
444- assert key_pages_cache .value .shape == value_pages_cache .value .shape , (
444+ assert key_pages_cache .get_value () .shape == value_pages_cache .get_value () .shape , (
445445 f"prefill_step key/value_pages_cache should have the same shape, but "
446446 f"getting { key_pages_cache .shape = } and { value_pages_cache .shape = } instead"
447447 )
448448
449- v_n_kv , _ , v_p , v_d = key_pages_cache .value .shape
449+ v_n_kv , _ , v_p , v_d = key_pages_cache .get_value () .shape
450450 assert v_n_kv == n_kv_head , f"{ v_n_kv = } { n_kv_head = } "
451451 assert v_p == self .tokens_per_page , f"{ v_p = } { self .tokens_per_page = } "
452452 assert v_d == head_dim , f"{ v_d = } { head_dim = } "
@@ -485,13 +485,13 @@ def update_prefill_step_pages(
485485 ),
486486 )
487487
488- key_pages_cache .value = nn .with_logical_constraint (key , self .kv_pages_axis_names )
489- value_pages_cache .value = nn .with_logical_constraint (value , self .kv_pages_axis_names )
488+ key_pages_cache .set_value ( nn .with_logical_constraint (key , self .kv_pages_axis_names ) )
489+ value_pages_cache .set_value ( nn .with_logical_constraint (value , self .kv_pages_axis_names ) )
490490
491491 def update_decode_step_pages (self , key_pages_cache , value_pages_cache , key , value , page_state ):
492492 """Update decode-step pages"""
493- key_pages = key_pages_cache .value
494- value_pages = value_pages_cache .value
493+ key_pages = key_pages_cache .get_value ()
494+ value_pages = value_pages_cache .get_value ()
495495
496496 batch_size , _ , kv_heads , head_dim = key .shape
497497 kv_heads , _ , _ , head_dim = key_pages .shape
@@ -511,6 +511,6 @@ def update_decode_step_pages(self, key_pages_cache, value_pages_cache, key, valu
511511 key_pages_updated = key_pages .at [kv_indices , broadcast_pages , broadcast_pos ].set (new_key )
512512 value_pages_updated = value_pages .at [kv_indices , broadcast_pages , broadcast_pos ].set (new_value )
513513
514- key_pages_cache .value = key_pages_updated
515- value_pages_cache .value = value_pages_updated
514+ key_pages_cache .set_value ( key_pages_updated )
515+ value_pages_cache .set_value ( value_pages_updated )
516516 return key_pages_cache , value_pages_cache
0 commit comments