Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 106 additions & 51 deletions examples/models/llama/source_transformation/custom_kv_cache.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
Expand Down Expand Up @@ -44,6 +44,7 @@
cache_type: QuantizedCacheType = QuantizedCacheType.AffineSymmetric,
use_custom_update_cache_op: bool = False,
return_float_values: bool = True,
is_seq_at_dim_2: bool = False,
):
super().__init__()
if cache_type not in (
Expand All @@ -55,13 +56,21 @@
)

# For now supporting int8 only
self.is_seq_at_dim_2 = is_seq_at_dim_2
self.use_custom_update_cache_op = use_custom_update_cache_op
self.quantized_cache_dtype = torch.int8
self.cache_fp_type = torch.float32
self.return_float_values = return_float_values
self.max_context_length = max_context_length
cache_shape = (max_batch_size, max_context_length, n_heads, head_dim)
scale_shape = (max_batch_size, max_context_length, n_heads, 1)
self.max_batch_size = max_batch_size
self.n_heads = n_heads
self.head_dim = head_dim
if not self.is_seq_at_dim_2:
cache_shape = (max_batch_size, max_context_length, n_heads, head_dim)
scale_shape = (max_batch_size, max_context_length, n_heads, 1)
else:
cache_shape = (max_batch_size, n_heads, max_context_length, head_dim)
scale_shape = (max_batch_size, n_heads, max_context_length, 1)
self.register_buffer(
"k_cache", torch.zeros(cache_shape, dtype=self.quantized_cache_dtype)
)
Expand Down Expand Up @@ -113,52 +122,60 @@
start_pos = input_pos[0].item()
if indices is not None:
_ = torch.ops.llama.update_cache_with_indices(
quantized_k_val, self.k_cache, start_pos, indices
quantized_k_val, self.k_cache, start_pos, indices, self.is_seq_at_dim_2
)
_ = torch.ops.llama.update_cache_with_indices(
k_scales, self.k_cache_scales, start_pos, indices
k_scales, self.k_cache_scales, start_pos, indices, self.is_seq_at_dim_2
)
_ = torch.ops.llama.update_cache_with_indices(
k_zero_points, self.k_cache_zero_points, start_pos, indices
k_zero_points, self.k_cache_zero_points, start_pos, indices, self.is_seq_at_dim_2
)
_ = torch.ops.llama.update_cache_with_indices(
quantized_v_val, self.v_cache, start_pos, indices
quantized_v_val, self.v_cache, start_pos, indices, self.is_seq_at_dim_2
)
_ = torch.ops.llama.update_cache_with_indices(
v_scales, self.v_cache_scales, start_pos, indices
v_scales, self.v_cache_scales, start_pos, indices, self.is_seq_at_dim_2
)
_ = torch.ops.llama.update_cache_with_indices(
v_zero_points, self.v_cache_zero_points, start_pos, indices
v_zero_points, self.v_cache_zero_points, start_pos, indices, self.is_seq_at_dim_2
)
else:
_ = torch.ops.llama.update_cache(
quantized_k_val, self.k_cache, start_pos
quantized_k_val, self.k_cache, start_pos, self.is_seq_at_dim_2
)
_ = torch.ops.llama.update_cache(
k_scales, self.k_cache_scales, start_pos
k_scales, self.k_cache_scales, start_pos, self.is_seq_at_dim_2
)
_ = torch.ops.llama.update_cache(
k_zero_points, self.k_cache_zero_points, start_pos
k_zero_points, self.k_cache_zero_points, start_pos, self.is_seq_at_dim_2
)
_ = torch.ops.llama.update_cache(
quantized_v_val, self.v_cache, start_pos
quantized_v_val, self.v_cache, start_pos, self.is_seq_at_dim_2
)
_ = torch.ops.llama.update_cache(
v_scales, self.v_cache_scales, start_pos
v_scales, self.v_cache_scales, start_pos, self.is_seq_at_dim_2
)
_ = torch.ops.llama.update_cache(
v_zero_points, self.v_cache_zero_points, start_pos
v_zero_points, self.v_cache_zero_points, start_pos, self.is_seq_at_dim_2
)
else:
assert indices is None, "Indices not supported for this path"
# Following is also broken because in prefill input_pos = [0]
# but we need to update some slice of cache
self.k_cache[:, input_pos] = quantized_k_val
self.k_cache_scales[:, input_pos] = k_scales
self.k_cache_zero_points[:, input_pos] = k_zero_points
self.v_cache[:, input_pos] = quantized_v_val
self.v_cache_scales[:, input_pos] = v_scales
self.v_cache_zero_points[:, input_pos] = v_zero_points
if self.is_seq_at_dim_2:
self.k_cache[:, :, input_pos] = quantized_k_val
self.k_cache_scales[:, :, input_pos] = k_scales
self.k_cache_zero_points[:, :, input_pos] = k_zero_points
self.v_cache[:, :, input_pos] = quantized_v_val
self.v_cache_scales[:, :, input_pos] = v_scales
self.v_cache_zero_points[:, :, input_pos] = v_zero_points
else:
self.k_cache[:, input_pos] = quantized_k_val
self.k_cache_scales[:, input_pos] = k_scales
self.k_cache_zero_points[:, input_pos] = k_zero_points
self.v_cache[:, input_pos] = quantized_v_val
self.v_cache_scales[:, input_pos] = v_scales
self.v_cache_zero_points[:, input_pos] = v_zero_points

