diff --git a/examples/models/llama/source_transformation/custom_kv_cache.py b/examples/models/llama/source_transformation/custom_kv_cache.py index 8d4d37e0e93..d66b9e2bb94 100644 --- a/examples/models/llama/source_transformation/custom_kv_cache.py +++ b/examples/models/llama/source_transformation/custom_kv_cache.py @@ -44,6 +44,7 @@ def __init__( cache_type: QuantizedCacheType = QuantizedCacheType.AffineSymmetric, use_custom_update_cache_op: bool = False, return_float_values: bool = True, + is_seq_at_dim_2: bool = False, ): super().__init__() if cache_type not in ( @@ -55,13 +56,21 @@ def __init__( ) # For now supporting int8 only + self.is_seq_at_dim_2 = is_seq_at_dim_2 self.use_custom_update_cache_op = use_custom_update_cache_op self.quantized_cache_dtype = torch.int8 self.cache_fp_type = torch.float32 self.return_float_values = return_float_values self.max_context_length = max_context_length - cache_shape = (max_batch_size, max_context_length, n_heads, head_dim) - scale_shape = (max_batch_size, max_context_length, n_heads, 1) + self.max_batch_size = max_batch_size + self.n_heads = n_heads + self.head_dim = head_dim + if not self.is_seq_at_dim_2: + cache_shape = (max_batch_size, max_context_length, n_heads, head_dim) + scale_shape = (max_batch_size, max_context_length, n_heads, 1) + else: + cache_shape = (max_batch_size, n_heads, max_context_length, head_dim) + scale_shape = (max_batch_size, n_heads, max_context_length, 1) self.register_buffer( "k_cache", torch.zeros(cache_shape, dtype=self.quantized_cache_dtype) ) @@ -113,52 +122,60 @@ def _quantize_and_update(self, input_pos, k_val, v_val, indices=None): start_pos = input_pos[0].item() if indices is not None: _ = torch.ops.llama.update_cache_with_indices( - quantized_k_val, self.k_cache, start_pos, indices + quantized_k_val, self.k_cache, start_pos, indices, self.is_seq_at_dim_2 ) _ = torch.ops.llama.update_cache_with_indices( - k_scales, self.k_cache_scales, start_pos, indices + k_scales, self.k_cache_scales, start_pos, indices, self.is_seq_at_dim_2 ) _ = torch.ops.llama.update_cache_with_indices( - k_zero_points, self.k_cache_zero_points, start_pos, indices + k_zero_points, self.k_cache_zero_points, start_pos, indices, self.is_seq_at_dim_2 ) _ = torch.ops.llama.update_cache_with_indices( - quantized_v_val, self.v_cache, start_pos, indices + quantized_v_val, self.v_cache, start_pos, indices, self.is_seq_at_dim_2 ) _ = torch.ops.llama.update_cache_with_indices( - v_scales, self.v_cache_scales, start_pos, indices + v_scales, self.v_cache_scales, start_pos, indices, self.is_seq_at_dim_2 ) _ = torch.ops.llama.update_cache_with_indices( - v_zero_points, self.v_cache_zero_points, start_pos, indices + v_zero_points, self.v_cache_zero_points, start_pos, indices, self.is_seq_at_dim_2 ) else: _ = torch.ops.llama.update_cache( - quantized_k_val, self.k_cache, start_pos + quantized_k_val, self.k_cache, start_pos, self.is_seq_at_dim_2 ) _ = torch.ops.llama.update_cache( - k_scales, self.k_cache_scales, start_pos + k_scales, self.k_cache_scales, start_pos, self.is_seq_at_dim_2 ) _ = torch.ops.llama.update_cache( - k_zero_points, self.k_cache_zero_points, start_pos + k_zero_points, self.k_cache_zero_points, start_pos, self.is_seq_at_dim_2 ) _ = torch.ops.llama.update_cache( - quantized_v_val, self.v_cache, start_pos + quantized_v_val, self.v_cache, start_pos, self.is_seq_at_dim_2 ) _ = torch.ops.llama.update_cache( - v_scales, self.v_cache_scales, start_pos + v_scales, self.v_cache_scales, start_pos, self.is_seq_at_dim_2 ) _ = torch.ops.llama.update_cache( - v_zero_points, self.v_cache_zero_points, start_pos + v_zero_points, self.v_cache_zero_points, start_pos, self.is_seq_at_dim_2 ) else: assert indices is None, "Indices not supported for this path" # Following is also broken because in prefill input_pos = [0] # but we need to update some slice of cache - self.k_cache[:, input_pos] = quantized_k_val - self.k_cache_scales[:, input_pos] = k_scales - self.k_cache_zero_points[:, input_pos] = k_zero_points - self.v_cache[:, input_pos] = quantized_v_val - self.v_cache_scales[:, input_pos] = v_scales - self.v_cache_zero_points[:, input_pos] = v_zero_points + if self.is_seq_at_dim_2: + self.k_cache[:, :, input_pos] = quantized_k_val + self.k_cache_scales[:, :, input_pos] = k_scales + self.k_cache_zero_points[:, :, input_pos] = k_zero_points + self.v_cache[:, :, input_pos] = quantized_v_val + self.v_cache_scales[:, :, input_pos] = v_scales + self.v_cache_zero_points[:, :, input_pos] = v_zero_points + else: + self.k_cache[:, input_pos] = quantized_k_val + self.k_cache_scales[:, input_pos] = k_scales + self.k_cache_zero_points[:, input_pos] = k_zero_points + self.v_cache[:, input_pos] = quantized_v_val + self.v_cache_scales[:, input_pos] = v_scales + self.v_cache_zero_points[:, input_pos] = v_zero_points def _update_and_return_float_values(self, input_pos, k_val, v_val, indices=None): self._quantize_and_update(input_pos, k_val, v_val, indices) @@ -188,17 +205,21 @@ def _update_and_return_float_values(self, input_pos, k_val, v_val, indices=None) if self.use_custom_update_cache_op: if indices is not None: _ = torch.ops.llama.update_cache_with_indices( - k_val, k_out, start_pos, indices + k_val, k_out, start_pos, indices, self.is_seq_at_dim_2 ) _ = torch.ops.llama.update_cache_with_indices( - v_val, v_out, start_pos, indices + v_val, v_out, start_pos, indices, self.is_seq_at_dim_2 ) else: - _ = torch.ops.llama.update_cache(k_val, k_out, start_pos) - _ = torch.ops.llama.update_cache(v_val, v_out, start_pos) + _ = torch.ops.llama.update_cache(k_val, k_out, start_pos, self.is_seq_at_dim_2) + _ = torch.ops.llama.update_cache(v_val, v_out, start_pos, self.is_seq_at_dim_2) else: - k_out[:, input_pos] = k_val - v_out[:, input_pos] = v_val + if self.is_seq_at_dim_2: + k_out[:, :, input_pos] = k_val + v_out[:, :, input_pos] = v_val + else: + k_out[:, input_pos] = k_val + v_out[:, input_pos] = v_val return k_out, v_out @@ -217,8 +238,9 @@ def update(self, input_pos, k_val, v_val, indices=None): This shall be removed by subsequent post-export graph pass """ - k_val = k_val.transpose(1, 2) - v_val = v_val.transpose(1, 2) + if not self.is_seq_at_dim_2: + k_val = k_val.transpose(1, 2) + v_val = v_val.transpose(1, 2) if self.return_float_values: k_out, v_out = self._update_and_return_float_values( @@ -228,7 +250,10 @@ def update(self, input_pos, k_val, v_val, indices=None): k_out, v_out = self._update_and_return_quantized_values( input_pos, k_val, v_val, indices ) - return k_out.transpose(1, 2), v_out.transpose(1, 2) + if not self.is_seq_at_dim_2: + return k_out.transpose(1, 2), v_out.transpose(1, 2) + else: + return k_out, v_out @classmethod def from_float( @@ -236,13 +261,15 @@ def from_float( kv_cache, cache_type: QuantizedCacheType, use_custom_update_cache_op: bool = False, + is_seq_at_dim_2: bool = False, ): max_batch_size, n_heads, max_context_length, head_dim = kv_cache.k_cache.shape if isinstance(kv_cache, CustomKVCache): # If replacing custom kv cache, then the shape is [B, S, H, D] - max_batch_size, max_context_length, n_heads, head_dim = ( - kv_cache.k_cache.shape - ) + max_batch_size = kv_cache.max_batch_size + n_heads = kv_cache.n_heads + max_context_length = kv_cache.max_context_length + head_dim = kv_cache.head_dim return cls( max_batch_size, max_context_length, @@ -250,6 +277,7 @@ def from_float( head_dim, cache_type, use_custom_update_cache_op, + is_seq_at_dim_2=is_seq_at_dim_2, ) @@ -312,10 +340,15 @@ def __init__( n_heads: int, head_dim: int, dtype=torch.float32, + is_seq_at_dim_2: bool = False, ): + self.is_seq_at_dim_2 = is_seq_at_dim_2 super().__init__() self.max_context_length = max_context_length - cache_shape = (max_batch_size, max_context_length, n_heads, head_dim) + if self.is_seq_at_dim_2: + cache_shape = (max_batch_size, n_heads, max_context_length, head_dim) + else: + cache_shape = (max_batch_size, max_context_length, n_heads, head_dim) self.max_batch_size = max_batch_size self.n_heads = n_heads @@ -335,25 +368,26 @@ def update( indices: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: # input_pos: [S], k_val: [B, H, S, D] - k_val = k_val.transpose(1, 2) - v_val = v_val.transpose(1, 2) + if not self.is_seq_at_dim_2: + k_val = k_val.transpose(1, 2) + v_val = v_val.transpose(1, 2) start_pos = input_pos[0].item() if indices is not None: _ = torch.ops.llama.update_cache_with_indices( - k_val, self.k_cache, start_pos, indices + k_val, self.k_cache, start_pos, indices, self.is_seq_at_dim_2 ) _ = torch.ops.llama.update_cache_with_indices( - v_val, self.v_cache, start_pos, indices + v_val, self.v_cache, start_pos, indices, self.is_seq_at_dim_2 ) else: - _ = torch.ops.llama.update_cache(k_val, self.k_cache, start_pos) - _ = torch.ops.llama.update_cache(v_val, self.v_cache, start_pos) + _ = torch.ops.llama.update_cache(k_val, self.k_cache, start_pos, self.is_seq_at_dim_2) + _ = torch.ops.llama.update_cache(v_val, self.v_cache, start_pos, self.is_seq_at_dim_2) - return ( - self.k_cache.transpose(1, 2), - self.v_cache.transpose(1, 2), - ) + if not self.is_seq_at_dim_2: + return (k_val.transpose(1, 2), v_val.transpose(1, 2)) + else: + return (self.k_cache, self.v_cache) def replace_kv_cache_with_custom_kv_cache(module): @@ -373,9 +407,11 @@ def replace_kv_cache_with_custom_kv_cache(module): def _replace_kv_cache_with_custom_kv_cache(module): for name, child in module.named_children(): if isinstance(child, KVCache): - cache_shape = child.k_cache.shape cache_dtype = child.k_cache.dtype - max_batch_size, n_heads, max_context_length, head_dim = cache_shape + max_batch_size = child.max_batch_size + n_heads = child.n_heads + max_context_length = child.max_context_length + head_dim = child.head_dim setattr( module, name, @@ -402,6 +438,7 @@ def __init__( cache_type: QuantizedCacheType = QuantizedCacheType.AffineSymmetric, use_custom_update_cache_op: bool = False, return_float_values: bool = True, + is_seq_at_dim_2: bool = False, ): # Look at attention.py for explanation on why max_context_length * 2 super().__init__( @@ -412,9 +449,11 @@ def __init__( cache_type, use_custom_update_cache_op, return_float_values, + is_seq_at_dim_2, ) self.cache_positions_manager = CachePositionsManager(self.max_context_length) self.is_ring_buffer = True + self.is_seq_at_dim_2 = is_seq_at_dim_2 self.window_size = max_context_length def create_causal_mask_for_ring_buffer(self, start_pos, seq_len): @@ -434,7 +473,10 @@ def update(self, input_pos, k_val, v_val): # 1. kv cache is stored as [B, S, H, D] # 2. If seq_len = k_val.size(2), we wont be able be able to optimize # away transpose at the output of k, v projection - seq_len = k_val.transpose(1, 2).size(1) + if not self.is_seq_at_dim_2: + seq_len = k_val.transpose(1, 2).size(1) + else: + seq_len = k_val.size(2) assert seq_len <= self.k_cache.size( 1 ), f"Update sequence length({seq_len}) for kv cache must be smaller than the cache size({self.k_cache.size(2)})" @@ -454,7 +496,9 @@ def from_quantized_kv_cache( assert isinstance( kv_cache, QuantizedKVCache ), "For QuantizedRingKVCache expect QuantizedKVCache as input kv_cache" - max_batch_size, _, n_heads, head_dim = kv_cache.k_cache.shape + max_batch_size = kv_cache.max_batch_size + n_heads = kv_cache.n_heads + head_dim = kv_cache.head_dim return cls( max_batch_size, sliding_window_size, @@ -463,6 +507,8 @@ def from_quantized_kv_cache( kv_cache.cache_type, kv_cache.use_custom_update_cache_op, kv_cache.return_float_values, + kv_cache.is_seq_at_dim_2, + is_seq_at_dim_2=kv_cache.is_seq_at_dim_2, ) @@ -474,10 +520,11 @@ def __init__( n_heads, head_dim, dtype=torch.float32, + is_seq_at_dim_2: bool = False, ): # Look at attention.py for explanation on why max_context_length * 2 super().__init__( - max_batch_size, max_context_length * 2, n_heads, head_dim, dtype + max_batch_size, max_context_length * 2, n_heads, head_dim, dtype, is_seq_at_dim_2 ) self.cache_positions_manager = CachePositionsManager(self.max_context_length) self.is_ring_buffer = True @@ -500,7 +547,10 @@ def update(self, input_pos, k_val, v_val): # 1. kv cache is stored as [B, S, H, D] # 2. If seq_len = k_val.size(2), we wont be able be able to optimize # away transpose at the output of k, v projection - seq_len = k_val.transpose(1, 2).size(1) + if not self.is_seq_at_dim_2: + seq_len = k_val.transpose(1, 2).size(1) + else: + seq_len = k_val.size(2) assert seq_len <= self.k_cache.size( 1 ), f"Update sequence length({seq_len}) for kv cache must be smaller than the cache size({self.k_cache.size(2)})" @@ -517,16 +567,21 @@ def from_custom_kv_cache( kv_cache, sliding_window_size, ): - max_batch_size, n_heads, _, head_dim = kv_cache.k_cache.shape + max_batch_size = kv_cache.max_batch_size + n_heads = kv_cache.n_heads + head_dim = kv_cache.head_dim if isinstance(kv_cache, CustomKVCache): # If replacing custom kv cache, then the shape is [B, S, H, D] - max_batch_size, _, n_heads, head_dim = kv_cache.k_cache.shape + max_batch_size = kv_cache.max_batch_size + n_heads = kv_cache.n_heads + head_dim = kv_cache.head_dim return cls( max_batch_size, sliding_window_size, n_heads, head_dim, dtype=kv_cache.k_cache.dtype, + is_seq_at_dim_2=kv_cache.is_seq_at_dim_2, ) diff --git a/examples/models/llama/source_transformation/sdpa.py b/examples/models/llama/source_transformation/sdpa.py index 2e108b2ec19..5217a103ce4 100644 --- a/examples/models/llama/source_transformation/sdpa.py +++ b/examples/models/llama/source_transformation/sdpa.py @@ -23,10 +23,12 @@ def __init__( self, dim: int, use_attention_mask: bool = False, + is_seq_at_dim_2: bool = False, ): super().__init__() self.dim = dim self.use_attention_mask = use_attention_mask + self.is_seq_at_dim_2 = is_seq_at_dim_2 def forward( self, @@ -38,9 +40,10 @@ def forward( seqlen, mask, ): - q = q.transpose(1, 2) # (bs, seqlen, n_local_heads, head_dim) - k = k.transpose(1, 2) - v = v.transpose(1, 2) + if not self.is_seq_at_dim_2: + q = q.transpose(1, 2) # (bs, seqlen, n_local_heads, head_dim) + k = k.transpose(1, 2) + v = v.transpose(1, 2) # Custom op only supports float32 currently. Converting to/from float32 is # faster than not having the op. @@ -58,6 +61,8 @@ def forward( mask, # Attention mask 0, # dropout probability. Ignored by the code False, # is_causal + scale=None, + is_seq_dim_2=self.is_seq_at_dim_2, ) else: output = torch.ops.llama.custom_sdpa( @@ -68,6 +73,8 @@ def forward( None, # Attention mask 0, # dropout probability. Ignored by the code True, # is_causal + scale=None, + is_seq_dim_2=self.is_seq_at_dim_2, ) if self.is_seq_at_dim_2: output = output.transpose(1, 2).contiguous() @@ -120,7 +127,7 @@ class QuantizedSDPA(torch.nn.Module): """ def __init__( - self, dim: int, kv_cache: QuantizedKVCache, use_attention_mask: bool = False + self, dim: int, kv_cache: QuantizedKVCache, use_attention_mask: bool = False, is_seq_at_dim_2: bool = False ): super().__init__() self.dim = dim @@ -128,6 +135,7 @@ def __init__( self.float_dtype = torch.float32 self.kv_cache = kv_cache self.use_attention_mask = use_attention_mask + self.is_seq_at_dim_2 = is_seq_at_dim_2 def forward( self, @@ -139,9 +147,10 @@ def forward( seqlen, mask, ): - q = q.transpose(1, 2) # (bs, seqlen, n_local_heads, head_dim) - k_quantized = k_quantized.transpose(1, 2) - v_quantized = v_quantized.transpose(1, 2) + if not self.is_seq_at_dim_2: + q = q.transpose(1, 2) # (bs, seqlen, n_local_heads, head_dim) + k_quantized = k_quantized.transpose(1, 2) + v_quantized = v_quantized.transpose(1, 2) q_scale, q_zero_point = ( torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default( @@ -181,6 +190,7 @@ def forward( k_scale_fp32, v_zero_point_int8, v_scale_fp32, + self.is_seq_at_dim_2, ) else: output = torch.ops.llama.custom_quantized_sdpa( @@ -198,6 +208,7 @@ def forward( k_scale_fp32, v_zero_point_int8, v_scale_fp32, + self.is_seq_at_dim_2, ) if self.is_seq_at_dim_2: @@ -210,9 +221,9 @@ def _update_attention_module_with_quantized_sdpa( ): sdpa = getattr(module, "SDPA", None) assert sdpa is not None + assert isinstance(sdpa, SDPACustom) # TODO: add support for SDPA with attention mask - # pyre-ignore - setattr(module, "SDPA", QuantizedSDPA(sdpa.dim, kv_cache)) # noqa: B010 + setattr(module, "SDPA", QuantizedSDPA(sdpa.dim, kv_cache, is_seq_at_dim_2=sdpa.is_seq_at_dim_2)) # noqa: B010 def _replace_sdpa_with_quantized_sdpa(module: torch.nn.Module):