Skip to content

Commit 8abe7fd

Browse files
maxwbuckleyclaude
andcommitted
Fix three Codex-review findings on variant= probe
[P2] Pass cache_dir through the variant probe. ``_variant_available`` and ``_resolve_variant`` previously didn't accept ``cache_dir``, so a caller using ``from_pretrained(..., cache_dir=X)`` would have ``hf_hub_download`` probe into the *default* HF cache and then ``snapshot_download(..., cache_dir=X)`` could not reuse the probe's download. Result: cold starts paid for the variant weights twice and the user's requested cache location was bypassed. Both helpers now take ``cache_dir`` and forward it to ``try_to_load_from_cache`` and ``hf_hub_download``. [P2] Sharded variant safetensors. ``_variant_allow_patterns`` already included ``model.{variant}.safetensors.index.json`` for forward compatibility, but the actual shard files (``model-XXXXX-of-YYYYY.{variant}.safetensors``) were excluded by the allow-list, so a publisher who shipped a sharded fp16 / bf16 variant would get only the index file pulled — the load would then fail or silently fall back to fp32. Added the ``model-*-of-*.{variant}.safetensors`` glob. [P2] dtype-vs-variant consistency in the outer dispatcher. ``GLiNER.from_pretrained`` (the outer class-level dispatcher) ran the variant probe before checking dtype/variant consistency. When the variant file was missing from the Hub, ``_resolve_variant`` downgraded to ``None`` and the inner consistency check was then skipped — silently accepting ``variant='bf16', dtype='fp16'`` and loading fp16 instead of raising the documented mismatch error. Hoisted the consistency check above the probe in the outer dispatcher to mirror the inner logic. Test coverage: - New ``test_includes_sharded_safetensors_pattern`` asserts the shard glob is per-variant (no cross-variant slip-through, default shards still excluded). - Updated ``test_fp16_and_bf16_differ_only_in_variant_filename`` to account for the new shard pattern entries in the symmetric difference. - New ``test_outer_dispatcher_mismatch_raises_before_probe`` exercises the outer-dispatcher path with a non-existent ``model_id`` and a mismatched ``variant``/``dtype`` pair, asserting the ``ValueError`` fires before any I/O. Guards against the silent-fp16 regression. 90 unit tests pass (was 88 + 2 new). Ruff lint and format clean. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 0d3cf78 commit 8abe7fd

2 files changed

Lines changed: 95 additions & 14 deletions

File tree

gliner/model.py

Lines changed: 57 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -319,19 +319,26 @@ def _normalize_variant(cls, variant) -> Optional[str]:
319319
def _variant_allow_patterns(variant: str) -> list:
320320
"""Return ``snapshot_download(allow_patterns=...)`` for a variant.
321321
322-
The patterns include the variant safetensors file (and its sharded
323-
index, if present) plus the configs and tokenizer assets every load
324-
needs. The default ``model.safetensors`` and ``pytorch_model.bin`` are
325-
deliberately excluded so the caller pays I/O only for the requested
326-
variant.
322+
Includes the single-file variant safetensors, the sharded variant
323+
index, the actual sharded variant safetensors files, and the configs
324+
and tokenizer assets every load needs. The default
325+
``model.safetensors`` and ``pytorch_model.bin`` are deliberately
326+
excluded so the caller pays I/O only for the requested variant.
327+
328+
Sharded checkpoint convention (transformers-style):
329+
``model-00001-of-NNNNN.{variant}.safetensors``
330+
``model.{variant}.safetensors.index.json``
327331
"""
328332
return [
329333
"*.json",
330334
"*.txt",
331335
"spiece.model",
332336
"sentencepiece.bpe.model",
337+
# Single-file variant.
333338
f"model.{variant}.safetensors",
339+
# Sharded variant: index file + per-shard files.
334340
f"model.{variant}.safetensors.index.json",
341+
f"model-*-of-*.{variant}.safetensors",
335342
]
336343

