Skip to content

Commit bab3440

Browse files
committed
Add support for transposed kv cache in all variants of custom kv cache
Differential Revision: [D93870395](https://our.internmc.facebook.com/intern/diff/D93870395/) [ghstack-poisoned]
1 parent 804ebb5 commit bab3440

2 files changed

Lines changed: 126 additions & 60 deletions

File tree

examples/models/llama/source_transformation/custom_kv_cache.py

Lines changed: 106 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def __init__(
4444
cache_type: QuantizedCacheType = QuantizedCacheType.AffineSymmetric,
4545
use_custom_update_cache_op: bool = False,
4646
return_float_values: bool = True,
47+
is_seq_at_dim_2: bool = False,
4748
):
4849
super().__init__()
4950
if cache_type not in (
@@ -55,13 +56,21 @@ def __init__(
5556
)
5657

5758
# For now supporting int8 only
59+
self.is_seq_at_dim_2 = is_seq_at_dim_2
5860
self.use_custom_update_cache_op = use_custom_update_cache_op
5961
self.quantized_cache_dtype = torch.int8
6062
self.cache_fp_type = torch.float32
6163
self.return_float_values = return_float_values
6264
self.max_context_length = max_context_length
63-
cache_shape = (max_batch_size, max_context_length, n_heads, head_dim)
64-
scale_shape = (max_batch_size, max_context_length, n_heads, 1)
65+
self.max_batch_size = max_batch_size
66+
self.n_heads = n_heads
67+
self.head_dim = head_dim
68+
if not self.is_seq_at_dim_2:
69+
cache_shape = (max_batch_size, max_context_length, n_heads, head_dim)
70+
scale_shape = (max_batch_size, max_context_length, n_heads, 1)
71+
else:
72+
cache_shape = (max_batch_size, n_heads, max_context_length, head_dim)
73+
scale_shape = (max_batch_size, n_heads, max_context_length, 1)
6574
self.register_buffer(
6675
"k_cache", torch.zeros(cache_shape, dtype=self.quantized_cache_dtype)
6776
)
@@ -113,52 +122,60 @@ def _quantize_and_update(self, input_pos, k_val, v_val, indices=None):
113122
start_pos = input_pos[0].item()
114123
if indices is not None:
115124
_ = torch.ops.llama.update_cache_with_indices(
116-
quantized_k_val, self.k_cache, start_pos, indices
125+
quantized_k_val, self.k_cache, start_pos, indices, self.is_seq_at_dim_2
117126
)
118127
_ = torch.ops.llama.update_cache_with_indices(
119-
k_scales, self.k_cache_scales, start_pos, indices
128+
k_scales, self.k_cache_scales, start_pos, indices, self.is_seq_at_dim_2
120129
)
121130
_ = torch.ops.llama.update_cache_with_indices(
122-
k_zero_points, self.k_cache_zero_points, start_pos, indices
131+
k_zero_points, self.k_cache_zero_points, start_pos, indices, self.is_seq_at_dim_2
123132
)
124133
_ = torch.ops.llama.update_cache_with_indices(
125-
quantized_v_val, self.v_cache, start_pos, indices
134+
quantized_v_val, self.v_cache, start_pos, indices, self.is_seq_at_dim_2
126135
)
127136
_ = torch.ops.llama.update_cache_with_indices(
128-
v_scales, self.v_cache_scales, start_pos, indices
137+
v_scales, self.v_cache_scales, start_pos, indices, self.is_seq_at_dim_2
129138
)
130139
_ = torch.ops.llama.update_cache_with_indices(
131-
v_zero_points, self.v_cache_zero_points, start_pos, indices
140+
v_zero_points, self.v_cache_zero_points, start_pos, indices, self.is_seq_at_dim_2
132141
)
133142
else:
134143
_ = torch.ops.llama.update_cache(
135-
quantized_k_val, self.k_cache, start_pos
144+
quantized_k_val, self.k_cache, start_pos, self.is_seq_at_dim_2
136145
)
137146
_ = torch.ops.llama.update_cache(
138-
k_scales, self.k_cache_scales, start_pos
147+
k_scales, self.k_cache_scales, start_pos, self.is_seq_at_dim_2
139148
)
140149
_ = torch.ops.llama.update_cache(
141-
k_zero_points, self.k_cache_zero_points, start_pos
150+
k_zero_points, self.k_cache_zero_points, start_pos, self.is_seq_at_dim_2
142151
)
143152
_ = torch.ops.llama.update_cache(
144-
quantized_v_val, self.v_cache, start_pos
153+
quantized_v_val, self.v_cache, start_pos, self.is_seq_at_dim_2
145154
)
146155
_ = torch.ops.llama.update_cache(
147-
v_scales, self.v_cache_scales, start_pos
156+
v_scales, self.v_cache_scales, start_pos, self.is_seq_at_dim_2
148157
)
149158
_ = torch.ops.llama.update_cache(
150-
v_zero_points, self.v_cache_zero_points, start_pos
159+
v_zero_points, self.v_cache_zero_points, start_pos, self.is_seq_at_dim_2
151160
)
152161
else:
153162
assert indices is None, "Indices not supported for this path"
154163
# Following is also broken because in prefill input_pos = [0]
155164
# but we need to update some slice of cache
156-
self.k_cache[:, input_pos] = quantized_k_val
157-
self.k_cache_scales[:, input_pos] = k_scales
158-
self.k_cache_zero_points[:, input_pos] = k_zero_points
159-
self.v_cache[:, input_pos] = quantized_v_val
160-
self.v_cache_scales[:, input_pos] = v_scales
161-
self.v_cache_zero_points[:, input_pos] = v_zero_points
165+
if self.is_seq_at_dim_2:
166+
self.k_cache[:, :, input_pos] = quantized_k_val
167+
self.k_cache_scales[:, :, input_pos] = k_scales
168+
self.k_cache_zero_points[:, :, input_pos] = k_zero_points
169+
self.v_cache[:, :, input_pos] = quantized_v_val
170+
self.v_cache_scales[:, :, input_pos] = v_scales
171+
self.v_cache_zero_points[:, :, input_pos] = v_zero_points
172+
else:
173+
self.k_cache[:, input_pos] = quantized_k_val
174+
self.k_cache_scales[:, input_pos] = k_scales
175+
self.k_cache_zero_points[:, input_pos] = k_zero_points
176+
self.v_cache[:, input_pos] = quantized_v_val
177+
self.v_cache_scales[:, input_pos] = v_scales
178+
self.v_cache_zero_points[:, input_pos] = v_zero_points
162179

163180
def _update_and_return_float_values(self, input_pos, k_val, v_val, indices=None):
164181
self._quantize_and_update(input_pos, k_val, v_val, indices)
@@ -188,17 +205,21 @@ def _update_and_return_float_values(self, input_pos, k_val, v_val, indices=None)
188205
if self.use_custom_update_cache_op:
189206
if indices is not None:
190207
_ = torch.ops.llama.update_cache_with_indices(
191-
k_val, k_out, start_pos, indices
208+
k_val, k_out, start_pos, indices, self.is_seq_at_dim_2
192209
)
193210
_ = torch.ops.llama.update_cache_with_indices(
194-
v_val, v_out, start_pos, indices
211+
v_val, v_out, start_pos, indices, self.is_seq_at_dim_2
195212
)
196213
else:
197-
_ = torch.ops.llama.update_cache(k_val, k_out, start_pos)
198-
_ = torch.ops.llama.update_cache(v_val, v_out, start_pos)
214+
_ = torch.ops.llama.update_cache(k_val, k_out, start_pos, self.is_seq_at_dim_2)
215+
_ = torch.ops.llama.update_cache(v_val, v_out, start_pos, self.is_seq_at_dim_2)
199216
else:
200-
k_out[:, input_pos] = k_val
201-
v_out[:, input_pos] = v_val
217+
if self.is_seq_at_dim_2:
218+
k_out[:, :, input_pos] = k_val
219+
v_out[:, :, input_pos] = v_val
220+
else:
221+
k_out[:, input_pos] = k_val
222+
v_out[:, input_pos] = v_val
202223

203224
return k_out, v_out
204225

@@ -217,8 +238,9 @@ def update(self, input_pos, k_val, v_val, indices=None):
217238
This shall be removed by subsequent post-export graph pass
218239
"""
219240

220-
k_val = k_val.transpose(1, 2)
221-
v_val = v_val.transpose(1, 2)
241+
if not self.is_seq_at_dim_2:
242+
k_val = k_val.transpose(1, 2)
243+
v_val = v_val.transpose(1, 2)
222244

223245
if self.return_float_values:
224246
k_out, v_out = self._update_and_return_float_values(
@@ -228,28 +250,34 @@ def update(self, input_pos, k_val, v_val, indices=None):
228250
k_out, v_out = self._update_and_return_quantized_values(
229251
input_pos, k_val, v_val, indices
230252
)
231-
return k_out.transpose(1, 2), v_out.transpose(1, 2)
253+
if not self.is_seq_at_dim_2:
254+
return k_out.transpose(1, 2), v_out.transpose(1, 2)
255+
else:
256+
return k_out, v_out
232257

233258
@classmethod
234259
def from_float(
235260
cls,
236261
kv_cache,
237262
cache_type: QuantizedCacheType,
238263
use_custom_update_cache_op: bool = False,
264+
is_seq_at_dim_2: bool = False,
239265
):
240266
max_batch_size, n_heads, max_context_length, head_dim = kv_cache.k_cache.shape
241267
if isinstance(kv_cache, CustomKVCache):
242268
# If replacing custom kv cache, then the shape is [B, S, H, D]
243-
max_batch_size, max_context_length, n_heads, head_dim = (
244-
kv_cache.k_cache.shape
245-
)
269+
max_batch_size = kv_cache.max_batch_size
270+
n_heads = kv_cache.n_heads
271+
max_context_length = kv_cache.max_context_length
272+
head_dim = kv_cache.head_dim
246273
return cls(
247274
max_batch_size,
248275
max_context_length,
249276
n_heads,
250277
head_dim,
251278
cache_type,
252279
use_custom_update_cache_op,
280+
is_seq_at_dim_2=is_seq_at_dim_2,
253281
)
254282

255283

@@ -312,10 +340,15 @@ def __init__(
312340
n_heads: int,
313341
head_dim: int,
314342
dtype=torch.float32,
343+
is_seq_at_dim_2: bool = False,
315344
):
345+
self.is_seq_at_dim_2 = is_seq_at_dim_2
316346
super().__init__()
317347
self.max_context_length = max_context_length
318-
cache_shape = (max_batch_size, max_context_length, n_heads, head_dim)
348+
if self.is_seq_at_dim_2:
349+
cache_shape = (max_batch_size, n_heads, max_context_length, head_dim)
350+
else:
351+
cache_shape = (max_batch_size, max_context_length, n_heads, head_dim)
319352

320353
self.max_batch_size = max_batch_size
321354
self.n_heads = n_heads
@@ -335,25 +368,26 @@ def update(
335368
indices: Optional[torch.Tensor] = None,
336369
) -> Tuple[torch.Tensor, torch.Tensor]:
337370
# input_pos: [S], k_val: [B, H, S, D]
338-
k_val = k_val.transpose(1, 2)
339-
v_val = v_val.transpose(1, 2)
371+
if not self.is_seq_at_dim_2:
372+
k_val = k_val.transpose(1, 2)
373+
v_val = v_val.transpose(1, 2)
340374
start_pos = input_pos[0].item()
341375

342376
if indices is not None:
343377
_ = torch.ops.llama.update_cache_with_indices(
344-
k_val, self.k_cache, start_pos, indices
378+
k_val, self.k_cache, start_pos, indices, self.is_seq_at_dim_2
345379
)
346380
_ = torch.ops.llama.update_cache_with_indices(
347-
v_val, self.v_cache, start_pos, indices
381+
v_val, self.v_cache, start_pos, indices, self.is_seq_at_dim_2
348382
)
349383
else:
350-
_ = torch.ops.llama.update_cache(k_val, self.k_cache, start_pos)
351-
_ = torch.ops.llama.update_cache(v_val, self.v_cache, start_pos)
384+
_ = torch.ops.llama.update_cache(k_val, self.k_cache, start_pos, self.is_seq_at_dim_2)
385+
_ = torch.ops.llama.update_cache(v_val, self.v_cache, start_pos, self.is_seq_at_dim_2)
352386

353-
return (
354-
self.k_cache.transpose(1, 2),
355-
self.v_cache.transpose(1, 2),
356-
)
387+
if not self.is_seq_at_dim_2:
388+
return (k_val.transpose(1, 2), v_val.transpose(1, 2))
389+
else:
390+
return (self.k_cache, self.v_cache)
357391

358392

359393
def replace_kv_cache_with_custom_kv_cache(module):
@@ -373,9 +407,11 @@ def replace_kv_cache_with_custom_kv_cache(module):
373407
def _replace_kv_cache_with_custom_kv_cache(module):
374408
for name, child in module.named_children():
375409
if isinstance(child, KVCache):
376-
cache_shape = child.k_cache.shape
377410
cache_dtype = child.k_cache.dtype
378-
max_batch_size, n_heads, max_context_length, head_dim = cache_shape
411+
max_batch_size = child.max_batch_size
412+
n_heads = child.n_heads
413+
max_context_length = child.max_context_length
414+
head_dim = child.head_dim
379415
setattr(
380416
module,
381417
name,
@@ -402,6 +438,7 @@ def __init__(
402438
cache_type: QuantizedCacheType = QuantizedCacheType.AffineSymmetric,
403439
use_custom_update_cache_op: bool = False,
404440
return_float_values: bool = True,
441+
is_seq_at_dim_2: bool = False,
405442
):
406443
# Look at attention.py for explanation on why max_context_length * 2
407444
super().__init__(
@@ -412,9 +449,11 @@ def __init__(
412449
cache_type,
413450
use_custom_update_cache_op,
414451
return_float_values,
452+
is_seq_at_dim_2,
415453
)
416454
self.cache_positions_manager = CachePositionsManager(self.max_context_length)
417455
self.is_ring_buffer = True
456+
self.is_seq_at_dim_2 = is_seq_at_dim_2
418457
self.window_size = max_context_length
419458

420459
def create_causal_mask_for_ring_buffer(self, start_pos, seq_len):
@@ -434,7 +473,10 @@ def update(self, input_pos, k_val, v_val):
434473
# 1. kv cache is stored as [B, S, H, D]
435474
# 2. If seq_len = k_val.size(2), we wont be able be able to optimize
436475
# away transpose at the output of k, v projection
437-
seq_len = k_val.transpose(1, 2).size(1)
476+
if not self.is_seq_at_dim_2:
477+
seq_len = k_val.transpose(1, 2).size(1)
478+
else:
479+
seq_len = k_val.size(2)
438480
assert seq_len <= self.k_cache.size(
439481
1
440482
), f"Update sequence length({seq_len}) for kv cache must be smaller than the cache size({self.k_cache.size(2)})"
@@ -454,7 +496,9 @@ def from_quantized_kv_cache(
454496
assert isinstance(
455497
kv_cache, QuantizedKVCache
456498
), "For QuantizedRingKVCache expect QuantizedKVCache as input kv_cache"
457-
max_batch_size, _, n_heads, head_dim = kv_cache.k_cache.shape
499+
max_batch_size = kv_cache.max_batch_size
500+
n_heads = kv_cache.n_heads
501+
head_dim = kv_cache.head_dim
458502
return cls(
459503
max_batch_size,
460504
sliding_window_size,
@@ -463,6 +507,8 @@ def from_quantized_kv_cache(
463507
kv_cache.cache_type,
464508
kv_cache.use_custom_update_cache_op,
465509
kv_cache.return_float_values,
510+
kv_cache.is_seq_at_dim_2,
511+
is_seq_at_dim_2=kv_cache.is_seq_at_dim_2,
466512
)
467513

468514

@@ -474,10 +520,11 @@ def __init__(
474520
n_heads,
475521
head_dim,
476522
dtype=torch.float32,
523+
is_seq_at_dim_2: bool = False,
477524
):
478525
# Look at attention.py for explanation on why max_context_length * 2
479526
super().__init__(
480-
max_batch_size, max_context_length * 2, n_heads, head_dim, dtype
527+
max_batch_size, max_context_length * 2, n_heads, head_dim, dtype, is_seq_at_dim_2
481528
)
482529
self.cache_positions_manager = CachePositionsManager(self.max_context_length)
483530
self.is_ring_buffer = True
@@ -500,7 +547,10 @@ def update(self, input_pos, k_val, v_val):
500547
# 1. kv cache is stored as [B, S, H, D]
501548
# 2. If seq_len = k_val.size(2), we wont be able be able to optimize
502549
# away transpose at the output of k, v projection
503-
seq_len = k_val.transpose(1, 2).size(1)
550+
if not self.is_seq_at_dim_2:
551+
seq_len = k_val.transpose(1, 2).size(1)
552+
else:
553+
seq_len = k_val.size(2)
504554
assert seq_len <= self.k_cache.size(
505555
1
506556
), f"Update sequence length({seq_len}) for kv cache must be smaller than the cache size({self.k_cache.size(2)})"
@@ -517,16 +567,21 @@ def from_custom_kv_cache(
517567
kv_cache,
518568
sliding_window_size,
519569
):
520-
max_batch_size, n_heads, _, head_dim = kv_cache.k_cache.shape
570+
max_batch_size = kv_cache.max_batch_size
571+
n_heads = kv_cache.n_heads
572+
head_dim = kv_cache.head_dim
521573
if isinstance(kv_cache, CustomKVCache):
522574
# If replacing custom kv cache, then the shape is [B, S, H, D]
523-
max_batch_size, _, n_heads, head_dim = kv_cache.k_cache.shape
575+
max_batch_size = kv_cache.max_batch_size
576+
n_heads = kv_cache.n_heads
577+
head_dim = kv_cache.head_dim
524578
return cls(
525579
max_batch_size,
526580
sliding_window_size,
527581
n_heads,
528582
head_dim,
529583
dtype=kv_cache.k_cache.dtype,
584+
is_seq_at_dim_2=kv_cache.is_seq_at_dim_2,
530585
)
531586

532587

0 commit comments

Comments
 (0)