Skip to content

Commit 4cb8688

Browse files
committed
add alias for paddle.io.DataLoader
1 parent 40d3065 commit 4cb8688

3 files changed

Lines changed: 140 additions & 0 deletions

File tree

python/paddle/utils/data/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
from .dataloader import (
16+
DataLoader,
1617
default_collate,
1718
get_worker_info,
1819
)
@@ -39,6 +40,7 @@
3940
'ChainDataset',
4041
'ConcatDataset',
4142
'Dataset',
43+
'DataLoader',
4244
'IterableDataset',
4345
'Subset',
4446
'random_split',

python/paddle/utils/data/dataloader.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,78 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from __future__ import annotations
15+
16+
import warnings
17+
from typing import TYPE_CHECKING, Any
18+
19+
from paddle.io import DataLoader as PaddleDataLoader
1420

1521
from ._utils.collate import (
1622
default_collate as default_collate,
1723
)
1824
from ._utils.worker import (
1925
get_worker_info as get_worker_info,
2026
)
27+
28+
if TYPE_CHECKING:
29+
from collections.abc import Callable
30+
31+
from paddle.io.dataloader import BatchSampler
32+
from paddle.io.dataloader.dataset import Dataset
33+
from paddle.io.reader import _CollateFn
34+
35+
36+
class DataLoader(PaddleDataLoader):
37+
def __init__(
38+
self,
39+
dataset: Dataset[Any],
40+
batch_size: int | None = 1,
41+
shuffle: bool | None = None,
42+
sampler: BatchSampler | None = None,
43+
batch_sampler: BatchSampler | None = None,
44+
num_workers: int = 0,
45+
collate_fn: _CollateFn | None = None,
46+
pin_memory: bool = False,
47+
drop_last: bool = False,
48+
timeout: float = 0,
49+
worker_init_fn: Callable[[int], None] | None = None,
50+
multiprocessing_context=None,
51+
generator=None,
52+
*,
53+
prefetch_factor: int | None = None,
54+
persistent_workers: bool = False,
55+
pin_memory_device: str = "",
56+
in_order: bool = True,
57+
) -> None:
58+
if (
59+
pin_memory is True
60+
or multiprocessing_context is not None
61+
or generator is not None
62+
or prefetch_factor is not None
63+
or len(pin_memory_device) > 0
64+
or in_order is False
65+
):
66+
warnings.warn(
67+
"pin_memory, multiprocessing_context, generator, prefetch_factor, pin_memory_device, in_order are currently not supported in DataLoader and will be ignored."
68+
)
69+
70+
if sampler is not None:
71+
if batch_sampler is not None:
72+
raise ValueError(
73+
"Cannot specify both sampler and batch_sampler"
74+
)
75+
batch_sampler = sampler
76+
77+
super().__init__(
78+
dataset=dataset,
79+
batch_sampler=batch_sampler,
80+
batch_size=batch_size,
81+
shuffle=shuffle,
82+
drop_last=drop_last,
83+
collate_fn=collate_fn,
84+
num_workers=num_workers,
85+
timeout=timeout,
86+
worker_init_fn=worker_init_fn,
87+
persistent_workers=persistent_workers,
88+
)

test/legacy_test/test_api_compatibility_part4.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1283,5 +1283,75 @@ def test_dygraph_Compatibility(self):
12831283
paddle.enable_static()
12841284

12851285

1286+
class RandomDataset(paddle.utils.data.Dataset):
1287+
def __init__(self, num_samples):
1288+
self.num_samples = num_samples
1289+
1290+
def __getitem__(self, idx):
1291+
image = np.random.random([784]).astype('float32')
1292+
label = np.random.randint(0, 10 - 1, (1,)).astype('int64')
1293+
return image, label
1294+
1295+
def __len__(self):
1296+
return self.num_samples
1297+
1298+
1299+
class TestDataLoaderAPI(unittest.TestCase):
1300+
def setUp(self):
1301+
np.random.seed(255)
1302+
self.batch_num = 4
1303+
self.batch_size = 8
1304+
self.dataset = RandomDataset(self.batch_num * self.batch_size)
1305+
self.batch_sampler = paddle.utils.data.BatchSampler(
1306+
self.dataset,
1307+
batch_size=self.batch_size,
1308+
shuffle=True,
1309+
drop_last=True,
1310+
)
1311+
1312+
def iter_loader_data(self, loader):
1313+
for _ in range(3):
1314+
for image, label in loader():
1315+
relu = paddle.nn.functional.relu(image)
1316+
self.assertEqual(image.shape, [self.batch_size, 784])
1317+
self.assertEqual(label.shape, [self.batch_size, 1])
1318+
self.assertEqual(relu.shape, [self.batch_size, 784])
1319+
1320+
def test_dygraph_Compatibility(self):
1321+
paddle.disable_static()
1322+
# case 1
1323+
loader = paddle.utils.data.DataLoader(
1324+
self.dataset,
1325+
self.batch_size,
1326+
shuffle=True,
1327+
num_workers=0,
1328+
drop_last=True,
1329+
)
1330+
self.iter_loader_data(loader)
1331+
# case 2
1332+
loader = paddle.utils.data.dataloader.DataLoader(
1333+
dataset=self.dataset,
1334+
batch_sampler=self.batch_sampler,
1335+
)
1336+
self.iter_loader_data(loader)
1337+
# case 3
1338+
loader = paddle.utils.data.DataLoader(
1339+
dataset=self.dataset,
1340+
sampler=self.batch_sampler,
1341+
)
1342+
self.iter_loader_data(loader)
1343+
paddle.enable_static()
1344+
1345+
def test_error(self):
1346+
paddle.disable_static()
1347+
with self.assertRaises(ValueError):
1348+
loader = paddle.utils.data.dataloader.DataLoader(
1349+
dataset=self.dataset,
1350+
sampler=self.batch_sampler,
1351+
batch_sampler=self.batch_sampler,
1352+
)
1353+
paddle.enable_static()
1354+
1355+
12861356
if __name__ == "__main__":
12871357
unittest.main()

0 commit comments

Comments
 (0)