@@ -109,12 +109,31 @@ class DeepSpeedFP16Config(DeepSpeedConfigModel):
109109 """
110110 Automatically cast inputs to fp16
111111 """
112-
112+
113113 loss_scale : float = 0
114114 """
115115 Loss scaling value. Default value of 0 means dynamic loss scaling instead of static loss scale.
116116 """
117117
118+ @field_validator ("loss_scale" )
119+ @classmethod
120+ def _validate_loss_scale (cls , v ):
121+ # Prevent True/False from being treated as 1/0
122+ if isinstance (v , bool ):
123+ raise ValueError ("fp16.loss_scale must be a number, not bool" )
124+
125+ v = float (v )
126+
127+ # Reject inf/-inf/nan
128+ if not math .isfinite (v ):
129+ raise ValueError ("fp16.loss_scale must be a finite number (not inf/-inf/nan)" )
130+
131+ # Reject negative values; 0 still means dynamic loss scaling
132+ if v < 0 :
133+ raise ValueError ("fp16.loss_scale must be >= 0 (0 enables dynamic loss scaling)" )
134+
135+ return v
136+
118137 initial_scale_power : int = 16
119138 """
120139 For dynamic loss scaling, set initial loss scale to 2^{initial_scale_power}.
@@ -156,27 +175,3 @@ def dynamic_loss_scale_args(self):
156175 CONSECUTIVE_HYSTERESIS : self .consecutive_hysteresis ,
157176 MIN_LOSS_SCALE : self .min_loss_scale ,
158177 }
159-
160- loss_scale : float = 0
161- """
162- Loss scaling value. Default value of 0 means dynamic loss scaling instead of static loss scale.
163- """
164-
165- @field_validator ("loss_scale" )
166- @classmethod
167- def _validate_loss_scale (cls , v ):
168- # Prevent True/False from being treated as 1/0
169- if isinstance (v , bool ):
170- raise ValueError ("fp16.loss_scale must be a number, not bool" )
171-
172- v = float (v )
173-
174- # Reject inf/-inf/nan
175- if not math .isfinite (v ):
176- raise ValueError ("fp16.loss_scale must be a finite number (not inf/-inf/nan)" )
177-
178- # Reject negative values; 0 still means dynamic loss scaling
179- if v < 0 :
180- raise ValueError ("fp16.loss_scale must be >= 0 (0 enables dynamic loss scaling)" )
181-
182- return v
0 commit comments