Skip to content

Commit dc5bdd3

Browse files
Refactor MONetBundleInferenceOperator type checking and formatting
- Improved the type checking for the model_network parameter to enhance readability and maintainability. - Adjusted formatting in the predict method for better clarity and consistency in multimodal data handling. - These changes contribute to cleaner code and improved functionality within the MONAI framework.
1 parent 6908b4c commit dc5bdd3

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

monai/deploy/operators/monet_bundle_inference_operator.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from monai.deploy.utils.importutil import optional_import
1717
from monai.transforms import ConcatItemsd, ResampleToMatch
1818
from monai.deploy.core.models.torch_model import TorchScriptModel
19+
1920
torch, _ = optional_import("torch", "1.10.2")
2021
MetaTensor, _ = optional_import("monai.data.meta_tensor", name="MetaTensor")
2122
__all__ = ["MONetBundleInferenceOperator"]
@@ -71,10 +72,14 @@ def _set_model_network(self, model_network):
7172
model_network : torch.nn.Module or torch.jit.ScriptModule
7273
The model network to be used for inference.
7374
"""
74-
if not isinstance(model_network, torch.nn.Module) and not torch.jit.isinstance(model_network, torch.jit.ScriptModule) and not isinstance(model_network, TorchScriptModel):
75+
if (
76+
not isinstance(model_network, torch.nn.Module)
77+
and not torch.jit.isinstance(model_network, torch.jit.ScriptModule)
78+
and not isinstance(model_network, TorchScriptModel)
79+
):
7580
raise TypeError("model_network must be an instance of torch.nn.Module or torch.jit.ScriptModule")
7681
self._nnunet_predictor.predictor.network = model_network
77-
82+
7883
def predict(self, data: Any, *args, **kwargs) -> Union[Image, Any, Tuple[Any, ...], Dict[Any, Any]]:
7984
"""Predicts output using the inferer. If multimodal data is provided as keyword arguments,
8085
it concatenates the data with the main input data."""
@@ -85,9 +90,8 @@ def predict(self, data: Any, *args, **kwargs) -> Union[Image, Any, Tuple[Any, ..
8590
multimodal_data = {"image": data}
8691
for key in kwargs.keys():
8792
if isinstance(kwargs[key], MetaTensor):
88-
multimodal_data[key] = ResampleToMatch(mode="bilinear")(kwargs[key], img_dst=data
89-
)
90-
data = ConcatItemsd(keys=list(multimodal_data.keys()),name="image")(multimodal_data)["image"]
93+
multimodal_data[key] = ResampleToMatch(mode="bilinear")(kwargs[key], img_dst=data)
94+
data = ConcatItemsd(keys=list(multimodal_data.keys()), name="image")(multimodal_data)["image"]
9195
if len(data.shape) == 4:
9296
data = data[None]
9397
return self._nnunet_predictor(data)

0 commit comments

Comments
 (0)