Skip to content

Commit 49a768a

Browse files
TimDettmersclaude
andcommitted
Remove deprecated APIs: research module, non-blockwise optimizers, and legacy quantization functions
Remove all remaining deprecated code that has been emitting FutureWarning since v0.45.0 (December 2024). Two prior cleanup rounds (v0.47.0, v0.49.0) already removed the easier items; this finishes the job. - Delete quantize(), dequantize(), quantize_no_absmax(), dequantize_no_absmax(), optimizer_update_8bit(), percentile_clipping(), and the str2optimizer8bit dispatch table from functional.py - Remove the non-blockwise 8-bit optimizer path from Optimizer2State and Optimizer1State; LAMB/LARS now use blockwise quantization - Remove percentile_clipping and block_wise parameters from all ~33 optimizer class constructors - Delete bitsandbytes/research/ (FP8 matmul, SwitchBack) - Delete bitsandbytes/nn/triton_based_modules.py, SwitchBackLinearBnb, and the orphaned bitsandbytes/triton/ kernel directory - Remove dead MatmulLtState fields (CxB, CxBt, formatB, _tile_indices) - Delete test_deprecated.py, test_triton.py; clean test_autograd.py, test_optim.py, test_functional.py - Remove benchmarking/switchback/ and update docs Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent a2c92f7 commit 49a768a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+32
-3191
lines changed

benchmarking/switchback/README.md

Lines changed: 0 additions & 4 deletions
This file was deleted.

benchmarking/switchback/info_a100_py2.jsonl

Lines changed: 0 additions & 60 deletions
This file was deleted.

benchmarking/switchback/make_plot_with_jsonl.py

Lines changed: 0 additions & 151 deletions
This file was deleted.
-34.1 KB
Binary file not shown.

benchmarking/switchback/speed_benchmark.py

Lines changed: 0 additions & 160 deletions
This file was deleted.

bitsandbytes/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import torch
1111

12-
from . import _ops, research, utils
12+
from . import _ops, utils
1313
from .autograd._functions import (
1414
MatmulLtState,
1515
matmul,

bitsandbytes/autograd/_functions.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -53,16 +53,12 @@ def get_current_outlier_idx(self):
5353

5454
@dataclass
5555
class MatmulLtState:
56-
_tile_indices: Optional[torch.Tensor] = None # TODO: remove
57-
5856
force_no_igemmlt: bool = False
5957

6058
CB: Optional[torch.Tensor] = None
61-
CxB: Optional[torch.Tensor] = None # TODO: Deprecate/remove
6259
SB: Optional[torch.Tensor] = None
6360
SCB: Optional[torch.Tensor] = None
6461

65-
CxBt: Optional[torch.Tensor] = None # TODO: Deprecate/remove
6662
SBt: Optional[torch.Tensor] = None
6763
CBt: Optional[torch.Tensor] = None
6864

@@ -75,22 +71,15 @@ class MatmulLtState:
7571
is_training = True
7672
has_fp16_weights = True
7773
use_pool = False
78-
formatB = "row" # TODO: Deprecate/remove
7974

8075
def reset_grads(self):
8176
self.CB = None
82-
self.CxB = None
8377
self.SB = None
8478
self.SCB = None
8579

86-
self.CxBt = None
8780
self.SBt = None
8881
self.CBt = None
8982

90-
@property
91-
def tile_indices(self):
92-
raise ValueError("tile_indices is no longer supported.")
93-
9483

9584
class MatMul8bitLt(torch.autograd.Function):
9685
@staticmethod
@@ -293,7 +282,6 @@ def backward(ctx, grad_output):
293282

294283
class MatMul4Bit(torch.autograd.Function):
295284
# forward is the same, but we added the fallback for pre-turing GPUs
296-
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
297285

298286
@staticmethod
299287
def forward(ctx, A, B, out=None, bias=None, quant_state: Optional[F.QuantState] = None):

bitsandbytes/backends/utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,8 @@
44
import torch
55

66
try:
7-
import triton.language as tl # noqa: F401
8-
97
import triton # noqa: F401
8+
import triton.language as tl # noqa: F401
109

1110
triton_available = True
1211
except ImportError:

0 commit comments

Comments
 (0)