Skip to content

Commit 4396187

Browse files
fix: replace deprecated torch._check_is_size with torch._check (#1940)
PyTorch's _check_is_size is being removed in a future release per pytorch/pytorch#169400 ("Use _check(i >= 0) instead"). This replaces all 23 occurrences with the recommended torch._check pattern, matching the existing torch._check style already used elsewhere in these files. Affected files (23 occurrences): - bitsandbytes/_ops.py (8) - bitsandbytes/backends/triton/ops.py (5) - bitsandbytes/backends/default/ops.py (4) - bitsandbytes/backends/cpu/ops.py (3) - bitsandbytes/backends/cuda/ops.py (2) - bitsandbytes/backends/hpu/ops.py (1) torch._check is available in all torch >= 2.0; bitsandbytes requires torch >= 2.2 per pyproject.toml, so no backwards-compat shim is needed. Co-authored-by: neil-the-nowledgable <254185769+neil-the-nowledgable@users.noreply.github.com> Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent a57d8e2 commit 4396187

6 files changed

Lines changed: 23 additions & 23 deletions

File tree

bitsandbytes/_ops.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def _(
183183
shape: Sequence[int],
184184
dtype: torch.dtype,
185185
) -> torch.Tensor:
186-
torch._check_is_size(blocksize)
186+
torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}")
187187
return torch.empty(shape, dtype=dtype, device=A.device)
188188

189189

@@ -203,7 +203,7 @@ def _(
203203
dtype: torch.dtype,
204204
out: torch.Tensor,
205205
) -> None:
206-
torch._check_is_size(blocksize)
206+
torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}")
207207
torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}")
208208
torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}")
209209
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")
@@ -219,7 +219,7 @@ def _(
219219
def _(
220220
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
221221
) -> tuple[torch.Tensor, torch.Tensor]:
222-
torch._check_is_size(blocksize)
222+
torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}")
223223

224224
n = A.numel()
225225
blocks = -(n // -blocksize)
@@ -236,7 +236,7 @@ def _(
236236

237237
@register_fake("bitsandbytes::dequantize_blockwise")
238238
def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor:
239-
torch._check_is_size(blocksize)
239+
torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}")
240240
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
241241
return torch.empty_like(A, dtype=dtype)
242242

@@ -251,7 +251,7 @@ def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int,
251251
def _(
252252
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor
253253
):
254-
torch._check_is_size(blocksize)
254+
torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}")
255255
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
256256
torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}")
257257
torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}")
@@ -263,7 +263,7 @@ def _(
263263

264264
@register_fake("bitsandbytes::quantize_blockwise")
265265
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
266-
torch._check_is_size(blocksize)
266+
torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}")
267267
n = A.numel()
268268
blocks = -(n // -blocksize)
269269
absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32)
@@ -281,7 +281,7 @@ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor
281281
def _(
282282
A: torch.Tensor, B: torch.Tensor, shapeB: Sequence[int], absmax: torch.Tensor, code: torch.Tensor, blocksize: int
283283
) -> torch.Tensor:
284-
torch._check_is_size(blocksize)
284+
torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}")
285285
torch._check(A.numel() == A.size(-1), lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}")
286286
torch._check(
287287
A.dtype in [torch.float16, torch.bfloat16, torch.float32],
@@ -311,7 +311,7 @@ def _(
311311
blocksize: int,
312312
out: torch.Tensor,
313313
) -> None:
314-
torch._check_is_size(blocksize)
314+
torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}")
315315
torch._check(A.numel() == A.size(-1), lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}")
316316
torch._check(
317317
A.dtype in [torch.float16, torch.bfloat16, torch.float32],

bitsandbytes/backends/cpu/ops.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def _(A: torch.Tensor, B: torch.Tensor):
3535

3636
@register_kernel("bitsandbytes::quantize_blockwise", "cpu")
3737
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
38-
torch._check_is_size(blocksize)
38+
torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}")
3939

4040
n = A.numel()
4141
blocks = -(n // -blocksize)
@@ -94,7 +94,7 @@ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor
9494
def _(
9595
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype
9696
) -> torch.Tensor:
97-
torch._check_is_size(blocksize)
97+
torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}")
9898
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
9999

100100
out = torch.empty_like(A, dtype=dtype)
@@ -146,7 +146,7 @@ def _(
146146
shape: Sequence[int],
147147
dtype: torch.dtype,
148148
) -> torch.Tensor:
149-
torch._check_is_size(blocksize)
149+
torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}")
150150
torch._check(quant_type in ("nf4", "fp4"), lambda: f"quant_type must be nf4 or fp4, got {quant_type}")
151151
torch._check(
152152
dtype in [torch.bfloat16, torch.float16, torch.float32],

bitsandbytes/backends/cuda/ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ def _get_col_absmax(
221221
@register_kernel("bitsandbytes::quantize_blockwise", "cuda")
222222
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
223223
A = A.contiguous()
224-
torch._check_is_size(blocksize)
224+
torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}")
225225

226226
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32])
227227

