Adding support for LMK pooling#3642
Conversation
|
Hello! This is very cool work, and very intuitive too! My main initial concerns are regarding the new I'm also wary of e.g. always adding a CLS token 'manually', instead of relying on the special tokens from the tokenizer. Then it's very possible the upstream model was also trained with a EOS token, but that's not preserved here. But sadly it's not simple to update the entire I also ran a training script for experimentation: import random
import logging
from datasets import load_dataset, Dataset
from sentence_transformers import (
SentenceTransformer,
SentenceTransformerTrainer,
SentenceTransformerTrainingArguments,
SentenceTransformerModelCardData,
)
from sentence_transformers.losses import CachedMultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers
from sentence_transformers.evaluation import InformationRetrievalEvaluator
from sentence_transformers.models import LandmarkTransformer, Pooling, Normalize
logging.basicConfig(
format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO
)
# 1. Load a model to finetune with 2. (Optional) model card data
landmark_transformer = LandmarkTransformer(
"microsoft/mpnet-base",
config_args={"splitter_type": "variable", "splitter_granularity": [32, 64, 128, 256]},
)
pooling = Pooling(landmark_transformer.get_word_embedding_dimension(), "lmk")
normalize = Normalize()
model = SentenceTransformer(
modules=[landmark_transformer, pooling, normalize],
model_card_data=SentenceTransformerModelCardData(
language="en",
license="apache-2.0",
model_name="MPNet base with Landmark Pooling trained on GooAQ triplets using CachedMultipleNegativesRankingLoss with GradCache",
),
)
# 3. Load a dataset to finetune on
dataset = load_dataset("sentence-transformers/gooaq", split="train").select(range(100_000))
dataset = dataset.add_column("id", range(len(dataset)))
dataset_dict = dataset.train_test_split(test_size=10_000, seed=12)
train_dataset: Dataset = dataset_dict["train"]
eval_dataset: Dataset = dataset_dict["test"]
# 4. Define a loss function
loss = CachedMultipleNegativesRankingLoss(model, mini_batch_size=64)
# 5. (Optional) Specify training arguments
run_name = "mpnet-base-gooaq-cmnrl-1024bs-lmk"
args = SentenceTransformerTrainingArguments(
# Required parameter:
output_dir=f"models/{run_name}",
# Optional training parameters:
num_train_epochs=1,
per_device_train_batch_size=1024,
per_device_eval_batch_size=1024,
learning_rate=2e-5 * 4,
warmup_ratio=0.1,
fp16=False, # Set to False if you get an error that your GPU can't run on FP16
bf16=True, # Set to True if you have a GPU that supports BF16
batch_sampler=BatchSamplers.NO_DUPLICATES, # CachedMultipleNegativesRankingLoss benefits from no duplicate samples in a batch
# Optional tracking/debugging parameters:
eval_strategy="steps",
eval_steps=0.1,
save_strategy="steps",
save_steps=0.1,
save_total_limit=2,
logging_steps=0.05,
logging_first_step=True,
run_name=run_name, # Will be used in W&B if `wandb` is installed
)
# 6. (Optional) Create an evaluator & evaluate the base model
# The full corpus, but only the evaluation queries
corpus = dict(zip(dataset["id"], dataset["answer"]))
random.seed(12)
queries = dict(zip(eval_dataset["id"], eval_dataset["question"]))
corpus = (
{qid: dataset[qid]["answer"] for qid in queries}
# {qid: dataset[qid]["answer"] for qid in queries} |
# {qid: dataset[qid]["answer"] for qid in random.sample(range(len(dataset)), 20_000)}
)
relevant_docs = {qid: {qid} for qid in eval_dataset["id"]}
dev_evaluator = InformationRetrievalEvaluator(
corpus=corpus,
queries=queries,
relevant_docs=relevant_docs,
show_progress_bar=True,
name="gooaq-dev",
)
dev_evaluator(model)
# 7. Create a trainer & train
trainer = SentenceTransformerTrainer(
model=model,
args=args,
train_dataset=train_dataset.remove_columns("id"),
eval_dataset=eval_dataset.remove_columns("id"),
loss=loss,
evaluator=dev_evaluator,
)
trainer.train()
# (Optional) Evaluate the trained model on the evaluator after training
dev_evaluator(model)
# 8. Save the trained model
model.save_pretrained(f"models/{run_name}/final")
# 9. (Optional) Push it to the Hugging Face Hub
model.push_to_hub(run_name)
I'll have to look more into this in the coming days!
|
|
Thank you for taking the time to try this out. One important detail that may have been missed in your earlier code is passing I will also update the code based on your suggestion. In addition, I have included an ablation comparing per-row variable LMK granularity with per-batch variable LMK granularity. Both approaches work, but per-row granularity performs better overall. Finally, I extended your code by switching to the MS MARCO training set and using ModernBERT to support evaluation with sequences longer than 8k tokens. This should allow you to verify the results on long-context benchmarks such as LongEmbed. Below are the results after fine-tuning on a single GPU:
import random
import logging
from datasets import load_dataset, Dataset
from sentence_transformers import (
SentenceTransformer,
SentenceTransformerTrainer,
SentenceTransformerTrainingArguments,
SentenceTransformerModelCardData,
)
from sentence_transformers.losses import CachedMultipleNegativesRankingLoss,MultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers
from sentence_transformers.evaluation import InformationRetrievalEvaluator
from sentence_transformers.evaluation import TripletEvaluator
from sentence_transformers.models import Transformer, LandmarkTransformer, Pooling, Normalize
import mteb
import torch
import sys
logging.basicConfig(
format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO
)
POOLING = sys.argv[1] #'cls' or 'mean' or 'lmk'
print(POOLING)
# Step 1/2 Load model and pooling mechanism
if POOLING=='lmk':
landmark_transformer = LandmarkTransformer(
"answerdotai/ModernBERT-base",
config_args={"splitter_type": "variable", "splitter_granularity": [32, 64, 128, 256]},
# config_args={"splitter_type": "fixed", "splitter_granularity": 32},
)
assert landmark_transformer.tokenizer.sep_token_id is not None
print(landmark_transformer.max_seq_length)
pooling = Pooling(landmark_transformer.get_word_embedding_dimension(), pooling_mode="lmk", lmk_token_id=landmark_transformer.lmk_token_id, include_prompt=False)
normalize = Normalize()
model = SentenceTransformer(
modules=[landmark_transformer, pooling, normalize],
)
else:
transformer = Transformer(
"answerdotai/ModernBERT-base",
)
print(transformer.max_seq_length)
if POOLING=='cls':
pooling = Pooling(transformer.get_word_embedding_dimension(), pooling_mode="cls", include_prompt=False)
elif POOLING=='mean':
pooling = Pooling(transformer.get_word_embedding_dimension(), pooling_mode="mean", include_prompt=False)
normalize = Normalize()
model = SentenceTransformer(
modules=[transformer, pooling, normalize],
)
# 3. Load a dataset to finetune on
dataset = load_dataset("sentence-transformers/msmarco-bm25", 'triplet', split="train").select(range(500_000))
dataset = dataset.add_column("id", range(len(dataset)))
dataset_dict = dataset.train_test_split(test_size=10_000, seed=12)
train_dataset: Dataset = dataset_dict["train"]
eval_dataset: Dataset = dataset_dict["test"]
# 4. Define a loss function
loss = CachedMultipleNegativesRankingLoss(model, mini_batch_size=64)
# 5. (Optional) Specify training arguments
run_name = f"modernbert-base-msmarco-cmnrl-1024bs-{POOLING}"
args = SentenceTransformerTrainingArguments(
# Required parameter:
output_dir=f"models/{run_name}",
# Optional training parameters:
max_steps=1000, # approx. 2 epochs 500k/1024
per_device_train_batch_size=1024,
per_device_eval_batch_size=1024,
learning_rate=1e-5,
warmup_steps=100,
fp16=False, # Set to False if you get an error that your GPU can't run on FP16
bf16=True, # Set to True if you have a GPU that supports BF16
batch_sampler=BatchSamplers.NO_DUPLICATES, # CachedMultipleNegativesRankingLoss benefits from no duplicate samples in a batch
# Optional tracking/debugging parameters:
eval_strategy="steps",
eval_steps=500,
save_strategy="steps",
save_steps=500,
save_total_limit=10,
logging_steps=5,
logging_first_step=True,
run_name=run_name, # Will be used in W&B if `wandb` is installed
)
# 6. (Optional) Create an evaluator & evaluate the base model
# The full corpus, but only the evaluation queries
corpus = dict(zip(dataset["id"], dataset["positive"]))
random.seed(12)
queries = dict(zip(eval_dataset["id"], eval_dataset["query"]))
corpus = (
{qid: dataset[qid]["positive"] for qid in queries}
# {qid: dataset[qid]["answer"] for qid in queries} |
# {qid: dataset[qid]["answer"] for qid in random.sample(range(len(dataset)), 20_000)}
)
relevant_docs = {qid: {qid} for qid in eval_dataset["id"]}
dev_evaluator = InformationRetrievalEvaluator(
corpus=corpus,
queries=queries,
relevant_docs=relevant_docs,
show_progress_bar=True,
name="msmarco_custom_dev",
)
dev_evaluator(model)
# 7. Create a trainer & train
trainer = SentenceTransformerTrainer(
model=model,
args=args,
train_dataset=train_dataset.remove_columns("id"),
eval_dataset=eval_dataset.remove_columns("id"),
loss=loss,
evaluator=dev_evaluator,
)
trainer.train()
# ============================
# Post-training evaluations
# ============================
# Change evaluation granularity to fixed for evaluation to remove any randomness
if POOLING=='lmk':
for module in model.modules():
if isinstance(module, LandmarkTransformer):
module.splitter_type = "fixed"
module.splitter_granularity = 32
logging.info(
"Switched LandmarkTransformer to fixed splitter with granularity=32 for evaluation"
)
# MSMarco custom dev (already defined)
dev_evaluator(model)
# Long Embed evaluation
# increase MSL to 32k
model.max_seq_length = 32768
TASK_NAMES=['LEMBNeedleRetrieval', 'LEMBPasskeyRetrieval','LEMBQMSumRetrieval','LEMBSummScreenFDRetrieval','LEMBWikimQARetrieval','LEMBNarrativeQARetrieval']
for TASK_NAME in TASK_NAMES:
task = mteb.get_task(TASK_NAME)
results = mteb.evaluate(
model=model,
tasks=[task],
overwrite_strategy='always',
cache=None,
encode_kwargs={"batch_size": 2}
)
print(f'-'*10)
print(f'{TASK_NAME} Results:')
scores = results.task_results[0].scores
total = 0
count = 0
for k in scores:
total += scores[k][0]['main_score']
count += 1
# average over test splits
print(total/count)
# 8. Save the trained model
model.save_pretrained(f"models/{run_name}/final") |
|
Thanks for preparing and running that script. The per-row granularity is indeed clearly better, we should stick with that. Furthermore, I noticed that Another downside that will still exist is that evaluation during training will be done with the same settings as training. A potentially interesting option is to instead let users specify I think I'm okay with some extra complexity (specifically: having to pass the LMK token to the Pooling) for the training user: we can expect a bit more from them, and we'll have to have some documentation akin to e.g. https://sbert.net/examples/sentence_transformer/training/matryoshka/README.html or https://sbert.net/examples/sentence_transformer/training/prompts/README.html that goes over the differences versus "regular training" either way. Alternatively, we can add some post-initialization checks:
I also think we should have the
|
|
Hey Tom — you raised a couple of good points that are worth discussing.
|
|
I made some local refactors, and reran (roughly) the script from #3642 (comment), resulting in: https://huggingface.co/tomaarsen/mpnet-base-gooaq-cmnrl-1024bs-lmk-v2#training-logs (0.8537 NDCG@10) Previously, I trained:
If you notice the before-training results, you'll see that the previous 2 models both score about Granted, this might also be due to a bug on my implementation's side. Edit: I now updated my implementation to not hardcode the I still can't escape from the performance of ~0.0050 NDCG@10 on the MIRIAD task, though.
|
|
Yes, the pooling initialization requires a valid Regarding the initial evaluation performance: If you mean before training, low scores are expected unless the model has been trained for retrieval. The behavior largely depends on how the base model was pretrained and how much signal the special tokens received during pretraining. For example, with My intuition is that this happens because BERT’s SEP token was trained with the NSP objective and therefore carries more importance. As a result, for bert, SEP based LMK pooling shows better performance than CLS before retrieval fine-tuning. In contrast, with Because, for ModernBERT, special tokens were not trained with any special objectives. Mean pooling performs significantly better because it aggregates information from all tokens. For your case with Switching the LMK token from After training, the differences largely disappear: Once trained for retrieval, LMK pooling works as expected. Overall, I would not worry about initial performance differences, as they are most likely driven by the base model’s pretraining objectives and how special tokens were treated, rather than by the pooling mechanism itself. |
…ith insertions Also simplify ST's validation and add training scripts
|
I've pushed my changes as described in my previous comment. We can always revert them if we don't want to go in that direction, and I'm open to your feedback. The new implementation uses I understand that the mix of adding new tokens that the model isn't (always) familiar with + only pooling on those tokens means that it's much trickier for the model to get started with training, but I'm still struggling with the setup where the model doesn't train at all. It's possible that for some models, the SEP token was never used in training at all, and so it still has random embedding weights. Then, the model will perform very poorly, and sometimes never recover. The Ettin + MIRIAD training script that I pushed seems to have this happen. Maybe then we need to (re-)initialize the embedding weights of the LMK token? It's a bit risky either way, because this will definitely sometimes hurt performance. Alternatively, maybe it can be helped with a lower learning rate or something. I'll try Ettin + GooAQ pairs to see if the issue is with Ettin or MIRIAD (or the combination perhaps). Definitely let me know your thoughts! Edit: Looks like Ettin doesn't work well compared to mpnet-base. If I use it with the GooAQ script, the loss stays high and the evaluation performance stays at 0, also with a lowered learning rate.
|
|
This is a bit strange. I am also seeing that training gets stuck when using the SEP token (50282) on Ettin with the GooAQ data. The loss consistently collapses to 6.93, which is exactly To debug this for Ettin, I tried changing the LMK token to a more commonly used token like I am not sure why this happens specifically with Ettin’s SEP token. The same SEP token setup works without issues for BERT, RoBERTa, ModernBERT-base, and GTE-en-mlm-base. I need to look into this more. If this turns out to be specific to certain models, we already allow users to specify the LMK token ID manually. So as a workaround, if the loss collapses, they can switch to a different token or initialize a fresh embedding for the LMK token. Edit: I increased the number of epochs to 5 for Ettin with GooAQ using the SEP token as the landmark token, and the performance starts to improve again. However, representation collapse should not have happened in the first place. I suspect this is very model-specific, since reducing the batch size to 256 with 1 epoch also resolved the issue. |
Perhaps Ettin its SEP token is completely untrained? I checked out its std/mean/max and it seemed in line with the rest of the embeddings at a glance. I indeed think it's likely model-specific. I'll try a bit more. Thank you for testing this out with me!
|
|
Hi Tom, Apologies for the delayed response. To confirm, I ran experiments across a set of training models and datasets comparing CLS and LMK pooling. I am attaching the script and results below. It appears to be an Ettin-specific issue and occurs across all Ettin models except the 32M model (where it degrades performance). I also tested replacing the LMK token with a token other than SEP for Ettin, and it works correctly in that case. However, it would be useful to mention this in the documentation, suggesting that the special token be replaced if representation collapse occurs. Please let me know if you have any suggestions. Results (NDCG@10):
Script: import random
import logging
import sys
from datasets import load_dataset, Dataset
from sentence_transformers import (
SentenceTransformer,
SentenceTransformerTrainer,
SentenceTransformerTrainingArguments,
SentenceTransformerModelCardData,
)
from sentence_transformers.losses import CachedMultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers
from sentence_transformers.evaluation import InformationRetrievalEvaluator
from sentence_transformers.models import Transformer, LandmarkTransformer, Pooling, Normalize
logging.basicConfig(
format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO
)
# ── CLI args ──────────────────────────────────────────────────────────────────
if len(sys.argv) != 4:
print("Usage: python train_lmk.py <pooling> <base_model> <dataset>")
print(" pooling : cls | mean | lmk")
print(" base_model: e.g. jhu-clsp/ettin-encoder-150m")
print(" dataset : gooaq | msmarco | nq | eli5 | trivia | hotpotqa | miracl | mldr")
sys.exit(1)
POOLING = sys.argv[1] # cls | mean | lmk
BASE_MODEL = sys.argv[2] # HF model id
DATASET = sys.argv[3] # short name
# ── Dataset registry ──────────────────────────────────────────────────────────
# Each entry: (hf_path, subset, split, query_col, doc_col, n_train, n_eval)
DATASET_CONFIG = {
"gooaq": ("sentence-transformers/gooaq", None, "train", "question", "answer", 90_000, 10_000),
"msmarco": ("sentence-transformers/msmarco-bm25", "triplet", "train", "query", "positive", 90_000, 10_000),
"nq": ("sentence-transformers/natural-questions", None, "train", "query", "answer", 90_000, 10_000),
"eli5": ("sentence-transformers/eli5", None, "train", "question", "answer", 90_000, 10_000),
"trivia": ("sentence-transformers/trivia-qa", None, "train", "query", "answer", 90_000, 10_000),
"hotpotqa": ("sentence-transformers/hotpotqa", "triplet", "train", "anchor", "positive", 90_000, 10_000),
"miracl": ("sentence-transformers/miracl", "en-triplet", "train", "anchor", "positive", 90_000, 10_000),
"mldr": ("sentence-transformers/mldr", "en-triplet", "train", "anchor", "positive", 90_000, 10_000),
}
if DATASET not in DATASET_CONFIG:
raise ValueError(f"Unknown dataset '{DATASET}'. Choose from: {list(DATASET_CONFIG)}")
hf_path, subset, split, query_col, doc_col, n_train, n_eval = DATASET_CONFIG[DATASET]
# ── Load & prepare dataset (streaming) ───────────────────────────────────────
logging.info(f"Loading dataset (streaming): {hf_path} (subset={subset})")
raw = (
load_dataset(hf_path, subset, split=split, streaming=True)
if subset
else load_dataset(hf_path, split=split, streaming=True)
)
# Shuffle with a fixed seed before splitting
raw = raw.shuffle(seed=12, buffer_size=10_000)
# Rename columns to canonical names so the trainer sees "anchor"/"positive"
if query_col != "anchor":
raw = raw.rename_column(query_col, "anchor")
if doc_col != "positive":
raw = raw.rename_column(doc_col, "positive")
raw = raw.select_columns(["anchor", "positive"])
# Split: first n_eval rows → eval, next n_train rows → train
eval_iterable = raw.take(n_eval)
train_dataset = raw.skip(n_eval).take(n_train) # IterableDataset for trainer
logging.info(f"Materialising eval set ({n_eval} rows) …")
eval_rows = list(eval_iterable)
eval_corpus = {i: row["positive"] for i, row in enumerate(eval_rows)}
eval_queries = {i: row["anchor"] for i, row in enumerate(eval_rows)}
eval_dataset = Dataset.from_list(eval_rows)
# ── Build model ───────────────────────────────────────────────────────────────
MAX_SEQ_LENGTH = 512
logging.info(f"Building model: {BASE_MODEL} pooling={POOLING}")
if POOLING == "lmk":
backbone = LandmarkTransformer(
BASE_MODEL,
max_seq_length=MAX_SEQ_LENGTH,
eval_splitter_granularity=32,
train_splitter_granularity=[32, 64, 128, 256],
model_args={"trust_remote_code": True},
config_args={"trust_remote_code": True},
)
pooling = Pooling(
backbone.get_word_embedding_dimension(),
pooling_mode="lmk",
lmk_token_id=backbone.lmk_token_id,
include_prompt=False,
)
normalize = Normalize()
model = SentenceTransformer(modules=[backbone, pooling, normalize])
else:
backbone = Transformer(
BASE_MODEL,
max_seq_length=MAX_SEQ_LENGTH,
model_args={"trust_remote_code": True},
config_args={"trust_remote_code": True},
)
pooling = Pooling(
backbone.get_word_embedding_dimension(),
pooling_mode=POOLING, # "cls" or "mean"
include_prompt=False,
)
normalize = Normalize()
model = SentenceTransformer(modules=[backbone, pooling, normalize])
# ── Loss ──────────────────────────────────────────────────────────────────────
loss = CachedMultipleNegativesRankingLoss(model, mini_batch_size=64)
# ── Training args ─────────────────────────────────────────────────────────────
model_short = BASE_MODEL.split("/")[-1]
run_name = f"{model_short}-{DATASET}-cmnrl-{POOLING}"
NUM_EPOCHS = 5
BATCH_SIZE = 1024
steps_per_epoch = n_train // BATCH_SIZE
max_steps = steps_per_epoch * NUM_EPOCHS
args = SentenceTransformerTrainingArguments(
output_dir=f"lmk_stress_test/{run_name}",
max_steps=max_steps,
per_device_train_batch_size=BATCH_SIZE,
per_device_eval_batch_size=BATCH_SIZE,
learning_rate=1e-5,
warmup_ratio=0.1,
fp16=False,
bf16=True,
batch_sampler=BatchSamplers.NO_DUPLICATES,
eval_strategy="steps",
eval_steps=steps_per_epoch // 5, # ~5 evals per epoch
save_strategy="steps",
save_steps=steps_per_epoch // 5,
save_total_limit=1,
logging_steps=10,
logging_first_step=True,
run_name=run_name,
)
# ── Evaluator ─────────────────────────────────────────────────────────────────
relevant_docs = {qid: {qid} for qid in eval_queries}
dev_evaluator = InformationRetrievalEvaluator(
corpus=eval_corpus,
queries=eval_queries,
relevant_docs=relevant_docs,
show_progress_bar=True,
name=f"{DATASET}-dev",
)
logging.info("Evaluating base model …")
dev_evaluator(model)
# ── Train ─────────────────────────────────────────────────────────────────────
trainer = SentenceTransformerTrainer(
model=model,
args=args,
train_dataset=train_dataset, # IterableDataset — no id column to drop
eval_dataset=eval_dataset, # materialised Dataset for per-step eval loss
loss=loss,
evaluator=dev_evaluator,
)
trainer.train()
logging.info("Evaluating fine-tuned model …")
dev_evaluator(model)
# ── Save ──────────────────────────────────────────────────────────────────────
model.save_pretrained(f"lmk_stress_test/{run_name}/final")
logging.info(f"Saved to lmk_stress_test/{run_name}/final") |
|
Would it be possible to control/define the "special tokens" list to attend to (in terms of the final mean pooling, NOT in terms of the addition of the LMK/ tokens)? |
Hi,
This PR adds support for LMK pooling, as described in our work here: (https://arxiv.org/pdf/2601.21525). This pooling mechanism is quite simple and improves long-context embedding performance with minimal overhead. Kindly review.