Commit b7b819c
Load from_pretrained at target dtype; restrict quantize= to int8
Two related changes to the precision / quantization surface, landed
together because they form one coherent story.
1. Add `dtype=` to `GLiNER.from_pretrained`
Load weights directly at the target floating-point precision. Each
state-dict tensor is cast during the `safe_open` read and the
random-init model shell is pre-cast via `instance.model.to(dtype)`
before `load_state_dict`, so a full fp32 snapshot never co-exists
with the loaded weights. Accepts strings (`"bf16"`, `"fp16"`,
`"float32"`, ...) or a floating-point `torch.dtype`; non-floating
dtypes (e.g. `torch.int8`) are rejected up front with a message
pointing at `quantize="int8"` for int paths. Int / bool buffers are
preserved in the state dict.
Memory impact: for CPU-only loads peak drops from ~2x fp32 to ~1x
fp32; for `map_location="cuda"`, the saving is avoiding a
simultaneous fp32 GPU state dict + fp32 GPU model plus the separate
post-load cast pass. Matches the `dtype=` surface on
`transformers.PreTrainedModel.from_pretrained` (string or
`torch.dtype`, same semantics), so users coming from HF get a
familiar API.
Primary target: cold starts and scalable serverless deployments
(Lambda, Cloud Run, Modal, RunPod serverless, autoscaled k8s) where
startup latency and peak memory drive cost and SLA.
The Ray-Serve layer (`gliner.serve.GLiNERFactory`) is wired through:
it now passes `map_location` + `dtype` to `from_pretrained` instead
of doing a post-load `.to(device=..., dtype=...)` cast, so serving
cold starts get the same peak-memory win.
2. Restrict `quantize=` to `"int8"` only
Previously `quantize=` accepted `True`, `"fp16"`, `"bf16"`, and
`"int8"`. Three of the five effective rows (GPU fp16, GPU bf16, CPU
bf16) were just `.to(dtype)` with an extra fp32 intermediate — not
quantization. A fourth (CPU fp16) wrapped `nn.Linear` with
dynamic-quantized variants but had no documented speed benefit and
was asymmetric with the other CPU/GPU combinations.
All non-int8 values now raise `ValueError` with a migration message
pointing at `dtype=` / `model.to(torch_dtype)`. `quantize="int8"`
is unchanged: torchao int8 weight-only on GPU, FBGEMM dynamic
quant on CPU. The signature narrows from `Union[bool, str] = False`
to `Optional[str] = None`; `model.quantize()` defaults to `"int8"`.
The serving CLI (`gliner.serve --quantization`) is narrowed in
lockstep to `{int8, None}`; precision stays under `--dtype`.
Docs in `docs/usage.md` and `docs/serving.md` are updated:
- New "Reduced-precision loading (`dtype`)" section in usage.md.
- "Quantization, Compilation & FlashDeBERTa" now shows `dtype="fp16"`
as the only half-precision path.
- Serving CLI reference and env-var table reflect the int8-only
`--quantization` / `GLINER_QUANTIZATION` surface.
New tests in `tests/test_quantize_and_dtype.py` (50 cases) cover
`_parse_dtype`, `_load_state_dict` cast-on-read, and the int8-only
`model.quantize(...)` validation.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>1 parent f6137fb commit b7b819c
6 files changed
Lines changed: 396 additions & 94 deletions
File tree
- docs
- gliner
- serve
- tests
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
145 | 145 | | |
146 | 146 | | |
147 | 147 | | |
148 | | - | |
149 | | - | |
| 148 | + | |
| 149 | + | |
| 150 | + | |
| 151 | + | |
150 | 152 | | |
151 | 153 | | |
152 | 154 | | |
| |||
199 | 201 | | |
200 | 202 | | |
201 | 203 | | |
202 | | - | |
| 204 | + | |
203 | 205 | | |
204 | 206 | | |
205 | 207 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
366 | 366 | | |
367 | 367 | | |
368 | 368 | | |
| 369 | + | |
| 370 | + | |
| 371 | + | |
| 372 | + | |
| 373 | + | |
| 374 | + | |
| 375 | + | |
| 376 | + | |
| 377 | + | |
| 378 | + | |
| 379 | + | |
| 380 | + | |
| 381 | + | |
| 382 | + | |
| 383 | + | |
| 384 | + | |
| 385 | + | |
| 386 | + | |
| 387 | + | |
| 388 | + | |
| 389 | + | |
| 390 | + | |
| 391 | + | |
| 392 | + | |
| 393 | + | |
| 394 | + | |
369 | 395 | | |
370 | 396 | | |
371 | | - | |
| 397 | + | |
372 | 398 | | |
373 | 399 | | |
374 | 400 | | |
375 | 401 | | |
376 | 402 | | |
377 | 403 | | |
378 | 404 | | |
379 | | - | |
| 405 | + | |
380 | 406 | | |
381 | 407 | | |
382 | 408 | | |
383 | 409 | | |
384 | 410 | | |
385 | 411 | | |
386 | 412 | | |
| 413 | + | |
387 | 414 | | |
388 | | - | |
389 | | - | |
| 415 | + | |
390 | 416 | | |
391 | 417 | | |
392 | 418 | | |
| |||
397 | 423 | | |
398 | 424 | | |
399 | 425 | | |
400 | | - | |
| 426 | + | |
401 | 427 | | |
402 | | - | |
| 428 | + | |
403 | 429 | | |
404 | | - | |
405 | | - | |
406 | | - | |
407 | | - | |
408 | | - | |
| 430 | + | |
| 431 | + | |
| 432 | + | |
| 433 | + | |
409 | 434 | | |
410 | 435 | | |
411 | 436 | | |
| |||
0 commit comments