diff --git a/tiatoolbox/models/architecture/kongnet.py b/tiatoolbox/models/architecture/kongnet.py index 3c75a273a..2bfa16028 100644 --- a/tiatoolbox/models/architecture/kongnet.py +++ b/tiatoolbox/models/architecture/kongnet.py @@ -862,9 +862,14 @@ def infer_batch( imgs = imgs.to(device=device, dtype=torch.float32) imgs = imgs.permute(0, 3, 1, 2) # to NCHW + try: + target_channels = model.target_channels + except AttributeError: + target_channels = model.module.target_channels + with torch.inference_mode(): logits = model(imgs) - target_logits = logits[:, model.target_channels, :, :] + target_logits = logits[:, target_channels, :, :] probs = torch.nn.functional.sigmoid(target_logits) probs = probs.permute(0, 2, 3, 1) # to NHWC diff --git a/tiatoolbox/models/engine/nucleus_detector.py b/tiatoolbox/models/engine/nucleus_detector.py index e2eedbaa9..55dfd4dcd 100644 --- a/tiatoolbox/models/engine/nucleus_detector.py +++ b/tiatoolbox/models/engine/nucleus_detector.py @@ -431,10 +431,10 @@ def post_process_wsi( # min_distance and postproc_tile_shape cannot be None here min_distance = kwargs.get("min_distance") if min_distance is None: - min_distance = self.model.min_distance + min_distance = self._get_model_attr("min_distance") tile_shape = kwargs.get("tile_shape") if tile_shape is None: - tile_shape = self.model.tile_shape + tile_shape = self._get_model_attr("tile_shape") # Add halo (overlap) around each block for post-processing depth_h = min_distance