Skip to content
3 changes: 2 additions & 1 deletion python_bindings/src/halide/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@ def install_dir():
active_generator_context,
alias,
funcs,
Generator,
generator,
main,
Generator,
Comment thread
alexreinking marked this conversation as resolved.
GeneratorParam,
InputBuffer,
InputScalar,
Expand Down
18 changes: 15 additions & 3 deletions python_bindings/src/halide/_generator_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
from enum import Enum
from functools import total_ordering
from .halide_ import (
_,
Comment thread
alexreinking marked this conversation as resolved.
_generate_filter_main,
_unique_name,
_UnspecifiedType,
ArgInfo,
ArgInfoDirection,
ArgInfoKind,
Expand All @@ -21,9 +25,6 @@
Type,
UInt,
Var,
_,
_UnspecifiedType,
_unique_name,
)
from inspect import isclass
from typing import Any, Optional
Expand Down Expand Up @@ -892,3 +893,14 @@ def funcs(names: str) -> tuple[Func, ...]:
def vars(names: str) -> tuple[Var, ...]:
"""Given a space-delimited string, create a Var for each substring and return as a tuple."""
return tuple(Var(n) for n in names.split(" "))


def main(argv: Optional[list[str]] = None):
"""Entrypoint for invoking all registered generators.

Args:
argv: A list of command-line arguments to pass to the generator. If None, uses sys.argv.
"""
if argv is None:
argv = sys.argv
_generate_filter_main(argv)
54 changes: 28 additions & 26 deletions python_bindings/src/halide/halide_/PyGenerator.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#include "PyGenerator.h"

#include <pybind11/embed.h>
#include <string>
#include <vector>

namespace Halide {
namespace PythonBindings {
Expand All @@ -15,16 +17,6 @@ using Halide::Parameter;
using Halide::Internal::ArgInfoDirection;
using Halide::Internal::ArgInfoKind;

template<typename T>
Comment thread
alexreinking marked this conversation as resolved.
std::map<std::string, T> dict_to_map(const py::dict &dict) {
_halide_user_assert(!dict.is(py::none()));
std::map<std::string, T> m;
for (auto it : dict) {
m[it.first.cast<std::string>()] = it.second.cast<T>();
}
return m;
}

class PyGeneratorBase : public AbstractGenerator {
// The name declared in the Python function's decorator
const std::string name_;
Expand Down Expand Up @@ -165,22 +157,32 @@ void define_generator(py::module &m) {
return o.str();
});

m.def("main", []() -> void {
py::object argv_object = py::module_::import("sys").attr("argv");
std::vector<std::string> argv_vector = args_to_vector<std::string>(argv_object);
std::vector<char *> argv;
argv.reserve(argv_vector.size());
for (auto &s : argv_vector) {
argv.push_back(const_cast<char *>(s.c_str()));
}
int result = Halide::Internal::generate_filter_main((int)argv.size(), argv.data(), PyGeneratorFactoryProvider());
if (result != 0) {
// Some paths in generate_filter_main() will fail with user_error or similar (which throws an exception
// due to how libHalide is built for python), but some paths just return an error code, so
// be sure to handle both.
throw std::runtime_error("Generator failed: " + std::to_string(result));
}
});
m.def("_generate_filter_main", //
[](const std::vector<std::string> &arguments) -> void {
if (arguments.empty()) {
throw std::invalid_argument("No arguments provided to _generate_filter_main");
}

// POSIX requires argv to be mutable and null-terminated
std::vector<char *> argv;
argv.reserve(arguments.size() + 1);
for (const auto &s : arguments) {
argv.push_back(const_cast<char *>(s.c_str()));
}
argv.push_back(nullptr);

const int result = Halide::Internal::generate_filter_main(
static_cast<int>(argv.size()) - 1, argv.data(), PyGeneratorFactoryProvider());
if (result != 0) {
// Some paths in generate_filter_main() will fail with user_error
// or similar (which throws an exception due to how libHalide is
// built for Python), but other paths just return an error code.
// For consistency, handle both by throwing a C++ exception, which
// PyBind11 turns into a Python exception.
throw std::runtime_error("Generator failed: " + std::to_string(result));
} //
},
py::arg("argv"));

m.def("_unique_name", []() -> std::string {
return ::Halide::Internal::unique_name('p');
Expand Down
Loading