@@ -464,7 +464,7 @@ def _gemv_4bit_impl(
464464
blocksize: int,
465465
out: torch.Tensor,
466466
) -> None:
467-
torch._check_is_size(blocksize)
467+
torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}")
468468

469469
# Note: these checks are not strictly necessary, and cost more than they are worth, so they are commented out for now.
470470
# torch._check(

bitsandbytes/backends/default/ops.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def _(A: torch.Tensor, threshold=0.0):
175175

176176
@register_kernel("bitsandbytes::quantize_blockwise", "default")
177177
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
178-
torch._check_is_size(blocksize)
178+
torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}")
179179

180180
n = A.numel()
181181
rem = n % blocksize
@@ -201,7 +201,7 @@ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor
201201

202202
@register_kernel("bitsandbytes::dequantize_blockwise", "default")
203203
def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor:
204-
torch._check_is_size(blocksize)
204+
torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}")
205205
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
206206

207207
out = code[A.reshape(-1).int()]
@@ -220,7 +220,7 @@ def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int,
220220
def _(
221221
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
222222
) -> tuple[torch.Tensor, torch.Tensor]:
223-
torch._check_is_size(blocksize)
223+
torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}")
224224
torch._check(quant_type in ("nf4", "fp4"), lambda: f"quant_type must be nf4 or fp4, got {quant_type}")
225225
torch._check(
226226
A.dtype in [torch.bfloat16, torch.float16, torch.float32],
@@ -317,7 +317,7 @@ def _(
317317
shape: Sequence[int],
318318
dtype: torch.dtype,
319319
) -> torch.Tensor:
320-
torch._check_is_size(blocksize)
320+
torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}")
321321
torch._check(quant_type in ("nf4", "fp4"), lambda: f"quant_type must be nf4 or fp4, got {quant_type}")
322322
torch._check(
323323
dtype in [torch.bfloat16, torch.float16, torch.float32],

bitsandbytes/backends/hpu/ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def _(
2525
shape: Sequence[int],
2626
dtype: torch.dtype,
2727
) -> torch.Tensor:
28-
torch._check_is_size(blocksize)
28+
torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}")
2929
torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4, got {quant_type}")
3030
torch._check(
3131
A.dtype in [torch.bfloat16, torch.uint8],

bitsandbytes/backends/triton/ops.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616

1717
def quantize_blockwise(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
18-
torch._check_is_size(blocksize)
18+
torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}")
1919
# torch._check(A.dtype == torch.float32, lambda: f"A must be float32 on xpu, got {A.dtype}")
2020
with torch_accelerator_module.device(A.device):
2121
out, absmax = kernels_8bit_quant.quantize_blockwise_triton(A.contiguous(), code, blocksize)
@@ -25,7 +25,7 @@ def quantize_blockwise(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> t
2525
def dequantize_blockwise(
2626
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype
2727
) -> torch.Tensor:
28-
torch._check_is_size(blocksize)
28+
torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}")
2929
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
3030
# torch._check(dtype == torch.float32, lambda: f"dtype must be float32 on xpu, got {dtype}")
3131
with torch_accelerator_module.device(A.device):
@@ -47,7 +47,7 @@ def dequantize_blockwise_inplace(
4747
dtype: torch.dtype,
4848
out: torch.Tensor,
4949
) -> None:
50-
torch._check_is_size(blocksize)
50+
torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}")
5151
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
5252
torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}")
5353
torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}")
@@ -67,7 +67,7 @@ def dequantize_blockwise_inplace(
6767
def quantize_4bit(
6868
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
6969
) -> tuple[torch.Tensor, torch.Tensor]:
70-
torch._check_is_size(blocksize)
70+
torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}")
7171
# torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on CPU, got {quant_type}")
7272
torch._check(
7373
A.dtype in [torch.bfloat16, torch.float16, torch.float32],
@@ -109,7 +109,7 @@ def dequantize_4bit(
109109
shape: Sequence[int],
110110
dtype: torch.dtype,
111111
) -> torch.Tensor:
112-
torch._check_is_size(blocksize)
112+
torch._check(blocksize >= 0, lambda: f"Blocksize must be non-negative, got {blocksize}")
113113
# torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on XPU, got {quant_type}")
114114
torch._check(
115115
dtype in [torch.bfloat16, torch.float16, torch.float32],

0 commit comments

Comments
 (0)