|
11 | 11 |
|
12 | 12 | ACCorr uses an extended layout with buffer(2) = angle and buffer(3) = output, |
13 | 13 | so it has its own dispatch function. |
| 14 | +
|
| 15 | +Memory |
| 16 | +------ |
| 17 | +Three measures keep GPU memory flat across a long surrogate loop: |
| 18 | +
|
| 19 | +1. A single command queue per device, cached by ``_get_command_queue``. |
| 20 | + Command queues are device-limited; creating one per dispatch exhausts the |
| 21 | + ceiling and ``newCommandQueue()`` then returns ``None`` (the historical |
| 22 | + crash). |
| 23 | +2. The dispatch runs inside an ``objc.autorelease_pool()`` so the autoreleased |
| 24 | + command buffer and encoder are reclaimed on each call. |
| 25 | +3. Every owned Metal buffer (data buffers and the four ``uint32`` constant |
| 26 | + buffers) is released in a ``finally`` block. |
14 | 27 | """ |
15 | 28 |
|
| 29 | +import contextlib |
16 | 30 | import struct |
| 31 | +from functools import lru_cache |
17 | 32 |
|
18 | 33 | import numpy as np |
19 | 34 |
|
|
22 | 37 | if METAL_AVAILABLE: |
23 | 38 | import Metal |
24 | 39 |
|
| 40 | + try: |
| 41 | + import objc |
| 42 | + |
| 43 | + # Drains the autoreleased Metal temporaries (command buffer and |
| 44 | + # encoder) at the end of each call. Without a pool draining in a |
| 45 | + # tight Python loop they accumulate on the top-level autorelease |
| 46 | + # pool until the device refuses a new command queue |
| 47 | + # (``newCommandQueue()`` returns None) — the long-run leak. |
| 48 | + _autorelease_pool = objc.autorelease_pool |
| 49 | + except ImportError: # pragma: no cover - objc ships with pyobjc Metal |
| 50 | + _autorelease_pool = contextlib.nullcontext |
| 51 | +else: # pragma: no cover - only exercised without PyObjC Metal |
| 52 | + _autorelease_pool = contextlib.nullcontext |
| 53 | + |
25 | 54 |
|
26 | 55 | def make_const_buffer(device, value): |
27 | 56 | """Create a Metal buffer containing a single uint32 constant.""" |
28 | 57 | return device.newBufferWithBytes_length_options_( |
29 | 58 | struct.pack('I', value), 4, Metal.MTLResourceStorageModeShared) |
30 | 59 |
|
31 | 60 |
|
| 61 | +@lru_cache(maxsize=8) |
| 62 | +def _get_command_queue(device): |
| 63 | + """ |
| 64 | + Return a command queue for ``device``, cached one-per-device. |
| 65 | +
|
| 66 | + Command queues are a heavyweight, device-limited resource (the device |
| 67 | + refuses a new one past a small ceiling). Creating one per dispatch — as |
| 68 | + the original code did — exhausts that ceiling in a long surrogate loop |
| 69 | + and ``newCommandQueue()`` then returns ``None``. A single queue is the |
| 70 | + idiomatic Metal pattern: many command buffers are enqueued onto it. |
| 71 | + """ |
| 72 | + return device.newCommandQueue() |
| 73 | + |
| 74 | + |
32 | 75 | def run_pairwise_kernel(complex_signal, compile_fn): |
33 | 76 | """ |
34 | 77 | Shared dispatch for pairwise Metal kernels with standard buffer layout. |
@@ -78,43 +121,57 @@ def run_pairwise_kernel(complex_signal, compile_fn): |
78 | 121 | buf_pj = device.newBufferWithBytes_length_options_( |
79 | 122 | idx_j.tobytes(), idx_j.nbytes, Metal.MTLResourceStorageModeShared) |
80 | 123 |
|
81 | | - # Dispatch |
| 124 | + # Constant buffers held in named locals so they can be released in the |
| 125 | + # finally block (passing them inline to setBuffer would leak them). |
| 126 | + buf_n_ef = make_const_buffer(device, n_ef) |
| 127 | + buf_n_ch = make_const_buffer(device, C) |
| 128 | + buf_n_t = make_const_buffer(device, T) |
| 129 | + buf_n_pairs = make_const_buffer(device, n_pairs) |
| 130 | + |
| 131 | + # Dispatch — wrapped in an autorelease pool so the autoreleased command |
| 132 | + # buffer / encoder are reclaimed each call (see the module docstring). |
82 | 133 | try: |
83 | | - queue = device.newCommandQueue() |
84 | | - cmd_buffer = queue.commandBuffer() |
85 | | - encoder = cmd_buffer.computeCommandEncoder() |
86 | | - |
87 | | - encoder.setComputePipelineState_(pipeline) |
88 | | - encoder.setBuffer_offset_atIndex_(buf_s, 0, 0) |
89 | | - encoder.setBuffer_offset_atIndex_(buf_c, 0, 1) |
90 | | - encoder.setBuffer_offset_atIndex_(buf_out, 0, 2) |
91 | | - encoder.setBuffer_offset_atIndex_(buf_pi, 0, 3) |
92 | | - encoder.setBuffer_offset_atIndex_(buf_pj, 0, 4) |
93 | | - encoder.setBuffer_offset_atIndex_(make_const_buffer(device, n_ef), 0, 5) |
94 | | - encoder.setBuffer_offset_atIndex_(make_const_buffer(device, C), 0, 6) |
95 | | - encoder.setBuffer_offset_atIndex_(make_const_buffer(device, T), 0, 7) |
96 | | - encoder.setBuffer_offset_atIndex_(make_const_buffer(device, n_pairs), 0, 8) |
97 | | - |
98 | | - total_threads = n_ef * n_pairs |
99 | | - threads_per_group = min(256, pipeline.maxTotalThreadsPerThreadgroup()) |
100 | | - |
101 | | - encoder.dispatchThreads_threadsPerThreadgroup_( |
102 | | - Metal.MTLSize(total_threads, 1, 1), |
103 | | - Metal.MTLSize(threads_per_group, 1, 1)) |
104 | | - encoder.endEncoding() |
105 | | - |
106 | | - cmd_buffer.commit() |
107 | | - cmd_buffer.waitUntilCompleted() |
108 | | - |
109 | | - out_ptr = buf_out.contents() |
110 | | - membuf = out_ptr.as_buffer(out_nbytes) |
111 | | - result = np.frombuffer(membuf, dtype=np.float32).copy().reshape(n_ef, C, C) |
| 134 | + with _autorelease_pool(): |
| 135 | + queue = _get_command_queue(device) |
| 136 | + cmd_buffer = queue.commandBuffer() |
| 137 | + encoder = cmd_buffer.computeCommandEncoder() |
| 138 | + |
| 139 | + encoder.setComputePipelineState_(pipeline) |
| 140 | + encoder.setBuffer_offset_atIndex_(buf_s, 0, 0) |
| 141 | + encoder.setBuffer_offset_atIndex_(buf_c, 0, 1) |
| 142 | + encoder.setBuffer_offset_atIndex_(buf_out, 0, 2) |
| 143 | + encoder.setBuffer_offset_atIndex_(buf_pi, 0, 3) |
| 144 | + encoder.setBuffer_offset_atIndex_(buf_pj, 0, 4) |
| 145 | + encoder.setBuffer_offset_atIndex_(buf_n_ef, 0, 5) |
| 146 | + encoder.setBuffer_offset_atIndex_(buf_n_ch, 0, 6) |
| 147 | + encoder.setBuffer_offset_atIndex_(buf_n_t, 0, 7) |
| 148 | + encoder.setBuffer_offset_atIndex_(buf_n_pairs, 0, 8) |
| 149 | + |
| 150 | + total_threads = n_ef * n_pairs |
| 151 | + threads_per_group = min(256, pipeline.maxTotalThreadsPerThreadgroup()) |
| 152 | + |
| 153 | + encoder.dispatchThreads_threadsPerThreadgroup_( |
| 154 | + Metal.MTLSize(total_threads, 1, 1), |
| 155 | + Metal.MTLSize(threads_per_group, 1, 1)) |
| 156 | + encoder.endEncoding() |
| 157 | + |
| 158 | + cmd_buffer.commit() |
| 159 | + cmd_buffer.waitUntilCompleted() |
| 160 | + |
| 161 | + out_ptr = buf_out.contents() |
| 162 | + membuf = out_ptr.as_buffer(out_nbytes) |
| 163 | + result = np.frombuffer(membuf, dtype=np.float32).copy().reshape(n_ef, C, C) |
112 | 164 |
|
113 | 165 | return result.reshape(E, F, C, C) |
114 | 166 | finally: |
115 | | - # Critical: Release all Metal buffers to prevent GPU memory leak |
| 167 | + # Release every owned Metal buffer to prevent GPU memory growth — |
| 168 | + # the data buffers and the four const buffers alike. |
116 | 169 | buf_s.release() |
117 | 170 | buf_c.release() |
118 | 171 | buf_out.release() |
119 | 172 | buf_pi.release() |
120 | 173 | buf_pj.release() |
| 174 | + buf_n_ef.release() |
| 175 | + buf_n_ch.release() |
| 176 | + buf_n_t.release() |
| 177 | + buf_n_pairs.release() |
0 commit comments