Skip to content

Commit 5f2abc4

Browse files
committed
Better extractor and eval
1 parent aa18836 commit 5f2abc4

7 files changed

Lines changed: 947 additions & 65 deletions

File tree

scripts/merge_lora.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,33 @@ def main(argv: list[str] | None = None) -> int:
6868
tokenizer,
6969
save_method="merged_16bit",
7070
)
71+
72+
# Copy VL processor files from base model if needed (e.g. preprocessor_config.json).
73+
# Unsloth saves the VL architecture in config.json but _prepare_text_tokenizer
74+
# strips the processor, so these files are missing from the merged output.
75+
config_path = Path(args.output) / "config.json"
76+
with open(config_path) as f:
77+
saved_config = json.load(f)
78+
archs = saved_config.get("architectures", [])
79+
is_vl = any("ConditionalGeneration" in a or "VL" in a for a in archs)
80+
81+
if is_vl:
82+
import shutil
83+
84+
from huggingface_hub import hf_hub_download
85+
86+
vl_files = ["preprocessor_config.json", "chat_template.json"]
87+
for filename in vl_files:
88+
dest = Path(args.output) / filename
89+
if dest.exists():
90+
continue
91+
try:
92+
src = hf_hub_download(base_model_name, filename)
93+
shutil.copy(src, dest)
94+
logger.info(f"Copied {filename} from {base_model_name}")
95+
except Exception:
96+
pass # File may not exist for all models
97+
7198
logger.info(f"Done. Merged model saved to {args.output}")
7299
return 0
73100

squeez/encoder/__init__.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,27 @@
11
"""Encoder-based line classifier for tool output extraction."""
22

3-
__all__ = ["SqueezEncoderConfig", "SqueezEncoderForLineClassification"]
3+
__all__ = [
4+
"SqueezEncoderConfig",
5+
"SqueezEncoderForLineClassification",
6+
"PooledLineConfig",
7+
"PooledLineClassifier",
8+
]
49

510

611
def __getattr__(name: str):
712
"""Lazily import encoder model classes so lightweight helpers stay optional."""
8-
if name in __all__:
13+
if name in ("SqueezEncoderConfig", "SqueezEncoderForLineClassification"):
914
from squeez.encoder.model import SqueezEncoderConfig, SqueezEncoderForLineClassification
1015

1116
return {
1217
"SqueezEncoderConfig": SqueezEncoderConfig,
1318
"SqueezEncoderForLineClassification": SqueezEncoderForLineClassification,
1419
}[name]
20+
if name in ("PooledLineConfig", "PooledLineClassifier"):
21+
from squeez.encoder.sentence import PooledLineClassifier, PooledLineConfig
22+
23+
return {
24+
"PooledLineConfig": PooledLineConfig,
25+
"PooledLineClassifier": PooledLineClassifier,
26+
}[name]
1527
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

squeez/encoder/evaluate.py

Lines changed: 57 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Evaluation script for the encoder line classifier.
1+
"""Evaluation script for encoder and pooled line classifiers.
22
33
Runs inference on an eval set and computes the same metrics as the generative
44
model's evaluate.py for direct comparison:
@@ -9,10 +9,18 @@
99
- ROUGE-L
1010
- Compression ratio
1111
12+
Supports both classifier types (auto-detected from model config):
13+
- token: SqueezEncoderForLineClassification (token-level with [LINE_SEP])
14+
- pooled: PooledLineClassifier (line-level mean-pool with [LINE_SEP])
15+
1216
Usage:
1317
python -m squeez.encoder.evaluate \
1418
--model-path output/squeez_encoder \
1519
--eval-file data/encoder_test.jsonl
20+
21+
python -m squeez.encoder.evaluate \
22+
--model-path output/squeez_pooled \
23+
--eval-file data/encoder_test.jsonl
1624
"""
1725

1826
from __future__ import annotations
@@ -21,32 +29,70 @@
2129
import json
2230
import logging
2331
import statistics
32+
from pathlib import Path
2433

2534
logger = logging.getLogger(__name__)
2635

2736

