Skip to content

Commit 09815cb

Browse files
authored
Merge pull request #82 from AdaWorldAPI/claude/setup-embedding-pipeline-Fa65C
fix: make simd module public (was pub(crate)) Consumers (bgz-tensor, thinking-engine) need ndarray::simd::bf16_to_f32_batch and other polyfill exports. Was incorrectly restricted to pub(crate). https://claude.ai/code/session_01ChLvBfpJS8dQhHxRD4pYNp
2 parents 95a19ba + 6e6bc85 commit 09815cb

3 files changed

Lines changed: 274 additions & 7 deletions

File tree

.claude/AMX_GOTCHAS.md

Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
# AMX Gotchas — Resolved on Stable Rust 1.94
2+
3+
> Updated: 2026-04-03
4+
> CPU: Sapphire Rapids (AMX-TILE + AMX-INT8 + AMX-BF16 confirmed)
5+
> Kernel: 6.18.5 (XCR0 bits 17+18 enabled)
6+
7+
---
8+
9+
## Status
10+
11+
AMX works on **stable Rust 1.94** via `asm!()`. No nightly needed.
12+
13+
```
14+
LDTILECFG: ✓ (load tile configuration)
15+
TILEZERO: ✓ (zero a tile register)
16+
TILERELEASE: ✓ (release tiles)
17+
TDPBUSD: ✓ (u8×i8 tile dot product, 256 MACs/instruction)
18+
```
19+
20+
---
21+
22+
## Gotcha 1: Rust intrinsics are NIGHTLY ONLY
23+
24+
```rust
25+
// This DOES NOT compile on stable:
26+
use std::arch::x86_64::_tile_loadconfig; // error: unstable feature x86_amx_intrinsics
27+
```
28+
29+
**Fix**: Use `asm!()` (stable since Rust 1.59):
30+
```rust
31+
asm!("ldtilecfg [{}]", in(reg) config.data.as_ptr(), options(nostack));
32+
```
33+
34+
Tracking issue: https://github.com/rust-lang/rust/issues/126622
35+
36+
---
37+
38+
## Gotcha 2: Tile config MUST be 64-byte aligned
39+
40+
```rust
41+
// This SEGFAULTS:
42+
let config = [0u8; 64]; // stack-allocated, no alignment guarantee
43+
44+
// This WORKS:
45+
#[repr(C, align(64))]
46+
struct TileConfig { data: [u8; 64] }
47+
let config = TileConfig { data: [0u8; 64] };
48+
```
49+
50+
LDTILECFG reads 64 bytes from the pointer. If not 64-byte aligned,
51+
the CPU raises #GP (general protection fault) → SIGSEGV.
52+
53+
---
54+
55+
## Gotcha 3: rbx is LLVM-reserved
56+
57+
```rust
58+
// This DOES NOT compile:
59+
asm!("cpuid", out("ebx") ebx, ...); // error: rbx is used internally by LLVM
60+
61+
// This WORKS:
62+
let result = core::arch::x86_64::__cpuid_count(7, 0); // stable, handles rbx internally
63+
```
64+
65+
For CPUID leaf 7 (AMX detection): use `__cpuid_count()`, not inline asm.
66+
67+
---
68+
69+
## Gotcha 4: OS must enable AMX via XSETBV
70+
71+
AMX tiles are large (8 KB of state). The OS must opt in via XCR0 bits 17+18.
72+
Linux 5.19+ enables AMX by default. Older kernels: SIGILL on tile instructions.
73+
74+
**Detection (stable)**:
75+
```rust
76+
let xcr0 = core::arch::x86_64::__cpuid_count(0xD, 0);
77+
let tilecfg = (xcr0.eax >> 17) & 1; // bit 17 = XTILECFG
78+
let tiledata = (xcr0.eax >> 18) & 1; // bit 18 = XTILEDATA
79+
// Both must be 1
80+
```
81+
82+
---
83+
84+
## Gotcha 5: TILEZERO/TILERELEASE need manual byte encoding
85+
86+
The Rust assembler on some toolchains doesn't know AMX mnemonics.
87+
Use raw instruction bytes:
88+
89+
```rust
90+
// TILEZERO tmm0
91+
asm!(".byte 0xc4, 0xe2, 0x7b, 0x49, 0xc0", options(nostack, nomem));
92+
93+
// TILEZERO tmm1
94+
asm!(".byte 0xc4, 0xe2, 0x7b, 0x49, 0xc8", options(nostack, nomem));
95+
96+
// TILEZERO tmm2
97+
asm!(".byte 0xc4, 0xe2, 0x7b, 0x49, 0xd0", options(nostack, nomem));
98+
99+
// TILEZERO tmm3
100+
asm!(".byte 0xc4, 0xe2, 0x7b, 0x49, 0xd8", options(nostack, nomem));
101+
102+
// TILERELEASE
103+
asm!(".byte 0xc4, 0xe2, 0x78, 0x49, 0xc0", options(nostack, nomem));
104+
105+
// TDPBUSD tmm0, tmm1, tmm2 (C += A × B)
106+
asm!(".byte 0xc4, 0xe2, 0x73, 0x5e, 0xc1", options(nostack, nomem));
107+
```
108+
109+
Note: LDTILECFG works as a mnemonic:
110+
```rust
111+
asm!("ldtilecfg [{}]", in(reg) ptr, options(nostack));
112+
```
113+
114+
---
115+
116+
## Gotcha 6: Tile config field layout is not obvious
117+
118+
The 64-byte tile config structure:
119+
```
120+
Byte 0: palette (must be 1)
121+
Bytes 1-15: reserved (zero)
122+
Bytes 16-23: rows per tile (tile 0 at byte 16, tile 1 at byte 17, ...)
123+
Bytes 24-47: reserved (zero)
124+
Bytes 48-63: colbytes per tile (tile 0 at [48..49] as u16 LE, tile 1 at [50..51], ...)
125+
```
126+
127+
For TDPBUSD (u8×i8 → i32):
128+
- Tile 0 (C result): rows=16, colbytes=64 (16 × i32 = 64 bytes per row)
129+
- Tile 1 (A input): rows=16, colbytes=64 (16 × 64 u8)
130+
- Tile 2 (B input): rows=16, colbytes=64 (transposed for column access)
131+
132+
**IMPORTANT**: colbytes is a u16 at byte offset 48+2*tile_id (little-endian).
133+
For values ≤ 64, only the low byte matters.
134+
135+
---
136+
137+
## Gotcha 7: TILEZERO with wrong config = SEGFAULT
138+
139+
If you configure tile 0 as 16 rows × 64 colbytes but then TILEZERO tmm0,
140+
it works. But if the config doesn't match what the hardware expects (e.g.,
141+
palette=0 or all zeros), TILEZERO will SEGFAULT.
142+
143+
**Fix**: Always start with the minimal working config:
144+
```rust
145+
cfg.data[0] = 1; // palette 1 (MUST be 1, not 0)
146+
cfg.data[16] = 1; // at least 1 row
147+
cfg.data[48] = 4; // at least 4 colbytes (1 × i32)
148+
```
149+
150+
Then expand to full 16×64 after verifying the minimal config works.
151+
152+
---
153+
154+
## Gotcha 8: is_x86_feature_detected!("amx-tile") is NIGHTLY ONLY
155+
156+
```rust
157+
// DOES NOT compile on stable:
158+
is_x86_feature_detected!("amx-tile") // error: unstable x86_amx_intrinsics
159+
160+
// WORKS on stable:
161+
fn amx_available() -> bool {
162+
let cpuid = core::arch::x86_64::__cpuid_count(7, 0);
163+
let amx_tile = (cpuid.edx >> 24) & 1;
164+
let amx_int8 = (cpuid.edx >> 25) & 1;
165+
amx_tile == 1 && amx_int8 == 1
166+
}
167+
```
168+
169+
Use `__cpuid_count` (stable) for detection, not `is_x86_feature_detected!`.
170+
171+
---
172+
173+
## Hardware Tiers (this session)
174+
175+
```
176+
Tier Feature MACs/instr Detection (stable) CPU
177+
──── ─────── ────────── ────────────────── ───
178+
3 AMX 256 __cpuid_count(7,0).edx bit 24 Sapphire Rapids+
179+
2 avx512vnni 64 is_x86_feature_detected! Cascade Lake+, Zen 4+
180+
1 avxvnniint8 32 is_x86_feature_detected! Arrow Lake (NUC 14)
181+
0 scalar 1 always any
182+
```
183+
184+
Also detectable but not yet kernelized:
185+
- `avxvnniint16`: i16×i16 dot product (VPDPWSSD)
186+
- `amx-bf16`: TDPBF16PS (BF16 tile matmul, for calibration)
187+
188+
---
189+
190+
## Files
191+
192+
```
193+
ndarray/src/simd_amx.rs — AMX detection + VNNI/VNNI2 kernels + quantize
194+
ndarray/src/hpc/amx_matmul.rs — AMX tile ops via inline asm (TDPBUSD)
195+
ndarray/crates/burn/src/ops/matmul.rs — 4-tier dispatch in distance table builder
196+
```
197+
198+
---
199+
200+
## What AMX Enables
201+
202+
```
203+
Distance table build (4096² = 16M dot products):
204+
AMX: ~20 min (all models combined)
205+
avx512vnni: ~1:20h
206+
avxvnniint8: ~2:40h (NUC 14)
207+
scalar: ~24-48h
208+
209+
ThinkingEngine MatVec (per cycle):
210+
AMX: ~44 μs (L1 table fits in 4 tile registers)
211+
avx512vnni: ~175 μs
212+
avxvnniint8: ~350 μs
213+
scalar: ~5 ms
214+
```

