@@ -89,80 +89,7 @@ def compute_total_samples(
8989 return total_samples , samples_per_volume
9090
9191
92- def calculate_inference_grid (
93- volume_shape : Tuple [int , int , int ],
94- patch_size : Tuple [int , int , int ],
95- stride : Tuple [int , int , int ],
96- ) -> Tuple [np .ndarray , Tuple [int , int , int ]]:
97- """
98- Calculate grid of patch positions for sliding-window inference.
99-
100- This function generates all patch positions needed to cover a volume
101- with overlapping patches, using the specified stride.
102-
103- Args:
104- volume_shape: Shape of the input volume (D, H, W)
105- patch_size: Size of each patch (D, H, W)
106- stride: Stride between patch centers (D, H, W)
107-
108- Returns:
109- positions: Array of shape (N, 3) containing (z, y, x) start positions
110- grid_shape: Tuple (num_z, num_y, num_x) indicating grid dimensions
111-
112- Examples:
113- >>> volume_shape = (256, 256, 256)
114- >>> patch_size = (128, 128, 128)
115- >>> stride = (64, 64, 64)
116- >>> positions, grid = calculate_inference_grid(volume_shape, patch_size, stride)
117- >>> print(f"Grid shape: {grid}")
118- >>> # Grid shape: (3, 3, 3)
119- >>> print(f"Total patches: {len(positions)}")
120- >>> # Total patches: 27
121-
122- Note:
123- The last patch in each dimension is "tucked in" to ensure it fits
124- within the volume boundaries, matching the legacy v1 behavior.
125- """
126- volume_shape = np .array (volume_shape )
127- patch_size = np .array (patch_size )
128- stride = np .array (stride )
129-
130- # Calculate grid dimensions
131- grid_shape = count_volume (volume_shape , patch_size , stride )
132- grid_shape = tuple (grid_shape )
133-
134- positions = []
135-
136- # Generate all grid positions
137- for z_idx in range (grid_shape [0 ]):
138- for y_idx in range (grid_shape [1 ]):
139- for x_idx in range (grid_shape [2 ]):
140- # Calculate position with boundary handling
141- # Normal case: multiply by stride
142- # Boundary case: tuck in to ensure patch fits
143- z = (
144- z_idx * stride [0 ]
145- if z_idx < grid_shape [0 ] - 1
146- else volume_shape [0 ] - patch_size [0 ]
147- )
148- y = (
149- y_idx * stride [1 ]
150- if y_idx < grid_shape [1 ] - 1
151- else volume_shape [1 ] - patch_size [1 ]
152- )
153- x = (
154- x_idx * stride [2 ]
155- if x_idx < grid_shape [2 ] - 1
156- else volume_shape [2 ] - patch_size [2 ]
157- )
158-
159- positions .append ([z , y , x ])
160-
161- return np .array (positions , dtype = np .int32 ), grid_shape
162-
163-
16492__all__ = [
16593 "count_volume" ,
16694 "compute_total_samples" ,
167- "calculate_inference_grid" ,
16895]
0 commit comments