Skip to content

Commit e25e0ab

Browse files
wip on torch flatten unflatten, etc
1 parent f09812d commit e25e0ab

File tree

1 file changed

+67
-11
lines changed

1 file changed

+67
-11
lines changed

bitsandbytes/nn/modules.py

Lines changed: 67 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,21 @@ def __new__(
222222
if data is None:
223223
data = torch.empty(0)
224224

225+
# Handle FakeTensor creation during dynamo tracing
226+
if torch._dynamo.is_compiling() and not isinstance(data, cls):
227+
if isinstance(data, torch._subclasses.FakeTensor):
228+
param = data.as_subclass(cls)
229+
param.requires_grad = requires_grad
230+
param.quant_state = quant_state
231+
param.blocksize = blocksize
232+
param.compress_statistics = compress_statistics
233+
param.quant_type = quant_type
234+
param.quant_storage = quant_storage
235+
param.module = module
236+
param.bnb_quantized = bnb_quantized
237+
return param
238+
239+
# Standard initialization for real tensors
225240
self = torch.Tensor._make_subclass(cls, data, requires_grad)
226241
self.blocksize = blocksize
227242
self.compress_statistics = compress_statistics
@@ -324,26 +339,23 @@ def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T: ...
324339
def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: ...
325340

326341
def to(self, *args, **kwargs):
327-
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
342+
device, dtype, non_blocking, _ = torch._C._nn._parse_to(*args, **kwargs)
328343

329344
if device is not None and device.type == "cuda" and not self.bnb_quantized:
330345
return self._quantize(device)
331346
else:
332-
if self.quant_state is not None:
333-
self.quant_state.to(device)
334-
335-
new_param = Params4bit(
347+
return Params4bit(
336348
super().to(device=device, dtype=dtype, non_blocking=non_blocking),
337349
requires_grad=self.requires_grad,
338-
quant_state=self.quant_state,
350+
quant_state=self.quant_state.to(device) if self.quant_state else None,
339351
blocksize=self.blocksize,
340352
compress_statistics=self.compress_statistics,
341353
quant_type=self.quant_type,
342354
quant_storage=self.quant_storage,
355+
module=self.module,
356+
bnb_quantized=self.bnb_quantized,
343357
)
344358

345-
return new_param
346-
347359
def __tensor_flatten__(self):
348360
"""Return data tensor and non-tensor context"""
349361
ctx = {
@@ -361,6 +373,20 @@ def __tensor_flatten__(self):
361373
def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride):
362374
"""Reconstruct Params4bit from components"""
363375
data = inner_tensors["data"]
376+
377+
# Special handling for FakeTensor reconstruction
378+
if isinstance(data, torch._subclasses.FakeTensor):
379+
param = data.as_subclass(Params4bit)
380+
param.blocksize = ctx["blocksize"]
381+
param.compress_statistics = ctx["compress_statistics"]
382+
param.quant_type = ctx["quant_type"]
383+
param.quant_state = ctx["quant_state"]
384+
param.quant_storage = ctx["quant_storage"]
385+
param.module = ctx["module"]
386+
param.bnb_quantized = ctx["bnb_quantized"]
387+
return param
388+
389+
# Standard reconstruction for real tensors
364390
return Params4bit(
365391
data,
366392
requires_grad=data.requires_grad,
@@ -373,6 +399,21 @@ def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride):
373399
bnb_quantized=ctx["bnb_quantized"],
374400
)
375401

402+
@classmethod
403+
def __torch_function__(cls, func, types, args=(), kwargs=None):
404+
# Type preservation through ops
405+
result = super().__torch_function__(func, types, args, kwargs or {})
406+
if isinstance(result, torch.Tensor) and not isinstance(result, cls):
407+
return result.as_subclass(cls)
408+
return result
409+
410+
@classmethod
411+
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
412+
# Delegate to FakeTensor implementation when needed
413+
if any(isinstance(x, torch._subclasses.FakeTensor) for x in args):
414+
return torch._C.DispatchKey.Fake(func(*args, **(kwargs or {})))
415+
return super().__torch_dispatch__(func, types, args, kwargs)
416+
376417
def detach(self):
377418
"""Create new instance preserving quantization state"""
378419
return type(self)(
@@ -460,21 +501,36 @@ def __init__(
460501
bias (`bool`, defaults to `True`):
461502
Whether the linear class uses the bias term as well.
462503
"""
463-
super().__init__(input_features, output_features, bias, device)
504+
# Bypass nn.Linear's parameter initialization
505+
super(nn.Linear, self).__init__()
506+
self.in_features = input_features
507+
self.out_features = output_features
508+
509+
# Manually register parameters
464510
self.weight = Params4bit(
465-
self.weight.data,
511+
torch.empty((output_features, input_features), dtype=quant_storage),
466512
requires_grad=False,
467513
compress_statistics=compress_statistics,
468514
quant_type=quant_type,
469515
quant_storage=quant_storage,
470516
module=self,
471517
)
472-
# self.persistent_buffers = [] # TODO consider as way to save quant state
518+
519+
if bias:
520+
self.bias = nn.Parameter(torch.empty(output_features))
521+
else:
522+
self.register_parameter("bias", None)
523+
524+
self.reset_parameters()
473525
self.compute_dtype = compute_dtype
474526
self.compute_type_is_set = False
475527
self.quant_state = None
476528
self.quant_storage = quant_storage
477529

530+
def reset_parameters(self):
531+
# Disable standard initialization
532+
pass
533+
478534
def set_compute_type(self, x):
479535
if x.dtype in [torch.float32, torch.bfloat16]:
480536
# the input is in a dtype that is safe to compute in, we switch

0 commit comments

Comments
 (0)