Skip to content

[feat] Add model merging via mergekit (SentenceTransformer / SparseEncoder / CrossEncoder)#3776

Open
yjoonjang wants to merge 2 commits into
huggingface:mainfrom
yjoonjang:feat/merge
Open

[feat] Add model merging via mergekit (SentenceTransformer / SparseEncoder / CrossEncoder)#3776
yjoonjang wants to merge 2 commits into
huggingface:mainfrom
yjoonjang:feat/merge

Conversation

@yjoonjang
Copy link
Copy Markdown
Contributor

Hello!

Pull Request overview

  • Add BaseModel.merge(...), a classmethod that combines several fine-tuned checkpoints into one model. The transformer body goes through mergekit; the surrounding ST modules (Pooling, Normalize, Dense, LayerNorm, ...) are handled in-tree so the result is a usable SentenceTransformer directory.
  • Addresses the merge feature request in [feature request] Pre-training, GPL, Model merging, LLM2Vec #2689 and follows the direction @tomaarsen suggested in arcee-ai/mergekit#286 — a mergekit + Sentence Transformers integration that preserves the Pooling and other modules.

Details

BaseModel.merge works for all three model classes (SentenceTransformer, SparseEncoder, CrossEncoder) since they all inherit from BaseModel. The transformer body is delegated to mergekit, so every method it supports (linear, slerp, task_arithmetic, ties, dare_ties, dare_linear, breadcrumbs, della, model_stock, sce, ...) is exposed for free. Around the body, we handle the ST module stack ourselves: validate that the inputs match, merge weight-bearing modules at the state-dict level, copy the stateless ones, and write the ST-specific top-level files mergekit doesn't know about (modules.json, config_sentence_transformers.json, sentence_bert_config.json). The result reloads cleanly with cls(output_path) and the per-class task head is preserved.

Example usage

SentenceTransformer.merge(
    models=["sentence-transformers/all-MiniLM-L6-v2",
            "sentence-transformers/paraphrase-MiniLM-L6-v2"],
    weights=[0.6, 0.4],
    method="linear",                # or "slerp", "ties", "task_arithmetic", ...
    base_model=None,                # required for delta-based methods
    output_path="merged-minilm/",
    dtype="float16",
    device="cpu",
)

Same signature for SparseEncoder.merge(...) and CrossEncoder.merge(...).

How it works

The merge runs in three phases:

Phase What happens
1. Validate All inputs must agree on modules.json — same classes, same order, same path. Per-module config.jsons are compared after canonicalizing through the Module class (load_config → constructor → get_config_dict), so older saves that predate newer pooling keys (pooling_mode_weightedmean_tokens, include_prompt, ...) compare equal to newer saves with the same effective config. If one target model has modules.json and others don't, the merge is rejected with a clear error; if all target models have (or all don't have) modules.json, the merge proceeds normally. Asym / Router and a save_in_root=False first module raise NotImplementedError up front; other non-standard module types (e.g. BoW, StaticEmbedding, CLIPModel) aren't explicitly rejected here but will fail later inside mergekit / ST loading.
2. Body Build a mergekit MergeConfiguration from the user's args and run_merge it to a temp directory; then move the outputs into output_path.
3. Other modules Weight-bearing modules (Dense, LayerNorm, WeightedLayerPooling) are merged state-dict-wise; stateless ones (Pooling, Normalize, Dropout, SpladePooling) are validated and copied from the first input. A subdirectory that's missing on disk (HF Hub doesn't track empty dirs — common for Normalize) or present-but-empty is treated as a stateless empty module, so we just mkdir the target and continue. Finally, any top-level sidecar from the first input that mergekit didn't write is copied over — that's the ST configs (modules.json, config_sentence_transformers.json, sentence_bert_config.json) and multimodal extras (preprocessor_config.json, processor_config.json, chat_template.json, ...).

Module dispatch table

Module class Has weights? Merge strategy
Transformer / MLMTransformer (HF body) yes mergekit, any supported method
Pooling / Normalize / Dropout / SpladePooling no config-equality check + copy from first input
Dense / LayerNorm / WeightedLayerPooling yes state-dict: linear weighted average, or task_arithmetic deltas when base_model is given
Asym / Router varies rejected up front with a clear error — would need recursive merging of child modules
BoW / StaticEmbedding / CLIPModel and other non-standard module types varies not explicitly handled; merge will surface an error from mergekit / module loading

Default weights

The natural "uniform" default depends on the method:

Method family weights=None defaults to Why
Blend (linear, slerp) [1/n] * n weights form a weighted average
Delta (task_arithmetic, ties, dare_*, breadcrumbs, della) [1.0] * n weights scale per-model deltas — [0.5, 0.5] would silently halve every delta

The mergekit parameters.normalize flag is set per-method so user-supplied scaling carries through end-to-end for delta methods too.

