Skip to content

Commit b704331

Browse files
committed
feat: added Llama3-like initialization test config
1 parent 4dea496 commit b704331

1 file changed

Lines changed: 59 additions & 0 deletions

File tree

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
initialized_model:
2+
component_key: model
3+
variant_key: model_initialized
4+
config:
5+
model:
6+
instance_key: model_raw
7+
pass_type: BY_REFERENCE
8+
model_initializer:
9+
component_key: model_initialization
10+
variant_key: llama3_like
11+
config:
12+
num_layers: ${model_raw.config.n_layer}
13+
n_embd: ${model_raw.config.n_embd}
14+
15+
16+
model_raw:
17+
component_key: model
18+
variant_key: gpt2
19+
config:
20+
use_meta_device: true
21+
use_weight_tying: false
22+
sample_key: "input_ids"
23+
poe_type: NOPE
24+
sequence_length: 128
25+
prediction_key: "logits"
26+
vocab_size: 2048 # 2K vocab for testing
27+
n_layer: 4 # 4 layers for testing
28+
n_head_q: 32
29+
n_head_kv: 8
30+
ffn_hidden: 128 # 128 ffn hidden dim for testing
31+
n_embd: 256 # 256 embedding dim for testing
32+
dropout: 0.0
33+
bias: true
34+
attention_config:
35+
qkv_transforms:
36+
- type_hint: RotaryTransform
37+
config:
38+
n_embd: ${model_raw.config.n_embd}
39+
n_head: ${model_raw.config.n_head_q}
40+
seq_length_dim: -2
41+
base_freq: 500000
42+
attention_implementation: pytorch_flash
43+
activation_type: swiglu
44+
attention_norm_config:
45+
norm_type: pytorch_rms_norm
46+
config:
47+
normalized_shape: ${model_raw.config.n_embd}
48+
eps: 1.0e-05
49+
ffn_norm_config:
50+
norm_type: pytorch_rms_norm
51+
config:
52+
normalized_shape: ${model_raw.config.n_embd}
53+
eps: 1.0e-05
54+
lm_head_norm_config:
55+
norm_type: pytorch_rms_norm
56+
config:
57+
normalized_shape: ${model_raw.config.n_embd}
58+
eps: 1.0e-05
59+

0 commit comments

Comments
 (0)