Skip to content

Commit b75e2a2

Browse files
author
“thucydides”
committed
style: Apply cargo fmt
1 parent 0c19630 commit b75e2a2

4 files changed

Lines changed: 161 additions & 105 deletions

File tree

rustorch-core/src/storage.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@ use std::sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard};
33

44
#[cfg(feature = "cuda")]
55
use cudarc::driver::CudaSlice;
6-
#[cfg(feature = "wgpu_backend")]
7-
use wgpu;
86
#[cfg(feature = "vulkan_backend")]
97
use vulkano::buffer::Subbuffer;
8+
#[cfg(feature = "wgpu_backend")]
9+
use wgpu;
1010

1111
#[derive(Debug, Clone, Copy, PartialEq)]
1212
pub enum Device {

rustorch-pytorch/src/lib.rs

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
use rustorch_core::Tensor;
2-
use tch::{self, Kind, Device as TchDevice};
31
use anyhow::Result;
2+
use rustorch_core::Tensor;
43
use std::path::Path;
4+
use tch::{self, Device as TchDevice, Kind};
55

66
pub struct PyTorchAdapter;
77

@@ -11,7 +11,7 @@ impl PyTorchAdapter {
1111
let storage = tensor.storage();
1212
let data = storage.data(); // Read lock
1313
let shape: Vec<i64> = tensor.shape().iter().map(|&x| x as i64).collect();
14-
14+
1515
// Create a PyTorch tensor from the data
1616
// Currently assumes f32 and CPU
1717
let t = tch::Tensor::from_slice(&data);
@@ -21,10 +21,10 @@ impl PyTorchAdapter {
2121
/// Convert a PyTorch Tensor to a RusTorch Tensor
2222
pub fn from_torch(tensor: &tch::Tensor) -> Result<Tensor> {
2323
let size: Vec<usize> = tensor.size().iter().map(|&x| x as usize).collect();
24-
24+
2525
// Ensure the tensor is on CPU and is contiguous
2626
let cpu_tensor = tensor.to_device(TchDevice::Cpu).contiguous();
27-
27+
2828
// Check if the tensor is Float (f32)
2929
if cpu_tensor.kind() != Kind::Float {
3030
// Cast to float if not
@@ -43,28 +43,34 @@ impl PyTorchAdapter {
4343
}
4444

4545
/// Load a PyTorch model (.pth) and return a dictionary of tensors (state_dict)
46-
pub fn load_state_dict<P: AsRef<Path>>(path: P) -> Result<std::collections::HashMap<String, Tensor>> {
46+
pub fn load_state_dict<P: AsRef<Path>>(
47+
path: P,
48+
) -> Result<std::collections::HashMap<String, Tensor>> {
4749
let tensors = tch::Tensor::load_multi(path)?;
4850
let mut result = std::collections::HashMap::new();
49-
51+
5052
for (name, tensor) in tensors {
5153
let rt_tensor = Self::from_torch(&tensor)?;
5254
result.insert(name, rt_tensor);
5355
}
54-
56+
5557
Ok(result)
5658
}
5759

5860
/// Save a dictionary of RusTorch tensors to a .pth file
59-
pub fn save_state_dict<P: AsRef<Path>>(tensors: &std::collections::HashMap<String, Tensor>, path: P) -> Result<()> {
61+
pub fn save_state_dict<P: AsRef<Path>>(
62+
tensors: &std::collections::HashMap<String, Tensor>,
63+
path: P,
64+
) -> Result<()> {
6065
let mut named_tensors = Vec::new();
6166
for (name, tensor) in tensors {
6267
let t = Self::to_torch(tensor);
6368
named_tensors.push((name.clone(), t));
6469
}
65-
70+
6671
// Convert to slice of (&str, Tensor)
67-
let named_tensors_refs: Vec<(&str, tch::Tensor)> = named_tensors.iter()
72+
let named_tensors_refs: Vec<(&str, tch::Tensor)> = named_tensors
73+
.iter()
6874
.map(|(n, t)| (n.as_str(), t.shallow_clone())) // shallow_clone is cheap
6975
.collect();
7076

@@ -132,12 +138,12 @@ mod tests {
132138
assert_eq!(rt_tensor_back.shape(), shape.as_slice());
133139
assert_eq!(*rt_tensor_back.storage().data(), data);
134140
}
135-
141+
136142
#[test]
137143
fn test_ops() {
138144
let t1 = Tensor::new(&vec![1.0, 2.0, 3.0, 4.0], &[2, 2]);
139145
let t2 = Tensor::new(&vec![1.0, 1.0, 1.0, 1.0], &[2, 2]);
140-
146+
141147
let res = ops::add(&t1, &t2).unwrap();
142148
assert_eq!(*res.storage().data(), vec![2.0, 3.0, 4.0, 5.0]);
143149
}

rustorch-vulkan/src/lib.rs

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1-
use rustorch_core::{Tensor, Storage};
2-
use vulkano::device::{Device, Queue, DeviceCreateInfo, QueueCreateInfo, QueueFlags};
3-
use vulkano::instance::{Instance, InstanceCreateInfo, InstanceExtensions};
4-
use vulkano::memory::allocator::{StandardMemoryAllocator, AllocationCreateInfo, MemoryTypeFilter};
1+
use anyhow::{Context, Result};
2+
use rustorch_core::{Storage, Tensor};
3+
use std::sync::Arc;
54
use vulkano::buffer::{Buffer, BufferCreateInfo, BufferUsage};
65
use vulkano::command_buffer::allocator::StandardCommandBufferAllocator;
76
use vulkano::descriptor_set::allocator::StandardDescriptorSetAllocator;
7+
use vulkano::device::{Device, DeviceCreateInfo, Queue, QueueCreateInfo, QueueFlags};
8+
use vulkano::instance::{Instance, InstanceCreateInfo, InstanceExtensions};
9+
use vulkano::memory::allocator::{AllocationCreateInfo, MemoryTypeFilter, StandardMemoryAllocator};
810
use vulkano::VulkanLibrary;
9-
use std::sync::Arc;
10-
use anyhow::{Result, Context};
1111

1212
pub struct VulkanContext {
1313
pub device: Arc<Device>,
@@ -29,7 +29,8 @@ impl VulkanContext {
2929
},
3030
..Default::default()
3131
},
32-
).context("Failed to create instance")?;
32+
)
33+
.context("Failed to create instance")?;
3334

3435
let physical_device = instance
3536
.enumerate_physical_devices()
@@ -52,12 +53,19 @@ impl VulkanContext {
5253
}],
5354
..Default::default()
5455
},
55-
).context("Failed to create device")?;
56+
)
57+
.context("Failed to create device")?;
5658

5759
let queue = queues.next().unwrap();
5860
let memory_allocator = Arc::new(StandardMemoryAllocator::new_default(device.clone()));
59-
let command_buffer_allocator = Arc::new(StandardCommandBufferAllocator::new(device.clone(), Default::default()));
60-
let descriptor_set_allocator = Arc::new(StandardDescriptorSetAllocator::new(device.clone(), Default::default()));
61+
let command_buffer_allocator = Arc::new(StandardCommandBufferAllocator::new(
62+
device.clone(),
63+
Default::default(),
64+
));
65+
let descriptor_set_allocator = Arc::new(StandardDescriptorSetAllocator::new(
66+
device.clone(),
67+
Default::default(),
68+
));
6169

6270
Ok(Self {
6371
device,
@@ -76,11 +84,13 @@ impl VulkanContext {
7684
..Default::default()
7785
},
7886
AllocationCreateInfo {
79-
memory_type_filter: MemoryTypeFilter::PREFER_DEVICE | MemoryTypeFilter::HOST_SEQUENTIAL_WRITE,
87+
memory_type_filter: MemoryTypeFilter::PREFER_DEVICE
88+
| MemoryTypeFilter::HOST_SEQUENTIAL_WRITE,
8089
..Default::default()
8190
},
8291
data.iter().cloned(),
83-
).context("Failed to create buffer")?;
92+
)
93+
.context("Failed to create buffer")?;
8494

8595
let storage = Storage::new_vulkan(Arc::new(buffer), 0);
8696
Ok(Tensor::new_with_storage(storage, shape))

0 commit comments

Comments
 (0)