Skip to content

Commit 0cafa7f

Browse files
authored
Specify blocksize (#1586)
* [fix] define blocksize define blocksize as just showing number is a bit confusing * [fix] match code with ademamix also define blocksize as in prev commit
1 parent 12c4096 commit 0cafa7f

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

bitsandbytes/optim/ademamix.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,8 +166,9 @@ def init_state(self, group, p, gindex, pindex):
166166
self.name2qmap["dynamic"] = state["qmap1"] = self.name2qmap["dynamic"].to(p.device)
167167
self.name2qmap["udynamic"] = state["qmap2"] = self.name2qmap["udynamic"].to(p.device)
168168

169+
blocksize = 256
169170
n = p.numel()
170-
blocks = (n // 256) + bool(n % 256)
171+
blocks = (n // blocksize) + bool(n % blocksize)
171172

172173
state["absmax1"] = torch.zeros((2, blocks), dtype=torch.float32, device=p.device)
173174
state["absmax2"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)

bitsandbytes/optim/optimizer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -475,9 +475,9 @@ def init_state(self, group, p, gindex, pindex):
475475
state["qmap2"] = self.name2qmap["udynamic"]
476476

477477
if config["block_wise"]:
478+
blocksize = 256
478479
n = p.numel()
479-
blocks = n // 256
480-
blocks += 1 if n % 256 > 0 else 0
480+
blocks = (n // blocksize) + bool(n % blocksize)
481481

482482
state["absmax1"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
483483
state["absmax2"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
@@ -697,9 +697,9 @@ def init_state(self, group, p, gindex, pindex):
697697
state["qmap1"] = self.name2qmap["dynamic"]
698698

699699
if config["block_wise"]:
700+
blocksize = 256
700701
n = p.numel()
701-
blocks = n // 256
702-
blocks += 1 if n % 256 > 0 else 0
702+
blocks = (n // blocksize) + bool(n % blocksize)
703703

704704
state["absmax1"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
705705
else:

0 commit comments

Comments
 (0)