Skip to content

Commit a3f84f8

Browse files
committed
feat: add Qwen3-8B
1 parent b594867 commit a3f84f8

11 files changed

Lines changed: 1260 additions & 0 deletions

File tree

CMakeLists.txt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,14 @@ add_executable(llama3
190190
)
191191
link_infini_train_exe(llama3)
192192

193+
add_executable(qwen3
194+
example/qwen3/main.cc
195+
example/common/tiny_shakespeare_dataset.cc
196+
example/common/utils.cc
197+
example/qwen3/checkpoint_loader.cc
198+
example/common/tokenizer.cc
199+
)
200+
link_infini_train_exe(qwen3)
193201
# Tools
194202
add_subdirectory(tools/infini_run)
195203
set_target_properties(infini_run PROPERTIES RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR})

example/qwen3/checkpoint_loader.cc

Lines changed: 440 additions & 0 deletions
Large diffs are not rendered by default.

example/qwen3/checkpoint_loader.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#pragma once
2+
3+
#include <memory>
4+
#include <string>
5+
6+
namespace infini_train::nn {
7+
class TransformerModel;
8+
} // namespace infini_train::nn
9+
10+
namespace qwen3 {
11+
std::shared_ptr<infini_train::nn::TransformerModel> LoadFromLLMC(const std::string &filepath);
12+
} // namespace qwen3

example/qwen3/config.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#pragma once
2+
3+
#include "infini_train/include/nn/modules/transformer/transformer_config.h"
4+
5+
namespace nn = infini_train::nn;
6+
namespace qwen3 {
7+
inline nn::TransformerConfig Qwen3Config() {
8+
return {.block_size = 40960,
9+
.vocab_size = 151936,
10+
.original_vocab_size = 151936,
11+
.n_layer = 36,
12+
.n_head = 32,
13+
.n_kv_head = 8,
14+
.n_embd = 4096,
15+
.attention_type = nn::AttentionType::kRoPE,
16+
.activation_type = nn::MLPType::kSwiGLU,
17+
.norm_type = nn::NormType::kRMSNorm,
18+
.add_bias_linear = false,
19+
.add_bias_lm_head = false,
20+
.tie_weights = false,
21+
.ffn_expansion_ratio = 4.5f, // 4096*4.5*2/3 = 12288
22+
.ffn_dim_multiplier = std::nullopt,
23+
.multiple_of = 1,
24+
.rope_theta = 1000000.0f,
25+
.use_scaled_rope = false,
26+
.norm_eps = 1e-6f};
27+
}
28+
} // namespace qwen3

example/qwen3/main.cc

Lines changed: 450 additions & 0 deletions
Large diffs are not rendered by default.

infini_train/include/nn/modules/transformer/causal_self_attention.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <vector>
66

77
#include "infini_train/include/nn/modules/module.h"
8+
#include "infini_train/include/nn/modules/normalization.h"
89
#include "infini_train/include/nn/modules/transformer/transformer_config.h"
910

