@@ -44,6 +44,7 @@ def __init__(
4444 cache_type : QuantizedCacheType = QuantizedCacheType .AffineSymmetric ,
4545 use_custom_update_cache_op : bool = False ,
4646 return_float_values : bool = True ,
47+ is_seq_at_dim_2 : bool = False ,
4748 ):
4849 super ().__init__ ()
4950 if cache_type not in (
@@ -55,13 +56,21 @@ def __init__(
5556 )
5657
5758 # For now supporting int8 only
59+ self .is_seq_at_dim_2 = is_seq_at_dim_2
5860 self .use_custom_update_cache_op = use_custom_update_cache_op
5961 self .quantized_cache_dtype = torch .int8
6062 self .cache_fp_type = torch .float32
6163 self .return_float_values = return_float_values
6264 self .max_context_length = max_context_length
63- cache_shape = (max_batch_size , max_context_length , n_heads , head_dim )
64- scale_shape = (max_batch_size , max_context_length , n_heads , 1 )
65+ self .max_batch_size = max_batch_size
66+ self .n_heads = n_heads
67+ self .head_dim = head_dim
68+ if not self .is_seq_at_dim_2 :
69+ cache_shape = (max_batch_size , max_context_length , n_heads , head_dim )
70+ scale_shape = (max_batch_size , max_context_length , n_heads , 1 )
71+ else :
72+ cache_shape = (max_batch_size , n_heads , max_context_length , head_dim )
73+ scale_shape = (max_batch_size , n_heads , max_context_length , 1 )
6574 self .register_buffer (
6675 "k_cache" , torch .zeros (cache_shape , dtype = self .quantized_cache_dtype )
6776 )
@@ -113,52 +122,60 @@ def _quantize_and_update(self, input_pos, k_val, v_val, indices=None):
113122 start_pos = input_pos [0 ].item ()
114123 if indices is not None :
115124 _ = torch .ops .llama .update_cache_with_indices (
116- quantized_k_val , self .k_cache , start_pos , indices
125+ quantized_k_val , self .k_cache , start_pos , indices , self . is_seq_at_dim_2
117126 )
118127 _ = torch .ops .llama .update_cache_with_indices (
119- k_scales , self .k_cache_scales , start_pos , indices
128+ k_scales , self .k_cache_scales , start_pos , indices , self . is_seq_at_dim_2
120129 )
121130 _ = torch .ops .llama .update_cache_with_indices (
122- k_zero_points , self .k_cache_zero_points , start_pos , indices
131+ k_zero_points , self .k_cache_zero_points , start_pos , indices , self . is_seq_at_dim_2
123132 )
124133 _ = torch .ops .llama .update_cache_with_indices (
125- quantized_v_val , self .v_cache , start_pos , indices
134+ quantized_v_val , self .v_cache , start_pos , indices , self . is_seq_at_dim_2
126135 )
127136 _ = torch .ops .llama .update_cache_with_indices (
128- v_scales , self .v_cache_scales , start_pos , indices
137+ v_scales , self .v_cache_scales , start_pos , indices , self . is_seq_at_dim_2
129138 )
130139 _ = torch .ops .llama .update_cache_with_indices (
131- v_zero_points , self .v_cache_zero_points , start_pos , indices
140+ v_zero_points , self .v_cache_zero_points , start_pos , indices , self . is_seq_at_dim_2
132141 )
133142 else :
134143 _ = torch .ops .llama .update_cache (
135- quantized_k_val , self .k_cache , start_pos
144+ quantized_k_val , self .k_cache , start_pos , self . is_seq_at_dim_2
136145 )
137146 _ = torch .ops .llama .update_cache (
138- k_scales , self .k_cache_scales , start_pos
147+ k_scales , self .k_cache_scales , start_pos , self . is_seq_at_dim_2
139148 )
140149 _ = torch .ops .llama .update_cache (
141- k_zero_points , self .k_cache_zero_points , start_pos
150+ k_zero_points , self .k_cache_zero_points , start_pos , self . is_seq_at_dim_2
142151 )
143152 _ = torch .ops .llama .update_cache (
144- quantized_v_val , self .v_cache , start_pos
153+ quantized_v_val , self .v_cache , start_pos , self . is_seq_at_dim_2
145154 )
146155 _ = torch .ops .llama .update_cache (
147- v_scales , self .v_cache_scales , start_pos
156+ v_scales , self .v_cache_scales , start_pos , self . is_seq_at_dim_2
148157 )
149158 _ = torch .ops .llama .update_cache (
150- v_zero_points , self .v_cache_zero_points , start_pos
159+ v_zero_points , self .v_cache_zero_points , start_pos , self . is_seq_at_dim_2
151160 )
152161 else :
153162 assert indices is None , "Indices not supported for this path"
154163 # Following is also broken because in prefill input_pos = [0]
155164 # but we need to update some slice of cache
156- self .k_cache [:, input_pos ] = quantized_k_val
157- self .k_cache_scales [:, input_pos ] = k_scales
158- self .k_cache_zero_points [:, input_pos ] = k_zero_points
159- self .v_cache [:, input_pos ] = quantized_v_val
160- self .v_cache_scales [:, input_pos ] = v_scales
161- self .v_cache_zero_points [:, input_pos ] = v_zero_points
165+ if self .is_seq_at_dim_2 :
166+ self .k_cache [:, :, input_pos ] = quantized_k_val
167+ self .k_cache_scales [:, :, input_pos ] = k_scales
168+ self .k_cache_zero_points [:, :, input_pos ] = k_zero_points
169+ self .v_cache [:, :, input_pos ] = quantized_v_val
170+ self .v_cache_scales [:, :, input_pos ] = v_scales
171+ self .v_cache_zero_points [:, :, input_pos ] = v_zero_points
172+ else :
173+ self .k_cache [:, input_pos ] = quantized_k_val
174+ self .k_cache_scales [:, input_pos ] = k_scales
175+ self .k_cache_zero_points [:, input_pos ] = k_zero_points
176+ self .v_cache [:, input_pos ] = quantized_v_val
177+ self .v_cache_scales [:, input_pos ] = v_scales
178+ self .v_cache_zero_points [:, input_pos ] = v_zero_points
162179
163180 def _update_and_return_float_values (self , input_pos , k_val , v_val , indices = None ):
164181 self ._quantize_and_update (input_pos , k_val , v_val , indices )
@@ -188,17 +205,21 @@ def _update_and_return_float_values(self, input_pos, k_val, v_val, indices=None)
188205 if self .use_custom_update_cache_op :
189206 if indices is not None :
190207 _ = torch .ops .llama .update_cache_with_indices (
191- k_val , k_out , start_pos , indices
208+ k_val , k_out , start_pos , indices , self . is_seq_at_dim_2
192209 )
193210 _ = torch .ops .llama .update_cache_with_indices (
194- v_val , v_out , start_pos , indices
211+ v_val , v_out , start_pos , indices , self . is_seq_at_dim_2
195212 )
196213 else :
197- _ = torch .ops .llama .update_cache (k_val , k_out , start_pos )
198- _ = torch .ops .llama .update_cache (v_val , v_out , start_pos )
214+ _ = torch .ops .llama .update_cache (k_val , k_out , start_pos , self . is_seq_at_dim_2 )
215+ _ = torch .ops .llama .update_cache (v_val , v_out , start_pos , self . is_seq_at_dim_2 )
199216 else :
200- k_out [:, input_pos ] = k_val
201- v_out [:, input_pos ] = v_val
217+ if self .is_seq_at_dim_2 :
218+ k_out [:, :, input_pos ] = k_val
219+ v_out [:, :, input_pos ] = v_val
220+ else :
221+ k_out [:, input_pos ] = k_val
222+ v_out [:, input_pos ] = v_val
202223
203224 return k_out , v_out
204225
@@ -217,8 +238,9 @@ def update(self, input_pos, k_val, v_val, indices=None):
217238 This shall be removed by subsequent post-export graph pass
218239 """
219240
220- k_val = k_val .transpose (1 , 2 )
221- v_val = v_val .transpose (1 , 2 )
241+ if not self .is_seq_at_dim_2 :
242+ k_val = k_val .transpose (1 , 2 )
243+ v_val = v_val .transpose (1 , 2 )
222244
223245 if self .return_float_values :
224246 k_out , v_out = self ._update_and_return_float_values (
@@ -228,28 +250,34 @@ def update(self, input_pos, k_val, v_val, indices=None):
228250 k_out , v_out = self ._update_and_return_quantized_values (
229251 input_pos , k_val , v_val , indices
230252 )
231- return k_out .transpose (1 , 2 ), v_out .transpose (1 , 2 )
253+ if not self .is_seq_at_dim_2 :
254+ return k_out .transpose (1 , 2 ), v_out .transpose (1 , 2 )
255+ else :
256+ return k_out , v_out
232257
233258 @classmethod
234259 def from_float (
235260 cls ,
236261 kv_cache ,
237262 cache_type : QuantizedCacheType ,
238263 use_custom_update_cache_op : bool = False ,
264+ is_seq_at_dim_2 : bool = False ,
239265 ):
240266 max_batch_size , n_heads , max_context_length , head_dim = kv_cache .k_cache .shape
241267 if isinstance (kv_cache , CustomKVCache ):
242268 # If replacing custom kv cache, then the shape is [B, S, H, D]
243- max_batch_size , max_context_length , n_heads , head_dim = (
244- kv_cache .k_cache .shape
245- )
269+ max_batch_size = kv_cache .max_batch_size
270+ n_heads = kv_cache .n_heads
271+ max_context_length = kv_cache .max_context_length
272+ head_dim = kv_cache .head_dim
246273 return cls (
247274 max_batch_size ,
248275 max_context_length ,
249276 n_heads ,
250277 head_dim ,
251278 cache_type ,
252279 use_custom_update_cache_op ,
280+ is_seq_at_dim_2 = is_seq_at_dim_2 ,
253281 )
254282
255283
@@ -312,10 +340,15 @@ def __init__(
312340 n_heads : int ,
313341 head_dim : int ,
314342 dtype = torch .float32 ,
343+ is_seq_at_dim_2 : bool = False ,
315344 ):
345+ self .is_seq_at_dim_2 = is_seq_at_dim_2
316346 super ().__init__ ()
317347 self .max_context_length = max_context_length
318- cache_shape = (max_batch_size , max_context_length , n_heads , head_dim )
348+ if self .is_seq_at_dim_2 :
349+ cache_shape = (max_batch_size , n_heads , max_context_length , head_dim )
350+ else :
351+ cache_shape = (max_batch_size , max_context_length , n_heads , head_dim )
319352
320353 self .max_batch_size = max_batch_size
321354 self .n_heads = n_heads
@@ -335,25 +368,26 @@ def update(
335368 indices : Optional [torch .Tensor ] = None ,
336369 ) -> Tuple [torch .Tensor , torch .Tensor ]:
337370 # input_pos: [S], k_val: [B, H, S, D]
338- k_val = k_val .transpose (1 , 2 )
339- v_val = v_val .transpose (1 , 2 )
371+ if not self .is_seq_at_dim_2 :
372+ k_val = k_val .transpose (1 , 2 )
373+ v_val = v_val .transpose (1 , 2 )
340374 start_pos = input_pos [0 ].item ()
341375
342376 if indices is not None :
343377 _ = torch .ops .llama .update_cache_with_indices (
344- k_val , self .k_cache , start_pos , indices
378+ k_val , self .k_cache , start_pos , indices , self . is_seq_at_dim_2
345379 )
346380 _ = torch .ops .llama .update_cache_with_indices (
347- v_val , self .v_cache , start_pos , indices
381+ v_val , self .v_cache , start_pos , indices , self . is_seq_at_dim_2
348382 )
349383 else :
350- _ = torch .ops .llama .update_cache (k_val , self .k_cache , start_pos )
351- _ = torch .ops .llama .update_cache (v_val , self .v_cache , start_pos )
384+ _ = torch .ops .llama .update_cache (k_val , self .k_cache , start_pos , self . is_seq_at_dim_2 )
385+ _ = torch .ops .llama .update_cache (v_val , self .v_cache , start_pos , self . is_seq_at_dim_2 )
352386
353- return (
354- self . k_cache . transpose (1 , 2 ),
355- self . v_cache . transpose ( 1 , 2 ),
356- )
387+ if not self . is_seq_at_dim_2 :
388+ return ( k_val . transpose (1 , 2 ), v_val . transpose ( 1 , 2 ))
389+ else :
390+ return ( self . k_cache , self . v_cache )
357391
358392
359393def replace_kv_cache_with_custom_kv_cache (module ):
@@ -373,9 +407,11 @@ def replace_kv_cache_with_custom_kv_cache(module):
373407def _replace_kv_cache_with_custom_kv_cache (module ):
374408 for name , child in module .named_children ():
375409 if isinstance (child , KVCache ):
376- cache_shape = child .k_cache .shape
377410 cache_dtype = child .k_cache .dtype
378- max_batch_size , n_heads , max_context_length , head_dim = cache_shape
411+ max_batch_size = child .max_batch_size
412+ n_heads = child .n_heads
413+ max_context_length = child .max_context_length
414+ head_dim = child .head_dim
379415 setattr (
380416 module ,
381417 name ,
@@ -402,6 +438,7 @@ def __init__(
402438 cache_type : QuantizedCacheType = QuantizedCacheType .AffineSymmetric ,
403439 use_custom_update_cache_op : bool = False ,
404440 return_float_values : bool = True ,
441+ is_seq_at_dim_2 : bool = False ,
405442 ):
406443 # Look at attention.py for explanation on why max_context_length * 2
407444 super ().__init__ (
@@ -412,9 +449,11 @@ def __init__(
412449 cache_type ,
413450 use_custom_update_cache_op ,
414451 return_float_values ,
452+ is_seq_at_dim_2 ,
415453 )
416454 self .cache_positions_manager = CachePositionsManager (self .max_context_length )
417455 self .is_ring_buffer = True
456+ self .is_seq_at_dim_2 = is_seq_at_dim_2
418457 self .window_size = max_context_length
419458
420459 def create_causal_mask_for_ring_buffer (self , start_pos , seq_len ):
@@ -434,7 +473,10 @@ def update(self, input_pos, k_val, v_val):
434473 # 1. kv cache is stored as [B, S, H, D]
435474 # 2. If seq_len = k_val.size(2), we wont be able be able to optimize
436475 # away transpose at the output of k, v projection
437- seq_len = k_val .transpose (1 , 2 ).size (1 )
476+ if not self .is_seq_at_dim_2 :
477+ seq_len = k_val .transpose (1 , 2 ).size (1 )
478+ else :
479+ seq_len = k_val .size (2 )
438480 assert seq_len <= self .k_cache .size (
439481 1
440482 ), f"Update sequence length({ seq_len } ) for kv cache must be smaller than the cache size({ self .k_cache .size (2 )} )"
@@ -454,7 +496,9 @@ def from_quantized_kv_cache(
454496 assert isinstance (
455497 kv_cache , QuantizedKVCache
456498 ), "For QuantizedRingKVCache expect QuantizedKVCache as input kv_cache"
457- max_batch_size , _ , n_heads , head_dim = kv_cache .k_cache .shape
499+ max_batch_size = kv_cache .max_batch_size
500+ n_heads = kv_cache .n_heads
501+ head_dim = kv_cache .head_dim
458502 return cls (
459503 max_batch_size ,
460504 sliding_window_size ,
@@ -463,6 +507,8 @@ def from_quantized_kv_cache(
463507 kv_cache .cache_type ,
464508 kv_cache .use_custom_update_cache_op ,
465509 kv_cache .return_float_values ,
510+ kv_cache .is_seq_at_dim_2 ,
511+ is_seq_at_dim_2 = kv_cache .is_seq_at_dim_2 ,
466512 )
467513
468514
@@ -474,10 +520,11 @@ def __init__(
474520 n_heads ,
475521 head_dim ,
476522 dtype = torch .float32 ,
523+ is_seq_at_dim_2 : bool = False ,
477524 ):
478525 # Look at attention.py for explanation on why max_context_length * 2
479526 super ().__init__ (
480- max_batch_size , max_context_length * 2 , n_heads , head_dim , dtype
527+ max_batch_size , max_context_length * 2 , n_heads , head_dim , dtype , is_seq_at_dim_2
481528 )
482529 self .cache_positions_manager = CachePositionsManager (self .max_context_length )
483530 self .is_ring_buffer = True
@@ -500,7 +547,10 @@ def update(self, input_pos, k_val, v_val):
500547 # 1. kv cache is stored as [B, S, H, D]
501548 # 2. If seq_len = k_val.size(2), we wont be able be able to optimize
502549 # away transpose at the output of k, v projection
503- seq_len = k_val .transpose (1 , 2 ).size (1 )
550+ if not self .is_seq_at_dim_2 :
551+ seq_len = k_val .transpose (1 , 2 ).size (1 )
552+ else :
553+ seq_len = k_val .size (2 )
504554 assert seq_len <= self .k_cache .size (
505555 1
506556 ), f"Update sequence length({ seq_len } ) for kv cache must be smaller than the cache size({ self .k_cache .size (2 )} )"
@@ -517,16 +567,21 @@ def from_custom_kv_cache(
517567 kv_cache ,
518568 sliding_window_size ,
519569 ):
520- max_batch_size , n_heads , _ , head_dim = kv_cache .k_cache .shape
570+ max_batch_size = kv_cache .max_batch_size
571+ n_heads = kv_cache .n_heads
572+ head_dim = kv_cache .head_dim
521573 if isinstance (kv_cache , CustomKVCache ):
522574 # If replacing custom kv cache, then the shape is [B, S, H, D]
523- max_batch_size , _ , n_heads , head_dim = kv_cache .k_cache .shape
575+ max_batch_size = kv_cache .max_batch_size
576+ n_heads = kv_cache .n_heads
577+ head_dim = kv_cache .head_dim
524578 return cls (
525579 max_batch_size ,
526580 sliding_window_size ,
527581 n_heads ,
528582 head_dim ,
529583 dtype = kv_cache .k_cache .dtype ,
584+ is_seq_at_dim_2 = kv_cache .is_seq_at_dim_2 ,
530585 )
531586
532587
0 commit comments