Skip to content

Commit 1c38f3f

Browse files
authored
Fix the implementation of the select() binding (#8862)
The new implementation corrects the select() grammar and gives better error messages when nested expressions (inside tuples) fail to cast / convert to the appropriate type. Also add human-friendly __str__ overloads to Expr and Tuple
1 parent 1b9be55 commit 1c38f3f

8 files changed

Lines changed: 108 additions & 86 deletions

File tree

python_bindings/src/halide/halide_/PyExpr.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,11 @@ void define_expr(py::module &m) {
5757
o << "<undefined halide.Expr>";
5858
}
5959
return o.str();
60+
})
61+
.def("__str__", [](const Expr &e) -> std::string {
62+
std::ostringstream o;
63+
o << e;
64+
return o.str();
6065
});
6166

6267
add_binary_operators(expr_class);

python_bindings/src/halide/halide_/PyIROperator.cpp

Lines changed: 54 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,59 @@ namespace Halide {
1818

1919
namespace PythonBindings {
2020

21+
namespace {
22+
23+
bool is_expr(const py::handle &obj) {
24+
try {
25+
(void)obj.cast<Expr>(); // Check if casting succeeds
26+
return true;
27+
} catch (const py::cast_error &) {
28+
return false;
29+
}
30+
}
31+
32+
template<typename T>
33+
T cast_arg(const py::handle &arg) {
34+
try {
35+
return arg.cast<T>();
36+
} catch (const py::cast_error &) {
37+
_halide_user_error
38+
<< "select(): Expected " << py::str(py::type::of<T>().attr("__name__"))
39+
<< " but got " << py::str(arg.get_type().attr("__name__")) << ": "
40+
<< py::str(arg);
41+
}
42+
}
43+
44+
template<typename TCond, typename TVal>
45+
py::object py_select_reduce(const py::args &args) {
46+
auto false_case = cast_arg<TVal>(args[args.size() - 1]);
47+
for (size_t pos = args.size() - 1; pos >= 2; pos -= 2) {
48+
auto true_case = cast_arg<TVal>(args[pos - 1]);
49+
auto condition = cast_arg<TCond>(args[pos - 2]);
50+
false_case = select(condition, true_case, false_case);
51+
}
52+
return py::cast(false_case);
53+
}
54+
55+
py::object py_select(const py::args &args) {
56+
if (args.size() < 3) {
57+
throw py::value_error("select() must have at least 3 arguments");
58+
}
59+
60+
if (args.size() % 2 != 1) {
61+
throw py::value_error("select() must have an odd number of arguments");
62+
}
63+
64+
if (is_expr(args[0])) { // If the condition is an Expr, then ...
65+
return is_expr(args[1]) ? // ... we need to check the value's kind.
66+
py_select_reduce<Expr, Expr>(args) :
67+
py_select_reduce<Expr, Tuple>(args);
68+
}
69+
return py_select_reduce<Tuple, Tuple>(args); // Otherwise, the value must be a tuple, too.
70+
}
71+
72+
} // namespace
73+
2174
void define_operators(py::module &m) {
2275
m.def("max", [](const py::args &args) -> Expr {
2376
if (args.size() < 2) {
@@ -48,85 +101,7 @@ void define_operators(py::module &m) {
48101
m.def("abs", &abs);
49102
m.def("absd", &absd);
50103

51-
m.def("select", [](const py::args &args) -> py::object {
52-
if (args.size() < 3) {
53-
throw py::value_error("select() must have at least 3 arguments");
54-
}
55-
if ((args.size() % 2) == 0) {
56-
throw py::value_error("select() must have an odd number of arguments");
57-
}
58-
59-
// Tricky set of options here:
60-
//
61-
// - (Expr, Expr, Expr, [Expr, Expr...]) -> Expr
62-
// - (Expr, Tuple, Tuple, [Tuple, Tuple...]) -> Tuple [Tuples must be same arity]
63-
// - (Tuple, Tuple, Tuple, [Tuple, Tuple...]) -> Tuple [Tuples must be same arity]
64-
//
65-
// It's made trickier by the fact that it's hard to do a reliable "is-a" check for Tuple here,
66-
// so we'll do the slow-but-reliable approach of just trying to cast to Tuple and catching
67-
// exceptions.
68-
69-
std::string tuple_error_msg;
70-
try {
71-
int pos = (int)args.size() - 1;
72-
Tuple false_value = args[pos--].cast<Tuple>();
73-
bool has_tuple_cond = false;
74-
bool has_expr_cond = false;
75-
while (pos > 0) {
76-
Tuple true_value = args[pos--].cast<Tuple>();
77-
// Note that 'condition' can be either Expr or Tuple, but must be consistent across all
78-
py::object py_cond = args[pos--];
79-
Expr expr_cond;
80-
Tuple tuple_cond(expr_cond);
81-
try {
82-
tuple_cond = py_cond.cast<Tuple>();
83-
has_tuple_cond = true;
84-
} catch (...) {
85-
expr_cond = py_cond.cast<Expr>();
86-
has_expr_cond = true;
87-
}
88-
89-
if (has_tuple_cond && has_expr_cond) {
90-
// We don't want to throw an error here, since the catch(...) would catch it,
91-
// and it would be hard to distinguish from other errors. Just set the string here
92-
// and jump to the error reporter outside the catch.
93-
tuple_error_msg = "select() on Tuples may not mix Expr and Tuple for the condition elements.";
94-
goto handle_tuple_error;
95-
}
96-
97-
if (expr_cond.defined()) {
98-
false_value = select(expr_cond, true_value, false_value);
99-
} else {
100-
if (tuple_cond.size() != true_value.size() || true_value.size() != false_value.size()) {
101-
// We don't want to throw an error here, since the catch(...) would catch it,
102-
// and it would be hard to distinguish from other errors. Just set the string here
103-
// and jump to the error reporter outside the catch.
104-
tuple_error_msg = "select() on Tuples requires all Tuples to have identical sizes.";
105-
goto handle_tuple_error;
106-
}
107-
false_value = select(tuple_cond, true_value, false_value);
108-
}
109-
}
110-
return to_python_tuple(false_value);
111-
112-
} catch (...) {
113-
// fall thru and try the Expr case
114-
}
115-
116-
handle_tuple_error:
117-
if (!tuple_error_msg.empty()) {
118-
_halide_user_assert(false) << tuple_error_msg;
119-
}
120-
121-
int pos = (int)args.size() - 1;
122-
Expr false_expr_value = args[pos--].cast<Expr>();
123-
while (pos > 0) {
124-
Expr true_expr_value = args[pos--].cast<Expr>();
125-
Expr condition_expr = args[pos--].cast<Expr>();
126-
false_expr_value = select(condition_expr, true_expr_value, false_expr_value);
127-
}
128-
return py::cast(false_expr_value);
129-
});
104+
m.def("select", py_select);
130105

131106
m.def("mux", static_cast<Expr (*)(const Expr &, const std::vector<Expr> &)>(&mux));
132107
m.def("mux", static_cast<Expr (*)(const Expr &, const Tuple &)>(&mux));

python_bindings/src/halide/halide_/PyTuple.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,11 @@ void define_tuple(py::module &m) {
6666
std::ostringstream o;
6767
o << "<halide.Tuple of size " << t.size() << ">";
6868
return o.str();
69+
})
70+
.def("__str__", [](const Tuple &t) -> std::string {
71+
std::ostringstream o;
72+
o << t;
73+
return o.str();
6974
});
7075

7176
py::implicitly_convertible<py::tuple, Tuple>();

python_bindings/test/correctness/iroperator.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import contextlib
2-
import halide as hl
3-
import sys
42
import io
3+
import math
4+
import sys
5+
6+
import halide as hl
57

68

79
# redirect_stdout() requires Python3, alas
@@ -59,6 +61,30 @@ def test_select():
5961
assert b[3] == 3
6062

6163

64+
def test_select_bad_argmax():
65+
x = hl.Var()
66+
f = hl.Func()
67+
f[x] = hl.sin(hl.f32(math.pi) * x / 16.0)
68+
69+
r = hl.RDom([(0, 10)])
70+
g = hl.Func()
71+
72+
g[()] = (0, f.type().min())
73+
try:
74+
g[()] = hl.select(f[r] > g[()][1], (f[r], r), g[()])
75+
except hl.HalideError as e:
76+
assert (
77+
"Error: The second and third arguments to a select do not have a matching type:"
78+
in str(e)
79+
)
80+
81+
g[()] = hl.select(f[r] > g[()][1], (r, f[r]), g[()])
82+
83+
idx, val = g.realize([])
84+
assert idx[()] == 8
85+
assert val[()] == 1.0
86+
87+
6288
def test_mux():
6389
c = hl.Var()
6490
f = hl.Func()
@@ -105,5 +131,6 @@ def test_minmax():
105131
test_print_expr()
106132
test_print_when()
107133
test_select()
134+
test_select_bad_argmax()
108135
test_mux()
109136
test_minmax()

python_bindings/test/correctness/memoize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def test_memoize():
1313
output[x] = f[x]
1414

1515
result = output.realize([3])
16-
assert list(result) == [1., 1., 1.]
16+
assert list(result) == [1.0, 1.0, 1.0]
1717

1818

1919
def main():

python_bindings/test/correctness/tuple_select.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,7 @@ def test_tuple_select():
7070
(x-100, y-200))
7171
# fmt: on
7272
except hl.HalideError as e:
73-
assert (
74-
"select() on Tuples may not mix Expr and Tuple for the condition elements."
75-
in str(e)
76-
)
73+
assert "select(): Expected Tuple but got Expr: ((x + y) < 100)" in str(e)
7774
else:
7875
assert False, "Did not see expected exception!"
7976

src/IRPrinter.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ ostream &operator<<(ostream &out, const Type &type) {
5555
}
5656
return out;
5757
}
58+
5859
ostream &operator<<(ostream &stream, const Expr &ir) {
5960
if (!ir.defined()) {
6061
stream << "(undefined)";
@@ -65,6 +66,14 @@ ostream &operator<<(ostream &stream, const Expr &ir) {
6566
return stream;
6667
}
6768

69+
ostream &operator<<(ostream &stream, const Tuple &ir) {
70+
stream << "(";
71+
for (size_t i = 0; i < ir.size(); i++) {
72+
stream << ir[i] << ", "; // keep the trailing comma
73+
}
74+
return stream << ")";
75+
}
76+
6877
ostream &operator<<(ostream &stream, const Buffer<> &buffer) {
6978
bool include_data = Internal::ends_with(buffer.name(), "_gpu_source_kernels");
7079
stream << "buffer " << buffer.name() << " = {";

src/IRPrinter.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ namespace Halide {
2626
* human-readable form */
2727
std::ostream &operator<<(std::ostream &stream, const Expr &);
2828

29+
/** Emit a tuple on an output stream (such as std::cout) in
30+
* human-readable form */
31+
std::ostream &operator<<(std::ostream &stream, const Tuple &);
32+
2933
/** Emit a halide type on an output stream (such as std::cout) in
3034
* human-readable form */
3135
std::ostream &operator<<(std::ostream &stream, const Type &);

0 commit comments

Comments
 (0)