1- """Evaluation script for the encoder line classifier .
1+ """Evaluation script for encoder and pooled line classifiers .
22
33Runs inference on an eval set and computes the same metrics as the generative
44model's evaluate.py for direct comparison:
99- ROUGE-L
1010- Compression ratio
1111
12+ Supports both classifier types (auto-detected from model config):
13+ - token: SqueezEncoderForLineClassification (token-level with [LINE_SEP])
14+ - pooled: PooledLineClassifier (line-level mean-pool with [LINE_SEP])
15+
1216Usage:
1317 python -m squeez.encoder.evaluate \
1418 --model-path output/squeez_encoder \
1519 --eval-file data/encoder_test.jsonl
20+
21+ python -m squeez.encoder.evaluate \
22+ --model-path output/squeez_pooled \
23+ --eval-file data/encoder_test.jsonl
1624"""
1725
1826from __future__ import annotations
2129import json
2230import logging
2331import statistics
32+ from pathlib import Path
2433
2534logger = logging .getLogger (__name__ )
2635
2736
37+ def _load_model_and_tokenizer (model_path : str ):
38+ """Load encoder or pooled model from path (auto-detected)."""
39+ import json
40+
41+ import torch
42+ from transformers import AutoTokenizer
43+
44+ from squeez .encoder .model import LINE_SEP_TOKEN
45+
46+ config_path = Path (model_path ) / "config.json"
47+ model_type = "encoder"
48+ if config_path .exists ():
49+ with open (config_path ) as f :
50+ cfg = json .load (f )
51+ if cfg .get ("model_type" ) == "squeez-pooled" :
52+ model_type = "pooled"
53+
54+ tokenizer = AutoTokenizer .from_pretrained (model_path , trust_remote_code = True )
55+ if tokenizer .convert_tokens_to_ids (LINE_SEP_TOKEN ) == tokenizer .unk_token_id :
56+ tokenizer .add_special_tokens ({"additional_special_tokens" : [LINE_SEP_TOKEN ]})
57+
58+ if model_type == "pooled" :
59+ from squeez .encoder .sentence import PooledLineClassifier
60+
61+ model = PooledLineClassifier .from_pretrained (model_path , trust_remote_code = True )
62+ else :
63+ from squeez .encoder .model import SqueezEncoderForLineClassification
64+
65+ model = SqueezEncoderForLineClassification .from_pretrained (
66+ model_path , trust_remote_code = True
67+ )
68+
69+ device = "cuda" if torch .cuda .is_available () else "cpu"
70+ model = model .to (device )
71+ model .eval ()
72+
73+ return model , tokenizer , model_type
74+
75+
2876def evaluate_encoder (
2977 model_path : str ,
3078 eval_file : str ,
3179 max_samples : int | None = None ,
3280 threshold : float = 0.5 ,
3381 examples_output : str | None = None ,
3482) -> dict :
35- """Evaluate the encoder model on an eval set.
83+ """Evaluate an encoder or pooled model on an eval set.
84+
85+ Auto-detects model type from config.json (squeez-encoder vs squeez-pooled).
3686
3787 Args:
38- model_path: Path to trained encoder model
88+ model_path: Path to trained model
3989 eval_file: Path to encoder-format JSONL
4090 max_samples: Maximum samples to evaluate
4191 threshold: Relevance score threshold
4292
4393 Returns:
4494 Dict with aggregate metrics (same format as generative evaluate.py)
4595 """
46- import torch
47- from transformers import AutoTokenizer
48-
49- from squeez .encoder .model import LINE_SEP_TOKEN , SqueezEncoderForLineClassification
5096 from squeez .training .evaluate import (
5197 compute_compression_ratio ,
5298 compute_empty_accuracy ,
@@ -56,17 +102,8 @@ def evaluate_encoder(
56102 compute_span_metrics ,
57103 )
58104
59- logger .info (f"Loading encoder model from { model_path } " )
60- tokenizer = AutoTokenizer .from_pretrained (model_path , trust_remote_code = True )
61-
62- # Ensure LINE_SEP is in tokenizer
63- if tokenizer .convert_tokens_to_ids (LINE_SEP_TOKEN ) == tokenizer .unk_token_id :
64- tokenizer .add_special_tokens ({"additional_special_tokens" : [LINE_SEP_TOKEN ]})
65-
66- model = SqueezEncoderForLineClassification .from_pretrained (model_path , trust_remote_code = True )
67- device = "cuda" if torch .cuda .is_available () else "cpu"
68- model = model .to (device )
69- model .eval ()
105+ logger .info (f"Loading model from { model_path } " )
106+ model , tokenizer , model_type = _load_model_and_tokenizer (model_path )
70107
71108 # Load eval data
72109 samples = []
@@ -189,7 +226,7 @@ def evaluate_encoder(
189226
190227 results ["empty_confusion" ] = empty_confusion
191228 results ["num_samples" ] = len (samples )
192- results ["model_type" ] = "encoder"
229+ results ["model_type" ] = model_type
193230 results ["threshold" ] = threshold
194231
195232 if examples_output :
@@ -198,7 +235,7 @@ def evaluate_encoder(
198235 logger .info (f"Saved per-sample examples to { examples_output } " )
199236
200237 logger .info ("=" * 60 )
201- logger .info ("ENCODER EVALUATION RESULTS" )
238+ logger .info (f" EVALUATION RESULTS ( { model_type } ) " )
202239 logger .info ("=" * 60 )
203240 for key , stats in results .items ():
204241 if isinstance (stats , dict ) and "mean" in stats :
0 commit comments