Skip to content

Commit 45c20ef

Browse files
committed
Remove torch.jit constructs from network architectures
Replace @torch.jit.script, torch.jit.is_scripting() guards, @torch.jit.unused, and @torch.jit.export annotations with plain Python across all network modules. Convert torch.jit-unfriendly patterns (ModuleList iteration, Python containers) back to idiomatic code. Also adds .pt2 loading support to MMARS, removes outdated TorchScript notes from MetaTensor, and refactors _compute_acr_mask in the coil sensitivity model. Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
1 parent be18b63 commit 45c20ef

File tree

17 files changed

+98
-77
lines changed

17 files changed

+98
-77
lines changed

monai/apps/detection/networks/retinanet_network.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ def forward(self, images: Tensor) -> Any:
332332
features = self.feature_extractor(images)
333333
if isinstance(features, Tensor):
334334
feature_maps = [features]
335-
elif torch.jit.isinstance(features, dict[str, Tensor]):
335+
elif isinstance(features, dict):
336336
feature_maps = list(features.values())
337337
else:
338338
feature_maps = list(features)

monai/apps/detection/utils/anchor_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ def grid_anchors(self, grid_sizes: list[list[int]], strides: list[list[Tensor]])
257257
for axis in range(self.spatial_dims)
258258
]
259259

260-
# to support torchscript, cannot directly use torch.meshgrid(shifts_centers).
260+
# unpack before passing to torch.meshgrid for compatibility.
261261
shifts_centers = list(torch.meshgrid(shifts_centers[: self.spatial_dims], indexing="ij"))
262262

263263
for axis in range(self.spatial_dims):

