Skip to content

Commit b6bfbeb

Browse files
committed
fix optimizer
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
1 parent 1baac13 commit b6bfbeb

1 file changed

Lines changed: 2 additions & 1 deletion

File tree

bitsandbytes/backends/triton/kernels_optim.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,8 @@ def _optimizer_precondition_1state_32bit(
123123

124124
if OPTIMIZER_ID == 0: # MOMENTUM
125125
if step == 1:
126-
s1_vals = g_vals
126+
# Cast to fp32 to avoid type mismatch: s1_vals is fp32 but g_vals may be fp16.
127+
s1_vals = g_vals.to(tl.float32)
127128
else:
128129
s1_vals = s1_vals * beta1 + g_vals
129130
update_norm = s1_vals * s1_vals

0 commit comments

Comments
 (0)