|
1 | | -//! AdaWorld backend: implements burn's Backend trait. |
2 | | -//! |
3 | | -//! Delegates all tensor operations to ndarray + crate::simd. |
4 | | -//! This is the entry point — every burn model compiled with `Backend = AdaWorld` |
5 | | -//! runs on our SIMD dispatch with optional AttentionTable compiled attention. |
6 | | -//! |
7 | | -//! # Implementation Status |
8 | | -//! |
9 | | -//! The Backend trait requires ~200+ methods across 7 op traits. |
10 | | -//! Implementation strategy: core ops first (what Whisper/Llama need), |
11 | | -//! then expand coverage guided by burn-backend-tests. |
12 | | -//! |
13 | | -//! Required traits: |
14 | | -//! FloatTensorOps — 84 required methods (+ ~36 with defaults) |
15 | | -//! IntTensorOps — ~50 required methods |
16 | | -//! BoolTensorOps — ~30 required methods |
17 | | -//! ModuleOps — conv, pool, embedding, etc. |
18 | | -//! ActivationOps — relu, sigmoid, gelu (most have defaults) |
19 | | -//! QTensorOps — quantized tensor ops |
20 | | -//! TransactionOps — batch execution |
21 | | -//! |
22 | | -//! # Architecture |
23 | | -//! |
24 | | -//! ```text |
25 | | -//! burn::Tensor<AdaWorld, D> |
26 | | -//! ↓ (burn dispatches via Backend trait) |
27 | | -//! AdaWorld::float_matmul(lhs, rhs) |
28 | | -//! ↓ (check for compiled attention table) |
29 | | -//! ├── AttentionTable[q_idx][k_idx] → O(1) (if compiled) |
30 | | -//! └── ndarray general_mat_mul() → O(d) (fallback to BLAS) |
31 | | -//! ↓ (ndarray delegates to BLAS or matrixmultiply) |
32 | | -//! crate::simd::F32x16 → AVX-512 / AVX2 via LazyLock dispatch |
33 | | -//! ``` |
34 | | -
|
35 | | -use crate::tensor::AdaTensor; |
36 | | - |
37 | | -/// The AdaWorld backend. |
| 1 | +use crate::rand::NdArrayRng; |
| 2 | +use crate::{NdArrayQTensor, NdArrayTensor}; |
| 3 | +use crate::{ |
| 4 | + SharedArray, |
| 5 | + element::{FloatNdArrayElement, IntNdArrayElement, QuantElement}, |
| 6 | +}; |
| 7 | +use alloc::string::String; |
| 8 | +use burn_backend::quantization::{QuantLevel, QuantMode, QuantScheme, QuantStore, QuantValue}; |
| 9 | +use burn_backend::tensor::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor}; |
| 10 | +use burn_backend::{Backend, DType, DeviceId, DeviceOps}; |
| 11 | +use burn_ir::{BackendIr, HandleKind, TensorHandle}; |
| 12 | +use burn_std::BoolStore; |
| 13 | +use burn_std::stub::Mutex; |
| 14 | +use core::marker::PhantomData; |
| 15 | +use rand::SeedableRng; |
| 16 | + |
| 17 | +pub(crate) static SEED: Mutex<Option<NdArrayRng>> = Mutex::new(None); |
| 18 | + |
| 19 | +/// The device type for the ndarray backend. |
| 20 | +#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)] |
| 21 | +pub enum NdArrayDevice { |
| 22 | + /// The CPU device. |
| 23 | + #[default] |
| 24 | + Cpu, |
| 25 | +} |
| 26 | + |
| 27 | +impl DeviceOps for NdArrayDevice {} |
| 28 | + |
| 29 | +impl burn_backend::Device for NdArrayDevice { |
| 30 | + fn from_id(_device_id: DeviceId) -> Self { |
| 31 | + Self::Cpu |
| 32 | + } |
| 33 | + |
| 34 | + fn to_id(&self) -> DeviceId { |
| 35 | + DeviceId { |
| 36 | + type_id: 0, |
| 37 | + index_id: 0, |
| 38 | + } |
| 39 | + } |
| 40 | +} |
| 41 | + |
| 42 | +/// Tensor backend that uses the [ndarray](ndarray) crate for executing tensor operations. |
38 | 43 | /// |
39 | | -/// CPU-only. Uses adaworldapi/ndarray with crate::simd SIMD dispatch. |
40 | | -/// Feature `attention-table` enables bgz-tensor compiled attention path. |
41 | | -#[derive(Clone, Default, Debug)] |
42 | | -pub struct AdaWorld; |
43 | | - |
44 | | -/// CPU device (unit type — there's only one CPU). |
45 | | -#[derive(Clone, Default, Debug, PartialEq, Eq, Hash)] |
46 | | -pub struct CpuDevice; |
47 | | - |
48 | | -// NOTE: Full Backend trait implementation requires ~200+ methods across 7 traits. |
49 | | -// This is tracked as a multi-session effort: |
50 | | -// |
51 | | -// Session 1 (current): Crate skeleton + architecture + tensor primitive |
52 | | -// Session 2: FloatTensorOps core (from_data, matmul, add, mul, exp, reshape, transpose) |
53 | | -// Session 3: IntTensorOps + BoolTensorOps |
54 | | -// Session 4: ModuleOps (conv, embedding) + ActivationOps |
55 | | -// Session 5: QTensorOps + TransactionOps + burn-backend-tests |
56 | | -// |
57 | | -// The implementation follows burn-ndarray's pattern but uses: |
58 | | -// - crate::simd::F32x16 for element-wise ops (not macerator) |
59 | | -// - LazyLock<SimdDispatch> for runtime tier selection (not compile-time features) |
60 | | -// - Optional AttentionTable for compiled attention (unique to this backend) |
| 44 | +/// This backend is compatible with CPUs and can be compiled for almost any platform, including |
| 45 | +/// `wasm`, `arm`, and `x86`. |
| 46 | +#[derive(Clone, Copy, Default, Debug)] |
| 47 | +pub struct NdArray<E = f32, I = i64, Q = i8> |
| 48 | +where |
| 49 | + NdArrayTensor: From<SharedArray<E>>, |
| 50 | + NdArrayTensor: From<SharedArray<I>>, |
| 51 | +{ |
| 52 | + _e: PhantomData<E>, |
| 53 | + _i: PhantomData<I>, |
| 54 | + _q: PhantomData<Q>, |
| 55 | +} |
| 56 | + |
| 57 | +impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> Backend for NdArray<E, I, Q> |
| 58 | +where |
| 59 | + NdArrayTensor: From<SharedArray<E>>, |
| 60 | + NdArrayTensor: From<SharedArray<I>>, |
| 61 | +{ |
| 62 | + type Device = NdArrayDevice; |
| 63 | + |
| 64 | + type FloatTensorPrimitive = NdArrayTensor; |
| 65 | + type FloatElem = E; |
| 66 | + |
| 67 | + type IntTensorPrimitive = NdArrayTensor; |
| 68 | + type IntElem = I; |
| 69 | + |
| 70 | + type BoolTensorPrimitive = NdArrayTensor; |
| 71 | + type BoolElem = bool; |
| 72 | + |
| 73 | + type QuantizedTensorPrimitive = NdArrayQTensor; |
| 74 | + |
| 75 | + fn ad_enabled(_device: &Self::Device) -> bool { |
| 76 | + false |
| 77 | + } |
| 78 | + |
| 79 | + fn name(_device: &Self::Device) -> String { |
| 80 | + String::from("ndarray") |
| 81 | + } |
| 82 | + |
| 83 | + fn seed(_device: &Self::Device, seed: u64) { |
| 84 | + let rng = NdArrayRng::seed_from_u64(seed); |
| 85 | + let mut seed = SEED.lock().unwrap(); |
| 86 | + *seed = Some(rng); |
| 87 | + } |
| 88 | + |
| 89 | + fn dtype_usage(_device: &Self::Device, dtype: DType) -> burn_backend::DTypeUsageSet { |
| 90 | + match dtype { |
| 91 | + DType::F64 |
| 92 | + | DType::F32 |
| 93 | + | DType::Flex32 |
| 94 | + | DType::I64 |
| 95 | + | DType::I32 |
| 96 | + | DType::I16 |
| 97 | + | DType::I8 |
| 98 | + | DType::U64 |
| 99 | + | DType::U32 |
| 100 | + | DType::U16 |
| 101 | + | DType::U8 |
| 102 | + | DType::Bool(BoolStore::Native) => burn_backend::DTypeUsage::general(), |
| 103 | + DType::F16 | DType::BF16 | DType::Bool(_) => burn_backend::DTypeUsageSet::empty(), |
| 104 | + DType::QFloat(scheme) => { |
| 105 | + match scheme { |
| 106 | + QuantScheme { |
| 107 | + level: QuantLevel::Tensor | QuantLevel::Block(_), |
| 108 | + mode: QuantMode::Symmetric, |
| 109 | + #[cfg(not(feature = "export_tests"))] |
| 110 | + value: QuantValue::Q8F | QuantValue::Q8S, |
| 111 | + // For tests, "native" sub-byte quant serves as a reference for value equality. |
| 112 | + // Values are stored as i8 regardless. |
| 113 | + #[cfg(feature = "export_tests")] |
| 114 | + value: |
| 115 | + QuantValue::Q8F |
| 116 | + | QuantValue::Q8S |
| 117 | + | QuantValue::Q4F |
| 118 | + | QuantValue::Q4S |
| 119 | + | QuantValue::Q2F |
| 120 | + | QuantValue::Q2S, |
| 121 | + store: QuantStore::Native, |
| 122 | + .. |
| 123 | + } => burn_backend::DTypeUsage::general(), |
| 124 | + _scheme => burn_backend::DTypeUsageSet::empty(), |
| 125 | + } |
| 126 | + } |
| 127 | + } |
| 128 | + } |
| 129 | + |
| 130 | + fn device_count(_: u16) -> usize { |
| 131 | + 1 |
| 132 | + } |
| 133 | +} |
| 134 | + |
| 135 | +impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> BackendIr for NdArray<E, I, Q> |
| 136 | +where |
| 137 | + NdArrayTensor: From<SharedArray<E>>, |
| 138 | + NdArrayTensor: From<SharedArray<I>>, |
| 139 | +{ |
| 140 | + type Handle = HandleKind<Self>; |
| 141 | + |
| 142 | + fn float_tensor(handle: TensorHandle<Self::Handle>) -> FloatTensor<Self> { |
| 143 | + match handle.handle { |
| 144 | + HandleKind::Float(handle) => handle, |
| 145 | + _ => panic!("Expected float handle, got {}", handle.handle.name()), |
| 146 | + } |
| 147 | + } |
| 148 | + |
| 149 | + fn int_tensor(handle: TensorHandle<Self::Handle>) -> IntTensor<Self> { |
| 150 | + match handle.handle { |
| 151 | + HandleKind::Int(handle) => handle, |
| 152 | + _ => panic!("Expected int handle, got {}", handle.handle.name()), |
| 153 | + } |
| 154 | + } |
| 155 | + |
| 156 | + fn bool_tensor(handle: TensorHandle<Self::Handle>) -> BoolTensor<Self> { |
| 157 | + match handle.handle { |
| 158 | + HandleKind::Bool(handle) => handle, |
| 159 | + _ => panic!("Expected bool handle, got {}", handle.handle.name()), |
| 160 | + } |
| 161 | + } |
| 162 | + |
| 163 | + fn quantized_tensor(handle: TensorHandle<Self::Handle>) -> QuantizedTensor<Self> { |
| 164 | + match handle.handle { |
| 165 | + HandleKind::Quantized(handle) => handle, |
| 166 | + _ => panic!("Expected quantized handle, got {}", handle.handle.name()), |
| 167 | + } |
| 168 | + } |
| 169 | + |
| 170 | + fn float_tensor_handle(tensor: FloatTensor<Self>) -> Self::Handle { |
| 171 | + HandleKind::Float(tensor) |
| 172 | + } |
| 173 | + |
| 174 | + fn int_tensor_handle(tensor: IntTensor<Self>) -> Self::Handle { |
| 175 | + HandleKind::Int(tensor) |
| 176 | + } |
| 177 | + |
| 178 | + fn bool_tensor_handle(tensor: BoolTensor<Self>) -> Self::Handle { |
| 179 | + HandleKind::Bool(tensor) |
| 180 | + } |
| 181 | + |
| 182 | + fn quantized_tensor_handle(tensor: QuantizedTensor<Self>) -> Self::Handle { |
| 183 | + HandleKind::Quantized(tensor) |
| 184 | + } |
| 185 | +} |
| 186 | + |
| 187 | +#[cfg(test)] |
| 188 | +mod tests { |
| 189 | + use super::*; |
| 190 | + use burn_backend::QTensorPrimitive; |
| 191 | + |
| 192 | + #[test] |
| 193 | + fn should_support_dtypes() { |
| 194 | + type B = NdArray<f32>; |
| 195 | + let device = Default::default(); |
| 196 | + |
| 197 | + assert!(B::supports_dtype(&device, DType::F64)); |
| 198 | + assert!(B::supports_dtype(&device, DType::F32)); |
| 199 | + assert!(B::supports_dtype(&device, DType::Flex32)); |
| 200 | + assert!(B::supports_dtype(&device, DType::I64)); |
| 201 | + assert!(B::supports_dtype(&device, DType::I32)); |
| 202 | + assert!(B::supports_dtype(&device, DType::I16)); |
| 203 | + assert!(B::supports_dtype(&device, DType::I8)); |
| 204 | + assert!(B::supports_dtype(&device, DType::U64)); |
| 205 | + assert!(B::supports_dtype(&device, DType::U32)); |
| 206 | + assert!(B::supports_dtype(&device, DType::U16)); |
| 207 | + assert!(B::supports_dtype(&device, DType::U8)); |
| 208 | + assert!(B::supports_dtype(&device, DType::Bool(BoolStore::Native))); |
| 209 | + assert!(B::supports_dtype( |
| 210 | + &device, |
| 211 | + DType::QFloat(NdArrayQTensor::default_scheme()) |
| 212 | + )); |
| 213 | + |
| 214 | + assert!(!B::supports_dtype(&device, DType::F16)); |
| 215 | + assert!(!B::supports_dtype(&device, DType::BF16)); |
| 216 | + // QuantStore::U32 not supported |
| 217 | + assert!(!B::supports_dtype( |
| 218 | + &device, |
| 219 | + DType::QFloat(QuantScheme::default()) |
| 220 | + )); |
| 221 | + } |
| 222 | +} |
0 commit comments