Skip to content

Commit aefc6bf

Browse files
committed
fix lint
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
1 parent d7c6ef0 commit aefc6bf

File tree

6 files changed

+44
-27
lines changed

6 files changed

+44
-27
lines changed

bitsandbytes/backends/cpu/ops.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,10 @@ def _(
322322

323323

324324
def _compute_update_norm_and_scale(
325-
update: torch.Tensor, unorm_vec: Optional[torch.Tensor], max_unorm: float, param_norm: float,
325+
update: torch.Tensor,
326+
unorm_vec: Optional[torch.Tensor],
327+
max_unorm: float,
328+
param_norm: float,
326329
) -> float:
327330
"""Compute trust-ratio scaling factor for LAMB/LARS and store update norm."""
328331
if max_unorm <= 0.0:
@@ -446,26 +449,35 @@ def _optimizer_update_32bit_cpu(
446449

447450

448451
@torch.no_grad()
449-
def _dequant_blockwise_fp32_direct(A_uint8: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor,
450-
blocksize: int) -> torch.Tensor:
452+
def _dequant_blockwise_fp32_direct(
453+
A_uint8: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int
454+
) -> torch.Tensor:
451455
"""Dequantize blockwise via direct C lib call, avoiding torch.ops dispatch overhead."""
452456
n = A_uint8.numel()
453457
out = torch.empty(n, dtype=torch.float32, device=A_uint8.device)
454458
lib.cdequantize_blockwise_cpu_fp32(
455-
get_ptr(code), get_ptr(A_uint8.reshape(-1)), get_ptr(absmax),
456-
get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(n),
459+
get_ptr(code),
460+
get_ptr(A_uint8.reshape(-1)),
461+
get_ptr(absmax),
462+
get_ptr(out),
463+
ct.c_longlong(blocksize),
464+
ct.c_longlong(n),
457465
)
458466
return out.reshape(A_uint8.shape)
459467

460468

461-
def _quant_blockwise_fp32_direct(A_fp32: torch.Tensor, code: torch.Tensor,
462-
absmax_out: torch.Tensor, out_uint8: torch.Tensor,
463-
blocksize: int) -> None:
469+
def _quant_blockwise_fp32_direct(
470+
A_fp32: torch.Tensor, code: torch.Tensor, absmax_out: torch.Tensor, out_uint8: torch.Tensor, blocksize: int
471+
) -> None:
464472
"""Quantize blockwise via direct C lib call, writing into existing buffers (zero-alloc)."""
465473
n = A_fp32.numel()
466474
lib.cquantize_blockwise_cpu_fp32(
467-
get_ptr(code), get_ptr(A_fp32.reshape(-1)), get_ptr(absmax_out),
468-
get_ptr(out_uint8.reshape(-1)), ct.c_longlong(blocksize), ct.c_longlong(n),
475+
get_ptr(code),
476+
get_ptr(A_fp32.reshape(-1)),
477+
get_ptr(absmax_out),
478+
get_ptr(out_uint8.reshape(-1)),
479+
ct.c_longlong(blocksize),
480+
ct.c_longlong(n),
469481
)
470482

471483

bitsandbytes/optim/optimizer.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,20 @@
22
#
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
5-
import logging
65
from collections import abc as container_abcs, defaultdict
76
from copy import deepcopy
87
from itertools import chain
8+
import logging
99
from typing import Optional
1010
import warnings
1111

1212
import torch
1313

14-
logger = logging.getLogger(__name__)
15-
1614
import bitsandbytes.functional as F
1715
from bitsandbytes.utils import sync_gpu
1816

17+
logger = logging.getLogger(__name__)
18+
1919

2020
class MockArgs:
2121
def __init__(self, initial_data):
@@ -375,8 +375,7 @@ def get_state_buffer(self, p, dtype=torch.float32):
375375
if p.device.type == "cpu":
376376
if self.is_paged and not getattr(self, "_cpu_paged_warned", False):
377377
warnings.warn(
378-
"Paged optimizers are not supported on CPU. "
379-
"Falling back to non-paged optimizer behavior.",
378+
"Paged optimizers are not supported on CPU. Falling back to non-paged optimizer behavior.",
380379
stacklevel=2,
381380
)
382381
self._cpu_paged_warned = True

csrc/cpu_ops.cpp

100755100644
Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -252,11 +252,8 @@ struct LUTCache {
252252
float fp[4];
253253
compute_fingerprint(code, fp);
254254
for (int i = 0; i < kLUTCacheSlots; ++i) {
255-
if (cached_codes[i] == code &&
256-
cached_fingerprints[i][0] == fp[0] &&
257-
cached_fingerprints[i][1] == fp[1] &&
258-
cached_fingerprints[i][2] == fp[2] &&
259-
cached_fingerprints[i][3] == fp[3]) {
255+
if (cached_codes[i] == code && cached_fingerprints[i][0] == fp[0] && cached_fingerprints[i][1] == fp[1] &&
256+
cached_fingerprints[i][2] == fp[2] && cached_fingerprints[i][3] == fp[3]) {
260257
return luts[i];
261258
}
262259
}

csrc/cpu_ops.h

100755100644
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ static inline float fp16_to_float(uint16_t h) {
196196

197197
if (exp == 0) {
198198
if (mant == 0) {
199-
bits = sign << 31; // zero
199+
bits = sign << 31; // zero
200200
} else {
201201
// subnormal fp16 -> normal fp32
202202
exp = 1;
@@ -208,7 +208,7 @@ static inline float fp16_to_float(uint16_t h) {
208208
bits = (sign << 31) | ((exp + 127 - 15) << 23) | (mant << 13);
209209
}
210210
} else if (exp == 0x1F) {
211-
bits = (sign << 31) | (0xFF << 23) | (mant ? (mant << 13) : 0); // Inf or NaN
211+
bits = (sign << 31) | (0xFF << 23) | (mant ? (mant << 13) : 0); // Inf or NaN
212212
} else {
213213
bits = (sign << 31) | ((exp + 127 - 15) << 23) | (mant << 13);
214214
}

examples/cpu/cpu_training.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,10 @@ def run_single(args):
164164

165165
ds = prepare_data(tokenizer, args.dataset, args.max_length)
166166
dataloader = torch.utils.data.DataLoader(
167-
ds, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn,
167+
ds,
168+
batch_size=args.batch_size,
169+
shuffle=True,
170+
collate_fn=collate_fn,
168171
)
169172

170173
optimizer = create_optimizer(model, args.optimizer, args.lr)
@@ -198,7 +201,10 @@ def run_compare(args):
198201

199202
ds = prepare_data(tokenizer, args.dataset, args.max_length, num_samples=100)
200203
dataloader = torch.utils.data.DataLoader(
201-
ds, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn,
204+
ds,
205+
batch_size=args.batch_size,
206+
shuffle=False,
207+
collate_fn=collate_fn,
202208
)
203209

204210
results = {}
@@ -362,4 +368,3 @@ def main():
362368
# Training runtime: 3.2s
363369
# Steps/sec: 9.5
364370
# Optimizer: bnb.optim.adamw8bit | Dtype: bf16
365-

tests/test_optim.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -425,8 +425,12 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name, device):
425425
else:
426426
assert err.mean() < 0.00006
427427
# Lion on CPU fp16 has slightly higher relative error due to sign-based updates at boundary
428-
relerr_thr = 0.00062 if (device == "cpu" and optim_name == "lion8bit_blockwise" and gtype == torch.float16) else 0.0006
429-
assert relerr.mean() < relerr_thr
428+
relerr_the = (
429+
0.00062
430+
if (device == "cpu" and optim_name == "lion8bit_blockwise" and gtype == torch.float16)
431+
else 0.0006
432+
)
433+
assert relerr.mean() < relerr_the
430434

431435
errors.append(err.mean().item())
432436
relerrors.append(relerr.mean().item())

0 commit comments

Comments
 (0)