|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | | -from typing import Dict, Optional, Tuple, Union |
| 14 | +from typing import Optional, Tuple, Union |
15 | 15 |
|
16 | 16 | import torch |
17 | 17 | import torch.nn as nn |
|
21 | 21 | from ...loaders.single_file_model import FromOriginalModelMixin |
22 | 22 | from ...utils import deprecate |
23 | 23 | from ...utils.accelerate_utils import apply_forward_hook |
| 24 | +from ..attention import AttentionMixin |
24 | 25 | from ..attention_processor import ( |
25 | 26 | ADDED_KV_ATTENTION_PROCESSORS, |
26 | 27 | CROSS_ATTENTION_PROCESSORS, |
27 | 28 | Attention, |
28 | | - AttentionProcessor, |
29 | 29 | AttnAddedKVProcessor, |
30 | 30 | AttnProcessor, |
31 | 31 | FusedAttnProcessor2_0, |
|
35 | 35 | from .vae import AutoencoderMixin, Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder |
36 | 36 |
|
37 | 37 |
|
38 | | -class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin): |
| 38 | +class AutoencoderKL( |
| 39 | + ModelMixin, AttentionMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin |
| 40 | +): |
39 | 41 | r""" |
40 | 42 | A VAE model with KL loss for encoding images into latents and decoding latent representations into images. |
41 | 43 |
|
@@ -138,66 +140,6 @@ def __init__( |
138 | 140 | self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1))) |
139 | 141 | self.tile_overlap_factor = 0.25 |
140 | 142 |
|
141 | | - @property |
142 | | - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors |
143 | | - def attn_processors(self) -> Dict[str, AttentionProcessor]: |
144 | | - r""" |
145 | | - Returns: |
146 | | - `dict` of attention processors: A dictionary containing all attention processors used in the model with |
147 | | - indexed by its weight name. |
148 | | - """ |
149 | | - # set recursively |
150 | | - processors = {} |
151 | | - |
152 | | - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): |
153 | | - if hasattr(module, "get_processor"): |
154 | | - processors[f"{name}.processor"] = module.get_processor() |
155 | | - |
156 | | - for sub_name, child in module.named_children(): |
157 | | - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) |
158 | | - |
159 | | - return processors |
160 | | - |
161 | | - for name, module in self.named_children(): |
162 | | - fn_recursive_add_processors(name, module, processors) |
163 | | - |
164 | | - return processors |
165 | | - |
166 | | - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor |
167 | | - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): |
168 | | - r""" |
169 | | - Sets the attention processor to use to compute attention. |
170 | | -
|
171 | | - Parameters: |
172 | | - processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): |
173 | | - The instantiated processor class or a dictionary of processor classes that will be set as the processor |
174 | | - for **all** `Attention` layers. |
175 | | -
|
176 | | - If `processor` is a dict, the key needs to define the path to the corresponding cross attention |
177 | | - processor. This is strongly recommended when setting trainable attention processors. |
178 | | -
|
179 | | - """ |
180 | | - count = len(self.attn_processors.keys()) |
181 | | - |
182 | | - if isinstance(processor, dict) and len(processor) != count: |
183 | | - raise ValueError( |
184 | | - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" |
185 | | - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." |
186 | | - ) |
187 | | - |
188 | | - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): |
189 | | - if hasattr(module, "set_processor"): |
190 | | - if not isinstance(processor, dict): |
191 | | - module.set_processor(processor) |
192 | | - else: |
193 | | - module.set_processor(processor.pop(f"{name}.processor")) |
194 | | - |
195 | | - for sub_name, child in module.named_children(): |
196 | | - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) |
197 | | - |
198 | | - for name, module in self.named_children(): |
199 | | - fn_recursive_attn_processor(name, module, processor) |
200 | | - |
201 | 143 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor |
202 | 144 | def set_default_attn_processor(self): |
203 | 145 | """ |
|
0 commit comments