Skip to content

Commit b1ec55f

Browse files
committed
improve and correct merger operations based on papers
1 parent 8deb4fd commit b1ec55f

4 files changed

Lines changed: 213 additions & 22 deletions

File tree

docs/merger_2_model_modes.md

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,47 @@
2323
---
2424

2525
## Power-Up (DARE)
26-
> Adds the unique capabilities of Model B to Model A using the Drop and Rescale (DARE) technique, which often preserves the knowledge of the base model better than simple additions.
26+
> Adds the unique capabilities of Model B to Model A using the Drop and Rescale (DARE) technique. This implementation handles shape mismatches between models by padding and uses a randomized dropout mask.
2727
2828
**Models Used:** A, B
2929
**Parameters:**
30-
- **Alpha:** The dropout rate. This is the proportion of unique weights from Model B that are *dropped* before merging. A higher value means less of B is merged.
30+
- **Alpha:** The dropout rate ($p$). This is the proportion of delta parameters from Model B that are randomly set to zero.
3131
- **Beta:** A final multiplier for the rescaled difference before it's added to Model A.
32+
- **Rescaling Logic:** Remaining weights are automatically rescaled by $1/(1-p)$ as per the DARE paper to approximate the original embeddings.
33+
34+
---
35+
36+
## Enhanced Man Interp
37+
> Sophisticated interpolation between values from A and B depending on their difference relative to other values, with manual threshold control.
38+
39+
**Models Used:** A, B
40+
**Parameters:**
41+
- **Alpha:** Interpolation strength.
42+
- **Beta:** Lower mean threshold for filtering differences.
43+
- **Gamma:** Upper mean threshold for filtering differences.
44+
- **Delta:** Smoothness factor (mix between randomized mask and powered differences).
45+
46+
---
47+
48+
## Enhanced Auto Interp
49+
> Automated version of the enhanced interpolation mode that dynamically calculates thresholds based on mean differences.
50+
51+
**Models Used:** A, B
52+
**Parameters:**
53+
- **Alpha:** Interpolation strength.
54+
- **Beta:** Threshold adjustment factor.
55+
- **Gamma:** Smoothness factor.
56+
57+
---
58+
59+
## Weight-Sum Cutoff
60+
> A linear interpolation mode that only applies the merge to weights whose differences fall within a specific threshold range.
61+
62+
**Models Used:** A, B
63+
**Parameters:**
64+
- **Alpha:** Interpolation weight (multiplier for the difference).
65+
- **Beta:** Upper threshold for the difference cutoff.
66+
- **Gamma:** Lower threshold for the difference cutoff.
3267

3368
---
3469

@@ -84,4 +119,4 @@ Layers matching any pattern will be **removed entirely** from the output.
84119
| `text_model` | All text encoder layers |
85120
| `\.norm` | All normalization layers |
86121
| `attn\.(q\|k\|v)` | Query, key, value attention weights |
87-
| `block\.[0-9]\.` | Blocks 0-9 |
122+
| `block\.[0-9]\.` | Blocks 0-9 |

docs/merger_3_model_modes.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,19 +22,19 @@
2222
---
2323

2424
## Extract-Features
25-
> A powerful mode that identifies features present in both `(B - A)` and `(C - A)` and adds them to A. Allows for fine-grained control over combining aspects based on their similarity.
25+
> A powerful mode that identifies features present in both `(B - A)` and `(C - A)` and adds them to A. It uses per-vector cosine similarity to decide how much of each feature to keep, allowing for fine-grained control over combining aspects.
2626
2727
**Models Used:** A, B, C
2828
**Parameters:**
2929
- **Alpha:** Weights the merge between Model B (`0.0`) and Model C (`1.0`).
3030
- **Beta:** Controls the focus on similarity (`0.0`) versus dissimilarity (`1.0`).
31-
- **Gamma:** A bias exponent for similarity. Higher values increase the bias.
31+
- **Gamma:** A bias exponent for similarity calculation.
3232
- **Delta:** A final multiplier for the extracted features before they are added to Model A.
3333

3434
---
3535

3636
## Add-Dissimilarities
37-
> Identifies features that are dissimilar between Model B and Model C and adds them to Model A. Useful for combining unique aspects of two different models.
37+
> Identifies features that are dissimilar between Model B and Model C (relative to A) and adds them to Model A. Useful for combining unique aspects of two different models.
3838
3939
**Models Used:** A, B, C
4040
**Parameters:**
@@ -80,4 +80,4 @@ Layers matching any pattern will be **removed entirely** from the output.
8080
**Pattern format:**
8181
- Whitespace-separated regex patterns
8282
- Patterns use **substring matching** (not full match)
83-
- Example: `text_model lora` matches any key containing "text_model" OR "lora"
83+
- Example: `text_model lora` matches any key containing "text_model" OR "lora"

