Skip to content

Commit a7322af

Browse files
author
learned_optimization authors
committed
No public description
PiperOrigin-RevId: 593193056
1 parent 597524c commit a7322af

1 file changed

Lines changed: 52 additions & 0 deletions

File tree

learned_optimization/research/univ_nfn/learned_opt/learned_opts.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,56 @@ def make_hk_irnn_perm_spec(mlp_params):
153153
return perm_spec
154154

155155

156+
def make_hk_transformer_perm_spec(mlp_params):
157+
"""Make perm spec for a transformer_lm.
158+
159+
Example:
160+
{'transformer/embed': {'embeddings': (32100, 32)},
161+
'transformer/h0_attn/key': {'b': (128,), 'w': (32, 128)},
162+
'transformer/h0_attn/linear': {'b': (32,), 'w': (128, 32)},
163+
'transformer/h0_attn/query': {'b': (128,), 'w': (32, 128)},
164+
'transformer/h0_attn/value': {'b': (128,), 'w': (32, 128)},
165+
'transformer/h0_ln_1': {'offset': (32,), 'scale': (32,)},
166+
'transformer/h0_ln_2': {'offset': (32,), 'scale': (32,)},
167+
'transformer/h0_mlp/linear': {'b': (128,), 'w': (32, 128)},
168+
'transformer/h0_mlp/linear_1': {'b': (32,), 'w': (128, 32)},
169+
'transformer/h1_attn/key': {'b': (128,), 'w': (32, 128)},
170+
'transformer/h1_attn/linear': {'b': (32,), 'w': (128, 32)},
171+
'transformer/h1_attn/query': {'b': (128,), 'w': (32, 128)},
172+
'transformer/h1_attn/value': {'b': (128,), 'w': (32, 128)},
173+
'transformer/h1_ln_1': {'offset': (32,), 'scale': (32,)},
174+
'transformer/h1_ln_2': {'offset': (32,), 'scale': (32,)},
175+
'transformer/h1_mlp/linear': {'b': (128,), 'w': (32, 128)},
176+
'transformer/h1_mlp/linear_1': {'b': (32,), 'w': (128, 32)},
177+
'transformer/h_f': {'offset': (32,), 'scale': (32,)},
178+
'transformer/linear': {'b': (32100,), 'w': (32, 32100)}}
179+
"""
180+
# -1,-2: vocab, 0: embed, 1: hidden, 2: embed_2, 3: hidden_2, 4: embed_3
181+
del (mlp_params,)
182+
perm_spec = {
183+
'transformer/embed': {'embeddings': (-1, 0)},
184+
'transformer/h0_attn/key': {'b': (1,), 'w': (0, 1)},
185+
'transformer/h0_attn/linear': {'b': (0,), 'w': (1, 0)},
186+
'transformer/h0_attn/query': {'b': (1,), 'w': (0, 1)},
187+
'transformer/h0_attn/value': {'b': (1,), 'w': (0, 1)},
188+
'transformer/h0_ln_1': {'offset': (0,), 'scale': (0,)},
189+
'transformer/h0_ln_2': {'offset': (0,), 'scale': (0,)},
190+
'transformer/h0_mlp/linear': {'b': (1,), 'w': (0, 1)},
191+
'transformer/h0_mlp/linear_1': {'b': (2,), 'w': (1, 2)},
192+
'transformer/h1_attn/key': {'b': (3,), 'w': (2, 3)},
193+
'transformer/h1_attn/linear': {'b': (2,), 'w': (3, 2)},
194+
'transformer/h1_attn/query': {'b': (3,), 'w': (2, 3)},
195+
'transformer/h1_attn/value': {'b': (3,), 'w': (2, 3)},
196+
'transformer/h1_ln_1': {'offset': (2,), 'scale': (2,)},
197+
'transformer/h1_ln_2': {'offset': (2,), 'scale': (2,)},
198+
'transformer/h1_mlp/linear': {'b': (3,), 'w': (2, 3)},
199+
'transformer/h1_mlp/linear_1': {'b': (4,), 'w': (3, 4)},
200+
'transformer/h_f': {'offset': (4,), 'scale': (4,)},
201+
'transformer/linear': {'b': (-2,), 'w': (4, -2)},
202+
}
203+
return perm_spec
204+
205+
156206
class MLPForOpt(nn.Module):
157207
"""MLP for learned opt."""
158208

@@ -414,6 +464,8 @@ def __init__(
414464
perm_spec = make_hk_cnn_perm_spec(example_params)
415465
elif 'irnn/linear' in example_params:
416466
perm_spec = make_hk_irnn_perm_spec(example_params)
467+
elif 'transformer/embed' in example_params:
468+
perm_spec = make_hk_transformer_perm_spec(example_params)
417469
else:
418470
perm_spec = make_hk_perm_spec(example_params)
419471
network = UnivNFNForOpt(

0 commit comments

Comments
 (0)