Skip to content

Commit 9ec7421

Browse files
committed
Fsdp integration
2 parents be6353f + 08e3ccd commit 9ec7421

17 files changed

Lines changed: 430 additions & 94 deletions

File tree

.github/scripts/build-cuda.sh

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,19 @@
22
declare build_arch
33
declare build_os
44
declare cuda_version
5+
declare cuda_targets
56

67
set -xeuo pipefail
78

8-
# By default, target Maxwell through Hopper.
9-
build_capability="50;52;60;61;70;75;80;86;89;90"
9+
if [[ -v cuda_targets ]]; then
10+
build_capability="${cuda_targets}"
11+
else
12+
# By default, target Maxwell through Hopper.
13+
build_capability="50;52;60;61;70;75;80;86;89;90"
1014

11-
# CUDA 12.8: Add sm100 and sm120; remove < sm75 to align with PyTorch 2.7+cu128 minimum
12-
[[ "${cuda_version}" == 12.8.* ]] && build_capability="75;80;86;89;90;100;120"
15+
# CUDA 12.8: Add sm100 and sm120; remove < sm75 to align with PyTorch 2.7+cu128 minimum
16+
[[ "${cuda_version}" == 12.8.* ]] && build_capability="75;80;86;89;90;100;120"
17+
fi
1318

1419
[[ "${build_os}" = windows-* ]] && python3 -m pip install ninja
1520

