Skip to content

Adding support for LMK pooling#3642

Open
meetdoshi90 wants to merge 9 commits into
huggingface:mainfrom
meetdoshi90:lmk_pr
Open

Adding support for LMK pooling#3642
meetdoshi90 wants to merge 9 commits into
huggingface:mainfrom
meetdoshi90:lmk_pr

Conversation

@meetdoshi90
Copy link
Copy Markdown

Hi,
This PR adds support for LMK pooling, as described in our work here: (https://arxiv.org/pdf/2601.21525). This pooling mechanism is quite simple and improves long-context embedding performance with minimal overhead. Kindly review.

meet@ibm.com;005JQU744;Meet Doshi added 3 commits January 30, 2026 07:28
@tomaarsen
Copy link
Copy Markdown
Member

tomaarsen commented Feb 2, 2026

Hello!

This is very cool work, and very intuitive too! My main initial concerns are regarding the new tokenize method: I would rather use super().tokenize(texts=texts, padding=padding) and then update the input_ids/attention_mask from there. This likely gives more freedom eventually when #3554 is merged and multimodality is introduced, which replaces the tokenize method with a preprocess. If the tokenize in LandmarkTransformer can simply update the input_ids etc. from the superclass, then it should be simpler for that refactor.

I'm also wary of e.g. always adding a CLS token 'manually', instead of relying on the special tokens from the tokenizer. Then it's very possible the upstream model was also trained with a EOS token, but that's not preserved here. But sadly it's not simple to update the entire input_ids tensor due to the mixed granularity per row. Have you ablated whether that helps/hurts?

I also ran a training script for experimentation:

import random
import logging
from datasets import load_dataset, Dataset
from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
    SentenceTransformerModelCardData,
)
from sentence_transformers.losses import CachedMultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers
from sentence_transformers.evaluation import InformationRetrievalEvaluator
from sentence_transformers.models import LandmarkTransformer, Pooling, Normalize

logging.basicConfig(
    format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO
)

# 1. Load a model to finetune with 2. (Optional) model card data
landmark_transformer = LandmarkTransformer(
    "microsoft/mpnet-base",
    config_args={"splitter_type": "variable", "splitter_granularity": [32, 64, 128, 256]},
)
pooling = Pooling(landmark_transformer.get_word_embedding_dimension(), "lmk")
normalize = Normalize()
model = SentenceTransformer(
    modules=[landmark_transformer, pooling, normalize],
    model_card_data=SentenceTransformerModelCardData(
        language="en",
        license="apache-2.0",
        model_name="MPNet base with Landmark Pooling trained on GooAQ triplets using CachedMultipleNegativesRankingLoss with GradCache",
    ),
)

# 3. Load a dataset to finetune on
dataset = load_dataset("sentence-transformers/gooaq", split="train").select(range(100_000))
dataset = dataset.add_column("id", range(len(dataset)))
dataset_dict = dataset.train_test_split(test_size=10_000, seed=12)
train_dataset: Dataset = dataset_dict["train"]

eval_dataset: Dataset = dataset_dict["test"]

# 4. Define a loss function
loss = CachedMultipleNegativesRankingLoss(model, mini_batch_size=64)

# 5. (Optional) Specify training arguments
run_name = "mpnet-base-gooaq-cmnrl-1024bs-lmk"
args = SentenceTransformerTrainingArguments(
    # Required parameter:
    output_dir=f"models/{run_name}",
    # Optional training parameters:
    num_train_epochs=1,
    per_device_train_batch_size=1024,
    per_device_eval_batch_size=1024,
    learning_rate=2e-5 * 4,
    warmup_ratio=0.1,
    fp16=False,  # Set to False if you get an error that your GPU can't run on FP16
    bf16=True,  # Set to True if you have a GPU that supports BF16
    batch_sampler=BatchSamplers.NO_DUPLICATES,  # CachedMultipleNegativesRankingLoss benefits from no duplicate samples in a batch
    # Optional tracking/debugging parameters:
    eval_strategy="steps",
    eval_steps=0.1,
    save_strategy="steps",
    save_steps=0.1,
    save_total_limit=2,
    logging_steps=0.05,
    logging_first_step=True,
    run_name=run_name,  # Will be used in W&B if `wandb` is installed
)

