1+ # -*- coding: utf-8 -*-
2+ from __future__ import print_function , division
3+
4+ import torch
5+ import json
6+ import math
7+ import random
8+ import numpy as np
9+ from scipy import ndimage
10+ from pymic import TaskType
11+ from pymic .transform .abstract_transform import AbstractTransform
12+ from pymic .transform .crop import CenterCrop
13+ from pymic .transform .intensity import *
14+ from pymic .util .image_process import *
15+
16+ def get_position_label (roi = 96 , num_crops = 4 ):
17+ half = roi // 2
18+ max_roi = roi * num_crops
19+ center_x , center_y = np .random .randint (low = half , high = max_roi - half ), \
20+ np .random .randint (low = half , high = max_roi - half )
21+
22+ x_min , x_max = center_x - half , center_x + half
23+ y_min , y_max = center_y - half , center_y + half
24+
25+ total_area = roi * roi
26+ labels = []
27+ for j in range (num_crops ):
28+ for i in range (num_crops ):
29+ crop_x_min , crop_x_max = i * roi , (i + 1 ) * roi
30+ crop_y_min , crop_y_max = j * roi , (j + 1 ) * roi
31+
32+ dx = min (crop_x_max , x_max ) - max (crop_x_min , x_min )
33+ dy = min (crop_y_max , y_max ) - max (crop_y_min , y_min )
34+ if dx <= 0 or dy <= 0 :
35+ area = 0
36+ else :
37+ area = (dx * dy ) / total_area
38+ labels .append (area )
39+
40+ labels = np .asarray (labels ).reshape (1 , num_crops * num_crops )
41+ return x_min , y_min , labels
42+
43+ class Crop4VoCo (CenterCrop ):
44+ """
45+ Randomly crop an volume into two views with augmentation. This is used for
46+ self-supervised pretraining such as DeSD.
47+
48+ The arguments should be written in the `params` dictionary, and it has the
49+ following fields:
50+
51+ :param `DualViewCrop_output_size`: (list/tuple) Desired output size [D, H, W].
52+ The output channel is the same as the input channel.
53+ :param `DualViewCrop_scale_lower_bound`: (list/tuple) Lower bound of the range of scale
54+ for each dimension. e.g. (1.0, 0.5, 0.5).
55+ param `DualViewCrop_scale_upper_bound`: (list/tuple) Upper bound of the range of scale
56+ for each dimension. e.g. (1.0, 2.0, 2.0).
57+ :param `DualViewCrop_inverse`: (optional, bool) Is inverse transform needed for inference.
58+ Default is `False`. Currently, the inverse transform is not supported, and
59+ this transform is assumed to be used only during training stage.
60+ """
61+ def __init__ (self , params ):
62+ roi_size = params .get ('Crop4VoCo_roi_size' .lower (), 64 )
63+ if isinstance (roi_size , int ):
64+ self .roi_size = [roi_size ] * 3
65+ else :
66+ self .roi_size = roi_size
67+ self .roi_num = params .get ('Crop4VoCo_roi_num' .lower (), 2 )
68+ self .base_num = params .get ('Crop4VoCo_base_num' .lower (), 4 )
69+
70+ self .inverse = params .get ('Crop4VoCo_inverse' .lower (), False )
71+ self .task = params ['Task' .lower ()]
72+
73+ def __call__ (self , sample ):
74+ image = sample ['image' ]
75+ channel , input_size = image .shape [0 ], image .shape [1 :]
76+ input_dim = len (input_size )
77+ # print(input_size, self.roi_size)
78+ assert (input_size [0 ] == self .roi_size [0 ])
79+ assert (input_size [1 ] == self .roi_size [1 ] * self .base_num )
80+ assert (input_size [2 ] == self .roi_size [2 ] * self .base_num )
81+
82+ base_num , roi_num , roi_size = self .base_num , self .roi_num , self .roi_size
83+ base_crops , roi_crops , roi_labels = [], [], []
84+ crop_size = [channel ] + list (roi_size )
85+ for j in range (base_num ):
86+ for i in range (base_num ):
87+ crop_min = [0 , 0 , roi_size [1 ]* j , roi_size [2 ]* i ]
88+ crop_max = [crop_min [d ] + crop_size [d ] for d in range (4 )]
89+ crop_out = crop_ND_volume_with_bounding_box (image , crop_min , crop_max )
90+ base_crops .append (crop_out )
91+
92+ for i in range (roi_num ):
93+ x_min , y_min , label = get_position_label (self .roi_size [2 ], base_num )
94+ # print('label', label)
95+ crop_min = [0 , 0 , y_min , x_min ]
96+ crop_max = [crop_min [d ] + crop_size [d ] for d in range (4 )]
97+ crop_out = crop_ND_volume_with_bounding_box (image , crop_min , crop_max )
98+ roi_crops .append (crop_out )
99+ roi_labels .append (label )
100+ roi_labels = np .concatenate (roi_labels , 0 ).reshape (roi_num , base_num * base_num )
101+
102+ base_crops = np .stack (base_crops , 0 )
103+ roi_crops = np .stack (roi_crops , 0 )
104+ sample ['image' ] = base_crops , roi_crops , roi_labels
105+ return sample
106+
107+
0 commit comments