.github/workflows/tests.yml

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
name: Unit tests
2+
3+
on:
4+
workflow_dispatch:
5+
schedule:
6+
# Every day at 02:15 AM UTC
7+
- cron: "15 2 * * *"
8+
push:
9+
branches: [testing-ci]
10+
11+
concurrency:
12+
group: ${{ github.workflow }}-${{ github.ref }}
13+
cancel-in-progress: true
14+
15+
jobs:
16+
17+
build-cpu:
18+
strategy:
19+
matrix:
20+
os: [ubuntu-22.04, windows-2025]
21+
arch: [x86_64]
22+
runs-on: ${{ matrix.os }}
23+
steps:
24+
- uses: actions/checkout@v4
25+
26+
- name: Setup MSVC
27+
if: startsWith(matrix.os, 'windows')
28+
uses: ilammy/msvc-dev-cmd@v1.13.0 # to use cl
29+
30+
- name: Build C++
31+
run: bash .github/scripts/build-cpu.sh
32+
env:
33+
build_os: ${{ matrix.os }}
34+
build_arch: ${{ matrix.arch }}
35+
36+
- name: Upload build artifact
37+
uses: actions/upload-artifact@v4
38+
with:
39+
name: lib_cpu_${{ matrix.os }}_${{ matrix.arch }}
40+
path: output/${{ matrix.os }}/${{ matrix.arch }}/*
41+
retention-days: 7
42+
43+
build-cuda:
44+
strategy:
45+
matrix:
46+
cuda_version: ["11.8.0", "12.8.1"]
47+
os: [ubuntu-22.04, windows-2025]
48+
arch: [x86_64]
49+
runs-on: ${{ matrix.os }}
50+
51+
steps:
52+
- uses: actions/checkout@v4
53+
54+
- name: Install CUDA Toolkit
55+
uses: Jimver/cuda-toolkit@v0.2.23
56+
if: startsWith(matrix.os, 'windows')
57+
id: cuda-toolkit
58+
with:
59+
cuda: ${{ matrix.cuda_version }}
60+
method: "network"
61+
sub-packages: '["nvcc","cudart","cusparse","cublas","thrust","nvrtc_dev","cublas_dev","cusparse_dev"]'
62+
use-github-cache: false
63+
64+
- name: Setup MSVC
65+
if: startsWith(matrix.os, 'windows')
66+
uses: ilammy/msvc-dev-cmd@v1.13.0 # to use cl
67+
68+
# We're running on T4 only for now, so we only target sm75.
69+
- name: Build C++ / CUDA
70+
run: bash .github/scripts/build-cuda.sh
71+
env:
72+
build_os: ${{ matrix.os }}
73+
build_arch: x86_64
74+
cuda_version: ${{ matrix.cuda_version }}
75+
cuda_targets: "75"
76+
77+
- name: Upload build artifact
78+
uses: actions/upload-artifact@v4
79+
with:
80+
name: lib_cuda_${{matrix.cuda_version}}_${{ matrix.os }}_${{ matrix.arch }}
81+
path: output/${{ matrix.os }}/${{ matrix.arch }}/*
82+
retention-days: 7
83+
84+
cpu-tests:
85+
needs: build-cpu
86+
strategy:
87+
fail-fast: false
88+
matrix:
89+
os: [ubuntu-22.04, windows-2025]
90+
arch: [x86_64]
91+
torch_version: ["2.7.0"]
92+
runs-on: ${{ matrix.os }}
93+
env:
94+
BNB_TEST_DEVICE: cpu
95+
steps:
96+
- uses: actions/checkout@v4
97+
98+
- name: Download build artifact
99+
uses: actions/download-artifact@v4
100+
with:
101+
name: lib_cpu_${{ matrix.os }}_${{ matrix.arch }}
102+
path: bitsandbytes/
103+
merge-multiple: true
104+
105+
- name: Setup Python
106+
uses: actions/setup-python@v5
107+
with:
108+
python-version: 3.9
109+
110+
- name: Install dependencies
111+
run: |
112+
pip install torch==${{ matrix.torch_version }} --index-url https://download.pytorch.org/whl/cpu
113+
pip install -e ".[test]"
114+
pip install pytest-cov
115+
116+
- name: Show installed packages
117+
run: pip list
118+
119+
- name: Run tests
120+
run: pytest
121+
122+
cuda-tests:
123+
needs: build-cuda
124+
strategy:
125+
fail-fast: false
126+
matrix:
127+
os: [ubuntu-22.04, windows-2025]
128+
arch: [x86_64]
129+
cuda_version: ["11.8.0", "12.8.1"]
130+
include:
131+
- cuda_version: "11.8.0"
132+
torch_version: "2.4.1"
133+
pypi_index: "https://download.pytorch.org/whl/cu118"
134+
- cuda_version: "12.8.1"
135+
torch_version: "2.7.0"
136+
pypi_index: "https://download.pytorch.org/whl/cu128"
137+
exclude:
138+
# Our current T4 Windows runner has a driver too old (471.11)
139+
# and cannot support CUDA 12+. Skip for now.
140+
- os: windows-2025
141+
cuda_version: "12.8.1"
142+
runs-on:
143+
labels: ${{ contains(matrix.os, 'windows') && 'CUDA-Windows-x64' || 'CUDA-Linux-x64' }}
144+
env:
145+
BNB_TEST_DEVICE: cuda
146+
steps:
147+
- name: Show GPU Information
148+
run: nvidia-smi
149+
150+
- uses: actions/checkout@v4
151+
152+
- name: Download build artifact
153+
uses: actions/download-artifact@v4
154+
with:
155+
name: lib_cuda_${{ matrix.cuda_version }}_${{ matrix.os }}_${{ matrix.arch }}
156+
path: bitsandbytes/
157+
merge-multiple: true
158+
159+
- name: Setup Python
160+
uses: actions/setup-python@v5
161+
with:
162+
python-version: 3.9
163+
164+
- name: Install dependencies
165+
run: |
166+
pip install torch==${{ matrix.torch_version }} --index-url ${{ matrix.pypi_index }}
167+
pip install -e ".[test]"
168+
pip install pytest-cov
169+
170+
- name: Show installed packages
171+
run: pip list
172+
173+
- name: Run tests
174+
run: pytest

bitsandbytes/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
# This is a signal for integrations with transformers/diffusers.
2323
# Eventually we may remove this but it is currently required for compatibility.
24-
features = {"multi-backend"}
24+
features = {"multi_backend"}
2525
supported_torch_devices = {
2626
"cpu",
2727
"cuda", # NVIDIA/AMD GPU

bitsandbytes/autograd/_functions.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,8 @@ def forward(
236236
ctx.state = state
237237

238238
ctx.grad_shape = input_shape
239-
ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype
239+
ctx.dtype_A = A.dtype
240+
ctx.dtype_bias = None if bias is None else bias.dtype
240241

241242
if any(ctx.needs_input_grad[:2]):
242243
ctx.tensors = (CAt, subA, A)

bitsandbytes/backends/cpu/ops.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections.abc import Sequence
12
import ctypes as ct
23
from typing import Optional
34

@@ -119,6 +120,10 @@ def _(
119120
) -> tuple[torch.Tensor, torch.Tensor]:
120121
torch._check_is_size(blocksize)
121122
torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on CPU, got {quant_type}")
123+
torch._check(
124+
A.dtype in [torch.bfloat16, torch.float16, torch.float32],
125+
lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}",
126+
)
122127

123128
n = A.numel()
124129

@@ -140,3 +145,73 @@ def _(
140145
packed = packed.squeeze().view(quant_storage).unsqueeze(1)
141146

142147
return packed, absmax.float()
148+
149+
150+
@register_kernel("bitsandbytes::dequantize_4bit", "cpu")
151+
def _(
152+
A: torch.Tensor,
153+
absmax: torch.Tensor,
154+
blocksize: int,
155+
quant_type: str,
156+
shape: Sequence[int],
157+
dtype: torch.dtype,
158+
) -> torch.Tensor:
159+
torch._check_is_size(blocksize)
160+
torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on CPU, got {quant_type}")
161+
torch._check(
162+
dtype in [torch.bfloat16, torch.float16, torch.float32],
163+
lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}",
164+
)
165+
torch._check(
166+
A.dtype == torch.uint8,
167+
lambda: f"Blockwise 4bit dequantization on CPU only supports uint8 storage, got {A.dtype}",
168+
)
169+
170+
A = A.view(-1, 1)
171+
172+
# Grab upper and lower nibbles. Using int64 for indexing in the LUT.
173+
upper = (A >> 4).to(torch.int64)
174+
lower = (A & 0x0F).to(torch.int64)
175+
176+
# Expand to blocks
177+
blocks = torch.cat((upper, lower), dim=1).reshape(-1, blocksize)
178+
179+
# Dequantize
180+
blocks = _NF4_QUANT_TABLE[blocks] * absmax[:, None]
181+
182+
# Reshape to original shape
183+
blocks = blocks.reshape(-1, *shape[1:])
184+
185+
return blocks.to(dtype)
186+
187+
188+
@register_kernel("bitsandbytes::gemv_4bit", "cpu")
189+
def _(
190+
A: torch.Tensor,
191+
B: torch.Tensor,
192+
shapeB: Sequence[int],
193+
absmax: torch.Tensor,
194+
code: torch.Tensor,
195+
blocksize: int,
196+
) -> torch.Tensor:
197+
# TODO: We need to determine whether `code` is NF4, FP4, or other.
198+
# Right now we assume NF4, as this is the only one supported on CPU.
199+
200+
B_dq = torch.ops.bitsandbytes.dequantize_4bit.default(
201+
B,
202+
absmax,
203+
blocksize,
204+
"nf4",
205+
shape=shapeB,
206+
dtype=A.dtype,
207+
)
208+
209+
# User called gemv with B.t(), so we need to transpose it back.
210+
# if B.shape[0] == 1:
211+
# B_dq = B_dq.t()
212+
213+
return torch.nn.functional.linear(
214+
A,
215+
B_dq,
216+
bias=None,
217+
)

bitsandbytes/backends/cuda/ops.py

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -22,45 +22,6 @@ def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor):
2222
_int8_linear_matmul_impl(A, B, out)
2323

2424

25-
@register_kernel("bitsandbytes::int8_mixed_scaled_mm", "cuda")
26-
def _(
27-
A: torch.Tensor,
28-
CA: torch.Tensor,
29-
CB: torch.Tensor,
30-
SCA: torch.Tensor,
31-
SCB: torch.Tensor,
32-
outlier_cols: Optional[torch.Tensor] = None,
33-
bias: Optional[torch.Tensor] = None,
34-
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
35-
subB = None
36-
37-
if outlier_cols is not None and outlier_cols.numel():
38-
# Extract the inputs with outliers in original precision
39-
subA = A[:, outlier_cols].contiguous()
40-
41-
# Dequantize the corresponding weight columns
42-
subB = (
43-
torch.ops.bitsandbytes.int8_vectorwise_dequant.default(CB[:, outlier_cols].contiguous(), SCB)
44-
.to(A.dtype)
45-
.t()
46-
)
47-
48-
# TODO: if state.has_fp16_weights: subB = B[:, outlier_cols].t()
49-
50-
else:
51-
# Needed for torch.compile when there are no outliers.
52-
subA = torch.empty(0, device=A.device, dtype=A.dtype)
53-
54-
# Int8 Matmul + Dequant + Bias
55-
output = torch.ops.bitsandbytes.int8_scaled_mm.default(CA, CB, SCA, SCB, bias=bias, dtype=A.dtype)
56-
57-
if subB is not None:
58-
# Add the outlier columns back to the output
59-
output = output.addmm(subA, subB)
60-
61-
return output, subA
62-
63-
6425
def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor):
6526
A, B = B, A
6627

0 commit comments

Comments
 (0)