Skip to content

Commit d0fb6ea

Browse files
committed
refactor: simplify code
1 parent 0c2b184 commit d0fb6ea

20 files changed

Lines changed: 579 additions & 700 deletions

Cargo.lock

Lines changed: 3 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ memmap2 = "0.9"
99
rayon = "1.10"
1010
tokenizers = { version = "0.22", features = ["http"] }
1111
safetensors = "0.7.0"
12+
serde_json = "1"
1213

1314
[[bin]]
1415
name = "rustllm"

src/components/activation.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ mod tests {
6969
let mut x = vec![1000.0_f32, 1001.0];
7070
softmax(&mut x);
7171
// softmax([1000, 1001]) == softmax([-1, 0]) == [e^-1/(e^-1+1), 1/(e^-1+1)]
72-
let expected = vec![
72+
let expected = [
7373
1.0_f32.exp().recip() / (1.0_f32.exp().recip() + 1.0),
7474
1.0 / (1.0_f32.exp().recip() + 1.0),
7575
];

src/components/attention.rs

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ fn dot(x: &[f32], y: &[f32]) -> f32 {
44
x.iter().zip(y.iter()).map(|(a, b)| a * b).sum()
55
}
66

7+
#[allow(clippy::too_many_arguments)]
78
pub fn group_query_attention(
89
attn_out: &mut [f32],
910
query: &[f32],
@@ -38,24 +39,24 @@ pub fn group_query_attention(
3839
// score against every cached key at positions 0..=pos
3940
let mut scores = vec![0.0_f32; pos + 1];
4041

41-
for t in 0..=pos {
42+
for (t, score) in scores.iter_mut().enumerate() {
4243
// index into key_cache for layer, position t, kv head kv_h
4344
let kv_off = layer_idx * max_seq_len * kv_dim + t * kv_dim + kv_h * head_dim;
4445
let k_t = &key_cache[kv_off..kv_off + head_dim];
45-
scores[t] = dot(q_head, k_t) / (head_dim as f32).sqrt();
46+
*score = dot(q_head, k_t) / (head_dim as f32).sqrt();
4647
}
4748

4849
softmax(&mut scores);
4950

5051
// weighted sum of values
5152
let out_head = &mut attn_out[h * head_dim..(h + 1) * head_dim];
5253

53-
for t in 0..=pos {
54+
for (t, &score) in scores.iter().enumerate() {
5455
let kv_off = layer_idx * max_seq_len * kv_dim + t * kv_dim + kv_h * head_dim;
5556
let v_t = &value_cache[kv_off..kv_off + head_dim];
5657

5758
for j in 0..head_dim {
58-
out_head[j] += scores[t] * v_t[j];
59+
out_head[j] += score * v_t[j];
5960
}
6061
}
6162
}
@@ -73,16 +74,17 @@ mod tests {
7374
let head_dim = 2;
7475
let max_seq_len = 4;
7576
let n_embed = n_heads * head_dim; // 4
77+
let n_layers = 1;
7678

7779
// query: [head0: 1,0, head1: 0,1]
7880
let query = vec![1.0, 0.0, 0.0, 1.0];
7981

8082
// key_cache: 1 layer × 4 positions × kv_dim(2). Fill pos=0 with [1, 0]
81-
let mut key_cache = vec![0.0_f32; 1 * max_seq_len * n_kv_heads * head_dim];
83+
let mut key_cache = vec![0.0_f32; n_layers * max_seq_len * n_kv_heads * head_dim];
8284
key_cache[0] = 1.0; // pos 0, kv_head 0, dim 0
8385

8486
// value_cache: same layout. Fill pos=0 with [3, 7]
85-
let mut value_cache = vec![0.0_f32; 1 * max_seq_len * n_kv_heads * head_dim];
87+
let mut value_cache = vec![0.0_f32; n_layers * max_seq_len * n_kv_heads * head_dim];
8688
value_cache[0] = 3.0;
8789
value_cache[1] = 7.0;
8890

@@ -115,20 +117,21 @@ mod tests {
115117
let n_kv_heads = 1;
116118
let head_dim = 2;
117119
let max_seq_len = 4;
120+
let n_layers = 1;
118121

119122
// query head: [1, 0]
120123
let query = vec![1.0_f32, 0.0];
121124

122125
// key at pos=0: [1, 0] → dot with query = 1.0 → score = 1/sqrt(2)
123126
// key at pos=1: [0, 1] → dot with query = 0.0 → score = 0/sqrt(2)
124-
let mut key_cache = vec![0.0_f32; 1 * max_seq_len * n_kv_heads * head_dim];
127+
let mut key_cache = vec![0.0_f32; n_layers * max_seq_len * n_kv_heads * head_dim];
125128
key_cache[0] = 1.0; // pos=0, dim=0
126129
key_cache[1] = 0.0; // pos=0, dim=1
127130
key_cache[2] = 0.0; // pos=1, dim=0
128131
key_cache[3] = 1.0; // pos=1, dim=1
129132

130133
// value at pos=0: [10, 0], pos=1: [0, 10]
131-
let mut value_cache = vec![0.0_f32; 1 * max_seq_len * n_kv_heads * head_dim];
134+
let mut value_cache = vec![0.0_f32; n_layers * max_seq_len * n_kv_heads * head_dim];
132135
value_cache[0] = 10.0;
133136
value_cache[3] = 10.0;
134137

@@ -153,7 +156,7 @@ mod tests {
153156
let total = s + 1.0; // sum(exp(xi))
154157
let w0 = s / total; // softmax weight for pos=0
155158
let w1 = 1.0 / total; // softmax weight for pos=1
156-
let expected = vec![w0 * 10.0, w1 * 10.0];
159+
let expected = [w0 * 10.0, w1 * 10.0];
157160

158161
for (got, exp) in attn_out.iter().zip(expected.iter()) {
159162
assert!((got - exp).abs() < 1e-5, "got {got}, expected {exp}");

src/components/matmul.rs

Lines changed: 45 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,40 @@
1-
use crate::components::weight::Weight;
1+
use crate::{components::quant::vec_dot, format::gguf::GgufType, tensor::Tensor};
22
use rayon::prelude::*;
33

4-
// matrix-vector multiplication
5-
// x shape: (in_channels,)
6-
// w shape: (out_channels, in_channels) stored in row-major order (like safetensors)
7-
// out shape: (out_channels,)
8-
#[allow(dead_code)]
9-
pub fn naive_matmul<W: Weight>(out: &mut [f32], x: &[f32], weight: &[W]) {
10-
for i in 0..out.len() {
11-
let mut sum = 0.0_f32;
12-
13-
for k in 0..x.len() {
14-
sum += x[k] * weight[k + i * x.len()].to_f32();
15-
}
16-
out[i] = sum;
4+
pub fn bf16_to_f32(n: u16) -> f32 {
5+
f32::from_bits((n as u32) << 16)
6+
}
7+
8+
pub trait FloatType: Copy + Sync + Send {
9+
fn to_f32(self) -> f32;
10+
}
11+
12+
impl FloatType for f32 {
13+
fn to_f32(self) -> f32 {
14+
self
1715
}
1816
}
1917

20-
// parallel matmul
21-
pub fn matmul<W: Weight>(out: &mut [f32], x: &[f32], weight: &[W]) {
18+
impl FloatType for u16 {
19+
// f32: [1 sign] [8 exponent] [23 mantissa] = 32 bits
20+
// bf16: [1 sign] [8 exponent] [ 7 mantissa] = 16 bits
21+
// To convert BF16 → f32, put these 16 bits in the upper 16 bits of a 32-bit word
22+
// and zero-fill the bottom
23+
fn to_f32(self) -> f32 {
24+
bf16_to_f32(self)
25+
}
26+
}
27+
28+
pub fn matmul_gguf(out: &mut [f32], x: &[f32], weight: &[u8], dtype: GgufType, n_cols: usize) {
29+
let row_bytes = dtype.row_bytes(n_cols);
30+
31+
out.par_iter_mut().enumerate().for_each(|(i, o)| {
32+
let row = &weight[i * row_bytes..(i + 1) * row_bytes];
33+
*o = vec_dot(row, x, dtype);
34+
});
35+
}
36+
37+
pub fn matmul_float<W: FloatType>(out: &mut [f32], x: &[f32], weight: &[W]) {
2238
out.par_iter_mut().enumerate().for_each(|(i, o)| {
2339
// i = row index, o = &mut f32 (that output element)
2440
let in_channels: usize = x.len();
@@ -30,6 +46,15 @@ pub fn matmul<W: Weight>(out: &mut [f32], x: &[f32], weight: &[W]) {
3046
});
3147
}
3248

49+
/// Matrix-vector multiply
50+
pub fn matmul(out: &mut [f32], x: &[f32], weight: Tensor) {
51+
match weight {
52+
Tensor::F32(w) => matmul_float::<f32>(out, x, w),
53+
Tensor::BF16(w) => matmul_float::<u16>(out, x, w),
54+
Tensor::Quantized { data, dtype } => matmul_gguf(out, x, data, dtype, x.len()),
55+
}
56+
}
57+
3358
#[cfg(test)]
3459
mod tests {
3560
use super::*;
@@ -41,40 +66,31 @@ mod tests {
4166

4267
#[test]
4368
fn test_matmul_2x3() {
44-
// weight: 2 rows × 3 cols (in_channels=3, out has 2 elements)
45-
// row 0: [1, 2, 3]
46-
// row 1: [4, 5, 6]
4769
let weight = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0];
4870
let x = vec![1.0_f32, 2.0, 3.0];
4971
let mut out = vec![0.0_f32; 2];
50-
let mut parallel_out = vec![0.0_f32; 2];
5172

52-
naive_matmul(&mut out, &x, &weight);
53-
matmul(&mut parallel_out, &x, &weight);
73+
matmul(&mut out, &x, Tensor::F32(&weight));
5474

5575
assert!(approx(out[0], 14.0));
5676
assert!(approx(out[1], 32.0));
57-
58-
// check parallel output
59-
assert!(approx(parallel_out[0], 14.0));
60-
assert!(approx(parallel_out[1], 32.0));
6177
}
6278

6379
#[test]
6480
fn test_matmul_bf16() {
65-
// Helper: convert f32 to bf16 (top 16 bits)
6681
fn f32_to_bf16(x: f32) -> u16 {
6782
(x.to_bits() >> 16) as u16
6883
}
6984

70-
let weight: Vec<u16> = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0]
85+
let weight: Vec<u16> = [1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0]
7186
.iter()
7287
.map(|&v| f32_to_bf16(v))
7388
.collect();
7489
let x = vec![1.0_f32, 2.0, 3.0];
7590
let mut out = vec![0.0_f32; 2];
7691

77-
matmul(&mut out, &x, &weight);
92+
matmul(&mut out, &x, Tensor::BF16(&weight));
93+
7894
assert!(approx(out[0], 14.0));
7995
assert!(approx(out[1], 32.0));
8096
}

src/components/mod.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,3 @@ pub mod norm;
55
pub mod quant;
66
pub mod rotary_embedding;
77
pub mod sampler;
8-
pub mod weight;

src/components/norm.rs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
1-
use crate::components::weight::Weight;
2-
31
// rmsnorm (in-place)
4-
pub fn rmsnorm<W: Weight>(out: &mut [f32], x: &[f32], weight: &[W]) {
2+
pub fn rmsnorm(out: &mut [f32], x: &[f32], weight: &[f32]) {
53
let sum: f32 = x.iter().map(|&v| v * v).sum();
64

75
// x.len() is usize, that's why we cast it to f32
86
let mean_sq = sum / x.len() as f32;
97
let scale = 1.0 / (mean_sq + 1e-5_f32).sqrt();
108

119
for i in 0..out.len() {
12-
out[i] = x[i] * scale * weight[i].to_f32();
10+
out[i] = x[i] * scale * weight[i];
1311
}
1412
}
1513

@@ -26,7 +24,7 @@ mod tests {
2624
fn test_rmsnorm_uniform_weight() {
2725
let x = vec![1.0_f32, 2.0, 3.0, 4.0];
2826
let weight = vec![1.0_f32; 4];
29-
let expected = vec![0.36514813_f32, 0.73029625, 1.09544444, 1.46059251];
27+
let expected = [0.36514813_f32, 0.73029625, 1.095_444_4, 1.460_592_5];
3028
let mut out = vec![0.0_f32; 4];
3129

3230
rmsnorm(&mut out, &x, &weight);
@@ -40,7 +38,7 @@ mod tests {
4038
fn test_rmsnorm_nonuniform_weight() {
4139
let x = vec![1.0_f32, 2.0, 3.0, 4.0];
4240
let weight = vec![0.5_f32, 1.0, 2.0, 0.5];
43-
let expected = vec![0.18257406_f32, 0.73029625, 2.19088888, 0.73029625];
41+
let expected = [0.18257406_f32, 0.73029625, 2.190_889, 0.73029625];
4442
let mut out = vec![0.0_f32; 4];
4543

4644
rmsnorm(&mut out, &x, &weight);

src/components/quant.rs

Lines changed: 7 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator};
2-
31
use crate::format::gguf::GgufType;
42

53
fn fp16_to_fp32(number: u16) -> f32 {
@@ -191,7 +189,7 @@ pub fn dequantize_row_q6_k(src: &[u8], dst: &mut [f32]) {
191189

192190
for l in 0..32 {
193191
// Reconstruct 6-bit signed quant: (low4 | high2 << 4) - 32
194-
let q1 = (ql[ql_off + l] & 0xF) as i32 | (((qh[qh_off + l] >> 0) & 3) as i32) << 4;
192+
let q1 = (ql[ql_off + l] & 0xF) as i32 | ((qh[qh_off + l] & 3) as i32) << 4;
195193
let q2 =
196194
(ql[ql_off + l + 32] & 0xF) as i32 | (((qh[qh_off + l] >> 2) & 3) as i32) << 4;
197195
let q3 = (ql[ql_off + l] >> 4) as i32 | (((qh[qh_off + l] >> 4) & 3) as i32) << 4;
@@ -234,7 +232,7 @@ pub fn vec_dot_q6_k(src: &[u8], x: &[f32]) -> f32 {
234232
let xp_off = xp + chunk * 128;
235233

236234
for l in 0..16 {
237-
let q1 = (ql[ql_off + l] & 0xF) as i32 | (((qh[qh_off + l] >> 0) & 3) as i32) << 4;
235+
let q1 = (ql[ql_off + l] & 0xF) as i32 | ((qh[qh_off + l] & 3) as i32) << 4;
238236
let q2 =
239237
(ql[ql_off + l + 32] & 0xF) as i32 | (((qh[qh_off + l] >> 2) & 3) as i32) << 4;
240238
let q3 = (ql[ql_off + l] >> 4) as i32 | (((qh[qh_off + l] >> 4) & 3) as i32) << 4;
@@ -248,7 +246,7 @@ pub fn vec_dot_q6_k(src: &[u8], x: &[f32]) -> f32 {
248246
}
249247

250248
for l in 16..32 {
251-
let q1 = (ql[ql_off + l] & 0xF) as i32 | (((qh[qh_off + l] >> 0) & 3) as i32) << 4;
249+
let q1 = (ql[ql_off + l] & 0xF) as i32 | ((qh[qh_off + l] & 3) as i32) << 4;
252250
let q2 =
253251
(ql[ql_off + l + 32] & 0xF) as i32 | (((qh[qh_off + l] >> 2) & 3) as i32) << 4;
254252
let q3 = (ql[ql_off + l] >> 4) as i32 | (((qh[qh_off + l] >> 4) & 3) as i32) << 4;
@@ -272,16 +270,16 @@ pub fn vec_dot_q6_k(src: &[u8], x: &[f32]) -> f32 {
272270

273271
pub fn vec_dot(src: &[u8], x: &[f32], dtype: GgufType) -> f32 {
274272
match dtype {
275-
GgufType::Q4_K => vec_dot_q4_k(src, x),
276-
GgufType::Q6_K => vec_dot_q6_k(src, x),
273+
GgufType::Q4K => vec_dot_q4_k(src, x),
274+
GgufType::Q6K => vec_dot_q6_k(src, x),
277275
_ => panic!("not implemented"),
278276
}
279277
}
280278

281279
pub fn dequantize_row(src: &[u8], dst: &mut [f32], dtype: GgufType) {
282280
match dtype {
283-
GgufType::Q4_K => dequantize_row_q4_k(src, dst),
284-
GgufType::Q6_K => dequantize_row_q6_k(src, dst),
281+
GgufType::Q4K => dequantize_row_q4_k(src, dst),
282+
GgufType::Q6K => dequantize_row_q6_k(src, dst),
285283
GgufType::F32 => {
286284
// reinterpret bytes as f32 and copy
287285
let floats =
@@ -296,31 +294,3 @@ pub fn dequantize_row(src: &[u8], dst: &mut [f32], dtype: GgufType) {
296294
_ => panic!("dequantize_row: unsupported dtype {:?}", dtype),
297295
}
298296
}
299-
300-
pub fn matmul_gguf(out: &mut [f32], x: &[f32], weight: &[u8], dtype: GgufType, n_cols: usize) {
301-
let row_bytes = dtype.row_bytes(n_cols);
302-
303-
out.par_iter_mut().enumerate().for_each(|(i, o)| {
304-
let row = &weight[i * row_bytes..(i + 1) * row_bytes];
305-
*o = vec_dot(row, x, dtype);
306-
});
307-
}
308-
309-
/// Diagnostic: dequantize each row to f32 first, then do a plain dot product.
310-
/// Same result as matmul_gguf if vec_dot is correct. Use to isolate vec_dot bugs.
311-
pub fn matmul_gguf_naive(
312-
out: &mut [f32],
313-
x: &[f32],
314-
weight: &[u8],
315-
dtype: GgufType,
316-
n_cols: usize,
317-
) {
318-
let row_bytes = dtype.row_bytes(n_cols);
319-
let mut tmp = vec![0.0f32; n_cols];
320-
321-
for (i, o) in out.iter_mut().enumerate() {
322-
let row = &weight[i * row_bytes..(i + 1) * row_bytes];
323-
dequantize_row(row, &mut tmp, dtype);
324-
*o = tmp.iter().zip(x.iter()).map(|(a, b)| a * b).sum();
325-
}
326-
}

src/components/rotary_embedding.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,8 +143,8 @@ mod tests {
143143
// Split-half pairs: (x[0],x[2]) with cos=1,sin=0 and (x[1],x[3]) with cos=0,sin=1
144144
// q: (3,1)→(3*1-1*0, 3*0+1*1)=(3,1), (4,2)→(4*0-2*1, 4*1+2*0)=(-2,4)
145145
// result: [3, -2, 1, 4]
146-
let expected_query = vec![3.0_f32, -2.0, 1.0, 4.0];
147-
let expected_key = vec![5.0_f32, -8.0, 7.0, 6.0];
146+
let expected_query = [3.0_f32, -2.0, 1.0, 4.0];
147+
let expected_key = [5.0_f32, -8.0, 7.0, 6.0];
148148

149149
// check query
150150
for (got, exp) in query.iter().zip(expected_query.iter()) {

0 commit comments

Comments
 (0)