Skip to content

Commit 01dfe52

Browse files
ninatumartinarroyo
andcommitted
Add gradient clipping options to optimizer
Introduces options for clipping gradients by global norm or by value, configurable via `config.opt_enable_grad_global_norm_clipping` and `config.opt_enable_grad_clipping`, as well as `config.max_grad_norm` and `config.max_grad_value`. Co-authored-by: martinarroyo <martinarroyo@google.com>
1 parent 4afed9f commit 01dfe52

14 files changed

Lines changed: 48 additions & 1 deletion

src/maxdiffusion/configs/base14.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,9 @@ adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradient
206206
adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients.
207207
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
208208
adam_weight_decay: 1.e-2 # AdamW Weight decay
209+
opt_enable_grad_clipping: False
210+
max_grad_value: 1.0
211+
opt_enable_grad_global_norm_clipping: False
209212
max_grad_norm: 1.0
210213

211214
enable_profiler: False

src/maxdiffusion/configs/base21.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,9 @@ adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradient
211211
adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients.
212212
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
213213
adam_weight_decay: 1.e-2 # AdamW Weight decay
214+
opt_enable_grad_clipping: False
215+
max_grad_value: 1.0
216+
opt_enable_grad_global_norm_clipping: False
214217
max_grad_norm: 1.0
215218

216219
enable_profiler: False

src/maxdiffusion/configs/base_2_base.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,9 @@ adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradient
221221
adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients.
222222
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
223223
adam_weight_decay: 1.e-2 # AdamW Weight decay
224+
opt_enable_grad_clipping: False
225+
max_grad_value: 1.0
226+
opt_enable_grad_global_norm_clipping: False
224227
max_grad_norm: 1.0
225228

226229
enable_profiler: False

src/maxdiffusion/configs/base_flux_dev.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,9 @@ adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradient
245245
adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients.
246246
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
247247
adam_weight_decay: 0 # AdamW Weight decay
248+
opt_enable_grad_clipping: False
249+
max_grad_value: 1.0
250+
opt_enable_grad_global_norm_clipping: False
248251
max_grad_norm: 1.0
249252

250253
enable_profiler: False

src/maxdiffusion/configs/base_flux_dev_multi_res.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,9 @@ adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradient
232232
adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients.
233233
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
234234
adam_weight_decay: 1.e-2 # AdamW Weight decay
235+
opt_enable_grad_clipping: False
236+
max_grad_value: 1.0
237+
opt_enable_grad_global_norm_clipping: False
235238
max_grad_norm: 1.0
236239

237240
enable_profiler: False

src/maxdiffusion/configs/base_flux_schnell.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,9 @@ adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradient
240240
adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients.
241241
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
242242
adam_weight_decay: 1.e-2 # AdamW Weight decay
243+
opt_enable_grad_clipping: False
244+
max_grad_value: 1.0
245+
opt_enable_grad_global_norm_clipping: False
243246
max_grad_norm: 1.0
244247

245248
enable_profiler: False

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,9 @@ adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradient
301301
adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients.
302302
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
303303
adam_weight_decay: 0 # AdamW Weight decay
304+
opt_enable_grad_clipping: False
305+
max_grad_value: 1.0
306+
opt_enable_grad_global_norm_clipping: False
304307
max_grad_norm: 1.0
305308

306309
enable_profiler: False

src/maxdiffusion/configs/base_wan_1_3b.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,9 @@ adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradient
257257
adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients.
258258
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
259259
adam_weight_decay: 0 # AdamW Weight decay
260+
opt_enable_grad_clipping: False
261+
max_grad_value: 1.0
262+
opt_enable_grad_global_norm_clipping: False
260263
max_grad_norm: 1.0
261264

262265
enable_profiler: False

src/maxdiffusion/configs/base_wan_27b.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,9 @@ adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradient
268268
adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients.
269269
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
270270
adam_weight_decay: 0 # AdamW Weight decay
271+
opt_enable_grad_clipping: False
272+
max_grad_value: 1.0
273+
opt_enable_grad_global_norm_clipping: False
271274
max_grad_norm: 1.0
272275

273276
enable_profiler: False

src/maxdiffusion/configs/base_wan_i2v_14b.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,9 @@ adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradient
263263
adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients.
264264
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
265265
adam_weight_decay: 0 # AdamW Weight decay
266+
opt_enable_grad_clipping: False
267+
max_grad_value: 1.0
268+
opt_enable_grad_global_norm_clipping: False
266269
max_grad_norm: 1.0
267270

268271
enable_profiler: False

0 commit comments

Comments
 (0)