Skip to content

Commit 13b1b7e

Browse files
committed
feat: burn-adaworld — full burn-ndarray backend copied into ndarray workspace
Copied upstream burn-ndarray (tracel-ai/burn main) into crates/burn-adaworld/. 30 tests passing. Compiles clean with upstream burn git deps. Source: ~11,700 lines (8,906 core + 2,782 SIMD via macerator). Edition: 2024 (Rust 1.85+, we run 1.93/1.94). Dependencies: burn-backend, burn-std, burn-ir, burn-autodiff from git main. This is the baseline to augment with: 1. Replace macerator SIMD with crate::simd F32x16 + LazyLock dispatch 2. Add bgz-tensor AttentionTable compiled attention path 3. Add SimilarityTable as BF16-equivalent scoring 4. Head-to-head benchmark vs upstream burn-ndarray Knowledge transfer: burn-ndarray's Backend trait implementation is the reference for implementing AdaWorld-specific optimizations. The matmul path (ops/matmul.rs) delegates to ndarray::linalg::general_mat_mul which hits BLAS. We can intercept this with AttentionTable for compiled attention layers. https://claude.ai/code/session_01Y69Vnw751w75iVSBRws7o7
1 parent eef500d commit 13b1b7e

43 files changed

Lines changed: 12075 additions & 433 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

Cargo.lock

Lines changed: 367 additions & 210 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/burn-adaworld/Cargo.toml

Lines changed: 61 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,75 @@
11
[package]
22
name = "burn-adaworld"
33
version = "0.1.0"
4-
edition = "2021"
4+
edition = "2024"
55
license = "MIT OR Apache-2.0"
66
publish = false
77
description = """
8-
Burn backend powered by adaworldapi/ndarray with:
9-
- crate::simd F32x16 via LazyLock dispatch (AVX-512 → AVX2 → scalar)
10-
- bgz-tensor AttentionTable for O(1) compiled attention (optional)
11-
- CAM-PQ product quantization for 170× compression (optional)
12-
- SimilarityTable as BF16-precision cosine replacement (256 levels, O(1))
13-
14-
The consumer sees burn's Tensor<B, D> API. Behind it:
15-
matmul() → checks for compiled AttentionTable → falls through to BLAS.
16-
All SIMD via crate::simd only. Consumer never sees hardware.
8+
Burn ndarray backend forked into adaworldapi/ndarray for SIMD augmentation.
9+
Source: upstream burn-ndarray (tracel-ai/burn, v0.21.0-pre.2).
10+
Goal: replace macerator SIMD with crate::simd F32x16 + LazyLock dispatch,
11+
add bgz-tensor AttentionTable compiled attention path.
1712
"""
1813

14+
[features]
15+
default = ["std", "simd", "multi-threads"]
16+
multi-threads = ["rayon", "ndarray/rayon", "matrixmultiply/threading"]
17+
simd = ["macerator", "bytemuck", "seq-macro", "itertools"]
18+
std = [
19+
"burn-autodiff",
20+
"burn-std/std",
21+
"burn-backend/std",
22+
"burn-ir/std",
23+
"ndarray/std",
24+
"matrixmultiply/std",
25+
"rand/std",
26+
"rand/std_rng",
27+
"num-traits/std",
28+
"macerator/std",
29+
]
30+
blas-openblas = ["blas-src/openblas", "ndarray/blas", "openblas-src"]
31+
blas-openblas-system = ["blas-src/openblas", "ndarray/blas", "openblas-src/system"]
32+
blas-netlib = ["blas-src/netlib", "ndarray/blas"]
33+
export_tests = []
34+
1935
[dependencies]
20-
# Upstream burn — Backend trait + tensor API
21-
burn-backend = "0.21.0-pre.2"
22-
burn-tensor = "0.21.0-pre.2"
36+
# Upstream burn crates (from git main — matches source code we copied)
37+
burn-autodiff = { git = "https://github.com/tracel-ai/burn.git", default-features = false, optional = true }
38+
burn-std = { git = "https://github.com/tracel-ai/burn.git", default-features = false }
39+
burn-ir = { git = "https://github.com/tracel-ai/burn.git", default-features = false }
40+
burn-backend = { git = "https://github.com/tracel-ai/burn.git", default-features = false }
41+
42+
# ndarray — uses our workspace root (adaworldapi/ndarray with SIMD + HPC)
43+
ndarray = { path = "../..", default-features = false }
44+
45+
# Matrix multiply
46+
matrixmultiply = { version = "0.3", default-features = false }
47+
48+
# Element traits
49+
num-traits = { version = "0.2", default-features = false }
50+
libm = "0.2"
51+
atomic_float = "1"
52+
const-random = "0.1"
53+
paste = "1"
2354

