@@ -93,6 +93,7 @@ def execute_merge(model_names, calc_mode, all_modes, recipe_params, model_type):
9393 mismatch_mode_str = recipe_params .get ('mismatch_mode' , 'skip' )
9494 mismatch_mode = MissingTensorBehavior (mismatch_mode_str )
9595 recipe_params ['mismatch_mode' ] = mismatch_mode
96+ alignment_mode = recipe_params .get ('alignment_mode' , 'pad/crop' )
9697
9798 # Compile filter patterns
9899 exclude_patterns = _compile_patterns (recipe_params .get ('exclude_patterns' , '' ))
@@ -168,6 +169,15 @@ def execute_merge(model_names, calc_mode, all_modes, recipe_params, model_type):
168169 for r_key , r_tensor in result .items ():
169170 merged_state_dict [r_key ] = r_tensor .to (save_torch_dtype ).cpu ().clone ()
170171 else :
172+ # Ensure compatibility with Model A's architecture.
173+ # If alignment_mode is 'pad/crop', we crop results that were padded.
174+ # If alignment_mode is 'interpolate', resizing happened during operators.
175+ if alignment_mode == 'pad/crop' :
176+ target_shape = recipe_params ['_tensor_a_shape' ]
177+ if result .shape != target_shape :
178+ slices = tuple (slice (0 , min (res_s , tgt_s )) for res_s , tgt_s in zip (result .shape , target_shape ))
179+ result = result [slices ]
180+
171181 merged_state_dict [key ] = result .to (save_torch_dtype ).cpu ().clone ()
172182
173183 # Clean up tensor_a reference to allow GC
@@ -219,6 +229,7 @@ def define_schema(cls):
219229 io .Combo .Input ("model_b" , options = ["None" ] + folder_paths .get_filename_list ("checkpoints" )),
220230 io .Combo .Input ("calc_mode" , options = [m .name for m in TWO_MODEL_MODES ]),
221231 io .Combo .Input ("mismatch_mode" , options = ["skip" , "zeros" , "error" ], default = "skip" ),
232+ io .Combo .Input ("alignment_mode" , options = ["pad/crop" , "interpolate" ], default = "pad/crop" ),
222233 io .Float .Input ("alpha" , default = 0.5 , min = - 2.0 , max = 3.0 , step = 0.01 ),
223234 io .Float .Input ("beta" , default = 0.5 , min = - 2.0 , max = 3.0 , step = 0.01 ),
224235 io .Float .Input ("gamma" , default = 0.99 , min = 0.0 , max = 1.0 , step = 0.001 ),
@@ -247,7 +258,7 @@ def execute(cls, execution_mode: str, model_a: str, model_b: str,
247258
248259 recipe_params = {
249260 "model_a" : model_a , "model_b" : model_b , "calc_mode" : calc_mode ,
250- "mismatch_mode" : mismatch_mode ,
261+ "mismatch_mode" : mismatch_mode , "alignment_mode" : recipe_params . get ( 'alignment_mode' , 'pad/crop' ),
251262 "alpha" : alpha , "beta" : beta , "gamma" : gamma , "delta" : delta , "seed" : seed ,
252263 "output_filename" : output_filename , "save_dtype" : save_dtype ,
253264 "device" : process_device , "dtype" : torch .float32 ,
@@ -273,6 +284,7 @@ def define_schema(cls):
273284 io .Combo .Input ("model_b" , options = ["None" ] + folder_paths .get_filename_list ("diffusion_models" )),
274285 io .Combo .Input ("calc_mode" , options = [m .name for m in TWO_MODEL_MODES ]),
275286 io .Combo .Input ("mismatch_mode" , options = ["skip" , "zeros" , "error" ], default = "skip" ),
287+ io .Combo .Input ("alignment_mode" , options = ["pad/crop" , "interpolate" ], default = "pad/crop" ),
276288 io .Float .Input ("alpha" , default = 0.5 , min = - 2.0 , max = 3.0 , step = 0.01 ),
277289 io .Float .Input ("beta" , default = 0.5 , min = - 2.0 , max = 3.0 , step = 0.01 ),
278290 io .Float .Input ("gamma" , default = 0.99 , min = 0.0 , max = 1.0 , step = 0.001 ),
@@ -301,7 +313,7 @@ def execute(cls, execution_mode: str, model_a: str, model_b: str,
301313
302314 recipe_params = {
303315 "model_a" : model_a , "model_b" : model_b , "calc_mode" : calc_mode ,
304- "mismatch_mode" : mismatch_mode ,
316+ "mismatch_mode" : mismatch_mode , "alignment_mode" : recipe_params . get ( 'alignment_mode' , 'pad/crop' ),
305317 "alpha" : alpha , "beta" : beta , "gamma" : gamma , "delta" : delta , "seed" : seed ,
306318 "output_filename" : output_filename , "save_dtype" : save_dtype ,
307319 "device" : process_device , "dtype" : torch .float32 ,
@@ -327,6 +339,7 @@ def define_schema(cls):
327339 io .Combo .Input ("model_b" , options = ["None" ] + folder_paths .get_filename_list ("text_encoders" )),
328340 io .Combo .Input ("calc_mode" , options = [m .name for m in TWO_MODEL_MODES ]),
329341 io .Combo .Input ("mismatch_mode" , options = ["skip" , "zeros" , "error" ], default = "skip" ),
342+ io .Combo .Input ("alignment_mode" , options = ["pad/crop" , "interpolate" ], default = "pad/crop" ),
330343 io .Float .Input ("alpha" , default = 0.5 , min = - 2.0 , max = 3.0 , step = 0.01 ),
331344 io .Float .Input ("beta" , default = 0.5 , min = - 2.0 , max = 3.0 , step = 0.01 ),
332345 io .Float .Input ("gamma" , default = 0.99 , min = 0.0 , max = 1.0 , step = 0.001 ),
@@ -355,7 +368,7 @@ def execute(cls, execution_mode: str, model_a: str, model_b: str,
355368
356369 recipe_params = {
357370 "model_a" : model_a , "model_b" : model_b , "calc_mode" : calc_mode ,
358- "mismatch_mode" : mismatch_mode ,
371+ "mismatch_mode" : mismatch_mode , "alignment_mode" : recipe_params . get ( 'alignment_mode' , 'pad/crop' ),
359372 "alpha" : alpha , "beta" : beta , "gamma" : gamma , "delta" : delta , "seed" : seed ,
360373 "output_filename" : output_filename , "save_dtype" : save_dtype ,
361374 "device" : process_device , "dtype" : torch .float32 ,
@@ -381,6 +394,7 @@ def define_schema(cls):
381394 io .Combo .Input ("model_b" , options = ["None" ] + folder_paths .get_filename_list ("loras" )),
382395 io .Combo .Input ("calc_mode" , options = [m .name for m in TWO_MODEL_MODES ]),
383396 io .Combo .Input ("mismatch_mode" , options = ["skip" , "zeros" , "error" ], default = "skip" ),
397+ io .Combo .Input ("alignment_mode" , options = ["pad/crop" , "interpolate" ], default = "pad/crop" ),
384398 io .Float .Input ("alpha" , default = 0.5 , min = - 2.0 , max = 3.0 , step = 0.01 ),
385399 io .Float .Input ("beta" , default = 0.5 , min = - 2.0 , max = 3.0 , step = 0.01 ),
386400 io .Float .Input ("gamma" , default = 0.99 , min = 0.0 , max = 1.0 , step = 0.001 ),
@@ -409,7 +423,7 @@ def execute(cls, execution_mode: str, model_a: str, model_b: str,
409423
410424 recipe_params = {
411425 "model_a" : model_a , "model_b" : model_b , "calc_mode" : calc_mode ,
412- "mismatch_mode" : mismatch_mode ,
426+ "mismatch_mode" : mismatch_mode , "alignment_mode" : recipe_params . get ( 'alignment_mode' , 'pad/crop' ),
413427 "alpha" : alpha , "beta" : beta , "gamma" : gamma , "delta" : delta , "seed" : seed ,
414428 "output_filename" : output_filename , "save_dtype" : save_dtype ,
415429 "device" : process_device , "dtype" : torch .float32 ,
@@ -435,6 +449,7 @@ def define_schema(cls):
435449 io .Combo .Input ("model_b" , options = ["None" ] + folder_paths .get_filename_list ("embeddings" )),
436450 io .Combo .Input ("calc_mode" , options = [m .name for m in TWO_MODEL_MODES ]),
437451 io .Combo .Input ("mismatch_mode" , options = ["skip" , "zeros" , "error" ], default = "skip" ),
452+ io .Combo .Input ("alignment_mode" , options = ["pad/crop" , "interpolate" ], default = "pad/crop" ),
438453 io .Float .Input ("alpha" , default = 0.5 , min = - 2.0 , max = 3.0 , step = 0.01 ),
439454 io .Float .Input ("beta" , default = 0.5 , min = - 2.0 , max = 3.0 , step = 0.01 ),
440455 io .Float .Input ("gamma" , default = 0.99 , min = 0.0 , max = 1.0 , step = 0.001 ),
@@ -463,7 +478,7 @@ def execute(cls, execution_mode: str, model_a: str, model_b: str,
463478
464479 recipe_params = {
465480 "model_a" : model_a , "model_b" : model_b , "calc_mode" : calc_mode ,
466- "mismatch_mode" : mismatch_mode ,
481+ "mismatch_mode" : mismatch_mode , "alignment_mode" : recipe_params . get ( 'alignment_mode' , 'pad/crop' ),
467482 "alpha" : alpha , "beta" : beta , "gamma" : gamma , "delta" : delta , "seed" : seed ,
468483 "output_filename" : output_filename , "save_dtype" : save_dtype ,
469484 "device" : process_device , "dtype" : torch .float32 ,
@@ -492,6 +507,7 @@ def define_schema(cls):
492507 io .Combo .Input ("model_c" , options = ["None" ] + folder_paths .get_filename_list ("checkpoints" )),
493508 io .Combo .Input ("calc_mode" , options = [m .name for m in THREE_MODEL_MODES ]),
494509 io .Combo .Input ("mismatch_mode" , options = ["skip" , "zeros" , "error" ], default = "skip" ),
510+ io .Combo .Input ("alignment_mode" , options = ["pad/crop" , "interpolate" ], default = "pad/crop" ),
495511 io .Float .Input ("alpha" , default = 0.5 , min = - 2.0 , max = 3.0 , step = 0.01 ),
496512 io .Float .Input ("beta" , default = 0.5 , min = - 2.0 , max = 3.0 , step = 0.01 ),
497513 io .Float .Input ("gamma" , default = 0.5 , min = - 2.0 , max = 3.0 , step = 0.01 ),
@@ -520,7 +536,7 @@ def execute(cls, execution_mode: str, model_a: str, model_b: str, model_c: str,
520536
521537 recipe_params = {
522538 "model_a" : model_a , "model_b" : model_b , "model_c" : model_c , "calc_mode" : calc_mode ,
523- "mismatch_mode" : mismatch_mode ,
539+ "mismatch_mode" : mismatch_mode , "alignment_mode" : recipe_params . get ( 'alignment_mode' , 'pad/crop' ),
524540 "alpha" : alpha , "beta" : beta , "gamma" : gamma , "delta" : delta , "seed" : seed ,
525541 "output_filename" : output_filename , "save_dtype" : save_dtype ,
526542 "device" : process_device , "dtype" : torch .float32 ,
@@ -547,6 +563,7 @@ def define_schema(cls):
547563 io .Combo .Input ("model_c" , options = ["None" ] + folder_paths .get_filename_list ("diffusion_models" )),
548564 io .Combo .Input ("calc_mode" , options = [m .name for m in THREE_MODEL_MODES ]),
549565 io .Combo .Input ("mismatch_mode" , options = ["skip" , "zeros" , "error" ], default = "skip" ),
566+ io .Combo .Input ("alignment_mode" , options = ["pad/crop" , "interpolate" ], default = "pad/crop" ),
550567 io .Float .Input ("alpha" , default = 0.5 , min = - 2.0 , max = 3.0 , step = 0.01 ),
551568 io .Float .Input ("beta" , default = 0.5 , min = - 2.0 , max = 3.0 , step = 0.01 ),
552569 io .Float .Input ("gamma" , default = 0.5 , min = - 2.0 , max = 3.0 , step = 0.01 ),
@@ -575,7 +592,7 @@ def execute(cls, execution_mode: str, model_a: str, model_b: str, model_c: str,
575592
576593 recipe_params = {
577594 "model_a" : model_a , "model_b" : model_b , "model_c" : model_c , "calc_mode" : calc_mode ,
578- "mismatch_mode" : mismatch_mode ,
595+ "mismatch_mode" : mismatch_mode , "alignment_mode" : recipe_params . get ( 'alignment_mode' , 'pad/crop' ),
579596 "alpha" : alpha , "beta" : beta , "gamma" : gamma , "delta" : delta , "seed" : seed ,
580597 "output_filename" : output_filename , "save_dtype" : save_dtype ,
581598 "device" : process_device , "dtype" : torch .float32 ,
@@ -602,6 +619,7 @@ def define_schema(cls):
602619 io .Combo .Input ("model_c" , options = ["None" ] + folder_paths .get_filename_list ("text_encoders" )),
603620 io .Combo .Input ("calc_mode" , options = [m .name for m in THREE_MODEL_MODES ]),
604621 io .Combo .Input ("mismatch_mode" , options = ["skip" , "zeros" , "error" ], default = "skip" ),
622+ io .Combo .Input ("alignment_mode" , options = ["pad/crop" , "interpolate" ], default = "pad/crop" ),
605623 io .Float .Input ("alpha" , default = 0.5 , min = - 2.0 , max = 3.0 , step = 0.01 ),
606624 io .Float .Input ("beta" , default = 0.5 , min = - 2.0 , max = 3.0 , step = 0.01 ),
607625 io .Float .Input ("gamma" , default = 0.5 , min = - 2.0 , max = 3.0 , step = 0.01 ),
@@ -630,7 +648,7 @@ def execute(cls, execution_mode: str, model_a: str, model_b: str, model_c: str,
630648
631649 recipe_params = {
632650 "model_a" : model_a , "model_b" : model_b , "model_c" : model_c , "calc_mode" : calc_mode ,
633- "mismatch_mode" : mismatch_mode ,
651+ "mismatch_mode" : mismatch_mode , "alignment_mode" : recipe_params . get ( 'alignment_mode' , 'pad/crop' ),
634652 "alpha" : alpha , "beta" : beta , "gamma" : gamma , "delta" : delta , "seed" : seed ,
635653 "output_filename" : output_filename , "save_dtype" : save_dtype ,
636654 "device" : process_device , "dtype" : torch .float32 ,
@@ -657,6 +675,7 @@ def define_schema(cls):
657675 io .Combo .Input ("model_c" , options = ["None" ] + folder_paths .get_filename_list ("loras" )),
658676 io .Combo .Input ("calc_mode" , options = [m .name for m in THREE_MODEL_MODES ]),
659677 io .Combo .Input ("mismatch_mode" , options = ["skip" , "zeros" , "error" ], default = "skip" ),
678+ io .Combo .Input ("alignment_mode" , options = ["pad/crop" , "interpolate" ], default = "pad/crop" ),
660679 io .Float .Input ("alpha" , default = 0.5 , min = - 2.0 , max = 3.0 , step = 0.01 ),
661680 io .Float .Input ("beta" , default = 0.5 , min = - 2.0 , max = 3.0 , step = 0.01 ),
662681 io .Float .Input ("gamma" , default = 0.5 , min = - 2.0 , max = 3.0 , step = 0.01 ),
@@ -685,7 +704,7 @@ def execute(cls, execution_mode: str, model_a: str, model_b: str, model_c: str,
685704
686705 recipe_params = {
687706 "model_a" : model_a , "model_b" : model_b , "model_c" : model_c , "calc_mode" : calc_mode ,
688- "mismatch_mode" : mismatch_mode ,
707+ "mismatch_mode" : mismatch_mode , "alignment_mode" : recipe_params . get ( 'alignment_mode' , 'pad/crop' ),
689708 "alpha" : alpha , "beta" : beta , "gamma" : gamma , "delta" : delta , "seed" : seed ,
690709 "output_filename" : output_filename , "save_dtype" : save_dtype ,
691710 "device" : process_device , "dtype" : torch .float32 ,
@@ -712,6 +731,7 @@ def define_schema(cls):
712731 io .Combo .Input ("model_c" , options = ["None" ] + folder_paths .get_filename_list ("embeddings" )),
713732 io .Combo .Input ("calc_mode" , options = [m .name for m in THREE_MODEL_MODES ]),
714733 io .Combo .Input ("mismatch_mode" , options = ["skip" , "zeros" , "error" ], default = "skip" ),
734+ io .Combo .Input ("alignment_mode" , options = ["pad/crop" , "interpolate" ], default = "pad/crop" ),
715735 io .Float .Input ("alpha" , default = 0.5 , min = - 2.0 , max = 3.0 , step = 0.01 ),
716736 io .Float .Input ("beta" , default = 0.5 , min = - 2.0 , max = 3.0 , step = 0.01 ),
717737 io .Float .Input ("gamma" , default = 0.5 , min = - 2.0 , max = 3.0 , step = 0.01 ),
@@ -740,7 +760,7 @@ def execute(cls, execution_mode: str, model_a: str, model_b: str, model_c: str,
740760
741761 recipe_params = {
742762 "model_a" : model_a , "model_b" : model_b , "model_c" : model_c , "calc_mode" : calc_mode ,
743- "mismatch_mode" : mismatch_mode ,
763+ "mismatch_mode" : mismatch_mode , "alignment_mode" : recipe_params . get ( 'alignment_mode' , 'pad/crop' ),
744764 "alpha" : alpha , "beta" : beta , "gamma" : gamma , "delta" : delta , "seed" : seed ,
745765 "output_filename" : output_filename , "save_dtype" : save_dtype ,
746766 "device" : process_device , "dtype" : torch .float32 ,
0 commit comments