Skip to content

Commit cc6d1bc

Browse files
use logging
1 parent 75ef76b commit cc6d1bc

3 files changed

Lines changed: 16 additions & 11 deletions

File tree

bitsandbytes/autograd/_functions.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from dataclasses import dataclass
2+
import logging
23
from math import prod
34
from typing import Optional
45
import warnings
@@ -8,6 +9,8 @@
89

910
import bitsandbytes.functional as F
1011

12+
logger = logging.getLogger(__name__)
13+
1114
# The inverse transformation for the colTuring and colAmpere format were contributed by Alex Borzunov:
1215
# https://github.com/bigscience-workshop/petals/blob/main/src/petals/utils/linear8bitlt_patch.py
1316

@@ -123,7 +126,7 @@ def forward(
123126

124127
# Cast A to fp16
125128
if A.dtype != torch.float16 and not _is_compiling():
126-
warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
129+
logger.warning("MatMul8bitLt: inputs will be cast from %s to float16 during quantization", A.dtype)
127130

128131
if len(A.shape) == 3:
129132
A = A.reshape(-1, A.shape[-1])

bitsandbytes/nn/modules.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55
import copy
6+
import logging
67
from typing import Any, Optional, TypeVar, Union, overload
7-
import warnings
88

99
import torch
1010
from torch import Tensor, device, dtype, nn
@@ -20,6 +20,8 @@
2020
from bitsandbytes.optim import GlobalOptimManager
2121
from bitsandbytes.utils import INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING, OutlierTracer
2222

23+
logger = logging.getLogger(__name__)
24+
2325
T = TypeVar("T", bound="torch.nn.Module")
2426

2527

@@ -443,7 +445,7 @@ def fix_4bit_weight_quant_state_from_module(module: Union["Embedding4bit", "Line
443445
return
444446

445447
if getattr(module, "quant_state", None) is None:
446-
warnings.warn(
448+
logger.warning(
447449
"FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.",
448450
)
449451

@@ -536,15 +538,13 @@ def set_compute_type(self, x):
536538
if self.compute_dtype in [None, torch.float32] and (x.numel() == x.shape[-1]):
537539
# single batch inference with input torch.float16 and compute_dtype float32 -> slow inference when it could be fast
538540
# warn the user about this
539-
warnings.warn(
541+
logger.warning(
540542
"Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference.",
541543
)
542-
warnings.filterwarnings("ignore", message=".*inference.")
543544
if self.compute_dtype in [None, torch.float32] and (x.numel() != x.shape[-1]):
544-
warnings.warn(
545+
logger.warning(
545546
"Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference or training speed.",
546547
)
547-
warnings.filterwarnings("ignore", message=".*inference or training")
548548

549549
def _save_to_state_dict(self, destination, prefix, keep_vars):
550550
"""
@@ -877,7 +877,7 @@ def __init__(
877877
blocksize = self.weight.blocksize
878878

879879
if embedding_dim % blocksize != 0:
880-
warnings.warn(
880+
logger.warning(
881881
f"Embedding size {embedding_dim} is not divisible by block size {blocksize}. "
882882
"This will lead to slow inference.",
883883
)
@@ -1164,9 +1164,8 @@ def forward(self, x):
11641164
if self.outlier_dim is None:
11651165
tracer = OutlierTracer.get_instance()
11661166
if not tracer.is_initialized():
1167-
print("Please use OutlierTracer.initialize(model) before using the OutlierAwareLinear layer")
1167+
logger.warning("Please use OutlierTracer.initialize(model) before using the OutlierAwareLinear layer")
11681168
outlier_idx = tracer.get_outliers(self.weight)
1169-
# print(outlier_idx, tracer.get_hvalue(self.weight))
11701169
self.outlier_dim = outlier_idx
11711170

11721171
if not self.is_quantized:

bitsandbytes/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
import json
2+
import logging
23
import shlex
34
import subprocess
45

56
import torch
67

8+
logger = logging.getLogger(__name__)
9+
710

811
def outlier_hook(module, input):
912
assert isinstance(module, torch.nn.Linear)
@@ -65,7 +68,7 @@ def get_hvalue(self, weight):
6568

6669
def get_outliers(self, weight):
6770
if not self.is_initialized():
68-
print("Outlier tracer is not initialized...")
71+
logger.warning("Outlier tracer is not initialized...")
6972
return None
7073
hvalue = self.get_hvalue(weight)
7174
if hvalue in self.hvalue2outlier_idx:

0 commit comments

Comments
 (0)