# 6. (Optional) Create an evaluator & evaluate the base model
# The full corpus, but only the evaluation queries
corpus = dict(zip(dataset["id"], dataset["answer"]))
random.seed(12)
queries = dict(zip(eval_dataset["id"], eval_dataset["question"]))
corpus = (
    {qid: dataset[qid]["answer"] for qid in queries}
    # {qid: dataset[qid]["answer"] for qid in queries} |
    # {qid: dataset[qid]["answer"] for qid in random.sample(range(len(dataset)), 20_000)}
)
relevant_docs = {qid: {qid} for qid in eval_dataset["id"]}
dev_evaluator = InformationRetrievalEvaluator(
    corpus=corpus,
    queries=queries,
    relevant_docs=relevant_docs,
    show_progress_bar=True,
    name="gooaq-dev",
)
dev_evaluator(model)

# 7. Create a trainer & train
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset.remove_columns("id"),
    eval_dataset=eval_dataset.remove_columns("id"),
    loss=loss,
    evaluator=dev_evaluator,
)
trainer.train()

# (Optional) Evaluate the trained model on the evaluator after training
dev_evaluator(model)

# 8. Save the trained model
model.save_pretrained(f"models/{run_name}/final")

# 9. (Optional) Push it to the Hugging Face Hub
model.push_to_hub(run_name)

I'll have to look more into this in the coming days!

  • Tom Aarsen

@meetdoshi90
Copy link
Copy Markdown
Author

Thank you for taking the time to try this out.

One important detail that may have been missed in your earlier code is passing lmk_token_id to the pooling mechanism. Without this, the pooling step does not know which tokens to aggregate and therefore falls back to mean pooling. I agree that manually passing the LMK token ID is somewhat clumsy at the moment, and I'm very open to suggestions on how this could be improved.

I will also update the code based on your suggestion. In addition, I have included an ablation comparing per-row variable LMK granularity with per-batch variable LMK granularity. Both approaches work, but per-row granularity performs better overall.

Finally, I extended your code by switching to the MS MARCO training set and using ModernBERT to support evaluation with sequences longer than 8k tokens. This should allow you to verify the results on long-context benchmarks such as LongEmbed.

Below are the results after fine-tuning on a single GPU:

Dataset Metric CLS Mean LMK-variable-row LMK-variable-batch
MSMarco (10k samples, dev) NDCG@10 90.8 90.3 90.4 89.5
NeedleRetrieval Precision@1 31.0 54.8 65.0 59.3
PasskeyRetrieval Precision@1 47.5 58.3 85.3 78.5
QMSumm NDCG@10 17.8 31.6 35.2 35.1
SumScreenFD NDCG@10 54.6 83.0 88.7 88.3
WikiMQA NDCG@10 45.6 49.8 54.3 55.5
NarrativeQA NDCG@10 19.0 24.3 39.2 39.7
import random
import logging
from datasets import load_dataset, Dataset
from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
    SentenceTransformerModelCardData,
)
from sentence_transformers.losses import CachedMultipleNegativesRankingLoss,MultipleNegativesRankingLoss 
from sentence_transformers.training_args import BatchSamplers
from sentence_transformers.evaluation import InformationRetrievalEvaluator
from sentence_transformers.evaluation import TripletEvaluator
from sentence_transformers.models import Transformer, LandmarkTransformer, Pooling, Normalize
import mteb
import torch
import sys

logging.basicConfig(
    format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO
)

POOLING = sys.argv[1] #'cls' or 'mean' or 'lmk'
print(POOLING)

# Step 1/2 Load model and pooling mechanism
if POOLING=='lmk':
    landmark_transformer = LandmarkTransformer(
        "answerdotai/ModernBERT-base",
        config_args={"splitter_type": "variable", "splitter_granularity": [32, 64, 128, 256]},
        # config_args={"splitter_type": "fixed", "splitter_granularity": 32},
    )
    assert landmark_transformer.tokenizer.sep_token_id is not None
    print(landmark_transformer.max_seq_length)
    pooling = Pooling(landmark_transformer.get_word_embedding_dimension(), pooling_mode="lmk", lmk_token_id=landmark_transformer.lmk_token_id, include_prompt=False)
    normalize = Normalize()
    model = SentenceTransformer(
        modules=[landmark_transformer, pooling, normalize],
    )
