|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
| 3 | +import hashlib |
| 4 | +import importlib |
| 5 | +import logging |
| 6 | +import os |
| 7 | +import shutil |
| 8 | +import sys |
| 9 | +import sysconfig |
| 10 | +import tempfile |
| 11 | +import urllib.error |
| 12 | +import urllib.request |
3 | 13 | from functools import lru_cache |
4 | 14 | from typing import Final, List, Tuple |
5 | 15 |
|
6 | 16 | import spacy |
| 17 | +from filelock import FileLock |
| 18 | + |
| 19 | +logger = logging.getLogger(__name__) |
7 | 20 |
|
8 | 21 | CACHE_MAX_SIZE: Final[int] = 128 |
9 | 22 |
|
10 | | -try: |
11 | | - _nlp = spacy.load("en_core_web_sm") |
12 | | -except OSError: |
13 | | - raise OSError( |
14 | | - "The spacy model 'en_core_web_sm' is required but not installed. " |
15 | | - "Install it with: python -m spacy download en_core_web_sm" |
16 | | - ) |
| 23 | +_SPACY_MODEL_NAME: Final[str] = "en_core_web_sm" |
| 24 | +_SPACY_MODEL_VERSION: Final[str] = "3.8.0" |
| 25 | +_SPACY_MODEL_URL: Final[str] = ( |
| 26 | + f"https://github.com/explosion/spacy-models/releases/download/" |
| 27 | + f"{_SPACY_MODEL_NAME}-{_SPACY_MODEL_VERSION}/" |
| 28 | + f"{_SPACY_MODEL_NAME}-{_SPACY_MODEL_VERSION}-py3-none-any.whl" |
| 29 | +) |
| 30 | +_SPACY_MODEL_SHA256: Final[str] = "1932429db727d4bff3deed6b34cfc05df17794f4a52eeb26cf8928f7c1a0fb85" |
| 31 | + |
| 32 | + |
| 33 | +_DOWNLOAD_TIMEOUT_SECONDS: Final[int] = 120 |
| 34 | +_INSTALL_LOCK_PATH: Final[str] = os.path.join( |
| 35 | + tempfile.gettempdir(), f"{_SPACY_MODEL_NAME}.install.lock" |
| 36 | +) |
| 37 | + |
| 38 | + |
| 39 | +def _download_with_timeout(url: str, dest: str) -> None: |
| 40 | + """Download a URL to a local file with a socket-level timeout.""" |
| 41 | + try: |
| 42 | + with urllib.request.urlopen(url, timeout=_DOWNLOAD_TIMEOUT_SECONDS) as resp: |
| 43 | + with open(dest, "wb") as out: |
| 44 | + shutil.copyfileobj(resp, out) |
| 45 | + except urllib.error.URLError as exc: |
| 46 | + raise RuntimeError( |
| 47 | + f"Failed to download spaCy model from {url}: {exc}. " |
| 48 | + "Check your network connection and try again." |
| 49 | + ) from exc |
| 50 | + |
| 51 | + |
| 52 | +def _install_spacy_model() -> None: |
| 53 | + """Download and install the pinned spaCy model wheel using the `installer` library.""" |
| 54 | + from installer import install |
| 55 | + from installer.destinations import SchemeDictionaryDestination |
| 56 | + from installer.sources import WheelFile |
| 57 | + from installer.utils import get_launcher_kind |
| 58 | + |
| 59 | + with tempfile.TemporaryDirectory() as tmp: |
| 60 | + whl_path = os.path.join(tmp, f"{_SPACY_MODEL_NAME}-{_SPACY_MODEL_VERSION}-py3-none-any.whl") |
| 61 | + logger.info("Downloading spaCy model %s %s …", _SPACY_MODEL_NAME, _SPACY_MODEL_VERSION) |
| 62 | + _download_with_timeout(_SPACY_MODEL_URL, whl_path) |
| 63 | + |
| 64 | + with open(whl_path, "rb") as f: |
| 65 | + sha256 = hashlib.sha256(f.read()).hexdigest() |
| 66 | + if sha256 != _SPACY_MODEL_SHA256: |
| 67 | + raise RuntimeError( |
| 68 | + f"Hash mismatch for {_SPACY_MODEL_NAME}: " |
| 69 | + f"expected {_SPACY_MODEL_SHA256}, got {sha256}" |
| 70 | + ) |
| 71 | + |
| 72 | + # Install into a staging directory to avoid races with other processes |
| 73 | + staging = os.path.join(tmp, "staging") |
| 74 | + paths = sysconfig.get_paths() |
| 75 | + staged_paths = paths.copy() |
| 76 | + staged_paths["purelib"] = staging |
| 77 | + staged_paths["platlib"] = staging |
| 78 | + |
| 79 | + destination = SchemeDictionaryDestination( |
| 80 | + staged_paths, |
| 81 | + interpreter=sys.executable, |
| 82 | + script_kind=get_launcher_kind(), |
| 83 | + ) |
| 84 | + with WheelFile.open(whl_path) as source: |
| 85 | + install(source=source, destination=destination, additional_metadata={}) |
| 86 | + |
| 87 | + # Move installed packages from staging into real site-packages. |
| 88 | + # The caller holds _INSTALL_LOCK_PATH so no other process races here. |
| 89 | + # Any dst that already exists is a remnant of a previous failed install |
| 90 | + # (spacy.load() just failed), so remove it before moving to avoid |
| 91 | + # shutil.move placing src *inside* an existing directory. |
| 92 | + site_packages = paths["purelib"] |
| 93 | + for item in os.listdir(staging): |
| 94 | + src = os.path.join(staging, item) |
| 95 | + dst = os.path.join(site_packages, item) |
| 96 | + try: |
| 97 | + if os.path.isdir(dst): |
| 98 | + shutil.rmtree(dst) |
| 99 | + elif os.path.exists(dst): |
| 100 | + os.remove(dst) |
| 101 | + shutil.move(src, dst) |
| 102 | + except OSError as exc: |
| 103 | + raise RuntimeError( |
| 104 | + f"Failed to install {_SPACY_MODEL_NAME} to {site_packages}: {exc}. " |
| 105 | + "Ensure the site-packages directory is writable, or pre-install the model " |
| 106 | + f"with: python -m spacy download {_SPACY_MODEL_NAME}" |
| 107 | + ) from exc |
| 108 | + |
| 109 | + logger.info("Installed %s %s", _SPACY_MODEL_NAME, _SPACY_MODEL_VERSION) |
| 110 | + |
| 111 | + |
| 112 | +def _load_spacy_model() -> spacy.language.Language: |
| 113 | + try: |
| 114 | + return spacy.load(_SPACY_MODEL_NAME) |
| 115 | + except OSError: |
| 116 | + pass |
| 117 | + |
| 118 | + # Serialize model installation across processes with an exclusive file lock. |
| 119 | + # A well-known path in the system temp dir is visible to all processes |
| 120 | + # regardless of their working directory. |
| 121 | + with FileLock(_INSTALL_LOCK_PATH, timeout=-1): |
| 122 | + # Double-check: another process may have installed while we waited. |
| 123 | + importlib.invalidate_caches() |
| 124 | + try: |
| 125 | + return spacy.load(_SPACY_MODEL_NAME) |
| 126 | + except OSError: |
| 127 | + pass |
| 128 | + _install_spacy_model() |
| 129 | + importlib.invalidate_caches() |
| 130 | + try: |
| 131 | + return spacy.load(_SPACY_MODEL_NAME) |
| 132 | + except OSError as exc: |
| 133 | + raise RuntimeError( |
| 134 | + f"Installed {_SPACY_MODEL_NAME} but spacy.load() still failed. " |
| 135 | + "Check site-packages permissions and installation integrity." |
| 136 | + ) from exc |
| 137 | + |
| 138 | + |
| 139 | +@lru_cache(maxsize=1) |
| 140 | +def _get_nlp() -> spacy.language.Language: |
| 141 | + """Load the spaCy model on first use and cache it for the lifetime of the process.""" |
| 142 | + return _load_spacy_model() |
17 | 143 |
|
18 | 144 |
|
19 | 145 | def _process(text: str) -> spacy.tokens.Doc: |
20 | 146 | """Run the spaCy pipeline once. All public functions extract what they need from the Doc.""" |
21 | 147 | # -- str() handles numpy.str_ from OCR pipelines -- |
22 | | - return _nlp(str(text)) |
| 148 | + return _get_nlp()(str(text)) |
23 | 149 |
|
24 | 150 |
|
25 | 151 | def sent_tokenize(text: str) -> List[str]: |
|
0 commit comments