Skip to content

Commit 48adf1d

Browse files
committed
update dataset and requirement version
1 parent 92a9bf9 commit 48adf1d

File tree

7 files changed

+57
-39
lines changed

7 files changed

+57
-39
lines changed

pymic/net_run/noisy_label/nll_clslsr.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -165,15 +165,20 @@ def get_confidence_map(cfg_file):
165165
transform_list.append(one_transform)
166166
data_transform = transforms.Compose(transform_list)
167167

168+
stage_dir = config['dataset']['train_dir']
168169
csv_file = config['dataset']['train_csv']
169170
modal_num = config['dataset'].get('modal_num', 1)
170-
stage_dir = config['dataset']['train_dir']
171+
stage_dim = config['dataset'].get('train_dim', 3)
172+
lab_key = config['dataset'].get('train_label_key', 'label')
173+
171174
dataset = NiftyDataset(root_dir = stage_dir,
172-
csv_file = csv_file,
173-
modal_num = modal_num,
174-
with_label= True,
175-
transform = data_transform,
176-
task = agent.task_type)
175+
csv_file = csv_file,
176+
modal_num = modal_num,
177+
image_dim = stage_dim,
178+
allow_missing_modal = False,
179+
label_key = lab_key,
180+
transform = data_transform,
181+
task = agent.task_type)
177182

178183
agent.set_datasets(None, None, dataset)
179184
agent.transform_list = transform_list

pymic/net_run/noisy_label/nll_dast.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -129,13 +129,17 @@ def get_noisy_dataset_from_config(self):
129129
data_transform = transforms.Compose(transform_list)
130130

131131
modal_num = self.config['dataset'].get('modal_num', 1)
132-
csv_file = self.config['dataset'].get('train_csv_noise', None)
132+
stage_dim = self.config['dataset'].get('train_dim', 3)
133+
lab_key = self.config['dataset'].get('train_label_key', 'label')
134+
csv_file = self.config['dataset'].get('train_csv_noise', None)
133135
dataset = NiftyDataset(root_dir = self.config['dataset']['train_dir'],
134-
csv_file = csv_file,
135-
modal_num = modal_num,
136-
with_label= True,
137-
transform = data_transform ,
138-
task = self.task_type)
136+
csv_file = csv_file,
137+
modal_num = modal_num,
138+
image_dim = stage_dim,
139+
allow_missing_modal = False,
140+
label_key = lab_key,
141+
transform = data_transform,
142+
task = self.task_type)
139143
return dataset
140144

141145

