@@ -153,6 +153,56 @@ def make_hk_irnn_perm_spec(mlp_params):
153153 return perm_spec
154154
155155
156+ def make_hk_transformer_perm_spec (mlp_params ):
157+ """Make perm spec for a transformer_lm.
158+
159+ Example:
160+ {'transformer/embed': {'embeddings': (32100, 32)},
161+ 'transformer/h0_attn/key': {'b': (128,), 'w': (32, 128)},
162+ 'transformer/h0_attn/linear': {'b': (32,), 'w': (128, 32)},
163+ 'transformer/h0_attn/query': {'b': (128,), 'w': (32, 128)},
164+ 'transformer/h0_attn/value': {'b': (128,), 'w': (32, 128)},
165+ 'transformer/h0_ln_1': {'offset': (32,), 'scale': (32,)},
166+ 'transformer/h0_ln_2': {'offset': (32,), 'scale': (32,)},
167+ 'transformer/h0_mlp/linear': {'b': (128,), 'w': (32, 128)},
168+ 'transformer/h0_mlp/linear_1': {'b': (32,), 'w': (128, 32)},
169+ 'transformer/h1_attn/key': {'b': (128,), 'w': (32, 128)},
170+ 'transformer/h1_attn/linear': {'b': (32,), 'w': (128, 32)},
171+ 'transformer/h1_attn/query': {'b': (128,), 'w': (32, 128)},
172+ 'transformer/h1_attn/value': {'b': (128,), 'w': (32, 128)},
173+ 'transformer/h1_ln_1': {'offset': (32,), 'scale': (32,)},
174+ 'transformer/h1_ln_2': {'offset': (32,), 'scale': (32,)},
175+ 'transformer/h1_mlp/linear': {'b': (128,), 'w': (32, 128)},
176+ 'transformer/h1_mlp/linear_1': {'b': (32,), 'w': (128, 32)},
177+ 'transformer/h_f': {'offset': (32,), 'scale': (32,)},
178+ 'transformer/linear': {'b': (32100,), 'w': (32, 32100)}}
179+ """
180+ # -1,-2: vocab, 0: embed, 1: hidden, 2: embed_2, 3: hidden_2, 4: embed_3
181+ del (mlp_params ,)
182+ perm_spec = {
183+ 'transformer/embed' : {'embeddings' : (- 1 , 0 )},
184+ 'transformer/h0_attn/key' : {'b' : (1 ,), 'w' : (0 , 1 )},
185+ 'transformer/h0_attn/linear' : {'b' : (0 ,), 'w' : (1 , 0 )},
186+ 'transformer/h0_attn/query' : {'b' : (1 ,), 'w' : (0 , 1 )},
187+ 'transformer/h0_attn/value' : {'b' : (1 ,), 'w' : (0 , 1 )},
188+ 'transformer/h0_ln_1' : {'offset' : (0 ,), 'scale' : (0 ,)},
189+ 'transformer/h0_ln_2' : {'offset' : (0 ,), 'scale' : (0 ,)},
190+ 'transformer/h0_mlp/linear' : {'b' : (1 ,), 'w' : (0 , 1 )},
191+ 'transformer/h0_mlp/linear_1' : {'b' : (2 ,), 'w' : (1 , 2 )},
192+ 'transformer/h1_attn/key' : {'b' : (3 ,), 'w' : (2 , 3 )},
193+ 'transformer/h1_attn/linear' : {'b' : (2 ,), 'w' : (3 , 2 )},
194+ 'transformer/h1_attn/query' : {'b' : (3 ,), 'w' : (2 , 3 )},
195+ 'transformer/h1_attn/value' : {'b' : (3 ,), 'w' : (2 , 3 )},
196+ 'transformer/h1_ln_1' : {'offset' : (2 ,), 'scale' : (2 ,)},
197+ 'transformer/h1_ln_2' : {'offset' : (2 ,), 'scale' : (2 ,)},
198+ 'transformer/h1_mlp/linear' : {'b' : (3 ,), 'w' : (2 , 3 )},
199+ 'transformer/h1_mlp/linear_1' : {'b' : (4 ,), 'w' : (3 , 4 )},
200+ 'transformer/h_f' : {'offset' : (4 ,), 'scale' : (4 ,)},
201+ 'transformer/linear' : {'b' : (- 2 ,), 'w' : (4 , - 2 )},
202+ }
203+ return perm_spec
204+
205+
156206class MLPForOpt (nn .Module ):
157207 """MLP for learned opt."""
158208
@@ -414,6 +464,8 @@ def __init__(
414464 perm_spec = make_hk_cnn_perm_spec (example_params )
415465 elif 'irnn/linear' in example_params :
416466 perm_spec = make_hk_irnn_perm_spec (example_params )
467+ elif 'transformer/embed' in example_params :
468+ perm_spec = make_hk_transformer_perm_spec (example_params )
417469 else :
418470 perm_spec = make_hk_perm_spec (example_params )
419471 network = UnivNFNForOpt (
0 commit comments