@@ -261,6 +261,34 @@ def __call__(self, inp_features):
261261 return self .mod (inp_features , self .perm_spec .unfreeze ())
262262
263263
264+ class HybridMLPNFN (nn .Module ):
265+ """MLP + NFN Lopt."""
266+
267+ in_channels : int
268+ hidden_channels : int
269+ out_channels : int
270+ num_layers : int
271+ perm_spec : Any
272+ ptwise_init : bool = False
273+
274+ def setup (self ):
275+ out_channels , hidden_channels = self .out_channels , self .hidden_channels
276+
277+ self .mlp = MLPForOpt (hidden_channels , hidden_channels , self .num_layers - 1 )
278+
279+ def make_layer (out_chan , in_chan ):
280+ if self .ptwise_init :
281+ return universal_layers .PointwiseInitNFLinear (out_chan , in_chan )
282+ else :
283+ return universal_layers .NFLinear (out_chan , in_chan , w_init = 'lecun' )
284+
285+ self .final = make_layer (out_channels , hidden_channels )
286+
287+ def __call__ (self , inp_features ):
288+ features = universal_layers .nf_relu (self .mlp (inp_features ))
289+ return self .final (features , self .perm_spec .unfreeze ())
290+
291+
264292class SGDControl (lopt_base .LearnedOptimizer ):
265293 """SGD where per-parameter learning rates are controlled by a network."""
266294
@@ -457,7 +485,13 @@ class ResidualOptNFN(ResidualOpt):
457485 """NFN learning a residual on base optimizer."""
458486
459487 def __init__ (
460- self , task , step_mult = 0.1 , out_mult = 1e-4 , ptwise_init = False , pos_emb = False
488+ self ,
489+ task ,
490+ step_mult = 0.1 ,
491+ out_mult = 1e-4 ,
492+ ptwise_init = False ,
493+ pos_emb = False ,
494+ hybrid = False ,
461495 ):
462496 example_params = task .init (jax .random .PRNGKey (0 ))
463497 if 'conv2_d' in example_params :
@@ -468,15 +502,25 @@ def __init__(
468502 perm_spec = make_hk_transformer_perm_spec (example_params )
469503 else :
470504 perm_spec = make_hk_perm_spec (example_params )
471- network = UnivNFNForOpt (
472- in_channels = 19 ,
473- hidden_channels = 32 ,
474- out_channels = 1 ,
475- num_layers = 4 ,
476- perm_spec = perm_spec ,
477- ptwise_init = ptwise_init ,
478- pos_emb = pos_emb ,
479- )
505+ if hybrid :
506+ assert not pos_emb
507+ network = HybridMLPNFN (
508+ in_channels = 19 ,
509+ hidden_channels = 32 ,
510+ out_channels = 1 ,
511+ num_layers = 4 ,
512+ perm_spec = perm_spec ,
513+ )
514+ else :
515+ network = UnivNFNForOpt (
516+ in_channels = 19 ,
517+ hidden_channels = 32 ,
518+ out_channels = 1 ,
519+ num_layers = 4 ,
520+ perm_spec = perm_spec ,
521+ ptwise_init = ptwise_init ,
522+ pos_emb = pos_emb ,
523+ )
480524 super ().__init__ (
481525 network , example_params , step_mult = step_mult , out_mult = out_mult
482526 )
0 commit comments