Skip to content

Commit 7d87fd4

Browse files
authored
Merge pull request #345 from urchade/feature/prompt_compression
Feature/prompt compression
2 parents 180ce84 + 94894e1 commit 7d87fd4

8 files changed

Lines changed: 680 additions & 19 deletions

File tree

Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
"""Evaluate raw vs compressed-prompt GLiNER on knowledgator/biomed_NER."""
2+
3+
import argparse
4+
import random
5+
import time
6+
7+
import torch
8+
from datasets import load_dataset
9+
10+
from gliner import GLiNER
11+
12+
13+
def predictions_to_ner(text, preds):
14+
"""Map char-offset predictions from model.inference to word-level ner tuples."""
15+
ent_dicts = [{"start": p["start"], "end": p["end"], "class": p["label"]} for p in preds]
16+
return char_to_word_sample(text, ent_dicts)
17+
18+
19+
def distill_finetune(model, distill_data, *, epochs, lr, batch_size, output_dir):
20+
"""Fine-tune `model` on pseudo-labeled `distill_data` via GLiNER.train_model."""
21+
# Attach the full label set so the collator uses it with prepare_labels=True.
22+
model.train_model(
23+
train_dataset=distill_data,
24+
eval_dataset=None,
25+
output_dir=output_dir,
26+
num_train_epochs=epochs,
27+
max_steps=-1, # override create_training_args' default (10000) so num_train_epochs wins
28+
per_device_train_batch_size=batch_size,
29+
learning_rate=lr,
30+
save_strategy="no",
31+
report_to="none",
32+
logging_steps=10,
33+
remove_unused_columns=False,
34+
)
35+
model.eval()
36+
37+
38+
def timed_evaluate(model, eval_data, *, warmup, repeats, device, **eval_kwargs):
39+
"""Run model.evaluate once for metrics and `repeats` times for timing."""
40+
if device.startswith("cuda"):
41+
torch.cuda.synchronize()
42+
out, f1 = model.evaluate(eval_data, **eval_kwargs)
43+
44+
for _ in range(warmup):
45+
model.evaluate(eval_data, **eval_kwargs)
46+
47+
if device.startswith("cuda"):
48+
torch.cuda.synchronize()
49+
times = []
50+
for _ in range(repeats):
51+
t0 = time.perf_counter()
52+
model.evaluate(eval_data, **eval_kwargs)
53+
if device.startswith("cuda"):
54+
torch.cuda.synchronize()
55+
times.append(time.perf_counter() - t0)
56+
57+
mean = sum(times) / len(times)
58+
return out, f1, mean, min(times)
59+
60+
61+
def char_to_word_sample(text, entities):
62+
"""Convert {text, entities:[{class,start,end}]} to {tokenized_text, ner}.
63+
64+
Uses whitespace tokenization and aligns char offsets to word indices.
65+
Entities that don't align to word boundaries are dropped.
66+
"""
67+
words = text.split()
68+
# Build char-start index for each word (assuming single-space separation of split()).
69+
char_starts, char_ends = [], []
70+
cursor = 0
71+
remaining = text
72+
for w in words:
73+
idx = remaining.find(w)
74+
abs_start = cursor + idx
75+
char_starts.append(abs_start)
76+
char_ends.append(abs_start + len(w))
77+
cursor = abs_start + len(w)
78+
remaining = text[cursor:]
79+
80+
start_to_widx = {s: i for i, s in enumerate(char_starts)}
81+
end_to_widx = {e: i for i, e in enumerate(char_ends)}
82+
83+
ner = []
84+
for ent in entities:
85+
s, e, cls = ent["start"], ent["end"], ent["class"].lower()
86+
# Tolerate leading/trailing whitespace inside span
87+
span_text = text[s:e]
88+
ls = len(span_text) - len(span_text.lstrip())
89+
le = len(span_text) - len(span_text.rstrip())
90+
s2, e2 = s + ls, e - le
91+
if s2 in start_to_widx and e2 in end_to_widx:
92+
ner.append((start_to_widx[s2], end_to_widx[e2], cls))
93+
return {"tokenized_text": words, "ner": ner}
94+
95+
96+
def main():
97+
parser = argparse.ArgumentParser()
98+
parser.add_argument("--model", default="gliner-community/gliner_small-v2.5")
99+
parser.add_argument("--dataset", default="knowledgator/biomed_NER")
100+
parser.add_argument("--split", default="train")
101+
parser.add_argument("--eval_size", type=int, default=3000)
102+
parser.add_argument("--compress_size", type=int, default=1000)
103+
parser.add_argument("--batch_size", type=int, default=4)
104+
parser.add_argument("--threshold", type=float, default=0.5)
105+
parser.add_argument("--device", default="cuda:0" if torch.cuda.is_available() else "cpu")
106+
parser.add_argument("--seed", type=int, default=42)
107+
parser.add_argument("--bench_warmup", type=int, default=1)
108+
parser.add_argument("--bench_repeats", type=int, default=1)
109+
parser.add_argument("--distill", action="store_true",
110+
help="Fine-tune the compressed model on raw-model pseudo-labels.")
111+
parser.add_argument("--distill_size", type=int, default=1000,
112+
help="Number of texts to use for distillation (drawn after compress slice).")
113+
parser.add_argument("--distill_epochs", type=int, default=3)
114+
parser.add_argument("--distill_lr", type=float, default=1e-5)
115+
parser.add_argument("--distill_threshold", type=float, default=0.3)
116+
parser.add_argument("--distill_output_dir", type=str, default="./distill_ckpt")
117+
args = parser.parse_args()
118+
119+
random.seed(args.seed)
120+
121+
print(f"Loading dataset {args.dataset} [{args.split}]...")
122+
ds = load_dataset(args.dataset, split=args.split)
123+
124+
processed = [char_to_word_sample(r["text"], r["entities"]) for r in ds]
125+
processed = [p for p in processed if p["ner"]] # drop empties
126+
127+
labels = sorted({t for p in processed for _, _, t in p["ner"]})
128+
print(f"{len(processed)} samples, {len(labels)} labels: {labels}")
129+
130+
random.shuffle(processed)
131+
# Pin the full label set on every sample so raw and compressed evaluations
132+
# share an identical label space. Without this, raw eval would derive
133+
# labels per-sample (only the positives present) and be unfairly easier
134+
# than the compressed path, which always classifies over all labels.
135+
for p in processed:
136+
p["ner_labels"] = labels
137+
eval_data = processed[: args.eval_size]
138+
compress_slice = processed[args.eval_size : args.eval_size + args.compress_size]
139+
if not compress_slice:
140+
compress_slice = processed[: args.compress_size]
141+
compress_texts = [" ".join(p["tokenized_text"]) for p in compress_slice]
142+
143+
distill_start = args.eval_size + args.compress_size
144+
distill_slice = processed[distill_start : distill_start + args.distill_size] if args.distill else []
145+
146+
print(f"Loading model {args.model}...")
147+
model = GLiNER.from_pretrained(args.model).to(args.device)
148+
149+
eval_kwargs = dict(flat_ner=True, threshold=args.threshold, batch_size=args.batch_size)
150+
n = len(eval_data)
151+
152+
print("=== Raw GLiNER evaluation ===")
153+
raw_out, raw_f1, raw_mean, raw_best = timed_evaluate(
154+
model, eval_data, warmup=args.bench_warmup, repeats=args.bench_repeats,
155+
device=args.device, **eval_kwargs,
156+
)
157+
print(raw_out)
158+
print(f"Raw F1: {raw_f1:.4f}")
159+
print(f"Raw timing (n={n}, bs={args.batch_size}, repeats={args.bench_repeats}): "
160+
f"mean {raw_mean:.3f}s | best {raw_best:.3f}s | "
161+
f"{n / raw_mean:.1f} samples/s")
162+
163+
distill_data = None
164+
if args.distill and distill_slice:
165+
print(f"Generating pseudo-labels from raw model on {len(distill_slice)} distillation texts...")
166+
distill_texts = [" ".join(p["tokenized_text"]) for p in distill_slice]
167+
preds = model.inference(
168+
distill_texts, labels, flat_ner=True,
169+
threshold=args.distill_threshold, batch_size=args.batch_size,
170+
)
171+
distill_data = [predictions_to_ner(t, p) for t, p in zip(distill_texts, preds)]
172+
kept = sum(1 for d in distill_data if d["ner"])
173+
print(f" {kept}/{len(distill_data)} samples carry at least one pseudo-label")
174+
175+
print(f"Compressing prompt embeddings over {len(compress_texts)} texts...")
176+
model.compress_prompt_embeddings(
177+
texts=compress_texts, labels=labels, batch_size=args.batch_size
178+
)
179+
model.config.precomputed_prompts_mode = True
180+
181+
if distill_data:
182+
print(f"Fine-tuning compressed model on pseudo-labels "
183+
f"(epochs={args.distill_epochs}, lr={args.distill_lr})...")
184+
distill_finetune(
185+
model, distill_data,
186+
epochs=args.distill_epochs, lr=args.distill_lr,
187+
batch_size=args.batch_size, output_dir=args.distill_output_dir,
188+
)
189+
190+
print("=== Compressed GLiNER evaluation ===")
191+
comp_out, comp_f1, comp_mean, comp_best = timed_evaluate(
192+
model, eval_data, warmup=args.bench_warmup, repeats=args.bench_repeats,
193+
device=args.device, **eval_kwargs,
194+
)
195+
print(comp_out)
196+
print(f"Compressed F1: {comp_f1:.4f}")
197+
print(f"Compressed timing (n={n}, bs={args.batch_size}, repeats={args.bench_repeats}): "
198+
f"mean {comp_mean:.3f}s | best {comp_best:.3f}s | "
199+
f"{n / comp_mean:.1f} samples/s")
200+
201+
print("\n=== Summary ===")
202+
print(f"Raw F1: {raw_f1:.4f} | mean {raw_mean:.3f}s | {n / raw_mean:.1f} samples/s")
203+
print(f"Compressed F1: {comp_f1:.4f} | mean {comp_mean:.3f}s | {n / comp_mean:.1f} samples/s")
204+
print(f"Delta F1 : {comp_f1 - raw_f1:+.4f}")
205+
print(f"Speedup : {raw_mean / comp_mean:.2f}x")
206+
207+
208+
if __name__ == "__main__":
209+
main()