24-
# Our ndarray with SIMD + HPC extensions
25-
ndarray = { path = "../..", features = ["std"] }
55+
# Random
56+
rand = { version = "0.10", default-features = false, features = ["std_rng"] }
2657

27-
# Standard deps
58+
# Serialization
2859
serde = { version = "1", features = ["derive"] }
29-
half = { version = "2", features = ["num-traits"] }
30-
num-traits = "0.2"
31-
rand = "0.8"
3260

33-
[dev-dependencies]
34-
burn-tensor-testgen = "0.21.0-pre.2"
61+
# SIMD (macerator — upstream burn's choice, will augment with crate::simd)
62+
macerator = { version = "0.3", default-features = false, optional = true }
63+
bytemuck = { version = "1", optional = true }
64+
seq-macro = { version = "0.3", optional = true }
65+
itertools = { version = "0.14", optional = true }
3566

36-
[features]
37-
default = ["std"]
38-
std = []
39-
# Enable bgz-tensor AttentionTable path for compiled attention
40-
attention-table = []
41-
# Enable multi-threaded execution via rayon
42-
multi-threads = ["ndarray/rayon"]
67+
# Parallel
68+
rayon = { version = "1", optional = true }
69+
70+
# BLAS (optional)
71+
blas-src = { version = "0.10", default-features = false, optional = true }
72+
openblas-src = { version = "0.10", optional = true }
73+
74+
[dev-dependencies]
75+
bytes = "1"
Lines changed: 221 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,60 +1,222 @@
1-
//! AdaWorld backend: implements burn's Backend trait.
2-
//!
3-
//! Delegates all tensor operations to ndarray + crate::simd.
4-
//! This is the entry point — every burn model compiled with `Backend = AdaWorld`
5-
//! runs on our SIMD dispatch with optional AttentionTable compiled attention.
6-
//!
7-
//! # Implementation Status
8-
//!
9-
//! The Backend trait requires ~200+ methods across 7 op traits.
10-
//! Implementation strategy: core ops first (what Whisper/Llama need),
11-
//! then expand coverage guided by burn-backend-tests.
12-
//!
13-
//! Required traits:
14-
//! FloatTensorOps — 84 required methods (+ ~36 with defaults)
15-
//! IntTensorOps — ~50 required methods
16-
//! BoolTensorOps — ~30 required methods
17-
//! ModuleOps — conv, pool, embedding, etc.
18-
//! ActivationOps — relu, sigmoid, gelu (most have defaults)
19-
//! QTensorOps — quantized tensor ops
20-
//! TransactionOps — batch execution
21-
//!
22-
//! # Architecture
23-
//!
24-
//! ```text
25-
//! burn::Tensor<AdaWorld, D>
26-
//! ↓ (burn dispatches via Backend trait)
27-
//! AdaWorld::float_matmul(lhs, rhs)
28-
//! ↓ (check for compiled attention table)
29-
//! ├── AttentionTable[q_idx][k_idx] → O(1) (if compiled)
30-
//! └── ndarray general_mat_mul() → O(d) (fallback to BLAS)
31-
//! ↓ (ndarray delegates to BLAS or matrixmultiply)
32-
//! crate::simd::F32x16 → AVX-512 / AVX2 via LazyLock dispatch
33-
//! ```
34-
35-
use crate::tensor::AdaTensor;
36-
37-
/// The AdaWorld backend.
1+
use crate::rand::NdArrayRng;
2+
use crate::{NdArrayQTensor, NdArrayTensor};
3+
use crate::{
4+
SharedArray,
5+
element::{FloatNdArrayElement, IntNdArrayElement, QuantElement},
6+
};
7+
use alloc::string::String;
8+
use burn_backend::quantization::{QuantLevel, QuantMode, QuantScheme, QuantStore, QuantValue};
9+
use burn_backend::tensor::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor};
10+
use burn_backend::{Backend, DType, DeviceId, DeviceOps};
11+
use burn_ir::{BackendIr, HandleKind, TensorHandle};
12+
use burn_std::BoolStore;
13+
use burn_std::stub::Mutex;
14+
use core::marker::PhantomData;
15+
use rand::SeedableRng;
16+
17+
pub(crate) static SEED: Mutex<Option<NdArrayRng>> = Mutex::new(None);
18+
19+
/// The device type for the ndarray backend.
20+
#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
21+
pub enum NdArrayDevice {
22+
/// The CPU device.
23+
#[default]
24+
Cpu,
25+
}
26+
27+
impl DeviceOps for NdArrayDevice {}
28+
29+
impl burn_backend::Device for NdArrayDevice {
30+
fn from_id(_device_id: DeviceId) -> Self {
31+
Self::Cpu
32+
}
33+
34+
fn to_id(&self) -> DeviceId {
35+
DeviceId {
36+
type_id: 0,
37+
index_id: 0,
38+
}
39+
}
40+
}
41+
42+
/// Tensor backend that uses the [ndarray](ndarray) crate for executing tensor operations.
3843
///
39-
/// CPU-only. Uses adaworldapi/ndarray with crate::simd SIMD dispatch.
40-
/// Feature `attention-table` enables bgz-tensor compiled attention path.
41-
#[derive(Clone, Default, Debug)]
42-
pub struct AdaWorld;
43-
44-
/// CPU device (unit type — there's only one CPU).
45-
#[derive(Clone, Default, Debug, PartialEq, Eq, Hash)]
46-
pub struct CpuDevice;
47-
48-
// NOTE: Full Backend trait implementation requires ~200+ methods across 7 traits.
49-
// This is tracked as a multi-session effort:
50-
//
51-
// Session 1 (current): Crate skeleton + architecture + tensor primitive
52-
// Session 2: FloatTensorOps core (from_data, matmul, add, mul, exp, reshape, transpose)
53-
// Session 3: IntTensorOps + BoolTensorOps
54-
// Session 4: ModuleOps (conv, embedding) + ActivationOps
55-
// Session 5: QTensorOps + TransactionOps + burn-backend-tests
56-
//
57-
// The implementation follows burn-ndarray's pattern but uses:
58-
// - crate::simd::F32x16 for element-wise ops (not macerator)
59-
// - LazyLock<SimdDispatch> for runtime tier selection (not compile-time features)
60-
// - Optional AttentionTable for compiled attention (unique to this backend)
44+
/// This backend is compatible with CPUs and can be compiled for almost any platform, including
45+
/// `wasm`, `arm`, and `x86`.
46+
#[derive(Clone, Copy, Default, Debug)]
47+
pub struct NdArray<E = f32, I = i64, Q = i8>
48+
where
49+
NdArrayTensor: From<SharedArray<E>>,
50+
NdArrayTensor: From<SharedArray<I>>,
51+
{
52+
_e: PhantomData<E>,
53+
_i: PhantomData<I>,
54+
_q: PhantomData<Q>,
55+
}
56+
57+
impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> Backend for NdArray<E, I, Q>
58+
where
59+
NdArrayTensor: From<SharedArray<E>>,
60+
NdArrayTensor: From<SharedArray<I>>,
61+
{
62+
type Device = NdArrayDevice;
63+
64+
type FloatTensorPrimitive = NdArrayTensor;
65+
type FloatElem = E;
66+
67+
type IntTensorPrimitive = NdArrayTensor;
68+
type IntElem = I;
69+
70+
type BoolTensorPrimitive = NdArrayTensor;
71+
type BoolElem = bool;
72+
73+
type QuantizedTensorPrimitive = NdArrayQTensor;
74+
75+
fn ad_enabled(_device: &Self::Device) -> bool {
76+
false
77+
}
78+
79+
fn name(_device: &Self::Device) -> String {
80+
String::from("ndarray")
81+
}
82+
83+
fn seed(_device: &Self::Device, seed: u64) {
84+
let rng = NdArrayRng::seed_from_u64(seed);
85+
let mut seed = SEED.lock().unwrap();
86+
*seed = Some(rng);
87+
}
88+
89+
fn dtype_usage(_device: &Self::Device, dtype: DType) -> burn_backend::DTypeUsageSet {
90+
match dtype {
91+
DType::F64
92+
| DType::F32
93+
| DType::Flex32
94+
| DType::I64
95+
| DType::I32
96+
| DType::I16
97+
| DType::I8
98+
| DType::U64
99+
| DType::U32
100+
| DType::U16
101+
| DType::U8
102+
| DType::Bool(BoolStore::Native) => burn_backend::DTypeUsage::general(),
103+
DType::F16 | DType::BF16 | DType::Bool(_) => burn_backend::DTypeUsageSet::empty(),
104+
DType::QFloat(scheme) => {
105+
match scheme {
106+
QuantScheme {
107+
level: QuantLevel::Tensor | QuantLevel::Block(_),
108+
mode: QuantMode::Symmetric,
109+
#[cfg(not(feature = "export_tests"))]
110+
value: QuantValue::Q8F | QuantValue::Q8S,
111+
// For tests, "native" sub-byte quant serves as a reference for value equality.
112+
// Values are stored as i8 regardless.
113+
#[cfg(feature = "export_tests")]
114+
value:
115+
QuantValue::Q8F
116+
| QuantValue::Q8S
117+
| QuantValue::Q4F
118+
| QuantValue::Q4S
119+
| QuantValue::Q2F
120+
| QuantValue::Q2S,
121+
store: QuantStore::Native,
122+
..
123+
} => burn_backend::DTypeUsage::general(),
124+
_scheme => burn_backend::DTypeUsageSet::empty(),
125+
}
126+
}
127+
}
128+
}
129+
130+
fn device_count(_: u16) -> usize {
131+
1
132+
}
133+
}
134+
135+
impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> BackendIr for NdArray<E, I, Q>
136+
where
137+
NdArrayTensor: From<SharedArray<E>>,
138+
NdArrayTensor: From<SharedArray<I>>,
139+
{
140+
type Handle = HandleKind<Self>;
141+
142+
fn float_tensor(handle: TensorHandle<Self::Handle>) -> FloatTensor<Self> {
143+
match handle.handle {
144+
HandleKind::Float(handle) => handle,
145+
_ => panic!("Expected float handle, got {}", handle.handle.name()),
146+
}
147+
}
148+
149+
fn int_tensor(handle: TensorHandle<Self::Handle>) -> IntTensor<Self> {
150+
match handle.handle {
151+
HandleKind::Int(handle) => handle,
152+
_ => panic!("Expected int handle, got {}", handle.handle.name()),
153+
}
154+
}
155+
156+
fn bool_tensor(handle: TensorHandle<Self::Handle>) -> BoolTensor<Self> {
157+
match handle.handle {
158+
HandleKind::Bool(handle) => handle,
159+
_ => panic!("Expected bool handle, got {}", handle.handle.name()),
160+
}
161+
}
162+
163+
fn quantized_tensor(handle: TensorHandle<Self::Handle>) -> QuantizedTensor<Self> {
164+
match handle.handle {
165+
HandleKind::Quantized(handle) => handle,
166+
_ => panic!("Expected quantized handle, got {}", handle.handle.name()),
167+
}
168+
}
169+
170+
fn float_tensor_handle(tensor: FloatTensor<Self>) -> Self::Handle {
171+
HandleKind::Float(tensor)
172+
}
173+
174+
fn int_tensor_handle(tensor: IntTensor<Self>) -> Self::Handle {
175+
HandleKind::Int(tensor)
176+
}
177+
178+
fn bool_tensor_handle(tensor: BoolTensor<Self>) -> Self::Handle {
179+
HandleKind::Bool(tensor)
180+
}
181+
182+
fn quantized_tensor_handle(tensor: QuantizedTensor<Self>) -> Self::Handle {
183+
HandleKind::Quantized(tensor)
184+
}
185+
}
186+
187+
#[cfg(test)]
188+
mod tests {
189+
use super::*;
190+
use burn_backend::QTensorPrimitive;
191+
192+
#[test]
193+
fn should_support_dtypes() {
194+
type B = NdArray<f32>;
195+
let device = Default::default();
196+
197+
assert!(B::supports_dtype(&device, DType::F64));
198+
assert!(B::supports_dtype(&device, DType::F32));
199+
assert!(B::supports_dtype(&device, DType::Flex32));
200+
assert!(B::supports_dtype(&device, DType::I64));
201+
assert!(B::supports_dtype(&device, DType::I32));
202+
assert!(B::supports_dtype(&device, DType::I16));
203+
assert!(B::supports_dtype(&device, DType::I8));
204+
assert!(B::supports_dtype(&device, DType::U64));
205+
assert!(B::supports_dtype(&device, DType::U32));
206+
assert!(B::supports_dtype(&device, DType::U16));
207+
assert!(B::supports_dtype(&device, DType::U8));
208+
assert!(B::supports_dtype(&device, DType::Bool(BoolStore::Native)));
209+
assert!(B::supports_dtype(
210+
&device,
211+
DType::QFloat(NdArrayQTensor::default_scheme())
212+
));
213+
214+
assert!(!B::supports_dtype(&device, DType::F16));
215+
assert!(!B::supports_dtype(&device, DType::BF16));
216+
// QuantStore::U32 not supported
217+
assert!(!B::supports_dtype(
218+
&device,
219+
DType::QFloat(QuantScheme::default())
220+
));
221+
}
222+
}

0 commit comments

Comments
 (0)