Skip to content

Commit d77e01c

Browse files
Replace print/warnings with logging in library modules (#1883)
* use logging * fix
1 parent 75ef76b commit d77e01c

File tree

4 files changed

+33
-28
lines changed

4 files changed

+33
-28
lines changed

bitsandbytes/autograd/_functions.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from dataclasses import dataclass
2+
import logging
23
from math import prod
34
from typing import Optional
45
import warnings
@@ -8,6 +9,8 @@
89

910
import bitsandbytes.functional as F
1011

12+
logger = logging.getLogger(__name__)
13+
1114
# The inverse transformation for the colTuring and colAmpere format were contributed by Alex Borzunov:
1215
# https://github.com/bigscience-workshop/petals/blob/main/src/petals/utils/linear8bitlt_patch.py
1316

@@ -123,7 +126,7 @@ def forward(
123126

124127
# Cast A to fp16
125128
if A.dtype != torch.float16 and not _is_compiling():
126-
warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
129+
logger.warning("MatMul8bitLt: inputs will be cast from %s to float16 during quantization", A.dtype)
127130

128131
if len(A.shape) == 3:
129132
A = A.reshape(-1, A.shape[-1])

bitsandbytes/nn/modules.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55
import copy
6+
import logging
67
from typing import Any, Optional, TypeVar, Union, overload
7-
import warnings
88

99
import torch
1010
from torch import Tensor, device, dtype, nn
@@ -20,6 +20,8 @@
2020
from bitsandbytes.optim import GlobalOptimManager
2121
from bitsandbytes.utils import INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING, OutlierTracer
2222

23+
logger = logging.getLogger(__name__)
24+
2325
T = TypeVar("T", bound="torch.nn.Module")
2426

2527

@@ -443,7 +445,7 @@ def fix_4bit_weight_quant_state_from_module(module: Union["Embedding4bit", "Line
443445
return
444446

445447
if getattr(module, "quant_state", None) is None:
446-
warnings.warn(
448+
logger.warning(
447449
"FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.",
448450
)
449451

@@ -536,15 +538,13 @@ def set_compute_type(self, x):
536538
if self.compute_dtype in [None, torch.float32] and (x.numel() == x.shape[-1]):
537539
# single batch inference with input torch.float16 and compute_dtype float32 -> slow inference when it could be fast
538540
# warn the user about this
539-
warnings.warn(
541+
logger.warning(
540542
"Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference.",
541543
)
542-
warnings.filterwarnings("ignore", message=".*inference.")
543544
if self.compute_dtype in [None, torch.float32] and (x.numel() != x.shape[-1]):
544-
warnings.warn(
545+
logger.warning(
545546
"Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference or training speed.",
546547
)
547-
warnings.filterwarnings("ignore", message=".*inference or training")
548548

549549
def _save_to_state_dict(self, destination, prefix, keep_vars):
550550
"""
@@ -877,7 +877,7 @@ def __init__(
877877
blocksize = self.weight.blocksize
878878

879879
if embedding_dim % blocksize != 0:
880-
warnings.warn(
880+
logger.warning(
881881
f"Embedding size {embedding_dim} is not divisible by block size {blocksize}. "
882882
"This will lead to slow inference.",
883883
)
@@ -1164,9 +1164,8 @@ def forward(self, x):
11641164
if self.outlier_dim is None:
11651165
tracer = OutlierTracer.get_instance()
11661166
if not tracer.is_initialized():
1167-
print("Please use OutlierTracer.initialize(model) before using the OutlierAwareLinear layer")
1167+
logger.warning("Please use OutlierTracer.initialize(model) before using the OutlierAwareLinear layer")
11681168
outlier_idx = tracer.get_outliers(self.weight)
1169-
# print(outlier_idx, tracer.get_hvalue(self.weight))
11701169
self.outlier_dim = outlier_idx
11711170

11721171
if not self.is_quantized:

bitsandbytes/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
import json
2+
import logging
23
import shlex
34
import subprocess
45

56
import torch
67

8+
logger = logging.getLogger(__name__)
9+
710

811
def outlier_hook(module, input):
912
assert isinstance(module, torch.nn.Linear)
@@ -65,7 +68,7 @@ def get_hvalue(self, weight):
6568

6669
def get_outliers(self, weight):
6770
if not self.is_initialized():
68-
print("Outlier tracer is not initialized...")
71+
logger.warning("Outlier tracer is not initialized...")
6972
return None
7073
hvalue = self.get_hvalue(weight)
7174
if hvalue in self.hvalue2outlier_idx:

tests/test_modules.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import contextlib
12
import inspect
3+
import logging
24

35
import pytest
46
import torch
@@ -8,6 +10,12 @@
810
from tests.helpers import get_available_devices, id_formatter, is_supported_on_hpu
911

1012

13+
@contextlib.contextmanager
14+
def caplog_at_level(caplog, level, logger_name):
15+
with caplog.at_level(level, logger=logger_name):
16+
yield
17+
18+
1119
class MockArgs:
1220
def __init__(self, initial_data):
1321
for key in initial_data:
@@ -453,46 +461,38 @@ def test_embedding_error(device, embedding_class, input_shape, embedding_dim, qu
453461

454462

455463
@pytest.mark.parametrize("device", get_available_devices())
456-
def test_4bit_linear_warnings(device):
464+
def test_4bit_linear_warnings(device, caplog):
457465
dim1 = 64
458466

459-
with pytest.warns(UserWarning, match=r"inference or training"):
460-
net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4") for i in range(10)])
461-
net = net.to(device)
462-
inp = torch.rand(10, dim1, device=device, dtype=torch.float16)
463-
net(inp)
464-
with pytest.warns(UserWarning, match=r"inference."):
465-
net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4") for i in range(10)])
466-
net = net.to(device)
467-
inp = torch.rand(1, dim1, device=device, dtype=torch.float16)
468-
net(inp)
469-
470-
with pytest.warns(UserWarning) as record:
467+
with caplog_at_level(caplog, logging.WARNING, "bitsandbytes.nn.modules"):
471468
net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4") for i in range(10)])
472469
net = net.to(device)
473470
inp = torch.rand(10, dim1, device=device, dtype=torch.float16)
474471
net(inp)
472+
assert any("inference or training" in msg for msg in caplog.messages)
475473

474+
caplog.clear()
475+
with caplog_at_level(caplog, logging.WARNING, "bitsandbytes.nn.modules"):
476476
net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4") for i in range(10)])
477477
net = net.to(device)
478478
inp = torch.rand(1, dim1, device=device, dtype=torch.float16)
479479
net(inp)
480-
481-
assert len(record) == 2
480+
assert any("inference." in msg for msg in caplog.messages)
482481

483482

484483
@pytest.mark.parametrize("device", get_available_devices())
485-
def test_4bit_embedding_warnings(device):
484+
def test_4bit_embedding_warnings(device, caplog):
486485
num_embeddings = 128
487486
default_block_size = 64
488487

489-
with pytest.warns(UserWarning, match=r"inference."):
488+
with caplog_at_level(caplog, logging.WARNING, "bitsandbytes.nn.modules"):
490489
net = bnb.nn.Embedding4bit(
491490
num_embeddings=num_embeddings, embedding_dim=default_block_size + 1, quant_type="nf4"
492491
)
493492
net.to(device)
494493
inp = torch.randint(low=0, high=num_embeddings, size=(1,), device=device)
495494
net(inp)
495+
assert any("inference" in msg for msg in caplog.messages)
496496

497497

498498
def test_4bit_embedding_weight_fsdp_fix(requires_cuda):

0 commit comments

Comments
 (0)