[feat] Add model merging via mergekit (SentenceTransformer / SparseEncoder / CrossEncoder)#3776
[feat] Add model merging via mergekit (SentenceTransformer / SparseEncoder / CrossEncoder)#3776yjoonjang wants to merge 2 commits into
feat] Add model merging via mergekit (SentenceTransformer / SparseEncoder / CrossEncoder)#3776Conversation
|
Hello! Oh wow, I've tried some merging stuff myself, but never considered adding it as a method on the BaseModel, I think that's very clean. When I tried merging manually earlier, I could never get it to work very well, but I'm very curious to try it here with this. I'll read through your PR in more details in a bit.
|
|
Hello, @yjoonjang and @tomaarsen ! I think the lower bound for the A few relevant differences:
A quick repro with from mergekit.architecture.base import ConfiguredModelArchitecture
from mergekit.architecture.json_definitions import NAME_TO_ARCHraises:
The In In
Repro with from mergekit.merge_methods import get
get("arcee_fusion")
get("karcher")raises: Both methods are available in So if we want to keep the current merge = ["mergekit>=0.1.4"] |
Co-authored-by: Yongbin Choi <whybe.choi@gmail.com>
|
Good catch, thanks. The package layout and |
Hello!
Pull Request overview
BaseModel.merge(...), a classmethod that combines several fine-tuned checkpoints into one model. The transformer body goes throughmergekit; the surrounding ST modules (Pooling,Normalize,Dense,LayerNorm, ...) are handled in-tree so the result is a usableSentenceTransformerdirectory.mergekit + Sentence Transformersintegration that preserves the Pooling and other modules.Details
BaseModel.mergeworks for all three model classes (SentenceTransformer,SparseEncoder,CrossEncoder) since they all inherit fromBaseModel. The transformer body is delegated tomergekit, so every method it supports (linear,slerp,task_arithmetic,ties,dare_ties,dare_linear,breadcrumbs,della,model_stock,sce, ...) is exposed for free. Around the body, we handle the ST module stack ourselves: validate that the inputs match, merge weight-bearing modules at the state-dict level, copy the stateless ones, and write the ST-specific top-level files mergekit doesn't know about (modules.json,config_sentence_transformers.json,sentence_bert_config.json). The result reloads cleanly withcls(output_path)and the per-class task head is preserved.Example usage
Same signature for
SparseEncoder.merge(...)andCrossEncoder.merge(...).How it works
The merge runs in three phases:
modules.json— same classes, same order, same path. Per-moduleconfig.jsons are compared after canonicalizing through the Module class (load_config→ constructor →get_config_dict), so older saves that predate newer pooling keys (pooling_mode_weightedmean_tokens,include_prompt, ...) compare equal to newer saves with the same effective config. If one target model hasmodules.jsonand others don't, the merge is rejected with a clear error; if all target models have (or all don't have)modules.json, the merge proceeds normally.Asym/Routerand asave_in_root=Falsefirst module raiseNotImplementedErrorup front; other non-standard module types (e.g.BoW,StaticEmbedding,CLIPModel) aren't explicitly rejected here but will fail later inside mergekit / ST loading.MergeConfigurationfrom the user's args andrun_mergeit to a temp directory; then move the outputs intooutput_path.Dense,LayerNorm,WeightedLayerPooling) are merged state-dict-wise; stateless ones (Pooling,Normalize,Dropout,SpladePooling) are validated and copied from the first input. A subdirectory that's missing on disk (HF Hub doesn't track empty dirs — common forNormalize) or present-but-empty is treated as a stateless empty module, so we justmkdirthe target and continue. Finally, any top-level sidecar from the first input that mergekit didn't write is copied over — that's the ST configs (modules.json,config_sentence_transformers.json,sentence_bert_config.json) and multimodal extras (preprocessor_config.json,processor_config.json,chat_template.json, ...).Module dispatch table
Transformer/MLMTransformer(HF body)Pooling/Normalize/Dropout/SpladePoolingDense/LayerNorm/WeightedLayerPoolingtask_arithmeticdeltas whenbase_modelis givenAsym/RouterBoW/StaticEmbedding/CLIPModeland other non-standard module typesDefault weights
The natural "uniform" default depends on the method:
weights=Nonedefaults tolinear,slerp)[1/n] * ntask_arithmetic,ties,dare_*,breadcrumbs,della)[1.0] * n[0.5, 0.5]would silently halve every deltaThe mergekit
parameters.normalizeflag is set per-method so user-supplied scaling carries through end-to-end for delta methods too.Problem & Actions
ST-side modules merge linearly
For weight-bearing ST modules (
Dense,LayerNorm,WeightedLayerPooling), onlylinearandtask_arithmetichave state-dict implementations.task_arithmeticfurther requires thebase_modelto also have the matching module on disk so we can read its tensors as the base; otherwise we fall back to linear here too. Forties/dare_*/slerpthe body uses that method but these ST-side tensors are merged with a plain linear weighted average and aWARNINGis logged. Most ST checkpoints don't have a Dense head (the common stack is body + Pooling + Normalize), so the practical footprint is small, but worth flagging.mergekit 0.1.4 bugs hit during integration
Two upstream bugs are worked around in
_require_mergekit():Pydantic forward refs not resolved on
ConfiguredModuleArchitecture/ConfiguredModelArchitecture. First call raisesPydanticUserError: ... is not fully defined; ... call ConfiguredModuleArchitecture.model_rebuild(). Reported as arcee-ai/mergekit#681 — still open, no maintainer reply. The fix is to import torch and callmodel_rebuild()once before invoking mergekit.mergekit/plan.py:172tuple/list concatenationTypeError.WeightInfo.aliasesisTuple[str, ...](the pydantic model needs to stay hashable), butplan.pydoes[w_in.name] + (w_in.aliases or [])and crashes. Only fires for architectures whose JSON usesaliases— basically BERT variants (LayerNorm.beta/gammalegacy aliases). I couldn't find related reports for this anywhere. The workaround pullsplan_tensor's source viainspect.getsource, swaps+for*unpacking, andexecs it back into mergekit. Sentinel-gated and logs aWARNINGso it's visible.Both workarounds can be deleted once upstream is fixed.
Reconciling mergekit's arch registry with the input checkpoints
mergekit ships a set of hand-crafted architecture JSON files (e.g.
bert-masked-lm.json) that list which weights to merge. These can disagree with what's actually in the input checkpoints in two ways:bert-masked-lm.jsonrequiresbert.pooler.dense.{weight,bias}, butBertForMaskedLMcheckpoints commonly don't save the pooler. Mergekit aborts withTensor X required but not present.bert-masked-lm.jsondoesn't listcls.predictions.transform.*, so those tensors silently get dropped from the merged output.Before calling
run_merge, we scan the input safetensors keys and compare against the arch's listed weights:NAME_TO_ARCH, which makes mergekit fall back to its auto-inference path. Auto-inference lists every input tensor and merges them all, which is the safer default whenever the hand-crafted arch is incomplete.Compatibility / refusal cases (tested)
modules.jsonpooling_mode="cls", another"mean"modules.jsonpresencemethod=base_modelbase_modelslerp/nuslerp/nearswaplen(set(models)) < 2slerpwithbase_modelnot inmodelslen(models) < 2output_path=NoneValidation
I ran a linear merging on a few embedding/reranker model pairs:
intfloat/multilingual-e5-large⊕nlpai-lab/KoE5pooling_mode_weightedmean_tokens,pooling_mode_lasttoken,include_prompt. multilingual-e5-large is also missing the2_Normalize/directory entirely (empty dir not tracked on the Hub) — handled by the empty-stateless-module logic.BAAI/bge-m3⊕nlpai-lab/KURE-v12_Normalize/.yjoonjang/splade-ko-v1⊕telepix/PIXIE-Splade-v1.5ModernBertForMaskedLMwith pooler weights present. End-to-end SparseEncoder merge works.naver/splade-cocondenser-ensembledistil⊕tomaarsen/splade-cocondenser-ensembledistil-stsembedding_dimension=Nonein itsSpladePoolingconfig while the other has30522— that field is lazy-initialized at first forward pass, so the canonical-config validator now treatsNoneas a wildcard; (b) mergekit'sbert-masked-lm.jsondoesn't listcls.predictions.transform.*, so without intervention those weights would silently drop from the merged output — fixed by falling back to auto-inference when the arch's listed weights don't cover the inputs.BAAI/bge-reranker-v2-m3⊕dragonkue/bge-reranker-v2-m3-koAutoModelForSequenceClassification— nomodules.jsonon either. Handled by treating modules.json-less inputs as body-only.Qwen/Qwen3-VL-Embedding-2B⊕whybe-choi/Qwen3-VL-Embedding-2B-ko-vdr-preview-v0.2preprocessor_config.json,processor_config.json,chat_template.{json,jinja}, ...) need to come from the first input separately, or the merged dir won't reload. The copy step now picks up every top-level sidecar from the first input that mergekit didn't already write, while explicitly skipping weight files so the merged shards aren't shadowed.Companion examples live under
examples/sentence_transformer/training/other/model_merging.py,examples/cross_encoder/training/other/model_merging.py, andexamples/sparse_encoder/training/other/model_merging.py.Files changed
sentence_transformers/base/merging.pysentence_transformers/base/model.pyBaseModel.mergeclassmethodpyproject.tomlmerge = ["mergekit>=0.0.5"]optional extratests/test_merging.pyexamples/sentence_transformer/training/other/model_merging.pyexamples/cross_encoder/training/other/model_merging.pyexamples/sparse_encoder/training/other/model_merging.pyTODO / open questions
plan.py:172tuple/list bug in mergekit; both workarounds in this PR can be dropped once mergekit ships a fix.ties/dare_*/slerpfor ST-side weight modules (Dense/LayerNorm/...) instead of always falling back to linear.SentenceTransformer.from_merge_config(yaml_path))?Note: Claude Code helped with parts of the implementation and drafted the test cases; everything's been manually reviewed and verified end-to-end. But there may be something I could have missed, so feel free to add/delete/update features.
Thank you!