Skip to content

Commit bb027f6

Browse files
committed
Refactors image loading and processing
Improves image loading error handling in distributed processing. Streamlines image processing by removing unnecessary conversions and using a more direct method for creating the tensor from the image data. Improves code readability by introducing formatting and consistent coding styles.
1 parent 50a0006 commit bb027f6

5 files changed

Lines changed: 43 additions & 73 deletions

File tree

Cargo.lock

Lines changed: 2 additions & 40 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,7 @@ image = { version = "0.25", default-features = false }
3838
imageops-kit = { git = "https://github.com/nusu-github/imageops-kit", version = "0.1.0" }
3939
imageproc = { version = "0.25", default-features = false }
4040
indicatif = "0.18"
41-
ndarray = "0.16"
42-
nshare = { version = "0.10", default-features = false, features = ["image", "ndarray"] }
41+
ndarray = "0.16.1"
4342
ort = "2.0.0-rc.10"
4443
parking_lot = { version = "0.12", features = ["hardware-lock-elision"] }
4544
rayon = "1.11"

src/distributed.rs

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -284,11 +284,12 @@ impl<M: BatchImageSegmentationModel + 'static> BatchProcessor for GpuInferenceBa
284284
batch_for_load
285285
.par_iter()
286286
.map(|job| {
287-
image::open(&job.payload.input_path)
288-
.map_err(|e| (job.id.clone(), e))
287+
image::open(&job.payload.input_path).map_err(|e| (job.id.clone(), e))
289288
})
290289
.collect()
291-
}).await.unwrap();
290+
})
291+
.await
292+
.unwrap();
292293

293294
let mut images = Vec::with_capacity(batch_size);
294295
let mut error_indices = Vec::new();
@@ -305,10 +306,10 @@ impl<M: BatchImageSegmentationModel + 'static> BatchProcessor for GpuInferenceBa
305306

306307
for idx in error_indices.iter().rev() {
307308
let mut failed_job = batch.remove(*idx);
308-
failed_job.payload.metadata.insert(
309-
"error".to_string(),
310-
"Image loading error".to_string(),
311-
);
309+
failed_job
310+
.payload
311+
.metadata
312+
.insert("error".to_string(), "Image loading error".to_string());
312313
failed_job.job_type = JobType::Postprocessing;
313314
}
314315

@@ -345,10 +346,9 @@ impl<M: BatchImageSegmentationModel + 'static> BatchProcessor for GpuInferenceBa
345346
return;
346347
}
347348

348-
job.payload.metadata.insert(
349-
"segmentation_complete".to_string(),
350-
"true".to_string(),
351-
);
349+
job.payload
350+
.metadata
351+
.insert("segmentation_complete".to_string(), "true".to_string());
352352
job.payload
353353
.metadata
354354
.insert("batch_size".to_string(), batch_size.to_string());
@@ -359,7 +359,9 @@ impl<M: BatchImageSegmentationModel + 'static> BatchProcessor for GpuInferenceBa
359359
job.job_type = JobType::Postprocessing;
360360
});
361361
batch
362-
}).await.unwrap();
362+
})
363+
.await
364+
.unwrap();
363365

364366
Ok(batch)
365367
})

src/model.rs

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ use image::{
1111
use imageops_kit::{AlphaMaskError, Image, ModifyAlphaExt, NormalizedFrom, PaddingExt};
1212
use imageproc::map::map_colors;
1313
use ndarray::prelude::*;
14-
use nshare::AsNdarray3;
1514
use ort::value::TensorRef;
1615
use ort::{
1716
execution_providers::{
@@ -102,10 +101,7 @@ impl Model {
102101
})?
103102
.with_optimized_model_path(cache_dir.join(format!(
104103
"{}.optimized.onnx",
105-
model_path
106-
.file_stem()
107-
.unwrap_or_default()
108-
.to_string_lossy()
104+
model_path.file_stem().unwrap_or_default().to_string_lossy()
109105
)))
110106
.map_err(|e| AnimeSegError::Model {
111107
operation: "optimized model path configuration".to_string(),
@@ -215,7 +211,12 @@ impl crate::traits::BatchImageSegmentationModel for Model {
215211
dimensions.push(img.dimensions());
216212
}
217213

218-
let batch_shape = (batch_size, 3, self.image_size as usize, self.image_size as usize);
214+
let batch_shape = (
215+
batch_size,
216+
3,
217+
self.image_size as usize,
218+
self.image_size as usize,
219+
);
219220
let mut batch_tensor = Array4::<f32>::zeros(batch_shape);
220221

221222
for (i, tensor) in batch_tensors.iter().enumerate() {
@@ -260,15 +261,27 @@ where
260261
let image = imageops::resize(image, image_size, image_size, FilterType::Lanczos3);
261262
let (w, h) = image.dimensions();
262263
let zero = S::zero();
263-
let (image, (x, y)) = image
264-
.to_square(Rgb([zero, zero, zero]))
264+
let (image, (x, y)) =
265+
image
266+
.to_square(Rgb([zero, zero, zero]))
267+
.map_err(|e| AnimeSegError::ImageProcessing {
268+
path: "unknown".to_string(),
269+
operation: "パディング追加".to_string(),
270+
source: Box::new(e),
271+
})?;
272+
273+
let (width, height) = image.dimensions();
274+
let channels = 3;
275+
let raw = image.into_raw();
276+
277+
let array_hwc = Array3::from_shape_vec((height as usize, width as usize, channels), raw)
265278
.map_err(|e| AnimeSegError::ImageProcessing {
266279
path: "unknown".to_string(),
267-
operation: "パディング追加".to_string(),
280+
operation: "Converting an image to an ndarray".to_string(),
268281
source: Box::new(e),
269282
})?;
283+
let tensor = array_hwc.permuted_axes([2, 0, 1]).insert_axis(Axis(0));
270284

271-
let tensor = image.as_ndarray3().slice_move(s![NewAxis, ..;-1, .., ..]);
272285
let max = S::DEFAULT_MAX_VALUE.into();
273286
let tensor = if max == (<f32 as Primitive>::DEFAULT_MAX_VALUE) {
274287
tensor.map(|v| (*v).into())

src/queue.rs

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,7 @@ impl QueueProvider for InMemoryQueueProvider {
119119
let notifier = self.get_or_create_notifier(queue_name);
120120
{
121121
let mut queues = self.queues.lock();
122-
let queue = queues
123-
.entry(queue_name.to_string())
124-
.or_default();
122+
let queue = queues.entry(queue_name.to_string()).or_default();
125123
queue.push_back(job);
126124
}
127125
notifier.notify_one();
@@ -133,9 +131,7 @@ impl QueueProvider for InMemoryQueueProvider {
133131
loop {
134132
{
135133
let mut queues = self.queues.lock();
136-
let queue = queues
137-
.entry(queue_name.to_string())
138-
.or_default();
134+
let queue = queues.entry(queue_name.to_string()).or_default();
139135

140136
if queue.len() < max_size {
141137
queue.push_back(job);
@@ -151,9 +147,7 @@ impl QueueProvider for InMemoryQueueProvider {
151147

152148
async fn dequeue(&self, queue_name: &str) -> Result<Option<Job>> {
153149
let mut queues = self.queues.lock();
154-
let queue = queues
155-
.entry(queue_name.to_string())
156-
.or_default();
150+
let queue = queues.entry(queue_name.to_string()).or_default();
157151
Ok(queue.pop_front())
158152
}
159153

0 commit comments

Comments
 (0)