def _update_and_return_float_values(self, input_pos, k_val, v_val, indices=None):
self._quantize_and_update(input_pos, k_val, v_val, indices)
Expand Down Expand Up @@ -188,17 +205,21 @@
if self.use_custom_update_cache_op:
if indices is not None:
_ = torch.ops.llama.update_cache_with_indices(
k_val, k_out, start_pos, indices
k_val, k_out, start_pos, indices, self.is_seq_at_dim_2
)
_ = torch.ops.llama.update_cache_with_indices(
v_val, v_out, start_pos, indices
v_val, v_out, start_pos, indices, self.is_seq_at_dim_2
)
else:
_ = torch.ops.llama.update_cache(k_val, k_out, start_pos)
_ = torch.ops.llama.update_cache(v_val, v_out, start_pos)
_ = torch.ops.llama.update_cache(k_val, k_out, start_pos, self.is_seq_at_dim_2)
_ = torch.ops.llama.update_cache(v_val, v_out, start_pos, self.is_seq_at_dim_2)
else:
k_out[:, input_pos] = k_val
v_out[:, input_pos] = v_val
if self.is_seq_at_dim_2:
k_out[:, :, input_pos] = k_val
v_out[:, :, input_pos] = v_val
else:
k_out[:, input_pos] = k_val
v_out[:, input_pos] = v_val

return k_out, v_out