37+
def _load_model_and_tokenizer(model_path: str):
38+
"""Load encoder or pooled model from path (auto-detected)."""
39+
import json
40+
41+
import torch
42+
from transformers import AutoTokenizer
43+
44+
from squeez.encoder.model import LINE_SEP_TOKEN
45+
46+
config_path = Path(model_path) / "config.json"
47+
model_type = "encoder"
48+
if config_path.exists():
49+
with open(config_path) as f:
50+
cfg = json.load(f)
51+
if cfg.get("model_type") == "squeez-pooled":
52+
model_type = "pooled"
53+
54+
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
55+
if tokenizer.convert_tokens_to_ids(LINE_SEP_TOKEN) == tokenizer.unk_token_id:
56+
tokenizer.add_special_tokens({"additional_special_tokens": [LINE_SEP_TOKEN]})
57+
58+
if model_type == "pooled":
59+
from squeez.encoder.sentence import PooledLineClassifier
60+
61+
model = PooledLineClassifier.from_pretrained(model_path, trust_remote_code=True)
62+
else:
63+
from squeez.encoder.model import SqueezEncoderForLineClassification
64+
65+
model = SqueezEncoderForLineClassification.from_pretrained(
66+
model_path, trust_remote_code=True
67+
)
68+
69+
device = "cuda" if torch.cuda.is_available() else "cpu"
70+
model = model.to(device)
71+
model.eval()
72+
73+
return model, tokenizer, model_type
74+
75+
2876
def evaluate_encoder(
2977
model_path: str,
3078
eval_file: str,
3179
max_samples: int | None = None,
3280
threshold: float = 0.5,
3381
examples_output: str | None = None,
3482
) -> dict:
35-
"""Evaluate the encoder model on an eval set.
83+
"""Evaluate an encoder or pooled model on an eval set.
84+
85+
Auto-detects model type from config.json (squeez-encoder vs squeez-pooled).
3686
3787
Args:
38-
model_path: Path to trained encoder model
88+
model_path: Path to trained model
3989
eval_file: Path to encoder-format JSONL
4090
max_samples: Maximum samples to evaluate
4191
threshold: Relevance score threshold
4292
4393
Returns:
4494
Dict with aggregate metrics (same format as generative evaluate.py)
4595
"""
46-
import torch
47-
from transformers import AutoTokenizer
48-
49-
from squeez.encoder.model import LINE_SEP_TOKEN, SqueezEncoderForLineClassification
5096
from squeez.training.evaluate import (
5197
compute_compression_ratio,
5298
compute_empty_accuracy,
@@ -56,17 +102,8 @@ def evaluate_encoder(
56102
compute_span_metrics,
57103
)
58104

59-
logger.info(f"Loading encoder model from {model_path}")
60-
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
61-
62-
# Ensure LINE_SEP is in tokenizer
63-
if tokenizer.convert_tokens_to_ids(LINE_SEP_TOKEN) == tokenizer.unk_token_id:
64-
tokenizer.add_special_tokens({"additional_special_tokens": [LINE_SEP_TOKEN]})
65-
66-
model = SqueezEncoderForLineClassification.from_pretrained(model_path, trust_remote_code=True)
67-
device = "cuda" if torch.cuda.is_available() else "cpu"
68-
model = model.to(device)
69-
model.eval()
105+
logger.info(f"Loading model from {model_path}")
106+
model, tokenizer, model_type = _load_model_and_tokenizer(model_path)
70107

71108
# Load eval data
72109
samples = []
@@ -189,7 +226,7 @@ def evaluate_encoder(
189226

190227
results["empty_confusion"] = empty_confusion
191228
results["num_samples"] = len(samples)
192-
results["model_type"] = "encoder"
229+
results["model_type"] = model_type
193230
results["threshold"] = threshold
194231

195232
if examples_output:
@@ -198,7 +235,7 @@ def evaluate_encoder(
198235
logger.info(f"Saved per-sample examples to {examples_output}")
199236

200237
logger.info("=" * 60)
201-
logger.info("ENCODER EVALUATION RESULTS")
238+
logger.info(f"EVALUATION RESULTS ({model_type})")
202239
logger.info("=" * 60)
203240
for key, stats in results.items():
204241
if isinstance(stats, dict) and "mean" in stats:

0 commit comments

Comments
 (0)