77import subprocess
88import sys
99import tarfile
10- import urllib .request
1110from pathlib import Path
1211from shutil import rmtree
1312
@@ -30,10 +29,11 @@ def _ensure_cutlass_source() -> Path:
3029
3130 archive_path = deps_dir / f"cutlass-v{ CUTLASS_VERSION } .tar.gz"
3231 if not archive_path .exists ():
33- print (f"Downloading CUTLASS v{ CUTLASS_VERSION } ..." )
34- with urllib .request .urlopen (CUTLASS_RELEASE_URL ) as response :
35- data = response .read ()
36- archive_path .write_bytes (data )
32+ _download_with_progress (
33+ CUTLASS_RELEASE_URL ,
34+ str (archive_path ),
35+ title = f"Downloading CUTLASS v{ CUTLASS_VERSION } " ,
36+ )
3737
3838 if cutlass_root .exists ():
3939 rmtree (cutlass_root )
@@ -421,6 +421,51 @@ def _resolve_wheel_url(tag_name: str, wheel_name: str) -> str:
421421 return DEFAULT_WHEEL_URL_TEMPLATE .format (tag_name = tag_name , wheel_name = wheel_name )
422422
423423
424+ def _download_with_progress (url : str , dest_path : str , title : str = "Downloading" ) -> None :
425+ """Download url to dest_path with simple stdout progress updates."""
426+ import time
427+ import urllib .request as req
428+
429+ start_time = time .time ()
430+ last_draw_time = 0.0
431+ last_print_percent = - 1
432+
433+ def _format_bytes (num_bytes : float ) -> str :
434+ units = ["B" , "KiB" , "MiB" , "GiB" , "TiB" ]
435+ value = float (max (num_bytes , 0.0 ))
436+ for unit in units :
437+ if value < 1024.0 or unit == units [- 1 ]:
438+ return f"{ value :0.1f} { unit } " if unit != "B" else f"{ int (value )} B"
439+ value /= 1024.0
440+ return f"{ value :0.1f} TiB"
441+
442+ def _reporthook (block_num : int , block_size : int , total_size : int ) -> None :
443+ nonlocal last_draw_time , last_print_percent
444+ now = time .time ()
445+ downloaded = block_num * block_size
446+ speed = downloaded / max (now - start_time , 1e-6 )
447+
448+ if total_size and total_size > 0 :
449+ percent = min (int (downloaded * 100 / total_size ), 100 )
450+ if percent == last_print_percent and percent != 100 :
451+ return
452+ subtitle = (
453+ f"{ percent :3d} % ({ _format_bytes (downloaded )} /{ _format_bytes (total_size )} ) "
454+ f"{ _format_bytes (speed )} /s"
455+ )
456+ print (f"{ title } { subtitle } " , flush = True )
457+ last_print_percent = percent
458+ last_draw_time = now
459+ else :
460+ if (now - last_draw_time ) < 1.0 :
461+ return
462+ subtitle = f"{ _format_bytes (downloaded )} { _format_bytes (speed )} /s"
463+ print (f"{ title } { subtitle } " , flush = True )
464+ last_draw_time = now
465+
466+ req .urlretrieve (url , dest_path , reporthook = _reporthook )
467+
468+
424469# Decide HAS_CUDA_V8 / HAS_CUDA_V9 without torch
425470HAS_CUDA_V8 = False
426471HAS_CUDA_V9 = False
@@ -833,12 +878,12 @@ def run(self):
833878 print (f"Resolved wheel URL: { wheel_url } \n wheel name={ wheel_filename } " )
834879
835880 try :
836- import urllib .request as req
837- req .urlretrieve (wheel_url , os .path .join (self .dist_dir , wheel_filename ))
838-
839881 if not os .path .exists (self .dist_dir ):
840882 os .makedirs (self .dist_dir )
841883
884+ wheel_path = os .path .join (self .dist_dir , wheel_filename )
885+
886+ _download_with_progress (wheel_url , wheel_path , title = "Downloading wheel" )
842887 print ("Raw wheel path" , wheel_filename )
843888 except BaseException :
844889 env_info = [f"python={ python_version } " , f"torch={ TORCH_VERSION or 'unknown' } " ]
0 commit comments