Skip to content

Commit e23b104

Browse files
authored
Merge pull request #44 from rhnfzl/develop
Release v0.6.1: bug fixes, functional tests, GLiNER ONNX fallback
2 parents 33a37d6 + 7eaf728 commit e23b104

12 files changed

Lines changed: 1050 additions & 128 deletions

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ SPECS.md
173173
docs/plans/
174174
docs/GLINER_GAP_ANALYSIS.md
175175
docs/V060_PLAN.md
176+
docs/FUNCTIONAL_TEST_REPORT.md
176177
squeakycleantext-explorer.html
177178
ralph-loop-prompt.md
178179
.claude/

sct/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from sct.utils.anonymization_map import AnonymizationMap, MapEntry
66
from sct.utils.process_result import ProcessResult
77

8-
__version__ = "0.6.0"
8+
__version__ = "0.6.1"
99
__all__ = [
1010
"TextCleaner", "TextCleanerConfig",
1111
"PII_LABELS", "PII_LABEL_MAP",

sct/config.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
VALID_NER_BACKENDS = frozenset({
2626
'onnx', 'torch', 'gliner', 'ensemble_onnx', 'ensemble_torch', 'presidio_gliner',
2727
})
28+
GLINER_BACKENDS = frozenset({'gliner', 'ensemble_onnx', 'ensemble_torch', 'presidio_gliner'})
2829

2930
DEFAULT_NER_MODELS: dict[str, str] = {
3031
'ENGLISH': 'rhnfzl/xlm-roberta-large-conll03-english-onnx',
@@ -320,9 +321,7 @@ def __post_init__(self):
320321
)
321322

322323
# GLiNER fields required for gliner/ensemble backends
323-
needs_gliner = self.ner_backend in (
324-
'gliner', 'ensemble_onnx', 'ensemble_torch', 'presidio_gliner',
325-
)
324+
needs_gliner = self.ner_backend in GLINER_BACKENDS
326325
if needs_gliner:
327326
if not self.gliner_model:
328327
raise ValueError(

sct/sct.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from concurrent.futures import ThreadPoolExecutor
88
from typing import Any, Callable, Dict, List, Optional
99

10-
from sct.config import TextCleanerConfig, _config_from_module_globals
10+
from sct.config import TextCleanerConfig, GLINER_BACKENDS, _config_from_module_globals
1111
from sct.utils import constants, contact, datetime, ner, normtext, resources, special, stopwords
1212
from sct.utils.anonymization_map import AnonymizationMap
1313
from sct.utils.process_result import ProcessResult
@@ -71,9 +71,7 @@ def __init__(self, cfg: Optional[TextCleanerConfig] = None):
7171
if self.cfg.check_ner_process:
7272
# Build GLiNER config dict (if needed)
7373
gliner_config = None
74-
needs_gliner = self.cfg.ner_backend in (
75-
'gliner', 'ensemble_onnx', 'ensemble_torch', 'presidio_gliner',
76-
)
74+
needs_gliner = self.cfg.ner_backend in GLINER_BACKENDS
7775
if needs_gliner:
7876
gliner_config = {
7977
'model': self.cfg.gliner_model,
@@ -109,10 +107,8 @@ def __init__(self, cfg: Optional[TextCleanerConfig] = None):
109107
else None
110108
),
111109
replacement_mode=self.cfg.replacement_mode,
110+
synthetic_replacer=self._synthetic_replacer,
112111
)
113-
else:
114-
pass # self.GeneralNER already initialized to None above
115-
116112
# GLiClass document-level pre-classification (optional, lazy-loaded)
117113
self._gliclass: Any = None
118114
if self.cfg.check_classify_document:
@@ -126,7 +122,6 @@ def __init__(self, cfg: Optional[TextCleanerConfig] = None):
126122
onnx=self.cfg.gliclass_onnx,
127123
)
128124

129-
self.batch_size = 8
130125
self._pipeline: List[Callable[[str], str]] = []
131126
self._post_fuzzy_pipeline: List[Callable[[str], str]] = []
132127
self._init_pipeline()
@@ -227,9 +222,6 @@ def _process_single(self, text: str) -> ProcessResult:
227222
# Detect language (pure function, thread-safe)
228223
language = self._detect_language(text)
229224

230-
# Pass language explicitly through pipeline context dict
231-
ctx = {"language": language}
232-
233225
current_text = text
234226

235227
# Pre-fuzzy pipeline steps (unicode fix → html → urls → emails → dates)
@@ -239,12 +231,11 @@ def _process_single(self, text: str) -> ProcessResult:
239231
# Fuzzy date replacement — requires language context, called explicitly
240232
# to avoid thread-local; positioned between replace_dates and replace_years.
241233
if self.cfg.check_fuzzy_replace_dates:
242-
lang = ctx.get("language")
243234
current_text = self.ProcessDateTime.fuzzy_replace_dates(
244235
current_text,
245236
replace_with=self.cfg.replace_with_dates,
246237
score_cutoff=self.cfg.fuzzy_date_score_cutoff,
247-
language=lang,
238+
language=language,
248239
)
249240

