Skip to content

Commit 4bcaeb0

Browse files
author
learned_optimization authors
committed
No public description
PiperOrigin-RevId: 593217203
1 parent a7322af commit 4bcaeb0

1 file changed

Lines changed: 54 additions & 10 deletions

File tree

learned_optimization/research/univ_nfn/learned_opt/learned_opts.py

Lines changed: 54 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,34 @@ def __call__(self, inp_features):
261261
return self.mod(inp_features, self.perm_spec.unfreeze())
262262

263263

264+
class HybridMLPNFN(nn.Module):
265+
"""MLP + NFN Lopt."""
266+
267+
in_channels: int
268+
hidden_channels: int
269+
out_channels: int
270+
num_layers: int
271+
perm_spec: Any
272+
ptwise_init: bool = False
273+
274+
def setup(self):
275+
out_channels, hidden_channels = self.out_channels, self.hidden_channels
276+
277+
self.mlp = MLPForOpt(hidden_channels, hidden_channels, self.num_layers - 1)
278+
279+
def make_layer(out_chan, in_chan):
280+
if self.ptwise_init:
281+
return universal_layers.PointwiseInitNFLinear(out_chan, in_chan)
282+
else:
283+
return universal_layers.NFLinear(out_chan, in_chan, w_init='lecun')
284+
285+
self.final = make_layer(out_channels, hidden_channels)
286+
287+
def __call__(self, inp_features):
288+
features = universal_layers.nf_relu(self.mlp(inp_features))
289+
return self.final(features, self.perm_spec.unfreeze())
290+
291+
264292
class SGDControl(lopt_base.LearnedOptimizer):
265293
"""SGD where per-parameter learning rates are controlled by a network."""
266294

@@ -457,7 +485,13 @@ class ResidualOptNFN(ResidualOpt):
457485
"""NFN learning a residual on base optimizer."""
458486

459487
def __init__(
460-
self, task, step_mult=0.1, out_mult=1e-4, ptwise_init=False, pos_emb=False
488+
self,
489+
task,
490+
step_mult=0.1,
491+
out_mult=1e-4,
492+
ptwise_init=False,
493+
pos_emb=False,
494+
hybrid=False,
461495
):
462496
example_params = task.init(jax.random.PRNGKey(0))
463497
if 'conv2_d' in example_params:
@@ -468,15 +502,25 @@ def __init__(
468502
perm_spec = make_hk_transformer_perm_spec(example_params)
469503
else:
470504
perm_spec = make_hk_perm_spec(example_params)
471-
network = UnivNFNForOpt(
472-
in_channels=19,
473-
hidden_channels=32,
474-
out_channels=1,
475-
num_layers=4,
476-
perm_spec=perm_spec,
477-
ptwise_init=ptwise_init,
478-
pos_emb=pos_emb,
479-
)
505+
if hybrid:
506+
assert not pos_emb
507+
network = HybridMLPNFN(
508+
in_channels=19,
509+
hidden_channels=32,
510+
out_channels=1,
511+
num_layers=4,
512+
perm_spec=perm_spec,
513+
)
514+
else:
515+
network = UnivNFNForOpt(
516+
in_channels=19,
517+
hidden_channels=32,
518+
out_channels=1,
519+
num_layers=4,
520+
perm_spec=perm_spec,
521+
ptwise_init=ptwise_init,
522+
pos_emb=pos_emb,
523+
)
480524
super().__init__(
481525
network, example_params, step_mult=step_mult, out_mult=out_mult
482526
)

0 commit comments

Comments
 (0)