Skip to content

Commit c794695

Browse files
committed
perf(gpt2): transpose weights at load time for SIMD-contiguous matmul
Weight matrices pre-transposed from [in_dim, out_dim] to [out_dim, in_dim] during safetensors loading. matmul_vec_simd now reads contiguous rows via F32x16::from_slice + mul_add — full SIMD utilization (768D = 48 × F32x16). https://claude.ai/code/session_01Y69Vnw751w75iVSBRws7o7
1 parent 929b143 commit c794695

2 files changed

Lines changed: 46 additions & 7 deletions

File tree

src/hpc/gpt2/inference.rs

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -364,13 +364,27 @@ fn softmax_simd(x: &mut [f32]) {
364364
/// Matrix-vector multiply: out = input @ weight^T + bias.
365365
/// Weight stored as [input_dim, output_dim] (row-major, transposed access).
366366
/// SIMD accelerated for the dot product.
367+
/// Matrix-vector multiply: out = input @ weight + bias.
368+
/// Weight is PRE-TRANSPOSED to [out_dim, in_dim] for contiguous SIMD access.
369+
/// Each output element reads a contiguous row of in_dim floats.
367370
fn matmul_vec_simd(input: &[f32], weight: &[f32], bias: &[f32], output: &mut [f32], in_dim: usize, out_dim: usize) {
368-
// GPT-2 stores weights as [in_dim, out_dim] (row-major).
369-
// Strided access per output — TODO: transpose at load time for SIMD.
371+
let chunks = in_dim / 16;
372+
let remainder = in_dim % 16;
373+
370374
for o in 0..out_dim {
371-
let mut dot = 0.0f32;
372-
for i in 0..in_dim {
373-
dot += input[i] * weight[i * out_dim + o];
375+
let row_offset = o * in_dim;
376+
let mut acc = F32x16::splat(0.0);
377+
for c in 0..chunks {
378+
let off = c * 16;
379+
let vi = F32x16::from_slice(&input[off..off + 16]);
380+
let vw = F32x16::from_slice(&weight[row_offset + off..row_offset + off + 16]);
381+
acc = vi.mul_add(vw, acc);
382+
}
383+
let mut dot = acc.reduce_sum();
384+
// Scalar tail
385+
let tail_start = chunks * 16;
386+
for i in 0..remainder {
387+
dot += input[tail_start + i] * weight[row_offset + tail_start + i];
374388
}
375389
output[o] = dot + bias[o];
376390
}

src/hpc/gpt2/weights.rs

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,10 +111,35 @@ impl Gpt2Weights {
111111
});
112112
}
113113

114-
Ok(Gpt2Weights {
114+
let mut weights = Gpt2Weights {
115115
wte, wpe, layers, ln_f_weight, ln_f_bias,
116-
})
116+
};
117+
weights.transpose_weights_for_simd();
118+
Ok(weights)
119+
}
120+
121+
/// Transpose all weight matrices from [in_dim, out_dim] to [out_dim, in_dim].
122+
/// After this, matmul can read weight rows contiguously for F32x16 SIMD.
123+
fn transpose_weights_for_simd(&mut self) {
124+
for layer in &mut self.layers {
125+
transpose_matrix(&mut layer.attn_qkv_weight, EMBED_DIM, 3 * EMBED_DIM);
126+
transpose_matrix(&mut layer.attn_out_weight, EMBED_DIM, EMBED_DIM);
127+
transpose_matrix(&mut layer.mlp_fc_weight, EMBED_DIM, MLP_DIM);
128+
transpose_matrix(&mut layer.mlp_proj_weight, MLP_DIM, EMBED_DIM);
129+
}
130+
}
131+
}
132+
133+
/// Transpose a [rows, cols] matrix in-place to [cols, rows].
134+
fn transpose_matrix(data: &mut Vec<f32>, rows: usize, cols: usize) {
135+
assert_eq!(data.len(), rows * cols);
136+
let mut transposed = vec![0.0f32; rows * cols];
137+
for r in 0..rows {
138+
for c in 0..cols {
139+
transposed[c * rows + r] = data[r * cols + c];
140+
}
117141
}
142+
*data = transposed;
118143
}
119144

120145
/// Tensor metadata from safetensors header.

0 commit comments

Comments
 (0)