-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathloaders.py
More file actions
143 lines (98 loc) · 4.07 KB
/
loaders.py
File metadata and controls
143 lines (98 loc) · 4.07 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
"""Helpers for loading the standalone EmbBERT bundle.
The simplified bundle ships one verified pretraining checkpoint together with
the tokenizer, configuration files, project library code, and runnable scripts.
The loader supports both the standalone ``EmbBERT/`` layout and the original
repository layout where the bundle lives in a subdirectory.
"""
from __future__ import annotations
import json
import sys
from pathlib import Path
def bundle_root() -> Path:
"""Return the absolute path to the EmbBERT bundle directory."""
return Path(__file__).resolve().parent
def repo_root() -> Path:
"""Return the directory that contains the runtime project files.
Returns:
The bundle root when it already contains ``lib/``. Otherwise, the
parent directory is returned for the original repository layout.
"""
for candidate in (bundle_root(), bundle_root().parent):
if (candidate / "lib").exists():
return candidate
return bundle_root().parent
def _ensure_repo_on_path() -> None:
"""Add the repository root to ``sys.path`` if needed."""
root = str(repo_root())
if root not in sys.path:
sys.path.insert(0, root)
def load_manifest() -> dict:
"""Load the bundle manifest.
Returns:
The decoded manifest dictionary.
"""
manifest_path = bundle_root() / "manifest.json"
return json.loads(manifest_path.read_text())
def build_tokenizer(max_length: int = 256):
"""Build a ``PreTrainedTokenizerFast`` from the bundled tokenizer file.
Args:
max_length: Maximum sequence length to expose through the tokenizer.
Returns:
A tokenizer configured with the special tokens used by this project.
"""
_ensure_repo_on_path()
from tokenizers import Tokenizer
from transformers import PreTrainedTokenizerFast
tokenizer_path = bundle_root() / "tokenizers" / "bpe_book_corpus_8192.json"
tokenizer = Tokenizer.from_file(str(tokenizer_path))
return PreTrainedTokenizerFast(
tokenizer_object=tokenizer,
unk_token="[UNK]",
sep_token="[SEP]",
pad_token="[PAD]",
cls_token="[CLS]",
mask_token="[MASK]",
max_length=max_length,
)
def list_pretraining_checkpoints() -> list[str]:
"""Return bundled pretraining checkpoint names in ascending order."""
manifest = load_manifest()
return [
artifact["name"]
for artifact in manifest["artifacts"]
if artifact["kind"] == "pretraining"
]
def load_pretraining_checkpoint(name: str = "checkpoint-616000"):
"""Load one bundled EmbBERT pretraining checkpoint.
Args:
name: Checkpoint directory name under ``checkpoints/pretraining``.
Returns:
A tuple ``(model, tokenizer)`` using the current repository classes.
Raises:
FileNotFoundError: If the requested checkpoint directory does not exist.
"""
_ensure_repo_on_path()
from lib.Models.EmbBERT import EmbBERT_Config
from lib.Models.classifiers import PretrainingClassifier
checkpoint_dir = bundle_root() / "checkpoints" / "pretraining" / name
if not checkpoint_dir.exists():
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_dir}")
config = EmbBERT_Config.from_pretrained(str(checkpoint_dir))
model = PretrainingClassifier.from_pretrained(str(checkpoint_dir), config=config)
tokenizer = build_tokenizer(max_length=int(config.max_length))
return model, tokenizer
def load_latest_pretraining_checkpoint():
"""Load the default bundled pretraining checkpoint.
Returns:
A tuple ``(model, tokenizer)`` for checkpoint ``616000``.
"""
return load_pretraining_checkpoint("checkpoint-616000")
def load_legacy_sequence_metadata() -> dict:
"""Report that the simplified bundle no longer ships legacy artifacts.
Raises:
FileNotFoundError: Always raised because the simplified bundle only
retains the ``checkpoint-616000`` pretraining artifact.
"""
raise FileNotFoundError(
"Legacy sequence metadata is not included in the simplified EmbBERT bundle."
)