Skip to content

Commit 9bae1f1

Browse files
committed
Allows the caller to pass in argv to hl.main()
1 parent 59bcd4b commit 9bae1f1

1 file changed

Lines changed: 41 additions & 16 deletions

File tree

python_bindings/src/halide/halide_/PyGenerator.cpp

Lines changed: 41 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
#include "PyGenerator.h"
22

3+
#include <optional>
34
#include <pybind11/embed.h>
5+
#include <string>
6+
#include <vector>
47

58
namespace Halide {
69
namespace PythonBindings {
@@ -126,6 +129,43 @@ class PyGeneratorFactoryProvider : public GeneratorFactoryProvider {
126129
}
127130
};
128131

132+
// Returns a vector of mutable char * pointers corresponding to each string in `strs`.
133+
// `strs` must outlive the input and the pointers are not stable if the std::strings are mutated.
134+
// Arg (pun intended), this is all because generate_filter_main wants a mutable char **argv.
135+
std::vector<char *> get_mutable_c_strs(const std::vector<std::string> &strs) {
136+
std::vector<char *> c_strs;
137+
c_strs.reserve(strs.size());
138+
for (const auto &s : strs) {
139+
c_strs.push_back(const_cast<char *>(s.c_str()));
140+
}
141+
return c_strs;
142+
}
143+
144+
// PyBind11 treats `const std::optional<std::vector<std::string>> &` as an argument
145+
// that can be a list of strings or None.
146+
void main_impl(const std::optional<std::vector<std::string>> &argv) {
147+
// If the caller passed in args, use them.
148+
// Otherwise, parse them from sys.argv.
149+
// We need to make a copy in either case because of how PyBind11 translates
150+
// the input list to a optional vector of strings.
151+
std::vector<std::string> argv_copy;
152+
if (argv.has_value()) {
153+
argv_copy = *argv;
154+
} else {
155+
py::object py_sys_argv = py::module_::import("sys").attr("argv");
156+
argv_copy = args_to_vector<std::string>(py_sys_argv);
157+
}
158+
159+
std::vector<char *> mutable_argv = get_mutable_c_strs(argv_copy);
160+
const int result = Halide::Internal::generate_filter_main((int)mutable_argv.size(), mutable_argv.data(), PyGeneratorFactoryProvider());
161+
if (result != 0) {
162+
// Some paths in generate_filter_main() will fail with user_error or similar (which throws an exception
163+
// due to how libHalide is built for Python), but some paths just return an error code. For consistency,
164+
// handle both by throwing a C++ exception, which pybind11 turns into a Python exception.
165+
throw std::runtime_error("Generator failed: " + std::to_string(result));
166+
}
167+
}
168+
129169
} // namespace
130170

131171
void define_generator(py::module &m) {
@@ -165,22 +205,7 @@ void define_generator(py::module &m) {
165205
return o.str();
166206
});
167207

168-
m.def("main", []() -> void {
169-
py::object argv_object = py::module_::import("sys").attr("argv");
170-
std::vector<std::string> argv_vector = args_to_vector<std::string>(argv_object);
171-
std::vector<char *> argv;
172-
argv.reserve(argv_vector.size());
173-
for (auto &s : argv_vector) {
174-
argv.push_back(const_cast<char *>(s.c_str()));
175-
}
176-
int result = Halide::Internal::generate_filter_main((int)argv.size(), argv.data(), PyGeneratorFactoryProvider());
177-
if (result != 0) {
178-
// Some paths in generate_filter_main() will fail with user_error or similar (which throws an exception
179-
// due to how libHalide is built for python), but some paths just return an error code, so
180-
// be sure to handle both.
181-
throw std::runtime_error("Generator failed: " + std::to_string(result));
182-
}
183-
});
208+
m.def("main", &main_impl, py::arg("argv") = py::none());
184209

185210
m.def("_unique_name", []() -> std::string {
186211
return ::Halide::Internal::unique_name('p');

0 commit comments

Comments
 (0)