|
| 1 | +"""Evaluate raw vs compressed-prompt GLiNER on knowledgator/biomed_NER.""" |
| 2 | + |
| 3 | +import argparse |
| 4 | +import random |
| 5 | +import time |
| 6 | + |
| 7 | +import torch |
| 8 | +from datasets import load_dataset |
| 9 | + |
| 10 | +from gliner import GLiNER |
| 11 | + |
| 12 | + |
| 13 | +def predictions_to_ner(text, preds): |
| 14 | + """Map char-offset predictions from model.inference to word-level ner tuples.""" |
| 15 | + ent_dicts = [{"start": p["start"], "end": p["end"], "class": p["label"]} for p in preds] |
| 16 | + return char_to_word_sample(text, ent_dicts) |
| 17 | + |
| 18 | + |
| 19 | +def distill_finetune(model, distill_data, *, epochs, lr, batch_size, output_dir): |
| 20 | + """Fine-tune `model` on pseudo-labeled `distill_data` via GLiNER.train_model.""" |
| 21 | + # Attach the full label set so the collator uses it with prepare_labels=True. |
| 22 | + model.train_model( |
| 23 | + train_dataset=distill_data, |
| 24 | + eval_dataset=None, |
| 25 | + output_dir=output_dir, |
| 26 | + num_train_epochs=epochs, |
| 27 | + max_steps=-1, # override create_training_args' default (10000) so num_train_epochs wins |
| 28 | + per_device_train_batch_size=batch_size, |
| 29 | + learning_rate=lr, |
| 30 | + save_strategy="no", |
| 31 | + report_to="none", |
| 32 | + logging_steps=10, |
| 33 | + remove_unused_columns=False, |
| 34 | + ) |
| 35 | + model.eval() |
| 36 | + |
| 37 | + |
| 38 | +def timed_evaluate(model, eval_data, *, warmup, repeats, device, **eval_kwargs): |
| 39 | + """Run model.evaluate once for metrics and `repeats` times for timing.""" |
| 40 | + if device.startswith("cuda"): |
| 41 | + torch.cuda.synchronize() |
| 42 | + out, f1 = model.evaluate(eval_data, **eval_kwargs) |
| 43 | + |
| 44 | + for _ in range(warmup): |
| 45 | + model.evaluate(eval_data, **eval_kwargs) |
| 46 | + |
| 47 | + if device.startswith("cuda"): |
| 48 | + torch.cuda.synchronize() |
| 49 | + times = [] |
| 50 | + for _ in range(repeats): |
| 51 | + t0 = time.perf_counter() |
| 52 | + model.evaluate(eval_data, **eval_kwargs) |
| 53 | + if device.startswith("cuda"): |
| 54 | + torch.cuda.synchronize() |
| 55 | + times.append(time.perf_counter() - t0) |
| 56 | + |
| 57 | + mean = sum(times) / len(times) |
| 58 | + return out, f1, mean, min(times) |
| 59 | + |
| 60 | + |
| 61 | +def char_to_word_sample(text, entities): |
| 62 | + """Convert {text, entities:[{class,start,end}]} to {tokenized_text, ner}. |
| 63 | +
|
| 64 | + Uses whitespace tokenization and aligns char offsets to word indices. |
| 65 | + Entities that don't align to word boundaries are dropped. |
| 66 | + """ |
| 67 | + words = text.split() |
| 68 | + # Build char-start index for each word (assuming single-space separation of split()). |
| 69 | + char_starts, char_ends = [], [] |
| 70 | + cursor = 0 |
| 71 | + remaining = text |
| 72 | + for w in words: |
| 73 | + idx = remaining.find(w) |
| 74 | + abs_start = cursor + idx |
| 75 | + char_starts.append(abs_start) |
| 76 | + char_ends.append(abs_start + len(w)) |
| 77 | + cursor = abs_start + len(w) |
| 78 | + remaining = text[cursor:] |
| 79 | + |
| 80 | + start_to_widx = {s: i for i, s in enumerate(char_starts)} |
| 81 | + end_to_widx = {e: i for i, e in enumerate(char_ends)} |
| 82 | + |
| 83 | + ner = [] |
| 84 | + for ent in entities: |
| 85 | + s, e, cls = ent["start"], ent["end"], ent["class"].lower() |
| 86 | + # Tolerate leading/trailing whitespace inside span |
| 87 | + span_text = text[s:e] |
| 88 | + ls = len(span_text) - len(span_text.lstrip()) |
| 89 | + le = len(span_text) - len(span_text.rstrip()) |
| 90 | + s2, e2 = s + ls, e - le |
| 91 | + if s2 in start_to_widx and e2 in end_to_widx: |
| 92 | + ner.append((start_to_widx[s2], end_to_widx[e2], cls)) |
| 93 | + return {"tokenized_text": words, "ner": ner} |
| 94 | + |
| 95 | + |
| 96 | +def main(): |
| 97 | + parser = argparse.ArgumentParser() |
| 98 | + parser.add_argument("--model", default="gliner-community/gliner_small-v2.5") |
| 99 | + parser.add_argument("--dataset", default="knowledgator/biomed_NER") |
| 100 | + parser.add_argument("--split", default="train") |
| 101 | + parser.add_argument("--eval_size", type=int, default=3000) |
| 102 | + parser.add_argument("--compress_size", type=int, default=1000) |
| 103 | + parser.add_argument("--batch_size", type=int, default=4) |
| 104 | + parser.add_argument("--threshold", type=float, default=0.5) |
| 105 | + parser.add_argument("--device", default="cuda:0" if torch.cuda.is_available() else "cpu") |
| 106 | + parser.add_argument("--seed", type=int, default=42) |
| 107 | + parser.add_argument("--bench_warmup", type=int, default=1) |
| 108 | + parser.add_argument("--bench_repeats", type=int, default=1) |
| 109 | + parser.add_argument("--distill", action="store_true", |
| 110 | + help="Fine-tune the compressed model on raw-model pseudo-labels.") |
| 111 | + parser.add_argument("--distill_size", type=int, default=1000, |
| 112 | + help="Number of texts to use for distillation (drawn after compress slice).") |
| 113 | + parser.add_argument("--distill_epochs", type=int, default=3) |
| 114 | + parser.add_argument("--distill_lr", type=float, default=1e-5) |
| 115 | + parser.add_argument("--distill_threshold", type=float, default=0.3) |
| 116 | + parser.add_argument("--distill_output_dir", type=str, default="./distill_ckpt") |
| 117 | + args = parser.parse_args() |
| 118 | + |
| 119 | + random.seed(args.seed) |
| 120 | + |
| 121 | + print(f"Loading dataset {args.dataset} [{args.split}]...") |
| 122 | + ds = load_dataset(args.dataset, split=args.split) |
| 123 | + |
| 124 | + processed = [char_to_word_sample(r["text"], r["entities"]) for r in ds] |
| 125 | + processed = [p for p in processed if p["ner"]] # drop empties |
| 126 | + |
| 127 | + labels = sorted({t for p in processed for _, _, t in p["ner"]}) |
| 128 | + print(f"{len(processed)} samples, {len(labels)} labels: {labels}") |
| 129 | + |
| 130 | + random.shuffle(processed) |
| 131 | + # Pin the full label set on every sample so raw and compressed evaluations |
| 132 | + # share an identical label space. Without this, raw eval would derive |
| 133 | + # labels per-sample (only the positives present) and be unfairly easier |
| 134 | + # than the compressed path, which always classifies over all labels. |
| 135 | + for p in processed: |
| 136 | + p["ner_labels"] = labels |
| 137 | + eval_data = processed[: args.eval_size] |
| 138 | + compress_slice = processed[args.eval_size : args.eval_size + args.compress_size] |
| 139 | + if not compress_slice: |
| 140 | + compress_slice = processed[: args.compress_size] |
| 141 | + compress_texts = [" ".join(p["tokenized_text"]) for p in compress_slice] |
| 142 | + |
| 143 | + distill_start = args.eval_size + args.compress_size |
| 144 | + distill_slice = processed[distill_start : distill_start + args.distill_size] if args.distill else [] |
| 145 | + |
| 146 | + print(f"Loading model {args.model}...") |
| 147 | + model = GLiNER.from_pretrained(args.model).to(args.device) |
| 148 | + |
| 149 | + eval_kwargs = dict(flat_ner=True, threshold=args.threshold, batch_size=args.batch_size) |
| 150 | + n = len(eval_data) |
| 151 | + |
| 152 | + print("=== Raw GLiNER evaluation ===") |
| 153 | + raw_out, raw_f1, raw_mean, raw_best = timed_evaluate( |
| 154 | + model, eval_data, warmup=args.bench_warmup, repeats=args.bench_repeats, |
| 155 | + device=args.device, **eval_kwargs, |
| 156 | + ) |
| 157 | + print(raw_out) |
| 158 | + print(f"Raw F1: {raw_f1:.4f}") |
| 159 | + print(f"Raw timing (n={n}, bs={args.batch_size}, repeats={args.bench_repeats}): " |
| 160 | + f"mean {raw_mean:.3f}s | best {raw_best:.3f}s | " |
| 161 | + f"{n / raw_mean:.1f} samples/s") |
| 162 | + |
| 163 | + distill_data = None |
| 164 | + if args.distill and distill_slice: |
| 165 | + print(f"Generating pseudo-labels from raw model on {len(distill_slice)} distillation texts...") |
| 166 | + distill_texts = [" ".join(p["tokenized_text"]) for p in distill_slice] |
| 167 | + preds = model.inference( |
| 168 | + distill_texts, labels, flat_ner=True, |
| 169 | + threshold=args.distill_threshold, batch_size=args.batch_size, |
| 170 | + ) |
| 171 | + distill_data = [predictions_to_ner(t, p) for t, p in zip(distill_texts, preds)] |
| 172 | + kept = sum(1 for d in distill_data if d["ner"]) |
| 173 | + print(f" {kept}/{len(distill_data)} samples carry at least one pseudo-label") |
| 174 | + |
| 175 | + print(f"Compressing prompt embeddings over {len(compress_texts)} texts...") |
| 176 | + model.compress_prompt_embeddings( |
| 177 | + texts=compress_texts, labels=labels, batch_size=args.batch_size |
| 178 | + ) |
| 179 | + model.config.precomputed_prompts_mode = True |
| 180 | + |
| 181 | + if distill_data: |
| 182 | + print(f"Fine-tuning compressed model on pseudo-labels " |
| 183 | + f"(epochs={args.distill_epochs}, lr={args.distill_lr})...") |
| 184 | + distill_finetune( |
| 185 | + model, distill_data, |
| 186 | + epochs=args.distill_epochs, lr=args.distill_lr, |
| 187 | + batch_size=args.batch_size, output_dir=args.distill_output_dir, |
| 188 | + ) |
| 189 | + |
| 190 | + print("=== Compressed GLiNER evaluation ===") |
| 191 | + comp_out, comp_f1, comp_mean, comp_best = timed_evaluate( |
| 192 | + model, eval_data, warmup=args.bench_warmup, repeats=args.bench_repeats, |
| 193 | + device=args.device, **eval_kwargs, |
| 194 | + ) |
| 195 | + print(comp_out) |
| 196 | + print(f"Compressed F1: {comp_f1:.4f}") |
| 197 | + print(f"Compressed timing (n={n}, bs={args.batch_size}, repeats={args.bench_repeats}): " |
| 198 | + f"mean {comp_mean:.3f}s | best {comp_best:.3f}s | " |
| 199 | + f"{n / comp_mean:.1f} samples/s") |
| 200 | + |
| 201 | + print("\n=== Summary ===") |
| 202 | + print(f"Raw F1: {raw_f1:.4f} | mean {raw_mean:.3f}s | {n / raw_mean:.1f} samples/s") |
| 203 | + print(f"Compressed F1: {comp_f1:.4f} | mean {comp_mean:.3f}s | {n / comp_mean:.1f} samples/s") |
| 204 | + print(f"Delta F1 : {comp_f1 - raw_f1:+.4f}") |
| 205 | + print(f"Speedup : {raw_mean / comp_mean:.2f}x") |
| 206 | + |
| 207 | + |
| 208 | +if __name__ == "__main__": |
| 209 | + main() |
0 commit comments