@@ -85,32 +85,39 @@ def __init__(
8585 self .spatial_dims = spatial_dims
8686 self .coil_dim = coil_dim
8787
88- def get_fully_sampled_region (self , mask : Tensor ) -> tuple [ int , int ] :
88+ def _compute_acr_mask (self , mask : Tensor ) -> Tensor :
8989 """
90- Extracts the size of the fully-sampled part of the kspace. Note that when a kspace
91- is under-sampled, a part of its center is fully sampled. This part is called the Auto
92- Calibration Region (ACR). ACR is used for sensitivity map computation.
90+ Compute a boolean mask for the Auto Calibration Region (ACR) — the contiguous
91+ fully-sampled center of the k-space sampling mask.
92+
93+ Uses pure tensor operations (``cumprod``) instead of while-loops so that
94+ the computation is compatible with ``torch.export``.
9395
9496 Args:
95- mask: the under-sampling mask of shape (..., S, 1) where S denotes the sampling dimension
97+ mask: the under-sampling mask of shape (..., S, 1) where S denotes the sampling dimension.
9698
9799 Returns:
98- A tuple containing
99- (1) left index of the region
100- (2) right index of the region
101-
102- Note:
103- Suppose the mask is of shape (1,1,20,1). If this function returns 8,12 as left and right
104- indices, then it means that the fully-sampled center region has size 4 starting from 8 to 12.
100+ A boolean tensor broadcastable to ``masked_kspace`` that is True inside the ACR.
105101 """
106- left = right = mask .shape [- 2 ] // 2
107- while mask [..., right , :]:
108- right += 1
102+ s_len = mask .shape [- 2 ]
103+ center = s_len // 2
104+
105+ # Flatten to 1-D along the sampling axis
106+ m = mask .reshape (- 1 )[:s_len ].bool ()
107+
108+ # Count consecutive True values from center going right
109+ right_count = torch .cumprod (m [center :].int (), dim = 0 ).sum ()
110+ # Count consecutive True values from center going left (including center)
111+ left_count = torch .cumprod (m [: center + 1 ].flip (0 ).int (), dim = 0 ).sum ()
112+ num_low_freqs = left_count + right_count - 1
109113
110- while mask [..., left , :]:
111- left -= 1
114+ # Build a boolean mask over the sampling dimension
115+ start = (s_len - num_low_freqs + 1 ) // 2
116+ freq_idx = torch .arange (s_len , device = mask .device )
117+ acr_1d = (freq_idx >= start ) & (freq_idx < start + num_low_freqs )
112118
113- return left + 1 , right
119+ # Reshape to (..., S, 1) so it broadcasts against masked_kspace
120+ return acr_1d .view (* ([1 ] * (mask .ndim - 2 )), s_len , 1 )
114121
115122 def forward (self , masked_kspace : Tensor , mask : Tensor ) -> Tensor :
116123 """
@@ -122,13 +129,10 @@ def forward(self, masked_kspace: Tensor, mask: Tensor) -> Tensor:
122129 Returns:
123130 predicted coil sensitivity maps with shape (B,C,H,W,2) for 2D data or (B,C,H,W,D,2) for 3D data.
124131 """
125- left , right = self .get_fully_sampled_region (mask )
126- num_low_freqs = right - left # size of the fully-sampled center
132+ acr_mask = self ._compute_acr_mask (mask )
127133
128134 # take out the fully-sampled region and set the rest of the data to zero
129- x = torch .zeros_like (masked_kspace )
130- start = (mask .shape [- 2 ] - num_low_freqs + 1 ) // 2 # this marks the start of center extraction
131- x [..., start : start + num_low_freqs , :] = masked_kspace [..., start : start + num_low_freqs , :]
135+ x = masked_kspace * acr_mask
132136
133137 # apply inverse fourier to the extracted fully-sampled data
134138 x = ifftn_centered_t (x , spatial_dims = self .spatial_dims , is_complex = True )
0 commit comments