Skip to content

Commit 4097326

Browse files
committed
Add pad/crop for mismatched tensors and interpolation mode for weight redistribution across rank
1 parent 31eee60 commit 4097326

2 files changed

Lines changed: 147 additions & 39 deletions

File tree

nodes/merger.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def execute_merge(model_names, calc_mode, all_modes, recipe_params, model_type):
9393
mismatch_mode_str = recipe_params.get('mismatch_mode', 'skip')
9494
mismatch_mode = MissingTensorBehavior(mismatch_mode_str)
9595
recipe_params['mismatch_mode'] = mismatch_mode
96+
alignment_mode = recipe_params.get('alignment_mode', 'pad/crop')
9697

9798
# Compile filter patterns
9899
exclude_patterns = _compile_patterns(recipe_params.get('exclude_patterns', ''))
@@ -168,6 +169,15 @@ def execute_merge(model_names, calc_mode, all_modes, recipe_params, model_type):
168169
for r_key, r_tensor in result.items():
169170
merged_state_dict[r_key] = r_tensor.to(save_torch_dtype).cpu().clone()
170171
else:
172+
# Ensure compatibility with Model A's architecture.
173+
# If alignment_mode is 'pad/crop', we crop results that were padded.
174+
# If alignment_mode is 'interpolate', resizing happened during operators.
175+
if alignment_mode == 'pad/crop':
176+
target_shape = recipe_params['_tensor_a_shape']
177+
if result.shape != target_shape:
178+
slices = tuple(slice(0, min(res_s, tgt_s)) for res_s, tgt_s in zip(result.shape, target_shape))
179+
result = result[slices]
180+
171181
merged_state_dict[key] = result.to(save_torch_dtype).cpu().clone()
172182

