11"""Utilities for downloading CrossScore model checkpoints."""
22
33import os
4+ import urllib .request
5+ import shutil
46from pathlib import Path
57
8+ # Download directly from GitHub (served via Git LFS)
69CHECKPOINT_URL = (
7- "https://huggingface.co /ActiveVisionLab/CrossScore/resolve /main/CrossScore-v1.0.0.ckpt"
10+ "https://github.com /ActiveVisionLab/CrossScore/raw /main/ckpt /CrossScore-v1.0.0.ckpt"
811)
912CHECKPOINT_FILENAME = "CrossScore-v1.0.0.ckpt"
1013
@@ -19,9 +22,10 @@ def get_cache_dir() -> Path:
1922def get_checkpoint_path () -> str :
2023 """Get path to the CrossScore checkpoint, downloading it if necessary.
2124
22- Downloads from HuggingFace Hub on first use and caches locally.
23- Set CROSSSCORE_CACHE_DIR environment variable to customize cache location.
24- Set CROSSSCORE_CKPT_PATH to use a specific local checkpoint file.
25+ Downloads from GitHub (Git LFS) on first use and caches locally at
26+ ~/.cache/crossscore/. Set environment variables to customize:
27+ CROSSSCORE_CKPT_PATH - use a specific local checkpoint file
28+ CROSSSCORE_CACHE_DIR - custom cache directory
2529
2630 Returns:
2731 Path to the checkpoint file.
@@ -39,33 +43,29 @@ def get_checkpoint_path() -> str:
3943 if ckpt_path .exists ():
4044 return str (ckpt_path )
4145
42- print (f"Downloading CrossScore checkpoint to { ckpt_path } ..." )
43- print (f" Source: { CHECKPOINT_URL } " )
44- print (" (Set CROSSSCORE_CKPT_PATH to use a local checkpoint instead)" )
46+ print (f"Downloading CrossScore checkpoint (~129MB)..." )
47+ print (f" From: { CHECKPOINT_URL } " )
48+ print (f" To: { ckpt_path } " )
49+ print (" (Set CROSSSCORE_CKPT_PATH to skip download and use a local file)" )
4550
51+ tmp_path = str (ckpt_path ) + ".tmp"
4652 try :
47- from huggingface_hub import hf_hub_download
53+ urllib .request .urlretrieve (CHECKPOINT_URL , tmp_path , _download_progress )
54+ os .rename (tmp_path , str (ckpt_path ))
55+ except Exception :
56+ if os .path .exists (tmp_path ):
57+ os .remove (tmp_path )
58+ raise
4859
49- downloaded_path = hf_hub_download (
50- repo_id = "ActiveVisionLab/CrossScore" ,
51- filename = CHECKPOINT_FILENAME ,
52- local_dir = str (cache_dir ),
53- )
54- return downloaded_path
55- except ImportError :
56- # Fallback to urllib if huggingface_hub not installed
57- import urllib .request
58- import shutil
60+ print (f"\n Download complete." )
61+ return str (ckpt_path )
5962
60- tmp_path = str (ckpt_path ) + ".tmp"
61- try :
62- with urllib .request .urlopen (CHECKPOINT_URL ) as response , open (tmp_path , "wb" ) as out :
63- shutil .copyfileobj (response , out )
64- os .rename (tmp_path , str (ckpt_path ))
65- except Exception :
66- if os .path .exists (tmp_path ):
67- os .remove (tmp_path )
68- raise
6963
70- print (f" Download complete: { ckpt_path } " )
71- return str (ckpt_path )
64+ def _download_progress (block_count , block_size , total_size ):
65+ """Progress callback for urlretrieve."""
66+ downloaded = block_count * block_size
67+ if total_size > 0 :
68+ pct = min (100 , downloaded * 100 // total_size )
69+ mb_done = downloaded / (1024 * 1024 )
70+ mb_total = total_size / (1024 * 1024 )
71+ print (f"\r { mb_done :.1f} /{ mb_total :.1f} MB ({ pct } %)" , end = "" , flush = True )
0 commit comments