else:
    transformer = Transformer(
        "answerdotai/ModernBERT-base",
    )
    print(transformer.max_seq_length)
    if POOLING=='cls':
        pooling = Pooling(transformer.get_word_embedding_dimension(), pooling_mode="cls", include_prompt=False)
    elif POOLING=='mean':
        pooling = Pooling(transformer.get_word_embedding_dimension(), pooling_mode="mean", include_prompt=False)
    normalize = Normalize()
    model = SentenceTransformer(
        modules=[transformer, pooling, normalize],
    )


# 3. Load a dataset to finetune on
dataset = load_dataset("sentence-transformers/msmarco-bm25", 'triplet', split="train").select(range(500_000))
dataset = dataset.add_column("id", range(len(dataset)))
dataset_dict = dataset.train_test_split(test_size=10_000, seed=12)
train_dataset: Dataset = dataset_dict["train"]

eval_dataset: Dataset = dataset_dict["test"]

# 4. Define a loss function
loss = CachedMultipleNegativesRankingLoss(model, mini_batch_size=64)

# 5. (Optional) Specify training arguments
run_name = f"modernbert-base-msmarco-cmnrl-1024bs-{POOLING}"
args = SentenceTransformerTrainingArguments(
    # Required parameter:
    output_dir=f"models/{run_name}",
    # Optional training parameters:
    max_steps=1000, # approx. 2 epochs 500k/1024
    per_device_train_batch_size=1024,
    per_device_eval_batch_size=1024,
    learning_rate=1e-5,
    warmup_steps=100,
    fp16=False,  # Set to False if you get an error that your GPU can't run on FP16
    bf16=True,  # Set to True if you have a GPU that supports BF16
    batch_sampler=BatchSamplers.NO_DUPLICATES,  # CachedMultipleNegativesRankingLoss benefits from no duplicate samples in a batch
    # Optional tracking/debugging parameters:
    eval_strategy="steps",
    eval_steps=500,
    save_strategy="steps",
    save_steps=500,
    save_total_limit=10,
    logging_steps=5,
    logging_first_step=True,
    run_name=run_name,  # Will be used in W&B if `wandb` is installed
)

# 6. (Optional) Create an evaluator & evaluate the base model
# The full corpus, but only the evaluation queries
corpus = dict(zip(dataset["id"], dataset["positive"]))
random.seed(12)
queries = dict(zip(eval_dataset["id"], eval_dataset["query"]))
corpus = (
    {qid: dataset[qid]["positive"] for qid in queries}
    # {qid: dataset[qid]["answer"] for qid in queries} |
    # {qid: dataset[qid]["answer"] for qid in random.sample(range(len(dataset)), 20_000)}
)
relevant_docs = {qid: {qid} for qid in eval_dataset["id"]}
dev_evaluator = InformationRetrievalEvaluator(
    corpus=corpus,
    queries=queries,
    relevant_docs=relevant_docs,
    show_progress_bar=True,
    name="msmarco_custom_dev",
)
dev_evaluator(model)

# 7. Create a trainer & train
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset.remove_columns("id"),
    eval_dataset=eval_dataset.remove_columns("id"),
    loss=loss,
    evaluator=dev_evaluator,
)
trainer.train()

# ============================
# Post-training evaluations
# ============================

# Change evaluation granularity to fixed for evaluation to remove any randomness
if POOLING=='lmk':
    for module in model.modules():
        if isinstance(module, LandmarkTransformer):
            module.splitter_type = "fixed"
            module.splitter_granularity = 32
            logging.info(
                "Switched LandmarkTransformer to fixed splitter with granularity=32 for evaluation"
            )

# MSMarco custom dev (already defined)
dev_evaluator(model)

