Skip to content

Commit f749c58

Browse files
committed
Version 1.1.2: Bug fixes in portpy.ai module.
1 parent 2979d04 commit f749c58

File tree

5 files changed

+104
-4
lines changed

5 files changed

+104
-4
lines changed

examples/python_files/imrt_dose_prediction.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242

4343

4444
# preprocess portpy data
45-
# data_preprocess(in_dir, out_dir)
45+
data_preprocess(in_dir, out_dir)
4646

4747
# **Note** split the data in train and test folder in the output directory before running further code
4848
# e.g. out_dir\train\Lung_Patient_2 out_dir\test\Lung_Patient_9

portpy/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,6 @@
1515
# a separate commercial license.
1616
# ----------------------------------------------------------------------
1717

18-
__version__ = "1.1.1"
18+
__version__ = "1.1.2"
1919
# Change version here manually to reflect it everywhere
2020
from portpy import photon

portpy/ai/data/__init__.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,76 @@
1515
# a separate commercial license.
1616
# ----------------------------------------------------------------------
1717

18+
"""This package includes all the modules related to data loading and preprocessing
19+
20+
To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
21+
You need to implement four functions:
22+
-- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
23+
-- <__len__>: return the size of dataset.
24+
-- <__getitem__>: get a data point from data loader.
25+
-- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
26+
27+
Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
28+
See our template dataset class 'template_dataset.py' for more details.
29+
"""
30+
import os
31+
import sys
32+
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
33+
if project_root not in sys.path:
34+
sys.path.append(project_root)
35+
36+
import importlib
37+
import torch.utils.data
38+
from portpy.ai.data.base_dataset import BaseDataset
39+
40+
41+
def find_dataset_using_name(dataset_name):
42+
"""Import the module "data/[dataset_name]_dataset.py".
43+
44+
In the file, the class called DatasetNameDataset() will
45+
be instantiated. It has to be a subclass of BaseDataset,
46+
and it is case-insensitive.
47+
"""
48+
dataset_filename = "portpy.ai.data." + dataset_name + "_dataset"
49+
datasetlib = importlib.import_module(dataset_filename)
50+
51+
dataset = None
52+
target_dataset_name = dataset_name.replace('_', '') + 'dataset'
53+
for name, cls in datasetlib.__dict__.items():
54+
if name.lower() == target_dataset_name.lower() \
55+
and issubclass(cls, BaseDataset):
56+
dataset = cls
57+
58+
if dataset is None:
59+
raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
60+
61+
return dataset
62+
63+
64+
def get_option_setter(dataset_name):
65+
"""Return the static method <modify_commandline_options> of the dataset class."""
66+
dataset_class = find_dataset_using_name(dataset_name)
67+
return dataset_class.modify_commandline_options
68+
69+
70+
def create_dataset(opt):
71+
"""Create a dataset given the option.
72+
73+
This function wraps the class CustomDatasetDataLoader.
74+
This is the main interface between this package and 'train.py'/'test.py'
75+
76+
Example:
77+
from data import create_dataset
78+
dataset = create_dataset(opt)
79+
"""
80+
data_loader = CustomDatasetDataLoader(opt)
81+
dataset = data_loader.load_data()
82+
return dataset
83+
84+
85+
class CustomDatasetDataLoader():
86+
"""Wrapper class of Dataset class that performs multi-threaded data loading"""
87+
1888
def __init__(self, opt):
1989
"""Initialize this class
2090
@@ -44,4 +114,4 @@ def __iter__(self):
44114
for i, data in enumerate(self.dataloader):
45115
if i * self.opt.batch_size >= self.opt.max_dataset_size:
46116
break
47-
yield data
117+
yield data

portpy/ai/data/base_dataset.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,36 @@
1515
# a separate commercial license.
1616
# ----------------------------------------------------------------------
1717

18+
19+
"""This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
20+
21+
It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
22+
"""
23+
import random
24+
import numpy as np
25+
import torch.utils.data as data
26+
from PIL import Image
27+
import torchvision.transforms as transforms
28+
import torchvision.transforms.functional as TF
29+
from abc import ABC, abstractmethod
30+
import torch
31+
import warnings
32+
from scipy.ndimage import affine_transform
33+
34+
35+
warnings.filterwarnings("ignore")
36+
37+
38+
class BaseDataset(data.Dataset, ABC):
39+
"""This class is an abstract base class (ABC) for datasets.
40+
41+
To create a subclass, you need to implement the following four functions:
42+
-- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
43+
-- <__len__>: return the size of dataset.
44+
-- <__getitem__>: get a data point.
45+
-- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
46+
"""
47+
1848
def __init__(self, opt):
1949
"""Initialize the class; save the options in the class
2050

portpy/ai/data/dosepred3d_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def __getitem__(self, index):
8888
if self.phase == 'train':
8989
# A, B, OAR, beam, hist, bins = transform_3d_data(A, B, OAR, PTV, beam, hist, bins, transform=False)
9090

91-
A, B, OAR, beam = transform_3d_data(A, B, OAR, PTV, beam, augment=self.transform)
91+
A, B, OAR, beam = transform_3d_data(A, B, OAR, PTV, beam, augment=self.augment)
9292
A = torch.unsqueeze(A, dim=0) # Add channel dimensions as data is 3D
9393
B = torch.unsqueeze(B, dim=0)
9494
beam = torch.unsqueeze(beam, dim=0)

0 commit comments

Comments
 (0)