|
49 | 49 | convert_stable_cascade_unet_single_file_to_diffusers, |
50 | 50 | convert_wan_transformer_to_diffusers, |
51 | 51 | convert_wan_vae_to_diffusers, |
| 52 | + convert_z_image_controlnet_checkpoint_to_diffusers, |
52 | 53 | convert_z_image_transformer_checkpoint_to_diffusers, |
53 | 54 | create_controlnet_diffusers_config_from_ldm, |
54 | 55 | create_unet_diffusers_config_from_ldm, |
|
174 | 175 | "default_subfolder": "transformer", |
175 | 176 | }, |
176 | 177 | "ZImageControlNetModel": { |
177 | | - "checkpoint_mapping_fn": lambda x: x, |
| 178 | + "checkpoint_mapping_fn": convert_z_image_controlnet_checkpoint_to_diffusers, |
178 | 179 | "config_create_fn": create_z_image_controlnet_config, |
179 | 180 | }, |
180 | 181 | } |
181 | 182 |
|
182 | 183 |
|
183 | 184 | def _should_convert_state_dict_to_diffusers(model_state_dict, checkpoint_state_dict): |
184 | | - return not set(model_state_dict.keys()).issubset(set(checkpoint_state_dict.keys())) |
| 185 | + model_state_dict_keys = set(model_state_dict.keys()) |
| 186 | + checkpoint_state_dict_keys = set(checkpoint_state_dict.keys()) |
| 187 | + is_subset = model_state_dict_keys.issubset(checkpoint_state_dict_keys) |
| 188 | + is_match = model_state_dict_keys == checkpoint_state_dict_keys |
| 189 | + return not (is_subset and is_match) |
185 | 190 |
|
186 | 191 |
|
187 | 192 | def _get_single_file_loadable_mapping_class(cls): |
|
0 commit comments