# Long Embed evaluation
# increase MSL to 32k 
model.max_seq_length = 32768
TASK_NAMES=['LEMBNeedleRetrieval', 'LEMBPasskeyRetrieval','LEMBQMSumRetrieval','LEMBSummScreenFDRetrieval','LEMBWikimQARetrieval','LEMBNarrativeQARetrieval'] 
for TASK_NAME in TASK_NAMES:
    task = mteb.get_task(TASK_NAME)
    results = mteb.evaluate(
                    model=model,
                    tasks=[task],
                    overwrite_strategy='always',
                    cache=None,
                    encode_kwargs={"batch_size": 2}
                )
    print(f'-'*10)
    print(f'{TASK_NAME} Results:')
    scores = results.task_results[0].scores
    total = 0
    count = 0
    for k in scores:
        total += scores[k][0]['main_score']
        count += 1
    # average over test splits
    print(total/count)

# 8. Save the trained model
model.save_pretrained(f"models/{run_name}/final")

@tomaarsen
Copy link
Copy Markdown
Member

Thanks for preparing and running that script. The per-row granularity is indeed clearly better, we should stick with that. Furthermore, I noticed that splitter_type might not need to be set by the user: the __init__ can simply infer it based on the splitter_granularity: if list: then variable, if int, then fixed.

Another downside that will still exist is that evaluation during training will be done with the same settings as training. A potentially interesting option is to instead let users specify splitter_granularity and an optional train_splitter_granularity defaulting to None. If not None, the latter is used only during training, i.e. if model.training, and the former is used in all other settings. Then, your evaluation and normal inference after training will run with whatever granularity you define there in splitter_granularity. The best setting from your paper would be splitter_granularity=32, train_splitter_granularity=[32, 64, 128, 256]. It might be more confusing though, I'm open to your thoughts.

I think I'm okay with some extra complexity (specifically: having to pass the LMK token to the Pooling) for the training user: we can expect a bit more from them, and we'll have to have some documentation akin to e.g. https://sbert.net/examples/sentence_transformer/training/matryoshka/README.html or https://sbert.net/examples/sentence_transformer/training/prompts/README.html that goes over the differences versus "regular training" either way.

Alternatively, we can add some post-initialization checks:

  1. If exactly one of LandmarkTransformer and Pooling with "lmk" is used, instead of 0 or 2, then we give a warning that there's likely misuse
  2. If Pooling has no lmk_token_id, use the tokenizer.sep_token_id or tokenizer.eos_token_id as a default if it exists, and warn users that this token ID will be used.

I also think we should have the splitter_granularity etc. kwargs as actual kwargs in the LandmarkTransformer __init__ instead of as parameters from the config_kwargs. The config_kwargs options are designed to be passed down to the transformers AutoConfig, which won't recognize them. Also, the saved options in config.json should automatically be passed to the __init__ when loading the module. This currently (presumably) doesn't work.

  • Tom Aarsen

@meetdoshi90
Copy link
Copy Markdown
Author

Hey Tom — you raised a couple of good points that are worth discussing.

  1. I tried your suggestion of calling super().tokenize() and then manually deconstructing and again inserting the LMK tokens. However, this ends up constructing the same tensor twice, which feels suboptimal. If you think this additional overhead is necessary to ensure consistent behaviour across models, I can go this route in the next commit to ensure the existing special tokens are preserved correctly.

  2. I agree with your point about inferring splitter_type automatically. I have updated LandmarkTransformer to accept two parameters: eval_splitter_granularity and train_splitter_granularity (defaulting to None). The model now selects the granularity based on its mode: in training mode, it uses train_splitter_granularity if provided (or loaded from the config), and otherwise falls back to eval_splitter_granularity. We have to clearly note that users must explicitly switch the model to evaluation mode using model.eval(), or use the encode() API, when running inference.

  3. I agree it will be a bug when reloading the model. The splitter configuration needs to be explicitly passed during initialization; otherwise, it is not restored correctly from config_args. I have fixed this to align with the new initialization arguments.

meet@ibm.com;005JQU744;Meet Doshi added 2 commits February 5, 2026 21:35
@tomaarsen
Copy link
Copy Markdown
Member

tomaarsen commented Feb 19, 2026

I made some local refactors, and reran (roughly) the script from #3642 (comment), resulting in: https://huggingface.co/tomaarsen/mpnet-base-gooaq-cmnrl-1024bs-lmk-v2#training-logs (0.8537 NDCG@10)

Previously, I trained:

