Skip to content

Commit 61c490d

Browse files
authored
Merge pull request #354 from maxwbuckley/add-download-variant
Add variant= to from_pretrained for selective half-precision downloads
2 parents e93c046 + 8abe7fd commit 61c490d

3 files changed

Lines changed: 645 additions & 23 deletions

File tree

docs/usage.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,24 @@ Accepted values: `"fp16"` / `"float16"` / `"half"`, `"bf16"` / `"bfloat16"`, `"f
392392

393393
`dtype` covers plain precision changes (bf16/fp16/fp32). For int8 / torchao / CPU dynamic quantization, keep using `quantize` (see below). The two can be combined if desired.
394394

395+
#### Selective download (`variant`)
396+
397+
`dtype=` casts in memory but the on-disk file is still fp32, so the bytes pulled from the Hub don't shrink. If a publisher uploads a half-precision variant of the file (`model.fp16.safetensors` or `model.bf16.safetensors`, following the transformers naming convention), pass `variant=` to download *only* that file:
398+
399+
```python
400+
model = GLiNER.from_pretrained("org/gliner_bf16-v1", variant="bf16")
401+
# Halves bytes-on-the-wire vs. the default fp32 download (~745 MB -> ~370 MB
402+
# for gliner_medium-v2.1) when a bf16 file is published.
403+
```
404+
405+
Behavior — `variant=` is a *best-effort hint*, not a hard requirement:
406+
407+
- `variant=None` (default): unchanged — pulls the whole repo and loads `model.safetensors`.
408+
- `variant="fp16"` / `"bf16"` and the variant **is** published: `snapshot_download` is filtered with `allow_patterns` so only `model.{variant}.safetensors` (plus configs and tokenizer assets) is fetched. `dtype=` is inferred from `variant`; passing both with mismatched precisions raises.
409+
- `variant="fp16"` / `"bf16"` and the variant **is not** published: a `UserWarning` is emitted and the loader falls back to the default fp32 file plus an in-memory cast — same outcome as passing `dtype=` alone, no error, no I/O win. The warning text tells the user the publisher hasn't uploaded the file so the bandwidth savings didn't apply.
410+
411+
This is the lever to pull for cold-start cost when bytes-on-the-wire dominate. Set `variant="bf16"` and forget about it — if the publisher has the variant file you get the I/O savings, and if they don't you get the in-memory `dtype=` behavior with a one-line warning. The probe uses `huggingface_hub.HfApi().list_repo_files` (one cheap API call) before downloading.
412+
395413
### Quantization, Compilation & FlashDeBERTa
396414

397415
Combine `dtype="fp16"` (or `"bf16"`) with `compile_torch_model=True` for up to ~1.9x faster GPU inference with zero quality loss:

0 commit comments

Comments
 (0)