Skip to content

Commit b615af1

Browse files
authored
Add support for small flux.2 decoder (#13314)
1 parent 40862c0 commit b615af1

2 files changed

Lines changed: 11 additions & 3 deletions

File tree

comfy/ldm/models/autoencoder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,14 +155,15 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
155155
def __init__(self, embed_dim: int, **kwargs):
156156
self.max_batch_size = kwargs.pop("max_batch_size", None)
157157
ddconfig = kwargs.pop("ddconfig")
158+
decoder_ddconfig = kwargs.pop("decoder_ddconfig", ddconfig)
158159
super().__init__(
159160
encoder_config={
160161
"target": "comfy.ldm.modules.diffusionmodules.model.Encoder",
161162
"params": ddconfig,
162163
},
163164
decoder_config={
164165
"target": "comfy.ldm.modules.diffusionmodules.model.Decoder",
165-
"params": ddconfig,
166+
"params": decoder_ddconfig,
166167
},
167168
**kwargs,
168169
)

comfy/sd.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -556,12 +556,19 @@ def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None)
556556
old_memory_used_decode = self.memory_used_decode
557557
self.memory_used_decode = lambda shape, dtype: old_memory_used_decode(shape, dtype) * 4.0
558558

559+
decoder_ch = sd['decoder.conv_in.weight'].shape[0] // ddconfig['ch_mult'][-1]
560+
if decoder_ch != ddconfig['ch']:
561+
decoder_ddconfig = ddconfig.copy()
562+
decoder_ddconfig['ch'] = decoder_ch
563+
else:
564+
decoder_ddconfig = None
565+
559566
if 'post_quant_conv.weight' in sd:
560-
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
567+
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1], **({"decoder_ddconfig": decoder_ddconfig} if decoder_ddconfig is not None else {}))
561568
else:
562569
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
563570
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig},
564-
decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': ddconfig})
571+
decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': decoder_ddconfig if decoder_ddconfig is not None else ddconfig})
565572
elif "decoder.layers.1.layers.0.beta" in sd:
566573
config = {}
567574
param_key = None

0 commit comments

Comments
 (0)