Skip to content

Commit 0b2b6b7

Browse files
alexreinkingGitHub Copilot
andauthored
Add __iter__ to RDom bindings (#8861)
* Add __iter__ to RDom Python bindings Previously, iterating over an RDom with list() or a for loop would fail with a KeyError. While __len__ and __getitem__ were defined, the Python iterator protocol requires __iter__ to return an iterator object. This adds a lightweight RDomIterator helper and exposes it via __iter__, allowing natural iteration: r = hl.RDom([(0, 10), (0, 20)]) for v in r: ... list(r) # returns [RVar, RVar] Adds test coverage in the existing rdom.py correctness test. * Fix Python test label propagation in AddPythonTest.cmake The LABEL argument was not being propagated to set_tests_properties, causing -L filters to not work correctly. Now both 'python' and the specific label (e.g., 'python_correctness') are applied to each test. Co-Authored-By: GitHub Copilot <github-copilot@users.noreply.github.com>
1 parent 1c38f3f commit 0b2b6b7

3 files changed

Lines changed: 44 additions & 3 deletions

File tree

python_bindings/cmake/AddPythonTest.cmake

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,16 @@ function(add_python_test)
1313
cmake_path(GET ARG_FILE STEM test_name)
1414
set(test_name "${ARG_LABEL}_${test_name}")
1515

16+
set(test_labels python ${ARG_LABEL})
17+
1618
add_test(
1719
NAME "${test_name}"
1820
COMMAND ${Halide_PYTHON_LAUNCHER} "$<TARGET_FILE:Python::Interpreter>" "$<SHELL_PATH:${CMAKE_CURRENT_SOURCE_DIR}/${ARG_FILE}>" ${ARG_TEST_ARGS}
1921
)
2022
set_tests_properties(
2123
"${test_name}"
2224
PROPERTIES
23-
LABELS "python"
25+
LABELS "${test_labels}"
2426
ENVIRONMENT "${ARG_ENVIRONMENT}"
2527
ENVIRONMENT_MODIFICATION "${ARG_PYTHONPATH}"
2628
SKIP_REGULAR_EXPRESSION "\\[SKIP\\]"

python_bindings/src/halide/halide_/PyRDom.cpp

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,33 @@ void define_rvar(py::module &m) {
3030
void define_rdom(py::module &m) {
3131
define_rvar(m);
3232

33+
// A small iterator wrapper to expose RDom iteration to Python.
34+
// It holds a copy of the RDom and an index, and implements
35+
// the iterator protocol (__iter__ and __next__).
36+
struct RDomIterator {
37+
RDom rd;
38+
int idx = 0;
39+
RDomIterator() = default;
40+
RDomIterator(const RDom &r) : rd(r) {
41+
}
42+
RVar next() {
43+
if (idx >= rd.dimensions()) {
44+
throw py::stop_iteration();
45+
}
46+
return rd[idx++];
47+
}
48+
RDomIterator &iter() {
49+
return *this;
50+
}
51+
};
52+
53+
// Expose the iterator type to Python so we can return it from __iter__.
54+
py::class_<RDomIterator>(m, "_RDomIterator")
55+
.def(py::init<>())
56+
.def(py::init<const RDom &>())
57+
.def("__iter__", &RDomIterator::iter)
58+
.def("__next__", &RDomIterator::next);
59+
3360
auto rdom_class =
3461
py::class_<RDom>(m, "RDom")
3562
.def(py::init<>())
@@ -41,12 +68,13 @@ void define_rdom(py::module &m) {
4168
.def("same_as", &RDom::same_as)
4269
.def("dimensions", &RDom::dimensions)
4370
.def("__len__", &RDom::dimensions)
71+
.def("__iter__", [](const RDom &r) { return RDomIterator(r); }, py::keep_alive<0, 1>())
4472
.def("where", &RDom::where, py::arg("predicate"))
4573
.def("__getitem__", [](RDom &r, const int i) -> RVar {
4674
if (i < 0 || i >= r.dimensions()) {
4775
throw pybind11::key_error();
4876
}
49-
return r[i];
77+
return r[i]; //
5078
})
5179
.def_readonly("x", &RDom::x)
5280
.def_readonly("y", &RDom::y)
@@ -55,7 +83,7 @@ void define_rdom(py::module &m) {
5583
.def("__repr__", [](const RDom &r) -> std::string {
5684
std::ostringstream o;
5785
o << "<halide.RDom " << r << ">";
58-
return o.str();
86+
return o.str(); //
5987
});
6088

6189
add_binary_operators(rdom_class);

python_bindings/test/correctness/rdom.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,16 @@ def test_rdom():
4646
return 0
4747

4848

49+
def test_rdom_iter():
50+
# Verify that RDom is iterable and yields RVars matching index access
51+
r = hl.RDom([(0, 10), (0, 20)], "r")
52+
it = list(r)
53+
assert len(it) == len(r)
54+
assert all(hasattr(v, "min") and hasattr(v, "extent") for v in it)
55+
for i in range(len(r)):
56+
assert str(it[i]) == str(r[i])
57+
58+
4959
def test_implicit_pure_definition():
5060
a = np.random.ranf((2, 3)).astype(np.float32)
5161
expected = a.sum(axis=1)
@@ -67,3 +77,4 @@ def test_implicit_pure_definition():
6777
if __name__ == "__main__":
6878
test_rdom()
6979
test_implicit_pure_definition()
80+
test_rdom_iter()

0 commit comments

Comments
 (0)