From 186b27cc5996f6d3f17e41b137fa6be7d7593c03 Mon Sep 17 00:00:00 2001 From: gozdeg Date: Thu, 4 Jun 2026 15:51:15 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20Fix=20multi-gpu=20error=20(DataP?= =?UTF-8?q?arallel=20Err)=20for=20kongnet=20and=20nucleus=5Fdetector?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tiatoolbox/models/architecture/kongnet.py | 7 ++++++- tiatoolbox/models/engine/nucleus_detector.py | 4 ++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/tiatoolbox/models/architecture/kongnet.py b/tiatoolbox/models/architecture/kongnet.py index 3c75a273a..2bfa16028 100644 --- a/tiatoolbox/models/architecture/kongnet.py +++ b/tiatoolbox/models/architecture/kongnet.py @@ -862,9 +862,14 @@ def infer_batch( imgs = imgs.to(device=device, dtype=torch.float32) imgs = imgs.permute(0, 3, 1, 2) # to NCHW + try: + target_channels = model.target_channels + except AttributeError: + target_channels = model.module.target_channels + with torch.inference_mode(): logits = model(imgs) - target_logits = logits[:, model.target_channels, :, :] + target_logits = logits[:, target_channels, :, :] probs = torch.nn.functional.sigmoid(target_logits) probs = probs.permute(0, 2, 3, 1) # to NHWC diff --git a/tiatoolbox/models/engine/nucleus_detector.py b/tiatoolbox/models/engine/nucleus_detector.py index e2eedbaa9..55dfd4dcd 100644 --- a/tiatoolbox/models/engine/nucleus_detector.py +++ b/tiatoolbox/models/engine/nucleus_detector.py @@ -431,10 +431,10 @@ def post_process_wsi( # min_distance and postproc_tile_shape cannot be None here min_distance = kwargs.get("min_distance") if min_distance is None: - min_distance = self.model.min_distance + min_distance = self._get_model_attr("min_distance") tile_shape = kwargs.get("tile_shape") if tile_shape is None: - tile_shape = self.model.tile_shape + tile_shape = self._get_model_attr("tile_shape") # Add halo (overlap) around each block for post-processing depth_h = min_distance