|
| 1 | +# Copyright 2025 The HuggingFace Team. All rights reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +from dataclasses import dataclass |
| 16 | + |
| 17 | +import torch |
| 18 | + |
| 19 | +from .hooks import BaseState, HookRegistry, ModelHook, StateManager |
| 20 | + |
| 21 | + |
| 22 | +_TEXT_KV_CACHE_TRANSFORMER_HOOK = "text_kv_cache_transformer" |
| 23 | +_TEXT_KV_CACHE_BLOCK_HOOK = "text_kv_cache_block" |
| 24 | + |
| 25 | + |
| 26 | +@dataclass |
| 27 | +class TextKVCacheConfig: |
| 28 | + """Enable exact (lossless) text K/V caching for transformer models. |
| 29 | +
|
| 30 | + Pre-computes per-block text key and value projections once before the denoising loop and reuses them across all |
| 31 | + steps. Positive and negative prompts are distinguished via a stable cache key captured by a transformer-level hook |
| 32 | + before any intermediate tensor allocations. |
| 33 | + """ |
| 34 | + |
| 35 | + pass |
| 36 | + |
| 37 | + |
| 38 | +class TextKVCacheState(BaseState): |
| 39 | + """Shared state between the transformer-level and block-level hooks. |
| 40 | +
|
| 41 | + The transformer hook writes the stable ``encoder_hidden_states`` ``data_ptr()`` (captured *before* ``txt_norm``) so |
| 42 | + that block hooks can use it as a reliable cache key across denoising steps. |
| 43 | + """ |
| 44 | + |
| 45 | + def __init__(self): |
| 46 | + self.key: int | None = None |
| 47 | + |
| 48 | + def reset(self): |
| 49 | + self.key = None |
| 50 | + |
| 51 | + |
| 52 | +class TextKVCacheBlockState(BaseState): |
| 53 | + """Per-block state holding cached text key/value projections.""" |
| 54 | + |
| 55 | + def __init__(self): |
| 56 | + self.kv_cache: dict[int, tuple[torch.Tensor, torch.Tensor]] = {} |
| 57 | + |
| 58 | + def reset(self): |
| 59 | + self.kv_cache.clear() |
| 60 | + |
| 61 | + |
| 62 | +class TextKVCacheTransformerHook(ModelHook): |
| 63 | + """Captures ``encoder_hidden_states.data_ptr()`` before ``txt_norm`` |
| 64 | + and writes it to shared state for the block hooks to read.""" |
| 65 | + |
| 66 | + _is_stateful = True |
| 67 | + |
| 68 | + def __init__(self, state_manager: StateManager): |
| 69 | + super().__init__() |
| 70 | + self.state_manager = state_manager |
| 71 | + |
| 72 | + def new_forward(self, module: torch.nn.Module, *args, **kwargs): |
| 73 | + if self.state_manager._current_context is None: |
| 74 | + self.state_manager.set_context("inference") |
| 75 | + |
| 76 | + encoder_hidden_states = kwargs.get("encoder_hidden_states") |
| 77 | + if encoder_hidden_states is not None: |
| 78 | + state: TextKVCacheState = self.state_manager.get_state() |
| 79 | + state.key = encoder_hidden_states.data_ptr() |
| 80 | + return self.fn_ref.original_forward(*args, **kwargs) |
| 81 | + |
| 82 | + def reset_state(self, module: torch.nn.Module): |
| 83 | + self.state_manager.reset() |
| 84 | + return module |
| 85 | + |
| 86 | + |
| 87 | +class TextKVCacheBlockHook(ModelHook): |
| 88 | + """Caches ``(txt_key, txt_value)`` per block per unique prompt using |
| 89 | + the stable cache key from the shared state.""" |
| 90 | + |
| 91 | + _is_stateful = True |
| 92 | + |
| 93 | + def __init__(self, state_manager: StateManager, block_state_manager: StateManager): |
| 94 | + super().__init__() |
| 95 | + self.state_manager = state_manager |
| 96 | + self.block_state_manager = block_state_manager |
| 97 | + |
| 98 | + def new_forward(self, module: torch.nn.Module, *args, **kwargs): |
| 99 | + from ..models.transformers.transformer_nucleusmoe_image import _apply_rotary_emb_nucleus |
| 100 | + |
| 101 | + if self.state_manager._current_context is None: |
| 102 | + self.state_manager.set_context("inference") |
| 103 | + |
| 104 | + if self.block_state_manager._current_context is None: |
| 105 | + self.block_state_manager.set_context("inference") |
| 106 | + |
| 107 | + if "encoder_hidden_states" in kwargs: |
| 108 | + encoder_hidden_states = kwargs["encoder_hidden_states"] |
| 109 | + else: |
| 110 | + encoder_hidden_states = args[1] |
| 111 | + |
| 112 | + if "image_rotary_emb" in kwargs: |
| 113 | + image_rotary_emb = kwargs["image_rotary_emb"] |
| 114 | + elif len(args) > 3: |
| 115 | + image_rotary_emb = args[3] |
| 116 | + else: |
| 117 | + image_rotary_emb = None |
| 118 | + |
| 119 | + state: TextKVCacheState = self.state_manager.get_state() |
| 120 | + cache_key = state.key |
| 121 | + |
| 122 | + block_state: TextKVCacheBlockState = self.block_state_manager.get_state() |
| 123 | + |
| 124 | + if cache_key not in block_state.kv_cache: |
| 125 | + context = module.encoder_proj(encoder_hidden_states) |
| 126 | + |
| 127 | + attn = module.attn |
| 128 | + head_dim = attn.inner_dim // attn.heads |
| 129 | + num_kv_heads = attn.inner_kv_dim // head_dim |
| 130 | + |
| 131 | + txt_key = attn.add_k_proj(context).unflatten(-1, (num_kv_heads, -1)) |
| 132 | + txt_value = attn.add_v_proj(context).unflatten(-1, (num_kv_heads, -1)) |
| 133 | + |
| 134 | + if attn.norm_added_k is not None: |
| 135 | + txt_key = attn.norm_added_k(txt_key) |
| 136 | + |
| 137 | + if image_rotary_emb is not None: |
| 138 | + _, txt_freqs = image_rotary_emb |
| 139 | + txt_key = _apply_rotary_emb_nucleus(txt_key, txt_freqs, use_real=False) |
| 140 | + |
| 141 | + block_state.kv_cache[cache_key] = (txt_key, txt_value) |
| 142 | + |
| 143 | + txt_key, txt_value = block_state.kv_cache[cache_key] |
| 144 | + |
| 145 | + attn_kwargs = kwargs.get("attention_kwargs") or {} |
| 146 | + attn_kwargs["cached_txt_key"] = txt_key |
| 147 | + attn_kwargs["cached_txt_value"] = txt_value |
| 148 | + kwargs["attention_kwargs"] = attn_kwargs |
| 149 | + |
| 150 | + return self.fn_ref.original_forward(*args, **kwargs) |
| 151 | + |
| 152 | + def reset_state(self, module: torch.nn.Module): |
| 153 | + self.block_state_manager.reset() |
| 154 | + return module |
| 155 | + |
| 156 | + |
| 157 | +def apply_text_kv_cache(module: torch.nn.Module, config: TextKVCacheConfig) -> None: |
| 158 | + from ..models.transformers.transformer_nucleusmoe_image import NucleusMoEImageTransformerBlock |
| 159 | + |
| 160 | + HookRegistry.check_if_exists_or_initialize(module) |
| 161 | + |
| 162 | + state_manager = StateManager(TextKVCacheState) |
| 163 | + |
| 164 | + transformer_hook = TextKVCacheTransformerHook(state_manager) |
| 165 | + registry = HookRegistry.check_if_exists_or_initialize(module) |
| 166 | + registry.register_hook(transformer_hook, _TEXT_KV_CACHE_TRANSFORMER_HOOK) |
| 167 | + |
| 168 | + for _, submodule in module.named_modules(): |
| 169 | + if isinstance(submodule, NucleusMoEImageTransformerBlock): |
| 170 | + block_state_manager = StateManager(TextKVCacheBlockState) |
| 171 | + hook = TextKVCacheBlockHook(state_manager, block_state_manager) |
| 172 | + block_registry = HookRegistry.check_if_exists_or_initialize(submodule) |
| 173 | + block_registry.register_hook(hook, _TEXT_KV_CACHE_BLOCK_HOOK) |
0 commit comments