If you notice the before-training results, you'll see that the previous 2 models both score about 0.2172, compared to 0.0016 with the new model. I discovered the reason: previously the lmk_token_id was kept at -1, and so the previous run didn't use any landmark attention at all. I'm finding that if I try a harder problem, e.g. https://huggingface.co/jhu-clsp/ettin-encoder-150m finetuned on https://huggingface.co/datasets/tomaarsen/miriad-4.4M-split, then I'm not able to escape from the initial evaluation performance of ~0.0050 NDCG@10. Did you also encounter this with your paper?

Granted, this might also be due to a bug on my implementation's side.

Edit: I now updated my implementation to not hardcode the cls token but just to preserve the original tokens by the tokenizer. This results in: https://huggingface.co/tomaarsen/mpnet-base-gooaq-cmnrl-1024bs-lmk-v3#training-logs (0.8601 NDCG@10).

I still can't escape from the performance of ~0.0050 NDCG@10 on the MIRIAD task, though.

  • Tom Aarsen

@meetdoshi90
Copy link
Copy Markdown
Author

Yes, the pooling initialization requires a valid lmk_token_id to function correctly; otherwise, it silently falls back to mean pooling. I have now added an explicit warning for this case, as you suggested.

Regarding the initial evaluation performance: If you mean before training, low scores are expected unless the model has been trained for retrieval. The behavior largely depends on how the base model was pretrained and how much signal the special tokens received during pretraining.

For example, with google-bert/bert-base-cased, both LMK and mean pooling outperform CLS before training:

base_model = google-bert/bert-base-cased

Dev GooAQ performance before training:
CLS   : NDCG@10 = 0.0203
LMK   : NDCG@10 = 0.1592
Mean  : NDCG@10 = 0.1641

My intuition is that this happens because BERT’s SEP token was trained with the NSP objective and therefore carries more importance. As a result, for bert, SEP based LMK pooling shows better performance than CLS before retrieval fine-tuning.

In contrast, with answerdotai/ModernBERT-base, neither CLS nor LMK performs well before training:

base_model = answerdotai/ModernBERT-base

Dev GooAQ performance before training:
CLS   : NDCG@10 = 0.0037
LMK   : NDCG@10 = 0.0021
Mean  : NDCG@10 = 0.1651

Because, for ModernBERT, special tokens were not trained with any special objectives. Mean pooling performs significantly better because it aggregates information from all tokens.

For your case with microsoft/mpnet-base, the choice of LMK token also matters:

base_model = microsoft/mpnet-base
LMK = eos (token id 2)

Dev GooAQ performance before training:
CLS                                 : NDCG@10 = 0.0671
LMK (lmk_token_id = 2, </s> token)  : NDCG@10 = 0.0016
LMK (lmk_token_id = 0, <s> token)   : NDCG@10 = 0.0210
Mean                                : NDCG@10 = 0.2155

Switching the LMK token from </s> to <s> improves performance from 0.0016 to 0.0210, which suggests that this effect is model dependent rather than an inherent issue with LMK pooling itself.

After training, the differences largely disappear:

Dev GooAQ performance after training:
CLS                                 : NDCG@10 = 0.8549
LMK (lmk_token_id = 2, </s> token)  : NDCG@10 = 0.8520
LMK (lmk_token_id = 0, <s> token)   : NDCG@10 = 0.8543
Mean                                : NDCG@10 = 0.8645

Once trained for retrieval, LMK pooling works as expected.

Overall, I would not worry about initial performance differences, as they are most likely driven by the base model’s pretraining objectives and how special tokens were treated, rather than by the pooling mechanism itself.

@tomaarsen
Copy link
Copy Markdown
Member

tomaarsen commented Feb 20, 2026

I've pushed my changes as described in my previous comment. We can always revert them if we don't want to go in that direction, and I'm open to your feedback. The new implementation uses super().tokenize and then inserts LMK tokens per-line.

I understand that the mix of adding new tokens that the model isn't (always) familiar with + only pooling on those tokens means that it's much trickier for the model to get started with training, but I'm still struggling with the setup where the model doesn't train at all. It's possible that for some models, the SEP token was never used in training at all, and so it still has random embedding weights. Then, the model will perform very poorly, and sometimes never recover. The Ettin + MIRIAD training script that I pushed seems to have this happen. Maybe then we need to (re-)initialize the embedding weights of the LMK token? It's a bit risky either way, because this will definitely sometimes hurt performance.

