Skip to content

Commit 17b508c

Browse files
authored
Add test to check if importing DALI doesn't break Torch process forking (#3669)
- adds a test with PyTorch to make sure that importing DALI doesn't interfere with torch.multiprocessing Signed-off-by: Janusz Lisiecki <jlisiecki@nvidia.com>
1 parent 1ed4bf6 commit 17b508c

2 files changed

Lines changed: 33 additions & 0 deletions

File tree

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import torch
16+
from torch.multiprocessing import Process
17+
# we need this import to check if it is safe to import DALI and not touch the CUDA runtime
18+
# that could crash forked process
19+
import nvidia.dali as dali
20+
21+
22+
def task_function():
23+
torch.cuda.set_device(0)
24+
25+
26+
def test_actual_proc():
27+
phase_process = Process(target=task_function)
28+
# phase_process.daemon = True
29+
phase_process.start()
30+
phase_process.join()
31+
assert phase_process.exitcode == 0
32+

qa/TL0_python_self_test_frameworks/test_pytorch.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ test_body() {
1212
nosetests --verbose test_external_source_pytorch_dlpack.py
1313
nosetests --verbose test_external_source_parallel_pytorch.py
1414
nosetests --verbose test_backend_impl_torch_dlpack.py
15+
nosetests --verbose test_dali_fork_torch.py
1516
nosetests --verbose --attr 'pytorch' test_external_source_impl_utils.py
1617
nosetests --verbose --attr 'pytorch' test_pipeline_debug.py
1718
}

0 commit comments

Comments
 (0)