You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/usage.md
+36-11Lines changed: 36 additions & 11 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -366,27 +366,53 @@ model = GLiNER.from_pretrained(
366
366
print(f"Model is on: {model.device}")
367
367
```
368
368
369
+
### Reduced-precision loading (`dtype`)
370
+
371
+
Pass `dtype` to `from_pretrained` to load the weights directly at the target floating-point precision — no intermediate fp32 copy, no post-load cast:
372
+
373
+
```python
374
+
from gliner import GLiNER
375
+
import torch
376
+
377
+
# Either a string or a torch.dtype
378
+
model = GLiNER.from_pretrained("urchade/gliner_medium-v2.1", dtype="bf16", map_location="cuda")
379
+
model = GLiNER.from_pretrained("urchade/gliner_medium-v2.1", dtype=torch.bfloat16, map_location="cuda")
380
+
```
381
+
382
+
Accepted values: `"fp16"` / `"float16"` / `"half"`, `"bf16"` / `"bfloat16"`, `"fp32"` / `"float32"` / `"float"`, or any floating-point `torch.dtype`. Int/bool buffers are left untouched; non-floating dtypes (e.g. `torch.int8`) are rejected — use `quantize="int8"` for that path.
383
+
384
+
**Why use `dtype` instead of `quantize="bf16"`:**
385
+
-`quantize` casts *after* the full fp32 state dict + fp32 model are already in memory.
386
+
-`dtype` casts each tensor *as it is read* from the safetensors file and pre-casts the model shell before `load_state_dict`, so a fully-fp32 snapshot never co-exists with the loaded weights. For CPU-only loads, peak host memory during load drops from ~2× fp32 to ~1× fp32 for bf16/fp16. For `map_location="cuda"`, the state dict streams to GPU while the shell is CPU-side, so the saving is avoiding a simultaneous fp32 GPU state dict + fp32 GPU model — not quite a 2×→1× total-footprint reduction, but still a meaningful win on the GPU peak and on the separate post-load cast pass.
387
+
388
+
**When it matters:** cold starts and scalable serverless deployments (AWS Lambda, Cloud Run, Modal, RunPod serverless, autoscaled Kubernetes pods, etc.) — startup latency and peak memory directly drive cost and SLA:
389
+
- Shorter cold-start on every new container (one pass instead of load + cast).
390
+
- Lower peak memory lets instances fit on smaller memory tiers and reduces boot-time OOMs under memory pressure.
391
+
- Faster first-inference latency after a scale-from-zero event.
392
+
393
+
`dtype` covers plain precision changes (bf16/fp16/fp32). For int8 / torchao / CPU dynamic quantization, keep using `quantize` (see below). The two can be combined if desired.
394
+
369
395
### Quantization, Compilation & FlashDeBERTa
370
396
371
-
Use `quantize=True` and`compile_torch_model=True` for up to ~1.9x faster GPU inference with zero quality loss:
397
+
Combine `dtype="fp16"` (or `"bf16"`) with`compile_torch_model=True` for up to ~1.9x faster GPU inference with zero quality loss:
372
398
373
399
```python
374
400
from gliner import GLiNER
375
401
376
402
model = GLiNER.from_pretrained(
377
403
"urchade/gliner_medium-v2.1",
378
404
map_location="cuda",
379
-
quantize=True, # or "fp16", "bf16"
405
+
dtype="fp16", # or "bf16" — see "Reduced-precision loading" above
380
406
compile_torch_model=True,
381
407
)
382
408
```
383
409
384
410
Or apply after loading:
385
411
386
412
```python
413
+
import torch
387
414
model = GLiNER.from_pretrained("urchade/gliner_medium-v2.1", map_location="cuda")
388
-
model.quantize() # fp16 half-precision (default)
389
-
model.quantize("bf16") # bfloat16 — better numerical stability, slightly less speedup
415
+
model.to(torch.float16) # fp16 half-precision
390
416
model.compile() # torch.compile with dynamic shapes
391
417
```
392
418
@@ -397,15 +423,14 @@ Compilation is especially beneficial for short sequences, where the overhead of
-`quantize=True` or `quantize="fp16"` — float16 half-precision. Best GPU speedup (~1.35x).
406
-
-`quantize="bf16"` — bfloat16. Better numerical stability, slightly less speedup (~1.2x).
407
-
-`quantize="int8"` — int8 quantization. On CPU, uses built-in FBGEMM int8 kernels (~1.6x speedup). On GPU, uses [torchao](https://github.com/pytorch/ao) int8 weight-only quantization (~50% memory reduction, no speed gain). Intended for models fine-tuned with quantization-aware training (QAT). Stock DeBERTa-based models lose accuracy with int8.
408
-
- On CPU, fp16/bf16 quantization reduces memory usage but does not improve speed.
430
+
**`quantize=` vs `dtype=`:**
431
+
-`dtype="fp16"` / `"bf16"` — plain precision change via efficient load (see the dedicated section above). This is the only way to get half-precision inference.
432
+
-`quantize="int8"` — real int8 quantization. On CPU, built-in FBGEMM kernels (~1.6x speedup). On GPU, [torchao](https://github.com/pytorch/ao) int8 weight-only quantization (~50% memory reduction, no speed gain). Intended for models fine-tuned with quantization-aware training (QAT); stock DeBERTa-based models lose accuracy with int8.
433
+
-`quantize=` accepts only `"int8"` (or `None`). Passing `True`, `"fp16"`, or `"bf16"` raises with a migration message — those were precision downcasts, not quantization, and are handled exclusively by `dtype=` / `model.to(...)` now.
409
434
410
435
**Compilation notes:**
411
436
-`compile_torch_model=True` uses [torch.compile](https://pytorch.org/docs/stable/torch.compiler.html) which JIT-compiles the model via [Triton](https://github.com/triton-lang/triton) kernels. The first inference call will be slower due to compilation, but all subsequent calls benefit from the compiled graph. This is only available on **Linux and WSL** (not native Windows or macOS).
0 commit comments