Skip to content

Commit d2cd03c

Browse files
realAsmaclaude
andcommitted
Move checkpoint_dir helpers from library to examples/llm_ptq
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: realAsma <akuriparambi@nvidia.com>
1 parent c50c4a7 commit d2cd03c

3 files changed

Lines changed: 35 additions & 41 deletions

File tree

examples/llm_ptq/example_utils.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import copy
1717
import glob
18+
import hashlib
1819
import inspect
1920
import json
2021
import logging
@@ -854,3 +855,35 @@ def copy_custom_model_files(source_path: str, export_path: str, trust_remote_cod
854855
print(f"Successfully copied {len(copied_files)} custom model files to {export_path}")
855856
else:
856857
print("No custom model files found to copy")
858+
859+
860+
def needs_checkpoint_path_update(quant_cfg: dict) -> bool:
861+
"""Check if quant_cfg has a checkpoint_dir that should be auto-resolved to a unique subpath."""
862+
algorithm = quant_cfg.get("algorithm")
863+
if algorithm is None or isinstance(algorithm, str):
864+
return False
865+
return algorithm.get("checkpoint_dir") is not None
866+
867+
868+
def resolve_checkpoint_dir(quant_cfg: dict, model_path: str) -> dict:
869+
"""Append a unique ``<model_name>_<config_hash>`` subdirectory to checkpoint_dir.
870+
871+
Allows a single recipe to be reused across models without checkpoint collisions.
872+
Must only be called when :func:`needs_checkpoint_path_update` returns True.
873+
"""
874+
algorithm = quant_cfg["algorithm"]
875+
base_dir = algorithm["checkpoint_dir"]
876+
877+
name = model_path.rstrip("/")
878+
if "/" in name and not os.path.isabs(name):
879+
name = name.replace("/", "--")
880+
else:
881+
name = Path(name).name
882+
883+
config_hash = hashlib.sha256(
884+
json.dumps(quant_cfg, sort_keys=True, default=str).encode()
885+
).hexdigest()[:8]
886+
887+
quant_cfg = copy.deepcopy(quant_cfg)
888+
quant_cfg["algorithm"]["checkpoint_dir"] = os.path.join(base_dir, f"{name}_{config_hash}")
889+
return quant_cfg

examples/llm_ptq/hf_ptq.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
is_enc_dec,
3535
is_nemotron_vl,
3636
load_mtp_weights,
37+
needs_checkpoint_path_update,
38+
resolve_checkpoint_dir,
3739
run_nemotron_vl_preview,
3840
)
3941
from torch.utils.data import DataLoader
@@ -1105,11 +1107,6 @@ def quantize_main(
11051107
quant_cfg = copy.deepcopy(quant_cfg)
11061108
_set_kv_cache_constant_amax(quant_cfg["quant_cfg"])
11071109

1108-
from modelopt.torch.quantization.utils.layerwise_calib import (
1109-
needs_checkpoint_path_update,
1110-
resolve_checkpoint_dir,
1111-
)
1112-
11131110
if needs_checkpoint_path_update(quant_cfg):
11141111
quant_cfg = resolve_checkpoint_dir(quant_cfg, args.pyt_ckpt_path)
11151112
print(f"Auto-resolved checkpoint_dir: {quant_cfg['algorithm']['checkpoint_dir']}")

modelopt/torch/quantization/utils/layerwise_calib.py

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -506,42 +506,6 @@ def _save_layer(
506506
_write_manifest(checkpoint_dir, idx, num_layers)
507507

508508

509-
def needs_checkpoint_path_update(quant_cfg: dict) -> bool:
510-
"""Check if quant_cfg has a checkpoint_dir that should be auto-resolved to a unique subpath."""
511-
algorithm = quant_cfg.get("algorithm")
512-
if algorithm is None or isinstance(algorithm, str):
513-
return False
514-
return algorithm.get("checkpoint_dir") is not None
515-
516-
517-
def resolve_checkpoint_dir(quant_cfg: dict, model_path: str) -> dict:
518-
"""Append a unique ``<model_name>_<config_hash>`` subdirectory to checkpoint_dir.
519-
520-
Allows a single recipe to be reused across models without checkpoint collisions.
521-
Must only be called when :func:`needs_checkpoint_path_update` returns True.
522-
"""
523-
import copy
524-
import hashlib
525-
from pathlib import Path
526-
527-
algorithm = quant_cfg["algorithm"]
528-
base_dir = algorithm["checkpoint_dir"]
529-
530-
name = model_path.rstrip("/")
531-
if "/" in name and not os.path.isabs(name):
532-
name = name.replace("/", "--")
533-
else:
534-
name = Path(name).name
535-
536-
config_hash = hashlib.sha256(
537-
json.dumps(quant_cfg, sort_keys=True, default=str).encode()
538-
).hexdigest()[:8]
539-
540-
quant_cfg = copy.deepcopy(quant_cfg)
541-
quant_cfg["algorithm"]["checkpoint_dir"] = os.path.join(base_dir, f"{name}_{config_hash}")
542-
return quant_cfg
543-
544-
545509
def detect_resume_point(checkpoint_dir: str) -> tuple[int, dict] | None:
546510
"""Detect where to resume from an existing checkpoint directory.
547511

0 commit comments

Comments
 (0)