Skip to content

Commit 9edeb0f

Browse files
committed
required changes
1 parent e9cc3de commit 9edeb0f

2 files changed

Lines changed: 56 additions & 5 deletions

File tree

bitsandbytes/optim/adam.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
6+
import warnings
77
from bitsandbytes.optim.optimizer import Optimizer2State
88

99

@@ -76,6 +76,8 @@ def __init__(
7676
betas=(0.9, 0.999),
7777
eps=1e-8,
7878
weight_decay=0,
79+
amsgrad=False,
80+
optim_bits=32,
7981
args=None,
8082
min_8bit_size=4096,
8183
percentile_clipping=100,
@@ -96,6 +98,12 @@ def __init__(
9698
The epsilon value prevents division by zero in the optimizer.
9799
weight_decay (`float`, defaults to 0.0):
98100
The weight decay value for the optimizer.
101+
amsgrad (`bool`, defaults to `False`):
102+
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
103+
Note: This parameter is not supported in Adam8bit and must be False.
104+
optim_bits (`int`, defaults to 32):
105+
The number of bits of the optimizer state.
106+
Note: This parameter is not used in Adam8bit as it always uses 8-bit optimization.
99107
args (`object`, defaults to `None`):
100108
An object with additional arguments.
101109
min_8bit_size (`int`, defaults to 4096):
@@ -107,14 +115,23 @@ def __init__(
107115
is_paged (`bool`, defaults to `False`):
108116
Whether the optimizer is a paged optimizer or not.
109117
"""
118+
# Validate unsupported parameters
119+
if amsgrad:
120+
raise ValueError("Adam8bit does not support amsgrad=True")
121+
122+
if optim_bits != 32:
123+
# We allow the default value of 32 to maintain compatibility with the function signature,
124+
# but any other value is invalid since Adam8bit always uses 8-bit optimization
125+
raise ValueError("Adam8bit only supports optim_bits=32 (default value for compatibility)")
126+
110127
super().__init__(
111128
"adam",
112129
params,
113130
lr,
114131
betas,
115132
eps,
116133
weight_decay,
117-
8,
134+
8, # Hardcoded to 8 bits
118135
args,
119136
min_8bit_size,
120137
percentile_clipping,
@@ -277,8 +294,10 @@ def __init__(
277294
The weight decay value for the optimizer.
278295
amsgrad (`bool`, defaults to `False`):
279296
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
297+
Note: This parameter is not supported in PagedAdam8bit and must be False.
280298
optim_bits (`int`, defaults to 32):
281299
The number of bits of the optimizer state.
300+
Note: This parameter is not used in PagedAdam8bit as it always uses 8-bit optimization.
282301
args (`object`, defaults to `None`):
283302
An object with additional arguments.
284303
min_8bit_size (`int`, defaults to 4096):
@@ -290,14 +309,23 @@ def __init__(
290309
is_paged (`bool`, defaults to `False`):
291310
Whether the optimizer is a paged optimizer or not.
292311
"""
312+
# Validate unsupported parameters
313+
if amsgrad:
314+
raise ValueError("PagedAdam8bit does not support amsgrad=True")
315+
316+
if optim_bits != 32:
317+
# We allow the default value of 32 to maintain compatibility with the function signature,
318+
# but any other value is invalid since PagedAdam8bit always uses 8-bit optimization
319+
raise ValueError("PagedAdam8bit only supports optim_bits=32 (default value for compatibility)")
320+
293321
super().__init__(
294322
"adam",
295323
params,
296324
lr,
297325
betas,
298326
eps,
299327
weight_decay,
300-
8,
328+
8, # Hardcoded to 8 bits
301329
args,
302330
min_8bit_size,
303331
percentile_clipping,

bitsandbytes/optim/adamw.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55
from bitsandbytes.optim.optimizer import Optimizer2State
6+
import warnings
67

78

89
class AdamW(Optimizer2State):
@@ -98,8 +99,10 @@ def __init__(
9899
The weight decay value for the optimizer.
99100
amsgrad (`bool`, defaults to `False`):
100101
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
102+
Note: This parameter is not supported in AdamW8bit and must be False.
101103
optim_bits (`int`, defaults to 32):
102104
The number of bits of the optimizer state.
105+
Note: This parameter is not used in AdamW8bit as it always uses 8-bit optimization.
103106
args (`object`, defaults to `None`):
104107
An object with additional arguments.
105108
min_8bit_size (`int`, defaults to 4096):
@@ -111,14 +114,23 @@ def __init__(
111114
is_paged (`bool`, defaults to `False`):
112115
Whether the optimizer is a paged optimizer or not.
113116
"""
117+
# Validate unsupported parameters
118+
if amsgrad:
119+
raise ValueError("AdamW8bit does not support amsgrad=True")
120+
121+
if optim_bits != 32:
122+
# We allow the default value of 32 to maintain compatibility with the function signature,
123+
# but any other value is invalid since AdamW8bit always uses 8-bit optimization
124+
raise ValueError("AdamW8bit only supports optim_bits=32 (default value for compatibility)")
125+
114126
super().__init__(
115127
"adam",
116128
params,
117129
lr,
118130
betas,
119131
eps,
120132
weight_decay,
121-
8,
133+
8, # Hardcoded to 8 bits
122134
args,
123135
min_8bit_size,
124136
percentile_clipping,
@@ -279,8 +291,10 @@ def __init__(
279291
The weight decay value for the optimizer.
280292
amsgrad (`bool`, defaults to `False`):
281293
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
294+
Note: This parameter is not supported in PagedAdamW8bit and must be False.
282295
optim_bits (`int`, defaults to 32):
283296
The number of bits of the optimizer state.
297+
Note: This parameter is not used in PagedAdamW8bit as it always uses 8-bit optimization.
284298
args (`object`, defaults to `None`):
285299
An object with additional arguments.
286300
min_8bit_size (`int`, defaults to 4096):
@@ -292,14 +306,23 @@ def __init__(
292306
is_paged (`bool`, defaults to `False`):
293307
Whether the optimizer is a paged optimizer or not.
294308
"""
309+
# Validate unsupported parameters
310+
if amsgrad:
311+
raise ValueError("PagedAdamW8bit does not support amsgrad=True")
312+
313+
if optim_bits != 32:
314+
# We allow the default value of 32 to maintain compatibility with the function signature,
315+
# but any other value is invalid since PagedAdamW8bit always uses 8-bit optimization
316+
raise ValueError("PagedAdamW8bit only supports optim_bits=32 (default value for compatibility)")
317+
295318
super().__init__(
296319
"adam",
297320
params,
298321
lr,
299322
betas,
300323
eps,
301324
weight_decay,
302-
8,
325+
8, # Hardcoded to 8 bits
303326
args,
304327
min_8bit_size,
305328
percentile_clipping,

0 commit comments

Comments
 (0)