Skip to content

Commit 7ff438c

Browse files
committed
feat: implementation fixed cargo
1 parent 41ed46f commit 7ff438c

1 file changed

Lines changed: 94 additions & 44 deletions

File tree

Lines changed: 94 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,59 +1,109 @@
1-
use std::fmt;
2-
use std::error::Error;
3-
4-
/// Tensor error type
5-
#[derive(Debug)]
6-
pub enum TensorError {
7-
ShapeMismatch,
8-
InvalidType,
9-
UnsafeOperation,
10-
Other(String),
11-
}
1+
use {
2+
crate::markers::{Compatible, DataType, Layout, Normalization, QuantParams},
3+
image::DynamicImage,
4+
};
5+
6+
/// Converts a `DynamicImage` to a tensor of type `D` with normalization `N` and layout `L`.
7+
pub fn to_tensor<D: DataType<QuantParams = ()>, N: Normalization, L: Layout>(
8+
img: &DynamicImage,
9+
) -> Vec<D::Repr>
10+
where
11+
(): Compatible<N, D>,
12+
{
13+
let rgb = img.to_rgb8();
14+
let (w, h) = rgb.dimensions();
15+
let (w, h) = (w as usize, h as usize);
16+
17+
let mut out = vec![D::Repr::default(); w * h * L::CHANNELS];
1218

13-
impl fmt::Display for TensorError {
14-
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
15-
match self {
16-
TensorError::ShapeMismatch => write!(f, "Shape mismatch"),
17-
TensorError::InvalidType => write!(f, "Invalid tensor type"),
18-
TensorError::UnsafeOperation => write!(f, "Unsafe operation attempted"),
19-
TensorError::Other(msg) => write!(f, "{}", msg),
19+
// Single-pass: read pixel -> convert to f32 (0..1 or 0..255) -> normalize -> dtype -> store by L::index
20+
for y in 0..h {
21+
for x in 0..w {
22+
let p = rgb.get_pixel(x as u32, y as u32);
23+
// base floats
24+
let mut f = [p[0] as f32, p[1] as f32, p[2] as f32];
25+
N::apply(&mut f);
26+
27+
// store
28+
for c in 0..3 {
29+
let idx = L::index(w, h, x, y, c);
30+
out[idx] = D::from_f32(f[c], &());
31+
}
2032
}
2133
}
34+
out
2235
}
2336

24-
impl Error for TensorError {}
37+
/// Converts a `DynamicImage` to a tensor of type `D` with normalization `N`, layout `L`, and quantization parameters `quant_params`.
38+
pub fn to_tensor_with_quant<D: DataType<QuantParams = QuantParams>, N: Normalization, L: Layout>(
39+
img: &DynamicImage,
40+
quant_params: D::QuantParams,
41+
) -> Vec<D::Repr>
42+
where
43+
(): Compatible<N, D>,
44+
{
45+
let rgb = img.to_rgb8();
46+
let (w, h) = rgb.dimensions();
47+
let (w, h) = (w as usize, h as usize);
2548

26-
/// Improved Tensor struct
27-
#[derive(Debug, Clone)]
28-
pub struct Tensor<T> {
29-
data: Vec<T>,
30-
shape: Vec<usize>,
31-
}
49+
let mut out = vec![D::Repr::default(); w * h * L::CHANNELS];
50+
51+
// Single-pass: read pixel -> convert to f32 -> normalize -> dtype -> store by L::index
52+
for y in 0..h {
53+
for x in 0..w {
54+
let p = rgb.get_pixel(x as u32, y as u32);
55+
// base floats
56+
let mut f = [p[0] as f32, p[1] as f32, p[2] as f32];
57+
N::apply(&mut f);
3258

33-
impl<T> Tensor<T> {
34-
pub fn new(data: Vec<T>, shape: Vec<usize>) -> Result<Self, TensorError> {
35-
let expected_len: usize = shape.iter().product();
36-
if data.len() != expected_len {
37-
return Err(TensorError::ShapeMismatch);
59+
// store
60+
for c in 0..3 {
61+
let idx = L::index(w, h, x, y, c);
62+
out[idx] = D::from_f32(f[c], &quant_params);
63+
}
3864
}
39-
Ok(Tensor { data, shape })
4065
}
66+
out
67+
}
4168

42-
pub fn shape(&self) -> &[usize] {
43-
&self.shape
44-
}
69+
#[cfg(test)]
70+
mod tests {
71+
use {
72+
super::*,
73+
crate::markers::*,
74+
image::{DynamicImage, RgbImage},
75+
};
4576

46-
pub fn data(&self) -> &[T] {
47-
&self.data
77+
// Build a tiny RGB with distinctive channels:
78+
// base(x,y) = 10*x + y
79+
// R = base
80+
// G = 100 + base
81+
// B = 200 + base
82+
fn make_distinct_rgb(w: u32, h: u32) -> DynamicImage {
83+
let mut img = RgbImage::new(w, h);
84+
for y in 0..h {
85+
for x in 0..w {
86+
let base = 10 * x + y;
87+
img.put_pixel(x, y, image::Rgb([base as u8, (100 + base) as u8, (200 + base) as u8]));
88+
}
89+
}
90+
DynamicImage::ImageRgb8(img)
4891
}
49-
}
5092

51-
// Usage of unsafe blocks should be minimized.
52-
// Example: replace unsafe indexing with safe methods.
53-
impl<T> Tensor<T> {
54-
pub fn get(&self, idx: usize) -> Result<&T, TensorError> {
55-
self.data.get(idx).ok_or(TensorError::Other("Index out of bounds".to_string()))
93+
#[test]
94+
fn test_to_tensor() {
95+
let img = make_distinct_rgb(2, 2);
96+
let tensor: Vec<u8> = to_tensor::<U8, Identity, CHW>(&img);
97+
assert_eq!(tensor.len(), 2 * 2 * 3);
98+
// Add more test logic here as needed
5699
}
57-
}
58100

59-
// Any FFI or memory manipulation should be wrapped safely and errors propagated.
101+
#[test]
102+
fn test_to_tensor_with_quant() {
103+
let img = make_distinct_rgb(2, 2);
104+
let quant_params = QuantParams { scale: 0.5, zero_point: 128 };
105+
let tensor: Vec<u8> = to_tensor_with_quant::<U8, Identity, CHW>(&img, quant_params);
106+
assert_eq!(tensor.len(), 2 * 2 * 3);
107+
// Add more test logic here as needed
108+
}
109+
}

0 commit comments

Comments
 (0)