Skip to content

Commit b1abb5d

Browse files
authored
Setup download progress (#2289)
* log download progress * refractor * refractor * format
1 parent fd0a2ab commit b1abb5d

3 files changed

Lines changed: 55 additions & 11 deletions

File tree

gptqmodel/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,4 @@
77
# even minor versions are release
88
# 5.2.0 => release, 5.1.0 => devel
99
# micro version (5.2.x) denotes patch fix, i.e. 5.2.1 is a patch fix release
10-
__version__ = "5.6.99"
10+
__version__ = "5.6.12"

setup.py

Lines changed: 53 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import subprocess
88
import sys
99
import tarfile
10-
import urllib.request
1110
from pathlib import Path
1211
from shutil import rmtree
1312

@@ -30,10 +29,11 @@ def _ensure_cutlass_source() -> Path:
3029

3130
archive_path = deps_dir / f"cutlass-v{CUTLASS_VERSION}.tar.gz"
3231
if not archive_path.exists():
33-
print(f"Downloading CUTLASS v{CUTLASS_VERSION} ...")
34-
with urllib.request.urlopen(CUTLASS_RELEASE_URL) as response:
35-
data = response.read()
36-
archive_path.write_bytes(data)
32+
_download_with_progress(
33+
CUTLASS_RELEASE_URL,
34+
str(archive_path),
35+
title=f"Downloading CUTLASS v{CUTLASS_VERSION}",
36+
)
3737

3838
if cutlass_root.exists():
3939
rmtree(cutlass_root)
@@ -421,6 +421,51 @@ def _resolve_wheel_url(tag_name: str, wheel_name: str) -> str:
421421
return DEFAULT_WHEEL_URL_TEMPLATE.format(tag_name=tag_name, wheel_name=wheel_name)
422422

423423

424+
def _download_with_progress(url: str, dest_path: str, title: str = "Downloading") -> None:
425+
"""Download url to dest_path with simple stdout progress updates."""
426+
import time
427+
import urllib.request as req
428+
429+
start_time = time.time()
430+
last_draw_time = 0.0
431+
last_print_percent = -1
432+
433+
def _format_bytes(num_bytes: float) -> str:
434+
units = ["B", "KiB", "MiB", "GiB", "TiB"]
435+
value = float(max(num_bytes, 0.0))
436+
for unit in units:
437+
if value < 1024.0 or unit == units[-1]:
438+
return f"{value:0.1f}{unit}" if unit != "B" else f"{int(value)}B"
439+
value /= 1024.0
440+
return f"{value:0.1f}TiB"
441+
442+
def _reporthook(block_num: int, block_size: int, total_size: int) -> None:
443+
nonlocal last_draw_time, last_print_percent
444+
now = time.time()
445+
downloaded = block_num * block_size
446+
speed = downloaded / max(now - start_time, 1e-6)
447+
448+
if total_size and total_size > 0:
449+
percent = min(int(downloaded * 100 / total_size), 100)
450+
if percent == last_print_percent and percent != 100:
451+
return
452+
subtitle = (
453+
f"{percent:3d}% ({_format_bytes(downloaded)}/{_format_bytes(total_size)}) "
454+
f"{_format_bytes(speed)}/s"
455+
)
456+
print(f"{title} {subtitle}", flush=True)
457+
last_print_percent = percent
458+
last_draw_time = now
459+
else:
460+
if (now - last_draw_time) < 1.0:
461+
return
462+
subtitle = f"{_format_bytes(downloaded)} {_format_bytes(speed)}/s"
463+
print(f"{title} {subtitle}", flush=True)
464+
last_draw_time = now
465+
466+
req.urlretrieve(url, dest_path, reporthook=_reporthook)
467+
468+
424469
# Decide HAS_CUDA_V8 / HAS_CUDA_V9 without torch
425470
HAS_CUDA_V8 = False
426471
HAS_CUDA_V9 = False
@@ -833,12 +878,12 @@ def run(self):
833878
print(f"Resolved wheel URL: {wheel_url}\nwheel name={wheel_filename}")
834879

835880
try:
836-
import urllib.request as req
837-
req.urlretrieve(wheel_url, os.path.join(self.dist_dir, wheel_filename))
838-
839881
if not os.path.exists(self.dist_dir):
840882
os.makedirs(self.dist_dir)
841883

884+
wheel_path = os.path.join(self.dist_dir, wheel_filename)
885+
886+
_download_with_progress(wheel_url, wheel_path, title="Downloading wheel")
842887
print("Raw wheel path", wheel_filename)
843888
except BaseException:
844889
env_info = [f"python={python_version}", f"torch={TORCH_VERSION or 'unknown'}"]

tests/test_bitblas_gptq_v2.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,12 @@
33
import pytest
44
import torch
55

6+
from gptqmodel import BACKEND, GPTQModel
67
from gptqmodel.nn_modules.qlinear.bitblas import (
78
BITBLAS_AVAILABLE,
89
import_bitblas,
910
)
1011

11-
from gptqmodel import GPTQModel, BACKEND
12-
from gptqmodel.quantization.config import FORMAT
1312

1413
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required for BitBLAS")
1514
@pytest.mark.skipif(not BITBLAS_AVAILABLE, reason="BitBLAS backend is not available")

0 commit comments

Comments
 (0)