|
1 | 1 | import os |
2 | 2 | from pathlib import Path |
3 | | -from ..runtime.paths import get_binary_dir, get_model_dir |
4 | | -from ..runtime.platform import get_platform_identifier, get_library_name |
5 | | -from ..runtime.downloader import download_file |
| 3 | + |
6 | 4 | from ..exceptions import DependencyMissingError |
| 5 | +from ..runtime.downloader import download_file |
| 6 | +from ..runtime.paths import get_binary_dir, get_model_dir |
| 7 | +from ..runtime.platform import get_library_name, get_platform_identifier |
7 | 8 | from .checksum import verify_checksum |
8 | 9 |
|
9 | 10 | POSTALKIT_VERSION = "v1.0.5" |
10 | | -BINARIES_BASE_URL = os.environ.get("POSTALKIT_BINARIES_URL", f"https://github.com/jayeshmepani/libpostal-ffi-python/releases/download/{POSTALKIT_VERSION}") |
11 | | -MODEL_DATA_URL = os.environ.get("POSTALKIT_MODEL_URL", "https://s3.amazonaws.com/libpostal/data/libpostal_data.tar.gz") |
| 11 | +BINARIES_BASE_URL = os.environ.get( |
| 12 | + "POSTALKIT_BINARIES_URL", |
| 13 | + f"https://github.com/jayeshmepani/libpostal-ffi-python/releases/download/{POSTALKIT_VERSION}", |
| 14 | +) |
| 15 | +MODEL_DATA_URL = os.environ.get( |
| 16 | + "POSTALKIT_MODEL_URL", "https://s3.amazonaws.com/libpostal/data/libpostal_data.tar.gz" |
| 17 | +) |
| 18 | + |
12 | 19 |
|
13 | 20 | def _download_and_verify(url: str, tar_path: Path, desc: str): |
14 | 21 | """Downloads a file and its .sha256 checksum file, then verifies and extracts.""" |
15 | 22 | # 1. Download the archive |
16 | 23 | download_file(url, tar_path, extract=False, desc=desc) |
17 | | - |
| 24 | + |
18 | 25 | # 2. Download the checksum |
19 | 26 | checksum_url = f"{url}.sha256" |
20 | 27 | checksum_path = tar_path.with_suffix(".tar.gz.sha256") |
21 | 28 | try: |
22 | 29 | download_file(checksum_url, checksum_path, extract=False, desc=f"{desc} Checksum") |
23 | | - with open(checksum_path, "r") as f: |
| 30 | + with open(checksum_path) as f: |
24 | 31 | # typical format: "hash filename" |
25 | 32 | expected_hash = f.read().strip().split()[0] |
26 | | - |
| 33 | + |
27 | 34 | if not verify_checksum(tar_path, expected_hash): |
28 | 35 | os.remove(tar_path) |
29 | | - raise DependencyMissingError(f"Checksum verification failed for {tar_path.name}") |
| 36 | + raise DependencyMissingError from None( |
| 37 | + f"Checksum verification failed for {tar_path.name}" |
| 38 | + ) |
30 | 39 | except Exception as e: |
31 | 40 | # If the checksum file doesn't exist remotely or fails to download, |
32 | 41 | # we strictly fail to prevent compromised or corrupt binaries in production. |
33 | 42 | if tar_path.exists(): |
34 | 43 | os.remove(tar_path) |
35 | 44 | if checksum_path.exists(): |
36 | 45 | os.remove(checksum_path) |
37 | | - raise DependencyMissingError(f"Failed to fetch or verify checksum for {tar_path.name}: {e}") |
| 46 | + raise DependencyMissingError from None( |
| 47 | + f"Failed to fetch or verify checksum for {tar_path.name}: {e}" |
| 48 | + ) |
38 | 49 |
|
39 | 50 | # 3. Extract after verification |
40 | 51 | from ..runtime.downloader import _extract_tar_gz |
| 52 | + |
41 | 53 | _extract_tar_gz(tar_path, tar_path.parent) |
42 | 54 |
|
| 55 | + |
43 | 56 | def ensure_models() -> Path: |
44 | 57 | """Ensures libpostal data models are downloaded and returns their path.""" |
45 | 58 | model_dir = get_model_dir() |
46 | 59 | marker_file = model_dir / "data_version" |
47 | | - |
| 60 | + |
48 | 61 | if not marker_file.exists(): |
49 | 62 | tar_path = model_dir / "libpostal_data.tar.gz" |
50 | 63 | print("PostalKit models not found. Downloading (~2GB)...") |
51 | 64 | _download_and_verify(MODEL_DATA_URL, tar_path, "Model Data") |
52 | 65 | marker_file.write_text("1") |
53 | | - |
| 66 | + |
54 | 67 | return model_dir |
55 | 68 |
|
| 69 | + |
56 | 70 | def ensure_binary() -> Path: |
57 | 71 | """Ensures the correct shared library for the platform is downloaded.""" |
58 | 72 | from ..runtime.paths import get_bundled_binary_dir |
| 73 | + |
59 | 74 | bin_dir = get_binary_dir() |
60 | 75 | lib_name = get_library_name() |
61 | 76 | platform_id = get_platform_identifier() |
62 | | - |
| 77 | + |
63 | 78 | # 1. Check bundled library |
64 | 79 | bundled_lib_path = get_bundled_binary_dir() / lib_name |
65 | 80 | if bundled_lib_path.exists(): |
66 | 81 | return bundled_lib_path |
67 | | - |
| 82 | + |
68 | 83 | # 2. Check auto-downloaded library |
69 | 84 | lib_path = bin_dir / lib_name |
70 | | - |
| 85 | + |
71 | 86 | if not lib_path.exists(): |
72 | 87 | download_url = f"{BINARIES_BASE_URL}/libpostal-{platform_id}.tar.gz" |
73 | 88 | tar_path = bin_dir / f"libpostal-{platform_id}.tar.gz" |
74 | 89 | print(f"PostalKit binary not found. Downloading for {platform_id}...") |
75 | 90 | _download_and_verify(download_url, tar_path, "Binary") |
76 | | - |
| 91 | + |
77 | 92 | return lib_path |
78 | 93 |
|
| 94 | + |
79 | 95 | def ensure_all_assets(): |
80 | 96 | """Download models and binaries if missing.""" |
81 | 97 | ensure_models() |
|
0 commit comments