Skip to content

Commit 4b9e7f7

Browse files
authored
Merge pull request #348 from maxwbuckley/from-pretrained-dtype
Load from_pretrained at target dtype; restrict `quantize=` to int8
2 parents f6137fb + b7b819c commit 4b9e7f7

6 files changed

Lines changed: 396 additions & 94 deletions

File tree

docs/serving.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,10 @@ result = client.predict(
145145
Model Configuration:
146146
--model Model name or path (required)
147147
--device cuda or cpu (default: cuda)
148-
--dtype float32, float16, bfloat16 (default: bfloat16)
149-
--quantization fp16, bf16, int8 (default: None)
148+
--dtype float32, float16/fp16, bfloat16/bf16 (default: bfloat16)
149+
Weights are loaded directly at this precision; the fp32
150+
intermediate is never materialized.
151+
--quantization int8 (default: None). For precision changes use --dtype.
150152
151153
Batching:
152154
--max-batch-size Max batch size (default: 32)
@@ -199,7 +201,7 @@ docker run --gpus all -p 8000:8000 \
199201
| `GLINER_MAX_BATCH_SIZE` | `32` | Max batch size |
200202
| `GLINER_NUM_REPLICAS` | `1` | Number of replicas |
201203
| `GLINER_MEMORY_FRACTION` | `0.8` | GPU memory fraction |
202-
| `GLINER_QUANTIZATION` | - | Quantization (fp16/bf16/int8) |
204+
| `GLINER_QUANTIZATION` | - | Quantization (`int8` only; use `GLINER_DTYPE` for precision) |
203205
| `GLINER_ENABLE_FLASHDEBERTA` | `false` | Enable FlashDeBERTa |
204206
| `GLINER_ENABLE_PACKING` | `false` | Enable sequence packing |
205207
| `GLINER_DISABLE_COMPILE` | `false` | Disable torch.compile |

docs/usage.md

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -366,27 +366,53 @@ model = GLiNER.from_pretrained(
366366
print(f"Model is on: {model.device}")
367367
```
368368

369+
### Reduced-precision loading (`dtype`)
370+
371+
Pass `dtype` to `from_pretrained` to load the weights directly at the target floating-point precision — no intermediate fp32 copy, no post-load cast:
372+
373+
```python
374+
from gliner import GLiNER
375+
import torch
376+
377+
# Either a string or a torch.dtype
378+
model = GLiNER.from_pretrained("urchade/gliner_medium-v2.1", dtype="bf16", map_location="cuda")
379+
model = GLiNER.from_pretrained("urchade/gliner_medium-v2.1", dtype=torch.bfloat16, map_location="cuda")
380+
```
381+
382+
Accepted values: `"fp16"` / `"float16"` / `"half"`, `"bf16"` / `"bfloat16"`, `"fp32"` / `"float32"` / `"float"`, or any floating-point `torch.dtype`. Int/bool buffers are left untouched; non-floating dtypes (e.g. `torch.int8`) are rejected — use `quantize="int8"` for that path.
383+
384+
**Why use `dtype` instead of `quantize="bf16"`:**
385+
- `quantize` casts *after* the full fp32 state dict + fp32 model are already in memory.
386+
- `dtype` casts each tensor *as it is read* from the safetensors file and pre-casts the model shell before `load_state_dict`, so a fully-fp32 snapshot never co-exists with the loaded weights. For CPU-only loads, peak host memory during load drops from ~2× fp32 to ~1× fp32 for bf16/fp16. For `map_location="cuda"`, the state dict streams to GPU while the shell is CPU-side, so the saving is avoiding a simultaneous fp32 GPU state dict + fp32 GPU model — not quite a 2×→1× total-footprint reduction, but still a meaningful win on the GPU peak and on the separate post-load cast pass.
387+
388+
**When it matters:** cold starts and scalable serverless deployments (AWS Lambda, Cloud Run, Modal, RunPod serverless, autoscaled Kubernetes pods, etc.) — startup latency and peak memory directly drive cost and SLA:
389+
- Shorter cold-start on every new container (one pass instead of load + cast).
390+
- Lower peak memory lets instances fit on smaller memory tiers and reduces boot-time OOMs under memory pressure.
391+
- Faster first-inference latency after a scale-from-zero event.
392+
393+
`dtype` covers plain precision changes (bf16/fp16/fp32). For int8 / torchao / CPU dynamic quantization, keep using `quantize` (see below). The two can be combined if desired.
394+
369395
### Quantization, Compilation & FlashDeBERTa
370396

371-
Use `quantize=True` and `compile_torch_model=True` for up to ~1.9x faster GPU inference with zero quality loss:
397+
Combine `dtype="fp16"` (or `"bf16"`) with `compile_torch_model=True` for up to ~1.9x faster GPU inference with zero quality loss:
372398

373399
```python
374400
from gliner import GLiNER
375401

376402
model = GLiNER.from_pretrained(
377403
"urchade/gliner_medium-v2.1",
378404
map_location="cuda",
379-
quantize=True, # or "fp16", "bf16"
405+
dtype="fp16", # or "bf16" — see "Reduced-precision loading" above
380406
compile_torch_model=True,
381407
)
382408
```
383409

384410
Or apply after loading:
385411

386412
```python
413+
import torch
387414
model = GLiNER.from_pretrained("urchade/gliner_medium-v2.1", map_location="cuda")
388-
model.quantize() # fp16 half-precision (default)
389-
model.quantize("bf16") # bfloat16 — better numerical stability, slightly less speedup
415+
model.to(torch.float16) # fp16 half-precision
390416
model.compile() # torch.compile with dynamic shapes
391417
```
392418

@@ -397,15 +423,14 @@ Compilation is especially beneficial for short sequences, where the overhead of
397423
| Condition | F1 | Speedup |
398424
|-----------|:---:|:---:|
399425
| GPU fp32 (baseline) | 0.8107 | 1.00x |
400-
| + quantize (fp16) | 0.8107 | 1.35x |
426+
| + `dtype="fp16"` | 0.8107 | 1.35x |
401427
| + compile | 0.8107 | 1.31x |
402-
| **+ quantize + compile** | **0.8107** | **1.94x** |
428+
| **+ `dtype="fp16"` + compile** | **0.8107** | **1.94x** |
403429

404-
**Quantization options:**
405-
- `quantize=True` or `quantize="fp16"` — float16 half-precision. Best GPU speedup (~1.35x).
406-
- `quantize="bf16"` — bfloat16. Better numerical stability, slightly less speedup (~1.2x).
407-
- `quantize="int8"` — int8 quantization. On CPU, uses built-in FBGEMM int8 kernels (~1.6x speedup). On GPU, uses [torchao](https://github.com/pytorch/ao) int8 weight-only quantization (~50% memory reduction, no speed gain). Intended for models fine-tuned with quantization-aware training (QAT). Stock DeBERTa-based models lose accuracy with int8.
408-
- On CPU, fp16/bf16 quantization reduces memory usage but does not improve speed.
430+
**`quantize=` vs `dtype=`:**
431+
- `dtype="fp16"` / `"bf16"` — plain precision change via efficient load (see the dedicated section above). This is the only way to get half-precision inference.
432+
- `quantize="int8"` — real int8 quantization. On CPU, built-in FBGEMM kernels (~1.6x speedup). On GPU, [torchao](https://github.com/pytorch/ao) int8 weight-only quantization (~50% memory reduction, no speed gain). Intended for models fine-tuned with quantization-aware training (QAT); stock DeBERTa-based models lose accuracy with int8.
433+
- `quantize=` accepts only `"int8"` (or `None`). Passing `True`, `"fp16"`, or `"bf16"` raises with a migration message — those were precision downcasts, not quantization, and are handled exclusively by `dtype=` / `model.to(...)` now.
409434

410435
**Compilation notes:**
411436
- `compile_torch_model=True` uses [torch.compile](https://pytorch.org/docs/stable/torch.compiler.html) which JIT-compiles the model via [Triton](https://github.com/triton-lang/triton) kernels. The first inference call will be slower due to compilation, but all subsequent calls benefit from the compiled graph. This is only available on **Linux and WSL** (not native Windows or macOS).

0 commit comments

Comments
 (0)