crates/burn/src/ops/matmul.rs

Lines changed: 59 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -125,16 +125,69 @@ pub fn try_vnni_matmul_u8(
125125
false
126126
}
127127

128-
/// Build a k×k distance table from k centroids using VNNI if available.
128+
/// Build a k×k COSINE SIMILARITY table from f32 centroids.
129+
///
130+
/// Takes raw f32 centroids, normalizes to unit vectors, quantizes,
131+
/// runs tiered VNNI/AMX dot product, maps to u8 [0, 255].
132+
///
133+
/// This IS the ThinkingEngine's brain. cosine[-1,1] → u8[0,255].
134+
/// 128 = orthogonal. 255 = identical. 0 = opposite.
135+
///
136+
/// centroids_f32: [k × dim] raw f32 centroids (row-major)
137+
/// Returns: [k × k] u8 cosine similarity table
138+
#[cfg(feature = "std")]
139+
pub fn build_cosine_table(centroids_f32: &[f32], k: usize, dim: usize) -> Vec<u8> {
140+
assert_eq!(centroids_f32.len(), k * dim);
141+
142+
// Step 1: Normalize each centroid to unit vector
143+
let mut normed = vec![0.0f32; k * dim];
144+
for i in 0..k {
145+
let row = &centroids_f32[i * dim..(i + 1) * dim];
146+
let norm: f32 = row.iter().map(|v| v * v).sum::<f32>().sqrt();
147+
let inv_norm = if norm > 1e-10 { 1.0 / norm } else { 0.0 };
148+
for d in 0..dim {
149+
normed[i * dim + d] = row[d] * inv_norm;
150+
}
151+
}
152+
153+
// Step 2: Quantize normalized [-1, 1] → u8 [0, 255]
154+
// After normalization, values are in [-1, 1].
155+
// Map: u8 = round((value + 1.0) * 127.5)
156+
let centroids_u8: Vec<u8> = normed.iter()
157+
.map(|&v| ((v + 1.0) * 127.5).round().clamp(0.0, 255.0) as u8)
158+
.collect();
159+
160+
// Step 3: Compute dot products using tiered VNNI dispatch
161+
let raw_dots = build_distance_table_vnni(&centroids_u8, k, dim);
162+
163+
// Step 4: Map i32 dot products → u8 cosine similarity [0, 255]
164+
// The dot product of two unit vectors quantized to u8 [0,255]:
165+
// max dot (identical) = sum of (u8_i)² over dim
166+
// min dot (opposite) = much lower
167+
// Find actual min/max to scale properly
168+
let min_dot = raw_dots.iter().copied().min().unwrap_or(0) as f64;
169+
let max_dot = raw_dots.iter().copied().max().unwrap_or(1) as f64;
170+
let range = (max_dot - min_dot).max(1.0);
171+
172+
let mut table = vec![128u8; k * k]; // 128 = default orthogonal
173+
for i in 0..k {
174+
for j in 0..k {
175+
let raw = raw_dots[i * k + j] as f64;
176+
let normalized = (raw - min_dot) / range; // [0, 1]
177+
table[i * k + j] = (normalized * 255.0).round().clamp(0.0, 255.0) as u8;
178+
}
179+
}
180+
181+
table
182+
}
183+
184+
/// Build a k×k RAW DOT PRODUCT table from u8 centroids using VNNI if available.
129185
///
130186
/// centroids_u8: [k × dim] quantized codebook centroids (u8, row-major)
131187
/// Returns: [k × k] i32 dot product matrix (symmetric)
132188
///
133-
/// Uses VNNI dot product (64 MACs/instruction) for each centroid pair.
134-
/// Symmetric: only computes upper triangle, mirrors to lower.
135-
///
136-
/// This IS the ThinkingEngine's brain construction step.
137-
/// 4096² = 16M dot products. With VNNI: ~1:20h for large dim.
189+
/// For cosine: use build_cosine_table() which normalizes first.
190+
/// This function is for raw dot products when centroids are already u8.
138191
#[cfg(feature = "std")]
139192
pub fn build_distance_table_vnni(centroids_u8: &[u8], k: usize, dim: usize) -> Vec<i32> {
140193
assert_eq!(centroids_u8.len(), k * dim);

src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ mod dimension;
232232
/// Portable SIMD types — `crate::simd::f32x16` today, `std::simd::f32x16` tomorrow.
233233
#[cfg(feature = "std")]
234234
#[allow(missing_docs)]
235-
pub(crate) mod simd;
235+
pub mod simd;
236236
#[cfg(all(feature = "std", target_arch = "x86_64"))]
237237
#[allow(missing_docs, dead_code)]
238238
pub(crate) mod simd_avx512;

0 commit comments

Comments
 (0)