33
44import multiprocessing as mp
55
6- from cuda .core .experimental import Buffer , Device , DeviceMemoryResource
6+ from cuda .core .experimental import Buffer , DeviceMemoryResource
77from utility import IPCBufferTestHelper
88
99CHILD_TIMEOUT_SEC = 20
@@ -21,7 +21,7 @@ def test_main(self, ipc_device, ipc_memory_resource):
2121
2222 # Start the child process.
2323 queue = mp .Queue ()
24- process = mp .Process (target = self .child_main , args = (mr , queue ))
24+ process = mp .Process (target = self .child_main , args = (device , mr , queue ))
2525 process .start ()
2626
2727 # Allocate and fill memory.
@@ -39,8 +39,8 @@ def test_main(self, ipc_device, ipc_memory_resource):
3939 # Verify that the buffer was modified.
4040 helper .verify_buffer (flipped = True )
4141
42- def child_main (self , mr , queue ):
43- device = Device ()
42+ def child_main (self , device , mr , queue ):
43+ device . set_current ()
4444 buffer = queue .get (timeout = CHILD_TIMEOUT_SEC )
4545 helper = IPCBufferTestHelper (device , buffer )
4646 helper .verify_buffer (flipped = False )
@@ -64,8 +64,8 @@ def test_main(self, ipc_device, ipc_memory_resource):
6464 q2 .put (buffer2 )
6565
6666 # Start the child processes.
67- p1 = mp .Process (target = self .child_main , args = (mr , 1 , q1 ))
68- p2 = mp .Process (target = self .child_main , args = (mr , 2 , q2 ))
67+ p1 = mp .Process (target = self .child_main , args = (device , mr , 1 , q1 ))
68+ p2 = mp .Process (target = self .child_main , args = (device , mr , 2 , q2 ))
6969 p1 .start ()
7070 p2 .start ()
7171
@@ -79,10 +79,10 @@ def test_main(self, ipc_device, ipc_memory_resource):
7979 IPCBufferTestHelper (device , buffer1 ).verify_buffer (flipped = False )
8080 IPCBufferTestHelper (device , buffer2 ).verify_buffer (flipped = True )
8181
82- def child_main (self , mr , idx , queue ):
82+ def child_main (self , device , mr , idx , queue ):
8383 # Note: passing the mr registers it so that buffers can be passed
8484 # directly.
85- device = Device ()
85+ device . set_current ()
8686 buffer1 = queue .get (timeout = CHILD_TIMEOUT_SEC )
8787 buffer2 = queue .get (timeout = CHILD_TIMEOUT_SEC )
8888 if idx == 1 :
@@ -104,8 +104,8 @@ def test_main(self, ipc_device, ipc_memory_resource):
104104
105105 # Start children.
106106 q1 , q2 = (mp .Queue () for _ in range (2 ))
107- p1 = mp .Process (target = self .child_main , args = (alloc_handle , 1 , q1 ))
108- p2 = mp .Process (target = self .child_main , args = (alloc_handle , 2 , q2 ))
107+ p1 = mp .Process (target = self .child_main , args = (device , alloc_handle , 1 , q1 ))
108+ p2 = mp .Process (target = self .child_main , args = (device , alloc_handle , 2 , q2 ))
109109 p1 .start ()
110110 p2 .start ()
111111
@@ -125,11 +125,10 @@ def test_main(self, ipc_device, ipc_memory_resource):
125125 IPCBufferTestHelper (device , buf1 ).verify_buffer (starting_from = 1 )
126126 IPCBufferTestHelper (device , buf2 ).verify_buffer (starting_from = 2 )
127127
128- def child_main (self , alloc_handle , idx , queue ):
128+ def child_main (self , device , alloc_handle , idx , queue ):
129129 """Fills a shared memory buffer."""
130130 # In this case, the device needs to be set up (passing the mr does it
131131 # implicitly in other tests).
132- device = Device ()
133132 device .set_current ()
134133 mr = DeviceMemoryResource .from_allocation_handle (device , alloc_handle )
135134 buffer_descriptor = queue .get (timeout = CHILD_TIMEOUT_SEC )
@@ -149,8 +148,8 @@ def test_main(self, ipc_device, ipc_memory_resource):
149148
150149 # Start children.
151150 q1 , q2 = (mp .Queue () for _ in range (2 ))
152- p1 = mp .Process (target = self .child_main , args = (alloc_handle , 1 , q1 ))
153- p2 = mp .Process (target = self .child_main , args = (alloc_handle , 2 , q2 ))
151+ p1 = mp .Process (target = self .child_main , args = (device , alloc_handle , 1 , q1 ))
152+ p2 = mp .Process (target = self .child_main , args = (device , alloc_handle , 2 , q2 ))
154153 p1 .start ()
155154 p2 .start ()
156155
@@ -170,9 +169,8 @@ def test_main(self, ipc_device, ipc_memory_resource):
170169 IPCBufferTestHelper (device , buf1 ).verify_buffer (starting_from = 1 )
171170 IPCBufferTestHelper (device , buf2 ).verify_buffer (starting_from = 2 )
172171
173- def child_main (self , alloc_handle , idx , queue ):
172+ def child_main (self , device , alloc_handle , idx , queue ):
174173 """Fills a shared memory buffer."""
175- device = Device ()
176174 device .set_current ()
177175
178176 # Register the memory resource.
0 commit comments