1- use std:: fmt;
2- use std:: error:: Error ;
3-
4- /// Tensor error type
5- #[ derive( Debug ) ]
6- pub enum TensorError {
7- ShapeMismatch ,
8- InvalidType ,
9- UnsafeOperation ,
10- Other ( String ) ,
11- }
1+ use {
2+ crate :: markers:: { Compatible , DataType , Layout , Normalization , QuantParams } ,
3+ image:: DynamicImage ,
4+ } ;
5+
6+ /// Converts a `DynamicImage` to a tensor of type `D` with normalization `N` and layout `L`.
7+ pub fn to_tensor < D : DataType < QuantParams = ( ) > , N : Normalization , L : Layout > (
8+ img : & DynamicImage ,
9+ ) -> Vec < D :: Repr >
10+ where
11+ ( ) : Compatible < N , D > ,
12+ {
13+ let rgb = img. to_rgb8 ( ) ;
14+ let ( w, h) = rgb. dimensions ( ) ;
15+ let ( w, h) = ( w as usize , h as usize ) ;
16+
17+ let mut out = vec ! [ D :: Repr :: default ( ) ; w * h * L :: CHANNELS ] ;
1218
13- impl fmt:: Display for TensorError {
14- fn fmt ( & self , f : & mut fmt:: Formatter < ' _ > ) -> fmt:: Result {
15- match self {
16- TensorError :: ShapeMismatch => write ! ( f, "Shape mismatch" ) ,
17- TensorError :: InvalidType => write ! ( f, "Invalid tensor type" ) ,
18- TensorError :: UnsafeOperation => write ! ( f, "Unsafe operation attempted" ) ,
19- TensorError :: Other ( msg) => write ! ( f, "{}" , msg) ,
19+ // Single-pass: read pixel -> convert to f32 (0..1 or 0..255) -> normalize -> dtype -> store by L::index
20+ for y in 0 ..h {
21+ for x in 0 ..w {
22+ let p = rgb. get_pixel ( x as u32 , y as u32 ) ;
23+ // base floats
24+ let mut f = [ p[ 0 ] as f32 , p[ 1 ] as f32 , p[ 2 ] as f32 ] ;
25+ N :: apply ( & mut f) ;
26+
27+ // store
28+ for c in 0 ..3 {
29+ let idx = L :: index ( w, h, x, y, c) ;
30+ out[ idx] = D :: from_f32 ( f[ c] , & ( ) ) ;
31+ }
2032 }
2133 }
34+ out
2235}
2336
24- impl Error for TensorError { }
37+ /// Converts a `DynamicImage` to a tensor of type `D` with normalization `N`, layout `L`, and quantization parameters `quant_params`.
38+ pub fn to_tensor_with_quant < D : DataType < QuantParams = QuantParams > , N : Normalization , L : Layout > (
39+ img : & DynamicImage ,
40+ quant_params : D :: QuantParams ,
41+ ) -> Vec < D :: Repr >
42+ where
43+ ( ) : Compatible < N , D > ,
44+ {
45+ let rgb = img. to_rgb8 ( ) ;
46+ let ( w, h) = rgb. dimensions ( ) ;
47+ let ( w, h) = ( w as usize , h as usize ) ;
2548
26- /// Improved Tensor struct
27- #[ derive( Debug , Clone ) ]
28- pub struct Tensor < T > {
29- data : Vec < T > ,
30- shape : Vec < usize > ,
31- }
49+ let mut out = vec ! [ D :: Repr :: default ( ) ; w * h * L :: CHANNELS ] ;
50+
51+ // Single-pass: read pixel -> convert to f32 -> normalize -> dtype -> store by L::index
52+ for y in 0 ..h {
53+ for x in 0 ..w {
54+ let p = rgb. get_pixel ( x as u32 , y as u32 ) ;
55+ // base floats
56+ let mut f = [ p[ 0 ] as f32 , p[ 1 ] as f32 , p[ 2 ] as f32 ] ;
57+ N :: apply ( & mut f) ;
3258
33- impl < T > Tensor < T > {
34- pub fn new ( data : Vec < T > , shape : Vec < usize > ) -> Result < Self , TensorError > {
35- let expected_len : usize = shape . iter ( ) . product ( ) ;
36- if data . len ( ) != expected_len {
37- return Err ( TensorError :: ShapeMismatch ) ;
59+ // store
60+ for c in 0 .. 3 {
61+ let idx = L :: index ( w , h , x , y , c ) ;
62+ out [ idx ] = D :: from_f32 ( f [ c ] , & quant_params ) ;
63+ }
3864 }
39- Ok ( Tensor { data, shape } )
4065 }
66+ out
67+ }
4168
42- pub fn shape ( & self ) -> & [ usize ] {
43- & self . shape
44- }
69+ #[ cfg( test) ]
70+ mod tests {
71+ use {
72+ super :: * ,
73+ crate :: markers:: * ,
74+ image:: { DynamicImage , RgbImage } ,
75+ } ;
4576
46- pub fn data ( & self ) -> & [ T ] {
47- & self . data
77+ // Build a tiny RGB with distinctive channels:
78+ // base(x,y) = 10*x + y
79+ // R = base
80+ // G = 100 + base
81+ // B = 200 + base
82+ fn make_distinct_rgb ( w : u32 , h : u32 ) -> DynamicImage {
83+ let mut img = RgbImage :: new ( w, h) ;
84+ for y in 0 ..h {
85+ for x in 0 ..w {
86+ let base = 10 * x + y;
87+ img. put_pixel ( x, y, image:: Rgb ( [ base as u8 , ( 100 + base) as u8 , ( 200 + base) as u8 ] ) ) ;
88+ }
89+ }
90+ DynamicImage :: ImageRgb8 ( img)
4891 }
49- }
5092
51- // Usage of unsafe blocks should be minimized.
52- // Example: replace unsafe indexing with safe methods.
53- impl < T > Tensor < T > {
54- pub fn get ( & self , idx : usize ) -> Result < & T , TensorError > {
55- self . data . get ( idx) . ok_or ( TensorError :: Other ( "Index out of bounds" . to_string ( ) ) )
93+ #[ test]
94+ fn test_to_tensor ( ) {
95+ let img = make_distinct_rgb ( 2 , 2 ) ;
96+ let tensor: Vec < u8 > = to_tensor :: < U8 , Identity , CHW > ( & img) ;
97+ assert_eq ! ( tensor. len( ) , 2 * 2 * 3 ) ;
98+ // Add more test logic here as needed
5699 }
57- }
58100
59- // Any FFI or memory manipulation should be wrapped safely and errors propagated.
101+ #[ test]
102+ fn test_to_tensor_with_quant ( ) {
103+ let img = make_distinct_rgb ( 2 , 2 ) ;
104+ let quant_params = QuantParams { scale : 0.5 , zero_point : 128 } ;
105+ let tensor: Vec < u8 > = to_tensor_with_quant :: < U8 , Identity , CHW > ( & img, quant_params) ;
106+ assert_eq ! ( tensor. len( ) , 2 * 2 * 3 ) ;
107+ // Add more test logic here as needed
108+ }
109+ }
0 commit comments