Problem & Actions

ST-side modules merge linearly

For weight-bearing ST modules (Dense, LayerNorm, WeightedLayerPooling), only linear and task_arithmetic have state-dict implementations. task_arithmetic further requires the base_model to also have the matching module on disk so we can read its tensors as the base; otherwise we fall back to linear here too. For ties/dare_*/slerp the body uses that method but these ST-side tensors are merged with a plain linear weighted average and a WARNING is logged. Most ST checkpoints don't have a Dense head (the common stack is body + Pooling + Normalize), so the practical footprint is small, but worth flagging.

mergekit 0.1.4 bugs hit during integration

Two upstream bugs are worked around in _require_mergekit():

  1. Pydantic forward refs not resolved on ConfiguredModuleArchitecture / ConfiguredModelArchitecture. First call raises PydanticUserError: ... is not fully defined; ... call ConfiguredModuleArchitecture.model_rebuild(). Reported as arcee-ai/mergekit#681 — still open, no maintainer reply. The fix is to import torch and call model_rebuild() once before invoking mergekit.

  2. mergekit/plan.py:172 tuple/list concatenation TypeError. WeightInfo.aliases is Tuple[str, ...] (the pydantic model needs to stay hashable), but plan.py does [w_in.name] + (w_in.aliases or []) and crashes. Only fires for architectures whose JSON uses aliases — basically BERT variants (LayerNorm.beta/gamma legacy aliases). I couldn't find related reports for this anywhere. The workaround pulls plan_tensor's source via inspect.getsource, swaps + for * unpacking, and execs it back into mergekit. Sentinel-gated and logs a WARNING so it's visible.

Both workarounds can be deleted once upstream is fixed.

Reconciling mergekit's arch registry with the input checkpoints

mergekit ships a set of hand-crafted architecture JSON files (e.g. bert-masked-lm.json) that list which weights to merge. These can disagree with what's actually in the input checkpoints in two ways:

  1. Required weights that aren't always saved. bert-masked-lm.json requires bert.pooler.dense.{weight,bias}, but BertForMaskedLM checkpoints commonly don't save the pooler. Mergekit aborts with Tensor X required but not present.
  2. Weights present in the checkpoint but missing from the arch JSON. The same bert-masked-lm.json doesn't list cls.predictions.transform.*, so those tensors silently get dropped from the merged output.

Before calling run_merge, we scan the input safetensors keys and compare against the arch's listed weights:

  • If any non-layer-templated input key isn't covered by the arch, we delete the arch entry from NAME_TO_ARCH, which makes mergekit fall back to its auto-inference path. Auto-inference lists every input tensor and merges them all, which is the safer default whenever the hand-crafted arch is incomplete.
  • Otherwise (arch fully covers the inputs), we only mark required-but-absent weights as optional and keep the hand-crafted arch's tied-weights / aliases semantics.

Compatibility / refusal cases (tested)

Refusal Trigger
Module count mismatch inputs differ in length of modules.json
Module class mismatch inputs have different class at the same index
Module config mismatch e.g. one model has pooling_mode="cls", another "mean"
Inputs disagree on modules.json presence one input has it, another doesn't
Unsupported method typo in method=
Missing base_model delta method without base_model
Two-model method with 3+ models wrong arity for slerp/nuslerp/nearswap
Two-model method with duplicate model paths len(set(models)) < 2
slerp with base_model not in models ambiguous
Single model len(models) < 2
output_path=None required argument

Validation

I ran a linear merging on a few embedding/reranker model pairs:

