Skip to content

Commit 24bec2a

Browse files
authored
[fix] match code with ademamix
also define blocksize as in prev commit
1 parent b412c91 commit 24bec2a

1 file changed

Lines changed: 4 additions & 4 deletions

File tree

bitsandbytes/optim/optimizer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -475,9 +475,9 @@ def init_state(self, group, p, gindex, pindex):
475475
state["qmap2"] = self.name2qmap["udynamic"]
476476

477477
if config["block_wise"]:
478+
blocksize = 256
478479
n = p.numel()
479-
blocks = n // 256
480-
blocks += 1 if n % 256 > 0 else 0
480+
blocks = (n // blocksize) + bool(n % blocksize)
481481

482482
state["absmax1"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
483483
state["absmax2"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
@@ -697,9 +697,9 @@ def init_state(self, group, p, gindex, pindex):
697697
state["qmap1"] = self.name2qmap["dynamic"]
698698

699699
if config["block_wise"]:
700+
blocksize = 256
700701
n = p.numel()
701-
blocks = n // 256
702-
blocks += 1 if n % 256 > 0 else 0
702+
blocks = (n // blocksize) + bool(n % blocksize)
703703

704704
state["absmax1"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
705705
else:

0 commit comments

Comments
 (0)