3535from learned_optimization .learned_optimizers import base as lopt_base
3636from learned_optimization .learned_optimizers import common
3737from learned_optimization .optimizers import base as opt_base
38- from learned_optimization .research .univ_nfn .nfn import ff_layers as nf_layers
3938from learned_optimization .research .univ_nfn .nfn import universal_layers
4039from learned_optimization .research .univ_nfn .nfn import utils as nfu
4140
@@ -89,42 +88,6 @@ class SimpleOptState(flax.struct.PyTreeNode):
8988 state : Any
9089
9190
92- def flax_to_hk (input_dict ):
93- """Maps flax parameter structure to haiku parameter structure.
94-
95- Example:
96- >>> input_dict = {
97- ... 'params': {
98- ... 'Dense_0': {'kernel': W0, 'bias': b0},
99- ... 'Dense_1': {'kernel': W1, 'bias': b1}
100- ... }
101- ... }
102- >>> transform_dict(input_dict)
103- {'mlp/~/linear_0': {'w': W0, 'b': b0}, 'mlp/~/linear_1': {'w': W1, 'b': b1}}
104- """
105- params = input_dict .get ('params' , {})
106- output_dict = {}
107- for i , (_ , layer_data ) in enumerate (params .items ()):
108- # Constructing new key and sub-dictionary format
109- new_key = f'mlp/~/linear_{ i } '
110- new_data = {'w' : layer_data ['kernel' ], 'b' : layer_data ['bias' ]}
111- output_dict [new_key ] = new_data
112- return output_dict
113-
114-
115- def hk_to_flax (input_dict ):
116- """Maps haiku parameter structure to flax parameter structure."""
117- output_dict = {'params' : {}}
118-
119- for key , layer_data in input_dict .items ():
120- # Extracting the layer number from the key
121- layer_num = key .split ('_' )[- 1 ] # Get the part after the last '_'
122- original_layer_name = f'Dense_{ layer_num } '
123- original_data = {'kernel' : layer_data ['w' ], 'bias' : layer_data ['b' ]}
124- output_dict ['params' ][original_layer_name ] = original_data
125- return output_dict
126-
127-
12891def make_hk_perm_spec (mlp_params ):
12992 """Produces perm spec for a haiku mlp."""
13093 perm_spec = {}
@@ -153,17 +116,57 @@ def make_hk_cnn_perm_spec(mlp_params):
153116 return perm_spec
154117
155118
119+ def build_init_fn (scale , shape ):
120+ return lambda rng , _shape : scale * jax .random .normal (rng , shape )
121+
122+
123+ class PosEmbConv (nn .Module ):
124+ """Add learned position embeddings for spatial dims of conv input."""
125+
126+ @nn .compact
127+ def __call__ (self , inp_features ):
128+ features , tree_def = jtu .tree_flatten (inp_features )
129+ out_features = []
130+ for i , val in enumerate (features ):
131+ if len (val .shape ) == 5 : # conv2d filter: HxWxC1xC2xC
132+ shape = (val .shape [0 ], val .shape [1 ], 1 , 1 , val .shape [- 1 ])
133+ scale = 0.17 # roughly 1 / sqrt(32), to match scale of kernel at init
134+ pos_emb = self .param (f'pos_emb_{ i } ' , build_init_fn (scale , shape ), shape )
135+ out_features .append (pos_emb + val )
136+ else :
137+ out_features .append (val )
138+ out_features = jtu .tree_unflatten (tree_def , out_features )
139+ return out_features
140+
141+
142+ def make_hk_irnn_perm_spec (mlp_params ):
143+ """Tested on RNNLM_lm1bbytes_Patch32_IRNN128_Embed64."""
144+ # -1: vocab, 0: embed, 1: hidden
145+ del mlp_params
146+ perm_spec = {
147+ 'embed' : {'embeddings' : (- 1 , 0 )},
148+ 'irnn/linear' : {'b' : (1 ,), 'w' : (0 , 1 )},
149+ 'irnn/linear_1' : {'b' : (1 ,), 'w' : (1 , 1 )},
150+ 'linear' : {'b' : (- 1 ,), 'w' : (1 , - 1 )},
151+ '~' : {'initial_state_0' : (- 2 , 1 )},
152+ }
153+ return perm_spec
154+
155+
156156class MLPForOpt (nn .Module ):
157157 """MLP for learned opt."""
158158
159159 hidden_channels : int
160160 out_channels : int
161161 num_layers : int
162+ pos_emb : bool = False
162163
163164 def setup (self ):
164165 layers = []
165- for _ in range (self .num_layers - 1 ):
166+ for i in range (self .num_layers - 1 ):
166167 layers .append (nn .Dense (self .hidden_channels ))
168+ if i == 0 and self .pos_emb :
169+ layers .append (PosEmbConv ())
167170 layers .append (jax .nn .relu )
168171 layers .append (nn .Dense (self .out_channels ))
169172 self .mod = nn .Sequential (layers )
@@ -173,38 +176,6 @@ def __call__(self, inp_features):
173176 return jtu .tree_map (self .mod , inp_features )
174177
175178
176- class NFNForOpt (nn .Module ):
177- """NFN for learned opt."""
178-
179- in_channels : int
180- hidden_channels : int
181- out_channels : int
182- num_layers : int
183- pos_enc : bool = True
184- hnet : bool = False
185-
186- def setup (self ):
187- assert not (self .hnet and self .pos_enc ), 'Only one of these can be on.'
188- in_channels , hidden_channels = self .in_channels , self .hidden_channels
189- layer_cls = lambda * args , ** kwargs : nf_layers .NFLinearMlp (
190- * args , ** kwargs , pe_enabled = self .pos_enc
191- )
192- if self .hnet :
193- layer_cls = nf_layers .NFLinearMlpHNet
194- layers = [layer_cls (hidden_channels , in_channels ), nf_layers .nf_relu ]
195- for _ in range (self .num_layers - 2 ):
196- layers .append (layer_cls (hidden_channels , hidden_channels ))
197- layers .append (nf_layers .nf_relu )
198- layers .append (layer_cls (self .out_channels , hidden_channels ))
199- self .mod = nn .Sequential (layers )
200-
201- def __call__ (self , inp_features ):
202- # add batch dimension for nf layers
203- inp_features = nfu .tree_expand_dims (inp_features , 0 )
204- out = flax_to_hk (self .mod (hk_to_flax (inp_features ))[0 ])
205- return nfu .tree_squeeze (out , 0 )
206-
207-
208179class UnivNFNForOpt (nn .Module ):
209180 """Univeral NFN for learned opt."""
210181
@@ -214,6 +185,7 @@ class UnivNFNForOpt(nn.Module):
214185 num_layers : int
215186 perm_spec : Any
216187 ptwise_init : bool = False
188+ pos_emb : bool = False
217189
218190 def setup (self ):
219191 in_channels , hidden_channels = self .in_channels , self .hidden_channels
@@ -224,10 +196,10 @@ def make_layer(out_chan, in_chan):
224196 else :
225197 return universal_layers .NFLinear (out_chan , in_chan , w_init = 'lecun' )
226198
227- layers = [
228- make_layer ( hidden_channels , in_channels ),
229- universal_layers . nf_relu ,
230- ]
199+ layers = [make_layer ( hidden_channels , in_channels )]
200+ if self . pos_emb :
201+ layers . append ( PosEmbConv ())
202+ layers . append ( universal_layers . nf_relu )
231203 for _ in range (self .num_layers - 1 ):
232204 layers .append (make_layer (hidden_channels , hidden_channels ))
233205 layers .append (universal_layers .nf_relu )
@@ -434,10 +406,14 @@ def norm_second_moment(p):
434406class ResidualOptNFN (ResidualOpt ):
435407 """NFN learning a residual on base optimizer."""
436408
437- def __init__ (self , task , step_mult = 0.1 , out_mult = 1e-4 , ptwise_init = False ):
409+ def __init__ (
410+ self , task , step_mult = 0.1 , out_mult = 1e-4 , ptwise_init = False , pos_emb = False
411+ ):
438412 example_params = task .init (jax .random .PRNGKey (0 ))
439413 if 'conv2_d' in example_params :
440414 perm_spec = make_hk_cnn_perm_spec (example_params )
415+ elif 'irnn/linear' in example_params :
416+ perm_spec = make_hk_irnn_perm_spec (example_params )
441417 else :
442418 perm_spec = make_hk_perm_spec (example_params )
443419 network = UnivNFNForOpt (
@@ -447,6 +423,7 @@ def __init__(self, task, step_mult=0.1, out_mult=1e-4, ptwise_init=False):
447423 num_layers = 4 ,
448424 perm_spec = perm_spec ,
449425 ptwise_init = ptwise_init ,
426+ pos_emb = pos_emb ,
450427 )
451428 super ().__init__ (
452429 network , example_params , step_mult = step_mult , out_mult = out_mult
@@ -456,9 +433,11 @@ def __init__(self, task, step_mult=0.1, out_mult=1e-4, ptwise_init=False):
456433@gin .configurable
457434class ResidualOptMLP (ResidualOpt ):
458435
459- def __init__ (self , task , step_mult = 0.1 , out_mult = 1e-4 ):
436+ def __init__ (self , task , step_mult = 0.1 , out_mult = 1e-4 , pos_emb = False ):
460437 example_params = task .init (jax .random .PRNGKey (0 ))
461- network = MLPForOpt (hidden_channels = 32 , out_channels = 1 , num_layers = 4 )
438+ network = MLPForOpt (
439+ hidden_channels = 32 , out_channels = 1 , num_layers = 4 , pos_emb = pos_emb
440+ )
462441 super ().__init__ (
463442 network , example_params , step_mult = step_mult , out_mult = out_mult
464443 )
0 commit comments