1011
namespace infini_train::nn {
@@ -15,6 +16,9 @@ class CausalSelfAttention : public infini_train::nn::CloneableModule<CausalSelfA
1516
static constexpr char kCAttnLayerName[] = "c_attn";
1617
static constexpr char kCProjLayerName[] = "c_proj";
1718

19+
static constexpr char kQNormLayerName[] = "q_norm";
20+
static constexpr char kKNormLayerName[] = "k_norm";
21+
1822
static constexpr char kParamBiasName[] = "bias";
1923

2024
explicit CausalSelfAttention(const TransformerConfig &config);
@@ -32,6 +36,9 @@ class CausalSelfAttention : public infini_train::nn::CloneableModule<CausalSelfA
3236
int64_t n_rep_ = 0;
3337
int64_t head_dim_ = 0;
3438

39+
std::shared_ptr<infini_train::nn::RMSNorm> q_norm_;
40+
std::shared_ptr<infini_train::nn::RMSNorm> k_norm_;
41+
3542
// Setup method for different attention modes
3643
void SetupAttention(const TransformerConfig &config);
3744

infini_train/include/nn/modules/transformer/transformer_config.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,10 @@ struct TransformerConfig {
6161
bool flash = false; // flash attention
6262
int64_t max_gen_batch_size = 4; // max batch size during inference
6363

64+
// ===== Q-K Norm (Qwen3 特有) =====
65+
bool use_qk_norm = false;
66+
float qk_norm_eps = 1e-6f;
67+
6468
bool UseGQA() const;
6569
int GetChunkSize() const;
6670
};

infini_train/src/nn/modules/transformer/causal_self_attention.cc

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,13 @@ namespace infini_train::nn {
2121
CausalSelfAttention::CausalSelfAttention(const TransformerConfig &config) : CloneableModule(kType), config_(config) {
2222
SetupAttention(config);
2323

24+
if (config_.use_qk_norm) {
25+
q_norm_ = std::make_shared<nn::RMSNorm>(head_dim_, config_.qk_norm_eps);
26+
k_norm_ = std::make_shared<nn::RMSNorm>(head_dim_, config_.qk_norm_eps);
27+
modules_[kQNormLayerName] = q_norm_;
28+
modules_[kKNormLayerName] = k_norm_;
29+
}
30+
2431
int64_t qkv_dim = (config.n_head + 2 * n_kv_head_) * head_dim_;
2532
// qkv: ColumnParallel (do not gather output)
2633
modules_[kCAttnLayerName] = std::make_shared<nn::parallel::ColumnParallelLinear>(
@@ -212,6 +219,18 @@ CausalSelfAttention::ForwardWithRoPE(const std::vector<std::shared_ptr<infini_tr
212219
// v: (B, T, KV_local, D)
213220
auto v = qkv->Slice(2, q_size_local + kv_size_local, q_size_local + 2 * kv_size_local)->View({B, T, KV_local, D});
214221

222+
if (config_.use_qk_norm) {
223+
auto q_shape = q->Dims(); // [B, T, H_local, D]
224+
q = q->View({B * T * H_local, D});
225+
q = (*q_norm_)({q})[0];
226+
q = q->View(q_shape);
227+
228+
auto k_shape = k->Dims(); // [B, T, KV_local, D]
229+
k = k->View({B * T * KV_local, D});
230+
k = (*k_norm_)({k})[0];
231+
k = k->View(k_shape);
232+
}
233+
215234
// -> RoPE on q, k
216235
// q: (B, T, H_local, D)
217236
// k: (B, T, KV_local, D)
Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Convert HuggingFace Qwen3-8B checkpoint to InfiniTrain LLMC format.
4+
5+
Usage:
6+
python convert_hf_qwen3_to_llmc.py \
7+
--hf-path ./Qwen3-8B \
8+
--output qwen3-8b-fp32.llmc \
9+
--tp-size 1
10+
"""
11+
12+
import argparse
13+
import struct
14+
import os
15+
import sys
16+
17+
import torch
18+
from safetensors.torch import load_file
19+
20+
21+
# ============================================================
22+
# Qwen3 magic number (different from llama3's 20240803)
23+
# ============================================================
24+
K_QWEN3_MAGIC = 20240804
25+
K_LLMC_FP32_VERSION = 3
26+
27+
28+
def parse_args():
29+
parser = argparse.ArgumentParser(description="Convert HF Qwen3 to LLMC format")
30+
parser.add_argument("--hf-path", required=True, help="Path to HF Qwen3 checkpoint dir")
31+
parser.add_argument("--output", required=True, help="Output LLMC file path")
32+
parser.add_argument("--tp-size", type=int, default=1, help="Tensor parallel size (default 1)")
33+
parser.add_argument("--tp-rank", type=int, default=0, help="Tensor parallel rank (default 0)")
34+
return parser.parse_args()
35+
36+
37+
def load_hf_weights(hf_path):
38+
"""Load all HF safetensors / pytorch_model.bin into a flat dict."""
39+
print(f"[1/4] Loading HF weights from {hf_path}...")
40+
state_dict = {}
41+
42+
if os.path.exists(os.path.join(hf_path, "model.safetensors.index.json")):
43+
import json
44+
with open(os.path.join(hf_path, "model.safetensors.index.json")) as f:
45+
index = json.load(f)
46+
loaded_files = set()
47+
for key, filename in index["weight_map"].items():
48+
if filename not in loaded_files:
49+
shard = load_file(os.path.join(hf_path, filename), device="cpu")
50+
state_dict.update(shard)
51+
loaded_files.add(filename)
52+
print(f" Loaded {filename} ({len(shard)} tensors)")
53+
elif os.path.exists(os.path.join(hf_path, "model.safetensors")):
54+
state_dict = load_file(os.path.join(hf_path, "model.safetensors"), device="cpu")
55+
else:
56+
# Fallback: pytorch_model.bin
57+
state_dict = torch.load(
58+
os.path.join(hf_path, "pytorch_model.bin"), map_location="cpu"
59+
)
60+
61+
print(f" Total tensors: {len(state_dict)}")
62+
return state_dict
63+
64+
65+
def write_header(f, config, tp_rank, tp_size):
66+
"""Write the 1024-byte LLMC header."""
67+
header = [0] * 256 # 256 int32 slots
68+
69+
header[0] = K_QWEN3_MAGIC # magic
70+
header[1] = K_LLMC_FP32_VERSION # version
71+
header[2] = config["block_size"]
72+
header[3] = config["vocab_size"]
73+
header[4] = config["n_layer"]
74+
header[5] = config["n_head"]
75+
header[6] = config["n_kv_head"]
76+
header[7] = config["n_embd"]
77+
78+
# Qwen3 has no ffn_dim_multiplier, write 0.0
79+
# Pack as float then unpack as int32 for the header slot
80+
ffn_mult_bytes = struct.pack("f", 0.0)
81+
ffn_mult_int = struct.unpack("i", ffn_mult_bytes)[0]
82+
header[8] = ffn_mult_int
83+
84+
header[9] = config["multiple_of"]
85+
86+
norm_eps_bytes = struct.pack("f", config["norm_eps"])
87+
header[10] = struct.unpack("i", norm_eps_bytes)[0]
88+
89+
rope_theta_bytes = struct.pack("f", config["rope_theta"])
90+
header[11] = struct.unpack("i", rope_theta_bytes)[0]
91+
92+
header[12] = int(config["use_scaled_rope"])
93+
header[13] = config.get("max_gen_bs", 4)
94+
header[14] = 1 # version_major
95+
header[15] = 0 # version_minor
96+
97+
# Write 256 int32 values
98+
data = struct.pack(f"{len(header)}i", *header)
99+
assert len(data) == 1024
100+
f.write(data)
101+
print(f"[2/4] Header written (magic={K_QWEN3_MAGIC}, vocab={config['vocab_size']}, layers={config['n_layer']})")
102+
103+
104+
def write_matrix(f, tensor):
105+
"""Write a 2D tensor as fp32 row-major."""
106+
t = tensor.float().cpu()
107+
assert t.dim() == 2, f"Expected 2D tensor, got {t.dim()}D: {t.shape}"
108+
arr = t.contiguous().numpy().flatten().tolist()
109+
data = struct.pack(f"{len(arr)}f", *arr)
110+
f.write(data)
111+
112+
113+
def write_vector(f, tensor):
114+
"""Write a 1D tensor as fp32."""
115+
t = tensor.float().cpu()
116+
assert t.dim() == 1, f"Expected 1D tensor, got {t.dim()}D: {t.shape}"
117+
arr = t.contiguous().numpy().tolist()
118+
data = struct.pack(f"{len(arr)}f", *arr)
119+
f.write(data)
120+
121+
122+
def shard_rows(tensor, tp_rank, tp_size):
123+
"""Row-parallel shard: split along dim 0."""
124+
if tp_size == 1:
125+
return tensor
126+
chunks = torch.chunk(tensor, tp_size, dim=0)
127+
return chunks[tp_rank]
128+
129+
130+
def shard_cols(tensor, tp_rank, tp_size):
131+
"""Column-parallel shard: split along dim 1."""
132+
if tp_size == 1:
133+
return tensor
134+
chunks = torch.chunk(tensor, tp_size, dim=1)
135+
return chunks[tp_rank]
136+
137+
138+
def convert(hf_path, output_path, tp_size=1, tp_rank=0):
139+
"""Main conversion pipeline."""
140+
141+
# ---- Load HF weights ----
142+
sd = load_hf_weights(hf_path)
143+
144+
# ---- Build config from HF ----
145+
import json
146+
with open(os.path.join(hf_path, "config.json")) as f:
147+
hf_config = json.load(f)
148+
149+
config = {
150+
"block_size": hf_config.get("max_position_embeddings", 40960),
151+
"vocab_size": hf_config["vocab_size"],
152+
"n_layer": hf_config["num_hidden_layers"],
153+
"n_head": hf_config["num_attention_heads"],
154+
"n_kv_head": hf_config["num_key_value_heads"],
155+
"n_embd": hf_config["hidden_size"],
156+
"norm_eps": hf_config.get("rms_norm_eps", 1e-6),
157+
"rope_theta": hf_config.get("rope_theta", 1000000.0),
158+
"use_scaled_rope": False,
159+
"multiple_of": 1,
160+
"max_gen_bs": 4,
161+
}
162+
163+
head_dim = config["n_embd"] // config["n_head"]
164+
q_out = config["n_embd"] # 4096
165+
kv_out = config["n_kv_head"] * head_dim # 8 * 128 = 1024
166+
ffn_hidden = hf_config["intermediate_size"] # 12288
167+
168+
print(f"[3/4] Config: {config}")
169+
print(f" head_dim={head_dim}, q_out={q_out}, kv_out={kv_out}, ffn_hidden={ffn_hidden}")
170+
171+
# ---- Write LLMC file ----
172+
with open(output_path, "wb") as f:
173+
# Header
174+
write_header(f, config, tp_rank, tp_size)
175+
176+
# 1. wte.weight [vocab_size, n_embd] → row-shard for TP
177+
wte = sd["model.embed_tokens.weight"]
178+
wte_shard = shard_rows(wte, tp_rank, tp_size)
179+
write_matrix(f, wte_shard)
180+
print(f" wte: {wte.shape} → shard {wte_shard.shape}")
181+
182+
n_layer = config["n_layer"]
183+
184+
for i in range(n_layer):
185+
prefix = f"model.layers.{i}"
186+
187+
if (i + 1) % 6 == 0 or i == 0:
188+
print(f" Layer {i}/{n_layer - 1}...")
189+
190+
# 2. ln_1.weight (input_layernorm) [n_embd] — full copy
191+
ln1 = sd[f"{prefix}.input_layernorm.weight"]
192+
write_vector(f, ln1)
193+
194+
# ===== 新增:Q-K Norm =====
195+
# 3. q_norm.weight [head_dim] — full copy
196+
q_norm_w = sd[f"{prefix}.self_attn.q_norm.weight"]
197+
write_vector(f, q_norm_w)
198+
199+
# 4. k_norm.weight [head_dim] — full copy
200+
k_norm_w = sd[f"{prefix}.self_attn.k_norm.weight"]
201+
write_vector(f, k_norm_w)
202+
203+
# 5. c_attn.weight [q_out + 2*kv_out, n_embd] — row-shard
204+
# HF: q_proj, k_proj, v_proj are separate → concat
205+
q_proj = sd[f"{prefix}.self_attn.q_proj.weight"] # [n_embd, n_embd]
206+
k_proj = sd[f"{prefix}.self_attn.k_proj.weight"] # [kv_out, n_embd]
207+
v_proj = sd[f"{prefix}.self_attn.v_proj.weight"] # [kv_out, n_embd]
208+
c_attn = torch.cat([q_proj, k_proj, v_proj], dim=0) # [q+2kv, n_embd]
209+
c_attn_shard = shard_rows(c_attn, tp_rank, tp_size)
210+
write_matrix(f, c_attn_shard)
211+
212+
# 6. c_proj (attn o_proj) [n_embd, n_embd] — col-shard
213+
o_proj = sd[f"{prefix}.self_attn.o_proj.weight"] # [n_embd, n_embd]
214+
o_proj_shard = shard_cols(o_proj, tp_rank, tp_size)
215+
write_matrix(f, o_proj_shard)
216+
217+
# 7. ln_2.weight (post_attention_layernorm) [n_embd] — full copy
218+
ln2 = sd[f"{prefix}.post_attention_layernorm.weight"]
219+
write_vector(f, ln2)
220+
221+
# 8. c_fc (gate_proj) [ffn_hidden, n_embd] — row-shard
222+
gate_proj = sd[f"{prefix}.mlp.gate_proj.weight"]
223+
gate_shard = shard_rows(gate_proj, tp_rank, tp_size)
224+
write_matrix(f, gate_shard)
225+
226+
# 9. c_fc2 (up_proj) [ffn_hidden, n_embd] — row-shard
227+
up_proj = sd[f"{prefix}.mlp.up_proj.weight"]
228+
up_shard = shard_rows(up_proj, tp_rank, tp_size)
229+
write_matrix(f, up_shard)
230+
231+
# 10. c_proj (mlp down_proj) [n_embd, ffn_hidden] — col-shard
232+
down_proj = sd[f"{prefix}.mlp.down_proj.weight"]
233+
down_shard = shard_cols(down_proj, tp_rank, tp_size)
234+
write_matrix(f, down_shard)
235+
236+
# 11. ln_f.weight (model.norm) [n_embd] — full copy
237+
ln_f = sd["model.norm.weight"]
238+
write_vector(f, ln_f)
239+
240+
# 12. lm_head.weight [vocab_size, n_embd] — row-shard for TP
241+
lm_head = sd["lm_head.weight"]
242+
lm_head_shard = shard_rows(lm_head, tp_rank, tp_size)
243+
write_matrix(f, lm_head_shard)
244+
print(f" lm_head: {lm_head.shape} → shard {lm_head_shard.shape}")
245+
246+
file_size = os.path.getsize(output_path)
247+
print(f"[4/4] Done! Output: {output_path} ({file_size / 1e9:.2f} GB)")
248+
249+
250+
if __name__ == "__main__":
251+
args = parse_args()
252+
convert(args.hf_path, args.output, args.tp_size, args.tp_rank)

scripts/convert_input_to_bin.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from transformers import AutoTokenizer
2+
import struct
3+
4+
tokenizer = AutoTokenizer.from_pretrained("/var/qy_home/jiyiming/Qwen3-8B")
5+
6+
with open("input.txt", "r") as f:
7+
text = f.read()
8+
9+
ids = tokenizer.encode(text)
10+
11+
with open("tiny_shakespeare_qwen3.bin", "wb") as f:
12+
for tid in ids:
13+
f.write(struct.pack("I", tid))
14+
15+
print(f"Wrote {len(ids)} tokens")

0 commit comments

Comments
 (0)