Skip to content

Commit 05a513d

Browse files
JacoCheungclaude
andcommitted
perf(commons): C++ Karmarkar-Karp partitioner releases GIL
KK runs on every load-balanced batch shuffle. The pure-Python implementation uses heapq with Python __lt__ callbacks, which keep the GIL the whole time KK runs in the background ThreadPoolExecutor. nsys shows a visible "Waiting for GIL" gap on the main thread (~1.5 ms) during the karmarkar_karp NVTX range — the main thread cannot keep submitting CUDA kernels while KK contends the GIL. This commit ports KK to a pybind11 C++ extension (kk_cpu_ops) with `py::gil_scoped_release` around the compute. Output is bit-for-bit identical to the Python implementation (same Set.__lt__ / State.__lt__ tie-breaking). partitioner.py resolution order: 1. honour KK_FORCE_PYTHON=1 (escape hatch / parity tests) 2. top-level import kk_cpu_ops (matches `python setup.py install` layout — picked up automatically when the Docker image is rebuilt via examples/commons/setup.py) 3. sibling .so next to the perf_model package (matches `python setup.py build_ext --inplace` during dev iteration) 4. fall back to the pure-Python implementation otherwise `BUILD_EXT_ONLY=kk_cpu_ops` env knob in setup.py rebuilds just the CPU partitioner without spending minutes on the CUDA extensions. Measurements: Isolated micro-benchmark (n=k*4096=32768, k=8): Python median 237.6 ms → C++ median 13.5 ms (17.6x) nsys NVTX karmarkar_karp duration on the kk worker thread: Python ms-scale → C++ avg 135 us / max 270 us (50x shorter) Main-thread "Waiting for GIL" gap during the karmarkar_karp NVTX range (per the original nsys screenshot, ~1.36 ms): gone under C++. Tests (26 cases in examples/tests/commons/test_kk_partitioner.py): - 16 parity cases across k = 2/4/8/16 x n/k = 1/4/32/128 - 3 unequal-size parity cases - Adversarial: all-equal workloads, deliberate len(items) ties - Sufficiency invariants: partitions cover [0, n) exactly once; KK beats naive chunking on skewed input. - test_gil_released_during_compute runs C++ KK in a background thread while the main thread runs a tight Python compute loop, asserts the main thread retains > 50 % of its GIL-free baseline throughput — proves the release actually fires. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 1782d6d commit 05a513d

4 files changed

Lines changed: 554 additions & 57 deletions

