Skip to content

Commit b7b819c

Browse files
maxwbuckleyclaude
andcommitted
Load from_pretrained at target dtype; restrict quantize= to int8
Two related changes to the precision / quantization surface, landed together because they form one coherent story. 1. Add `dtype=` to `GLiNER.from_pretrained` Load weights directly at the target floating-point precision. Each state-dict tensor is cast during the `safe_open` read and the random-init model shell is pre-cast via `instance.model.to(dtype)` before `load_state_dict`, so a full fp32 snapshot never co-exists with the loaded weights. Accepts strings (`"bf16"`, `"fp16"`, `"float32"`, ...) or a floating-point `torch.dtype`; non-floating dtypes (e.g. `torch.int8`) are rejected up front with a message pointing at `quantize="int8"` for int paths. Int / bool buffers are preserved in the state dict. Memory impact: for CPU-only loads peak drops from ~2x fp32 to ~1x fp32; for `map_location="cuda"`, the saving is avoiding a simultaneous fp32 GPU state dict + fp32 GPU model plus the separate post-load cast pass. Matches the `dtype=` surface on `transformers.PreTrainedModel.from_pretrained` (string or `torch.dtype`, same semantics), so users coming from HF get a familiar API. Primary target: cold starts and scalable serverless deployments (Lambda, Cloud Run, Modal, RunPod serverless, autoscaled k8s) where startup latency and peak memory drive cost and SLA. The Ray-Serve layer (`gliner.serve.GLiNERFactory`) is wired through: it now passes `map_location` + `dtype` to `from_pretrained` instead of doing a post-load `.to(device=..., dtype=...)` cast, so serving cold starts get the same peak-memory win. 2. Restrict `quantize=` to `"int8"` only Previously `quantize=` accepted `True`, `"fp16"`, `"bf16"`, and `"int8"`. Three of the five effective rows (GPU fp16, GPU bf16, CPU bf16) were just `.to(dtype)` with an extra fp32 intermediate — not quantization. A fourth (CPU fp16) wrapped `nn.Linear` with dynamic-quantized variants but had no documented speed benefit and was asymmetric with the other CPU/GPU combinations. All non-int8 values now raise `ValueError` with a migration message pointing at `dtype=` / `model.to(torch_dtype)`. `quantize="int8"` is unchanged: torchao int8 weight-only on GPU, FBGEMM dynamic quant on CPU. The signature narrows from `Union[bool, str] = False` to `Optional[str] = None`; `model.quantize()` defaults to `"int8"`. The serving CLI (`gliner.serve --quantization`) is narrowed in lockstep to `{int8, None}`; precision stays under `--dtype`. Docs in `docs/usage.md` and `docs/serving.md` are updated: - New "Reduced-precision loading (`dtype`)" section in usage.md. - "Quantization, Compilation & FlashDeBERTa" now shows `dtype="fp16"` as the only half-precision path. - Serving CLI reference and env-var table reflect the int8-only `--quantization` / `GLINER_QUANTIZATION` surface. New tests in `tests/test_quantize_and_dtype.py` (50 cases) cover `_parse_dtype`, `_load_state_dict` cast-on-read, and the int8-only `model.quantize(...)` validation. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent f6137fb commit b7b819c

6 files changed

Lines changed: 396 additions & 94 deletions

File tree

