Skip to content

Commit 21be8d5

Browse files
Ramdam17claude
andauthored
fix(sync): cache Metal command queue to stop GPU resource leak (#279)
accorr_metal and the shared Metal dispatch created a new MTLCommandQueue on every call. Command queues are a device-limited resource, so a long surrogate loop exhausted the ceiling and newCommandQueue() returned None, crashing with "'NoneType' object has no attribute 'commandBuffer'" after ~7,800 calls (PLAN-014). Cache a single command queue per device (_get_command_queue), run the dispatch inside an objc.autorelease_pool() to reclaim the transient command buffer and encoder, and release the four inline uint32 const buffers in the finally block. Verified on M4 Max: 36k dispatches with flat RSS (~430 MB), no crash, no "Context leak detected", sync parity tests still pass (37 passed / 3 CuPy skipped). Co-authored-by: Claude Fable 5 <noreply@anthropic.com>
1 parent 3deb4d8 commit 21be8d5

2 files changed

Lines changed: 135 additions & 64 deletions

File tree

hypyp/sync/kernels/_metal_dispatch.py

Lines changed: 88 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,24 @@
1111
1212
ACCorr uses an extended layout with buffer(2) = angle and buffer(3) = output,
1313
so it has its own dispatch function.
14+
15+
Memory
16+
------
17+
Three measures keep GPU memory flat across a long surrogate loop:
18+
19+
1. A single command queue per device, cached by ``_get_command_queue``.
20+
Command queues are device-limited; creating one per dispatch exhausts the
21+
ceiling and ``newCommandQueue()`` then returns ``None`` (the historical
22+
crash).
23+
2. The dispatch runs inside an ``objc.autorelease_pool()`` so the autoreleased
24+
command buffer and encoder are reclaimed on each call.
25+
3. Every owned Metal buffer (data buffers and the four ``uint32`` constant
26+
buffers) is released in a ``finally`` block.
1427
"""
1528

29+
import contextlib
1630
import struct
31+
from functools import lru_cache
1732

1833
import numpy as np
1934

@@ -22,13 +37,41 @@
2237
if METAL_AVAILABLE:
2338
import Metal
2439

40+
try:
41+
import objc
42+
43+
# Drains the autoreleased Metal temporaries (command buffer and
44+
# encoder) at the end of each call. Without a pool draining in a
45+
# tight Python loop they accumulate on the top-level autorelease
46+
# pool until the device refuses a new command queue
47+
# (``newCommandQueue()`` returns None) — the long-run leak.
48+
_autorelease_pool = objc.autorelease_pool
49+
except ImportError: # pragma: no cover - objc ships with pyobjc Metal
50+
_autorelease_pool = contextlib.nullcontext
51+
else: # pragma: no cover - only exercised without PyObjC Metal
52+
_autorelease_pool = contextlib.nullcontext
53+
2554

2655
def make_const_buffer(device, value):
2756
"""Create a Metal buffer containing a single uint32 constant."""
2857
return device.newBufferWithBytes_length_options_(
2958
struct.pack('I', value), 4, Metal.MTLResourceStorageModeShared)
3059

3160

61+
@lru_cache(maxsize=8)
62+
def _get_command_queue(device):
63+
"""
64+
Return a command queue for ``device``, cached one-per-device.
65+
66+
Command queues are a heavyweight, device-limited resource (the device
67+
refuses a new one past a small ceiling). Creating one per dispatch — as
68+
the original code did — exhausts that ceiling in a long surrogate loop
69+
and ``newCommandQueue()`` then returns ``None``. A single queue is the
70+
idiomatic Metal pattern: many command buffers are enqueued onto it.
71+
"""
72+
return device.newCommandQueue()
73+
74+
3275
def run_pairwise_kernel(complex_signal, compile_fn):
3376
"""
3477
Shared dispatch for pairwise Metal kernels with standard buffer layout.
@@ -78,43 +121,57 @@ def run_pairwise_kernel(complex_signal, compile_fn):
78121
buf_pj = device.newBufferWithBytes_length_options_(
79122
idx_j.tobytes(), idx_j.nbytes, Metal.MTLResourceStorageModeShared)
80123

81-
# Dispatch
124+
# Constant buffers held in named locals so they can be released in the
125+
# finally block (passing them inline to setBuffer would leak them).
126+
buf_n_ef = make_const_buffer(device, n_ef)
127+
buf_n_ch = make_const_buffer(device, C)
128+
buf_n_t = make_const_buffer(device, T)
129+
buf_n_pairs = make_const_buffer(device, n_pairs)
130+
131+
# Dispatch — wrapped in an autorelease pool so the autoreleased command
132+
# buffer / encoder are reclaimed each call (see the module docstring).
82133
try:
83-
queue = device.newCommandQueue()
84-
cmd_buffer = queue.commandBuffer()
85-
encoder = cmd_buffer.computeCommandEncoder()
86-
87-
encoder.setComputePipelineState_(pipeline)
88-
encoder.setBuffer_offset_atIndex_(buf_s, 0, 0)
89-
encoder.setBuffer_offset_atIndex_(buf_c, 0, 1)
90-
encoder.setBuffer_offset_atIndex_(buf_out, 0, 2)
91-
encoder.setBuffer_offset_atIndex_(buf_pi, 0, 3)
92-
encoder.setBuffer_offset_atIndex_(buf_pj, 0, 4)
93-
encoder.setBuffer_offset_atIndex_(make_const_buffer(device, n_ef), 0, 5)
94-
encoder.setBuffer_offset_atIndex_(make_const_buffer(device, C), 0, 6)
95-
encoder.setBuffer_offset_atIndex_(make_const_buffer(device, T), 0, 7)
96-
encoder.setBuffer_offset_atIndex_(make_const_buffer(device, n_pairs), 0, 8)
97-
98-
total_threads = n_ef * n_pairs
99-
threads_per_group = min(256, pipeline.maxTotalThreadsPerThreadgroup())
100-
101-
encoder.dispatchThreads_threadsPerThreadgroup_(
102-
Metal.MTLSize(total_threads, 1, 1),
103-
Metal.MTLSize(threads_per_group, 1, 1))
104-
encoder.endEncoding()
105-
106-
cmd_buffer.commit()
107-
cmd_buffer.waitUntilCompleted()
108-
109-
out_ptr = buf_out.contents()
110-
membuf = out_ptr.as_buffer(out_nbytes)
111-
result = np.frombuffer(membuf, dtype=np.float32).copy().reshape(n_ef, C, C)
134+
with _autorelease_pool():
135+
queue = _get_command_queue(device)
136+
cmd_buffer = queue.commandBuffer()
137+
encoder = cmd_buffer.computeCommandEncoder()
138+
139+
encoder.setComputePipelineState_(pipeline)
140+
encoder.setBuffer_offset_atIndex_(buf_s, 0, 0)
141+
encoder.setBuffer_offset_atIndex_(buf_c, 0, 1)
142+
encoder.setBuffer_offset_atIndex_(buf_out, 0, 2)
143+
encoder.setBuffer_offset_atIndex_(buf_pi, 0, 3)
144+
encoder.setBuffer_offset_atIndex_(buf_pj, 0, 4)
145+
encoder.setBuffer_offset_atIndex_(buf_n_ef, 0, 5)
146+
encoder.setBuffer_offset_atIndex_(buf_n_ch, 0, 6)
147+
encoder.setBuffer_offset_atIndex_(buf_n_t, 0, 7)
148+
encoder.setBuffer_offset_atIndex_(buf_n_pairs, 0, 8)
149+
150+
total_threads = n_ef * n_pairs
151+
threads_per_group = min(256, pipeline.maxTotalThreadsPerThreadgroup())
152+
153+
encoder.dispatchThreads_threadsPerThreadgroup_(
154+
Metal.MTLSize(total_threads, 1, 1),
155+
Metal.MTLSize(threads_per_group, 1, 1))
156+
encoder.endEncoding()
157+
158+
cmd_buffer.commit()
159+
cmd_buffer.waitUntilCompleted()
160+
161+
out_ptr = buf_out.contents()
162+
membuf = out_ptr.as_buffer(out_nbytes)
163+
result = np.frombuffer(membuf, dtype=np.float32).copy().reshape(n_ef, C, C)
112164

113165
return result.reshape(E, F, C, C)
114166
finally:
115-
# Critical: Release all Metal buffers to prevent GPU memory leak
167+
# Release every owned Metal buffer to prevent GPU memory growth —
168+
# the data buffers and the four const buffers alike.
116169
buf_s.release()
117170
buf_c.release()
118171
buf_out.release()
119172
buf_pi.release()
120173
buf_pj.release()
174+
buf_n_ef.release()
175+
buf_n_ch.release()
176+
buf_n_t.release()
177+
buf_n_pairs.release()

hypyp/sync/kernels/metal_accorr.py

Lines changed: 47 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import numpy as np
1616

1717
from . import METAL_AVAILABLE
18-
from ._metal_dispatch import make_const_buffer
18+
from ._metal_dispatch import make_const_buffer, _autorelease_pool, _get_command_queue
1919

2020
if METAL_AVAILABLE:
2121
import Metal
@@ -150,45 +150,59 @@ def accorr_metal(complex_signal: np.ndarray) -> np.ndarray:
150150
buf_pj = device.newBufferWithBytes_length_options_(
151151
idx_j.tobytes(), idx_j.nbytes, Metal.MTLResourceStorageModeShared)
152152

153-
# Dispatch
153+
# Constant buffers held in named locals so they can be released in the
154+
# finally block (passing them inline to setBuffer would leak them).
155+
buf_n_ef = make_const_buffer(device, n_ef)
156+
buf_n_ch = make_const_buffer(device, C)
157+
buf_n_t = make_const_buffer(device, T)
158+
buf_n_pairs = make_const_buffer(device, n_pairs)
159+
160+
# Dispatch — wrapped in an autorelease pool so the autoreleased command
161+
# buffer / encoder are reclaimed each call (see the module docstring).
154162
try:
155-
queue = device.newCommandQueue()
156-
cmd_buffer = queue.commandBuffer()
157-
encoder = cmd_buffer.computeCommandEncoder()
158-
159-
encoder.setComputePipelineState_(pipeline)
160-
encoder.setBuffer_offset_atIndex_(buf_s, 0, 0)
161-
encoder.setBuffer_offset_atIndex_(buf_c, 0, 1)
162-
encoder.setBuffer_offset_atIndex_(buf_angle, 0, 2)
163-
encoder.setBuffer_offset_atIndex_(buf_out, 0, 3)
164-
encoder.setBuffer_offset_atIndex_(buf_pi, 0, 4)
165-
encoder.setBuffer_offset_atIndex_(buf_pj, 0, 5)
166-
encoder.setBuffer_offset_atIndex_(make_const_buffer(device, n_ef), 0, 6)
167-
encoder.setBuffer_offset_atIndex_(make_const_buffer(device, C), 0, 7)
168-
encoder.setBuffer_offset_atIndex_(make_const_buffer(device, T), 0, 8)
169-
encoder.setBuffer_offset_atIndex_(make_const_buffer(device, n_pairs), 0, 9)
170-
171-
total_threads = n_ef * n_pairs
172-
threads_per_group = min(256, pipeline.maxTotalThreadsPerThreadgroup())
173-
174-
encoder.dispatchThreads_threadsPerThreadgroup_(
175-
Metal.MTLSize(total_threads, 1, 1),
176-
Metal.MTLSize(threads_per_group, 1, 1))
177-
encoder.endEncoding()
178-
179-
cmd_buffer.commit()
180-
cmd_buffer.waitUntilCompleted()
181-
182-
out_ptr = buf_out.contents()
183-
membuf = out_ptr.as_buffer(out_nbytes)
184-
result = np.frombuffer(membuf, dtype=np.float32).copy().reshape(n_ef, C, C)
163+
with _autorelease_pool():
164+
queue = _get_command_queue(device)
165+
cmd_buffer = queue.commandBuffer()
166+
encoder = cmd_buffer.computeCommandEncoder()
167+
168+
encoder.setComputePipelineState_(pipeline)
169+
encoder.setBuffer_offset_atIndex_(buf_s, 0, 0)
170+
encoder.setBuffer_offset_atIndex_(buf_c, 0, 1)
171+
encoder.setBuffer_offset_atIndex_(buf_angle, 0, 2)
172+
encoder.setBuffer_offset_atIndex_(buf_out, 0, 3)
173+
encoder.setBuffer_offset_atIndex_(buf_pi, 0, 4)
174+
encoder.setBuffer_offset_atIndex_(buf_pj, 0, 5)
175+
encoder.setBuffer_offset_atIndex_(buf_n_ef, 0, 6)
176+
encoder.setBuffer_offset_atIndex_(buf_n_ch, 0, 7)
177+
encoder.setBuffer_offset_atIndex_(buf_n_t, 0, 8)
178+
encoder.setBuffer_offset_atIndex_(buf_n_pairs, 0, 9)
179+
180+
total_threads = n_ef * n_pairs
181+
threads_per_group = min(256, pipeline.maxTotalThreadsPerThreadgroup())
182+
183+
encoder.dispatchThreads_threadsPerThreadgroup_(
184+
Metal.MTLSize(total_threads, 1, 1),
185+
Metal.MTLSize(threads_per_group, 1, 1))
186+
encoder.endEncoding()
187+
188+
cmd_buffer.commit()
189+
cmd_buffer.waitUntilCompleted()
190+
191+
out_ptr = buf_out.contents()
192+
membuf = out_ptr.as_buffer(out_nbytes)
193+
result = np.frombuffer(membuf, dtype=np.float32).copy().reshape(n_ef, C, C)
185194

186195
return result.reshape(E, F, C, C)
187196
finally:
188-
# Critical: Release all Metal buffers to prevent GPU memory leak
197+
# Release every owned Metal buffer to prevent GPU memory growth —
198+
# the data buffers and the four const buffers alike.
189199
buf_s.release()
190200
buf_c.release()
191201
buf_angle.release()
192202
buf_out.release()
193203
buf_pi.release()
194204
buf_pj.release()
205+
buf_n_ef.release()
206+
buf_n_ch.release()
207+
buf_n_t.release()
208+
buf_n_pairs.release()

0 commit comments

Comments
 (0)