Add low_cpu_mem_usage=True to from_pretrained for ~2x faster cold starts#355
Conversation
`dtype="bf16"` already loads at the target precision, but it doesn't make
the load any faster. `from_pretrained` still:
1. Allocates a fp32 random-initialized model shell (~745 MB for
gliner_medium-v2.1).
2. Runs Kaiming/Xavier init over every parameter.
3. Casts the entire shell to bf16 (when dtype= is set).
4. Overwrites every value with the loaded weights.
Steps 1-3 are all thrown away. This commit adds an opt-in
`low_cpu_mem_usage=True` flag that skips them: the model graph is built
under `torch.device("meta")` (shape descriptors, no allocation, no init
compute), the state dict is read at the target dtype, and
`load_state_dict(assign=True)` swaps the loaded tensors directly into the
meta-shell parameter slots in one pass. A small post-fix re-materializes
non-persistent buffers (DeBERTa's `position_ids`) that the state dict
doesn't carry.
Measured on RTX 5090 with `urchade/gliner_medium-v2.1`, n=12 reps per
mode, OS page cache warmed, Welch t-tested:
CPU bf16: 3.30s -> 1.60s (2.06x faster, 1700ms saved, t=+14.67)
CPU fp32: 3.04s -> 1.45s (2.10x faster, 1591ms saved, t=+12.81)
CUDA bf16: 3.16s -> 1.61s (1.96x faster, 1543ms saved, t=+20.96)
All effect sizes |t| > 12 — far above the noise floor. Stdev also drops
~3x (0.38s -> 0.12s) because there's much less work happening in the
load path.
Peak host RSS also improves:
- CPU bf16: 1597 MB -> 1225 MB (-23%)
- CPU fp32: 1598 MB -> 170 MB (-89%; safetensors mmap reuse)
- CUDA bf16: 1361 MB -> 1004 MB (-26%)
The fp32 case is dramatic because safetensors mmaps the on-disk file and
we never copy it into anonymous memory.
Verified bit-identical to the standard path: 0 missing keys, 0 unexpected
keys, all 224 parameters byte-compare equal, predictions match end-to-end
on a held-out sentence. Existing test suite passes (200 unit tests, 1
pre-existing skip, 1 pre-existing import error in
tests/test_infer_packing.py unrelated to this change).
Default remains `False` — the path is opt-in until it has runtime
exposure across more architectures. Wired through both
`BaseGLiNER.from_pretrained` (line 768) and the outer
`GLiNER.from_pretrained` dispatcher (line 4262).
Adds 4 unit tests for `_materialize_meta_buffers` (54 total in
test_quantize_and_dtype.py, all passing). docs/usage.md gains a
"Skipping the random-init shell" subsection under the existing dtype=
section, with the benchmark table.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Validation across cached GLiNER architectures revealed two real gaps in
the original meta-init path:
1. ``token_type_ids`` — BERT-family standard non-persistent buffer,
used by the BGE labels encoder in BiEncoderSpanGLiNER. Canonical
value is zeros — now handled in ``_materialize_meta_buffers``.
2. ``rotary_emb.inv_freq`` — RoPE inverse-frequency buffer in
ModernBERT (and ettin-encoder, used by knowledgator/gliner-bi-base-v2.0).
The canonical value is computed as
``1 / (base ** (arange(0, dim, 2) / dim))`` where ``base`` varies
per-architecture (10000 for standard, 160000 for ModernBERT local
attention) and isn't recoverable from the buffer alone. The
previous "zero-fill + warn" behavior would have shipped silently
broken inference (zeros break RoPE attention).
Reworked ``_materialize_meta_buffers`` to return a
``(materialized, unrecognized)`` tuple. ``from_pretrained`` checks for
unrecognized buffers and, if any exist, deletes the partial meta state
and falls back to the standard load path with a single ``UserWarning``
naming the unsupported buffer pattern.
Net effect:
- DeBERTa-based architectures (UniEncoderSpan, UniEncoderToken):
unchanged — full meta-init speedup.
- BERT-family bi-encoder (BGE labels encoder): now also uses meta-init
via the new ``token_type_ids`` handler.
- RoPE-based bi-encoders (ModernBERT, ettin): auto-fall-back to the
standard path with a clear warning. Bit-identical loaded params
via the fallback. No risk of silently broken inference.
Validation script ``benchmarks/low_cpu_mem_usage/arch_validation.py``
covers 6 cached models across 3 dispatcher classes:
urchade/gliner_small-v2.1 UniEncoderSpanGLiNER OK (meta path)
urchade/gliner_large-v2.1 UniEncoderSpanGLiNER OK (meta path)
gliner-community/gliner_small-v2.5 UniEncoderSpanGLiNER OK (meta path)
knowledgator/gliner-multitask-large-v0.5 UniEncoderTokenGLiNER OK (meta path)
knowledgator/gliner-bi-base-v2.0 BiEncoderSpanGLiNER OK (auto-fallback: inv_freq)
knowledgator/modern-gliner-bi-base-v1.0 BiEncoderSpanGLiNER OK (auto-fallback: inv_freq)
All 6 produce parameters bit-identical (sha256 hash) to the standard
load path. Inference predictions on a held-out sentence match across
all DeBERTa models; bi-encoder inference is skipped because the
baseline forward path is broken upstream
(``BertModel.forward() got an unexpected keyword argument 'token_lengths'``)
— the load itself succeeds for both baseline and lowmem.
Test suite: 55 cases pass (54 -> 55 with the new
``test_token_type_ids_restored_to_zeros`` and the contract change to
``test_unknown_buffer_returned_as_unrecognized``). ruff lint and format
clean. Performance unchanged — single-rep sanity at cuda_lowmem_bf16
is 1.91s vs cuda_baseline_bf16 2.80s, in line with the n=12 result of
1.61s vs 3.16s reported in the original commit.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Codex review finding [P2]: with ``low_cpu_mem_usage=True`` and the
default ``strict=False``, ``load_state_dict(assign=True)`` succeeds
even if the checkpoint is missing a parameter — but that parameter
stays on the meta device. The subsequent ``instance.model.to(map_location)``
then raises ``NotImplementedError: Cannot copy out of meta tensor``.
The standard load path would have kept the random-initialized value
and loaded successfully, so the meta path is a strict regression for
this case.
Fix: after ``load_state_dict(assign=True)``, scan ``named_parameters()``
for any tensor still on meta. If any are found (or the existing
unrecognized-buffer check fired), discard the partial meta state and
fall back to the standard load path. The fallback warning now names
the cause — either a list of unrecognized non-persistent buffers (e.g.
RoPE ``inv_freq``) or a sample of the missing parameter names —
truncated for readability when the missing-key set is large.
End-to-end verification on a synthetic ``urchade/gliner_medium-v2.1``
clone with ``span_rep_layer.span_rep_layer.out_project.0.bias``
removed from ``model.safetensors``:
- ``low_cpu_mem_usage=True``: load succeeds via the fallback,
user-visible UserWarning names the missing key, ``param dtype`` is
correct (bfloat16). Pre-fix this would have raised
``NotImplementedError`` from ``.to()``.
Adds ``TestMetaParamFallbackContract`` (3 cases) asserting the
underlying contract: ``load_state_dict(assign=True, strict=False)``
leaves params on meta when keys are missing, the post-assign scan
finds them, and ``.to()`` on a remaining meta param raises. 58 unit
tests pass total (was 55).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
Ran a Codex review on this branch. One finding, addressed in [P2] Fall back when assign-load leaves meta parameters Fix: after End-to-end verification on a synthetic
Adds |
Resolves conflicts between low_cpu_mem_usage= (this branch) and variant= (merged from main). Both knobs are independent and now coexist in from_pretrained: variant= narrows the download, low_cpu_mem_usage= skips the random-init shell during load. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
Resolved the merge conflicts with |
|
@maxwbuckley , solid contribution, thank you. |
Add
low_cpu_mem_usage=Truetofrom_pretrainedfor ~2× faster cold startsWhy
Cold-start latency in
from_pretrainedhas three layers:variant=PR (#354) addresses this: when the publisher has uploaded a half-precision file, the download shrinks by ~half (745 MB → 372 MB forgliner_medium-v2.1).from_pretrainedstill:dtype=is set).Steps 1–3 are all thrown away. The randoms are never used. The cast operates on data that's about to be overwritten.
The two lower-layer changes are orthogonal and stack.
variant=shrinks the download;low_cpu_mem_usage=Trueshrinks the per-load CPU work that fires regardless of what the publisher uploaded.What this PR adds
A new opt-in
low_cpu_mem_usage: bool = Falseflag on bothBaseGLiNER.from_pretrainedand the outerGLiNER.from_pretraineddispatcher:When
True:torch.device("meta")— shape descriptors only, no allocation, no init compute.dtype=cast-on-read path).load_state_dict(state_dict, assign=True)swaps the loaded tensors directly into the meta-shell parameter slots in one pass.position_ids,token_type_ids) the state dict didn't carry.inv_freq, whosebasevaries per-architecture). The fallback emits a singleUserWarningnaming the unsupported buffer and produces a bit-identical model via the standard load path.map_location.The auto-fallback is the key design choice: rather than ship silently-broken inference for architectures we haven't validated, we detect the unsupported case and quietly use the slow path. Users get the speedup where it's safe and a working model where it isn't, with the warning telling them which case they hit.
Benchmark results
urchade/gliner_medium-v2.1, RTX 5090, n=12 reps per mode, OS page cache warmed (2 discarded warmups), Welch t-tested. Each load runs in a fresh subprocess so peak memory isn't contaminated by prior allocations. Modes interleave within each rep block to defuse warm-cache bias.Wall-clock load time
low_cpu_mem_usage=TrueAll effect sizes are |t| > 12 — far above the noise floor. ~1.5 s saved on every cold start, regardless of dtype or device. Stdev also drops ~3× (0.38 s → 0.12 s for CPU bf16) because there's much less work happening in the load path.
Peak host RSS
low_cpu_mem_usage=TrueThe CPU fp32 case is dramatic because safetensors mmaps the on-disk file and the loaded tensors are views into the mmap — we never copy the model into anonymous memory at all.
GPU peak unchanged at 585 MB (the
dtype=work in PR #348 already covered that).Raw data:
benchmarks/low_cpu_mem_usage/results.json. Driver:benchmarks/low_cpu_mem_usage/run_bench.py.Architecture validation
Validated across every cached GLiNER architecture I had locally — script at
benchmarks/low_cpu_mem_usage/arch_validation.py. Each model loads twice (baseline vs.low_cpu_mem_usage=True), the parameters are SHA-256 hashed end-to-end, and predictions are compared on a held-out sentence:urchade/gliner_small-v2.1UniEncoderSpanGLiNERurchade/gliner_large-v2.1UniEncoderSpanGLiNERgliner-community/gliner_small-v2.5UniEncoderSpanGLiNERknowledgator/gliner-multitask-large-v0.5UniEncoderTokenGLiNERknowledgator/gliner-bi-base-v2.0BiEncoderSpanGLiNERinv_freqfrom ettin RoPE)knowledgator/modern-gliner-bi-base-v1.0BiEncoderSpanGLiNERinv_freqfrom ModernBERT RoPE)4 of 6 models take the fast meta-init path. 2 of 6 auto-fall-back due to RoPE. Both groups produce bit-identical loaded parameters to the standard load. The two skipped predictions are because of an upstream bug in the bi-encoder forward path (
BertModel.forward() got an unexpected keyword argument 'token_lengths') that fails on the baseline too — unrelated to this PR.The bi-encoder inference fail is a pre-existing GLiNER bug on these specific bi-encoder repos; the load itself works correctly.
Architectures not validated
I didn't have cached checkpoints for these dispatcher classes:
BiEncoderTokenGLiNERUniEncoderSpanDecoderGLiNER/UniEncoderTokenDecoderGLiNERUniEncoderSpanRelexGLiNER/UniEncoderTokenRelexGLiNERIf any of them register non-persistent buffers we don't recognize, they'll auto-fall-back to the standard path with a warning naming the unsupported buffer. They will not silently produce wrong inference. Worth running the validation script against representative checkpoints once they're available.
Correctness verification
load_state_dict(assign=True)on the DeBERTa path.map_location="cuda"produces a model with all parameters oncuda:0at the requested dtype.No regressions
tests/test_infer_packing.py(ModuleNotFoundError: No module named 'tests.utils_infer'— broken before this PR)._materialize_meta_buffers: position_ids round-trip, no-op when nothing on meta, nested-module recursion, token_type_ids round-trip, unknown-buffer returns asunrecognized(so the caller can fall back).ruff checkandruff format --checkclean.Files changed
gliner/model.py—_materialize_meta_buffersstatic helper returning(materialized, unrecognized).low_cpu_mem_usage: bool = Falseargument added to bothfrom_pretrainedentry points. The innerfrom_pretrainedbranches between standard path and meta-init; on unrecognized buffers, the meta state is discarded and the standard path runs as a fallback. Standard path remains byte-for-byte unchanged.tests/test_quantize_and_dtype.py— newTestMaterializeMetaBuffersclass with 5 cases. 55 total tests passing.docs/usage.md— new "Skipping the random-init shell (low_cpu_mem_usage)" subsection with the benchmark table.What this stacks with
gliner_medium-v2.1dtype="bf16"(already merged)variant="bf16"(PR #354)low_cpu_mem_usage=True(this PR)Used together, the cold-start path on
gliner_medium-v2.1against a publisher who has uploaded a bf16 variant goes from "download 745 MB, build fp32 shell, init, cast, load" (~3 s + download) to "download 372 MB, build meta shell, load directly into bf16" (~1.5 s + half the download).Why opt-in (default
False)Real downsides — none dealbreakers, all reasons to ship cautiously:
load_state_dict(assign=True)requires PyTorch 2.1+. Default-on would break users pinned to older torch.assign=Truereplaces parameters with new tensors; if a model ties input embeddings to output projections (RoBERTa-style), the tie breaks. DeBERTa-v3 doesn't tie, so current GLiNER models are safe — but a future fine-tune on a tied-weight backbone would silently degrade.BaseGLiNERfor a custom architecture would see their_create_modelrun undertorch.device("meta")if the default flips. Subtle bugs in custom code wouldn't be caught.Recommended rollout: ship as opt-in, validate the remaining 5 architectures against representative checkpoints, then flip the default. The auto-fallback design means flipping the default is a smaller risk than it would be otherwise — anything we missed produces a fallback warning, not a wrong model.
Followups
_materialize_meta_buffersto recognize them turns auto-fallback into full speedup. A small contribution from anyone hitting the warning.gliner.serveCLI flag. The serving layer is the canonical cold-start use case; adding--low-cpu-mem-usagetogliner.serve.__main__is the obvious follow-up. Deliberately scoped out of this PR.