1- use rustorch_core:: Tensor ;
2- use tch:: { self , Kind , Device as TchDevice } ;
31use anyhow:: Result ;
2+ use rustorch_core:: Tensor ;
43use std:: path:: Path ;
4+ use tch:: { self , Device as TchDevice , Kind } ;
55
66pub struct PyTorchAdapter ;
77
@@ -11,7 +11,7 @@ impl PyTorchAdapter {
1111 let storage = tensor. storage ( ) ;
1212 let data = storage. data ( ) ; // Read lock
1313 let shape: Vec < i64 > = tensor. shape ( ) . iter ( ) . map ( |& x| x as i64 ) . collect ( ) ;
14-
14+
1515 // Create a PyTorch tensor from the data
1616 // Currently assumes f32 and CPU
1717 let t = tch:: Tensor :: from_slice ( & data) ;
@@ -21,10 +21,10 @@ impl PyTorchAdapter {
2121 /// Convert a PyTorch Tensor to a RusTorch Tensor
2222 pub fn from_torch ( tensor : & tch:: Tensor ) -> Result < Tensor > {
2323 let size: Vec < usize > = tensor. size ( ) . iter ( ) . map ( |& x| x as usize ) . collect ( ) ;
24-
24+
2525 // Ensure the tensor is on CPU and is contiguous
2626 let cpu_tensor = tensor. to_device ( TchDevice :: Cpu ) . contiguous ( ) ;
27-
27+
2828 // Check if the tensor is Float (f32)
2929 if cpu_tensor. kind ( ) != Kind :: Float {
3030 // Cast to float if not
@@ -43,28 +43,34 @@ impl PyTorchAdapter {
4343 }
4444
4545 /// Load a PyTorch model (.pth) and return a dictionary of tensors (state_dict)
46- pub fn load_state_dict < P : AsRef < Path > > ( path : P ) -> Result < std:: collections:: HashMap < String , Tensor > > {
46+ pub fn load_state_dict < P : AsRef < Path > > (
47+ path : P ,
48+ ) -> Result < std:: collections:: HashMap < String , Tensor > > {
4749 let tensors = tch:: Tensor :: load_multi ( path) ?;
4850 let mut result = std:: collections:: HashMap :: new ( ) ;
49-
51+
5052 for ( name, tensor) in tensors {
5153 let rt_tensor = Self :: from_torch ( & tensor) ?;
5254 result. insert ( name, rt_tensor) ;
5355 }
54-
56+
5557 Ok ( result)
5658 }
5759
5860 /// Save a dictionary of RusTorch tensors to a .pth file
59- pub fn save_state_dict < P : AsRef < Path > > ( tensors : & std:: collections:: HashMap < String , Tensor > , path : P ) -> Result < ( ) > {
61+ pub fn save_state_dict < P : AsRef < Path > > (
62+ tensors : & std:: collections:: HashMap < String , Tensor > ,
63+ path : P ,
64+ ) -> Result < ( ) > {
6065 let mut named_tensors = Vec :: new ( ) ;
6166 for ( name, tensor) in tensors {
6267 let t = Self :: to_torch ( tensor) ;
6368 named_tensors. push ( ( name. clone ( ) , t) ) ;
6469 }
65-
70+
6671 // Convert to slice of (&str, Tensor)
67- let named_tensors_refs: Vec < ( & str , tch:: Tensor ) > = named_tensors. iter ( )
72+ let named_tensors_refs: Vec < ( & str , tch:: Tensor ) > = named_tensors
73+ . iter ( )
6874 . map ( |( n, t) | ( n. as_str ( ) , t. shallow_clone ( ) ) ) // shallow_clone is cheap
6975 . collect ( ) ;
7076
@@ -132,12 +138,12 @@ mod tests {
132138 assert_eq ! ( rt_tensor_back. shape( ) , shape. as_slice( ) ) ;
133139 assert_eq ! ( * rt_tensor_back. storage( ) . data( ) , data) ;
134140 }
135-
141+
136142 #[ test]
137143 fn test_ops ( ) {
138144 let t1 = Tensor :: new ( & vec ! [ 1.0 , 2.0 , 3.0 , 4.0 ] , & [ 2 , 2 ] ) ;
139145 let t2 = Tensor :: new ( & vec ! [ 1.0 , 1.0 , 1.0 , 1.0 ] , & [ 2 , 2 ] ) ;
140-
146+
141147 let res = ops:: add ( & t1, & t2) . unwrap ( ) ;
142148 assert_eq ! ( * res. storage( ) . data( ) , vec![ 2.0 , 3.0 , 4.0 , 5.0 ] ) ;
143149 }
0 commit comments