250241
# Post-fuzzy pipeline steps (years → phones → numbers → symbols → whitespace)
@@ -255,7 +246,7 @@ def _process_single(self, text: str) -> ProcessResult:
255246
if self.cfg.check_ner_process and self.GeneralNER is not None:
256247
current_text = self.GeneralNER.ner_process(
257248
current_text,
258-
positional_tags=list(self.cfg.positional_tags),
249+
positional_tags=self.cfg.positional_tags,
259250
ner_confidence_threshold=self.cfg.ner_confidence_threshold,
260251
language=language,
261252
anon_map=anon_map,

sct/utils/gliclass_adapter.py

Lines changed: 16 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -32,18 +32,21 @@ def __init__(
3232
self._onnx = onnx
3333
self._pipeline = None
3434

35-
if onnx:
36-
self._init_onnx(model_id)
37-
else:
38-
self._init_pytorch(model_id)
35+
self._init_model(model_id)
3936

4037
logger.info(
4138
"Loaded GLiClass model: %s (onnx=%s, labels=%d)",
4239
model_id, onnx, len(self.labels),
4340
)
4441

45-
def _init_pytorch(self, model_id: str) -> None:
46-
"""Load GLiClass model via PyTorch (gliclass package)."""
42+
def _init_model(self, model_id: str) -> None:
43+
"""Load GLiClass model via gliclass package.
44+
45+
Note: gliclass does not yet expose a native ONNX loader, so the
46+
``onnx`` flag is recorded but both paths use the same PyTorch-backed
47+
``GLiClassModel.from_pretrained``. When gliclass adds ONNX support,
48+
this method should branch on ``self._onnx``.
49+
"""
4750
try:
4851
from gliclass import GLiClassModel, ZeroShotClassificationPipeline # noqa: S404
4952
from transformers import AutoTokenizer
@@ -62,26 +65,6 @@ def _init_pytorch(self, model_id: str) -> None:
6265
device='cpu',
6366
)
6467

65-
def _init_onnx(self, model_id: str) -> None:
66-
"""Load GLiClass model via ONNX Runtime (torch-free)."""
67-
try:
68-
from gliclass import GLiClassModel, ZeroShotClassificationPipeline # noqa: S404
69-
from transformers import AutoTokenizer
70-
except ImportError:
71-
raise ImportError(
72-
"gliclass + onnxruntime are required for ONNX GLiClass backend. "
73-
"Install with: pip install squeakycleantext[classify] squeakycleantext[classify-onnx]"
74-
)
75-
76-
model = GLiClassModel.from_pretrained(model_id)
77-
tokenizer = AutoTokenizer.from_pretrained(model_id)
78-
self._pipeline = ZeroShotClassificationPipeline(
79-
model=model,
80-
tokenizer=tokenizer,
81-
classification_type=self.classification_type,
82-
device='cpu',
83-
)
84-
8568
def classify(self, text: str) -> List[Dict[str, float]]:
8669
"""Classify text against configured labels.
8770
@@ -94,16 +77,16 @@ def classify(self, text: str) -> List[Dict[str, float]]:
9477

9578
result = self._pipeline(
9679
text,
97-
candidate_labels=self.labels,
80+
labels=self.labels,
9881
)
9982

100-
# Pipeline returns {"sequence": ..., "labels": [...], "scores": [...]}
83+
# Pipeline returns list[list[dict]] — one list per input text,
84+
# each containing {"label": str, "score": float} dicts.
10185
classifications = []
102-
labels = result.get('labels', [])
103-
scores = result.get('scores', [])
104-
for label, score in zip(labels, scores):
105-
if score >= self.threshold:
106-
classifications.append({'label': label, 'score': score})
86+
entries = result[0] if result else []
87+
for entry in entries:
88+
if entry['score'] >= self.threshold:
89+
classifications.append({'label': entry['label'], 'score': entry['score']})
10790

10891
classifications.sort(key=lambda x: x['score'], reverse=True)
10992
return classifications

sct/utils/gliner_adapter.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,19 @@ def __init__(
4848
"Install with: pip install squeakycleantext[gliner]"
4949
)
5050
if onnx:
51-
self.model = GLiNER.from_pretrained(
52-
model_id, load_onnx_model=True, load_tokenizer=True,
53-
)
54-
logger.info("Loaded GLiNER model in ONNX mode: %s", model_id)
51+
try:
52+
self.model = GLiNER.from_pretrained(
53+
model_id, load_onnx_model=True, load_tokenizer=True,
54+
)
55+
logger.info("Loaded GLiNER model in ONNX mode: %s", model_id)
56+
except FileNotFoundError:
57+
logger.warning(
58+
"ONNX model not found for %s (GLiNER issue #314: most models "
59+
"don't ship model.onnx at repo root). Falling back to PyTorch.",
60+
model_id,
61+
)
62+
self.model = GLiNER.from_pretrained(model_id)
63+
self._onnx = False
5564
else:
5665
self.model = GLiNER.from_pretrained(model_id)
5766
if device == 'cuda' and not onnx:

sct/utils/ner.py

Lines changed: 35 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,26 @@
33
import threading
44
from collections import defaultdict
55
import logging
6-
from typing import Dict, List, Optional, Union
6+
from typing import Dict, List, Optional, Sequence, Union
77
from pathlib import Path
88

99
import onnxruntime as ort
1010

1111
from presidio_anonymizer import AnonymizerEngine
1212
from presidio_anonymizer.entities import RecognizerResult
1313

14+
from typing import NamedTuple as _NamedTuple
15+
1416
from sct.utils import constants
1517
from sct.utils.anonymization_map import AnonymizationMap
1618
from sct.utils.onnx_pipeline import load_onnx_ner_model
1719
from sct.config import DEFAULT_NER_MODELS, DEFAULT_NER_ENSEMBLE, NER_ENSEMBLE_DEFAULT_KEYS, LANG_KEYS
1820

21+
22+
class AnonymizeResult(_NamedTuple):
23+
"""Result of text anonymization — lightweight typed container."""
24+
text: str
25+
1926
ort.set_default_logger_severity(3) # Silence ONNX Runtime warnings
2027

2128
logger = logging.getLogger(__name__)
@@ -53,7 +60,8 @@ def __init__(self, cache_dir: Optional[Path] = None, device: Optional[str] = Non
5360
ner_batch_size: int = 8,
5461
ensemble_models: Optional[Dict] = None,
5562
ensemble_default_keys: Optional[tuple] = None,
56-
replacement_mode: str = 'placeholder'):
63+
replacement_mode: str = 'placeholder',
64+
synthetic_replacer=None):
5765
"""Initialize NER processor.
5866
5967
Args:
@@ -67,11 +75,12 @@ def __init__(self, cache_dir: Optional[Path] = None, device: Optional[str] = Non
6775
Required when ner_backend involves GLiNER.
6876
torch_model_names: Language-keyed dict of PyTorch model repo IDs.
6977
Required when ner_backend involves torch.
78+
synthetic_replacer: Shared SyntheticReplacer instance (from TextCleaner).
7079
"""
7180
self._ner_backend = ner_backend
7281
self._ner_batch_size = ner_batch_size
7382
self._replacement_mode = replacement_mode
74-
self._synthetic_replacer = None # Lazy-loaded when replacement_mode='synthetic'
83+
self._synthetic_replacer = synthetic_replacer
7584
self._gliner_pipe = None
7685
self._ensemble_models: Dict[str, tuple] = ensemble_models if ensemble_models is not None else DEFAULT_NER_ENSEMBLE
7786
self._ensemble_default_keys: tuple = (
@@ -195,11 +204,12 @@ def _init_presidio_gliner(self, gliner_config):
195204
try:
196205
from presidio_analyzer.predefined_recognizers import GLiNERRecognizer # noqa: S404
197206
gliner_recognizer = GLiNERRecognizer(
198-
model_path=gliner_config['model'],
207+
model_name=gliner_config['model'],
199208
supported_entities=[
200209
label.upper()
201210
for label in gliner_config.get('labels', ['person', 'organization', 'location'])
202211
],
212+
threshold=gliner_config.get('threshold', 0.4),
203213
)
204214
self._analyzer.registry.add_recognizer(gliner_recognizer)
205215
except ImportError:
@@ -208,13 +218,6 @@ def _init_presidio_gliner(self, gliner_config):
208218
"Install with: pip install presidio-analyzer gliner"
209219
)
210220

211-
def _get_synthetic_replacer(self):
212-
"""Lazily initialize the SyntheticReplacer."""
213-
if self._synthetic_replacer is None:
214-
from sct.utils.synthetic import SyntheticReplacer
215-
self._synthetic_replacer = SyntheticReplacer()
216-
return self._synthetic_replacer
217-
218221
def _get_ensemble_keys(self, language: str) -> tuple:
219222
"""Return ordered model keys to run for the given language."""
220223
return self._ensemble_models.get(language, self._ensemble_default_keys)
@@ -313,9 +316,8 @@ def anonymize_text(self, text, filtered_data, replacement_mode='placeholder',
313316
return self._anonymize_reversible(text, filtered_data, anon_map)
314317

315318
if replacement_mode == 'synthetic':
316-
replacer = self._get_synthetic_replacer()
317-
result_text = replacer.generate_for_entities(text, filtered_data)
318-
return type('AnonymizeResult', (), {'text': result_text})()
319+
result_text = self._synthetic_replacer.generate_for_entities(text, filtered_data)
320+
return AnonymizeResult(text=result_text)
319321

320322
has_custom = any(
321323
items['entity_group'] not in ENTITY_TYPE_MAP
@@ -328,7 +330,7 @@ def anonymize_text(self, text, filtered_data, replacement_mode='placeholder',
328330
for items in sorted_data:
329331
tag = ENTITY_TYPE_MAP.get(items['entity_group'], items['entity_group'])
330332
text = text[:items['start']] + f"<{tag}>" + text[items['end']:]
331-
return type('AnonymizeResult', (), {'text': text})()
333+
return AnonymizeResult(text=text)
332334
else:
333335
# Standard entities only: use Presidio (existing behavior)
334336
analyzer_result = []
@@ -348,13 +350,14 @@ def anonymize_text(self, text, filtered_data, replacement_mode='placeholder',
348350
if 0 <= entry.start < text_length and 0 < entry.end <= text_length
349351
]
350352

351-
return self.engine.anonymize(text=text, analyzer_results=analyzer_result)
353+
engine_result = self.engine.anonymize(text=text, analyzer_results=analyzer_result)
354+
return AnonymizeResult(text=engine_result.text)
352355

353356
def _anonymize_reversible(self, text, filtered_data, anon_map=None):
354357
"""Replace entities with indexed placeholders and populate the map.
355358
356359
Uses right-to-left replacement to preserve character offsets.
357-
Returns an AnonymizeResult-like object with ``.text`` attribute.
360+
Returns an ``AnonymizeResult`` with the anonymized ``.text``.
358361
"""
359362
if anon_map is None:
360363
anon_map = AnonymizationMap()
@@ -373,7 +376,7 @@ def _anonymize_reversible(self, text, filtered_data, anon_map=None):
373376
)
374377
text = text[:item['start']] + placeholder + text[item['end']:]
375378

376-
return type('AnonymizeResult', (), {'text': text, 'anon_map': anon_map})()
379+
return AnonymizeResult(text=text)
377380

378381
def ner_ensemble(self, ner_results, t):
379382
"""Apply ensemble voting across multiple model results.
@@ -440,7 +443,7 @@ def _simple_chunk(self, text: str, max_tokens: int = 384) -> List[str]:
440443
def ner_process(
441444
self,
442445
text: str,
443-
positional_tags: Optional[List[str]] = None,
446+
positional_tags: Optional[Sequence[str]] = None,
444447
ner_confidence_threshold: Optional[float] = None,
445448
language: Optional[str] = None,
446449
anon_map: Optional['AnonymizationMap'] = None,
@@ -474,6 +477,16 @@ def ner_process(
474477
if not chunks:
475478
return text
476479

480+
# Pre-compute GLiNER-only tag set (constant across chunks)
481+
gliner_all_tags = None
482+
if self._ner_backend == 'gliner' and self._gliner_pipe:
483+
gliner_all_tags = set(positional_tags)
484+
gliner_all_tags.update(self._gliner_pipe.label_map.values())
485+
gliner_all_tags.update(
486+
label.upper() for label in self._gliner_pipe.labels
487+
if label not in self._gliner_pipe.label_map
488+
)
489+
477490
# --- Inference + ensemble per chunk ---
478491
ner_clean_text = []
479492
for chunk in chunks:
@@ -486,7 +499,7 @@ def ner_process(
486499
model_name = self._model_names.get(key, key)
487500
model_lock = self._get_lock(model_name)
488501
with model_lock:
489-
batch = self._get_pipeline(key)([chunk])
502+
batch = self._pipelines[key]([chunk])
490503
ner_results.extend(self.ner_data(batch[0], positional_tags))
491504

492505
# GLiNER backend
@@ -495,15 +508,8 @@ def ner_process(
495508
gliner_lock = self._get_lock('gliner')
496509
with gliner_lock:
497510
gliner_batch = self._gliner_pipe([chunk])
498-
if self._ner_backend == 'gliner':
499-
# GLiNER-only: include all mapped entity types
500-
all_tags = set(positional_tags)
501-
all_tags.update(self._gliner_pipe.label_map.values())
502-
all_tags.update(
503-
label.upper() for label in self._gliner_pipe.labels
504-
if label not in self._gliner_pipe.label_map
505-
)
506-
ner_results.extend(self.ner_data(gliner_batch[0], all_tags))
511+
if gliner_all_tags is not None:
512+
ner_results.extend(self.ner_data(gliner_batch[0], gliner_all_tags))
507513
else:
508514
# Ensemble: filter to positional_tags only
509515
ner_results.extend(self.ner_data(gliner_batch[0], positional_tags))

0 commit comments

Comments
 (0)