Skip to content

Commit c860f3f

Browse files
authored
Warn when multiprocessing start method is 'fork' (#1309)
* Warn when multiprocessing start method is 'fork' CUDA does not support the fork() system call. Forked subprocesses exhibit undefined behavior, including failure to initialize CUDA contexts and devices. Add warning checks in multiprocessing reduction functions for IPC objects (DeviceMemoryResource, IPCAllocationHandle, Event) that warn when the start method is 'fork'. The warning is emitted once per process when IPC objects are serialized. Fixes #1136 * Skip multiprocessing warning tests on Windows Change mempool_device to ipc_device fixture for tests that require IPC-enabled memory resources. The ipc_device fixture properly skips on Windows where IPC is not supported. * Add reset_fork_warning function and rename check_multiprocessing_start_method - Add reset_fork_warning() function for testing purposes - Rename _check_multiprocessing_start_method to check_multiprocessing_start_method (remove leading underscore) - Update all tests to use reset_fork_warning() instead of directly accessing internal flag - Fix trailing whitespace
1 parent 5c42278 commit c860f3f

File tree

4 files changed

+202
-0
lines changed

4 files changed

+202
-0
lines changed

cuda_core/cuda/core/experimental/_event.pyx

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ from typing import TYPE_CHECKING, Optional
2121
from cuda.core.experimental._context import Context
2222
from cuda.core.experimental._utils.cuda_utils import (
2323
CUDAError,
24+
check_multiprocessing_start_method,
2425
driver,
2526
)
2627
if TYPE_CHECKING:
@@ -300,6 +301,7 @@ cdef class IPCEventDescriptor:
300301

301302

302303
def _reduce_event(event):
304+
check_multiprocessing_start_method()
303305
return event.from_ipc_descriptor, (event.get_ipc_descriptor(),)
304306

305307
multiprocessing.reduction.register(Event, _reduce_event)

cuda_core/cuda/core/experimental/_memory/_ipc.pyx

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ from cuda.bindings cimport cydriver
1010
from cuda.core.experimental._memory._buffer cimport Buffer
1111
from cuda.core.experimental._stream cimport default_stream
1212
from cuda.core.experimental._utils.cuda_utils cimport HANDLE_RETURN
13+
from cuda.core.experimental._utils.cuda_utils import check_multiprocessing_start_method
1314

1415
import multiprocessing
1516
import os
@@ -129,6 +130,7 @@ cdef class IPCAllocationHandle:
129130

130131

131132
def _reduce_allocation_handle(alloc_handle):
133+
check_multiprocessing_start_method()
132134
df = multiprocessing.reduction.DupFd(alloc_handle.handle)
133135
return _reconstruct_allocation_handle, (type(alloc_handle), df, alloc_handle.uuid)
134136

@@ -141,6 +143,7 @@ multiprocessing.reduction.register(IPCAllocationHandle, _reduce_allocation_handl
141143

142144

143145
def _deep_reduce_device_memory_resource(mr):
146+
check_multiprocessing_start_method()
144147
from .._device import Device
145148
device = Device(mr.device_id)
146149
alloc_handle = mr.get_allocation_handle()

cuda_core/cuda/core/experimental/_utils/cuda_utils.pyx

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
import functools
66
from functools import partial
77
import importlib.metadata
8+
import multiprocessing
9+
import platform
10+
import warnings
811
from collections import namedtuple
912
from collections.abc import Sequence
1013
from contextlib import ExitStack
@@ -283,3 +286,48 @@ class Transaction:
283286
"""
284287
# pop_all() empties this stack so no callbacks are triggered on exit.
285288
self._stack.pop_all()
289+
290+
291+
# Track whether we've already warned about fork method
292+
_fork_warning_checked = False
293+
294+
295+
def reset_fork_warning():
296+
"""Reset the fork warning check flag for testing purposes.
297+
298+
This function is intended for use in tests to allow multiple test runs
299+
to check the warning behavior.
300+
"""
301+
global _fork_warning_checked
302+
_fork_warning_checked = False
303+
304+
305+
def check_multiprocessing_start_method():
306+
"""Check if multiprocessing start method is 'fork' and warn if so."""
307+
global _fork_warning_checked
308+
if _fork_warning_checked:
309+
return
310+
_fork_warning_checked = True
311+
312+
# Common warning message parts
313+
common_message = (
314+
"CUDA does not support. Forked subprocesses exhibit undefined behavior, "
315+
"including failure to initialize CUDA contexts and devices. Set the start method "
316+
"to 'spawn' before creating processes that use CUDA. "
317+
"Use: multiprocessing.set_start_method('spawn')"
318+
)
319+
320+
try:
321+
start_method = multiprocessing.get_start_method()
322+
if start_method == "fork":
323+
message = f"multiprocessing start method is 'fork', which {common_message}"
324+
warnings.warn(message, UserWarning, stacklevel=3)
325+
except RuntimeError:
326+
# get_start_method() can raise RuntimeError if start method hasn't been set
327+
# In this case, default is 'fork' on Linux, so we should warn
328+
if platform.system() == "Linux":
329+
message = (
330+
f"multiprocessing start method is not set and defaults to 'fork' on Linux, "
331+
f"which {common_message}"
332+
)
333+
warnings.warn(message, UserWarning, stacklevel=3)
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""
5+
Test that warnings are emitted when multiprocessing start method is 'fork'
6+
and IPC objects are serialized.
7+
8+
These tests use mocking to simulate the 'fork' start method without actually
9+
using fork, avoiding the need for subprocess isolation.
10+
"""
11+
12+
import warnings
13+
from unittest.mock import patch
14+
15+
from cuda.core.experimental import DeviceMemoryResource, DeviceMemoryResourceOptions, EventOptions
16+
from cuda.core.experimental._event import _reduce_event
17+
from cuda.core.experimental._memory._ipc import (
18+
_deep_reduce_device_memory_resource,
19+
_reduce_allocation_handle,
20+
)
21+
from cuda.core.experimental._utils.cuda_utils import reset_fork_warning
22+
23+
24+
def test_warn_on_fork_method_device_memory_resource(ipc_device):
25+
"""Test that warning is emitted when DeviceMemoryResource is pickled with fork method."""
26+
device = ipc_device
27+
device.set_current()
28+
options = DeviceMemoryResourceOptions(max_size=2097152, ipc_enabled=True)
29+
mr = DeviceMemoryResource(device, options=options)
30+
31+
with patch("multiprocessing.get_start_method", return_value="fork"), warnings.catch_warnings(record=True) as w:
32+
warnings.simplefilter("always")
33+
34+
# Reset the warning flag to allow testing
35+
reset_fork_warning()
36+
37+
# Trigger the reduction function directly
38+
_deep_reduce_device_memory_resource(mr)
39+
40+
# Check that warning was emitted
41+
assert len(w) == 1, f"Expected 1 warning, got {len(w)}: {[str(warning.message) for warning in w]}"
42+
warning = w[0]
43+
assert warning.category is UserWarning
44+
assert "fork" in str(warning.message).lower()
45+
assert "spawn" in str(warning.message).lower()
46+
assert "undefined behavior" in str(warning.message).lower()
47+
48+
mr.close()
49+
50+
51+
def test_warn_on_fork_method_allocation_handle(ipc_device):
52+
"""Test that warning is emitted when IPCAllocationHandle is pickled with fork method."""
53+
device = ipc_device
54+
device.set_current()
55+
options = DeviceMemoryResourceOptions(max_size=2097152, ipc_enabled=True)
56+
mr = DeviceMemoryResource(device, options=options)
57+
alloc_handle = mr.get_allocation_handle()
58+
59+
with patch("multiprocessing.get_start_method", return_value="fork"), warnings.catch_warnings(record=True) as w:
60+
warnings.simplefilter("always")
61+
62+
# Reset the warning flag to allow testing
63+
reset_fork_warning()
64+
65+
# Trigger the reduction function directly
66+
_reduce_allocation_handle(alloc_handle)
67+
68+
# Check that warning was emitted
69+
assert len(w) == 1
70+
warning = w[0]
71+
assert warning.category is UserWarning
72+
assert "fork" in str(warning.message).lower()
73+
74+
mr.close()
75+
76+
77+
def test_warn_on_fork_method_event(mempool_device):
78+
"""Test that warning is emitted when Event is pickled with fork method."""
79+
device = mempool_device
80+
device.set_current()
81+
stream = device.create_stream()
82+
ipc_event_options = EventOptions(ipc_enabled=True)
83+
event = stream.record(options=ipc_event_options)
84+
85+
with patch("multiprocessing.get_start_method", return_value="fork"), warnings.catch_warnings(record=True) as w:
86+
warnings.simplefilter("always")
87+
88+
# Reset the warning flag to allow testing
89+
reset_fork_warning()
90+
91+
# Trigger the reduction function directly
92+
_reduce_event(event)
93+
94+
# Check that warning was emitted
95+
assert len(w) == 1
96+
warning = w[0]
97+
assert warning.category is UserWarning
98+
assert "fork" in str(warning.message).lower()
99+
100+
event.close()
101+
102+
103+
def test_no_warning_with_spawn_method(ipc_device):
104+
"""Test that no warning is emitted when start method is 'spawn'."""
105+
device = ipc_device
106+
device.set_current()
107+
options = DeviceMemoryResourceOptions(max_size=2097152, ipc_enabled=True)
108+
mr = DeviceMemoryResource(device, options=options)
109+
110+
with patch("multiprocessing.get_start_method", return_value="spawn"), warnings.catch_warnings(record=True) as w:
111+
warnings.simplefilter("always")
112+
113+
# Reset the warning flag to allow testing
114+
reset_fork_warning()
115+
116+
# Trigger the reduction function directly
117+
_deep_reduce_device_memory_resource(mr)
118+
119+
# Check that no fork-related warning was emitted
120+
fork_warnings = [warning for warning in w if "fork" in str(warning.message).lower()]
121+
assert len(fork_warnings) == 0, f"Unexpected warning: {fork_warnings[0].message if fork_warnings else None}"
122+
123+
mr.close()
124+
125+
126+
def test_warning_emitted_only_once(ipc_device):
127+
"""Test that warning is only emitted once even when multiple objects are pickled."""
128+
device = ipc_device
129+
device.set_current()
130+
options = DeviceMemoryResourceOptions(max_size=2097152, ipc_enabled=True)
131+
mr1 = DeviceMemoryResource(device, options=options)
132+
mr2 = DeviceMemoryResource(device, options=options)
133+
134+
with patch("multiprocessing.get_start_method", return_value="fork"), warnings.catch_warnings(record=True) as w:
135+
warnings.simplefilter("always")
136+
137+
# Reset the warning flag to allow testing
138+
reset_fork_warning()
139+
140+
# Trigger reduction multiple times
141+
_deep_reduce_device_memory_resource(mr1)
142+
_deep_reduce_device_memory_resource(mr2)
143+
144+
# Check that warning was emitted only once
145+
fork_warnings = [warning for warning in w if "fork" in str(warning.message).lower()]
146+
assert len(fork_warnings) == 1, f"Expected 1 warning, got {len(fork_warnings)}"
147+
148+
mr1.close()
149+
mr2.close()

0 commit comments

Comments
 (0)