173183
# Clean up tensor_a reference to allow GC
@@ -219,6 +229,7 @@ def define_schema(cls):
219229
io.Combo.Input("model_b", options=["None"] + folder_paths.get_filename_list("checkpoints")),
220230
io.Combo.Input("calc_mode", options=[m.name for m in TWO_MODEL_MODES]),
221231
io.Combo.Input("mismatch_mode", options=["skip", "zeros", "error"], default="skip"),
232+
io.Combo.Input("alignment_mode", options=["pad/crop", "interpolate"], default="pad/crop"),
222233
io.Float.Input("alpha", default=0.5, min=-2.0, max=3.0, step=0.01),
223234
io.Float.Input("beta", default=0.5, min=-2.0, max=3.0, step=0.01),
224235
io.Float.Input("gamma", default=0.99, min=0.0, max=1.0, step=0.001),
@@ -247,7 +258,7 @@ def execute(cls, execution_mode: str, model_a: str, model_b: str,
247258

248259
recipe_params = {
249260
"model_a": model_a, "model_b": model_b, "calc_mode": calc_mode,
250-
"mismatch_mode": mismatch_mode,
261+
"mismatch_mode": mismatch_mode, "alignment_mode": recipe_params.get('alignment_mode', 'pad/crop'),
251262
"alpha": alpha, "beta": beta, "gamma": gamma, "delta": delta, "seed": seed,
252263
"output_filename": output_filename, "save_dtype": save_dtype,
253264
"device": process_device, "dtype": torch.float32,
@@ -273,6 +284,7 @@ def define_schema(cls):
273284
io.Combo.Input("model_b", options=["None"] + folder_paths.get_filename_list("diffusion_models")),
274285
io.Combo.Input("calc_mode", options=[m.name for m in TWO_MODEL_MODES]),
275286
io.Combo.Input("mismatch_mode", options=["skip", "zeros", "error"], default="skip"),
287+
io.Combo.Input("alignment_mode", options=["pad/crop", "interpolate"], default="pad/crop"),
276288
io.Float.Input("alpha", default=0.5, min=-2.0, max=3.0, step=0.01),
277289
io.Float.Input("beta", default=0.5, min=-2.0, max=3.0, step=0.01),
278290
io.Float.Input("gamma", default=0.99, min=0.0, max=1.0, step=0.001),
@@ -301,7 +313,7 @@ def execute(cls, execution_mode: str, model_a: str, model_b: str,
301313

302314
recipe_params = {
303315
"model_a": model_a, "model_b": model_b, "calc_mode": calc_mode,
304-
"mismatch_mode": mismatch_mode,
316+
"mismatch_mode": mismatch_mode, "alignment_mode": recipe_params.get('alignment_mode', 'pad/crop'),
305317
"alpha": alpha, "beta": beta, "gamma": gamma, "delta": delta, "seed": seed,
306318
"output_filename": output_filename, "save_dtype": save_dtype,
307319
"device": process_device, "dtype": torch.float32,
@@ -327,6 +339,7 @@ def define_schema(cls):
327339
io.Combo.Input("model_b", options=["None"] + folder_paths.get_filename_list("text_encoders")),
328340
io.Combo.Input("calc_mode", options=[m.name for m in TWO_MODEL_MODES]),
329341
io.Combo.Input("mismatch_mode", options=["skip", "zeros", "error"], default="skip"),
342+
io.Combo.Input("alignment_mode", options=["pad/crop", "interpolate"], default="pad/crop"),
330343
io.Float.Input("alpha", default=0.5, min=-2.0, max=3.0, step=0.01),
331344
io.Float.Input("beta", default=0.5, min=-2.0, max=3.0, step=0.01),
332345
io.Float.Input("gamma", default=0.99, min=0.0, max=1.0, step=0.001),
@@ -355,7 +368,7 @@ def execute(cls, execution_mode: str, model_a: str, model_b: str,
355368

356369
recipe_params = {
357370
"model_a": model_a, "model_b": model_b, "calc_mode": calc_mode,
358-
"mismatch_mode": mismatch_mode,
371+
"mismatch_mode": mismatch_mode, "alignment_mode": recipe_params.get('alignment_mode', 'pad/crop'),
359372
"alpha": alpha, "beta": beta, "gamma": gamma, "delta": delta, "seed": seed,
360373
"output_filename": output_filename, "save_dtype": save_dtype,
361374
"device": process_device, "dtype": torch.float32,
@@ -381,6 +394,7 @@ def define_schema(cls):
381394
io.Combo.Input("model_b", options=["None"] + folder_paths.get_filename_list("loras")),
382395
io.Combo.Input("calc_mode", options=[m.name for m in TWO_MODEL_MODES]),
383396
io.Combo.Input("mismatch_mode", options=["skip", "zeros", "error"], default="skip"),
397+
io.Combo.Input("alignment_mode", options=["pad/crop", "interpolate"], default="pad/crop"),
384398
io.Float.Input("alpha", default=0.5, min=-2.0, max=3.0, step=0.01),
385399
io.Float.Input("beta", default=0.5, min=-2.0, max=3.0, step=0.01),
386400
io.Float.Input("gamma", default=0.99, min=0.0, max=1.0, step=0.001),
@@ -409,7 +423,7 @@ def execute(cls, execution_mode: str, model_a: str, model_b: str,
409423

410424
recipe_params = {
411425
"model_a": model_a, "model_b": model_b, "calc_mode": calc_mode,
412-
"mismatch_mode": mismatch_mode,
426+
"mismatch_mode": mismatch_mode, "alignment_mode": recipe_params.get('alignment_mode', 'pad/crop'),
413427
"alpha": alpha, "beta": beta, "gamma": gamma, "delta": delta, "seed": seed,
414428
"output_filename": output_filename, "save_dtype": save_dtype,
415429
"device": process_device, "dtype": torch.float32,
@@ -435,6 +449,7 @@ def define_schema(cls):
435449
io.Combo.Input("model_b", options=["None"] + folder_paths.get_filename_list("embeddings")),
436450
io.Combo.Input("calc_mode", options=[m.name for m in TWO_MODEL_MODES]),
437451
io.Combo.Input("mismatch_mode", options=["skip", "zeros", "error"], default="skip"),
452+
io.Combo.Input("alignment_mode", options=["pad/crop", "interpolate"], default="pad/crop"),
438453
io.Float.Input("alpha", default=0.5, min=-2.0, max=3.0, step=0.01),
439454
io.Float.Input("beta", default=0.5, min=-2.0, max=3.0, step=0.01),
440455
io.Float.Input("gamma", default=0.99, min=0.0, max=1.0, step=0.001),
@@ -463,7 +478,7 @@ def execute(cls, execution_mode: str, model_a: str, model_b: str,
463478

464479
recipe_params = {
465480
"model_a": model_a, "model_b": model_b, "calc_mode": calc_mode,
466-
"mismatch_mode": mismatch_mode,
481+
"mismatch_mode": mismatch_mode, "alignment_mode": recipe_params.get('alignment_mode', 'pad/crop'),
467482
"alpha": alpha, "beta": beta, "gamma": gamma, "delta": delta, "seed": seed,
468483
"output_filename": output_filename, "save_dtype": save_dtype,
469484
"device": process_device, "dtype": torch.float32,
@@ -492,6 +507,7 @@ def define_schema(cls):
492507
io.Combo.Input("model_c", options=["None"] + folder_paths.get_filename_list("checkpoints")),
493508
io.Combo.Input("calc_mode", options=[m.name for m in THREE_MODEL_MODES]),
494509
io.Combo.Input("mismatch_mode", options=["skip", "zeros", "error"], default="skip"),
510+
io.Combo.Input("alignment_mode", options=["pad/crop", "interpolate"], default="pad/crop"),
495511
io.Float.Input("alpha", default=0.5, min=-2.0, max=3.0, step=0.01),
496512
io.Float.Input("beta", default=0.5, min=-2.0, max=3.0, step=0.01),
497513
io.Float.Input("gamma", default=0.5, min=-2.0, max=3.0, step=0.01),
@@ -520,7 +536,7 @@ def execute(cls, execution_mode: str, model_a: str, model_b: str, model_c: str,
520536

521537
recipe_params = {
522538
"model_a": model_a, "model_b": model_b, "model_c": model_c, "calc_mode": calc_mode,
523-
"mismatch_mode": mismatch_mode,
539+
"mismatch_mode": mismatch_mode, "alignment_mode": recipe_params.get('alignment_mode', 'pad/crop'),
524540
"alpha": alpha, "beta": beta, "gamma": gamma, "delta": delta, "seed": seed,
525541
"output_filename": output_filename, "save_dtype": save_dtype,
526542
"device": process_device, "dtype": torch.float32,
@@ -547,6 +563,7 @@ def define_schema(cls):
547563
io.Combo.Input("model_c", options=["None"] + folder_paths.get_filename_list("diffusion_models")),
548564
io.Combo.Input("calc_mode", options=[m.name for m in THREE_MODEL_MODES]),
549565
io.Combo.Input("mismatch_mode", options=["skip", "zeros", "error"], default="skip"),
566+
io.Combo.Input("alignment_mode", options=["pad/crop", "interpolate"], default="pad/crop"),
550567
io.Float.Input("alpha", default=0.5, min=-2.0, max=3.0, step=0.01),
551568
io.Float.Input("beta", default=0.5, min=-2.0, max=3.0, step=0.01),
552569
io.Float.Input("gamma", default=0.5, min=-2.0, max=3.0, step=0.01),
@@ -575,7 +592,7 @@ def execute(cls, execution_mode: str, model_a: str, model_b: str, model_c: str,
575592

576593
recipe_params = {
577594
"model_a": model_a, "model_b": model_b, "model_c": model_c, "calc_mode": calc_mode,
578-
"mismatch_mode": mismatch_mode,
595+
"mismatch_mode": mismatch_mode, "alignment_mode": recipe_params.get('alignment_mode', 'pad/crop'),
579596
"alpha": alpha, "beta": beta, "gamma": gamma, "delta": delta, "seed": seed,
580597
"output_filename": output_filename, "save_dtype": save_dtype,
581598
"device": process_device, "dtype": torch.float32,
@@ -602,6 +619,7 @@ def define_schema(cls):
602619
io.Combo.Input("model_c", options=["None"] + folder_paths.get_filename_list("text_encoders")),
603620
io.Combo.Input("calc_mode", options=[m.name for m in THREE_MODEL_MODES]),
604621
io.Combo.Input("mismatch_mode", options=["skip", "zeros", "error"], default="skip"),
622+
io.Combo.Input("alignment_mode", options=["pad/crop", "interpolate"], default="pad/crop"),
605623
io.Float.Input("alpha", default=0.5, min=-2.0, max=3.0, step=0.01),
606624
io.Float.Input("beta", default=0.5, min=-2.0, max=3.0, step=0.01),
607625
io.Float.Input("gamma", default=0.5, min=-2.0, max=3.0, step=0.01),
@@ -630,7 +648,7 @@ def execute(cls, execution_mode: str, model_a: str, model_b: str, model_c: str,
630648

631649
recipe_params = {
632650
"model_a": model_a, "model_b": model_b, "model_c": model_c, "calc_mode": calc_mode,
633-
"mismatch_mode": mismatch_mode,
651+
"mismatch_mode": mismatch_mode, "alignment_mode": recipe_params.get('alignment_mode', 'pad/crop'),
634652
"alpha": alpha, "beta": beta, "gamma": gamma, "delta": delta, "seed": seed,
635653
"output_filename": output_filename, "save_dtype": save_dtype,
636654
"device": process_device, "dtype": torch.float32,
@@ -657,6 +675,7 @@ def define_schema(cls):
657675
io.Combo.Input("model_c", options=["None"] + folder_paths.get_filename_list("loras")),
658676
io.Combo.Input("calc_mode", options=[m.name for m in THREE_MODEL_MODES]),
659677
io.Combo.Input("mismatch_mode", options=["skip", "zeros", "error"], default="skip"),
678+
io.Combo.Input("alignment_mode", options=["pad/crop", "interpolate"], default="pad/crop"),
660679
io.Float.Input("alpha", default=0.5, min=-2.0, max=3.0, step=0.01),
661680
io.Float.Input("beta", default=0.5, min=-2.0, max=3.0, step=0.01),
662681
io.Float.Input("gamma", default=0.5, min=-2.0, max=3.0, step=0.01),
@@ -685,7 +704,7 @@ def execute(cls, execution_mode: str, model_a: str, model_b: str, model_c: str,
685704

686705
recipe_params = {
687706
"model_a": model_a, "model_b": model_b, "model_c": model_c, "calc_mode": calc_mode,
688-
"mismatch_mode": mismatch_mode,
707+
"mismatch_mode": mismatch_mode, "alignment_mode": recipe_params.get('alignment_mode', 'pad/crop'),
689708
"alpha": alpha, "beta": beta, "gamma": gamma, "delta": delta, "seed": seed,
690709
"output_filename": output_filename, "save_dtype": save_dtype,
691710
"device": process_device, "dtype": torch.float32,
@@ -712,6 +731,7 @@ def define_schema(cls):
712731
io.Combo.Input("model_c", options=["None"] + folder_paths.get_filename_list("embeddings")),
713732
io.Combo.Input("calc_mode", options=[m.name for m in THREE_MODEL_MODES]),
714733
io.Combo.Input("mismatch_mode", options=["skip", "zeros", "error"], default="skip"),
734+
io.Combo.Input("alignment_mode", options=["pad/crop", "interpolate"], default="pad/crop"),
715735
io.Float.Input("alpha", default=0.5, min=-2.0, max=3.0, step=0.01),
716736
io.Float.Input("beta", default=0.5, min=-2.0, max=3.0, step=0.01),
717737
io.Float.Input("gamma", default=0.5, min=-2.0, max=3.0, step=0.01),
@@ -740,7 +760,7 @@ def execute(cls, execution_mode: str, model_a: str, model_b: str, model_c: str,
740760

741761
recipe_params = {
742762
"model_a": model_a, "model_b": model_b, "model_c": model_c, "calc_mode": calc_mode,
743-
"mismatch_mode": mismatch_mode,
763+
"mismatch_mode": mismatch_mode, "alignment_mode": recipe_params.get('alignment_mode', 'pad/crop'),
744764
"alpha": alpha, "beta": beta, "gamma": gamma, "delta": delta, "seed": seed,
745765
"output_filename": output_filename, "save_dtype": save_dtype,
746766
"device": process_device, "dtype": torch.float32,

0 commit comments

Comments
 (0)