337344
@classmethod
@@ -340,6 +347,7 @@ def _variant_available(
340347
model_id: str,
341348
variant: str,
342349
revision: Optional[str] = None,
350+
cache_dir: Optional[Union[str, Path]] = None,
343351
token: Union[str, bool, None] = None,
344352
local_files_only: bool = False,
345353
) -> Optional[bool]:
@@ -375,8 +383,11 @@ def _variant_available(
375383
# try_to_load_from_cache validates the repo_id format; an
376384
# HFValidationError here means the input isn't a valid repo_id at
377385
# all (e.g. a non-existent local path), so treat as uncertain.
386+
# ``cache_dir`` must match what ``snapshot_download`` will use, or the
387+
# probe and the actual download diverge (and we'd download the variant
388+
# twice).
378389
try:
379-
cached = try_to_load_from_cache(repo_id=model_id, filename=target, revision=revision)
390+
cached = try_to_load_from_cache(repo_id=model_id, filename=target, revision=revision, cache_dir=cache_dir)
380391
except Exception:
381392
return None
382393
if isinstance(cached, str):
@@ -388,8 +399,16 @@ def _variant_available(
388399

389400
# 4. Try-and-recover via hf_hub_download. Success caches the file so
390401
# the subsequent snapshot_download reuses it (no double download).
402+
# cache_dir must propagate so the probe and snapshot_download share
403+
# the same store.
391404
try:
392-
hf_hub_download(repo_id=model_id, filename=target, revision=revision, token=token)
405+
hf_hub_download(
406+
repo_id=model_id,
407+
filename=target,
408+
revision=revision,
409+
cache_dir=cache_dir,
410+
token=token,
411+
)
393412
return True
394413
except EntryNotFoundError:
395414
return False
@@ -404,6 +423,7 @@ def _resolve_variant(
404423
model_id: str,
405424
variant: Optional[str],
406425
revision: Optional[str] = None,
426+
cache_dir: Optional[Union[str, Path]] = None,
407427
token: Union[str, bool, None] = None,
408428
local_files_only: bool = False,
409429
) -> Optional[str]:
@@ -420,7 +440,12 @@ def _resolve_variant(
420440
if variant is None:
421441
return None
422442
available = cls._variant_available(
423-
model_id, variant, revision=revision, token=token, local_files_only=local_files_only
443+
model_id,
444+
variant,
445+
revision=revision,
446+
cache_dir=cache_dir,
447+
token=token,
448+
local_files_only=local_files_only,
424449
)
425450
if available is False:
426451
# TODO(strict-variant): once half-precision variant files have been
@@ -1046,6 +1071,7 @@ def from_pretrained(
10461071
model_id,
10471072
variant,
10481073
revision=revision,
1074+
cache_dir=cache_dir,
10491075
token=token,
10501076
local_files_only=local_files_only,
10511077
)
@@ -4530,22 +4556,39 @@ def from_pretrained(
45304556
# outer ``GLiNER`` class doesn't inherit from ``BaseGLiNER``; reuse
45314557
# the helpers directly so behavior stays in lockstep.
45324558
normalized_variant = BaseGLiNER._normalize_variant(variant)
4559+
4560+
# dtype-vs-variant consistency check MUST run before the probe.
4561+
# Otherwise, when the variant file is missing on the Hub,
4562+
# ``_resolve_variant`` downgrades to ``None`` and the inner
4563+
# ``from_pretrained``'s consistency check is skipped — silently
4564+
# accepting a ``variant="bf16", dtype="fp16"`` mismatch instead of
4565+
# raising as documented.
4566+
torch_dtype = BaseGLiNER._parse_dtype(dtype)
4567+
if normalized_variant is not None:
4568+
variant_dtype = BaseGLiNER._VARIANT_TO_DTYPE[normalized_variant]
4569+
if torch_dtype is None:
4570+
torch_dtype = variant_dtype
4571+
# Propagate the variant's dtype so the inner cast-on-read still
4572+
# produces the requested precision after a fallback.
4573+
dtype = variant_dtype
4574+
elif torch_dtype != variant_dtype:
4575+
raise ValueError(
4576+
f"variant={normalized_variant!r} requires dtype={variant_dtype}; "
4577+
f"got dtype={torch_dtype}. Drop dtype= to inherit from variant, "
4578+
f"or unset variant= to load the default file."
4579+
)
4580+
45334581
# Probe for availability and warn-and-fall-back to None if the variant
45344582
# file isn't published. The inner from_pretrained will see model_dir
45354583
# is already populated and skip its own probe — no double round-trip.
45364584
normalized_variant = BaseGLiNER._resolve_variant(
45374585
model_id,
45384586
normalized_variant,
45394587
revision=revision,
4588+
cache_dir=cache_dir,
45404589
token=token,
45414590
local_files_only=local_files_only,
45424591
)
4543-
# If the probe downgraded variant -> None but the user asked for a
4544-
# specific variant, propagate the variant's dtype so the inner cast-on-
4545-
# read still produces the requested precision.
4546-
if variant is not None and normalized_variant is None and dtype is None:
4547-
original = BaseGLiNER._normalize_variant(variant)
4548-
dtype = BaseGLiNER._VARIANT_TO_DTYPE[original]
45494592

45504593
model_dir = BaseGLiNER._download_model(
45514594
model_id,

tests/test_quantize_and_dtype.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from torch import nn
1414
from safetensors.torch import save_file
1515

16+
from gliner import GLiNER
1617
from gliner.model import BaseGLiNER
1718

1819

@@ -295,8 +296,26 @@ def test_fp16_and_bf16_differ_only_in_variant_filename(self):
295296
"model.bf16.safetensors",
296297
"model.fp16.safetensors.index.json",
297298
"model.bf16.safetensors.index.json",
299+
# Sharded variant patterns must also differ between fp16 and bf16,
300+
# otherwise large multi-file checkpoints can't pull only the
301+
# requested precision's shards.
302+
"model-*-of-*.fp16.safetensors",
303+
"model-*-of-*.bf16.safetensors",
298304
}
299305

306+
def test_includes_sharded_safetensors_pattern(self):
307+
"""Sharded variant checkpoints place tensor data in
308+
``model-XXXXX-of-YYYYY.{variant}.safetensors`` files; without the
309+
wildcard pattern the index would download but the actual shards
310+
would be filtered out.
311+
"""
312+
patterns = BaseGLiNER._variant_allow_patterns("bf16")
313+
assert "model-*-of-*.bf16.safetensors" in patterns
314+
# Wrong variant must not slip through the sharded match.
315+
assert "model-*-of-*.fp16.safetensors" not in patterns
316+
# Default-variant shards must still be excluded.
317+
assert "model-*-of-*.safetensors" not in patterns
318+
300319

301320
class TestVariantDtypeConsistency:
302321
"""``variant=`` and ``dtype=`` must agree (or only one set).
@@ -339,6 +358,25 @@ def test_int_dtype_against_variant_rejected_by_dtype_parser(self, tmp_path: Path
339358
dtype=torch.int8,
340359
)
341360

361+
def test_outer_dispatcher_mismatch_raises_before_probe(self, tmp_path: Path):
362+
"""Codex review finding: the outer ``GLiNER.from_pretrained`` used to
363+
run the variant probe before checking dtype/variant consistency, so
364+
when the variant file was missing on the Hub the consistency check
365+
was skipped and a ``variant='bf16', dtype='fp16'`` mismatch would
366+
load fp16 silently. This test guards the regression by checking
367+
the mismatch raises even when the model_id is a non-existent path
368+
(which would otherwise be caught later by the download step).
369+
"""
370+
# tmp_path has no gliner_config.json; if the consistency check runs
371+
# first, we get a ValueError. If it runs after the probe (the bug),
372+
# we'd get a FileNotFoundError or warning instead.
373+
with pytest.raises(ValueError, match="variant='bf16' requires"):
374+
GLiNER.from_pretrained(
375+
model_id=str(tmp_path),
376+
variant="bf16",
377+
dtype="fp16",
378+
)
379+
342380

343381
class TestVariantAvailable:
344382
"""``_variant_available`` probe for variant file presence."""

0 commit comments

Comments
 (0)