33# This source code is licensed under the MIT license found in the
44# LICENSE file in the root directory of this source tree.
55
6-
6+ import warnings
77from bitsandbytes .optim .optimizer import Optimizer2State
88
99
@@ -76,6 +76,8 @@ def __init__(
7676 betas = (0.9 , 0.999 ),
7777 eps = 1e-8 ,
7878 weight_decay = 0 ,
79+ amsgrad = False ,
80+ optim_bits = 32 ,
7981 args = None ,
8082 min_8bit_size = 4096 ,
8183 percentile_clipping = 100 ,
@@ -96,6 +98,12 @@ def __init__(
9698 The epsilon value prevents division by zero in the optimizer.
9799 weight_decay (`float`, defaults to 0.0):
98100 The weight decay value for the optimizer.
101+ amsgrad (`bool`, defaults to `False`):
102+ Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
103+ Note: This parameter is not supported in Adam8bit and must be False.
104+ optim_bits (`int`, defaults to 32):
105+ The number of bits of the optimizer state.
106+ Note: This parameter is not used in Adam8bit as it always uses 8-bit optimization.
99107 args (`object`, defaults to `None`):
100108 An object with additional arguments.
101109 min_8bit_size (`int`, defaults to 4096):
@@ -107,14 +115,23 @@ def __init__(
107115 is_paged (`bool`, defaults to `False`):
108116 Whether the optimizer is a paged optimizer or not.
109117 """
118+ # Validate unsupported parameters
119+ if amsgrad :
120+ raise ValueError ("Adam8bit does not support amsgrad=True" )
121+
122+ if optim_bits != 32 :
123+ # We allow the default value of 32 to maintain compatibility with the function signature,
124+ # but any other value is invalid since Adam8bit always uses 8-bit optimization
125+ raise ValueError ("Adam8bit only supports optim_bits=32 (default value for compatibility)" )
126+
110127 super ().__init__ (
111128 "adam" ,
112129 params ,
113130 lr ,
114131 betas ,
115132 eps ,
116133 weight_decay ,
117- 8 ,
134+ 8 , # Hardcoded to 8 bits
118135 args ,
119136 min_8bit_size ,
120137 percentile_clipping ,
@@ -277,8 +294,10 @@ def __init__(
277294 The weight decay value for the optimizer.
278295 amsgrad (`bool`, defaults to `False`):
279296 Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
297+ Note: This parameter is not supported in PagedAdam8bit and must be False.
280298 optim_bits (`int`, defaults to 32):
281299 The number of bits of the optimizer state.
300+ Note: This parameter is not used in PagedAdam8bit as it always uses 8-bit optimization.
282301 args (`object`, defaults to `None`):
283302 An object with additional arguments.
284303 min_8bit_size (`int`, defaults to 4096):
@@ -290,14 +309,23 @@ def __init__(
290309 is_paged (`bool`, defaults to `False`):
291310 Whether the optimizer is a paged optimizer or not.
292311 """
312+ # Validate unsupported parameters
313+ if amsgrad :
314+ raise ValueError ("PagedAdam8bit does not support amsgrad=True" )
315+
316+ if optim_bits != 32 :
317+ # We allow the default value of 32 to maintain compatibility with the function signature,
318+ # but any other value is invalid since PagedAdam8bit always uses 8-bit optimization
319+ raise ValueError ("PagedAdam8bit only supports optim_bits=32 (default value for compatibility)" )
320+
293321 super ().__init__ (
294322 "adam" ,
295323 params ,
296324 lr ,
297325 betas ,
298326 eps ,
299327 weight_decay ,
300- 8 ,
328+ 8 , # Hardcoded to 8 bits
301329 args ,
302330 min_8bit_size ,
303331 percentile_clipping ,
0 commit comments