Skip to content

Commit 597524c

Browse files
author
learned_optimization authors
committed
No public description
PiperOrigin-RevId: 593160972
1 parent 35120a4 commit 597524c

2 files changed

Lines changed: 58 additions & 80 deletions

File tree

learned_optimization/research/univ_nfn/learned_opt/learned_opts.py

Lines changed: 56 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
from learned_optimization.learned_optimizers import base as lopt_base
3636
from learned_optimization.learned_optimizers import common
3737
from learned_optimization.optimizers import base as opt_base
38-
from learned_optimization.research.univ_nfn.nfn import ff_layers as nf_layers
3938
from learned_optimization.research.univ_nfn.nfn import universal_layers
4039
from learned_optimization.research.univ_nfn.nfn import utils as nfu
4140

@@ -89,42 +88,6 @@ class SimpleOptState(flax.struct.PyTreeNode):
8988
state: Any
9089

9190

92-
def flax_to_hk(input_dict):
93-
"""Maps flax parameter structure to haiku parameter structure.
94-
95-
Example:
96-
>>> input_dict = {
97-
... 'params': {
98-
... 'Dense_0': {'kernel': W0, 'bias': b0},
99-
... 'Dense_1': {'kernel': W1, 'bias': b1}
100-
... }
101-
... }
102-
>>> transform_dict(input_dict)
103-
{'mlp/~/linear_0': {'w': W0, 'b': b0}, 'mlp/~/linear_1': {'w': W1, 'b': b1}}
104-
"""
105-
params = input_dict.get('params', {})
106-
output_dict = {}
107-
for i, (_, layer_data) in enumerate(params.items()):
108-
# Constructing new key and sub-dictionary format
109-
new_key = f'mlp/~/linear_{i}'
110-
new_data = {'w': layer_data['kernel'], 'b': layer_data['bias']}
111-
output_dict[new_key] = new_data
112-
return output_dict
113-
114-
115-
def hk_to_flax(input_dict):
116-
"""Maps haiku parameter structure to flax parameter structure."""
117-
output_dict = {'params': {}}
118-
119-
for key, layer_data in input_dict.items():
120-
# Extracting the layer number from the key
121-
layer_num = key.split('_')[-1] # Get the part after the last '_'
122-
original_layer_name = f'Dense_{layer_num}'
123-
original_data = {'kernel': layer_data['w'], 'bias': layer_data['b']}
124-
output_dict['params'][original_layer_name] = original_data
125-
return output_dict
126-
127-
12891
def make_hk_perm_spec(mlp_params):
12992
"""Produces perm spec for a haiku mlp."""
13093
perm_spec = {}
@@ -153,17 +116,57 @@ def make_hk_cnn_perm_spec(mlp_params):
153116
return perm_spec
154117

155118

119+
def build_init_fn(scale, shape):
120+
return lambda rng, _shape: scale * jax.random.normal(rng, shape)
121+
122+
123+
class PosEmbConv(nn.Module):
124+
"""Add learned position embeddings for spatial dims of conv input."""
125+
126+
@nn.compact
127+
def __call__(self, inp_features):
128+
features, tree_def = jtu.tree_flatten(inp_features)
129+
out_features = []
130+
for i, val in enumerate(features):
131+
if len(val.shape) == 5: # conv2d filter: HxWxC1xC2xC
132+
shape = (val.shape[0], val.shape[1], 1, 1, val.shape[-1])
133+
scale = 0.17 # roughly 1 / sqrt(32), to match scale of kernel at init
134+
pos_emb = self.param(f'pos_emb_{i}', build_init_fn(scale, shape), shape)
135+
out_features.append(pos_emb + val)
136+
else:
137+
out_features.append(val)
138+
out_features = jtu.tree_unflatten(tree_def, out_features)
139+
return out_features
140+
141+
142+
def make_hk_irnn_perm_spec(mlp_params):
143+
"""Tested on RNNLM_lm1bbytes_Patch32_IRNN128_Embed64."""
144+
# -1: vocab, 0: embed, 1: hidden
145+
del mlp_params
146+
perm_spec = {
147+
'embed': {'embeddings': (-1, 0)},
148+
'irnn/linear': {'b': (1,), 'w': (0, 1)},
149+
'irnn/linear_1': {'b': (1,), 'w': (1, 1)},
150+
'linear': {'b': (-1,), 'w': (1, -1)},
151+
'~': {'initial_state_0': (-2, 1)},
152+
}
153+
return perm_spec
154+
155+
156156
class MLPForOpt(nn.Module):
157157
"""MLP for learned opt."""
158158

159159
hidden_channels: int
160160
out_channels: int
161161
num_layers: int
162+
pos_emb: bool = False
162163

163164
def setup(self):
164165
layers = []
165-
for _ in range(self.num_layers - 1):
166+
for i in range(self.num_layers - 1):
166167
layers.append(nn.Dense(self.hidden_channels))
168+
if i == 0 and self.pos_emb:
169+
layers.append(PosEmbConv())
167170
layers.append(jax.nn.relu)
168171
layers.append(nn.Dense(self.out_channels))
169172
self.mod = nn.Sequential(layers)
@@ -173,38 +176,6 @@ def __call__(self, inp_features):
173176
return jtu.tree_map(self.mod, inp_features)
174177

175178