File tree

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
// Karmarkar-Karp k-way partitioning in C++ — drop-in replacement for the
5+
// pure-Python implementation in `partitioner.py`. The whole compute path
6+
// releases the GIL so the main Python thread can keep submitting CUDA
7+
// kernels while the algorithm runs in a background ThreadPoolExecutor.
8+
//
9+
// Output is bit-for-bit identical to the Python version (same tie-breaking
10+
// rules) so it can be swapped in without changing downstream behaviour.
11+
12+
#include <pybind11/pybind11.h>
13+
#include <pybind11/stl.h>
14+
15+
#include <algorithm>
16+
#include <cstdint>
17+
#include <functional>
18+
#include <stdexcept>
19+
#include <utility>
20+
#include <vector>
21+
22+
namespace py = pybind11;
23+
24+
namespace {
25+
26+
struct Set {
27+
int64_t sum = 0;
28+
std::vector<std::pair<int64_t, int64_t>> items; // (idx, val)
29+
30+
void add(int64_t idx, int64_t val) {
31+
items.emplace_back(idx, val);
32+
sum += val;
33+
}
34+
35+
void merge_from(Set& other) {
36+
items.reserve(items.size() + other.items.size());
37+
for (auto& it : other.items) {
38+
items.push_back(it);
39+
sum += it.second;
40+
}
41+
}
42+
43+
// Matches Python `Set.__lt__`:
44+
// if sum != other.sum: return sum < other.sum
45+
// if len(items) != len(other.items): return len(items) < len(other.items)
46+
// return items < other.items # lexicographic
47+
bool operator<(const Set& other) const {
48+
if (sum != other.sum) return sum < other.sum;
49+
if (items.size() != other.items.size())
50+
return items.size() < other.items.size();
51+
return items < other.items;
52+
}
53+
bool operator>(const Set& other) const { return other < *this; }
54+
};
55+
56+
struct State {
57+
int k;
58+
std::vector<Set> sets; // maintained in *descending* order (sets[0] largest)
59+
60+
explicit State(int k_) : k(k_), sets(k_) {}
61+
62+
// ``items`` has length in [1, k]; element i goes into sets[i] (matching
63+
// Python init), then sets are sorted descending.
64+
void init_from(const std::vector<std::pair<int64_t, int64_t>>& items) {
65+
for (size_t i = 0; i < items.size(); ++i) {
66+
sets[i].add(items[i].first, items[i].second);
67+
}
68+
std::sort(sets.begin(), sets.end(), std::greater<Set>());
69+
}
70+
71+
// Python `merge`: pair sets[i] ↔ other.sets[k-1-i], then resort descending.
72+
void merge_with(State& other) {
73+
for (int i = 0; i < k; ++i) {
74+
sets[i].merge_from(other.sets[k - 1 - i]);
75+
}
76+
std::sort(sets.begin(), sets.end(), std::greater<Set>());
77+
}
78+
79+
int64_t spread() const { return sets.front().sum - sets.back().sum; }
80+
81+
// Heap ordering. Python uses a min-heap (`heapq`) with `State.__lt__`
82+
// flipped so the state with the LARGEST spread is popped first:
83+
// if spread != other.spread: return spread > other.spread
84+
// return sets[0] > other.sets[0]
85+
//
86+
// ``std::priority_queue`` / ``std::push_heap`` give a max-heap based on
87+
// ``operator<``: the element where ``a < b`` is true for every other ``b``
88+
// gets popped LAST. So define ``operator<`` such that "smaller" means
89+
// "lower priority" (popped later), which means we want LARGER spread
90+
// (and, on tie, larger ``sets[0]``) to compare as GREATER.
91+
bool operator<(const State& other) const {
92+
const int64_t s0 = spread();
93+
const int64_t s1 = other.spread();
94+
if (s0 != s1) return s0 < s1;
95+
return sets.front() < other.sets.front();
96+
}
97+
};
98+
99+
std::vector<std::vector<int64_t>> karmarkar_karp_cpp(
100+
std::vector<int64_t> workloads,
101+
int k_partitions,
102+
bool equal_size) {
103+
// Release the GIL for the entire compute. ``workloads`` was already
104+
// pickled in by pybind11 (when called across processes) or copied from a
105+
// Python list (when called in-process) before this point, so we do not
106+
// touch any Python object until we return.
107+
py::gil_scoped_release release;
108+
109+
if (k_partitions <= 0) {
110+
throw std::invalid_argument("k_partitions must be > 0");
111+
}
112+
const size_t n = workloads.size();
113+
if (equal_size && (n % static_cast<size_t>(k_partitions) != 0)) {
114+
throw std::invalid_argument(
115+
"len(workloads) must be divisible by k_partitions when equal_size=True");
116+
}
117+
if (n == 0) {
118+
return std::vector<std::vector<int64_t>>(k_partitions);
119+
}
120+
121+
// Match Python's ``sorted([(workload, i) for i, workload in enumerate(workloads)])``
122+
// — ascending by (workload, idx). std::pair<int64_t,int64_t>::operator< is
123+
// lexicographic, so a plain std::sort on (workload, idx) does it.
124+
std::vector<std::pair<int64_t, int64_t>> sorted_workloads;
125+
sorted_workloads.reserve(n);
126+
for (size_t i = 0; i < n; ++i) {
127+
sorted_workloads.emplace_back(workloads[i], static_cast<int64_t>(i));
128+
}
129+
std::sort(sorted_workloads.begin(), sorted_workloads.end());
130+
131+
// Build initial heap of States.
132+
std::vector<State> heap;
133+
heap.reserve(equal_size ? n / k_partitions : n);
134+
135+
if (equal_size) {
136+
std::vector<std::pair<int64_t, int64_t>> group;
137+
group.reserve(k_partitions);
138+
for (size_t off = 0; off < n; off += k_partitions) {
139+
group.clear();
140+
for (int i = 0; i < k_partitions; ++i) {
141+
const auto& [workload, idx] = sorted_workloads[off + i];
142+
// Python: items.append((idx, workload)) (note: (idx, workload), not (workload, idx))
143+
group.emplace_back(idx, workload);
144+
}
145+
State s(k_partitions);
146+
s.init_from(group);
147+
heap.push_back(std::move(s));
148+
}
149+
} else {
150+
std::vector<std::pair<int64_t, int64_t>> single(1);
151+
for (const auto& [workload, idx] : sorted_workloads) {
152+
single[0] = {idx, workload};
153+
State s(k_partitions);
154+
s.init_from(single);
155+
heap.push_back(std::move(s));
156+
}
157+
}
158+
std::make_heap(heap.begin(), heap.end());
159+
160+
while (heap.size() > 1) {
161+
std::pop_heap(heap.begin(), heap.end());
162+
State s0 = std::move(heap.back());
163+
heap.pop_back();
164+
165+
std::pop_heap(heap.begin(), heap.end());
166+
State s1 = std::move(heap.back());
167+
heap.pop_back();
168+
169+
s0.merge_with(s1);
170+
heap.push_back(std::move(s0));
171+
std::push_heap(heap.begin(), heap.end());
172+
}
173+
174+
// Extract partitions from the surviving state.
175+
State& final_state = heap.front();
176+
std::vector<std::vector<int64_t>> partitions(k_partitions);
177+
for (int i = 0; i < k_partitions; ++i) {
178+
auto& src = final_state.sets[i].items;
179+
auto& dst = partitions[i];
180+
dst.reserve(src.size());
181+
for (const auto& [idx, _val] : src) {
182+
dst.push_back(idx);
183+
}
184+
}
185+
return partitions;
186+
}
187+
188+
} // namespace
189+
190+
PYBIND11_MODULE(kk_cpu_ops, m) {
191+
m.doc() =
192+
"C++ Karmarkar-Karp k-way partitioning. Releases the GIL during compute "
193+
"so the main Python thread can keep submitting CUDA kernels.";
194+
m.def(
195+
"karmarkar_karp",
196+
&karmarkar_karp_cpp,
197+
py::arg("workloads"),
198+
py::arg("k_partitions"),
199+
py::arg("equal_size"),
200+
"Identical output to commons.perf_model.partitioner.karmarkar_karp "
201+
"(same tie-breaking rules), but with the GIL released for the entire "
202+
"compute.");
203+
}