Pair Result Notes
intfloat/multilingual-e5-largenlpai-lab/KoE5 Older e5 has fewer pooling-config keys than KoE5 (pooling_mode_weightedmean_tokens, pooling_mode_lasttoken, include_prompt. multilingual-e5-large is also missing the 2_Normalize/ directory entirely (empty dir not tracked on the Hub) — handled by the empty-stateless-module logic.
BAAI/bge-m3nlpai-lab/KURE-v1 Same older/newer pooling asymmetry, same missing-2_Normalize/.
yjoonjang/splade-ko-v1telepix/PIXIE-Splade-v1.5 Both are full bert-base ModernBertForMaskedLM with pooler weights present. End-to-end SparseEncoder merge works.
naver/splade-cocondenser-ensembledistiltomaarsen/splade-cocondenser-ensembledistil-sts Surfaced two issues: (a) one checkpoint has embedding_dimension=None in its SpladePooling config while the other has 30522 — that field is lazy-initialized at first forward pass, so the canonical-config validator now treats None as a wildcard; (b) mergekit's bert-masked-lm.json doesn't list cls.predictions.transform.*, so without intervention those weights would silently drop from the merged output — fixed by falling back to auto-inference when the arch's listed weights don't cover the inputs.
BAAI/bge-reranker-v2-m3dragonkue/bge-reranker-v2-m3-ko Plain HF AutoModelForSequenceClassification — no modules.json on either. Handled by treating modules.json-less inputs as body-only.
Qwen/Qwen3-VL-Embedding-2Bwhybe-choi/Qwen3-VL-Embedding-2B-ko-vdr-preview-v0.2 Multimodal (vision-language) embedding, ~2B params. Surfaced that mergekit only copies the tokenizer — multimodal sidecars (preprocessor_config.json, processor_config.json, chat_template.{json,jinja}, ...) need to come from the first input separately, or the merged dir won't reload. The copy step now picks up every top-level sidecar from the first input that mergekit didn't already write, while explicitly skipping weight files so the merged shards aren't shadowed.

Companion examples live under examples/sentence_transformer/training/other/model_merging.py, examples/cross_encoder/training/other/model_merging.py, and examples/sparse_encoder/training/other/model_merging.py.

Files changed

File Change
sentence_transformers/base/merging.py new — core merge logic
sentence_transformers/base/model.py added BaseModel.merge classmethod
pyproject.toml added merge = ["mergekit>=0.0.5"] optional extra
tests/test_merging.py new tests (per-method parametrize + refusal cases)
examples/sentence_transformer/training/other/model_merging.py new example
examples/cross_encoder/training/other/model_merging.py new example
examples/sparse_encoder/training/other/model_merging.py new example

TODO / open questions

  • File an upstream issue + 1-line PR for the plan.py:172 tuple/list bug in mergekit; both workarounds in this PR can be dropped once mergekit ships a fix.
  • Implement ties/dare_*/slerp for ST-side weight modules (Dense/LayerNorm/...) instead of always falling back to linear.
  • Worth supporting mergekit's slice-based "frankenmerge" YAML pass-through (SentenceTransformer.from_merge_config(yaml_path))?
  • Cross-reference arcee-ai/mergekit#286 once this lands.

Note: Claude Code helped with parts of the implementation and drafted the test cases; everything's been manually reviewed and verified end-to-end. But there may be something I could have missed, so feel free to add/delete/update features.

Thank you!

  • Youngjoon Jang

@tomaarsen
Copy link
Copy Markdown
Member

Hello!

Oh wow, I've tried some merging stuff myself, but never considered adding it as a method on the BaseModel, I think that's very clean. When I tried merging manually earlier, I could never get it to work very well, but I'm very curious to try it here with this.

I'll read through your PR in more details in a bit.

  • Tom Aarsen

@tomaarsen tomaarsen added the enhancement New feature or request label May 13, 2026
@whybe-choi
Copy link
Copy Markdown
Contributor

Hello, @yjoonjang and @tomaarsen !

I think the lower bound for the merge extra should be derived from the mergekit interfaces and methods this PR exposes. I checked the currently published mergekit versions on PyPI: 0.0.5, 0.0.5.1, 0.0.6, and 0.1.4.

A few relevant differences:

  1. Up to mergekit==0.0.6, the architecture code is still exposed as mergekit/architecture.py. The newer package layout referenced in this PR only exists in 0.1.4, e.g. mergekit/architecture/base.py and mergekit/architecture/json_definitions.py.

A quick repro with mergekit==0.0.6:

from mergekit.architecture.base import ConfiguredModelArchitecture
from mergekit.architecture.json_definitions import NAME_TO_ARCH

raises:

ModuleNotFoundError: No module named 'mergekit.architecture.base'; 'mergekit.architecture' is not a package
ModuleNotFoundError: No module named 'mergekit.architecture.json_definitions'; 'mergekit.architecture' is not a package
  1. More importantly, this PR includes arcee_fusion and karcher in SUPPORTED_METHODS, but those methods are not implemented in mergekit==0.0.6.

The get() helper resolves methods from REGISTERED_MERGE_METHODS and raises RuntimeError if the method name is missing:

In 0.0.6, the registry includes methods such as LinearMerge, SlerpMerge, NuSlerpMerge, SCEMerge, and NearSwapMerge, but not ArceeFusionMerge or KarcherMerge:

In 0.1.4, both are imported and registered:

Repro with mergekit==0.0.6:

from mergekit.merge_methods import get

get("arcee_fusion")
get("karcher")

raises:

RuntimeError: Unimplemented merge method arcee_fusion
RuntimeError: Unimplemented merge method karcher

Both methods are available in mergekit==0.1.4.

So if we want to keep the current SUPPORTED_METHODS list as-is, I think the extra should be:

merge = ["mergekit>=0.1.4"]

Comment thread pyproject.toml Outdated
Co-authored-by: Yongbin Choi <whybe.choi@gmail.com>
@yjoonjang
Copy link
Copy Markdown
Contributor Author

Good catch, thanks. The package layout and arcee_fusion/karcher registration only appeared in 0.1.4, so it definitely should be >=0.1.4.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants