Skip to content

Enable batch inference for image slices#1298

Closed
tomathosauce wants to merge 2 commits into
obss:mainfrom
tomathosauce:batch_implementation1
Closed

Enable batch inference for image slices#1298
tomathosauce wants to merge 2 commits into
obss:mainfrom
tomathosauce:batch_implementation1

Conversation

@tomathosauce
Copy link
Copy Markdown

This PR is my attempt at implementing batch inference over all slices of an image. I decided to work on this because I tried the existing pull requests addressing this feature, but unfortunately they did not work in my case.

My current implementation is still somewhat hacky and it is limited to Ultralytics models, but I made sure to include tests demonstrating that it works as intended. I plan to continue refining it as time permits, and I hope others in the community find it useful.

@golden452
Copy link
Copy Markdown

golden452 commented Mar 2, 2026

Thanks a lot for your work!

I can't get it working. Could you help look into this issue?

failed tests:

=========================== short test summary info ============================
FAILED tests/test_predict.py::test_get_prediction_automodel_yolo11 - ValueError: Number of predictions (1) and shift_amount_list (2) do not match.
FAILED tests/test_predict.py::test_prediction_category_remapping - ValueError: Number of predictions (1) and shift_amount_list (2) do not match.
FAILED tests/test_predict.py::test_get_prediction_yolo11 - ValueError: Number of predictions (1) and shift_amount_list (2) do not match.
FAILED tests/test_predict.py::test_ultralytics_yolo11n_prediction - ZeroDivisionError: float division by zero
FAILED tests/test_predict.py::test_video_prediction - ZeroDivisionError: float division by zero
FAILED tests/test_huggingface_model.py::test_get_prediction_huggingface - assert 0 == 10
 +  where 0 = len([])
FAILED tests/test_huggingface_model.py::test_get_prediction_automodel_huggingface - assert 0 == 10
 +  where 0 = len([])
FAILED tests/test_huggingface_model.py::test_get_sliced_prediction_huggingface - NotImplementedError
FAILED tests/test_torchvision.py::TestTorchVisionDetectionModel::test_get_prediction_torchvision - assert 0 == 7
 +  where 0 = len([])
FAILED tests/test_torchvision.py::TestTorchVisionDetectionModel::test_get_sliced_prediction_torchvision - NotImplementedError

errors I got: (on both yolo11 and yolo26 models)

---------------------------------------------------------------------------
ZeroDivisionError                         Traceback (most recent call last)
Cell In[6], line 3
      1 for image_path in images_to_process:
      2     # get batch sliced prediction
----> 3     prediction_result1 = get_sliced_prediction(
      4         image=image_path,
      5         detection_model=yolo26_detection_model,
      6         slice_height=SLICE_SIZE,
      7         slice_width=SLICE_SIZE,
      8         overlap_height_ratio=OVERLAP,
      9         overlap_width_ratio=OVERLAP,
     10         perform_standard_pred=False,
     11         postprocess_type='GREEDYNMM',
     12         postprocess_match_threshold='0.1',
     13         postprocess_match_metric='IOS',
     14         postprocess_class_agnostic=True,
     15         num_batch=4
     16     )

File [~/Desktop/tomato_sahi/sahi_batch/sahi/predict.py:340](http://localhost:8888/lab/tree/tomato_sahi/sahi_batch/tomato_sahi/sahi_batch/sahi/predict.py#line=339), in get_sliced_prediction(image, detection_model, slice_height, slice_width, overlap_height_ratio, overlap_width_ratio, perform_standard_pred, postprocess_type, postprocess_match_metric, postprocess_match_threshold, postprocess_class_agnostic, verbose, merge_buffer_length, auto_slice_resolution, slice_export_prefix, slice_dir, exclude_classes_by_name, exclude_classes_by_id, progress_bar, progress_callback, num_batch)
    337     shift_amount_list.append(slice_image_result.starting_pixels[idx])
    339 # perform batch prediction
--> 340 prediction_result = get_prediction(
    341     image=image_list,
    342     detection_model=detection_model,
    343     shift_amount=shift_amount_list,
    344     full_shape=[
    345         slice_image_result.original_image_height,
    346         slice_image_result.original_image_width,
    347     ],
    348     exclude_classes_by_name=exclude_classes_by_name,
    349     exclude_classes_by_id=exclude_classes_by_id,
    350 )
    352 if isinstance(prediction_result, list):
    353     for prediction in prediction_result:

File [~/Desktop/tomato_sahi/sahi_batch/sahi/predict.py:152](http://localhost:8888/lab/tree/tomato_sahi/sahi_batch/tomato_sahi/sahi_batch/sahi/predict.py#line=151), in get_prediction(image, detection_model, shift_amount, full_shape, postprocess, verbose, exclude_classes_by_name, exclude_classes_by_id)
    149     object_prediction_dict[shift_amount_index].append(obj_preds)
    151 time_end = time.perf_counter() - time_start
--> 152 durations_in_seconds["postprocess"] = time_end [/](http://localhost:8888/) len(object_prediction_dict)
    154 if verbose == 1:
    155     print(
    156         "Prediction performed in",
    157         durations_in_seconds["prediction"],
    158         "seconds.",
    159     )

ZeroDivisionError: float division by zero

Copy link
Copy Markdown

@JiwaniZakir JiwaniZakir left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In ultralytics.py, the else branch of perform_per_image_batch_inference (the standard detection path, no mask/OBB) has a clear bug: it iterates over prediction_result (the full batch) instead of image_result (the single image's result from the outer loop). This means for every image in the batch, processed_result will contain boxes from all images, causing temp_shift_idxs to be wildly inflated and predictions to be duplicated across every slice. It should be [result.boxes.data for result in image_result], matching the pattern used in the has_mask branch.

Additionally, self._original_shape = image_list[0].shape assumes all slices share the same dimensions, which breaks for boundary slices that are typically smaller than interior ones — this could silently corrupt coordinate transformations downstream.

The type annotation temp_results_list: torch.Tensor | np.ndarray[...] = [] is also misleading; the variable is always a plain list and the union type implies it could be a tensor or array, which it never is.

onuralpszr added a commit that referenced this pull request Mar 31, 2026
Override perform_batch_inference to pass image lists directly to YOLO
model for true batch processing. Extract _extract_predictions helper
to avoid duplication. Store per-image shapes for correct mask resizing
in batch mode.

Closes #1113, closes #1298, closes #1318

Signed-off-by: Onuralp SEZER <thunderbirdtr@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants