Skip to content

Commit 92a9bf9

Browse files
committed
Create crop4voco.py
1 parent 254491e commit 92a9bf9

File tree

1 file changed

+107
-0
lines changed

1 file changed

+107
-0
lines changed

pymic/transform/crop4voco.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# -*- coding: utf-8 -*-
2+
from __future__ import print_function, division
3+
4+
import torch
5+
import json
6+
import math
7+
import random
8+
import numpy as np
9+
from scipy import ndimage
10+
from pymic import TaskType
11+
from pymic.transform.abstract_transform import AbstractTransform
12+
from pymic.transform.crop import CenterCrop
13+
from pymic.transform.intensity import *
14+
from pymic.util.image_process import *
15+
16+
def get_position_label(roi=96, num_crops=4):
17+
half = roi // 2
18+
max_roi = roi * num_crops
19+
center_x, center_y = np.random.randint(low=half, high=max_roi - half), \
20+
np.random.randint(low=half, high=max_roi - half)
21+
22+
x_min, x_max = center_x - half, center_x + half
23+
y_min, y_max = center_y - half, center_y + half
24+
25+
total_area = roi * roi
26+
labels = []
27+
for j in range(num_crops):
28+
for i in range(num_crops):
29+
crop_x_min, crop_x_max = i * roi, (i + 1) * roi
30+
crop_y_min, crop_y_max = j * roi, (j + 1) * roi
31+
32+
dx = min(crop_x_max, x_max) - max(crop_x_min, x_min)
33+
dy = min(crop_y_max, y_max) - max(crop_y_min, y_min)
34+
if dx <= 0 or dy <= 0:
35+
area = 0
36+
else:
37+
area = (dx * dy) / total_area
38+
labels.append(area)
39+
40+
labels = np.asarray(labels).reshape(1, num_crops * num_crops)
41+
return x_min, y_min, labels
42+
43+
class Crop4VoCo(CenterCrop):
44+
"""
45+
Randomly crop an volume into two views with augmentation. This is used for
46+
self-supervised pretraining such as DeSD.
47+
48+
The arguments should be written in the `params` dictionary, and it has the
49+
following fields:
50+
51+
:param `DualViewCrop_output_size`: (list/tuple) Desired output size [D, H, W].
52+
The output channel is the same as the input channel.
53+
:param `DualViewCrop_scale_lower_bound`: (list/tuple) Lower bound of the range of scale
54+
for each dimension. e.g. (1.0, 0.5, 0.5).
55+
param `DualViewCrop_scale_upper_bound`: (list/tuple) Upper bound of the range of scale
56+
for each dimension. e.g. (1.0, 2.0, 2.0).
57+
:param `DualViewCrop_inverse`: (optional, bool) Is inverse transform needed for inference.
58+
Default is `False`. Currently, the inverse transform is not supported, and
59+
this transform is assumed to be used only during training stage.
60+
"""
61+
def __init__(self, params):
62+
roi_size = params.get('Crop4VoCo_roi_size'.lower(), 64)
63+
if isinstance(roi_size, int):
64+
self.roi_size = [roi_size] * 3
65+
else:
66+
self.roi_size = roi_size
67+
self.roi_num = params.get('Crop4VoCo_roi_num'.lower(), 2)
68+
self.base_num = params.get('Crop4VoCo_base_num'.lower(), 4)
69+
70+
self.inverse = params.get('Crop4VoCo_inverse'.lower(), False)
71+
self.task = params['Task'.lower()]
72+
73+
def __call__(self, sample):
74+
image = sample['image']
75+
channel, input_size = image.shape[0], image.shape[1:]
76+
input_dim = len(input_size)
77+
# print(input_size, self.roi_size)
78+
assert(input_size[0] == self.roi_size[0])
79+
assert(input_size[1] == self.roi_size[1] * self.base_num)
80+
assert(input_size[2] == self.roi_size[2] * self.base_num)
81+
82+
base_num, roi_num, roi_size = self.base_num, self.roi_num, self.roi_size
83+
base_crops, roi_crops, roi_labels = [], [], []
84+
crop_size = [channel] + list(roi_size)
85+
for j in range(base_num):
86+
for i in range(base_num):
87+
crop_min = [0, 0, roi_size[1]*j, roi_size[2]*i]
88+
crop_max = [crop_min[d] + crop_size[d] for d in range(4)]
89+
crop_out = crop_ND_volume_with_bounding_box(image, crop_min, crop_max)
90+
base_crops.append(crop_out)
91+
92+
for i in range(roi_num):
93+
x_min, y_min, label = get_position_label(self.roi_size[2], base_num)
94+
# print('label', label)
95+
crop_min = [0, 0, y_min, x_min]
96+
crop_max = [crop_min[d] + crop_size[d] for d in range(4)]
97+
crop_out = crop_ND_volume_with_bounding_box(image, crop_min, crop_max)
98+
roi_crops.append(crop_out)
99+
roi_labels.append(label)
100+
roi_labels = np.concatenate(roi_labels, 0).reshape(roi_num, base_num * base_num)
101+
102+
base_crops = np.stack(base_crops, 0)
103+
roi_crops = np.stack(roi_crops, 0)
104+
sample['image'] = base_crops, roi_crops, roi_labels
105+
return sample
106+
107+

0 commit comments

Comments
 (0)