Skip to content

Commit 5775dda

Browse files
Ensure temporary is delete is compile fails
Make sure to del cpu_func and my_func that references it.
1 parent 9303966 commit 5775dda

1 file changed

Lines changed: 24 additions & 20 deletions

File tree

cuda_core/examples/strided_memory_view.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,10 @@
7676
extra_compile_args=["-std=c++11"],
7777
)
7878
temp_dir = tempfile.mkdtemp()
79-
cpu_prog.compile(tmpdir=temp_dir)
79+
try:
80+
cpu_prog.compile(tmpdir=temp_dir)
81+
finally:
82+
shutil.rmtree(temp_dir)
8083
saved_sys_path = sys.path
8184
try:
8285
sys.path.append(temp_dir)
@@ -147,25 +150,6 @@ def my_func(arr, work_stream):
147150
cpu_func(cpu_prog.cast("int*", view.ptr), size)
148151

149152

150-
# This takes the CPU path
151-
if FFI:
152-
try:
153-
# Create input array on CPU
154-
arr_cpu = np.zeros(1024, dtype=np.int32)
155-
print(f"before: {arr_cpu[:10]=}")
156-
157-
# Run the workload
158-
my_func(arr_cpu, None)
159-
160-
# Check the result
161-
print(f"after: {arr_cpu[:10]=}")
162-
assert np.allclose(arr_cpu, np.arange(1024, dtype=np.int32))
163-
finally:
164-
# clean up temp directory
165-
del cpu_func
166-
shutil.rmtree(temp_dir)
167-
168-
169153
# This takes the GPU path
170154
if cp:
171155
s = dev.create_stream()
@@ -182,3 +166,23 @@ def my_func(arr, work_stream):
182166
assert cp.allclose(arr_gpu, 1 + cp.arange(1024, dtype=cp.int32))
183167
finally:
184168
s.close()
169+
170+
# This takes the CPU path
171+
if FFI:
172+
try:
173+
# Create input array on CPU
174+
arr_cpu = np.zeros(1024, dtype=np.int32)
175+
print(f"before: {arr_cpu[:10]=}")
176+
177+
# Run the workload
178+
my_func(arr_cpu, None)
179+
180+
# Check the result
181+
print(f"after: {arr_cpu[:10]=}")
182+
assert np.allclose(arr_cpu, np.arange(1024, dtype=np.int32))
183+
finally:
184+
# to allow FFI module to unload, we delete references to
185+
# to cpu_func
186+
del cpu_func, my_func
187+
# clean up temp directory
188+
shutil.rmtree(temp_dir)

0 commit comments

Comments
 (0)