Skip to content

Commit 22ac6fa

Browse files
[GLM-Image] Add batch support for GlmImagePipeline (#13007)
* init Signed-off-by: JaredforReal <w13431838023@gmail.com> * change from right padding to left padding Signed-off-by: JaredforReal <w13431838023@gmail.com> * try i2i batch Signed-off-by: JaredforReal <w13431838023@gmail.com> * fix: revert i2i prior_token_image_ids to original 1D tensor format * refactor KVCache for per prompt batching Signed-off-by: JaredforReal <w13431838023@gmail.com> * fix KVCache Signed-off-by: JaredforReal <w13431838023@gmail.com> * fix shape error Signed-off-by: JaredforReal <w13431838023@gmail.com> * refactor pipeline Signed-off-by: JaredforReal <w13431838023@gmail.com> * fix for left padding Signed-off-by: JaredforReal <w13431838023@gmail.com> * insert seed to AR model Signed-off-by: JaredforReal <w13431838023@gmail.com> * delete generator, use torch manual_seed Signed-off-by: JaredforReal <w13431838023@gmail.com> * add batch processing unit tests for GlmImagePipeline Signed-off-by: JaredforReal <w13431838023@gmail.com> * simplify normalize images method Signed-off-by: JaredforReal <w13431838023@gmail.com> * fix grids_per_sample Signed-off-by: JaredforReal <w13431838023@gmail.com> * fix t2i Signed-off-by: JaredforReal <w13431838023@gmail.com> * delete comments, simplify condition statement Signed-off-by: JaredforReal <w13431838023@gmail.com> * chage generate_prior_tokens outputs Signed-off-by: JaredforReal <w13431838023@gmail.com> * simplify if logic Signed-off-by: JaredforReal <w13431838023@gmail.com> * support user provided prior_token_ids directly Signed-off-by: JaredforReal <w13431838023@gmail.com> * remove blank lines Signed-off-by: JaredforReal <w13431838023@gmail.com> * align with transformers Signed-off-by: JaredforReal <w13431838023@gmail.com> * Apply style fixes --------- Signed-off-by: JaredforReal <w13431838023@gmail.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 71a865b commit 22ac6fa

File tree

3 files changed

+516
-150
lines changed

3 files changed

+516
-150
lines changed

src/diffusers/models/transformers/transformer_glm_image.py

Lines changed: 70 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -143,41 +143,86 @@ def forward(
143143

144144

145145
class GlmImageLayerKVCache:
146-
"""KV cache for GlmImage model."""
146+
"""KV cache for GlmImage model.
147+
Supports per-sample caching for batch processing where each sample may have different condition images.
148+
"""
147149

148150
def __init__(self):
149-
self.k_cache = None
150-
self.v_cache = None
151+
self.k_caches: List[Optional[torch.Tensor]] = []
152+
self.v_caches: List[Optional[torch.Tensor]] = []
151153
self.mode: Optional[str] = None # "write", "read", "skip"
154+
self.current_sample_idx: int = 0 # Current sample index for writing
152155

153156
def store(self, k: torch.Tensor, v: torch.Tensor):
154-
if self.k_cache is None:
155-
self.k_cache = k
156-
self.v_cache = v
157+
"""Store KV cache for the current sample."""
158+
# k, v shape: (1, seq_len, num_heads, head_dim)
159+
if len(self.k_caches) <= self.current_sample_idx:
160+
# First time storing for this sample
161+
self.k_caches.append(k)
162+
self.v_caches.append(v)
157163
else:
158-
self.k_cache = torch.cat([self.k_cache, k], dim=1)
159-
self.v_cache = torch.cat([self.v_cache, v], dim=1)
164+
# Append to existing cache for this sample (multiple condition images)
165+
self.k_caches[self.current_sample_idx] = torch.cat([self.k_caches[self.current_sample_idx], k], dim=1)
166+
self.v_caches[self.current_sample_idx] = torch.cat([self.v_caches[self.current_sample_idx], v], dim=1)
160167

161168
def get(self, k: torch.Tensor, v: torch.Tensor):
162-
if self.k_cache.shape[0] != k.shape[0]:
163-
k_cache_expanded = self.k_cache.expand(k.shape[0], -1, -1, -1)
164-
v_cache_expanded = self.v_cache.expand(v.shape[0], -1, -1, -1)
169+
"""Get combined KV cache for all samples in the batch.
170+
171+
Args:
172+
k: Current key tensor, shape (batch_size, seq_len, num_heads, head_dim)
173+
v: Current value tensor, shape (batch_size, seq_len, num_heads, head_dim)
174+
Returns:
175+
Combined key and value tensors with cached values prepended.
176+
"""
177+
batch_size = k.shape[0]
178+
num_cached_samples = len(self.k_caches)
179+
if num_cached_samples == 0:
180+
return k, v
181+
if num_cached_samples == 1:
182+
# Single cache, expand for all batch samples (shared condition images)
183+
k_cache_expanded = self.k_caches[0].expand(batch_size, -1, -1, -1)
184+
v_cache_expanded = self.v_caches[0].expand(batch_size, -1, -1, -1)
185+
elif num_cached_samples == batch_size:
186+
# Per-sample cache, concatenate along batch dimension
187+
k_cache_expanded = torch.cat(self.k_caches, dim=0)
188+
v_cache_expanded = torch.cat(self.v_caches, dim=0)
165189
else:
166-
k_cache_expanded = self.k_cache
167-
v_cache_expanded = self.v_cache
190+
# Mismatch: try to handle by repeating the caches
191+
# This handles cases like num_images_per_prompt > 1
192+
repeat_factor = batch_size // num_cached_samples
193+
if batch_size % num_cached_samples == 0:
194+
k_cache_list = []
195+
v_cache_list = []
196+
for i in range(num_cached_samples):
197+
k_cache_list.append(self.k_caches[i].expand(repeat_factor, -1, -1, -1))
198+
v_cache_list.append(self.v_caches[i].expand(repeat_factor, -1, -1, -1))
199+
k_cache_expanded = torch.cat(k_cache_list, dim=0)
200+
v_cache_expanded = torch.cat(v_cache_list, dim=0)
201+
else:
202+
raise ValueError(
203+
f"Cannot match {num_cached_samples} cached samples to batch size {batch_size}. "
204+
f"Batch size must be a multiple of the number of cached samples."
205+
)
168206

169-
k_cache = torch.cat([k_cache_expanded, k], dim=1)
170-
v_cache = torch.cat([v_cache_expanded, v], dim=1)
171-
return k_cache, v_cache
207+
k_combined = torch.cat([k_cache_expanded, k], dim=1)
208+
v_combined = torch.cat([v_cache_expanded, v], dim=1)
209+
return k_combined, v_combined
172210

173211
def clear(self):
174-
self.k_cache = None
175-
self.v_cache = None
212+
self.k_caches = []
213+
self.v_caches = []
176214
self.mode = None
215+
self.current_sample_idx = 0
216+
217+
def next_sample(self):
218+
"""Move to the next sample for writing."""
219+
self.current_sample_idx += 1
177220

178221

179222
class GlmImageKVCache:
180-
"""Container for all layers' KV caches."""
223+
"""Container for all layers' KV caches.
224+
Supports per-sample caching for batch processing where each sample may have different condition images.
225+
"""
181226

182227
def __init__(self, num_layers: int):
183228
self.num_layers = num_layers
@@ -192,6 +237,12 @@ def set_mode(self, mode: Optional[str]):
192237
for cache in self.caches:
193238
cache.mode = mode
194239

240+
def next_sample(self):
241+
"""Move to the next sample for writing. Call this after processing
242+
all condition images for one batch sample."""
243+
for cache in self.caches:
244+
cache.next_sample()
245+
195246
def clear(self):
196247
for cache in self.caches:
197248
cache.clear()

0 commit comments

Comments
 (0)