monai/apps/mmars/mmars.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def load_from_mmar(
205205
mmar_dir: : target directory to store the MMAR, default is mmars subfolder under `torch.hub get_dir()`.
206206
progress: whether to display a progress bar when downloading the content.
207207
version: version number of the MMAR. Set it to `-1` to use `item[Keys.VERSION]`.
208-
map_location: pytorch API parameter for `torch.load` or `torch.jit.load`.
208+
map_location: pytorch API parameter for ``torch.load`` or ``torch.jit.load`` (legacy ``.ts`` files).
209209
pretrained: whether to load the pretrained weights after initializing a network module.
210210
weights_only: whether to load only the weights instead of initializing the network module and assign weights.
211211
model_key: a key to search in the model file or config file for the model dictionary.
@@ -232,12 +232,26 @@ def load_from_mmar(
232232
_model_file = model_dir / item.get(Keys.MODEL_FILE, model_file)
233233
logger.info(f'\n*** "{item.get(Keys.NAME)}" available at {model_dir}.')
234234

235-
# loading with `torch.jit.load`
235+
# loading with `torch.export.load` for .pt2 files
236+
if _model_file.name.endswith(".pt2"):
237+
if not pretrained:
238+
warnings.warn("Loading an ExportedProgram, 'pretrained' option ignored.", stacklevel=2)
239+
if weights_only:
240+
warnings.warn("Loading an ExportedProgram, 'weights_only' option ignored.", stacklevel=2)
241+
return torch.export.load(str(_model_file))
242+
243+
# loading with `torch.jit.load` for legacy .ts files
236244
if _model_file.name.endswith(".ts"):
245+
warnings.warn(
246+
"Loading TorchScript (.ts) models is deprecated since MONAI v1.5 and will be removed in v1.7. "
247+
"Use torch.export (.pt2) format instead.",
248+
FutureWarning,
249+
stacklevel=2,
250+
)
237251
if not pretrained:
238-
warnings.warn("Loading a ScriptModule, 'pretrained' option ignored.")
252+
warnings.warn("Loading a ScriptModule, 'pretrained' option ignored.", stacklevel=2)
239253
if weights_only:
240-
warnings.warn("Loading a ScriptModule, 'weights_only' option ignored.")
254+
warnings.warn("Loading a ScriptModule, 'weights_only' option ignored.", stacklevel=2)
241255
return torch.jit.load(_model_file, map_location=map_location)
242256

243257
# loading with `torch.load`

monai/apps/reconstruction/networks/nets/coil_sensitivity_model.py

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -85,32 +85,39 @@ def __init__(
8585
self.spatial_dims = spatial_dims
8686
self.coil_dim = coil_dim
8787

88-
def get_fully_sampled_region(self, mask: Tensor) -> tuple[int, int]:
88+
def _compute_acr_mask(self, mask: Tensor) -> Tensor:
8989
"""
90-
Extracts the size of the fully-sampled part of the kspace. Note that when a kspace
91-
is under-sampled, a part of its center is fully sampled. This part is called the Auto
92-
Calibration Region (ACR). ACR is used for sensitivity map computation.
90+
Compute a boolean mask for the Auto Calibration Region (ACR) — the contiguous
91+
fully-sampled center of the k-space sampling mask.
92+
93+
Uses pure tensor operations (``cumprod``) instead of while-loops so that
94+
the computation is compatible with ``torch.export``.
9395
9496
Args:
95-
mask: the under-sampling mask of shape (..., S, 1) where S denotes the sampling dimension
97+
mask: the under-sampling mask of shape (..., S, 1) where S denotes the sampling dimension.
9698
9799
Returns:
98-
A tuple containing
99-
(1) left index of the region
100-
(2) right index of the region
101-
102-
Note:
103-
Suppose the mask is of shape (1,1,20,1). If this function returns 8,12 as left and right
104-
indices, then it means that the fully-sampled center region has size 4 starting from 8 to 12.
100+
A boolean tensor broadcastable to ``masked_kspace`` that is True inside the ACR.
105101
"""
106-
left = right = mask.shape[-2] // 2
107-
while mask[..., right, :]:
108-
right += 1
102+
s_len = mask.shape[-2]
103+
center = s_len // 2
104+
105+
# Flatten to 1-D along the sampling axis
106+
m = mask.reshape(-1)[:s_len].bool()
107+
108+
# Count consecutive True values from center going right
109+
right_count = torch.cumprod(m[center:].int(), dim=0).sum()
110+
# Count consecutive True values from center going left (including center)
111+
left_count = torch.cumprod(m[: center + 1].flip(0).int(), dim=0).sum()
112+
num_low_freqs = left_count + right_count - 1
109113

110-
while mask[..., left, :]:
111-
left -= 1
114+
# Build a boolean mask over the sampling dimension
115+
start = (s_len - num_low_freqs + 1) // 2
116+
freq_idx = torch.arange(s_len, device=mask.device)
117+
acr_1d = (freq_idx >= start) & (freq_idx < start + num_low_freqs)
112118

113-
return left + 1, right
119+
# Reshape to (..., S, 1) so it broadcasts against masked_kspace
120+
return acr_1d.view(*([1] * (mask.ndim - 2)), s_len, 1)
114121

115122
def forward(self, masked_kspace: Tensor, mask: Tensor) -> Tensor:
116123
"""
@@ -122,13 +129,10 @@ def forward(self, masked_kspace: Tensor, mask: Tensor) -> Tensor:
122129
Returns:
123130
predicted coil sensitivity maps with shape (B,C,H,W,2) for 2D data or (B,C,H,W,D,2) for 3D data.
124131
"""
125-
left, right = self.get_fully_sampled_region(mask)
126-
num_low_freqs = right - left # size of the fully-sampled center
132+
acr_mask = self._compute_acr_mask(mask)
127133

128134
# take out the fully-sampled region and set the rest of the data to zero
129-
x = torch.zeros_like(masked_kspace)
130-
start = (mask.shape[-2] - num_low_freqs + 1) // 2 # this marks the start of center extraction
131-
x[..., start : start + num_low_freqs, :] = masked_kspace[..., start : start + num_low_freqs, :]
135+
x = masked_kspace * acr_mask
132136

133137
# apply inverse fourier to the extracted fully-sampled data
134138
x = ifftn_centered_t(x, spatial_dims=self.spatial_dims, is_complex=True)

monai/data/meta_tensor.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -84,13 +84,6 @@ class MetaTensor(MetaObj, torch.Tensor):
8484
assert torch.all(m2.affine == affine)
8585
8686
Notes:
87-
- Requires pytorch 1.9 or newer for full compatibility.
88-
- Older versions of pytorch (<=1.8), `torch.jit.trace(net, im)` may
89-
not work if `im` is of type `MetaTensor`. This can be resolved with
90-
`torch.jit.trace(net, im.as_tensor())`.
91-
- For pytorch < 1.8, sharing `MetaTensor` instances across processes may not be supported.
92-
- For pytorch < 1.9, next(iter(meta_tensor)) returns a torch.Tensor.
93-
see: https://github.com/pytorch/pytorch/issues/54457
9487
- A warning will be raised if in the constructor `affine` is not `None` and
9588
`meta` already contains the key `affine`.
9689
- You can query whether the `MetaTensor` is a batch with the `is_batch` attribute.

monai/networks/blocks/feature_pyramid_network.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def __init__(
206206
def get_result_from_inner_blocks(self, x: Tensor, idx: int) -> Tensor:
207207
"""
208208
This is equivalent to self.inner_blocks[idx](x),
209-
but torchscript doesn't support this yet
209+
but module indexing with a variable is used for compatibility
210210
"""
211211
num_blocks = len(self.inner_blocks)
212212
if idx < 0:
@@ -220,7 +220,7 @@ def get_result_from_inner_blocks(self, x: Tensor, idx: int) -> Tensor:
220220
def get_result_from_layer_blocks(self, x: Tensor, idx: int) -> Tensor:
221221
"""
222222
This is equivalent to self.layer_blocks[idx](x),
223-
but torchscript doesn't support this yet
223+
but module indexing with a variable is used for compatibility
224224
"""
225225
num_blocks = len(self.layer_blocks)
226226
if idx < 0:

monai/networks/blocks/selfattention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,13 +112,13 @@ def __init__(
112112

113113
if use_combined_linear:
114114
self.qkv = nn.Linear(self.hidden_input_size, self.inner_dim * 3, bias=qkv_bias)
115-
self.to_q = self.to_k = self.to_v = nn.Identity() # add to enable torchscript
115+
self.to_q = self.to_k = self.to_v = nn.Identity() # placeholder for unused code path
116116
self.input_rearrange = Rearrange("b h (qkv l d) -> qkv b l h d", qkv=3, l=num_heads)
117117
else:
118118
self.to_q = nn.Linear(self.hidden_input_size, self.inner_dim, bias=qkv_bias)
119119
self.to_k = nn.Linear(self.hidden_input_size, self.inner_dim, bias=qkv_bias)
120120
self.to_v = nn.Linear(self.hidden_input_size, self.inner_dim, bias=qkv_bias)
121-
self.qkv = nn.Identity() # add to enable torchscript
121+
self.qkv = nn.Identity() # placeholder for unused code path
122122
self.input_rearrange = Rearrange("b h (l d) -> b l h d", l=num_heads)
123123
self.out_rearrange = Rearrange("b l h d -> b h (l d)")
124124
self.drop_output = nn.Dropout(dropout_rate)

monai/networks/blocks/squeeze_and_excitation.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
8181
y = self.fc(y).view([b, c] + [1] * (x.ndim - 2))
8282
result = x * y
8383

84-
# Residual connection is moved here instead of providing an override of forward in ResidualSELayer since
85-
# Torchscript has an issue with using super().
84+
# Residual connection is applied here rather than in a forward override in ResidualSELayer.
8685
if self.add_residual:
8786
result += x
8887

monai/networks/layers/factories.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,8 +262,7 @@ def instance_nvfuser_factory(dim):
262262
It only supports 3d tensors as the input. It also requires to use with CUDA and non-Windows OS.
263263
In this function, if the required library `apex.normalization.InstanceNorm3dNVFuser` does not exist,
264264
`nn.InstanceNorm3d` will be returned instead.
265-
This layer is based on a customized autograd function, which is not supported in TorchScript currently.
266-
Please switch to use `nn.InstanceNorm3d` if TorchScript is necessary.
265+
This layer is based on a customized autograd function.
267266
268267
Please check the following link for more details about how to install `apex`:
269268
https://github.com/NVIDIA/apex#installation

monai/networks/layers/simplelayers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def __init__(self, *shape: int) -> None:
163163

164164
def forward(self, x: torch.Tensor) -> torch.Tensor:
165165
shape = list(self.shape)
166-
shape[0] = x.shape[0] # done this way for Torchscript
166+
shape[0] = x.shape[0]
167167
return x.reshape(shape)
168168

169169

0 commit comments

Comments
 (0)