-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsimd_amx.rs
More file actions
413 lines (373 loc) · 15.4 KB
/
Copy pathsimd_amx.rs
File metadata and controls
413 lines (373 loc) · 15.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
//! AMX (Advanced Matrix Extensions) — confirmed working via inline asm on stable Rust 1.94.
//!
//! AMX provides hardware tile matrix multiplication:
//! TDPBUSD: 16×16 tile of u8×i8 → i32 = 256 MACs per instruction
//! TDPBF16PS: 16×16 tile of BF16×BF16 → f32
//!
//! Status: HARDWARE CONFIRMED + OS ENABLED + INLINE ASM TESTED
//! AMX-TILE: ✓ (LDTILECFG, TILEZERO, TILERELEASE all work)
//! AMX-INT8: ✓ (TDPBUSD available)
//! AMX-BF16: ✓ (TDPBF16PS available)
//! Kernel: 6.18.5 (XCR0 bits 17+18 set)
//!
//! Rust intrinsics: NIGHTLY ONLY (issue #126622)
//! Inline asm: STABLE (works on Rust 1.94, tested)
//!
//! Inline asm encoding (verified working):
//! LDTILECFG: asm!("ldtilecfg [{}]", in(reg) ptr, options(nostack))
//! TILEZERO t0: asm!(".byte 0xc4, 0xe2, 0x7b, 0x49, 0xc0", options(nostack, nomem))
//! TILERELEASE: asm!(".byte 0xc4, 0xe2, 0x78, 0x49, 0xc0", options(nostack, nomem))
//!
//! ThinkingEngine tiers:
//! AMX: 256 MACs/instr ~44 μs/cycle (via inline asm, stable)
//! VNNI: 64 MACs/instr ~175 μs/cycle (stable intrinsics)
//! F32x16: 16 MACs/instr ~400 μs/cycle (stable)
//! F64x8: 8 MACs/instr ~700 μs/cycle (stable)
//!
//! Codebook distance table build: AMX reduces 24-48h → ~1:20h.
// ═══════════════════════════════════════════════════════════════════════════
// Detection (stable — just CPUID, no AMX instructions)
// ═══════════════════════════════════════════════════════════════════════════
/// Check if AMX hardware is present AND OS-enabled.
///
/// Two checks required:
/// 1. CPUID.07H.0H:EDX bits 24 (AMX-TILE) + 25 (AMX-INT8) = CPU supports it
/// 2. XCR0 bits 17 (TILECFG) + 18 (TILEDATA) = OS has enabled tile state
///
/// The XCR0 check is critical: even if CPUID reports AMX, the hypervisor
/// may not have enabled the XSTATE for tiles. Without OS enablement,
/// LDTILECFG will SIGILL.
///
/// Previous bug: used CPUID leaf 0xD (reports what CPU supports for XSAVE)
/// instead of _xgetbv(0) (reports what OS actually enabled). The old check
/// could return true on a hypervisor that advertises AMX in CPUID but
/// hasn't set XCR0 bits 17+18.
#[cfg(target_arch = "x86_64")]
pub fn amx_available() -> bool {
// Step 1: CPU supports AMX-TILE + AMX-INT8?
let cpuid = core::arch::x86_64::__cpuid_count(7, 0);
let amx_tile = (cpuid.edx >> 24) & 1;
let amx_int8 = (cpuid.edx >> 25) & 1;
if amx_tile == 0 || amx_int8 == 0 { return false; }
// Step 2: OS enabled XSAVE? (CPUID.01H:ECX bit 27 = OSXSAVE)
let cpuid_01 = core::arch::x86_64::__cpuid(1);
let osxsave = (cpuid_01.ecx >> 27) & 1;
if osxsave == 0 { return false; }
// Step 3: OS actually enabled tile state in XCR0?
// _xgetbv(0) reads the ACTUAL XCR0 register (what the OS set),
// not the CPUID-reported capability.
// Bit 17 = TILECFG, Bit 18 = TILEDATA. Both must be set.
let xcr0: u64 = unsafe { core::arch::x86_64::_xgetbv(0) };
let tilecfg = (xcr0 >> 17) & 1;
let tiledata = (xcr0 >> 18) & 1;
if tilecfg == 0 || tiledata == 0 { return false; }
// Step 4: Request XCOMP_PERM for TILEDATA.
// Linux kernel 5.19+: processes must call prctl(ARCH_REQ_XCOMP_PERM, 18)
// to request permission for TILEDATA (XFEATURE 18) before using AMX.
// Without this, LDTILECFG will SIGILL even if XCR0 bits are set.
// The prctl either succeeds (0) or fails (-1) — idempotent, safe to call
// multiple times.
#[cfg(target_os = "linux")]
{
const SYS_PRCTL: i64 = 157; // x86_64 syscall number for prctl
const ARCH_REQ_XCOMP_PERM: i64 = 0x1023;
const XFEATURE_XTILEDATA: i64 = 18;
// SAFETY: syscall(prctl, ARCH_REQ_XCOMP_PERM, 18) is a simple permission
// request. It either grants tile permission (returns 0) or fails (returns
// -errno). No side effects on failure. Idempotent.
let ret: i64;
unsafe {
core::arch::asm!(
"syscall",
inlateout("rax") SYS_PRCTL => ret,
in("rdi") ARCH_REQ_XCOMP_PERM,
in("rsi") XFEATURE_XTILEDATA,
in("rdx") 0i64,
in("r10") 0i64,
in("r8") 0i64,
lateout("rcx") _,
lateout("r11") _,
options(nostack),
);
}
if ret != 0 { return false; }
}
true
}
#[cfg(not(target_arch = "x86_64"))]
pub fn amx_available() -> bool { false }
/// AMX capability report.
pub fn amx_report() -> String {
#[cfg(target_arch = "x86_64")]
{
let cpuid = core::arch::x86_64::__cpuid_count(7, 0);
let tile = (cpuid.edx >> 24) & 1 == 1;
let int8 = (cpuid.edx >> 25) & 1 == 1;
let bf16 = (cpuid.edx >> 22) & 1 == 1;
format!("AMX: TILE={} INT8={} BF16={} available={}", tile, int8, bf16, amx_available())
}
#[cfg(not(target_arch = "x86_64"))]
{ "AMX: not x86_64".to_string() }
}
// ═══════════════════════════════════════════════════════════════════════════
// VNNI kernel (stable intrinsics — the bridge until AMX stabilizes)
// ═══════════════════════════════════════════════════════════════════════════
/// VNNI u8×i8 dot product: 64 multiply-accumulates per instruction.
///
/// Computes: for each 32-bit lane, sum of 4 products: u8[k] × i8[k].
/// 16 lanes × 4 products = 64 MACs total.
///
/// Used by ThinkingEngine for the u8 distance table × i8 energy MatVec.
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512vnni")]
pub unsafe fn vnni_dpbusd(
acc: core::arch::x86_64::__m512i,
a: core::arch::x86_64::__m512i, // 64 × u8
b: core::arch::x86_64::__m512i, // 64 × i8 (energy, quantized)
) -> core::arch::x86_64::__m512i {
core::arch::x86_64::_mm512_dpbusd_epi32(acc, a, b)
}
/// Complete VNNI MatVec: one row of distance table × energy vector.
///
/// Row: &[u8] of length N (one row of distance table).
/// Energy: &[i8] of length N (quantized energy).
/// Returns: i32 dot product (sum of all N u8×i8 products).
///
/// Processes 64 elements per VNNI instruction.
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512vnni")]
pub unsafe fn vnni_dot_u8_i8(row: &[u8], energy: &[i8]) -> i32 {
use core::arch::x86_64::*;
let n = row.len().min(energy.len());
let chunks = n / 64;
let mut acc = _mm512_setzero_si512();
for c in 0..chunks {
let off = c * 64;
let a = _mm512_loadu_si512(row[off..].as_ptr() as *const __m512i);
let b = _mm512_loadu_si512(energy[off..].as_ptr() as *const __m512i);
acc = _mm512_dpbusd_epi32(acc, a, b);
}
// Horizontal sum of 16 i32 lanes
_mm512_reduce_add_epi32(acc)
}
/// VNNI MatVec for the entire distance table × energy vector.
///
/// table: &[u8] of size N×N (row-major distance table).
/// energy_i8: &[i8] of size N (quantized energy).
/// result: &mut [i32] of size N (output: accumulated dot products).
///
/// This IS the ThinkingEngine's core loop at VNNI resolution.
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512vnni")]
pub unsafe fn vnni_matvec(
table: &[u8],
energy_i8: &[i8],
result: &mut [i32],
n: usize,
) {
for i in 0..n {
if energy_i8.iter().all(|&e| e == 0) { result[i] = 0; continue; }
let row = &table[i * n..(i + 1) * n];
result[i] = vnni_dot_u8_i8(row, energy_i8);
}
}
/// AVX-VNNI (ymm, 256-bit) dot product: 32 MACs per VPDPBUSD instruction.
/// For CPUs with avxvnniint8 but NOT avx512vnni (Arrow Lake, NUC 14 i9-185H, etc.)
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avxvnniint8")]
pub unsafe fn vnni2_dot_u8_i8(row: &[u8], energy: &[i8]) -> i32 {
use core::arch::x86_64::*;
let n = row.len().min(energy.len());
let chunks = n / 32;
let mut acc = _mm256_setzero_si256();
for c in 0..chunks {
let off = c * 32;
let a = _mm256_loadu_si256(row[off..].as_ptr() as *const __m256i);
let b = _mm256_loadu_si256(energy[off..].as_ptr() as *const __m256i);
// VPDPBUSD ymm: 8 lanes × 4 u8×i8 products = 32 MACs
acc = _mm256_dpbusd_epi32(acc, a, b);
}
// Horizontal sum of 8 i32 lanes
let hi128 = _mm256_extracti128_si256(acc, 1);
let lo128 = _mm256_castsi256_si128(acc);
let sum128 = _mm_add_epi32(lo128, hi128);
let sum64 = _mm_add_epi32(sum128, _mm_srli_si128(sum128, 8));
let sum32 = _mm_add_epi32(sum64, _mm_srli_si128(sum64, 4));
let mut total = _mm_extract_epi32(sum32, 0);
// Scalar remainder
let offset = chunks * 32;
for i in offset..n {
total += row[i] as i32 * energy[i] as i32;
}
total
}
/// VNNI2 MatVec for the entire distance table × energy vector (ymm path).
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avxvnniint8")]
pub unsafe fn vnni2_matvec(
table: &[u8],
energy_i8: &[i8],
result: &mut [i32],
n: usize,
) {
for i in 0..n {
let row = &table[i * n..(i + 1) * n];
result[i] = vnni2_dot_u8_i8(row, energy_i8);
}
}
/// Scalar fallback for VNNI dot product (non-x86 or no VNNI).
pub fn vnni_dot_u8_i8_scalar(row: &[u8], energy: &[i8]) -> i32 {
let n = row.len().min(energy.len());
let mut acc = 0i32;
for i in 0..n {
acc += row[i] as i32 * energy[i] as i32;
}
acc
}
/// Scalar MatVec fallback.
pub fn vnni_matvec_scalar(
table: &[u8],
energy_i8: &[i8],
result: &mut [i32],
n: usize,
) {
for i in 0..n {
let row = &table[i * n..(i + 1) * n];
result[i] = vnni_dot_u8_i8_scalar(row, energy_i8);
}
}
/// Runtime-dispatched VNNI MatVec: avx512vnni → avxvnniint8 → scalar i32.
///
/// Three tiers, checked in order (first match wins):
/// avx512vnni — 64 MACs/instr (zmm, Cascade Lake+, Zen 4+)
/// avxvnniint8 — 32 MACs/instr (ymm, Arrow Lake, NUC 14 i9-185H)
/// scalar i32 — only for non-x86 or testing
///
/// IMPORTANT: avxvnniint8 (VNNI2, 256-bit) is NEVER reached when
/// avx512vnni (VNNI512) is present. This is correct:
/// - CPUs with avx512vnni always have 512-bit VPDPBUSD (faster)
/// - avxvnniint8 exists ONLY for CPUs that dropped AVX-512
/// but added 256-bit VNNI (Arrow Lake, Meteor Lake U-series)
/// - The two instructions have DIFFERENT encodings:
/// avx512vnni: EVEX-encoded VPDPBUSD zmm (512-bit)
/// avxvnniint8: VEX-encoded VPDPBUSD ymm (256-bit)
/// - Running EVEX VPDPBUSD on a VEX-only CPU = SIGILL
/// - Running VEX VPDPBUSD on an EVEX CPU = works but wastes half the width
///
/// The thinking engine's cycle_auto() dispatches:
/// VNNI detected → cycle_vnni() → this function
/// No VNNI → cycle() → F32x16 FMA (never reaches here)
pub fn matvec_dispatch(
table: &[u8],
energy_i8: &[i8],
result: &mut [i32],
n: usize,
) {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512vnni") {
unsafe { vnni_matvec(table, energy_i8, result, n); }
return;
}
if is_x86_feature_detected!("avxvnniint8") {
unsafe { vnni2_matvec(table, energy_i8, result, n); }
return;
}
}
// Non-x86 or no VNNI: i32 scalar accumulate.
// On x86, the thinking engine uses F32x16 FMA instead of reaching here.
vnni_matvec_scalar(table, energy_i8, result, n);
}
// ═══════════════════════════════════════════════════════════════════════════
// Quantize energy f64 → i8 for VNNI path
// ═══════════════════════════════════════════════════════════════════════════
/// Quantize f64 energy vector to i8 for VNNI MatVec.
/// Maps [0.0, max_energy] → [0, 127].
pub fn quantize_energy_i8(energy: &[f64], output: &mut [i8]) {
let n = energy.len().min(output.len());
let max_e = energy.iter().cloned().fold(0.0f64, f64::max);
if max_e < 1e-15 {
for o in output[..n].iter_mut() { *o = 0; }
return;
}
let scale = 127.0 / max_e;
for i in 0..n {
output[i] = (energy[i] * scale).round().clamp(0.0, 127.0) as i8;
}
}
/// Dequantize i32 result back to f64.
pub fn dequantize_result_f64(result: &[i32], output: &mut [f64], scale: f64) {
for (i, &r) in result.iter().enumerate() {
if i < output.len() {
output[i] = r as f64 * scale;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_amx_detection() {
let available = amx_available();
let report = amx_report();
eprintln!("{}", report);
eprintln!("AMX available: {}", available);
}
#[test]
fn test_vnni_dot_scalar() {
let row = vec![128u8; 64]; // similarity = 0.5
let energy = vec![10i8; 64]; // energy = 10
let dot = vnni_dot_u8_i8_scalar(&row, &energy);
assert_eq!(dot, 128 * 10 * 64);
eprintln!("Scalar dot: {}", dot);
}
#[test]
fn test_vnni_matvec_scalar() {
let n = 64;
let mut table = vec![128u8; n * n];
for i in 0..n { table[i * n + i] = 255; } // diagonal = max
let energy = vec![10i8; n];
let mut result = vec![0i32; n];
vnni_matvec_scalar(&table, &energy, &mut result, n);
// Each row: 63 × 128 × 10 + 1 × 255 × 10 = 80640 + 2550 = 83190
assert!(result[0] > 0);
eprintln!("MatVec result[0]: {}", result[0]);
}
#[test]
fn test_vnni_dispatch() {
let n = 64;
let mut table = vec![128u8; n * n];
for i in 0..n { table[i * n + i] = 255; }
let energy = vec![10i8; n];
let mut result = vec![0i32; n];
matvec_dispatch(&table, &energy, &mut result, n);
assert!(result[0] > 0);
#[cfg(target_arch = "x86_64")]
eprintln!("VNNI available: {}", is_x86_feature_detected!("avx512vnni"));
eprintln!("Dispatch result[0]: {}", result[0]);
}
#[test]
fn test_quantize_energy() {
let energy = [0.0, 0.5, 1.0, 0.25, 0.75];
let mut quant = [0i8; 5];
quantize_energy_i8(&energy, &mut quant);
assert_eq!(quant[0], 0);
assert_eq!(quant[2], 127); // max maps to 127
assert!(quant[1] > 50 && quant[1] < 70); // ~63
eprintln!("Quantized: {:?}", quant);
}
#[test]
fn test_vnni_matches_scalar() {
let n = 128;
let table: Vec<u8> = (0..n*n).map(|i| (i % 256) as u8).collect();
let energy: Vec<i8> = (0..n).map(|i| (i % 100) as i8).collect();
let mut scalar_result = vec![0i32; n];
vnni_matvec_scalar(&table, &energy, &mut scalar_result, n);
let mut dispatch_result = vec![0i32; n];
matvec_dispatch(&table, &energy, &mut dispatch_result, n);
for i in 0..n {
assert_eq!(scalar_result[i], dispatch_result[i],
"mismatch at row {}: scalar={} dispatch={}", i, scalar_result[i], dispatch_result[i]);
}
}
}