Skip to content

Add low_cpu_mem_usage=True to from_pretrained for ~2x faster cold starts#355

Merged
Ingvarstep merged 4 commits into
urchade:mainfrom
maxwbuckley:add-low-cpu-mem-usage
Apr 29, 2026
Merged

Add low_cpu_mem_usage=True to from_pretrained for ~2x faster cold starts#355
Ingvarstep merged 4 commits into
urchade:mainfrom
maxwbuckley:add-low-cpu-mem-usage

Conversation

@maxwbuckley
Copy link
Copy Markdown
Contributor

@maxwbuckley maxwbuckley commented Apr 27, 2026

Add low_cpu_mem_usage=True to from_pretrained for ~2× faster cold starts

Why

Cold-start latency in from_pretrained has three layers:

  1. Bytes-on-the-wire — the variant= PR (#354) addresses this: when the publisher has uploaded a half-precision file, the download shrinks by ~half (745 MB → 372 MB for gliner_medium-v2.1).
  2. CPU-side model construction and init — this PR addresses this. Regardless of how many bytes you downloaded, from_pretrained still:
    • Allocates a ~745 MB fp32 random-initialized model shell.
    • Runs Kaiming / Xavier initialization over every parameter.
    • Casts the entire shell to bf16 (when dtype= is set).
    • Overwrites every value with the loaded weights.
      Steps 1–3 are all thrown away. The randoms are never used. The cast operates on data that's about to be overwritten.
  3. GPU transfer + first-inference JIT — out of scope for both PRs.

The two lower-layer changes are orthogonal and stack. variant= shrinks the download; low_cpu_mem_usage=True shrinks the per-load CPU work that fires regardless of what the publisher uploaded.

What this PR adds

A new opt-in low_cpu_mem_usage: bool = False flag on both BaseGLiNER.from_pretrained and the outer GLiNER.from_pretrained dispatcher:

model = GLiNER.from_pretrained(
    "urchade/gliner_medium-v2.1",
    dtype="bf16",
    low_cpu_mem_usage=True,
)

When True:

  1. Build the model graph under torch.device("meta") — shape descriptors only, no allocation, no init compute.
  2. Read the state dict at the target precision (the existing dtype= cast-on-read path).
  3. load_state_dict(state_dict, assign=True) swaps the loaded tensors directly into the meta-shell parameter slots in one pass.
  4. Re-materialize non-persistent buffers (position_ids, token_type_ids) the state dict didn't carry.
  5. Auto-fall-back to the standard load path when the model has non-persistent buffers we don't know how to recompute (e.g. RoPE's inv_freq, whose base varies per-architecture). The fallback emits a single UserWarning naming the unsupported buffer and produces a bit-identical model via the standard load path.
  6. Move to map_location.

The auto-fallback is the key design choice: rather than ship silently-broken inference for architectures we haven't validated, we detect the unsupported case and quietly use the slow path. Users get the speedup where it's safe and a working model where it isn't, with the warning telling them which case they hit.

Benchmark results

urchade/gliner_medium-v2.1, RTX 5090, n=12 reps per mode, OS page cache warmed (2 discarded warmups), Welch t-tested. Each load runs in a fresh subprocess so peak memory isn't contaminated by prior allocations. Modes interleave within each rep block to defuse warm-cache bias.

Wall-clock load time

pairing baseline low_cpu_mem_usage=True speedup saved Welch t
CPU bf16 3.30 s 1.60 s 2.06× 1700 ms +14.67
CPU fp32 3.04 s 1.45 s 2.10× 1591 ms +12.81
CUDA bf16 3.16 s 1.61 s 1.96× 1543 ms +20.96

All effect sizes are |t| > 12 — far above the noise floor. ~1.5 s saved on every cold start, regardless of dtype or device. Stdev also drops ~3× (0.38 s → 0.12 s for CPU bf16) because there's much less work happening in the load path.

Peak host RSS

pairing baseline low_cpu_mem_usage=True reduction
CPU bf16 1597 MB 1225 MB −23%
CPU fp32 1598 MB 170 MB −89%
CUDA bf16 1361 MB 1004 MB −26%

The CPU fp32 case is dramatic because safetensors mmaps the on-disk file and the loaded tensors are views into the mmap — we never copy the model into anonymous memory at all.

