Skip to content

Commit 2f67ea5

Browse files
committed
add reference schedule-free
1 parent d9e61e6 commit 2f67ea5

4 files changed

Lines changed: 331 additions & 0 deletions

File tree

submissions/self_tuning/schedule_free_adamw/__init__.py

Whitespace-only changes.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
q
Lines changed: 321 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,321 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import math
7+
from typing import Dict, Iterator, List, Tuple
8+
from absl import logging
9+
import torch
10+
import torch.distributed.nn as dist_nn
11+
12+
from algorithmic_efficiency import spec
13+
from algorithmic_efficiency.pytorch_utils import pytorch_setup
14+
15+
USE_PYTORCH_DDP = pytorch_setup()[0]
16+
HPARAMS = {
17+
"dropout_rate": 0.1,
18+
"learning_rate": 0.0025,
19+
"one_minus_beta1": 0.1,
20+
"beta2": 0.9955159689799007,
21+
"weight_decay": 0.08121616522670176,
22+
"warmup_factor": 0.02,
23+
"weight_lr_power": 2,
24+
"label_smoothing": 0.2,
25+
"r": 0.75,
26+
"conformer_bs": 192,
27+
}
28+
29+
30+
class AdamWScheduleFree(torch.optim.Optimizer):
31+
r"""Schedule Free AdamW
32+
"""
33+
def __init__(self, params,
34+
lr=1e-3,
35+
betas=(0.9, 0.999),
36+
eps=1e-8,
37+
weight_decay=0,
38+
weight_lr_power=2,
39+
warmup_steps=0,
40+
r=0,
41+
):
42+
defaults = dict(lr=lr,
43+
betas=betas,
44+
eps=eps,
45+
r=r,
46+
k=0,
47+
weight_sum=0.0,
48+
lr_max=0.0,
49+
warmup_steps=warmup_steps,
50+
weight_lr_power=weight_lr_power,
51+
weight_decay=weight_decay)
52+
53+
super().__init__(params, defaults)
54+
55+
def reset(self):
56+
for group in self.param_groups:
57+
group['k'] = 0
58+
group['lr_max'] = 0
59+
group['weight_sum'] = 0
60+
61+
for p in group['params']:
62+
# State initialization
63+
state = self.state[p]
64+
state['z'].copy_(state['x0'])
65+
p.data.copy_(state['x0'])
66+
state['exp_avg_sq'].zero_()
67+
68+
def step(self, closure):
69+
"""Performs a single optimization step.
70+
71+
Arguments:
72+
closure (callable, optional): A closure that reevaluates the model
73+
and returns the loss.
74+
"""
75+
# Swap to extrapolated point:
76+
for group in self.param_groups:
77+
beta1, beta2 = group['betas']
78+
r = group['r']
79+
k = group['k']
80+
81+
for p in group['params']:
82+
# State initialization
83+
state = self.state[p]
84+
if 'z' not in state:
85+
state['z'] = torch.clone(p.data)
86+
state['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.bfloat16)
87+
state['x0'] = p.data.cpu()
88+
89+
z = state['z']
90+
91+
# Extrapolate
92+
#p = p + (1-beta1)*(z-p)
93+
#p.data.mul_(beta1).add_(z, alpha=1-beta1)
94+
p.data.lerp_(end=z, weight=1-beta1)
95+
96+
# Evaluate gradient at extrapolated point
97+
loss = closure()
98+
99+
for group in self.param_groups:
100+
eps = group['eps']
101+
k = group['k']
102+
warmup_steps = group['warmup_steps']
103+
104+
if k < warmup_steps:
105+
sched = (k+1) / warmup_steps
106+
else:
107+
sched = 1.0
108+
annealed_lr = group['lr']*sched
109+
110+
lr = max(annealed_lr, eps)
111+
112+
decay = group['weight_decay']
113+
beta1, beta2 = group['betas']
114+
weight_lr_power = group['weight_lr_power']
115+
116+
r = group['r']
117+
lr_max = group['lr_max'] = max(lr, group['lr_max'])
118+
119+
weight = ((k+1)**r) * (lr_max**weight_lr_power)
120+
weight_sum = group['weight_sum'] = group['weight_sum'] + weight
121+
122+
ckp1 = weight/weight_sum
123+
124+
bias_correction2 = 1 - beta2 ** (k+1)
125+
step_size = lr * math.sqrt(bias_correction2)
126+
127+
for p in group['params']:
128+
if p.grad is None:
129+
continue
130+
grad = p.grad.data
131+
132+
state = self.state[p]
133+
134+
exp_avg_sq = state['exp_avg_sq']
135+
z = state['z']
136+
137+
# Unextrapolate
138+
#p = (p - (1-beta1)*z)/beta1
139+
#p.data.sub_(z, alpha=1-beta1).div_(beta1)
140+
p.data.lerp_(end=z, weight=1-1/beta1)
141+
142+
# Decay the first and second moment running average coefficient
143+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1-beta2)
144+
denom = exp_avg_sq.sqrt().add_(eps)
145+
146+
z.addcdiv_(grad, denom, value=-step_size)
147+
148+
# Decay
149+
z.sub_(p.data, alpha=step_size*decay)
150+
151+
### Take step
152+
#p.data.mul_(1-ckp1).add_(z, alpha=ckp1)
153+
p.data.lerp_(end=z, weight=ckp1)
154+
155+
group['k'] = k+1
156+
return loss
157+
158+
def init_optimizer_state(workload: spec.Workload,
159+
model_params: spec.ParameterContainer,
160+
model_state: spec.ModelAuxiliaryState,
161+
hyperparameters: spec.Hyperparameters,
162+
rng: spec.RandomState) -> spec.OptimizerState:
163+
del model_state
164+
165+
optimizer = AdamWScheduleFree(
166+
model_params.parameters(),
167+
lr=HPARAMS['learning_rate'],
168+
betas=(1.0 - HPARAMS['one_minus_beta1'], HPARAMS['beta2']),
169+
warmup_steps=int(HPARAMS['warmup_factor'] * workload.step_hint * 0.75),
170+
weight_decay=HPARAMS['weight_decay'],
171+
weight_lr_power=HPARAMS['weight_lr_power'],
172+
r=HPARAMS['r'])
173+
174+
optimizer_state = {'optimizer':optimizer, 'max_checked_eval_step': -1, 'has_forced_reset': False, 'first_eval': False, }
175+
return optimizer_state
176+
177+
def update_params(workload: spec.Workload,
178+
current_param_container: spec.ParameterContainer,
179+
current_params_types: spec.ParameterTypeTree,
180+
model_state: spec.ModelAuxiliaryState,
181+
hyperparameters: spec.Hyperparameters,
182+
batch: Dict[str, spec.Tensor],
183+
loss_type: spec.LossType,
184+
optimizer_state: spec.OptimizerState,
185+
eval_results: List[Tuple[int, float]],
186+
global_step: int,
187+
rng: spec.RandomState) -> spec.UpdateReturn:
188+
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
189+
del current_params_types
190+
del loss_type
191+
del hyperparameters
192+
193+
194+
metric_name = workload.target_metric_name
195+
# TODO - remove force_reset
196+
eval_step = len(eval_results)
197+
if (global_step > workload.step_hint*0.10) and optimizer_state['max_checked_eval_step'] < eval_step:
198+
optimizer_state['max_checked_eval_step'] = eval_step
199+
# Don't do resetting on workloads that don't run eval often enough
200+
if len(eval_results) >= 4: # and optimizer_state["first_eval_is_far_from_target"]:
201+
val_metric = f"validation/{metric_name}"
202+
initial_eval = eval_results[0][1][val_metric]
203+
latest_eval = eval_results[-1][1][val_metric]
204+
second_latest_eval = eval_results[-2][1][val_metric]
205+
third_latest_eval = eval_results[-3][1][val_metric]
206+
fourth_latest_eval = eval_results[-4][1][val_metric]
207+
MARGIN = 0.01
208+
if metric_name in ["loss", "wer"]:
209+
# Decreasing eval workloads should be flipped
210+
initial_eval = -initial_eval
211+
latest_eval = -latest_eval
212+
second_latest_eval = -second_latest_eval
213+
third_latest_eval = -third_latest_eval
214+
fourth_latest_eval = -fourth_latest_eval
215+
# Higher is better
216+
# scale as a curve from 0 --> 1
217+
# if the eval values are far from the target (i.e. - worse than initial) and stays far from the target for 4 evals
218+
if (latest_eval - initial_eval < MARGIN) and (latest_eval - second_latest_eval < MARGIN) and (second_latest_eval - third_latest_eval < MARGIN) and (third_latest_eval - fourth_latest_eval < MARGIN):
219+
# Reset parameters since we appear to have diverged
220+
logging.info("Reseting All Weights ")
221+
logging.info(f"Global Step: {global_step}")
222+
optimizer_state['has_forced_reset'] = True
223+
224+
# Perform reset
225+
del model_state
226+
model_state = None
227+
optimizer_state['optimizer'].reset()
228+
229+
# Decrease learning rate by 2x if it diverged.
230+
for param_group in optimizer_state['optimizer'].param_groups:
231+
param_group['lr'] = param_group['lr']/2.0
232+
233+
###########
234+
235+
current_model = current_param_container
236+
current_model.train()
237+
238+
new_model_state = None
239+
240+
def closure():
241+
nonlocal new_model_state
242+
optimizer_state['optimizer'].zero_grad()
243+
244+
logits_batch, new_model_state = workload.model_fn(
245+
params=current_model,
246+
augmented_and_preprocessed_input_batch=batch,
247+
model_state=model_state,
248+
mode=spec.ForwardPassMode.TRAIN,
249+
rng=rng,
250+
update_batch_norm=True)
251+
252+
loss_dict = workload.loss_fn(
253+
label_batch=batch['targets'],
254+
logits_batch=logits_batch,
255+
mask_batch=batch.get('weights'),
256+
label_smoothing=HPARAMS['label_smoothing'])
257+
summed_loss = loss_dict['summed']
258+
n_valid_examples = loss_dict['n_valid_examples']
259+
if USE_PYTORCH_DDP:
260+
# Use dist_nn.all_reduce to ensure correct loss and gradient scaling.
261+
summed_loss = dist_nn.all_reduce(summed_loss)
262+
n_valid_examples = dist_nn.all_reduce(n_valid_examples)
263+
loss = summed_loss / n_valid_examples
264+
265+
loss.backward()
266+
return loss
267+
268+
loss = optimizer_state['optimizer'].step(closure)
269+
270+
return (optimizer_state, current_param_container, new_model_state)
271+
272+
273+
def get_batch_size(workload_name):
274+
# Return the global batch size.
275+
if workload_name == 'criteo1tb':
276+
return 262_144
277+
elif workload_name == 'fastmri':
278+
return 16 # 32
279+
elif workload_name == 'imagenet_resnet':
280+
return 1024
281+
elif workload_name == 'imagenet_vit':
282+
return 1024
283+
elif workload_name == 'librispeech_conformer':
284+
return 224
285+
elif workload_name == 'librispeech_deepspeech':
286+
return 128 # 256
287+
elif workload_name == 'ogbg':
288+
return 512
289+
elif workload_name == 'wmt':
290+
return 128
291+
elif workload_name == 'mnist':
292+
return 16
293+
elif workload_name == 'imagenet_resnet_gelu':
294+
return 512
295+
elif workload_name == 'imagenet_resnet_silu':
296+
return 512
297+
elif workload_name == 'finewebedu_lm':
298+
return 64
299+
else:
300+
raise ValueError(f'Unsupported workload name: {workload_name}.')
301+
302+
def data_selection(workload: spec.Workload,
303+
input_queue: Iterator[Dict[str, spec.Tensor]],
304+
optimizer_state: spec.OptimizerState,
305+
current_param_container: spec.ParameterContainer,
306+
model_state: spec.ModelAuxiliaryState,
307+
hyperparameters: spec.Hyperparameters,
308+
global_step: int,
309+
rng: spec.RandomState) -> Dict[str, spec.Tensor]:
310+
"""Select data from the infinitely repeating, pre-shuffled input queue.
311+
Each element of the queue is a batch of training examples and labels.
312+
"""
313+
del workload
314+
del optimizer_state
315+
del current_param_container
316+
del model_state
317+
del hyperparameters
318+
del global_step
319+
del rng
320+
batch = next(input_queue)
321+
return batch
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
submission_name: Schedule Free AdamW
2+
submission_folder: schedule_free_adamw
3+
authors: Alice Yang, Aaron Defazio, Konstantin Mishchenko
4+
affiliations: Meta AI, Samsung AI
5+
version: '1.0'
6+
ruleset: self-tuning
7+
framework: PyTorch
8+
description: >-
9+
A self-tuning version of Schedule Free AdamW ([Defazio et al., 2024](https://openreview.net/forum?id=0XeNkkENuI)) using a single hyperparameter configuration.

0 commit comments

Comments
 (0)