nodes/merger.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ def define_schema(cls):
222222
io.Float.Input("alpha", default=0.5, min=-2.0, max=3.0, step=0.01),
223223
io.Float.Input("beta", default=0.5, min=-2.0, max=3.0, step=0.01),
224224
io.Float.Input("gamma", default=0.99, min=0.0, max=1.0, step=0.001),
225+
io.Float.Input("delta", default=0.5, min=-2.0, max=3.0, step=0.01),
225226
io.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff),
226227
io.String.Input("output_filename", default="merged_2_checkpoint"),
227228
io.Combo.Input("save_dtype", options=["fp32", "fp16", "bf16"]),
@@ -238,7 +239,7 @@ def define_schema(cls):
238239
@classmethod
239240
def execute(cls, execution_mode: str, model_a: str, model_b: str,
240241
calc_mode: str, mismatch_mode: str, alpha: float, beta: float,
241-
gamma: float, seed: int, output_filename: str, save_dtype: str,
242+
gamma: float, delta: float, seed: int, output_filename: str, save_dtype: str,
242243
process_device: str, exclude_patterns: str, discard_patterns: str) -> io.NodeOutput:
243244
doc = load_documentation_from_file('merger_2_model_modes.md')
244245
if execution_mode == "DOCUMENTATION ONLY":
@@ -247,7 +248,7 @@ def execute(cls, execution_mode: str, model_a: str, model_b: str,
247248
recipe_params = {
248249
"model_a": model_a, "model_b": model_b, "calc_mode": calc_mode,
249250
"mismatch_mode": mismatch_mode,
250-
"alpha": alpha, "beta": beta, "gamma": gamma, "seed": seed,
251+
"alpha": alpha, "beta": beta, "gamma": gamma, "delta": delta, "seed": seed,
251252
"output_filename": output_filename, "save_dtype": save_dtype,
252253
"device": process_device, "dtype": torch.float32,
253254
"exclude_patterns": exclude_patterns, "discard_patterns": discard_patterns,
@@ -275,6 +276,7 @@ def define_schema(cls):
275276
io.Float.Input("alpha", default=0.5, min=-2.0, max=3.0, step=0.01),
276277
io.Float.Input("beta", default=0.5, min=-2.0, max=3.0, step=0.01),
277278
io.Float.Input("gamma", default=0.99, min=0.0, max=1.0, step=0.001),
279+
io.Float.Input("delta", default=0.5, min=-2.0, max=3.0, step=0.01),
278280
io.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff),
279281
io.String.Input("output_filename", default="merged_2_model"),
280282
io.Combo.Input("save_dtype", options=["fp32", "fp16", "bf16"]),
@@ -291,7 +293,7 @@ def define_schema(cls):
291293
@classmethod
292294
def execute(cls, execution_mode: str, model_a: str, model_b: str,
293295
calc_mode: str, mismatch_mode: str, alpha: float, beta: float,
294-
gamma: float, seed: int, output_filename: str, save_dtype: str,
296+
gamma: float, delta: float, seed: int, output_filename: str, save_dtype: str,
295297
process_device: str, exclude_patterns: str, discard_patterns: str) -> io.NodeOutput:
296298
doc = load_documentation_from_file('merger_2_model_modes.md')
297299
if execution_mode == "DOCUMENTATION ONLY":
@@ -300,7 +302,7 @@ def execute(cls, execution_mode: str, model_a: str, model_b: str,
300302
recipe_params = {
301303
"model_a": model_a, "model_b": model_b, "calc_mode": calc_mode,
302304
"mismatch_mode": mismatch_mode,
303-
"alpha": alpha, "beta": beta, "gamma": gamma, "seed": seed,
305+
"alpha": alpha, "beta": beta, "gamma": gamma, "delta": delta, "seed": seed,
304306
"output_filename": output_filename, "save_dtype": save_dtype,
305307
"device": process_device, "dtype": torch.float32,
306308
"exclude_patterns": exclude_patterns, "discard_patterns": discard_patterns,
@@ -328,6 +330,7 @@ def define_schema(cls):
328330
io.Float.Input("alpha", default=0.5, min=-2.0, max=3.0, step=0.01),
329331
io.Float.Input("beta", default=0.5, min=-2.0, max=3.0, step=0.01),
330332
io.Float.Input("gamma", default=0.99, min=0.0, max=1.0, step=0.001),
333+
io.Float.Input("delta", default=0.5, min=-2.0, max=3.0, step=0.01),
331334
io.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff),
332335
io.String.Input("output_filename", default="merged_2_textencoder"),
333336
io.Combo.Input("save_dtype", options=["fp32", "fp16", "bf16"]),
@@ -344,7 +347,7 @@ def define_schema(cls):
344347
@classmethod
345348
def execute(cls, execution_mode: str, model_a: str, model_b: str,
346349
calc_mode: str, mismatch_mode: str, alpha: float, beta: float,
347-
gamma: float, seed: int, output_filename: str, save_dtype: str,
350+
gamma: float, delta: float, seed: int, output_filename: str, save_dtype: str,
348351
process_device: str, exclude_patterns: str, discard_patterns: str) -> io.NodeOutput:
349352
doc = load_documentation_from_file('merger_2_model_modes.md')
350353
if execution_mode == "DOCUMENTATION ONLY":
@@ -353,7 +356,7 @@ def execute(cls, execution_mode: str, model_a: str, model_b: str,
353356
recipe_params = {
354357
"model_a": model_a, "model_b": model_b, "calc_mode": calc_mode,
355358
"mismatch_mode": mismatch_mode,
356-
"alpha": alpha, "beta": beta, "gamma": gamma, "seed": seed,
359+
"alpha": alpha, "beta": beta, "gamma": gamma, "delta": delta, "seed": seed,
357360
"output_filename": output_filename, "save_dtype": save_dtype,
358361
"device": process_device, "dtype": torch.float32,
359362
"exclude_patterns": exclude_patterns, "discard_patterns": discard_patterns,
@@ -381,6 +384,7 @@ def define_schema(cls):
381384
io.Float.Input("alpha", default=0.5, min=-2.0, max=3.0, step=0.01),
382385
io.Float.Input("beta", default=0.5, min=-2.0, max=3.0, step=0.01),
383386
io.Float.Input("gamma", default=0.99, min=0.0, max=1.0, step=0.001),
387+
io.Float.Input("delta", default=0.5, min=-2.0, max=3.0, step=0.01),
384388
io.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff),
385389
io.String.Input("output_filename", default="merged_2_lora"),
386390
io.Combo.Input("save_dtype", options=["fp32", "fp16", "bf16"]),
@@ -397,7 +401,7 @@ def define_schema(cls):
397401
@classmethod
398402
def execute(cls, execution_mode: str, model_a: str, model_b: str,
399403
calc_mode: str, mismatch_mode: str, alpha: float, beta: float,
400-
gamma: float, seed: int, output_filename: str, save_dtype: str,
404+
gamma: float, delta: float, seed: int, output_filename: str, save_dtype: str,
401405
process_device: str, exclude_patterns: str, discard_patterns: str) -> io.NodeOutput:
402406
doc = load_documentation_from_file('merger_2_model_modes.md')
403407
if execution_mode == "DOCUMENTATION ONLY":
@@ -406,7 +410,7 @@ def execute(cls, execution_mode: str, model_a: str, model_b: str,
406410
recipe_params = {
407411
"model_a": model_a, "model_b": model_b, "calc_mode": calc_mode,
408412
"mismatch_mode": mismatch_mode,
409-
"alpha": alpha, "beta": beta, "gamma": gamma, "seed": seed,
413+
"alpha": alpha, "beta": beta, "gamma": gamma, "delta": delta, "seed": seed,
410414
"output_filename": output_filename, "save_dtype": save_dtype,
411415
"device": process_device, "dtype": torch.float32,
412416
"exclude_patterns": exclude_patterns, "discard_patterns": discard_patterns,
@@ -434,6 +438,7 @@ def define_schema(cls):
434438
io.Float.Input("alpha", default=0.5, min=-2.0, max=3.0, step=0.01),
435439
io.Float.Input("beta", default=0.5, min=-2.0, max=3.0, step=0.01),
436440
io.Float.Input("gamma", default=0.99, min=0.0, max=1.0, step=0.001),
441+
io.Float.Input("delta", default=0.5, min=-2.0, max=3.0, step=0.01),
437442
io.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff),
438443
io.String.Input("output_filename", default="merged_2_embedding"),
439444
io.Combo.Input("save_dtype", options=["fp32", "fp16", "bf16"]),
@@ -450,7 +455,7 @@ def define_schema(cls):
450455
@classmethod
451456
def execute(cls, execution_mode: str, model_a: str, model_b: str,
452457
calc_mode: str, mismatch_mode: str, alpha: float, beta: float,
453-
gamma: float, seed: int, output_filename: str, save_dtype: str,
458+
gamma: float, delta: float, seed: int, output_filename: str, save_dtype: str,
454459
process_device: str, exclude_patterns: str, discard_patterns: str) -> io.NodeOutput:
455460
doc = load_documentation_from_file('merger_2_model_modes.md')
456461
if execution_mode == "DOCUMENTATION ONLY":
@@ -459,7 +464,7 @@ def execute(cls, execution_mode: str, model_a: str, model_b: str,
459464
recipe_params = {
460465
"model_a": model_a, "model_b": model_b, "calc_mode": calc_mode,
461466
"mismatch_mode": mismatch_mode,
462-
"alpha": alpha, "beta": beta, "gamma": gamma, "seed": seed,
467+
"alpha": alpha, "beta": beta, "gamma": gamma, "delta": delta, "seed": seed,
463468
"output_filename": output_filename, "save_dtype": save_dtype,
464469
"device": process_device, "dtype": torch.float32,
465470
"exclude_patterns": exclude_patterns, "discard_patterns": discard_patterns,

0 commit comments

Comments
 (0)