|
1 | 1 | #include "PyGenerator.h" |
2 | 2 |
|
| 3 | +#include <optional> |
3 | 4 | #include <pybind11/embed.h> |
| 5 | +#include <string> |
| 6 | +#include <vector> |
4 | 7 |
|
5 | 8 | namespace Halide { |
6 | 9 | namespace PythonBindings { |
@@ -126,6 +129,43 @@ class PyGeneratorFactoryProvider : public GeneratorFactoryProvider { |
126 | 129 | } |
127 | 130 | }; |
128 | 131 |
|
| 132 | +// Returns a vector of mutable char * pointers corresponding to each string in `strs`. |
| 133 | +// `strs` must outlive the input and the pointers are not stable if the std::strings are mutated. |
| 134 | +// Arg (pun intended), this is all because generate_filter_main wants a mutable char **argv. |
| 135 | +std::vector<char *> get_mutable_c_strs(const std::vector<std::string> &strs) { |
| 136 | + std::vector<char *> c_strs; |
| 137 | + c_strs.reserve(strs.size()); |
| 138 | + for (const auto &s : strs) { |
| 139 | + c_strs.push_back(const_cast<char *>(s.c_str())); |
| 140 | + } |
| 141 | + return c_strs; |
| 142 | +} |
| 143 | + |
| 144 | +// PyBind11 treats `const std::optional<std::vector<std::string>> &` as an argument |
| 145 | +// that can be a list of strings or None. |
| 146 | +void main_impl(const std::optional<std::vector<std::string>> &argv) { |
| 147 | + // If the caller passed in args, use them. |
| 148 | + // Otherwise, parse them from sys.argv. |
| 149 | + // We need to make a copy in either case because of how PyBind11 translates |
| 150 | + // the input list to a optional vector of strings. |
| 151 | + std::vector<std::string> argv_copy; |
| 152 | + if (argv.has_value()) { |
| 153 | + argv_copy = *argv; |
| 154 | + } else { |
| 155 | + py::object py_sys_argv = py::module_::import("sys").attr("argv"); |
| 156 | + argv_copy = args_to_vector<std::string>(py_sys_argv); |
| 157 | + } |
| 158 | + |
| 159 | + std::vector<char *> mutable_argv = get_mutable_c_strs(argv_copy); |
| 160 | + const int result = Halide::Internal::generate_filter_main((int)mutable_argv.size(), mutable_argv.data(), PyGeneratorFactoryProvider()); |
| 161 | + if (result != 0) { |
| 162 | + // Some paths in generate_filter_main() will fail with user_error or similar (which throws an exception |
| 163 | + // due to how libHalide is built for Python), but some paths just return an error code. For consistency, |
| 164 | + // handle both by throwing a C++ exception, which pybind11 turns into a Python exception. |
| 165 | + throw std::runtime_error("Generator failed: " + std::to_string(result)); |
| 166 | + } |
| 167 | +} |
| 168 | + |
129 | 169 | } // namespace |
130 | 170 |
|
131 | 171 | void define_generator(py::module &m) { |
@@ -165,22 +205,7 @@ void define_generator(py::module &m) { |
165 | 205 | return o.str(); |
166 | 206 | }); |
167 | 207 |
|
168 | | - m.def("main", []() -> void { |
169 | | - py::object argv_object = py::module_::import("sys").attr("argv"); |
170 | | - std::vector<std::string> argv_vector = args_to_vector<std::string>(argv_object); |
171 | | - std::vector<char *> argv; |
172 | | - argv.reserve(argv_vector.size()); |
173 | | - for (auto &s : argv_vector) { |
174 | | - argv.push_back(const_cast<char *>(s.c_str())); |
175 | | - } |
176 | | - int result = Halide::Internal::generate_filter_main((int)argv.size(), argv.data(), PyGeneratorFactoryProvider()); |
177 | | - if (result != 0) { |
178 | | - // Some paths in generate_filter_main() will fail with user_error or similar (which throws an exception |
179 | | - // due to how libHalide is built for python), but some paths just return an error code, so |
180 | | - // be sure to handle both. |
181 | | - throw std::runtime_error("Generator failed: " + std::to_string(result)); |
182 | | - } |
183 | | - }); |
| 208 | + m.def("main", &main_impl, py::arg("argv") = py::none()); |
184 | 209 |
|
185 | 210 | m.def("_unique_name", []() -> std::string { |
186 | 211 | return ::Halide::Internal::unique_name('p'); |
|
0 commit comments