RusTorch is designed as a modular, high-performance deep learning framework. This document outlines its core architectural components and design philosophy.
RusTorch follows a layered architecture similar to PyTorch but leverages Rust's type system and ownership model for safety and concurrency.
graph TD
A[User Code] --> B[RusTorch NN]
B --> C[RusTorch Core]
C --> D[Autograd Engine]
C --> E[Storage & Allocator]
C --> F[JIT Compiler]
E --> G[CPU Backend - Rayon]
E --> H[CUDA Backend - Cudarc]
E --> I[Metal/Wasm Backends]
This crate provides the Tensor struct, which is the central data structure.
- Tensor: A wrapper around
Arc<TensorImpl>. It provides a view into the underlying storage.- Shape & Strides: Handles N-dimensional indexing and broadcasting.
- Storage: A contiguous memory block (Vec on CPU, CudaSlice on GPU).
- Autograd: Each tensor holds a
Mutex<Option<Tensor>>for gradients and anOption<Arc<dyn BackwardOp>>for the computational graph.
- Autograd Engine: Implements reverse-mode automatic differentiation. It builds a dynamic graph (DAG) during the forward pass. When
.backward()is called, it traverses the graph in reverse topological order (currently recursive DFS) to compute gradients. - JIT Compiler: An experimental module (
jit.rs) that captures the computation graph into an Intermediate Representation (IR). It performs static optimizations like:- Operator Fusion: Combining
Conv2d+ReLUinto a single kernel to reduce memory bandwidth. - Dead Code Elimination: Removing unused graph nodes.
- Operator Fusion: Combining
Built on top of core, this crate defines the Module trait and implements common layers.
- Module Trait: Defines the interface for all layers.
forward(&self, input: &Tensor) -> Tensorparameters(&self) -> Vec<Tensor>
- Layers:
Linear,Conv2d: Standard learnable layers.RNN,LSTM,GRU: Recurrent layers with state management.Transformer: Multi-head attention and encoder blocks.
- Optimizers:
SGDandAdamimplement parameter updates. They track parameter references and apply gradients. - Data:
DatasetandDataLoaderprovide multi-threaded data pipeline primitives.
- CPU: Uses
Rayonfor work-stealing parallelism. Operations likematmulandconv2dare parallelized across batch and channel dimensions. - CUDA: (In progress) Uses
cudarcto manage GPU memory and launch PTX kernels. The architecture allows seamless switching between devices via theDeviceabstraction.
- DistributedDataParallel (DDP): Implements data parallelism by replicating the model on multiple workers.
- AllReduce: Gradients are synchronized across workers using a ring-reduction algorithm (currently simulated, extensible to MPI/NCCL).
- Safety First: We use
ArcandMutexto manage shared state (like gradients) safely across threads. Rust's borrow checker prevents data races. - Zero-Cost Abstractions: High-level APIs (like
Module) compile down to efficient low-level code. We avoid overhead where possible. - Interoperability: The API mirrors PyTorch to minimize the learning curve for Python users.
- Extensibility: The
BackwardOptrait allows users to define custom differentiable operations easily.
- Dynamic Graph Optimization: enhancing the JIT to support dynamic control flow.
- XLA Integration: Lowering the IR to XLA/HLO for broader hardware support (TPUs).
- Quantization: Native support for int8 inference.