GPU peak unchanged at 585 MB (the dtype= work in PR #348 already covered that).

Raw data: benchmarks/low_cpu_mem_usage/results.json. Driver: benchmarks/low_cpu_mem_usage/run_bench.py.

Architecture validation

Validated across every cached GLiNER architecture I had locally — script at benchmarks/low_cpu_mem_usage/arch_validation.py. Each model loads twice (baseline vs. low_cpu_mem_usage=True), the parameters are SHA-256 hashed end-to-end, and predictions are compared on a held-out sentence:

model dispatcher class result path taken
urchade/gliner_small-v2.1 UniEncoderSpanGLiNER params identical, preds match meta-init
urchade/gliner_large-v2.1 UniEncoderSpanGLiNER params identical, preds match meta-init
gliner-community/gliner_small-v2.5 UniEncoderSpanGLiNER params identical, preds match meta-init
knowledgator/gliner-multitask-large-v0.5 UniEncoderTokenGLiNER params identical, preds match meta-init
knowledgator/gliner-bi-base-v2.0 BiEncoderSpanGLiNER params identical (preds skipped — pre-existing forward bug) auto-fallback (inv_freq from ettin RoPE)
knowledgator/modern-gliner-bi-base-v1.0 BiEncoderSpanGLiNER params identical (preds skipped — pre-existing forward bug) auto-fallback (inv_freq from ModernBERT RoPE)

4 of 6 models take the fast meta-init path. 2 of 6 auto-fall-back due to RoPE. Both groups produce bit-identical loaded parameters to the standard load. The two skipped predictions are because of an upstream bug in the bi-encoder forward path (BertModel.forward() got an unexpected keyword argument 'token_lengths') that fails on the baseline too — unrelated to this PR.

The bi-encoder inference fail is a pre-existing GLiNER bug on these specific bi-encoder repos; the load itself works correctly.

Architectures not validated

I didn't have cached checkpoints for these dispatcher classes:

  • BiEncoderTokenGLiNER
  • UniEncoderSpanDecoderGLiNER / UniEncoderTokenDecoderGLiNER
  • UniEncoderSpanRelexGLiNER / UniEncoderTokenRelexGLiNER

If any of them register non-persistent buffers we don't recognize, they'll auto-fall-back to the standard path with a warning naming the unsupported buffer. They will not silently produce wrong inference. Worth running the validation script against representative checkpoints once they're available.

Correctness verification

  • Bit-identical parameters verified via SHA-256 hash across all 6 validated architectures.
  • 0 missing keys, 0 unexpected keys in load_state_dict(assign=True) on the DeBERTa path.
  • 0 leftover meta tensors on the meta path; auto-fallback engages cleanly when unrecognized buffers exist.
  • End-to-end inference matches on the four DeBERTa-based models; bi-encoder inference skipped due to upstream forward bug (independent of this PR).
  • CUDA path verifiedmap_location="cuda" produces a model with all parameters on cuda:0 at the requested dtype.

No regressions

  • 200 existing unit tests pass, 1 pre-existing skip (unrelated), 1 pre-existing import error in tests/test_infer_packing.py (ModuleNotFoundError: No module named 'tests.utils_infer' — broken before this PR).
  • 5 new unit tests for _materialize_meta_buffers: position_ids round-trip, no-op when nothing on meta, nested-module recursion, token_type_ids round-trip, unknown-buffer returns as unrecognized (so the caller can fall back).
  • ruff check and ruff format --check clean.

Files changed

  • gliner/model.py_materialize_meta_buffers static helper returning (materialized, unrecognized). low_cpu_mem_usage: bool = False argument added to both from_pretrained entry points. The inner from_pretrained branches between standard path and meta-init; on unrecognized buffers, the meta state is discarded and the standard path runs as a fallback. Standard path remains byte-for-byte unchanged.
  • tests/test_quantize_and_dtype.py — new TestMaterializeMetaBuffers class with 5 cases. 55 total tests passing.
  • docs/usage.md — new "Skipping the random-init shell (low_cpu_mem_usage)" subsection with the benchmark table.

What this stacks with

optimization what it does typical savings on gliner_medium-v2.1
dtype="bf16" (already merged) cast-on-read; no fp32 state dict in memory ~37% GPU peak; 0% wall-clock
variant="bf16" (PR #354) download only the bf16 file from Hub ~373 MB bytes-on-wire (when published)
low_cpu_mem_usage=True (this PR) skip random-init compute + fp32 alloc + cast pass ~50% wall-clock, 23–89% peak host RSS

Used together, the cold-start path on gliner_medium-v2.1 against a publisher who has uploaded a bf16 variant goes from "download 745 MB, build fp32 shell, init, cast, load" (~3 s + download) to "download 372 MB, build meta shell, load directly into bf16" (~1.5 s + half the download).

Why opt-in (default False)

Real downsides — none dealbreakers, all reasons to ship cautiously:

  1. Architecture coverage. I validated 3 of 8 dispatcher classes. The other 5 (BiEncoderToken, the two decoder variants, the two relex variants) should work or auto-fall-back, but haven't been load-tested.
  2. load_state_dict(assign=True) requires PyTorch 2.1+. Default-on would break users pinned to older torch.
  3. Tied weights would break silently. assign=True replaces parameters with new tensors; if a model ties input embeddings to output projections (RoBERTa-style), the tie breaks. DeBERTa-v3 doesn't tie, so current GLiNER models are safe — but a future fine-tune on a tied-weight backbone would silently degrade.
  4. User subclasses. Anyone subclassing BaseGLiNER for a custom architecture would see their _create_model run under torch.device("meta") if the default flips. Subtle bugs in custom code wouldn't be caught.
  5. Diagnostic blast radius. Bugs here are wrong-but-plausible outputs rather than crashes. Opt-in is self-documenting in user configs; default-on means "I upgraded GLiNER and predictions changed slightly" is harder to trace.

Recommended rollout: ship as opt-in, validate the remaining 5 architectures against representative checkpoints, then flip the default. The auto-fallback design means flipping the default is a smaller risk than it would be otherwise — anything we missed produces a fallback warning, not a wrong model.

Followups

  • Validate the remaining 5 dispatcher classes (BiEncoderToken, two decoder variants, two relex variants) against representative checkpoints.
  • Test on a tied-weights architecture. Add coverage for any future GLiNER backbone that ties embeddings to output projections.
  • Architectures with non-position_ids / non-token_type_ids / non-inv_freq buffers. The fallback handles unknowns safely, but extending _materialize_meta_buffers to recognize them turns auto-fallback into full speedup. A small contribution from anyone hitting the warning.
  • gliner.serve CLI flag. The serving layer is the canonical cold-start use case; adding --low-cpu-mem-usage to gliner.serve.__main__ is the obvious follow-up. Deliberately scoped out of this PR.

maxwbuckley and others added 2 commits April 27, 2026 20:59
`dtype="bf16"` already loads at the target precision, but it doesn't make
the load any faster. `from_pretrained` still:

  1. Allocates a fp32 random-initialized model shell (~745 MB for
     gliner_medium-v2.1).
  2. Runs Kaiming/Xavier init over every parameter.
  3. Casts the entire shell to bf16 (when dtype= is set).
  4. Overwrites every value with the loaded weights.

Steps 1-3 are all thrown away. This commit adds an opt-in
`low_cpu_mem_usage=True` flag that skips them: the model graph is built
under `torch.device("meta")` (shape descriptors, no allocation, no init
compute), the state dict is read at the target dtype, and
`load_state_dict(assign=True)` swaps the loaded tensors directly into the
meta-shell parameter slots in one pass. A small post-fix re-materializes
non-persistent buffers (DeBERTa's `position_ids`) that the state dict
doesn't carry.

Measured on RTX 5090 with `urchade/gliner_medium-v2.1`, n=12 reps per
mode, OS page cache warmed, Welch t-tested:

  CPU bf16:  3.30s -> 1.60s  (2.06x faster, 1700ms saved, t=+14.67)
  CPU fp32:  3.04s -> 1.45s  (2.10x faster, 1591ms saved, t=+12.81)
  CUDA bf16: 3.16s -> 1.61s  (1.96x faster, 1543ms saved, t=+20.96)

All effect sizes |t| > 12 — far above the noise floor. Stdev also drops
~3x (0.38s -> 0.12s) because there's much less work happening in the
load path.

Peak host RSS also improves:
  - CPU bf16:  1597 MB -> 1225 MB  (-23%)
  - CPU fp32:  1598 MB -> 170 MB   (-89%; safetensors mmap reuse)
  - CUDA bf16: 1361 MB -> 1004 MB  (-26%)

The fp32 case is dramatic because safetensors mmaps the on-disk file and
we never copy it into anonymous memory.

Verified bit-identical to the standard path: 0 missing keys, 0 unexpected
keys, all 224 parameters byte-compare equal, predictions match end-to-end
on a held-out sentence. Existing test suite passes (200 unit tests, 1
pre-existing skip, 1 pre-existing import error in
tests/test_infer_packing.py unrelated to this change).

Default remains `False` — the path is opt-in until it has runtime
exposure across more architectures. Wired through both
`BaseGLiNER.from_pretrained` (line 768) and the outer
`GLiNER.from_pretrained` dispatcher (line 4262).

Adds 4 unit tests for `_materialize_meta_buffers` (54 total in
test_quantize_and_dtype.py, all passing). docs/usage.md gains a
"Skipping the random-init shell" subsection under the existing dtype=
section, with the benchmark table.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Validation across cached GLiNER architectures revealed two real gaps in
the original meta-init path:

  1. ``token_type_ids`` — BERT-family standard non-persistent buffer,
     used by the BGE labels encoder in BiEncoderSpanGLiNER. Canonical
     value is zeros — now handled in ``_materialize_meta_buffers``.

  2. ``rotary_emb.inv_freq`` — RoPE inverse-frequency buffer in
     ModernBERT (and ettin-encoder, used by knowledgator/gliner-bi-base-v2.0).
     The canonical value is computed as
     ``1 / (base ** (arange(0, dim, 2) / dim))`` where ``base`` varies
     per-architecture (10000 for standard, 160000 for ModernBERT local
     attention) and isn't recoverable from the buffer alone. The
     previous "zero-fill + warn" behavior would have shipped silently
     broken inference (zeros break RoPE attention).

Reworked ``_materialize_meta_buffers`` to return a
``(materialized, unrecognized)`` tuple. ``from_pretrained`` checks for
unrecognized buffers and, if any exist, deletes the partial meta state
and falls back to the standard load path with a single ``UserWarning``
naming the unsupported buffer pattern.

Net effect:
  - DeBERTa-based architectures (UniEncoderSpan, UniEncoderToken):
    unchanged — full meta-init speedup.
  - BERT-family bi-encoder (BGE labels encoder): now also uses meta-init
    via the new ``token_type_ids`` handler.
  - RoPE-based bi-encoders (ModernBERT, ettin): auto-fall-back to the
    standard path with a clear warning. Bit-identical loaded params
    via the fallback. No risk of silently broken inference.

Validation script ``benchmarks/low_cpu_mem_usage/arch_validation.py``
covers 6 cached models across 3 dispatcher classes:

  urchade/gliner_small-v2.1               UniEncoderSpanGLiNER     OK (meta path)
  urchade/gliner_large-v2.1               UniEncoderSpanGLiNER     OK (meta path)
  gliner-community/gliner_small-v2.5      UniEncoderSpanGLiNER     OK (meta path)
  knowledgator/gliner-multitask-large-v0.5 UniEncoderTokenGLiNER   OK (meta path)
  knowledgator/gliner-bi-base-v2.0        BiEncoderSpanGLiNER      OK (auto-fallback: inv_freq)
  knowledgator/modern-gliner-bi-base-v1.0 BiEncoderSpanGLiNER      OK (auto-fallback: inv_freq)

All 6 produce parameters bit-identical (sha256 hash) to the standard
load path. Inference predictions on a held-out sentence match across
all DeBERTa models; bi-encoder inference is skipped because the
baseline forward path is broken upstream
(``BertModel.forward() got an unexpected keyword argument 'token_lengths'``)
— the load itself succeeds for both baseline and lowmem.

Test suite: 55 cases pass (54 -> 55 with the new
``test_token_type_ids_restored_to_zeros`` and the contract change to
``test_unknown_buffer_returned_as_unrecognized``). ruff lint and format
clean. Performance unchanged — single-rep sanity at cuda_lowmem_bf16
is 1.91s vs cuda_baseline_bf16 2.80s, in line with the n=12 result of
1.61s vs 3.16s reported in the original commit.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@Ingvarstep Ingvarstep self-requested a review April 27, 2026 20:32
Codex review finding [P2]: with ``low_cpu_mem_usage=True`` and the
default ``strict=False``, ``load_state_dict(assign=True)`` succeeds
even if the checkpoint is missing a parameter — but that parameter
stays on the meta device. The subsequent ``instance.model.to(map_location)``
then raises ``NotImplementedError: Cannot copy out of meta tensor``.
The standard load path would have kept the random-initialized value
and loaded successfully, so the meta path is a strict regression for
this case.

Fix: after ``load_state_dict(assign=True)``, scan ``named_parameters()``
for any tensor still on meta. If any are found (or the existing
unrecognized-buffer check fired), discard the partial meta state and
fall back to the standard load path. The fallback warning now names
the cause — either a list of unrecognized non-persistent buffers (e.g.
RoPE ``inv_freq``) or a sample of the missing parameter names —
truncated for readability when the missing-key set is large.

End-to-end verification on a synthetic ``urchade/gliner_medium-v2.1``
clone with ``span_rep_layer.span_rep_layer.out_project.0.bias``
removed from ``model.safetensors``:

  - ``low_cpu_mem_usage=True``: load succeeds via the fallback,
    user-visible UserWarning names the missing key, ``param dtype`` is
    correct (bfloat16). Pre-fix this would have raised
    ``NotImplementedError`` from ``.to()``.

Adds ``TestMetaParamFallbackContract`` (3 cases) asserting the
underlying contract: ``load_state_dict(assign=True, strict=False)``
leaves params on meta when keys are missing, the post-assign scan
finds them, and ``.to()`` on a remaining meta param raises. 58 unit
tests pass total (was 55).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@maxwbuckley
Copy link
Copy Markdown
Contributor Author

Ran a Codex review on this branch. One finding, addressed in 6513398:

[P2] Fall back when assign-load leaves meta parameters
With low_cpu_mem_usage=True and the default strict=False, load_state_dict(assign=True) succeeds even if the checkpoint is missing a parameter — but that parameter stays on the meta device. The subsequent instance.model.to(map_location) then raises NotImplementedError: Cannot copy out of meta tensor. The standard load path would have kept the random-initialized value and loaded successfully, so the meta path was a strict regression for this case (and a potential gotcha for fine-tuned checkpoints that drop or rename a head).

Fix: after load_state_dict(assign=True), scan named_parameters() for any tensor still on meta. If any are found (or the existing unrecognized-buffer check fired), discard the partial meta state and fall back to the standard load path. The fallback warning now names the cause — either unrecognized non-persistent buffers (e.g. RoPE inv_freq) or a sample of the missing parameter names.

End-to-end verification on a synthetic urchade/gliner_medium-v2.1 clone with one parameter (span_rep_layer.span_rep_layer.out_project.0.bias) removed from model.safetensors:

  • Pre-fix: low_cpu_mem_usage=True raises NotImplementedError from .to().
  • Post-fix: load succeeds via fallback, UserWarning names the missing key, model loads at bf16 as requested.

Adds TestMetaParamFallbackContract (3 cases) asserting the underlying PyTorch contract: load_state_dict(assign=True, strict=False) leaves params on meta when keys are missing, the post-assign scan finds them, and .to() on a remaining meta param raises. 58 tests pass total (was 55).

Resolves conflicts between low_cpu_mem_usage= (this branch) and variant=
(merged from main). Both knobs are independent and now coexist in
from_pretrained: variant= narrows the download, low_cpu_mem_usage=
skips the random-init shell during load.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@maxwbuckley
Copy link
Copy Markdown
Contributor Author

Resolved the merge conflicts with main (the variant= PR landed on top of this branch's diff). Both knobs are independent and now coexist in from_pretrainedvariant= narrows the download, low_cpu_mem_usage= skips the random-init shell during load. Lint clean and all 98 tests in tests/test_quantize_and_dtype.py pass.

@Ingvarstep Ingvarstep merged commit e9c3c7d into urchade:main Apr 29, 2026
3 checks passed
@Ingvarstep
Copy link
Copy Markdown
Collaborator

@maxwbuckley , solid contribution, thank you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants