A small Rust library providing unified traits for tensor operations in machine learning models.
- Standardized tensor trait interfaces
- Adapters for popular tensor crates
- Custom error handling for tensor ops
Add to your Cargo.toml:
tensor-trait = "0.1"use tensor_trait::tensor_trait::Tensor;
struct MyTensor {
data: Vec<f32>,
shape: Vec<usize>,
}
impl Tensor for MyTensor {
type Elem = f32;
fn ndim(&self) -> usize {
self.shape.len()
}
fn shape(&self) -> &[usize] {
&self.shape
}
// Other trait methods...
}use tensor_trait::adapters::ndarray_adapter::NdArrayTensor;
use ndarray::Array2;
let arr = Array2::<f32>::zeros((2, 3));
let tensor = NdArrayTensor::from(arr);
assert_eq!(tensor.ndim(), 2);use tensor_trait::error::TensorOpError;
fn do_tensor_op() -> Result<(), TensorOpError> {
// ...something that can fail
Err(TensorOpError::ShapeMismatch)
}Detailed trait docs and adapters can be found in the crate docs.