Expand All @@ -217,8 +238,9 @@
This shall be removed by subsequent post-export graph pass
"""

k_val = k_val.transpose(1, 2)
v_val = v_val.transpose(1, 2)
if not self.is_seq_at_dim_2:
k_val = k_val.transpose(1, 2)
v_val = v_val.transpose(1, 2)

if self.return_float_values:
k_out, v_out = self._update_and_return_float_values(
Expand All @@ -228,28 +250,34 @@
k_out, v_out = self._update_and_return_quantized_values(
input_pos, k_val, v_val, indices
)
return k_out.transpose(1, 2), v_out.transpose(1, 2)
if not self.is_seq_at_dim_2:
return k_out.transpose(1, 2), v_out.transpose(1, 2)
else:
return k_out, v_out

@classmethod
def from_float(
cls,
kv_cache,
cache_type: QuantizedCacheType,
use_custom_update_cache_op: bool = False,
is_seq_at_dim_2: bool = False,
):
max_batch_size, n_heads, max_context_length, head_dim = kv_cache.k_cache.shape
if isinstance(kv_cache, CustomKVCache):
# If replacing custom kv cache, then the shape is [B, S, H, D]
max_batch_size, max_context_length, n_heads, head_dim = (
kv_cache.k_cache.shape
)
max_batch_size = kv_cache.max_batch_size
n_heads = kv_cache.n_heads
max_context_length = kv_cache.max_context_length
head_dim = kv_cache.head_dim
return cls(
max_batch_size,
max_context_length,
n_heads,
head_dim,
cache_type,
use_custom_update_cache_op,
is_seq_at_dim_2=is_seq_at_dim_2,
)


Expand Down Expand Up @@ -312,10 +340,15 @@
n_heads: int,
head_dim: int,
dtype=torch.float32,
is_seq_at_dim_2: bool = False,
):
self.is_seq_at_dim_2 = is_seq_at_dim_2
super().__init__()
self.max_context_length = max_context_length
cache_shape = (max_batch_size, max_context_length, n_heads, head_dim)
if self.is_seq_at_dim_2:
cache_shape = (max_batch_size, n_heads, max_context_length, head_dim)
else:
cache_shape = (max_batch_size, max_context_length, n_heads, head_dim)

self.max_batch_size = max_batch_size
self.n_heads = n_heads
Expand All @@ -335,25 +368,26 @@
indices: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# input_pos: [S], k_val: [B, H, S, D]
k_val = k_val.transpose(1, 2)
v_val = v_val.transpose(1, 2)
if not self.is_seq_at_dim_2:
k_val = k_val.transpose(1, 2)
v_val = v_val.transpose(1, 2)
start_pos = input_pos[0].item()

if indices is not None:
_ = torch.ops.llama.update_cache_with_indices(
k_val, self.k_cache, start_pos, indices
k_val, self.k_cache, start_pos, indices, self.is_seq_at_dim_2
)
_ = torch.ops.llama.update_cache_with_indices(
v_val, self.v_cache, start_pos, indices
v_val, self.v_cache, start_pos, indices, self.is_seq_at_dim_2
)
else:
_ = torch.ops.llama.update_cache(k_val, self.k_cache, start_pos)
_ = torch.ops.llama.update_cache(v_val, self.v_cache, start_pos)
_ = torch.ops.llama.update_cache(k_val, self.k_cache, start_pos, self.is_seq_at_dim_2)
_ = torch.ops.llama.update_cache(v_val, self.v_cache, start_pos, self.is_seq_at_dim_2)

return (
self.k_cache.transpose(1, 2),
self.v_cache.transpose(1, 2),
)
if not self.is_seq_at_dim_2:
return (k_val.transpose(1, 2), v_val.transpose(1, 2))
else:
return (self.k_cache, self.v_cache)


def replace_kv_cache_with_custom_kv_cache(module):
Expand All @@ -373,9 +407,11 @@
def _replace_kv_cache_with_custom_kv_cache(module):
for name, child in module.named_children():
if isinstance(child, KVCache):
cache_shape = child.k_cache.shape
cache_dtype = child.k_cache.dtype
max_batch_size, n_heads, max_context_length, head_dim = cache_shape
max_batch_size = child.max_batch_size
n_heads = child.n_heads
max_context_length = child.max_context_length
head_dim = child.head_dim
setattr(
module,
name,
Expand All @@ -402,6 +438,7 @@
cache_type: QuantizedCacheType = QuantizedCacheType.AffineSymmetric,
use_custom_update_cache_op: bool = False,
return_float_values: bool = True,
is_seq_at_dim_2: bool = False,
):
# Look at attention.py for explanation on why max_context_length * 2
super().__init__(
Expand All @@ -412,9 +449,11 @@
cache_type,
use_custom_update_cache_op,
return_float_values,
is_seq_at_dim_2,
)
self.cache_positions_manager = CachePositionsManager(self.max_context_length)
self.is_ring_buffer = True
self.is_seq_at_dim_2 = is_seq_at_dim_2
self.window_size = max_context_length

def create_causal_mask_for_ring_buffer(self, start_pos, seq_len):
Expand All @@ -434,7 +473,10 @@
# 1. kv cache is stored as [B, S, H, D]
# 2. If seq_len = k_val.size(2), we wont be able be able to optimize
# away transpose at the output of k, v projection
seq_len = k_val.transpose(1, 2).size(1)
if not self.is_seq_at_dim_2:
seq_len = k_val.transpose(1, 2).size(1)
else:
seq_len = k_val.size(2)
assert seq_len <= self.k_cache.size(
1
), f"Update sequence length({seq_len}) for kv cache must be smaller than the cache size({self.k_cache.size(2)})"
Expand All @@ -454,7 +496,9 @@
assert isinstance(
kv_cache, QuantizedKVCache
), "For QuantizedRingKVCache expect QuantizedKVCache as input kv_cache"
max_batch_size, _, n_heads, head_dim = kv_cache.k_cache.shape
max_batch_size = kv_cache.max_batch_size
n_heads = kv_cache.n_heads
head_dim = kv_cache.head_dim
return cls(
max_batch_size,
sliding_window_size,
Expand All @@ -463,6 +507,8 @@
kv_cache.cache_type,
kv_cache.use_custom_update_cache_op,
kv_cache.return_float_values,
kv_cache.is_seq_at_dim_2,
is_seq_at_dim_2=kv_cache.is_seq_at_dim_2,
)


Expand All @@ -474,10 +520,11 @@
n_heads,
head_dim,
dtype=torch.float32,
is_seq_at_dim_2: bool = False,
):
# Look at attention.py for explanation on why max_context_length * 2
super().__init__(
max_batch_size, max_context_length * 2, n_heads, head_dim, dtype
max_batch_size, max_context_length * 2, n_heads, head_dim, dtype, is_seq_at_dim_2
)
self.cache_positions_manager = CachePositionsManager(self.max_context_length)
self.is_ring_buffer = True
Expand All @@ -500,7 +547,10 @@
# 1. kv cache is stored as [B, S, H, D]
# 2. If seq_len = k_val.size(2), we wont be able be able to optimize
# away transpose at the output of k, v projection
seq_len = k_val.transpose(1, 2).size(1)
if not self.is_seq_at_dim_2:
seq_len = k_val.transpose(1, 2).size(1)
else:
seq_len = k_val.size(2)
assert seq_len <= self.k_cache.size(
1
), f"Update sequence length({seq_len}) for kv cache must be smaller than the cache size({self.k_cache.size(2)})"
Expand All @@ -517,16 +567,21 @@
kv_cache,
sliding_window_size,
):
max_batch_size, n_heads, _, head_dim = kv_cache.k_cache.shape
max_batch_size = kv_cache.max_batch_size
n_heads = kv_cache.n_heads
head_dim = kv_cache.head_dim
if isinstance(kv_cache, CustomKVCache):
# If replacing custom kv cache, then the shape is [B, S, H, D]
max_batch_size, _, n_heads, head_dim = kv_cache.k_cache.shape
max_batch_size = kv_cache.max_batch_size
n_heads = kv_cache.n_heads
head_dim = kv_cache.head_dim
return cls(
max_batch_size,
sliding_window_size,
n_heads,
head_dim,
dtype=kv_cache.k_cache.dtype,
is_seq_at_dim_2=kv_cache.is_seq_at_dim_2,
)


Expand Down
Loading
Loading