@@ -18,6 +18,59 @@ namespace Halide {
1818
1919namespace PythonBindings {
2020
21+ namespace {
22+
23+ bool is_expr (const py::handle &obj) {
24+ try {
25+ (void )obj.cast <Expr>(); // Check if casting succeeds
26+ return true ;
27+ } catch (const py::cast_error &) {
28+ return false ;
29+ }
30+ }
31+
32+ template <typename T>
33+ T cast_arg (const py::handle &arg) {
34+ try {
35+ return arg.cast <T>();
36+ } catch (const py::cast_error &) {
37+ _halide_user_error
38+ << " select(): Expected " << py::str (py::type::of<T>().attr (" __name__" ))
39+ << " but got " << py::str (arg.get_type ().attr (" __name__" )) << " : "
40+ << py::str (arg);
41+ }
42+ }
43+
44+ template <typename TCond, typename TVal>
45+ py::object py_select_reduce (const py::args &args) {
46+ auto false_case = cast_arg<TVal>(args[args.size () - 1 ]);
47+ for (size_t pos = args.size () - 1 ; pos >= 2 ; pos -= 2 ) {
48+ auto true_case = cast_arg<TVal>(args[pos - 1 ]);
49+ auto condition = cast_arg<TCond>(args[pos - 2 ]);
50+ false_case = select (condition, true_case, false_case);
51+ }
52+ return py::cast (false_case);
53+ }
54+
55+ py::object py_select (const py::args &args) {
56+ if (args.size () < 3 ) {
57+ throw py::value_error (" select() must have at least 3 arguments" );
58+ }
59+
60+ if (args.size () % 2 != 1 ) {
61+ throw py::value_error (" select() must have an odd number of arguments" );
62+ }
63+
64+ if (is_expr (args[0 ])) { // If the condition is an Expr, then ...
65+ return is_expr (args[1 ]) ? // ... we need to check the value's kind.
66+ py_select_reduce<Expr, Expr>(args) :
67+ py_select_reduce<Expr, Tuple>(args);
68+ }
69+ return py_select_reduce<Tuple, Tuple>(args); // Otherwise, the value must be a tuple, too.
70+ }
71+
72+ } // namespace
73+
2174void define_operators (py::module &m) {
2275 m.def (" max" , [](const py::args &args) -> Expr {
2376 if (args.size () < 2 ) {
@@ -48,85 +101,7 @@ void define_operators(py::module &m) {
48101 m.def (" abs" , &abs);
49102 m.def (" absd" , &absd);
50103
51- m.def (" select" , [](const py::args &args) -> py::object {
52- if (args.size () < 3 ) {
53- throw py::value_error (" select() must have at least 3 arguments" );
54- }
55- if ((args.size () % 2 ) == 0 ) {
56- throw py::value_error (" select() must have an odd number of arguments" );
57- }
58-
59- // Tricky set of options here:
60- //
61- // - (Expr, Expr, Expr, [Expr, Expr...]) -> Expr
62- // - (Expr, Tuple, Tuple, [Tuple, Tuple...]) -> Tuple [Tuples must be same arity]
63- // - (Tuple, Tuple, Tuple, [Tuple, Tuple...]) -> Tuple [Tuples must be same arity]
64- //
65- // It's made trickier by the fact that it's hard to do a reliable "is-a" check for Tuple here,
66- // so we'll do the slow-but-reliable approach of just trying to cast to Tuple and catching
67- // exceptions.
68-
69- std::string tuple_error_msg;
70- try {
71- int pos = (int )args.size () - 1 ;
72- Tuple false_value = args[pos--].cast <Tuple>();
73- bool has_tuple_cond = false ;
74- bool has_expr_cond = false ;
75- while (pos > 0 ) {
76- Tuple true_value = args[pos--].cast <Tuple>();
77- // Note that 'condition' can be either Expr or Tuple, but must be consistent across all
78- py::object py_cond = args[pos--];
79- Expr expr_cond;
80- Tuple tuple_cond (expr_cond);
81- try {
82- tuple_cond = py_cond.cast <Tuple>();
83- has_tuple_cond = true ;
84- } catch (...) {
85- expr_cond = py_cond.cast <Expr>();
86- has_expr_cond = true ;
87- }
88-
89- if (has_tuple_cond && has_expr_cond) {
90- // We don't want to throw an error here, since the catch(...) would catch it,
91- // and it would be hard to distinguish from other errors. Just set the string here
92- // and jump to the error reporter outside the catch.
93- tuple_error_msg = " select() on Tuples may not mix Expr and Tuple for the condition elements." ;
94- goto handle_tuple_error;
95- }
96-
97- if (expr_cond.defined ()) {
98- false_value = select (expr_cond, true_value, false_value);
99- } else {
100- if (tuple_cond.size () != true_value.size () || true_value.size () != false_value.size ()) {
101- // We don't want to throw an error here, since the catch(...) would catch it,
102- // and it would be hard to distinguish from other errors. Just set the string here
103- // and jump to the error reporter outside the catch.
104- tuple_error_msg = " select() on Tuples requires all Tuples to have identical sizes." ;
105- goto handle_tuple_error;
106- }
107- false_value = select (tuple_cond, true_value, false_value);
108- }
109- }
110- return to_python_tuple (false_value);
111-
112- } catch (...) {
113- // fall thru and try the Expr case
114- }
115-
116- handle_tuple_error:
117- if (!tuple_error_msg.empty ()) {
118- _halide_user_assert (false ) << tuple_error_msg;
119- }
120-
121- int pos = (int )args.size () - 1 ;
122- Expr false_expr_value = args[pos--].cast <Expr>();
123- while (pos > 0 ) {
124- Expr true_expr_value = args[pos--].cast <Expr>();
125- Expr condition_expr = args[pos--].cast <Expr>();
126- false_expr_value = select (condition_expr, true_expr_value, false_expr_value);
127- }
128- return py::cast (false_expr_value);
129- });
104+ m.def (" select" , py_select);
130105
131106 m.def (" mux" , static_cast <Expr (*)(const Expr &, const std::vector<Expr> &)>(&mux));
132107 m.def (" mux" , static_cast <Expr (*)(const Expr &, const Tuple &)>(&mux));
0 commit comments