Alternatively, maybe it can be helped with a lower learning rate or something. I'll try Ettin + GooAQ pairs to see if the issue is with Ettin or MIRIAD (or the combination perhaps).

Definitely let me know your thoughts!

Edit: Looks like Ettin doesn't work well compared to mpnet-base. If I use it with the GooAQ script, the loss stays high and the evaluation performance stays at 0, also with a lowered learning rate.

  • Tom Aarsen

@meetdoshi90
Copy link
Copy Markdown
Author

meetdoshi90 commented Feb 20, 2026

This is a bit strange. I am also seeing that training gets stuck when using the SEP token (50282) on Ettin with the GooAQ data. The loss consistently collapses to 6.93, which is exactly ln(1024) (the batch size). That might indicate representation collapse, meaning all LMK embeddings are converging to the same vector.

To debug this for Ettin, I tried changing the LMK token to a more commonly used token like [CLS] (50281), [MASK] (50284), and also to a relatively unused token like [unused0] (50285). In all the cases, the loss decreases normally, and retrieval performance improves, instead of getting stuck at that collapse point.

I am not sure why this happens specifically with Ettin’s SEP token. The same SEP token setup works without issues for BERT, RoBERTa, ModernBERT-base, and GTE-en-mlm-base.

I need to look into this more. If this turns out to be specific to certain models, we already allow users to specify the LMK token ID manually. So as a workaround, if the loss collapses, they can switch to a different token or initialize a fresh embedding for the LMK token.

Edit: I increased the number of epochs to 5 for Ettin with GooAQ using the SEP token as the landmark token, and the performance starts to improve again. However, representation collapse should not have happened in the first place. I suspect this is very model-specific, since reducing the batch size to 256 with 1 epoch also resolved the issue.

@tomaarsen
Copy link
Copy Markdown
Member

tomaarsen commented Feb 23, 2026

To debug this for Ettin, I tried changing the LMK token to a more commonly used token like [CLS] (50281), [MASK] (50284), and also to a relatively unused token like [unused0] (50285). In all the cases, the loss decreases normally, and retrieval performance improves, instead of getting stuck at that collapse point.

Perhaps Ettin its SEP token is completely untrained? I checked out its std/mean/max and it seemed in line with the rest of the embeddings at a glance. I indeed think it's likely model-specific. I'll try a bit more.

Thank you for testing this out with me!

  • Tom Aarsen

@meetdoshi90
Copy link
Copy Markdown
Author

Hi Tom,

Apologies for the delayed response. To confirm, I ran experiments across a set of training models and datasets comparing CLS and LMK pooling. I am attaching the script and results below.

It appears to be an Ettin-specific issue and occurs across all Ettin models except the 32M model (where it degrades performance). I also tested replacing the LMK token with a token other than SEP for Ettin, and it works correctly in that case.

However, it would be useful to mention this in the documentation, suggesting that the special token be replaced if representation collapse occurs. Please let me know if you have any suggestions.

Results (NDCG@10):

Datasets -> eli5 eli5 gooaq gooaq hotpotqa hotpotqa msmarco msmarco nq nq trivia trivia
Model cls lmk cls lmk cls lmk cls lmk cls lmk cls lmk
ModernBERT-base 71.3 67.8 86.2 83.9 91.2 89.1 92.2 90.7 87.3 86.4 78.0 76.7
bert-base-cased 61.8 59.3 76.6 75.8 89.4 89.1 81.4 80.9 79.6 78.4 75.6 74.3
ettin-encoder-150m 70.8 0.2 ❌ 85.3 1.0 ❌ 90.2 2.1 ❌ 91.4 0.7 ❌ 86.1 1.8 ❌ 77.0 0.0 ❌
ettin-encoder-17m 47.3 0.6 ❌ 65.3 0.8 ❌ 83.4 1.1 ❌ 74.3 0.6 ❌ 67.2 0.4 ❌ 63.2 0.2 ❌
ettin-encoder-32m 57.2 53.8 76.5 73.2 88.0 86.5 84.7 82.3 79.1 76.5 72.1 69.2
ettin-encoder-400m 83.8 45.2 92.2 84.8 95.0 3.7 ❌ 95.4 2.8 ❌ 92.3 75.0 85.8 5.2 ❌
gte-en-mlm-base 68.9 67.8 86.7 86.2 90.7 90.5 92.4 92.1 88.5 87.9 79.4 79.5
nomic-bert-2048 66.4 66.0 84.3 84.3 90.7 90.7 91.7 91.7 88.9 89.0 79.3 79.4
roberta-base 68.5 68.2 83.5 83.4 90.0 89.9 89.5 89.6 85.8 86.0 77.4 77.5
xlm-roberta-base 25.5 44.1 40.6 62.2 66.1 82.2 51.9 77.8 40.7 70.2 48.7 64.4

