Skip to content

Commit 1b9717f

Browse files
rchen152learned_optimization authors
authored andcommitted
Internal change
PiperOrigin-RevId: 530350759
1 parent 2892fae commit 1b9717f

8 files changed

Lines changed: 33 additions & 33 deletions

File tree

learned_optimization/learned_optimizers/adafac_mlp_lopt.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -404,14 +404,15 @@ def init(
404404
iteration=jnp.asarray(0, dtype=jnp.int32),
405405
num_steps=jnp.asarray(num_steps))
406406

407-
def update(self,
408-
opt_state: AdafacMLPLOptState,
409-
grad: opt_base.Gradient,
410-
loss: jnp.ndarray,
411-
model_state: Optional[opt_base.ModelState] = None,
412-
is_valid: bool = False,
413-
key: Optional[PRNGKey] = None) -> AdafacMLPLOptState:
414-
407+
def update(
408+
self, # pytype: disable=signature-mismatch # overriding-parameter-count-checks
409+
opt_state: AdafacMLPLOptState,
410+
grad: opt_base.Gradient,
411+
loss: jnp.ndarray,
412+
model_state: Optional[opt_base.ModelState] = None,
413+
is_valid: bool = False,
414+
key: Optional[PRNGKey] = None,
415+
) -> AdafacMLPLOptState:
415416
mom_roll, rms_roll, fac_vec_roll = self._get_rolling()
416417
next_mom_rolling = mom_roll.update(opt_state.mom_rolling, grad)
417418
next_rms_rolling = rms_roll.update(opt_state.rms_rolling, grad)

learned_optimization/learned_optimizers/adafac_nominal.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -477,14 +477,15 @@ def init(
477477
iteration=jnp.asarray(0, dtype=jnp.int32),
478478
num_steps=jnp.asarray(num_steps))
479479

480-
def update(self,
481-
opt_state: AdafacMLPLOptState,
482-
grad: opt_base.Gradient,
483-
loss: jnp.ndarray,
484-
model_state: Optional[opt_base.ModelState] = None,
485-
is_valid: bool = False,
486-
key: Optional[PRNGKey] = None) -> AdafacMLPLOptState:
487-
480+
def update(
481+
self, # pytype: disable=signature-mismatch # overriding-parameter-count-checks
482+
opt_state: AdafacMLPLOptState,
483+
grad: opt_base.Gradient,
484+
loss: jnp.ndarray,
485+
model_state: Optional[opt_base.ModelState] = None,
486+
is_valid: bool = False,
487+
key: Optional[PRNGKey] = None,
488+
) -> AdafacMLPLOptState:
488489
mom_roll, rms_roll, fac_vec_roll = self._get_rolling()
489490
next_mom_rolling = mom_roll.update(opt_state.mom_rolling, grad)
490491
next_rms_rolling = rms_roll.update(opt_state.rms_rolling, grad)

learned_optimization/learned_optimizers/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ def get_params(self, state):
222222
def get_state(self, state):
223223
return self.opts[0].get_state(state.inner_opt_states[0])
224224

225-
def update(self, opt_state, grad, model_state=None, **kwargs):
225+
def update(self, opt_state, grad, model_state=None, **kwargs): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
226226
# apply to both opts
227227
new_opt_states = [
228228
opt.update(os, grad, model_state=model_state, **kwargs)

learned_optimization/learned_optimizers/mlp_lopt.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -115,14 +115,15 @@ def init(self,
115115
rolling_features=common.vec_rolling_mom(decays).init(params),
116116
iteration=jnp.asarray(0, dtype=jnp.int32))
117117

118-
def update(self,
119-
opt_state: MLPLOptState,
120-
grad: Any,
121-
loss: float,
122-
model_state: Any = None,
123-
is_valid: bool = False,
124-
key: Optional[PRNGKey] = None) -> MLPLOptState:
125-
118+
def update(
119+
self, # pytype: disable=signature-mismatch # overriding-parameter-count-checks
120+
opt_state: MLPLOptState,
121+
grad: Any,
122+
loss: float,
123+
model_state: Any = None,
124+
is_valid: bool = False,
125+
key: Optional[PRNGKey] = None,
126+
) -> MLPLOptState:
126127
next_rolling_features = common.vec_rolling_mom(decays).update(
127128
opt_state.rolling_features, grad)
128129

learned_optimization/optimizers/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def init(self, params, model_state=None, num_steps=None, **kwargs):
140140
dir_opt_state=self.direction_opt.init(
141141
params, model_state=model_state, num_steps=num_steps, **kwargs))
142142

143-
def update(self, opt_state, grad, model_state=None, **kwargs):
143+
def update(self, opt_state, grad, model_state=None, **kwargs): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
144144
base_params = opt_state.params
145145

146146
next_mag_opt_state = self.magnitude_opt.update(

learned_optimization/optimizers/learning_rate_schedules.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,7 @@ def __init__(self,
7777
self.constant_fraction = constant_fraction
7878
self.warmup_fraction = warmup_fraction
7979

80-
def __call__(self, global_step, max_steps) -> chex.Array:
81-
80+
def __call__(self, global_step, max_steps) -> chex.Array: # pytype: disable=signature-mismatch # overriding-parameter-count-checks
8281
def fload32(x):
8382
"""Convert input to float32."""
8483
return jnp.asarray(x, dtype=onp.float32)

learned_optimization/outer_trainers/full_grad.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def init_worker_state(self, worker_weights: gradient_learner.WorkerWeights,
148148
key: PRNGKey) -> UnrollState:
149149
return UnrollState()
150150

151-
def compute_gradient_estimate(
151+
def compute_gradient_estimate( # pytype: disable=signature-mismatch # overriding-parameter-count-checks
152152
self,
153153
worker_weights,
154154
key,

learned_optimization/research/hysteresis/truncated_es_shared_noise.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,15 +97,13 @@ def init_worker_state(self, worker_weights: gradient_learner.WorkerWeights,
9797
epsilons=epsilons)
9898

9999
@profile.wrap()
100-
def compute_gradient_estimate(
100+
def compute_gradient_estimate( # pytype: disable=signature-mismatch # overriding-parameter-count-checks
101101
self,
102102
worker_weights,
103103
key: PRNGKey,
104-
state:
105-
TruncatedESSharedNoiseAttributes, # this is the same state returned by init_worker_state
104+
state: TruncatedESSharedNoiseAttributes, # this is the same state returned by init_worker_state
106105
with_summary=False,
107106
) -> Tuple[gradient_learner.GradientEstimatorOut, Mapping[str, jnp.ndarray]]:
108-
109107
# because we have a for loop we let haiku manages the key
110108
rng = hk.PRNGSequence(key)
111109

0 commit comments

Comments
 (0)