examples/commons/perf_model/partitioner.py

Lines changed: 66 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
# limitations under the License.
2828

2929
import heapq
30+
import os
3031
from typing import Any, List, Tuple, Union
3132

3233
import numpy as np
@@ -38,10 +39,74 @@
3839
Tensor = None
3940
nvtx = None
4041

42+
# Optional C++ accelerator. Same output as the Python implementation but
43+
# releases the GIL for the entire compute, so the main thread can keep
44+
# submitting CUDA kernels while KK runs in a background ThreadPoolExecutor.
45+
# Set ``KK_FORCE_PYTHON=1`` to bypass the C++ path (useful for parity tests).
46+
#
47+
# Resolution order:
48+
# 1. Honour ``KK_FORCE_PYTHON=1`` → no native module.
49+
# 2. Top-level import — the location used by ``python setup.py install``
50+
# inside the container (``/usr/local/lib/.../dist-packages``).
51+
# 3. Sibling .so next to the ``perf_model`` package — the location used by
52+
# ``python setup.py build_ext --inplace`` during dev iteration.
53+
_FORCE_PYTHON = os.environ.get("KK_FORCE_PYTHON", "0") == "1"
54+
_kk_cpu_ops = None
55+
if not _FORCE_PYTHON:
56+
try:
57+
import kk_cpu_ops as _kk_cpu_ops # type: ignore[import-not-found,no-redef]
58+
except ImportError:
59+
import glob as _glob
60+
import importlib.util as _importlib_util
61+
62+
_so_glob = os.path.join(
63+
os.path.dirname(os.path.dirname(__file__)),
64+
"kk_cpu_ops*.so",
65+
)
66+
_matches = sorted(_glob.glob(_so_glob))
67+
if _matches:
68+
_spec = _importlib_util.spec_from_file_location("kk_cpu_ops", _matches[0])
69+
if _spec is not None and _spec.loader is not None:
70+
_kk_cpu_ops = _importlib_util.module_from_spec(_spec)
71+
_spec.loader.exec_module(_kk_cpu_ops)
72+
4173