Script:

import random
import logging
import sys
from datasets import load_dataset, Dataset
from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
    SentenceTransformerModelCardData,
)
from sentence_transformers.losses import CachedMultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers
from sentence_transformers.evaluation import InformationRetrievalEvaluator
from sentence_transformers.models import Transformer, LandmarkTransformer, Pooling, Normalize

logging.basicConfig(
    format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO
)

# ── CLI args ──────────────────────────────────────────────────────────────────
if len(sys.argv) != 4:
    print("Usage: python train_lmk.py <pooling> <base_model> <dataset>")
    print("  pooling   : cls | mean | lmk")
    print("  base_model: e.g. jhu-clsp/ettin-encoder-150m")
    print("  dataset   : gooaq | msmarco | nq | eli5 | trivia | hotpotqa | miracl | mldr")
    sys.exit(1)

POOLING    = sys.argv[1]   # cls | mean | lmk
BASE_MODEL = sys.argv[2]   # HF model id
DATASET    = sys.argv[3]   # short name



# ── Dataset registry ──────────────────────────────────────────────────────────
# Each entry: (hf_path, subset, split, query_col, doc_col, n_train, n_eval)
DATASET_CONFIG = {
    "gooaq":    ("sentence-transformers/gooaq",         None,          "train", "question", "answer",   90_000, 10_000),
    "msmarco":  ("sentence-transformers/msmarco-bm25",  "triplet",     "train", "query",    "positive", 90_000, 10_000),
    "nq":       ("sentence-transformers/natural-questions", None,      "train", "query",    "answer",   90_000, 10_000),
    "eli5":     ("sentence-transformers/eli5",          None,          "train", "question", "answer",   90_000, 10_000),
    "trivia":   ("sentence-transformers/trivia-qa",     None,          "train", "query",    "answer",   90_000, 10_000),
    "hotpotqa": ("sentence-transformers/hotpotqa",      "triplet",     "train", "anchor",   "positive", 90_000, 10_000),
    "miracl":   ("sentence-transformers/miracl",        "en-triplet",  "train", "anchor",   "positive", 90_000, 10_000),
    "mldr":     ("sentence-transformers/mldr",          "en-triplet",  "train", "anchor",   "positive", 90_000, 10_000),
}

if DATASET not in DATASET_CONFIG:
    raise ValueError(f"Unknown dataset '{DATASET}'. Choose from: {list(DATASET_CONFIG)}")

hf_path, subset, split, query_col, doc_col, n_train, n_eval = DATASET_CONFIG[DATASET]

# ── Load & prepare dataset (streaming) ───────────────────────────────────────
logging.info(f"Loading dataset (streaming): {hf_path} (subset={subset})")
raw = (
    load_dataset(hf_path, subset, split=split, streaming=True)
    if subset
    else load_dataset(hf_path, split=split, streaming=True)
)

# Shuffle with a fixed seed before splitting
raw = raw.shuffle(seed=12, buffer_size=10_000)

# Rename columns to canonical names so the trainer sees "anchor"/"positive"
if query_col != "anchor":
    raw = raw.rename_column(query_col, "anchor")
if doc_col != "positive":
    raw = raw.rename_column(doc_col, "positive")

raw = raw.select_columns(["anchor", "positive"])