docs/serving.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,10 @@ result = client.predict(
145145
Model Configuration:
146146
--model Model name or path (required)
147147
--device cuda or cpu (default: cuda)
148-
--dtype float32, float16, bfloat16 (default: bfloat16)
149-
--quantization fp16, bf16, int8 (default: None)
148+
--dtype float32, float16/fp16, bfloat16/bf16 (default: bfloat16)
149+
Weights are loaded directly at this precision; the fp32
150+
intermediate is never materialized.
151+
--quantization int8 (default: None). For precision changes use --dtype.
150152
151153
Batching:
152154
--max-batch-size Max batch size (default: 32)
@@ -199,7 +201,7 @@ docker run --gpus all -p 8000:8000 \
199201
| `GLINER_MAX_BATCH_SIZE` | `32` | Max batch size |
200202
| `GLINER_NUM_REPLICAS` | `1` | Number of replicas |
201203
| `GLINER_MEMORY_FRACTION` | `0.8` | GPU memory fraction |
202-
| `GLINER_QUANTIZATION` | - | Quantization (fp16/bf16/int8) |
204+
| `GLINER_QUANTIZATION` | - | Quantization (`int8` only; use `GLINER_DTYPE` for precision) |
203205
| `GLINER_ENABLE_FLASHDEBERTA` | `false` | Enable FlashDeBERTa |
204206
| `GLINER_ENABLE_PACKING` | `false` | Enable sequence packing |
205207
| `GLINER_DISABLE_COMPILE` | `false` | Disable torch.compile |

docs/usage.md

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -366,27 +366,53 @@ model = GLiNER.from_pretrained(
366366
print(f"Model is on: {model.device}")
367367
```
368368

369+
### Reduced-precision loading (`dtype`)
370+
371+
Pass `dtype` to `from_pretrained` to load the weights directly at the target floating-point precision — no intermediate fp32 copy, no post-load cast:
372+
373+
```python
374+
from gliner import GLiNER
375+
import torch
376+
377+
# Either a string or a torch.dtype
378+
model = GLiNER.from_pretrained("urchade/gliner_medium-v2.1", dtype="bf16", map_location="cuda")
379+
model = GLiNER.from_pretrained("urchade/gliner_medium-v2.1", dtype=torch.bfloat16, map_location="cuda")
380+
```
381+
382+
Accepted values: `"fp16"` / `"float16"` / `"half"`, `"bf16"` / `"bfloat16"`, `"fp32"` / `"float32"` / `"float"`, or any floating-point `torch.dtype`. Int/bool buffers are left untouched; non-floating dtypes (e.g. `torch.int8`) are rejected — use `quantize="int8"` for that path.
383+
384+
**Why use `dtype` instead of `quantize="bf16"`:**
385+
- `quantize` casts *after* the full fp32 state dict + fp32 model are already in memory.
386+
- `dtype` casts each tensor *as it is read* from the safetensors file and pre-casts the model shell before `load_state_dict`, so a fully-fp32 snapshot never co-exists with the loaded weights. For CPU-only loads, peak host memory during load drops from ~2× fp32 to ~1× fp32 for bf16/fp16. For `map_location="cuda"`, the state dict streams to GPU while the shell is CPU-side, so the saving is avoiding a simultaneous fp32 GPU state dict + fp32 GPU model — not quite a 2×→1× total-footprint reduction, but still a meaningful win on the GPU peak and on the separate post-load cast pass.
387+
388+
**When it matters:** cold starts and scalable serverless deployments (AWS Lambda, Cloud Run, Modal, RunPod serverless, autoscaled Kubernetes pods, etc.) — startup latency and peak memory directly drive cost and SLA:
389+
- Shorter cold-start on every new container (one pass instead of load + cast).
390+
- Lower peak memory lets instances fit on smaller memory tiers and reduces boot-time OOMs under memory pressure.
391+
- Faster first-inference latency after a scale-from-zero event.
392+
393+
`dtype` covers plain precision changes (bf16/fp16/fp32). For int8 / torchao / CPU dynamic quantization, keep using `quantize` (see below). The two can be combined if desired.
394+
369395
### Quantization, Compilation & FlashDeBERTa
370396

371-
Use `quantize=True` and `compile_torch_model=True` for up to ~1.9x faster GPU inference with zero quality loss:
397+
Combine `dtype="fp16"` (or `"bf16"`) with `compile_torch_model=True` for up to ~1.9x faster GPU inference with zero quality loss:
372398

373399
```python
374400
from gliner import GLiNER
375401

376402
model = GLiNER.from_pretrained(
377403
"urchade/gliner_medium-v2.1",
378404
map_location="cuda",
379-
quantize=True, # or "fp16", "bf16"
405+
dtype="fp16", # or "bf16" — see "Reduced-precision loading" above
380406
compile_torch_model=True,
381407
)
382408
```
383409

384410
Or apply after loading:
385411

386412
```python
413+
import torch
387414
model = GLiNER.from_pretrained("urchade/gliner_medium-v2.1", map_location="cuda")
388-
model.quantize() # fp16 half-precision (default)
389-
model.quantize("bf16") # bfloat16 — better numerical stability, slightly less speedup
415+
model.to(torch.float16) # fp16 half-precision
390416
model.compile() # torch.compile with dynamic shapes
391417
```
392418

@@ -397,15 +423,14 @@ Compilation is especially beneficial for short sequences, where the overhead of
397423
| Condition | F1 | Speedup |
398424
|-----------|:---:|:---:|
399425
| GPU fp32 (baseline) | 0.8107 | 1.00x |
400-
| + quantize (fp16) | 0.8107 | 1.35x |
426+
| + `dtype="fp16"` | 0.8107 | 1.35x |
401427
| + compile | 0.8107 | 1.31x |
402-
| **+ quantize + compile** | **0.8107** | **1.94x** |
428+
| **+ `dtype="fp16"` + compile** | **0.8107** | **1.94x** |
403429

404-
**Quantization options:**
405-
- `quantize=True` or `quantize="fp16"` — float16 half-precision. Best GPU speedup (~1.35x).
406-
- `quantize="bf16"` — bfloat16. Better numerical stability, slightly less speedup (~1.2x).
407-
- `quantize="int8"` — int8 quantization. On CPU, uses built-in FBGEMM int8 kernels (~1.6x speedup). On GPU, uses [torchao](https://github.com/pytorch/ao) int8 weight-only quantization (~50% memory reduction, no speed gain). Intended for models fine-tuned with quantization-aware training (QAT). Stock DeBERTa-based models lose accuracy with int8.
408-
- On CPU, fp16/bf16 quantization reduces memory usage but does not improve speed.
430+
**`quantize=` vs `dtype=`:**
431+
- `dtype="fp16"` / `"bf16"` — plain precision change via efficient load (see the dedicated section above). This is the only way to get half-precision inference.
432+
- `quantize="int8"` — real int8 quantization. On CPU, built-in FBGEMM kernels (~1.6x speedup). On GPU, [torchao](https://github.com/pytorch/ao) int8 weight-only quantization (~50% memory reduction, no speed gain). Intended for models fine-tuned with quantization-aware training (QAT); stock DeBERTa-based models lose accuracy with int8.
433+
- `quantize=` accepts only `"int8"` (or `None`). Passing `True`, `"fp16"`, or `"bf16"` raises with a migration message — those were precision downcasts, not quantization, and are handled exclusively by `dtype=` / `model.to(...)` now.
409434

410435
**Compilation notes:**
411436
- `compile_torch_model=True` uses [torch.compile](https://pytorch.org/docs/stable/torch.compiler.html) which JIT-compiles the model via [Triton](https://github.com/triton-lang/triton) kernels. The first inference call will be slower due to compilation, but all subsequent calls benefit from the compiled graph. This is only available on **Linux and WSL** (not native Windows or macOS).

0 commit comments

Comments
 (0)