176-
class NFNForOpt(nn.Module):
177-
"""NFN for learned opt."""
178-
179-
in_channels: int
180-
hidden_channels: int
181-
out_channels: int
182-
num_layers: int
183-
pos_enc: bool = True
184-
hnet: bool = False
185-
186-
def setup(self):
187-
assert not (self.hnet and self.pos_enc), 'Only one of these can be on.'
188-
in_channels, hidden_channels = self.in_channels, self.hidden_channels
189-
layer_cls = lambda *args, **kwargs: nf_layers.NFLinearMlp(
190-
*args, **kwargs, pe_enabled=self.pos_enc
191-
)
192-
if self.hnet:
193-
layer_cls = nf_layers.NFLinearMlpHNet
194-
layers = [layer_cls(hidden_channels, in_channels), nf_layers.nf_relu]
195-
for _ in range(self.num_layers - 2):
196-
layers.append(layer_cls(hidden_channels, hidden_channels))
197-
layers.append(nf_layers.nf_relu)
198-
layers.append(layer_cls(self.out_channels, hidden_channels))
199-
self.mod = nn.Sequential(layers)
200-
201-
def __call__(self, inp_features):
202-
# add batch dimension for nf layers
203-
inp_features = nfu.tree_expand_dims(inp_features, 0)
204-
out = flax_to_hk(self.mod(hk_to_flax(inp_features))[0])
205-
return nfu.tree_squeeze(out, 0)
206-
207-
208179
class UnivNFNForOpt(nn.Module):
209180
"""Univeral NFN for learned opt."""
210181

@@ -214,6 +185,7 @@ class UnivNFNForOpt(nn.Module):
214185
num_layers: int
215186
perm_spec: Any
216187
ptwise_init: bool = False
188+
pos_emb: bool = False
217189

218190
def setup(self):
219191
in_channels, hidden_channels = self.in_channels, self.hidden_channels
@@ -224,10 +196,10 @@ def make_layer(out_chan, in_chan):
224196
else:
225197
return universal_layers.NFLinear(out_chan, in_chan, w_init='lecun')
226198

227-
layers = [
228-
make_layer(hidden_channels, in_channels),
229-
universal_layers.nf_relu,
230-
]
199+
layers = [make_layer(hidden_channels, in_channels)]
200+
if self.pos_emb:
201+
layers.append(PosEmbConv())
202+
layers.append(universal_layers.nf_relu)
231203
for _ in range(self.num_layers - 1):
232204
layers.append(make_layer(hidden_channels, hidden_channels))
233205
layers.append(universal_layers.nf_relu)
@@ -434,10 +406,14 @@ def norm_second_moment(p):
434406
class ResidualOptNFN(ResidualOpt):
435407
"""NFN learning a residual on base optimizer."""
436408

437-
def __init__(self, task, step_mult=0.1, out_mult=1e-4, ptwise_init=False):
409+
def __init__(
410+
self, task, step_mult=0.1, out_mult=1e-4, ptwise_init=False, pos_emb=False
411+
):
438412
example_params = task.init(jax.random.PRNGKey(0))
439413
if 'conv2_d' in example_params:
440414
perm_spec = make_hk_cnn_perm_spec(example_params)
415+
elif 'irnn/linear' in example_params:
416+
perm_spec = make_hk_irnn_perm_spec(example_params)
441417
else:
442418
perm_spec = make_hk_perm_spec(example_params)
443419
network = UnivNFNForOpt(
@@ -447,6 +423,7 @@ def __init__(self, task, step_mult=0.1, out_mult=1e-4, ptwise_init=False):
447423
num_layers=4,
448424
perm_spec=perm_spec,
449425
ptwise_init=ptwise_init,
426+
pos_emb=pos_emb,
450427
)
451428
super().__init__(
452429
network, example_params, step_mult=step_mult, out_mult=out_mult
@@ -456,9 +433,11 @@ def __init__(self, task, step_mult=0.1, out_mult=1e-4, ptwise_init=False):
456433
@gin.configurable
457434
class ResidualOptMLP(ResidualOpt):
458435

459-
def __init__(self, task, step_mult=0.1, out_mult=1e-4):
436+
def __init__(self, task, step_mult=0.1, out_mult=1e-4, pos_emb=False):
460437
example_params = task.init(jax.random.PRNGKey(0))
461-
network = MLPForOpt(hidden_channels=32, out_channels=1, num_layers=4)
438+
network = MLPForOpt(
439+
hidden_channels=32, out_channels=1, num_layers=4, pos_emb=pos_emb
440+
)
462441
super().__init__(
463442
network, example_params, step_mult=step_mult, out_mult=out_mult
464443
)

learned_optimization/research/univ_nfn/nfn/siren.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from typing import Any, Callable, Optional, Tuple
2020

2121
from flax import linen as nn
22-
from flax import nn as fnn
2322
import jax
2423
from jax import lax
2524
import jax.numpy as jnp
@@ -113,7 +112,7 @@ class ModulatedLayer(nn.Module):
113112
features: int = 32
114113
is_first: bool = False
115114
synthesis_act: Callable = jnp.sin
116-
modulator_act: Callable = fnn.relu
115+
modulator_act: Callable = nn.relu
117116
precision: Any = None
118117
dtype: Any = jnp.float32
119118
w0_first_layer: float = 30.0
@@ -196,7 +195,7 @@ class ModulatedSiren(nn.Module):
196195
output_dim: int = 3
197196
num_layers: int = 5
198197
synthesis_act: Callable = jnp.sin
199-
modulator_act: Callable = fnn.relu
198+
modulator_act: Callable = nn.relu
200199
final_activation: Callable = lambda x: x
201200
w0_first_layer: float = 30.0
202201
dtype: Any = jnp.float32

0 commit comments

Comments
 (0)