# Split: first n_eval rows → eval, next n_train rows → train
eval_iterable  = raw.take(n_eval)
train_dataset  = raw.skip(n_eval).take(n_train)   # IterableDataset for trainer

logging.info(f"Materialising eval set ({n_eval} rows) …")
eval_rows     = list(eval_iterable)
eval_corpus   = {i: row["positive"] for i, row in enumerate(eval_rows)}
eval_queries  = {i: row["anchor"]   for i, row in enumerate(eval_rows)}

eval_dataset  = Dataset.from_list(eval_rows)

# ── Build model ───────────────────────────────────────────────────────────────
MAX_SEQ_LENGTH = 512
logging.info(f"Building model: {BASE_MODEL}  pooling={POOLING}")

if POOLING == "lmk":
    backbone = LandmarkTransformer(
        BASE_MODEL,
        max_seq_length=MAX_SEQ_LENGTH,
        eval_splitter_granularity=32,
        train_splitter_granularity=[32, 64, 128, 256],
        model_args={"trust_remote_code": True},
        config_args={"trust_remote_code": True},
    )
    pooling = Pooling(
        backbone.get_word_embedding_dimension(),
        pooling_mode="lmk",
        lmk_token_id=backbone.lmk_token_id,
        include_prompt=False,
    )
    normalize = Normalize()
    model = SentenceTransformer(modules=[backbone, pooling, normalize])
else:
    backbone = Transformer(
        BASE_MODEL,
        max_seq_length=MAX_SEQ_LENGTH,
        model_args={"trust_remote_code": True},
        config_args={"trust_remote_code": True},
    )
    pooling = Pooling(
        backbone.get_word_embedding_dimension(),
        pooling_mode=POOLING,       # "cls" or "mean"
        include_prompt=False,
    )
    normalize = Normalize()
    model = SentenceTransformer(modules=[backbone, pooling, normalize])

# ── Loss ──────────────────────────────────────────────────────────────────────
loss = CachedMultipleNegativesRankingLoss(model, mini_batch_size=64)

# ── Training args ─────────────────────────────────────────────────────────────
model_short = BASE_MODEL.split("/")[-1]
run_name = f"{model_short}-{DATASET}-cmnrl-{POOLING}"

NUM_EPOCHS = 5
BATCH_SIZE = 1024
steps_per_epoch = n_train // BATCH_SIZE
max_steps = steps_per_epoch * NUM_EPOCHS

args = SentenceTransformerTrainingArguments(
    output_dir=f"lmk_stress_test/{run_name}",
    max_steps=max_steps,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    learning_rate=1e-5,
    warmup_ratio=0.1,
    fp16=False,
    bf16=True,
    batch_sampler=BatchSamplers.NO_DUPLICATES,
    eval_strategy="steps",
    eval_steps=steps_per_epoch // 5,       # ~5 evals per epoch
    save_strategy="steps",
    save_steps=steps_per_epoch // 5,
    save_total_limit=1,
    logging_steps=10,
    logging_first_step=True,
    run_name=run_name,
)

# ── Evaluator ─────────────────────────────────────────────────────────────────
relevant_docs = {qid: {qid} for qid in eval_queries}

dev_evaluator = InformationRetrievalEvaluator(
    corpus=eval_corpus,
    queries=eval_queries,
    relevant_docs=relevant_docs,
    show_progress_bar=True,
    name=f"{DATASET}-dev",
)

logging.info("Evaluating base model …")
dev_evaluator(model)

# ── Train ─────────────────────────────────────────────────────────────────────
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,   # IterableDataset — no id column to drop
    eval_dataset=eval_dataset,     # materialised Dataset for per-step eval loss
    loss=loss,
    evaluator=dev_evaluator,
)
trainer.train()

logging.info("Evaluating fine-tuned model …")
dev_evaluator(model)

# ── Save ──────────────────────────────────────────────────────────────────────
model.save_pretrained(f"lmk_stress_test/{run_name}/final")
logging.info(f"Saved to lmk_stress_test/{run_name}/final")

@ddofer
Copy link
Copy Markdown
Contributor

ddofer commented Mar 11, 2026

Would it be possible to control/define the "special tokens" list to attend to (in terms of the final mean pooling, NOT in terms of the addition of the LMK/ tokens)?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants