|
21 | 21 | import torch.nn.functional as F |
22 | 22 |
|
23 | 23 | from ...configuration_utils import ConfigMixin, register_to_config |
| 24 | +from ...image_processor import IPAdapterMaskProcessor |
24 | 25 | from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin |
25 | 26 | from ...utils import apply_lora_scale, logging |
26 | 27 | from ...utils.torch_utils import maybe_allow_in_graph |
@@ -244,28 +245,100 @@ def __call__( |
244 | 245 | # IP-adapter |
245 | 246 | ip_attn_output = torch.zeros_like(hidden_states) |
246 | 247 |
|
247 | | - for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip( |
248 | | - ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip |
| 248 | + if ip_adapter_masks is not None: |
| 249 | + if not isinstance(ip_adapter_masks, list): |
| 250 | + ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1)) |
| 251 | + if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)): |
| 252 | + raise ValueError( |
| 253 | + f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match " |
| 254 | + f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states " |
| 255 | + f"({len(ip_hidden_states)})" |
| 256 | + ) |
| 257 | + for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)): |
| 258 | + if mask is None: |
| 259 | + continue |
| 260 | + if not isinstance(mask, torch.Tensor) or mask.ndim != 4: |
| 261 | + raise ValueError( |
| 262 | + "Each element of the ip_adapter_masks array should be a tensor with shape " |
| 263 | + "[1, num_images_for_ip_adapter, height, width]." |
| 264 | + " Please use `IPAdapterMaskProcessor` to preprocess your mask" |
| 265 | + ) |
| 266 | + num_ip_images = 1 if ip_state.ndim == 3 else ip_state.shape[1] |
| 267 | + if mask.shape[1] != num_ip_images: |
| 268 | + raise ValueError( |
| 269 | + f"Number of masks ({mask.shape[1]}) does not match " |
| 270 | + f"number of ip images ({num_ip_images}) at index {index}" |
| 271 | + ) |
| 272 | + if isinstance(scale, list) and not len(scale) == mask.shape[1]: |
| 273 | + raise ValueError( |
| 274 | + f"Number of masks ({mask.shape[1]}) does not match " |
| 275 | + f"number of scales ({len(scale)}) at index {index}" |
| 276 | + ) |
| 277 | + else: |
| 278 | + ip_adapter_masks = [None] * len(self.scale) |
| 279 | + |
| 280 | + for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip( |
| 281 | + ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks |
249 | 282 | ): |
250 | | - ip_key = to_k_ip(current_ip_hidden_states) |
251 | | - ip_value = to_v_ip(current_ip_hidden_states) |
252 | | - |
253 | | - ip_key = ip_key.view(batch_size, -1, attn.heads, attn.head_dim) |
254 | | - ip_value = ip_value.view(batch_size, -1, attn.heads, attn.head_dim) |
255 | | - |
256 | | - current_ip_hidden_states = dispatch_attention_fn( |
257 | | - ip_query, |
258 | | - ip_key, |
259 | | - ip_value, |
260 | | - attn_mask=None, |
261 | | - dropout_p=0.0, |
262 | | - is_causal=False, |
263 | | - backend=self._attention_backend, |
264 | | - parallel_config=self._parallel_config, |
265 | | - ) |
266 | | - current_ip_hidden_states = current_ip_hidden_states.reshape(batch_size, -1, attn.heads * attn.head_dim) |
267 | | - current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype) |
268 | | - ip_attn_output += scale * current_ip_hidden_states |
| 283 | + if mask is not None: |
| 284 | + if current_ip_hidden_states.ndim == 3: |
| 285 | + current_ip_hidden_states = current_ip_hidden_states[:, None, :, :] |
| 286 | + if not isinstance(scale, list): |
| 287 | + scale = [scale] * mask.shape[1] |
| 288 | + |
| 289 | + for i in range(mask.shape[1]): |
| 290 | + ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :]) |
| 291 | + ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :]) |
| 292 | + |
| 293 | + ip_key = ip_key.view(batch_size, -1, attn.heads, attn.head_dim) |
| 294 | + ip_value = ip_value.view(batch_size, -1, attn.heads, attn.head_dim) |
| 295 | + |
| 296 | + _current_ip_hidden_states = dispatch_attention_fn( |
| 297 | + ip_query, |
| 298 | + ip_key, |
| 299 | + ip_value, |
| 300 | + attn_mask=None, |
| 301 | + dropout_p=0.0, |
| 302 | + is_causal=False, |
| 303 | + backend=self._attention_backend, |
| 304 | + parallel_config=self._parallel_config, |
| 305 | + ) |
| 306 | + _current_ip_hidden_states = _current_ip_hidden_states.reshape( |
| 307 | + batch_size, -1, attn.heads * attn.head_dim |
| 308 | + ) |
| 309 | + _current_ip_hidden_states = _current_ip_hidden_states.to(ip_query.dtype) |
| 310 | + |
| 311 | + mask_downsample = IPAdapterMaskProcessor.downsample( |
| 312 | + mask[:, i, :, :], |
| 313 | + batch_size, |
| 314 | + _current_ip_hidden_states.shape[1], |
| 315 | + _current_ip_hidden_states.shape[2], |
| 316 | + ) |
| 317 | + mask_downsample = mask_downsample.to(dtype=ip_query.dtype, device=ip_query.device) |
| 318 | + |
| 319 | + ip_attn_output += scale[i] * (_current_ip_hidden_states * mask_downsample) |
| 320 | + else: |
| 321 | + ip_key = to_k_ip(current_ip_hidden_states) |
| 322 | + ip_value = to_v_ip(current_ip_hidden_states) |
| 323 | + |
| 324 | + ip_key = ip_key.view(batch_size, -1, attn.heads, attn.head_dim) |
| 325 | + ip_value = ip_value.view(batch_size, -1, attn.heads, attn.head_dim) |
| 326 | + |
| 327 | + current_ip_hidden_states = dispatch_attention_fn( |
| 328 | + ip_query, |
| 329 | + ip_key, |
| 330 | + ip_value, |
| 331 | + attn_mask=None, |
| 332 | + dropout_p=0.0, |
| 333 | + is_causal=False, |
| 334 | + backend=self._attention_backend, |
| 335 | + parallel_config=self._parallel_config, |
| 336 | + ) |
| 337 | + current_ip_hidden_states = current_ip_hidden_states.reshape( |
| 338 | + batch_size, -1, attn.heads * attn.head_dim |
| 339 | + ) |
| 340 | + current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype) |
| 341 | + ip_attn_output += scale * current_ip_hidden_states |
269 | 342 |
|
270 | 343 | return hidden_states, encoder_hidden_states, ip_attn_output |
271 | 344 | else: |
|
0 commit comments