Skip to content

Commit 3ead20d

Browse files
committed
Add validation for loss_scale in precision_config
(cherry picked from commit f0059a7) Signed-off-by: nathon-lee <leejianwoo@gmail.com>
1 parent b050c76 commit 3ead20d

1 file changed

Lines changed: 20 additions & 25 deletions

File tree

deepspeed/runtime/precision_config.py

Lines changed: 20 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -109,12 +109,31 @@ class DeepSpeedFP16Config(DeepSpeedConfigModel):
109109
"""
110110
Automatically cast inputs to fp16
111111
"""
112-
112+
113113
loss_scale: float = 0
114114
"""
115115
Loss scaling value. Default value of 0 means dynamic loss scaling instead of static loss scale.
116116
"""
117117

118+
@field_validator("loss_scale")
119+
@classmethod
120+
def _validate_loss_scale(cls, v):
121+
# Prevent True/False from being treated as 1/0
122+
if isinstance(v, bool):
123+
raise ValueError("fp16.loss_scale must be a number, not bool")
124+
125+
v = float(v)
126+
127+
# Reject inf/-inf/nan
128+
if not math.isfinite(v):
129+
raise ValueError("fp16.loss_scale must be a finite number (not inf/-inf/nan)")
130+
131+
# Reject negative values; 0 still means dynamic loss scaling
132+
if v < 0:
133+
raise ValueError("fp16.loss_scale must be >= 0 (0 enables dynamic loss scaling)")
134+
135+
return v
136+
118137
initial_scale_power: int = 16
119138
"""
120139
For dynamic loss scaling, set initial loss scale to 2^{initial_scale_power}.
@@ -156,27 +175,3 @@ def dynamic_loss_scale_args(self):
156175
CONSECUTIVE_HYSTERESIS: self.consecutive_hysteresis,
157176
MIN_LOSS_SCALE: self.min_loss_scale,
158177
}
159-
160-
loss_scale: float = 0
161-
"""
162-
Loss scaling value. Default value of 0 means dynamic loss scaling instead of static loss scale.
163-
"""
164-
165-
@field_validator("loss_scale")
166-
@classmethod
167-
def _validate_loss_scale(cls, v):
168-
# Prevent True/False from being treated as 1/0
169-
if isinstance(v, bool):
170-
raise ValueError("fp16.loss_scale must be a number, not bool")
171-
172-
v = float(v)
173-
174-
# Reject inf/-inf/nan
175-
if not math.isfinite(v):
176-
raise ValueError("fp16.loss_scale must be a finite number (not inf/-inf/nan)")
177-
178-
# Reject negative values; 0 still means dynamic loss scaling
179-
if v < 0:
180-
raise ValueError("fp16.loss_scale must be >= 0 (0 enables dynamic loss scaling)")
181-
182-
return v

0 commit comments

Comments
 (0)