pymic/net_run/predict.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,23 +21,26 @@ def main():
2121
exit()
2222
parser = argparse.ArgumentParser()
2323
parser.add_argument("cfg", help="configuration file for testing")
24-
parser.add_argument("-test_csv", help="the csv file for testing images",
25-
required=False, default=None)
26-
parser.add_argument("-output_dir", help="the output dir for inference results",
24+
parser.add_argument("--test_csv", help="the csv file for testing images",
2725
required=False, default=None)
28-
parser.add_argument("-ckpt_dir", help="the dir for trained model",
26+
parser.add_argument("--test_dir", help="the dir for testing images",
2927
required=False, default=None)
30-
parser.add_argument("-ckpt_mode", help="the mode for chekpoint: 0-latest, 1-best, 2-customized",
28+
parser.add_argument("--output_dir", help="the output dir for inference results",
3129
required=False, default=None)
32-
parser.add_argument("-ckpt_name", help="the name chekpoint if ckpt_mode = 2",
30+
parser.add_argument("--ckpt_dir", help="the dir for trained model",
3331
required=False, default=None)
34-
parser.add_argument("-gpus", help="the gpus for runing, e.g., [0]",
32+
parser.add_argument("--ckpt_mode", help="the mode for chekpoint: 0-latest, 1-best, 2-customized",
33+
required=False, default=None)
34+
parser.add_argument("--ckpt_name", help="the name chekpoint if ckpt_mode = 2",
35+
required=False, default=None)
36+
parser.add_argument("--gpus", help="the gpus for runing, e.g., [0]",
3537
required=False, default=None)
3638
args = parser.parse_args()
3739
if(not os.path.isfile(args.cfg)):
3840
raise ValueError("The config file does not exist: " + args.cfg)
3941
config = parse_config(args)
4042
config = synchronize_config(config)
43+
print(config)
4144
log_dir = config['testing']['output_dir']
4245
if(not os.path.exists(log_dir)):
4346
os.makedirs(log_dir, exist_ok=True)

pymic/net_run/semi_sup/ssl_abstract.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,16 @@ def get_unlabeled_dataset_from_config(self):
5252
self.transform_list.append(one_transform)
5353
data_transform = transforms.Compose(self.transform_list)
5454

55-
csv_file = self.config['dataset'].get('train_csv_unlab', None)
55+
csv_file = self.config['dataset'].get('train_csv_unlab', None)
56+
stage_dim = self.config['dataset'].get('train_dim', 3)
5657
dataset = NiftyDataset(root_dir = train_dir,
5758
csv_file = csv_file,
5859
modal_num = modal_num,
59-
with_label= False,
60-
transform = data_transform )
60+
image_dim = stage_dim,
61+
allow_missing_modal = False,
62+
label_key = None,
63+
transform = data_transform,
64+
task = self.task_type)
6165
return dataset
6266

6367
def create_dataset(self):

pymic/net_run/train.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,15 +54,15 @@ def main():
5454
exit()
5555
parser = argparse.ArgumentParser()
5656
parser.add_argument("cfg", help="configuration file for training")
57-
parser.add_argument("-train_csv", help="the csv file for training images",
57+
parser.add_argument("--train_csv", help="the csv file for training images",
5858
required=False, default=None)
59-
parser.add_argument("-valid_csv", help="the csv file for validation images",
59+
parser.add_argument("--valid_csv", help="the csv file for validation images",
6060
required=False, default=None)
61-
parser.add_argument("-ckpt_dir", help="the output dir for trained model",
61+
parser.add_argument("--ckpt_dir", help="the output dir for trained model",
6262
required=False, default=None)
63-
parser.add_argument("-iter_max", help="the maximal iteration number for training",
63+
parser.add_argument("--iter_max", help="the maximal iteration number for training",
6464
required=False, default=None)
65-
parser.add_argument("-gpus", help="the gpus for runing, e.g., [0]",
65+
parser.add_argument("--gpus", help="the gpus for runing, e.g., [0]",
6666
required=False, default=None)
6767
args = parser.parse_args()
6868
if(not os.path.isfile(args.cfg)):

pymic/util/evaluation_cls.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -176,13 +176,13 @@ def main():
176176
:param pred_prob_csv: (str) The csv file for prediction probability.
177177
"""
178178
parser = argparse.ArgumentParser()
179-
parser.add_argument("-cfg", help="configuration file for evaluation",
179+
parser.add_argument("--cfg", help="configuration file for evaluation",
180180
required=False, default=None)
181-
parser.add_argument("-metric", help="evaluation metrics, e.g., accuracy, or [accuracy, auc]",
181+
parser.add_argument("--metric", help="evaluation metrics, e.g., accuracy, or [accuracy, auc]",
182182
required=False, default=None)
183-
parser.add_argument("-gt_csv", help="csv file for ground truth",
183+
parser.add_argument("--gt_csv", help="csv file for ground truth",
184184
required=False, default=None)
185-
parser.add_argument("-pred_prob_csv", help="csv file for probability prediction",
185+
parser.add_argument("--pred_prob_csv", help="csv file for probability prediction",
186186
required=False, default=None)
187187
args = parser.parse_args()
188188
print(args)

requirements.txt

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
h5py
22
matplotlib>=3.1.2
3-
numpy>=1.17.4
4-
pandas>=0.25.3
5-
scikit-image>=0.16.2
6-
scikit-learn>=0.22
7-
scipy>=1.3.3
8-
SimpleITK>=2.0.0
3+
numpy>=1.23.5
4+
pandas>=1.5.2
5+
scikit-image>=0.19.3
6+
scikit-learn>=1.2.0
7+
scipy>=1.10.0
8+
SimpleITK>=2.0.2
99
tensorboard
1010
tensorboardX
11-
torch>=1.1.12
12-
torchvision>=0.13.0
11+
torch>=1.13.1
12+
torchvision>=0.14.1
13+
causal-conv1d>=1.5.0
14+
mamba-ssm>=2.2.4

0 commit comments

Comments
 (0)