docs/usage.md

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -990,6 +990,112 @@ print(f"- Products: {[e['text'] for e in entities if e['label'] == 'product']}")
990990
print(f"- Timeline: {[e['text'] for e in entities if e['label'] == 'date']}")
991991
```
992992

993+
## ⚡ Prompt Compression (Precomputed Prompt Embeddings)
994+
995+
For uni-encoder models (span, token, and relation-extraction variants) you can
996+
precompute the prompt embeddings for a **fixed** label set and reuse them at
997+
inference time. In precomputed mode the encoder receives only the text
998+
(no `<<ENT>>label1<<ENT>>...<<SEP>>` prefix), which shortens the input sequence,
999+
reduces attention cost, and can noticeably speed up inference — at a small
1000+
accuracy trade-off versus re-encoding the prompts on every call.
1001+
1002+
### How it works
1003+
1004+
`BaseGLiNER.compress_prompt_embeddings(texts, labels, rel_labels=None, batch_size=8, distill=False, distill_threshold=0.3, distill_epochs=3, distill_lr=1e-5, distill_batch_size=None, distill_output_dir="./distill_ckpt", distill_train_kwargs=None)`:
1005+
1006+
1. Runs the normal forward pass over `(texts, labels)` pairs.
1007+
2. Extracts the per-label prompt embedding (the `<<ENT>>` token representation,
1008+
pre-projection) from each example.
1009+
3. Averages across all examples to produce an `(L, D)` matrix stored as a
1010+
non-trainable parameter on the underlying model (`model.precomputed_prompts`).
1011+
4. Sets `config.precomputed_prompts_mode = True` and writes
1012+
`config.id_to_classes`, so subsequent `predict_entities` / `forward` calls
1013+
skip prompt-prepending and look up the stored embeddings instead.
1014+
1015+
The stored embeddings travel with `state_dict`, so `save_pretrained` /
1016+
`from_pretrained` round-trip them automatically. Training can continue after
1017+
compression — the stored matrix is frozen but everything else keeps training.
1018+
1019+
### Basic usage (entity extraction)
1020+
1021+
```python
1022+
from gliner import GLiNER
1023+
1024+
model = GLiNER.from_pretrained("urchade/gliner_small-v2.1")
1025+
1026+
# Representative texts from your target domain. They do not need labels;
1027+
# they are only used as contexts while averaging the prompt representations.
1028+
calibration_texts = [
1029+
"Barack Obama was born in Honolulu, Hawaii.",
1030+
"Apple announced a new iPhone at their Cupertino headquarters.",
1031+
# ... ideally 100–1000 diverse sentences from your domain
1032+
]
1033+
1034+
labels = ["person", "organization", "location", "date"]
1035+
1036+
# One-time compression step
1037+
model.compress_prompt_embeddings(calibration_texts, labels, batch_size=16)
1038+
1039+
# Inference now uses the precomputed prompts — no need to pass labels again
1040+
entities = model.predict_entities(
1041+
"Tim Cook visited Berlin last Tuesday.",
1042+
labels, # must match (order-insensitive) the compressed set
1043+
threshold=0.5,
1044+
)
1045+
1046+
# Persist the compressed model
1047+
model.save_pretrained("./gliner-compressed")
1048+
```
1049+
1050+
### Relation extraction
1051+
1052+
For relex models (`UniEncoderSpanRelexModel` / `UniEncoderTokenRelexModel`),
1053+
pass `rel_labels` so the `<<REL>>` prompt embeddings are compressed as well:
1054+
1055+
```python
1056+
model.compress_prompt_embeddings(
1057+
texts=calibration_texts,
1058+
labels=["person", "organization", "location"],
1059+
rel_labels=["works_for", "located_in", "founder_of"],
1060+
batch_size=8,
1061+
)
1062+
```
1063+
1064+
### End-to-end distillation
1065+
1066+
Compression alone can dip quality because averaged prompt embeddings drop
1067+
context-specific signal. Pass `distill=True` to recover it in a single call:
1068+
the raw (pre-compression) model first generates pseudo-labels over `texts`,
1069+
prompts are then compressed, and the compressed model is fine-tuned on those
1070+
pseudo-labels — no separate script required.
1071+
1072+
```python
1073+
model.compress_prompt_embeddings(
1074+
texts=calibration_texts, # also used as the distillation corpus
1075+
labels=labels,
1076+
batch_size=16,
1077+
distill=True,
1078+
distill_threshold=0.3, # pseudo-label confidence cutoff
1079+
distill_epochs=3,
1080+
distill_lr=1e-5,
1081+
distill_output_dir="./distill_ckpt",
1082+
)
1083+
```
1084+
1085+
Relevant knobs:
1086+
1087+
- `distill_threshold`: confidence cutoff used when the raw model produces
1088+
pseudo-labels. Lower values widen the training signal but add noise.
1089+
- `distill_epochs`, `distill_lr`: fine-tuning schedule.
1090+
- `distill_batch_size`: defaults to `batch_size` if omitted.
1091+
- `distill_output_dir`: forwarded to `train_model`.
1092+
- `distill_train_kwargs`: dict of extra kwargs merged into the underlying
1093+
`train_model` call (e.g. to override `save_strategy`, `logging_steps`, etc.).
1094+
1095+
Pseudo-labels are generated from the same `texts` used for compression, so one
1096+
diverse in-domain corpus serves both roles.
1097+
1098+
9931099
## Tips and Best Practices
9941100

9951101
1. **Choose the right model architecture**:

gliner/config.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ def __init__(
3939
span_loss_coef: float = 1.0,
4040
represent_spans: bool = False,
4141
neg_spans_ratio: float = 1.0,
42+
precomputed_prompts_mode: Optional[bool] = None,
43+
id_to_classes: Optional[dict] = None,
4244
**kwargs,
4345
):
4446
"""Initialize BaseGLiNERConfig.
@@ -72,6 +74,8 @@ def __init__(
7274
span_loss_coef (float, optional): Span loss coefficient. Defaults to 1.0.
7375
represent_spans (bool, optional): Whether to represent spans. Defaults to False.
7476
neg_spans_ratio (float, optional): Ratio of negative spans. Defaults to 1.0.
77+
precomputed_prompts_mode (Optional[bool]): Whether to use precomputed prompts. Defaults to None.
78+
id_to_classes (Optional[dict]): Mapping from class IDs to class names. Defaults to None.
7579
**kwargs: Additional keyword arguments passed to parent class.
7680
"""
7781
super().__init__(**kwargs)
@@ -108,6 +112,8 @@ def __init__(
108112
self.span_loss_coef = span_loss_coef
109113
self.represent_spans = represent_spans
110114
self.neg_spans_ratio = neg_spans_ratio
115+
self.precomputed_prompts_mode = precomputed_prompts_mode
116+
self.id_to_classes = id_to_classes
111117

112118

113119
class UniEncoderConfig(BaseGLiNERConfig):
@@ -201,6 +207,7 @@ def __init__(
201207
augment_ent_drop_prob=(0.0, 1.0),
202208
augment_rel_drop_prob=(0.0, 0.3),
203209
augment_add_other_prob=0.5,
210+
rel_id_to_classes: Optional[dict] = None,
204211
**kwargs,
205212
):
206213
"""Initialize UniEncoderRelexConfig.
@@ -223,6 +230,7 @@ def __init__(
223230
the per-type entity drop probability. Defaults to (0.0, 0.4).
224231
augment_rel_drop_prob (tuple, optional): Range (min, max) from which to sample
225232
the per-type relation drop probability. Defaults to (0.0, 0.4).
233+
rel_id_to_classes (Optional[dict]): Mapping from relation class IDs to class names. Defaults to None.
226234
**kwargs: Additional keyword arguments passed to UniEncoderConfig.
227235
228236
Raises:
@@ -241,6 +249,7 @@ def __init__(
241249
self.augment_ent_drop_prob = tuple(augment_ent_drop_prob)
242250
self.augment_rel_drop_prob = tuple(augment_rel_drop_prob)
243251
self.augment_add_other_prob = augment_add_other_prob
252+
self.rel_id_to_classes = rel_id_to_classes
244253

245254

246255
class UniEncoderSpanRelexConfig(UniEncoderRelexConfig):

0 commit comments

Comments
 (0)