4274
def karmarkar_karp(
4375
workloads: Union[np.ndarray, List[int], Tensor], k_partitions: int, equal_size: bool
4476
):
77+
"""K-way load-balanced partitioning via Karmarkar-Karp.
78+
79+
Returns ``k_partitions`` lists of original indices. When the C++ accelerator
80+
``kk_cpu_ops`` is importable, the heavy heap traversal runs without the
81+
GIL; otherwise the pure-Python fallback below is used. Output is
82+
bit-identical between the two paths (same tie-breaking).
83+
"""
84+
if nvtx is not None:
85+
nvtx.range_push("karmarkar_karp")
86+
try:
87+
# Normalize to a plain Python list of ints. Tensors / ndarrays both
88+
# have ``.tolist()``; built-in lists do not, so a hasattr check picks
89+
# the right branch.
90+
if hasattr(workloads, "tolist"):
91+
workloads = workloads.tolist()
92+
93+
if _kk_cpu_ops is not None:
94+
partitions = _kk_cpu_ops.karmarkar_karp(workloads, k_partitions, equal_size)
95+
else:
96+
partitions = _karmarkar_karp_python(workloads, k_partitions, equal_size)
97+
98+
if equal_size:
99+
for partition in partitions:
100+
assert len(partition) * k_partitions == len(
101+
workloads
102+
), f"{len(partition)} * {k_partitions} != {len(workloads)}"
103+
return partitions
104+
finally:
105+
if nvtx is not None:
106+
nvtx.range_pop() # karmarkar_karp
107+
108+
109+
def _karmarkar_karp_python(workloads: List[int], k_partitions: int, equal_size: bool):
45110
# see: https://en.wikipedia.org/wiki/Largest_differencing_method
46111
class Set:
47112
def __init__(self) -> None:
@@ -114,14 +179,6 @@ def __repr__(self) -> str:
114179
repr_str += "]"
115180
return repr_str
116181

117-
if nvtx is not None:
118-
nvtx.range_push("karmarkar_karp")
119-
120-
workloads = (
121-
workloads.tolist()
122-
if isinstance(workloads, Tensor) and Tensor is not None
123-
else workloads
124-
)
125182
sorted_workloads = sorted([(workload, i) for i, workload in enumerate(workloads)])
126183
states_pq: List[Any] = []
127184
if equal_size:
@@ -145,16 +202,7 @@ def __repr__(self) -> str:
145202
state0.merge(state1)
146203
heapq.heappush(states_pq, state0)
147204

148-
final_state = states_pq[0]
149-
partitions = final_state.get_partitions()
150-
if equal_size:
151-
for i, partition in enumerate(partitions):
152-
assert len(partition) * k_partitions == len(
153-
workloads
154-
), f"{len(partition)} * {k_partitions} != {len(workloads)}"
155-
if nvtx is not None:
156-
nvtx.range_pop() # karmarkar_karp
157-
return partitions
205+
return states_pq[0].get_partitions()
158206

159207

160208
if __name__ == "__main__":

0 commit comments

Comments
 (0)