|
3 | 3 | # This source code is licensed under the MIT license found in the |
4 | 4 | # LICENSE file in the root directory of this source tree. |
5 | 5 | import copy |
| 6 | +import logging |
6 | 7 | from typing import Any, Optional, TypeVar, Union, overload |
7 | | -import warnings |
8 | 8 |
|
9 | 9 | import torch |
10 | 10 | from torch import Tensor, device, dtype, nn |
|
20 | 20 | from bitsandbytes.optim import GlobalOptimManager |
21 | 21 | from bitsandbytes.utils import INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING, OutlierTracer |
22 | 22 |
|
| 23 | +logger = logging.getLogger(__name__) |
| 24 | + |
23 | 25 | T = TypeVar("T", bound="torch.nn.Module") |
24 | 26 |
|
25 | 27 |
|
@@ -443,7 +445,7 @@ def fix_4bit_weight_quant_state_from_module(module: Union["Embedding4bit", "Line |
443 | 445 | return |
444 | 446 |
|
445 | 447 | if getattr(module, "quant_state", None) is None: |
446 | | - warnings.warn( |
| 448 | + logger.warning( |
447 | 449 | "FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.", |
448 | 450 | ) |
449 | 451 |
|
@@ -536,15 +538,13 @@ def set_compute_type(self, x): |
536 | 538 | if self.compute_dtype in [None, torch.float32] and (x.numel() == x.shape[-1]): |
537 | 539 | # single batch inference with input torch.float16 and compute_dtype float32 -> slow inference when it could be fast |
538 | 540 | # warn the user about this |
539 | | - warnings.warn( |
| 541 | + logger.warning( |
540 | 542 | "Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference.", |
541 | 543 | ) |
542 | | - warnings.filterwarnings("ignore", message=".*inference.") |
543 | 544 | if self.compute_dtype in [None, torch.float32] and (x.numel() != x.shape[-1]): |
544 | | - warnings.warn( |
| 545 | + logger.warning( |
545 | 546 | "Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference or training speed.", |
546 | 547 | ) |
547 | | - warnings.filterwarnings("ignore", message=".*inference or training") |
548 | 548 |
|
549 | 549 | def _save_to_state_dict(self, destination, prefix, keep_vars): |
550 | 550 | """ |
@@ -877,7 +877,7 @@ def __init__( |
877 | 877 | blocksize = self.weight.blocksize |
878 | 878 |
|
879 | 879 | if embedding_dim % blocksize != 0: |
880 | | - warnings.warn( |
| 880 | + logger.warning( |
881 | 881 | f"Embedding size {embedding_dim} is not divisible by block size {blocksize}. " |
882 | 882 | "This will lead to slow inference.", |
883 | 883 | ) |
@@ -1164,9 +1164,8 @@ def forward(self, x): |
1164 | 1164 | if self.outlier_dim is None: |
1165 | 1165 | tracer = OutlierTracer.get_instance() |
1166 | 1166 | if not tracer.is_initialized(): |
1167 | | - print("Please use OutlierTracer.initialize(model) before using the OutlierAwareLinear layer") |
| 1167 | + logger.warning("Please use OutlierTracer.initialize(model) before using the OutlierAwareLinear layer") |
1168 | 1168 | outlier_idx = tracer.get_outliers(self.weight) |
1169 | | - # print(outlier_idx, tracer.get_hvalue(self.weight)) |
1170 | 1169 | self.outlier_dim = outlier_idx |
1171 | 1170 |
|
1172 | 1171 | if not self.is_quantized: |
|
0 commit comments