From 864a26ddd8f804fce80326e777be015de3e6199a Mon Sep 17 00:00:00 2001 From: Martijn Courteaux Date: Sun, 15 Mar 2026 22:34:07 +0100 Subject: [PATCH 01/13] First version of the IRGraph printer for C++ code. --- Makefile | 2 + src/CMakeLists.txt | 2 + src/IRGraphCXXPrinter.cpp | 47 ++++++ src/IRGraphCXXPrinter.h | 298 ++++++++++++++++++++++++++++++++++++++ test/internal.cpp | 2 + 5 files changed, 351 insertions(+) create mode 100644 src/IRGraphCXXPrinter.cpp create mode 100644 src/IRGraphCXXPrinter.h diff --git a/Makefile b/Makefile index 7edddd719f81..1a30944417ca 100644 --- a/Makefile +++ b/Makefile @@ -529,6 +529,7 @@ SOURCE_FILES = \ Interval.cpp \ IR.cpp \ IREquality.cpp \ + IRGraphCXXPrinter.cpp \ IRMatch.cpp \ IRMutator.cpp \ IROperator.cpp \ @@ -732,6 +733,7 @@ HEADER_FILES = \ IntrusivePtr.h \ IR.h \ IREquality.h \ + IRGraphCXXPrinter.h \ IRMatch.h \ IRMutator.h \ IROperator.h \ diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index a373136025a9..6d69e1cd57f3 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -139,6 +139,7 @@ target_sources( IntrusivePtr.h IR.h IREquality.h + IRGraphCXXPrinter.h IRMatch.h IRMutator.h IROperator.h @@ -318,6 +319,7 @@ target_sources( Interval.cpp IR.cpp IREquality.cpp + IRGraphCXXPrinter.cpp IRMatch.cpp IRMutator.cpp IROperator.cpp diff --git a/src/IRGraphCXXPrinter.cpp b/src/IRGraphCXXPrinter.cpp new file mode 100644 index 000000000000..875d844d7ee4 --- /dev/null +++ b/src/IRGraphCXXPrinter.cpp @@ -0,0 +1,47 @@ +#include "IRGraphCXXPrinter.h" + +#include "Expr.h" +#include "IR.h" +#include "IREquality.h" +#include "IROperator.h" + +namespace Halide { +namespace Internal { + +void IRGraphCXXPrinter::test() { + // This: + Expr e = Select::make(Mod::make(Ramp::make(10, 314, 8), Broadcast::make(10, 8)) < Variable::make(Int(32), "p"), Broadcast::make(40, 8) + Ramp::make(4, 8, 8), VectorReduce::make(VectorReduce::Add, Ramp::make(0, 1, 16), 8)); + + // Printed by: + IRGraphCXXPrinter p(std::cout); + p.print(e); + + // Prints: + Expr expr_0 = IntImm::make(Type(Type::Int, 32, 1), 10); + Expr expr_1 = IntImm::make(Type(Type::Int, 32, 1), 314); + Expr expr_2 = Ramp::make(expr_0, expr_1, 8); + Expr expr_3 = IntImm::make(Type(Type::Int, 32, 1), 10); + Expr expr_4 = Broadcast::make(expr_3, 8); + Expr expr_5 = Mod::make(expr_2, expr_4); + Expr expr_6 = Variable::make(Type(Type::Int, 32, 1), "p"); + Expr expr_7 = Broadcast::make(expr_6, 8); + Expr expr_8 = LT::make(expr_5, expr_7); + Expr expr_9 = IntImm::make(Type(Type::Int, 32, 1), 40); + Expr expr_10 = Broadcast::make(expr_9, 8); + Expr expr_11 = IntImm::make(Type(Type::Int, 32, 1), 4); + Expr expr_12 = IntImm::make(Type(Type::Int, 32, 1), 8); + Expr expr_13 = Ramp::make(expr_11, expr_12, 8); + Expr expr_14 = Add::make(expr_10, expr_13); + Expr expr_15 = IntImm::make(Type(Type::Int, 32, 1), 0); + Expr expr_16 = IntImm::make(Type(Type::Int, 32, 1), 1); + Expr expr_17 = Ramp::make(expr_15, expr_16, 16); + Expr expr_18 = VectorReduce::make(VectorReduce::Add, expr_17, 8); + Expr expr_19 = Select::make(expr_8, expr_14, expr_18); + + // Now let's see if it matches: + internal_assert(equal(expr_19, e)) << "Expressions don't match:\n\n" + << e << "\n\n" + << expr_19 << "\n"; +} +} // namespace Internal +} // namespace Halide diff --git a/src/IRGraphCXXPrinter.h b/src/IRGraphCXXPrinter.h new file mode 100644 index 000000000000..0c07d283f023 --- /dev/null +++ b/src/IRGraphCXXPrinter.h @@ -0,0 +1,298 @@ +#ifndef HALIDE_IRGRAPHCXXPRINTER_H +#define HALIDE_IRGRAPHCXXPRINTER_H + +#include +#include +#include +#include +#include +#include + +#include "Expr.h" +#include "IR.h" +#include "IRVisitor.h" + +namespace Halide { +namespace Internal { + +class IRGraphCXXPrinter : public IRVisitor { +public: + std::ostream &os; + + // Tracks visited nodes so we don't print them twice (handles the DAG structure) + std::map node_names; + int var_counter = 0; + + IRGraphCXXPrinter(std::ostream &os) : os(os) { + } + + void print(const Expr &e) { + if (e.defined()) { + e.accept(this); + } + } + + void print(const Stmt &s) { + if (s.defined()) { + s.accept(this); + } + } + +private: + // ========================================================================= + // ✨ CLEVER TEMPLATING ✨ + // This SFINAE trick checks if `T::make` can be invoked with `Args...`. + // It will trigger a static_assert if you forget an argument or pass the + // wrong field types, completely preventing generated code compile errors! + // ========================================================================= + template + static constexpr auto check_make_args(Args &&...args) + -> decltype(T::make(std::forward(args)...), std::true_type{}) { + return std::true_type{}; + } + + template + static constexpr std::false_type check_make_args(...) { + return std::false_type{}; + } + + // ========================================================================= + // ARGUMENT STRINGIFIERS + // These convert Halide objects into strings representing C++ code. + // ========================================================================= + + template + std::string to_cpp_arg(const T &x) { + if constexpr (std::is_arithmetic_v) { + return std::to_string(x); + } else { + internal_error << "Not supported to print"; + } + } + + template<> + std::string to_cpp_arg(const Expr &e) { + if (!e.defined()) { + return "Expr()"; + } + e.accept(this); // Visit dependencies first + return node_names.at(e.get()); + } + + std::string to_cpp_arg(const Stmt &s) { + if (!s.defined()) { + return "Stmt()"; + } + s.accept(this); // Visit dependencies first + return node_names.at(s.get()); + } + + std::string to_cpp_arg(const Range &r) { + r.min.accept(this); + r.extent.accept(this); // Visit dependencies first + return "Range(" + node_names.at(r.min.get()) + ", " + node_names.at(r.extent.get()) + ")"; + } + + std::string to_cpp_arg(const std::string &s) { + return "\"" + s + "\""; + } + + std::string to_cpp_arg(Type t) { + std::ostringstream oss; + oss << "Type(Type::" + << (t.is_int() ? "Int" : t.is_uint() ? "UInt" : + t.is_float() ? "Float" : + t.is_bfloat() ? "BFloat" : + "Handle") + << ", " << t.bits() << ", " << t.lanes() << ")"; + return oss.str(); + } + + std::string to_cpp_arg(ForType f) { + switch (f) { + case ForType::Serial: + return "ForType::Serial"; + case ForType::Parallel: + return "ForType::Parallel"; + case ForType::Vectorized: + return "ForType::Vectorized"; + case ForType::Unrolled: + return "ForType::Unrolled"; + case ForType::Extern: + return "ForType::Extern"; + case ForType::GPUBlock: + return "ForType::GPUBlock"; + case ForType::GPUThread: + return "ForType::GPUThread"; + case ForType::GPULane: + return "ForType::GPULane"; + default: + return "ForType::Serial"; + } + } + + std::string to_cpp_arg(const VectorReduce::Operator &op) { + switch (op) { + case VectorReduce::Add: + return "VectorReduce::Add"; + case VectorReduce::SaturatingAdd: + return "VectorReduce::SaturatingAdd"; + case VectorReduce::Mul: + return "VectorReduce::Mul"; + case VectorReduce::Min: + return "VectorReduce::Min"; + case VectorReduce::Max: + return "VectorReduce::Max"; + case VectorReduce::And: + return "VectorReduce::And"; + case VectorReduce::Or: + return "VectorReduce::Or"; + } + internal_error << "Invalid VectorReduce"; + } + + std::string to_cpp_arg(DeviceAPI api) { + return "DeviceAPI::" + std::to_string((int)api); // Or proper switch-case logic + } + + std::string to_cpp_arg(ModulusRemainder align) { + return "ModulusRemainder(" + std::to_string(align.modulus) + ", " + std::to_string(align.remainder) + ")"; + } + + std::string to_cpp_arg(const Parameter &p) { + internal_error << "Not supported to print Parameter"; + } + + template + std::string to_cpp_arg(const std::vector &vec) { + std::string res = "{"; + for (size_t i = 0; i < vec.size(); ++i) { + res += to_cpp_arg(vec[i]); + if (i + 1 < vec.size()) { + res += ", "; + } + } + res += "}"; + return res; + } + + // ========================================================================= + // CORE NODE EMITTER + // ========================================================================= + template + void emit_node(const char *node_type_str, const T *op, Args &&...args) { + // 1. Maintain DAG properties: if we've already generated it, skip. + if (node_names.count(op)) { + return; + } + + // 2. ✨ Check at our compile-time that the signature aligns exactly! + static_assert(decltype(check_make_args(std::forward(args)...))::value, + "Arguments extracted for printer do not match any T::make() signature! " + "Check your VISIT_NODE macro arguments."); + + // 3. Evaluate arguments post-order to build dependencies. + // (C++11 guarantees left-to-right evaluation in brace-init lists) + std::vector printed_args = {to_cpp_arg(args)...}; + + // 4. Generate the actual C++ code + bool is_expr = std::is_base_of_v; + std::string var_name = (is_expr ? "expr_" : "stmt_") + std::to_string(var_counter++); + + os << (is_expr ? "Expr " : "Stmt ") << var_name << " = " << node_type_str << "::make("; + for (size_t i = 0; i < printed_args.size(); ++i) { + os << printed_args[i] << (i + 1 == printed_args.size() ? "" : ", "); + } + os << ");\n"; + + node_names[op] = var_name; + } + +protected: +// ========================================================================= +// VISITOR OVERRIDES +// ========================================================================= + +// Macro handles mapping the IR node pointer to our clever template. +#define VISIT_NODE(NodeType, ...) \ + void visit(const NodeType *op) override { \ + IRVisitor::visit(op); \ + emit_node(#NodeType, op, __VA_ARGS__); \ + } + + // --- 1. Core / Primitive Values --- + VISIT_NODE(IntImm, op->type, op->value) + VISIT_NODE(UIntImm, op->type, op->value) + VISIT_NODE(FloatImm, op->type, op->value) + VISIT_NODE(StringImm, op->value) + + // --- 2. Variable & Broadcast --- + VISIT_NODE(Variable, op->type, op->name /*, op->image, op->param, op->reduction_domain */) + VISIT_NODE(Broadcast, op->value, op->lanes) + + // --- 3. Binary & Unary Math Nodes --- + VISIT_NODE(Add, op->a, op->b) + VISIT_NODE(Sub, op->a, op->b) + VISIT_NODE(Mod, op->a, op->b) + VISIT_NODE(Mul, op->a, op->b) + VISIT_NODE(Div, op->a, op->b) + VISIT_NODE(Min, op->a, op->b) + VISIT_NODE(Max, op->a, op->b) + VISIT_NODE(EQ, op->a, op->b) + VISIT_NODE(NE, op->a, op->b) + VISIT_NODE(LT, op->a, op->b) + VISIT_NODE(LE, op->a, op->b) + VISIT_NODE(GT, op->a, op->b) + VISIT_NODE(GE, op->a, op->b) + VISIT_NODE(And, op->a, op->b) + VISIT_NODE(Or, op->a, op->b) + VISIT_NODE(Not, op->a) + + // --- 4. Casts & Shuffles --- + VISIT_NODE(Cast, op->type, op->value) + VISIT_NODE(Reinterpret, op->type, op->value) + VISIT_NODE(Shuffle, op->vectors, op->indices) + + // --- 5. Complex Expressions --- + VISIT_NODE(Select, op->condition, op->true_value, op->false_value) + VISIT_NODE(Load, op->type, op->name, op->index, op->image, op->param, op->predicate, op->alignment) + VISIT_NODE(Ramp, op->base, op->stride, op->lanes) + VISIT_NODE(Call, op->type, op->name, op->args, op->call_type, op->func, op->value_index, op->image, op->param) + VISIT_NODE(Let, op->name, op->value, op->body) + VISIT_NODE(VectorReduce, op->op, op->value, op->type.lanes()) + + // --- 6. Core Statements --- + VISIT_NODE(LetStmt, op->name, op->value, op->body) + VISIT_NODE(AssertStmt, op->condition, op->message) + VISIT_NODE(Evaluate, op->value) + VISIT_NODE(Block, op->first, op->rest) + VISIT_NODE(IfThenElse, op->condition, op->then_case, op->else_case) + VISIT_NODE(For, op->name, op->min, op->max, op->for_type, op->partition_policy, op->device_api, op->body) + + // --- 7. Memory / Buffer Operations --- + VISIT_NODE(Store, op->name, op->value, op->index, op->param, op->predicate, op->alignment) + VISIT_NODE(Provide, op->name, op->values, op->args, op->predicate) + VISIT_NODE(Allocate, op->name, op->type, op->memory_type, op->extents, op->condition, op->body, op->new_expr, op->free_function) + VISIT_NODE(Free, op->name) + VISIT_NODE(Realize, op->name, op->types, op->memory_type, op->bounds, op->condition, op->body) + VISIT_NODE(Prefetch, op->name, op->types, op->bounds, op->prefetch, op->condition, op->body) + VISIT_NODE(HoistedStorage, op->name, op->body) + + // --- 8. Concurrency & Sync --- + VISIT_NODE(ProducerConsumer, op->name, op->is_producer, op->body) + VISIT_NODE(Acquire, op->semaphore, op->count, op->body) + VISIT_NODE(Fork, op->first, op->rest) + VISIT_NODE(Atomic, op->producer_name, op->mutex_name, op->body) + +// Clean up macro +#undef VISIT_NODE + +public: + static void test(); +}; + +} // namespace Internal + +} // namespace Halide + +#endif // HALIDE_IRGRAPHCXXPRINTER_H diff --git a/test/internal.cpp b/test/internal.cpp index 08283fa9cf54..448a5960e79a 100644 --- a/test/internal.cpp +++ b/test/internal.cpp @@ -18,12 +18,14 @@ #include "Solve.h" #include "SpirvIR.h" #include "UniquifyVariableNames.h" +#include "IRGraphCXXPrinter.h" using namespace Halide; using namespace Halide::Internal; int main(int argc, const char **argv) { IRPrinter::test(); + IRGraphCXXPrinter::test(); CodeGen_C::test(); ir_equality_test(); bounds_test(); From a6716de1ddec0fd12c71e607953421974856965b Mon Sep 17 00:00:00 2001 From: Martijn Courteaux Date: Sun, 15 Mar 2026 22:35:39 +0100 Subject: [PATCH 02/13] typo and clang violation --- src/IRGraphCXXPrinter.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/IRGraphCXXPrinter.cpp b/src/IRGraphCXXPrinter.cpp index 875d844d7ee4..ed1c9b34213a 100644 --- a/src/IRGraphCXXPrinter.cpp +++ b/src/IRGraphCXXPrinter.cpp @@ -41,7 +41,9 @@ void IRGraphCXXPrinter::test() { // Now let's see if it matches: internal_assert(equal(expr_19, e)) << "Expressions don't match:\n\n" << e << "\n\n" - << expr_19 << "\n"; + << expr_19 << "\n"; + + // Here is a bad typo for Alex who likes progamming. Above is a badly intented line. Two typos? } } // namespace Internal } // namespace Halide From 28aefd228401e4fca23e0bc970f8cdb5e3935b35 Mon Sep 17 00:00:00 2001 From: Martijn Courteaux Date: Sun, 15 Mar 2026 22:41:49 +0100 Subject: [PATCH 03/13] Try again From c0293615d7164a24bb0a30ca67b10c3412a8d0ad Mon Sep 17 00:00:00 2001 From: Martijn Courteaux Date: Sun, 15 Mar 2026 23:18:07 +0100 Subject: [PATCH 04/13] Moved implementation of the IRGraphCXXPrinter to the cpp file. --- src/IRGraphCXXPrinter.cpp | 235 +++++++++++++++++++++++++++++- src/IRGraphCXXPrinter.h | 293 +++++++------------------------------- 2 files changed, 284 insertions(+), 244 deletions(-) diff --git a/src/IRGraphCXXPrinter.cpp b/src/IRGraphCXXPrinter.cpp index ed1c9b34213a..20ca73fd5979 100644 --- a/src/IRGraphCXXPrinter.cpp +++ b/src/IRGraphCXXPrinter.cpp @@ -8,6 +8,237 @@ namespace Halide { namespace Internal { +namespace { +// ========================================================================= +// ✨ CLEVER TEMPLATING ✨ +// This SFINAE trick checks if `T::make` can be invoked with `Args...`. +// It will trigger a static_assert if you forget an argument or pass the +// wrong field types, completely preventing generated code compile errors! +// ========================================================================= +template +static constexpr auto check_make_args(Args &&...args) + -> decltype(T::make(std::forward(args)...), std::true_type{}) { + return std::true_type{}; +} + +template +static constexpr std::false_type check_make_args(...) { + return std::false_type{}; +} + +} // namespace + +template +std::string IRGraphCXXPrinter::to_cpp_arg(const T &x) { + if constexpr (std::is_arithmetic_v) { + return std::to_string(x); + } else { + internal_error << "Not supported to print"; + } +} + +template<> +std::string IRGraphCXXPrinter::to_cpp_arg(const Expr &e) { + if (!e.defined()) { + return "Expr()"; + } + include(e); + return node_names.at(e.get()); +} + +template<> +std::string IRGraphCXXPrinter::to_cpp_arg(const Stmt &s) { + if (!s.defined()) { + return "Stmt()"; + } + include(s); + return node_names.at(s.get()); +} + +template<> +std::string IRGraphCXXPrinter::to_cpp_arg(const Range &r) { + include(r.min); + include(r.extent); + return "Range(" + node_names.at(r.min.get()) + ", " + node_names.at(r.extent.get()) + ")"; +} + +template<> +std::string IRGraphCXXPrinter::to_cpp_arg(const std::string &s) { + return "\"" + s + "\""; +} +template<> +std::string IRGraphCXXPrinter::to_cpp_arg(const ForType &f) { + switch (f) { + case ForType::Serial: + return "ForType::Serial"; + case ForType::Parallel: + return "ForType::Parallel"; + case ForType::Vectorized: + return "ForType::Vectorized"; + case ForType::Unrolled: + return "ForType::Unrolled"; + case ForType::Extern: + return "ForType::Extern"; + case ForType::GPUBlock: + return "ForType::GPUBlock"; + case ForType::GPUThread: + return "ForType::GPUThread"; + case ForType::GPULane: + return "ForType::GPULane"; + default: + return "ForType::Serial"; + } +} + +template<> +std::string IRGraphCXXPrinter::to_cpp_arg(const VectorReduce::Operator &op) { + switch (op) { + case VectorReduce::Add: + return "VectorReduce::Add"; + case VectorReduce::SaturatingAdd: + return "VectorReduce::SaturatingAdd"; + case VectorReduce::Mul: + return "VectorReduce::Mul"; + case VectorReduce::Min: + return "VectorReduce::Min"; + case VectorReduce::Max: + return "VectorReduce::Max"; + case VectorReduce::And: + return "VectorReduce::And"; + case VectorReduce::Or: + return "VectorReduce::Or"; + } + internal_error << "Invalid VectorReduce"; +} + +template<> +std::string IRGraphCXXPrinter::to_cpp_arg(const Type &t) { + std::ostringstream oss; + oss << "Type(Type::" + << (t.is_int() ? "Int" : t.is_uint() ? "UInt" : + t.is_float() ? "Float" : + t.is_bfloat() ? "BFloat" : + "Handle") + << ", " << t.bits() << ", " << t.lanes() << ")"; + return oss.str(); +} + +template<> +std::string IRGraphCXXPrinter::to_cpp_arg(const ModulusRemainder &align) { + return "ModulusRemainder(" + std::to_string(align.modulus) + ", " + std::to_string(align.remainder) + ")"; +} + +template +std::string IRGraphCXXPrinter::to_cpp_arg(const std::vector &vec) { + std::string res = "{"; + for (size_t i = 0; i < vec.size(); ++i) { + res += to_cpp_arg(vec[i]); + if (i + 1 < vec.size()) { + res += ", "; + } + } + res += "}"; + return res; +} + +template +void IRGraphCXXPrinter::emit_node(const char *node_type_str, const T *op, Args &&...args) { + if (node_names.count(op)) { + return; + } + + static_assert(decltype(check_make_args(std::forward(args)...))::value, + "Arguments extracted for printer do not match any T::make() signature! " + "Check your VISIT_NODE macro arguments."); + + // Evaluate arguments post-order to build dependencies. + // (C++11 guarantees left-to-right evaluation in brace-init lists) + std::vector printed_args = {to_cpp_arg(args)...}; + + // Generate the actual C++ code + bool is_expr = std::is_base_of_v; + std::string var_name = (is_expr ? "expr_" : "stmt_") + std::to_string(var_counter++); + + os << (is_expr ? "Expr " : "Stmt ") << var_name << " = " << node_type_str << "::make("; + for (size_t i = 0; i < printed_args.size(); ++i) { + os << printed_args[i] << (i + 1 == printed_args.size() ? "" : ", "); + } + os << ");\n"; + + node_names[op] = var_name; +} + +// Macro handles mapping the IR node pointer to our clever template. +#define VISIT_NODE(NodeType, ...) \ + void IRGraphCXXPrinter::visit(const NodeType *op) { \ + IRGraphVisitor::visit(op); \ + emit_node(#NodeType, op, __VA_ARGS__); \ + } + +// --- 1. Core / Primitive Values --- +VISIT_NODE(IntImm, op->type, op->value) +VISIT_NODE(UIntImm, op->type, op->value) +VISIT_NODE(FloatImm, op->type, op->value) +VISIT_NODE(StringImm, op->value) + +// --- 2. Variable & Broadcast --- +VISIT_NODE(Variable, op->type, op->name /*, op->image, op->param, op->reduction_domain */) +VISIT_NODE(Broadcast, op->value, op->lanes) + +// --- 3. Binary & Unary Math Nodes --- +VISIT_NODE(Add, op->a, op->b) +VISIT_NODE(Sub, op->a, op->b) +VISIT_NODE(Mod, op->a, op->b) +VISIT_NODE(Mul, op->a, op->b) +VISIT_NODE(Div, op->a, op->b) +VISIT_NODE(Min, op->a, op->b) +VISIT_NODE(Max, op->a, op->b) +VISIT_NODE(EQ, op->a, op->b) +VISIT_NODE(NE, op->a, op->b) +VISIT_NODE(LT, op->a, op->b) +VISIT_NODE(LE, op->a, op->b) +VISIT_NODE(GT, op->a, op->b) +VISIT_NODE(GE, op->a, op->b) +VISIT_NODE(And, op->a, op->b) +VISIT_NODE(Or, op->a, op->b) +VISIT_NODE(Not, op->a) + +// --- 4. Casts & Shuffles --- +VISIT_NODE(Cast, op->type, op->value) +VISIT_NODE(Reinterpret, op->type, op->value) +VISIT_NODE(Shuffle, op->vectors, op->indices) + +// --- 5. Complex Expressions --- +VISIT_NODE(Select, op->condition, op->true_value, op->false_value) +VISIT_NODE(Load, op->type, op->name, op->index, op->image, op->param, op->predicate, op->alignment) +VISIT_NODE(Ramp, op->base, op->stride, op->lanes) +VISIT_NODE(Call, op->type, op->name, op->args, op->call_type, op->func, op->value_index, op->image, op->param) +VISIT_NODE(Let, op->name, op->value, op->body) +VISIT_NODE(VectorReduce, op->op, op->value, op->type.lanes()) + +// --- 6. Core Statements --- +VISIT_NODE(LetStmt, op->name, op->value, op->body) +VISIT_NODE(AssertStmt, op->condition, op->message) +VISIT_NODE(Evaluate, op->value) +VISIT_NODE(Block, op->first, op->rest) +VISIT_NODE(IfThenElse, op->condition, op->then_case, op->else_case) +VISIT_NODE(For, op->name, op->min, op->max, op->for_type, op->partition_policy, op->device_api, op->body) + +// --- 7. Memory / Buffer Operations --- +VISIT_NODE(Store, op->name, op->value, op->index, op->param, op->predicate, op->alignment) +VISIT_NODE(Provide, op->name, op->values, op->args, op->predicate) +VISIT_NODE(Allocate, op->name, op->type, op->memory_type, op->extents, op->condition, op->body, op->new_expr, op->free_function) +VISIT_NODE(Free, op->name) +VISIT_NODE(Realize, op->name, op->types, op->memory_type, op->bounds, op->condition, op->body) +VISIT_NODE(Prefetch, op->name, op->types, op->bounds, op->prefetch, op->condition, op->body) +VISIT_NODE(HoistedStorage, op->name, op->body) + +// --- 8. Concurrency & Sync --- +VISIT_NODE(ProducerConsumer, op->name, op->is_producer, op->body) +VISIT_NODE(Acquire, op->semaphore, op->count, op->body) +VISIT_NODE(Fork, op->first, op->rest) +VISIT_NODE(Atomic, op->producer_name, op->mutex_name, op->body) + void IRGraphCXXPrinter::test() { // This: Expr e = Select::make(Mod::make(Ramp::make(10, 314, 8), Broadcast::make(10, 8)) < Variable::make(Int(32), "p"), Broadcast::make(40, 8) + Ramp::make(4, 8, 8), VectorReduce::make(VectorReduce::Add, Ramp::make(0, 1, 16), 8)); @@ -41,9 +272,7 @@ void IRGraphCXXPrinter::test() { // Now let's see if it matches: internal_assert(equal(expr_19, e)) << "Expressions don't match:\n\n" << e << "\n\n" - << expr_19 << "\n"; - - // Here is a bad typo for Alex who likes progamming. Above is a badly intented line. Two typos? + << expr_19 << "\n"; } } // namespace Internal } // namespace Halide diff --git a/src/IRGraphCXXPrinter.h b/src/IRGraphCXXPrinter.h index 0c07d283f023..fee2cf2d6a4a 100644 --- a/src/IRGraphCXXPrinter.h +++ b/src/IRGraphCXXPrinter.h @@ -15,7 +15,7 @@ namespace Halide { namespace Internal { -class IRGraphCXXPrinter : public IRVisitor { +class IRGraphCXXPrinter : public IRGraphVisitor { public: std::ostream &os; @@ -39,253 +39,64 @@ class IRGraphCXXPrinter : public IRVisitor { } private: - // ========================================================================= - // ✨ CLEVER TEMPLATING ✨ - // This SFINAE trick checks if `T::make` can be invoked with `Args...`. - // It will trigger a static_assert if you forget an argument or pass the - // wrong field types, completely preventing generated code compile errors! - // ========================================================================= template - static constexpr auto check_make_args(Args &&...args) - -> decltype(T::make(std::forward(args)...), std::true_type{}) { - return std::true_type{}; - } - - template - static constexpr std::false_type check_make_args(...) { - return std::false_type{}; - } - - // ========================================================================= - // ARGUMENT STRINGIFIERS - // These convert Halide objects into strings representing C++ code. - // ========================================================================= + void emit_node(const char *node_type_str, const T *op, Args &&...args); template - std::string to_cpp_arg(const T &x) { - if constexpr (std::is_arithmetic_v) { - return std::to_string(x); - } else { - internal_error << "Not supported to print"; - } - } - - template<> - std::string to_cpp_arg(const Expr &e) { - if (!e.defined()) { - return "Expr()"; - } - e.accept(this); // Visit dependencies first - return node_names.at(e.get()); - } - - std::string to_cpp_arg(const Stmt &s) { - if (!s.defined()) { - return "Stmt()"; - } - s.accept(this); // Visit dependencies first - return node_names.at(s.get()); - } - - std::string to_cpp_arg(const Range &r) { - r.min.accept(this); - r.extent.accept(this); // Visit dependencies first - return "Range(" + node_names.at(r.min.get()) + ", " + node_names.at(r.extent.get()) + ")"; - } - - std::string to_cpp_arg(const std::string &s) { - return "\"" + s + "\""; - } - - std::string to_cpp_arg(Type t) { - std::ostringstream oss; - oss << "Type(Type::" - << (t.is_int() ? "Int" : t.is_uint() ? "UInt" : - t.is_float() ? "Float" : - t.is_bfloat() ? "BFloat" : - "Handle") - << ", " << t.bits() << ", " << t.lanes() << ")"; - return oss.str(); - } - - std::string to_cpp_arg(ForType f) { - switch (f) { - case ForType::Serial: - return "ForType::Serial"; - case ForType::Parallel: - return "ForType::Parallel"; - case ForType::Vectorized: - return "ForType::Vectorized"; - case ForType::Unrolled: - return "ForType::Unrolled"; - case ForType::Extern: - return "ForType::Extern"; - case ForType::GPUBlock: - return "ForType::GPUBlock"; - case ForType::GPUThread: - return "ForType::GPUThread"; - case ForType::GPULane: - return "ForType::GPULane"; - default: - return "ForType::Serial"; - } - } - - std::string to_cpp_arg(const VectorReduce::Operator &op) { - switch (op) { - case VectorReduce::Add: - return "VectorReduce::Add"; - case VectorReduce::SaturatingAdd: - return "VectorReduce::SaturatingAdd"; - case VectorReduce::Mul: - return "VectorReduce::Mul"; - case VectorReduce::Min: - return "VectorReduce::Min"; - case VectorReduce::Max: - return "VectorReduce::Max"; - case VectorReduce::And: - return "VectorReduce::And"; - case VectorReduce::Or: - return "VectorReduce::Or"; - } - internal_error << "Invalid VectorReduce"; - } - - std::string to_cpp_arg(DeviceAPI api) { - return "DeviceAPI::" + std::to_string((int)api); // Or proper switch-case logic - } - - std::string to_cpp_arg(ModulusRemainder align) { - return "ModulusRemainder(" + std::to_string(align.modulus) + ", " + std::to_string(align.remainder) + ")"; - } - - std::string to_cpp_arg(const Parameter &p) { - internal_error << "Not supported to print Parameter"; - } + std::string to_cpp_arg(const T &x); template - std::string to_cpp_arg(const std::vector &vec) { - std::string res = "{"; - for (size_t i = 0; i < vec.size(); ++i) { - res += to_cpp_arg(vec[i]); - if (i + 1 < vec.size()) { - res += ", "; - } - } - res += "}"; - return res; - } - - // ========================================================================= - // CORE NODE EMITTER - // ========================================================================= - template - void emit_node(const char *node_type_str, const T *op, Args &&...args) { - // 1. Maintain DAG properties: if we've already generated it, skip. - if (node_names.count(op)) { - return; - } - - // 2. ✨ Check at our compile-time that the signature aligns exactly! - static_assert(decltype(check_make_args(std::forward(args)...))::value, - "Arguments extracted for printer do not match any T::make() signature! " - "Check your VISIT_NODE macro arguments."); - - // 3. Evaluate arguments post-order to build dependencies. - // (C++11 guarantees left-to-right evaluation in brace-init lists) - std::vector printed_args = {to_cpp_arg(args)...}; - - // 4. Generate the actual C++ code - bool is_expr = std::is_base_of_v; - std::string var_name = (is_expr ? "expr_" : "stmt_") + std::to_string(var_counter++); - - os << (is_expr ? "Expr " : "Stmt ") << var_name << " = " << node_type_str << "::make("; - for (size_t i = 0; i < printed_args.size(); ++i) { - os << printed_args[i] << (i + 1 == printed_args.size() ? "" : ", "); - } - os << ");\n"; - - node_names[op] = var_name; - } + std::string to_cpp_arg(const std::vector &vec); protected: -// ========================================================================= -// VISITOR OVERRIDES -// ========================================================================= - -// Macro handles mapping the IR node pointer to our clever template. -#define VISIT_NODE(NodeType, ...) \ - void visit(const NodeType *op) override { \ - IRVisitor::visit(op); \ - emit_node(#NodeType, op, __VA_ARGS__); \ - } - - // --- 1. Core / Primitive Values --- - VISIT_NODE(IntImm, op->type, op->value) - VISIT_NODE(UIntImm, op->type, op->value) - VISIT_NODE(FloatImm, op->type, op->value) - VISIT_NODE(StringImm, op->value) - - // --- 2. Variable & Broadcast --- - VISIT_NODE(Variable, op->type, op->name /*, op->image, op->param, op->reduction_domain */) - VISIT_NODE(Broadcast, op->value, op->lanes) - - // --- 3. Binary & Unary Math Nodes --- - VISIT_NODE(Add, op->a, op->b) - VISIT_NODE(Sub, op->a, op->b) - VISIT_NODE(Mod, op->a, op->b) - VISIT_NODE(Mul, op->a, op->b) - VISIT_NODE(Div, op->a, op->b) - VISIT_NODE(Min, op->a, op->b) - VISIT_NODE(Max, op->a, op->b) - VISIT_NODE(EQ, op->a, op->b) - VISIT_NODE(NE, op->a, op->b) - VISIT_NODE(LT, op->a, op->b) - VISIT_NODE(LE, op->a, op->b) - VISIT_NODE(GT, op->a, op->b) - VISIT_NODE(GE, op->a, op->b) - VISIT_NODE(And, op->a, op->b) - VISIT_NODE(Or, op->a, op->b) - VISIT_NODE(Not, op->a) - - // --- 4. Casts & Shuffles --- - VISIT_NODE(Cast, op->type, op->value) - VISIT_NODE(Reinterpret, op->type, op->value) - VISIT_NODE(Shuffle, op->vectors, op->indices) - - // --- 5. Complex Expressions --- - VISIT_NODE(Select, op->condition, op->true_value, op->false_value) - VISIT_NODE(Load, op->type, op->name, op->index, op->image, op->param, op->predicate, op->alignment) - VISIT_NODE(Ramp, op->base, op->stride, op->lanes) - VISIT_NODE(Call, op->type, op->name, op->args, op->call_type, op->func, op->value_index, op->image, op->param) - VISIT_NODE(Let, op->name, op->value, op->body) - VISIT_NODE(VectorReduce, op->op, op->value, op->type.lanes()) - - // --- 6. Core Statements --- - VISIT_NODE(LetStmt, op->name, op->value, op->body) - VISIT_NODE(AssertStmt, op->condition, op->message) - VISIT_NODE(Evaluate, op->value) - VISIT_NODE(Block, op->first, op->rest) - VISIT_NODE(IfThenElse, op->condition, op->then_case, op->else_case) - VISIT_NODE(For, op->name, op->min, op->max, op->for_type, op->partition_policy, op->device_api, op->body) - - // --- 7. Memory / Buffer Operations --- - VISIT_NODE(Store, op->name, op->value, op->index, op->param, op->predicate, op->alignment) - VISIT_NODE(Provide, op->name, op->values, op->args, op->predicate) - VISIT_NODE(Allocate, op->name, op->type, op->memory_type, op->extents, op->condition, op->body, op->new_expr, op->free_function) - VISIT_NODE(Free, op->name) - VISIT_NODE(Realize, op->name, op->types, op->memory_type, op->bounds, op->condition, op->body) - VISIT_NODE(Prefetch, op->name, op->types, op->bounds, op->prefetch, op->condition, op->body) - VISIT_NODE(HoistedStorage, op->name, op->body) - - // --- 8. Concurrency & Sync --- - VISIT_NODE(ProducerConsumer, op->name, op->is_producer, op->body) - VISIT_NODE(Acquire, op->semaphore, op->count, op->body) - VISIT_NODE(Fork, op->first, op->rest) - VISIT_NODE(Atomic, op->producer_name, op->mutex_name, op->body) - -// Clean up macro -#undef VISIT_NODE + void visit(const IntImm *) override; + void visit(const UIntImm *) override; + void visit(const FloatImm *) override; + void visit(const StringImm *) override; + void visit(const Cast *) override; + void visit(const Reinterpret *) override; + void visit(const Add *) override; + void visit(const Sub *) override; + void visit(const Mul *) override; + void visit(const Div *) override; + void visit(const Mod *) override; + void visit(const Min *) override; + void visit(const Max *) override; + void visit(const EQ *) override; + void visit(const NE *) override; + void visit(const LT *) override; + void visit(const LE *) override; + void visit(const GT *) override; + void visit(const GE *) override; + void visit(const And *) override; + void visit(const Or *) override; + void visit(const Not *) override; + void visit(const Select *) override; + void visit(const Load *) override; + void visit(const Ramp *) override; + void visit(const Broadcast *) override; + void visit(const Let *) override; + void visit(const LetStmt *) override; + void visit(const AssertStmt *) override; + void visit(const ProducerConsumer *) override; + void visit(const Store *) override; + void visit(const Provide *) override; + void visit(const Allocate *) override; + void visit(const Free *) override; + void visit(const Realize *) override; + void visit(const Block *) override; + void visit(const Fork *) override; + void visit(const IfThenElse *) override; + void visit(const Evaluate *) override; + void visit(const Call *) override; + void visit(const Variable *) override; + void visit(const For *) override; + void visit(const Acquire *) override; + void visit(const Shuffle *) override; + void visit(const Prefetch *) override; + void visit(const HoistedStorage *) override; + void visit(const Atomic *) override; + void visit(const VectorReduce *) override; public: static void test(); From 2832427b4ee1e1815568e33c1118457760f5fe46 Mon Sep 17 00:00:00 2001 From: Martijn Courteaux Date: Sun, 15 Mar 2026 23:19:50 +0100 Subject: [PATCH 05/13] Cleanup --- src/IRGraphCXXPrinter.cpp | 6 ------ src/IRGraphCXXPrinter.h | 2 -- 2 files changed, 8 deletions(-) diff --git a/src/IRGraphCXXPrinter.cpp b/src/IRGraphCXXPrinter.cpp index 20ca73fd5979..84ddab203039 100644 --- a/src/IRGraphCXXPrinter.cpp +++ b/src/IRGraphCXXPrinter.cpp @@ -9,12 +9,6 @@ namespace Halide { namespace Internal { namespace { -// ========================================================================= -// ✨ CLEVER TEMPLATING ✨ -// This SFINAE trick checks if `T::make` can be invoked with `Args...`. -// It will trigger a static_assert if you forget an argument or pass the -// wrong field types, completely preventing generated code compile errors! -// ========================================================================= template static constexpr auto check_make_args(Args &&...args) -> decltype(T::make(std::forward(args)...), std::true_type{}) { diff --git a/src/IRGraphCXXPrinter.h b/src/IRGraphCXXPrinter.h index fee2cf2d6a4a..8ce306576f45 100644 --- a/src/IRGraphCXXPrinter.h +++ b/src/IRGraphCXXPrinter.h @@ -3,9 +3,7 @@ #include #include -#include #include -#include #include #include "Expr.h" From 40d478032d79e45419fde92d3d35ea1f89ad0363 Mon Sep 17 00:00:00 2001 From: "halide-ci[bot]" <266445882+halide-ci[bot]@users.noreply.github.com> Date: Sun, 15 Mar 2026 22:23:26 +0000 Subject: [PATCH 06/13] Apply pre-commit auto-fixes --- test/internal.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/internal.cpp b/test/internal.cpp index 448a5960e79a..5df642b6db3b 100644 --- a/test/internal.cpp +++ b/test/internal.cpp @@ -9,6 +9,7 @@ #include "Generator.h" #include "IR.h" #include "IREquality.h" +#include "IRGraphCXXPrinter.h" #include "IRMatch.h" #include "IRPrinter.h" #include "Interval.h" @@ -18,7 +19,6 @@ #include "Solve.h" #include "SpirvIR.h" #include "UniquifyVariableNames.h" -#include "IRGraphCXXPrinter.h" using namespace Halide; using namespace Halide::Internal; From c19a87db39cd51567d4a6dc5251d58e2efbd3c9d Mon Sep 17 00:00:00 2001 From: Martijn Courteaux Date: Sun, 15 Mar 2026 23:31:17 +0100 Subject: [PATCH 07/13] Expand test. --- src/IRGraphCXXPrinter.cpp | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/IRGraphCXXPrinter.cpp b/src/IRGraphCXXPrinter.cpp index 84ddab203039..afd6fbea5c32 100644 --- a/src/IRGraphCXXPrinter.cpp +++ b/src/IRGraphCXXPrinter.cpp @@ -236,6 +236,10 @@ VISIT_NODE(Atomic, op->producer_name, op->mutex_name, op->body) void IRGraphCXXPrinter::test() { // This: Expr e = Select::make(Mod::make(Ramp::make(10, 314, 8), Broadcast::make(10, 8)) < Variable::make(Int(32), "p"), Broadcast::make(40, 8) + Ramp::make(4, 8, 8), VectorReduce::make(VectorReduce::Add, Ramp::make(0, 1, 16), 8)); + e = e * e; // make it a graph + e = cast(Float(32, 8), e); + e = reinterpret(Int(32, 8), e); + e = Shuffle::make_interleave({e, e * Broadcast::make(3, 8)}); // Printed by: IRGraphCXXPrinter p(std::cout); @@ -262,11 +266,19 @@ void IRGraphCXXPrinter::test() { Expr expr_17 = Ramp::make(expr_15, expr_16, 16); Expr expr_18 = VectorReduce::make(VectorReduce::Add, expr_17, 8); Expr expr_19 = Select::make(expr_8, expr_14, expr_18); + Expr expr_20 = Mul::make(expr_19, expr_19); + Expr expr_21 = Cast::make(Type(Type::Float, 32, 8), expr_20); + Expr expr_22 = Reinterpret::make(Type(Type::Int, 32, 8), expr_21); + Expr expr_23 = IntImm::make(Type(Type::Int, 32, 1), 3); + Expr expr_24 = Broadcast::make(expr_23, 8); + Expr expr_25 = Mul::make(expr_22, expr_24); + Expr expr_26 = Shuffle::make({expr_22, expr_25}, {0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15}); // Now let's see if it matches: - internal_assert(equal(expr_19, e)) << "Expressions don't match:\n\n" + const Expr &printed = expr_26; + internal_assert(equal(printed, e)) << "Expressions don't match:\n\n" << e << "\n\n" - << expr_19 << "\n"; + << printed << "\n"; } } // namespace Internal } // namespace Halide From c3b1a22f99a57a2370d320ff358997a8be3e6328 Mon Sep 17 00:00:00 2001 From: Martijn Courteaux Date: Mon, 16 Mar 2026 18:27:45 +0100 Subject: [PATCH 08/13] Don't pretend that we will IR print Stmts. --- src/IRGraphCXXPrinter.cpp | 5 +++++ src/IRGraphCXXPrinter.h | 25 ++----------------------- 2 files changed, 7 insertions(+), 23 deletions(-) diff --git a/src/IRGraphCXXPrinter.cpp b/src/IRGraphCXXPrinter.cpp index afd6fbea5c32..7077043764f4 100644 --- a/src/IRGraphCXXPrinter.cpp +++ b/src/IRGraphCXXPrinter.cpp @@ -40,6 +40,7 @@ std::string IRGraphCXXPrinter::to_cpp_arg(const Expr &e) { return node_names.at(e.get()); } +// Not used, but leaving in place in case we ever want to expand this to Stmts. template<> std::string IRGraphCXXPrinter::to_cpp_arg(const Stmt &s) { if (!s.defined()) { @@ -60,6 +61,8 @@ template<> std::string IRGraphCXXPrinter::to_cpp_arg(const std::string &s) { return "\"" + s + "\""; } + +// Not used, but leaving in place in case we ever want to expand this to Stmts. template<> std::string IRGraphCXXPrinter::to_cpp_arg(const ForType &f) { switch (f) { @@ -210,6 +213,7 @@ VISIT_NODE(Call, op->type, op->name, op->args, op->call_type, op->func, op->valu VISIT_NODE(Let, op->name, op->value, op->body) VISIT_NODE(VectorReduce, op->op, op->value, op->type.lanes()) +#if 0 // Currently no support yet for Stmts, however, the macros below are already correct. We just can't print everything yet. // --- 6. Core Statements --- VISIT_NODE(LetStmt, op->name, op->value, op->body) VISIT_NODE(AssertStmt, op->condition, op->message) @@ -232,6 +236,7 @@ VISIT_NODE(ProducerConsumer, op->name, op->is_producer, op->body) VISIT_NODE(Acquire, op->semaphore, op->count, op->body) VISIT_NODE(Fork, op->first, op->rest) VISIT_NODE(Atomic, op->producer_name, op->mutex_name, op->body) +#endif void IRGraphCXXPrinter::test() { // This: diff --git a/src/IRGraphCXXPrinter.h b/src/IRGraphCXXPrinter.h index 8ce306576f45..210eec6876d1 100644 --- a/src/IRGraphCXXPrinter.h +++ b/src/IRGraphCXXPrinter.h @@ -30,12 +30,6 @@ class IRGraphCXXPrinter : public IRGraphVisitor { } } - void print(const Stmt &s) { - if (s.defined()) { - s.accept(this); - } - } - private: template void emit_node(const char *node_type_str, const T *op, Args &&...args); @@ -47,6 +41,8 @@ class IRGraphCXXPrinter : public IRGraphVisitor { std::string to_cpp_arg(const std::vector &vec); protected: + using IRGraphVisitor::visit; + void visit(const IntImm *) override; void visit(const UIntImm *) override; void visit(const FloatImm *) override; @@ -74,26 +70,9 @@ class IRGraphCXXPrinter : public IRGraphVisitor { void visit(const Ramp *) override; void visit(const Broadcast *) override; void visit(const Let *) override; - void visit(const LetStmt *) override; - void visit(const AssertStmt *) override; - void visit(const ProducerConsumer *) override; - void visit(const Store *) override; - void visit(const Provide *) override; - void visit(const Allocate *) override; - void visit(const Free *) override; - void visit(const Realize *) override; - void visit(const Block *) override; - void visit(const Fork *) override; - void visit(const IfThenElse *) override; - void visit(const Evaluate *) override; void visit(const Call *) override; void visit(const Variable *) override; - void visit(const For *) override; - void visit(const Acquire *) override; void visit(const Shuffle *) override; - void visit(const Prefetch *) override; - void visit(const HoistedStorage *) override; - void visit(const Atomic *) override; void visit(const VectorReduce *) override; public: From fe19a362fc7bfe7a1164a1cae1839efcd4d27404 Mon Sep 17 00:00:00 2001 From: Martijn Courteaux Date: Mon, 16 Mar 2026 19:15:53 +0100 Subject: [PATCH 09/13] more clang-tidy; --- src/IRGraphCXXPrinter.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/IRGraphCXXPrinter.cpp b/src/IRGraphCXXPrinter.cpp index 7077043764f4..620cf5e8c54d 100644 --- a/src/IRGraphCXXPrinter.cpp +++ b/src/IRGraphCXXPrinter.cpp @@ -10,13 +10,13 @@ namespace Internal { namespace { template -static constexpr auto check_make_args(Args &&...args) +constexpr auto check_make_args(Args &&...args) -> decltype(T::make(std::forward(args)...), std::true_type{}) { return std::true_type{}; } template -static constexpr std::false_type check_make_args(...) { +constexpr std::false_type check_make_args(...) { return std::false_type{}; } From 3af8c26d4e0f3c4c0db6ab5bfe7f0fac59103cf6 Mon Sep 17 00:00:00 2001 From: Martijn Courteaux Date: Tue, 17 Mar 2026 15:47:07 +0100 Subject: [PATCH 10/13] Support printing Call statements. --- src/IRGraphCXXPrinter.cpp | 244 ++++++++++++++++++++++++++------------ 1 file changed, 165 insertions(+), 79 deletions(-) diff --git a/src/IRGraphCXXPrinter.cpp b/src/IRGraphCXXPrinter.cpp index 620cf5e8c54d..dd30690aed1b 100644 --- a/src/IRGraphCXXPrinter.cpp +++ b/src/IRGraphCXXPrinter.cpp @@ -1,10 +1,13 @@ #include "IRGraphCXXPrinter.h" #include "Expr.h" +#include "Function.h" #include "IR.h" #include "IREquality.h" #include "IROperator.h" +#include + namespace Halide { namespace Internal { @@ -62,50 +65,68 @@ std::string IRGraphCXXPrinter::to_cpp_arg(const std::string &s) { return "\"" + s + "\""; } +#define ENUM_TO_STR(x) \ + case x: \ + return #x; + // Not used, but leaving in place in case we ever want to expand this to Stmts. template<> std::string IRGraphCXXPrinter::to_cpp_arg(const ForType &f) { switch (f) { - case ForType::Serial: - return "ForType::Serial"; - case ForType::Parallel: - return "ForType::Parallel"; - case ForType::Vectorized: - return "ForType::Vectorized"; - case ForType::Unrolled: - return "ForType::Unrolled"; - case ForType::Extern: - return "ForType::Extern"; - case ForType::GPUBlock: - return "ForType::GPUBlock"; - case ForType::GPUThread: - return "ForType::GPUThread"; - case ForType::GPULane: - return "ForType::GPULane"; - default: - return "ForType::Serial"; + ENUM_TO_STR(ForType::Serial); + ENUM_TO_STR(ForType::Parallel); + ENUM_TO_STR(ForType::Vectorized); + ENUM_TO_STR(ForType::Unrolled); + ENUM_TO_STR(ForType::Extern); + ENUM_TO_STR(ForType::GPUBlock); + ENUM_TO_STR(ForType::GPUThread); + ENUM_TO_STR(ForType::GPULane); + } + return ""; +} + +template<> +std::string IRGraphCXXPrinter::to_cpp_arg(const Call::CallType &f) { + switch (f) { + ENUM_TO_STR(Call::CallType::Image); + ENUM_TO_STR(Call::CallType::Extern); + ENUM_TO_STR(Call::CallType::ExternCPlusPlus); + ENUM_TO_STR(Call::CallType::PureExtern); + ENUM_TO_STR(Call::CallType::Halide); + ENUM_TO_STR(Call::CallType::Intrinsic); + ENUM_TO_STR(Call::CallType::PureIntrinsic); } + return ""; } template<> std::string IRGraphCXXPrinter::to_cpp_arg(const VectorReduce::Operator &op) { switch (op) { - case VectorReduce::Add: - return "VectorReduce::Add"; - case VectorReduce::SaturatingAdd: - return "VectorReduce::SaturatingAdd"; - case VectorReduce::Mul: - return "VectorReduce::Mul"; - case VectorReduce::Min: - return "VectorReduce::Min"; - case VectorReduce::Max: - return "VectorReduce::Max"; - case VectorReduce::And: - return "VectorReduce::And"; - case VectorReduce::Or: - return "VectorReduce::Or"; + ENUM_TO_STR(VectorReduce::Add); + ENUM_TO_STR(VectorReduce::SaturatingAdd); + ENUM_TO_STR(VectorReduce::Mul); + ENUM_TO_STR(VectorReduce::Min); + ENUM_TO_STR(VectorReduce::Max); + ENUM_TO_STR(VectorReduce::And); + ENUM_TO_STR(VectorReduce::Or); + } + return ""; +} + +template<> +std::string IRGraphCXXPrinter::to_cpp_arg(const Halide::Parameter &p) { + if (!p.defined()) { + return "Parameter()"; + } + return "/* Parameter */ " + to_cpp_arg(p.name()); +} + +template<> +std::string IRGraphCXXPrinter::to_cpp_arg>(const Halide::Buffer<> &b) { + if (!b.defined()) { + return "Buffer<>()"; } - internal_error << "Invalid VectorReduce"; + return "/* Buffer */ " + to_cpp_arg(b.name()); } template<> @@ -209,7 +230,31 @@ VISIT_NODE(Shuffle, op->vectors, op->indices) VISIT_NODE(Select, op->condition, op->true_value, op->false_value) VISIT_NODE(Load, op->type, op->name, op->index, op->image, op->param, op->predicate, op->alignment) VISIT_NODE(Ramp, op->base, op->stride, op->lanes) -VISIT_NODE(Call, op->type, op->name, op->args, op->call_type, op->func, op->value_index, op->image, op->param) + +void IRGraphCXXPrinter::visit(const Call *op) { + if (op->call_type == Call::Image && op->image.defined()) { + // Variant 4: Convenience constructor for loads from concrete images + emit_node("Call", op, op->image, op->args); + } else if (op->call_type == Call::Image && op->param.defined()) { + // Variant 5: Convenience constructor for loads from image parameters + emit_node("Call", op, op->param, op->args); + } else if (op->call_type == Call::Halide && op->func.defined()) { + // Variant 3: Convenience constructor for calls to other halide functions. + // We wrap the FunctionPtr into a Function object to perfectly match + // the expected `const Function &func` signature. + emit_node("Call", op, Internal::Function(op->func), op->args, op->value_index); + } else if (op->is_intrinsic()) { + + emit_node("Call", op, op->type, op->name, op->args, op->call_type); + } else { + // Variant 2: Fallback to the fully explicit string-name version. + // (Note: Halide's API internally handles mapping string names back + // to IntrinsicOp if it happens to be an intrinsic call). + emit_node("Call", op, op->type, op->name, op->args, op->call_type, + op->func, op->value_index, op->image, op->param); + } +} + VISIT_NODE(Let, op->name, op->value, op->body) VISIT_NODE(VectorReduce, op->op, op->value, op->type.lanes()) @@ -238,52 +283,93 @@ VISIT_NODE(Fork, op->first, op->rest) VISIT_NODE(Atomic, op->producer_name, op->mutex_name, op->body) #endif +#undef ENUM_TO_STR + void IRGraphCXXPrinter::test() { - // This: - Expr e = Select::make(Mod::make(Ramp::make(10, 314, 8), Broadcast::make(10, 8)) < Variable::make(Int(32), "p"), Broadcast::make(40, 8) + Ramp::make(4, 8, 8), VectorReduce::make(VectorReduce::Add, Ramp::make(0, 1, 16), 8)); - e = e * e; // make it a graph - e = cast(Float(32, 8), e); - e = reinterpret(Int(32, 8), e); - e = Shuffle::make_interleave({e, e * Broadcast::make(3, 8)}); - - // Printed by: - IRGraphCXXPrinter p(std::cout); - p.print(e); - - // Prints: - Expr expr_0 = IntImm::make(Type(Type::Int, 32, 1), 10); - Expr expr_1 = IntImm::make(Type(Type::Int, 32, 1), 314); - Expr expr_2 = Ramp::make(expr_0, expr_1, 8); - Expr expr_3 = IntImm::make(Type(Type::Int, 32, 1), 10); - Expr expr_4 = Broadcast::make(expr_3, 8); - Expr expr_5 = Mod::make(expr_2, expr_4); - Expr expr_6 = Variable::make(Type(Type::Int, 32, 1), "p"); - Expr expr_7 = Broadcast::make(expr_6, 8); - Expr expr_8 = LT::make(expr_5, expr_7); - Expr expr_9 = IntImm::make(Type(Type::Int, 32, 1), 40); - Expr expr_10 = Broadcast::make(expr_9, 8); - Expr expr_11 = IntImm::make(Type(Type::Int, 32, 1), 4); - Expr expr_12 = IntImm::make(Type(Type::Int, 32, 1), 8); - Expr expr_13 = Ramp::make(expr_11, expr_12, 8); - Expr expr_14 = Add::make(expr_10, expr_13); - Expr expr_15 = IntImm::make(Type(Type::Int, 32, 1), 0); - Expr expr_16 = IntImm::make(Type(Type::Int, 32, 1), 1); - Expr expr_17 = Ramp::make(expr_15, expr_16, 16); - Expr expr_18 = VectorReduce::make(VectorReduce::Add, expr_17, 8); - Expr expr_19 = Select::make(expr_8, expr_14, expr_18); - Expr expr_20 = Mul::make(expr_19, expr_19); - Expr expr_21 = Cast::make(Type(Type::Float, 32, 8), expr_20); - Expr expr_22 = Reinterpret::make(Type(Type::Int, 32, 8), expr_21); - Expr expr_23 = IntImm::make(Type(Type::Int, 32, 1), 3); - Expr expr_24 = Broadcast::make(expr_23, 8); - Expr expr_25 = Mul::make(expr_22, expr_24); - Expr expr_26 = Shuffle::make({expr_22, expr_25}, {0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15}); - - // Now let's see if it matches: - const Expr &printed = expr_26; - internal_assert(equal(printed, e)) << "Expressions don't match:\n\n" - << e << "\n\n" - << printed << "\n"; + { + // This: + Expr e = Select::make(Mod::make(Ramp::make(10, 314, 8), Broadcast::make(10, 8)) < Variable::make(Int(32), "p"), Broadcast::make(40, 8) + Ramp::make(4, 8, 8), VectorReduce::make(VectorReduce::Add, Ramp::make(0, 1, 16), 8)); + e = e * e; // make it a graph + e = cast(Float(32, 8), e); + e = reinterpret(Int(32, 8), e); + e = Shuffle::make_interleave({e, e * Broadcast::make(3, 8)}); + + // Printed by: + std::stringstream ss; + IRGraphCXXPrinter p(ss); + p.print(e); + + std::cout << "Printed expr: " << e << "\n"; + std::cout << ss.str(); + std::cout << "\n"; + + // Prints: + Expr expr_0 = IntImm::make(Type(Type::Int, 32, 1), 10); + Expr expr_1 = IntImm::make(Type(Type::Int, 32, 1), 314); + Expr expr_2 = Ramp::make(expr_0, expr_1, 8); + Expr expr_3 = IntImm::make(Type(Type::Int, 32, 1), 10); + Expr expr_4 = Broadcast::make(expr_3, 8); + Expr expr_5 = Mod::make(expr_2, expr_4); + Expr expr_6 = Variable::make(Type(Type::Int, 32, 1), "p"); + Expr expr_7 = Broadcast::make(expr_6, 8); + Expr expr_8 = LT::make(expr_5, expr_7); + Expr expr_9 = IntImm::make(Type(Type::Int, 32, 1), 40); + Expr expr_10 = Broadcast::make(expr_9, 8); + Expr expr_11 = IntImm::make(Type(Type::Int, 32, 1), 4); + Expr expr_12 = IntImm::make(Type(Type::Int, 32, 1), 8); + Expr expr_13 = Ramp::make(expr_11, expr_12, 8); + Expr expr_14 = Add::make(expr_10, expr_13); + Expr expr_15 = IntImm::make(Type(Type::Int, 32, 1), 0); + Expr expr_16 = IntImm::make(Type(Type::Int, 32, 1), 1); + Expr expr_17 = Ramp::make(expr_15, expr_16, 16); + Expr expr_18 = VectorReduce::make(VectorReduce::Add, expr_17, 8); + Expr expr_19 = Select::make(expr_8, expr_14, expr_18); + Expr expr_20 = Mul::make(expr_19, expr_19); + Expr expr_21 = Cast::make(Type(Type::Float, 32, 8), expr_20); + Expr expr_22 = Reinterpret::make(Type(Type::Int, 32, 8), expr_21); + Expr expr_23 = IntImm::make(Type(Type::Int, 32, 1), 3); + Expr expr_24 = Broadcast::make(expr_23, 8); + Expr expr_25 = Mul::make(expr_22, expr_24); + Expr expr_26 = Shuffle::make({expr_22, expr_25}, {0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15}); + + // Now let's see if it matches: + const Expr &printed = expr_26; + internal_assert(equal(printed, e)) << "Expressions don't match:\n\n" + << e << "\n\n" + << printed << "\n"; + } + + { + // An expression Alex was interested in: + Expr imm1 = IntImm::make(Int(16), -32000); + Expr imm2 = UIntImm::make(UInt(16), 1); + Expr cast_imm1 = Cast::make(UInt(64), imm1); + Expr cast_imm2 = Cast::make(UInt(64), imm2); + Expr test_cast = ~(cast_imm1 / cast_imm2); + + // Printed by: + std::stringstream ss; + IRGraphCXXPrinter p(ss); + p.print(test_cast); + + std::cout << "Printed expr: " << test_cast << "\n"; + std::cout << ss.str(); + std::cout << "\n"; + + // Produces: + Expr expr_0 = IntImm::make(Type(Type::Int, 16, 1), -32000); + Expr expr_1 = Cast::make(Type(Type::UInt, 64, 1), expr_0); + Expr expr_2 = UIntImm::make(Type(Type::UInt, 16, 1), 1); + Expr expr_3 = Cast::make(Type(Type::UInt, 64, 1), expr_2); + Expr expr_4 = Div::make(expr_1, expr_3); + Expr expr_5 = Call::make(Type(Type::UInt, 64, 1), "bitwise_not", {expr_4}, Call::CallType::PureIntrinsic); + + // Now let's see if it matches: + const Expr &printed = expr_5; + internal_assert(equal(printed, test_cast)) << "Expressions don't match:\n\n" + << test_cast << "\n\n" + << printed << "\n"; + } } } // namespace Internal } // namespace Halide From 4a59bb2883da03a53a28e817d52ac93765752cb9 Mon Sep 17 00:00:00 2001 From: Martijn Courteaux Date: Mon, 23 Mar 2026 20:02:46 +0100 Subject: [PATCH 11/13] Add string-based testing for IRGraphCXXPrinter. --- src/IRGraphCXXPrinter.cpp | 98 ++++++++++++++++++++++----------------- 1 file changed, 55 insertions(+), 43 deletions(-) diff --git a/src/IRGraphCXXPrinter.cpp b/src/IRGraphCXXPrinter.cpp index dd30690aed1b..2e8c867dd63c 100644 --- a/src/IRGraphCXXPrinter.cpp +++ b/src/IRGraphCXXPrinter.cpp @@ -286,6 +286,8 @@ VISIT_NODE(Atomic, op->producer_name, op->mutex_name, op->body) #undef ENUM_TO_STR void IRGraphCXXPrinter::test() { +#define STR(X) #X "\n" +#define CODE(X) X { // This: Expr e = Select::make(Mod::make(Ramp::make(10, 314, 8), Broadcast::make(10, 8)) < Variable::make(Int(32), "p"), Broadcast::make(40, 8) + Ramp::make(4, 8, 8), VectorReduce::make(VectorReduce::Add, Ramp::make(0, 1, 16), 8)); @@ -299,44 +301,49 @@ void IRGraphCXXPrinter::test() { IRGraphCXXPrinter p(ss); p.print(e); - std::cout << "Printed expr: " << e << "\n"; - std::cout << ss.str(); - std::cout << "\n"; - // Prints: - Expr expr_0 = IntImm::make(Type(Type::Int, 32, 1), 10); - Expr expr_1 = IntImm::make(Type(Type::Int, 32, 1), 314); - Expr expr_2 = Ramp::make(expr_0, expr_1, 8); - Expr expr_3 = IntImm::make(Type(Type::Int, 32, 1), 10); - Expr expr_4 = Broadcast::make(expr_3, 8); - Expr expr_5 = Mod::make(expr_2, expr_4); - Expr expr_6 = Variable::make(Type(Type::Int, 32, 1), "p"); - Expr expr_7 = Broadcast::make(expr_6, 8); - Expr expr_8 = LT::make(expr_5, expr_7); - Expr expr_9 = IntImm::make(Type(Type::Int, 32, 1), 40); - Expr expr_10 = Broadcast::make(expr_9, 8); - Expr expr_11 = IntImm::make(Type(Type::Int, 32, 1), 4); - Expr expr_12 = IntImm::make(Type(Type::Int, 32, 1), 8); - Expr expr_13 = Ramp::make(expr_11, expr_12, 8); - Expr expr_14 = Add::make(expr_10, expr_13); - Expr expr_15 = IntImm::make(Type(Type::Int, 32, 1), 0); - Expr expr_16 = IntImm::make(Type(Type::Int, 32, 1), 1); - Expr expr_17 = Ramp::make(expr_15, expr_16, 16); - Expr expr_18 = VectorReduce::make(VectorReduce::Add, expr_17, 8); - Expr expr_19 = Select::make(expr_8, expr_14, expr_18); - Expr expr_20 = Mul::make(expr_19, expr_19); - Expr expr_21 = Cast::make(Type(Type::Float, 32, 8), expr_20); - Expr expr_22 = Reinterpret::make(Type(Type::Int, 32, 8), expr_21); - Expr expr_23 = IntImm::make(Type(Type::Int, 32, 1), 3); - Expr expr_24 = Broadcast::make(expr_23, 8); - Expr expr_25 = Mul::make(expr_22, expr_24); - Expr expr_26 = Shuffle::make({expr_22, expr_25}, {0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15}); - - // Now let's see if it matches: +#define RESULT(X) \ + X(Expr expr_0 = IntImm::make(Type(Type::Int, 32, 1), 10);) \ + X(Expr expr_1 = IntImm::make(Type(Type::Int, 32, 1), 314);) \ + X(Expr expr_2 = Ramp::make(expr_0, expr_1, 8);) \ + X(Expr expr_3 = IntImm::make(Type(Type::Int, 32, 1), 10);) \ + X(Expr expr_4 = Broadcast::make(expr_3, 8);) \ + X(Expr expr_5 = Mod::make(expr_2, expr_4);) \ + X(Expr expr_6 = Variable::make(Type(Type::Int, 32, 1), "p");) \ + X(Expr expr_7 = Broadcast::make(expr_6, 8);) \ + X(Expr expr_8 = LT::make(expr_5, expr_7);) \ + X(Expr expr_9 = IntImm::make(Type(Type::Int, 32, 1), 40);) \ + X(Expr expr_10 = Broadcast::make(expr_9, 8);) \ + X(Expr expr_11 = IntImm::make(Type(Type::Int, 32, 1), 4);) \ + X(Expr expr_12 = IntImm::make(Type(Type::Int, 32, 1), 8);) \ + X(Expr expr_13 = Ramp::make(expr_11, expr_12, 8);) \ + X(Expr expr_14 = Add::make(expr_10, expr_13);) \ + X(Expr expr_15 = IntImm::make(Type(Type::Int, 32, 1), 0);) \ + X(Expr expr_16 = IntImm::make(Type(Type::Int, 32, 1), 1);) \ + X(Expr expr_17 = Ramp::make(expr_15, expr_16, 16);) \ + X(Expr expr_18 = VectorReduce::make(VectorReduce::Add, expr_17, 8);) \ + X(Expr expr_19 = Select::make(expr_8, expr_14, expr_18);) \ + X(Expr expr_20 = Mul::make(expr_19, expr_19);) \ + X(Expr expr_21 = Cast::make(Type(Type::Float, 32, 8), expr_20);) \ + X(Expr expr_22 = Reinterpret::make(Type(Type::Int, 32, 8), expr_21);) \ + X(Expr expr_23 = IntImm::make(Type(Type::Int, 32, 1), 3);) \ + X(Expr expr_24 = Broadcast::make(expr_23, 8);) \ + X(Expr expr_25 = Mul::make(expr_22, expr_24);) \ + X(Expr expr_26 = Shuffle::make({expr_22, expr_25}, {0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15});) + + std::string expected = RESULT(STR); + internal_assert(expected == ss.str()) << "Generated C++ code was not as expected." + << "Expected:\n" + << expected << "\n\nGenerated:\n" + << ss.str() << "\n"; + + // Now let's see if the IR it produces matches: + RESULT(CODE); const Expr &printed = expr_26; internal_assert(equal(printed, e)) << "Expressions don't match:\n\n" << e << "\n\n" << printed << "\n"; +#undef RESULT } { @@ -352,23 +359,28 @@ void IRGraphCXXPrinter::test() { IRGraphCXXPrinter p(ss); p.print(test_cast); - std::cout << "Printed expr: " << test_cast << "\n"; - std::cout << ss.str(); - std::cout << "\n"; - // Produces: - Expr expr_0 = IntImm::make(Type(Type::Int, 16, 1), -32000); - Expr expr_1 = Cast::make(Type(Type::UInt, 64, 1), expr_0); - Expr expr_2 = UIntImm::make(Type(Type::UInt, 16, 1), 1); - Expr expr_3 = Cast::make(Type(Type::UInt, 64, 1), expr_2); - Expr expr_4 = Div::make(expr_1, expr_3); - Expr expr_5 = Call::make(Type(Type::UInt, 64, 1), "bitwise_not", {expr_4}, Call::CallType::PureIntrinsic); +#define RESULT(X) \ + X(Expr expr_0 = IntImm::make(Type(Type::Int, 16, 1), -32000);) \ + X(Expr expr_1 = Cast::make(Type(Type::UInt, 64, 1), expr_0);) \ + X(Expr expr_2 = UIntImm::make(Type(Type::UInt, 16, 1), 1);) \ + X(Expr expr_3 = Cast::make(Type(Type::UInt, 64, 1), expr_2);) \ + X(Expr expr_4 = Div::make(expr_1, expr_3);) \ + X(Expr expr_5 = Call::make(Type(Type::UInt, 64, 1), "bitwise_not", {expr_4}, Call::CallType::PureIntrinsic);) + + std::string expected = RESULT(STR); + internal_assert(expected == ss.str()) << "Generated C++ code was not as expected." + << "Expected:\n" + << expected << "\n\nGenerated:\n" + << ss.str() << "\n"; // Now let's see if it matches: + RESULT(CODE); const Expr &printed = expr_5; internal_assert(equal(printed, test_cast)) << "Expressions don't match:\n\n" << test_cast << "\n\n" << printed << "\n"; +#undef RESULT } } } // namespace Internal From 6ea527aa3a64e210ac36faa4404e5af857b17500 Mon Sep 17 00:00:00 2001 From: Martijn Courteaux Date: Fri, 10 Apr 2026 23:10:17 +0200 Subject: [PATCH 12/13] Move IRGraphCXXPrinter to test/fuzz --- {src => test/fuzz}/IRGraphCXXPrinter.cpp | 0 {src => test/fuzz}/IRGraphCXXPrinter.h | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename {src => test/fuzz}/IRGraphCXXPrinter.cpp (100%) rename {src => test/fuzz}/IRGraphCXXPrinter.h (100%) diff --git a/src/IRGraphCXXPrinter.cpp b/test/fuzz/IRGraphCXXPrinter.cpp similarity index 100% rename from src/IRGraphCXXPrinter.cpp rename to test/fuzz/IRGraphCXXPrinter.cpp diff --git a/src/IRGraphCXXPrinter.h b/test/fuzz/IRGraphCXXPrinter.h similarity index 100% rename from src/IRGraphCXXPrinter.h rename to test/fuzz/IRGraphCXXPrinter.h From ce4c6eef4a691d46e6eeac47f144aafa4de1a877 Mon Sep 17 00:00:00 2001 From: Martijn Courteaux Date: Fri, 10 Apr 2026 23:37:38 +0200 Subject: [PATCH 13/13] Move IRGraphCXXPrinter to the fuzz folder and fixup compilation, and make the simplify fuzzer print out the Expr. --- Makefile | 6 ++---- src/CMakeLists.txt | 2 -- test/fuzz/CMakeLists.txt | 14 +++++++++----- test/fuzz/IRGraphCXXPrinter.cpp | 6 ------ test/fuzz/IRGraphCXXPrinter.h | 4 +--- test/fuzz/simplify.cpp | 8 ++++++++ test/internal.cpp | 2 -- 7 files changed, 20 insertions(+), 22 deletions(-) diff --git a/Makefile b/Makefile index 1a30944417ca..f866979a978f 100644 --- a/Makefile +++ b/Makefile @@ -529,7 +529,6 @@ SOURCE_FILES = \ Interval.cpp \ IR.cpp \ IREquality.cpp \ - IRGraphCXXPrinter.cpp \ IRMatch.cpp \ IRMutator.cpp \ IROperator.cpp \ @@ -733,7 +732,6 @@ HEADER_FILES = \ IntrusivePtr.h \ IR.h \ IREquality.h \ - IRGraphCXXPrinter.h \ IRMatch.h \ IRMutator.h \ IROperator.h \ @@ -1288,7 +1286,7 @@ PERFORMANCE_TESTS = $(shell ls $(ROOT_DIR)/test/performance/*.cpp) ERROR_TESTS = $(shell ls $(ROOT_DIR)/test/error/*.cpp) WARNING_TESTS = $(shell ls $(ROOT_DIR)/test/warning/*.cpp) RUNTIME_TESTS = $(shell ls $(ROOT_DIR)/test/runtime/*.cpp) -FUZZ_TESTS = $(filter-out %halide_fuzz_main.cpp, $(shell ls $(ROOT_DIR)/test/fuzz/*.cpp)) +FUZZ_TESTS = $(filter-out %halide_fuzz_main.cpp %IRGraphCXXPrinter.cpp, $(shell ls $(ROOT_DIR)/test/fuzz/*.cpp)) GENERATOR_EXTERNAL_TESTS := $(shell ls $(ROOT_DIR)/test/generator/*test.cpp) GENERATOR_EXTERNAL_TEST_GENERATOR := $(shell ls $(ROOT_DIR)/test/generator/*_generator.cpp) TUTORIALS = $(filter-out %_generate.cpp, $(shell ls $(ROOT_DIR)/tutorial/*.cpp)) @@ -1477,7 +1475,7 @@ $(BIN_DIR)/$(TARGET)/correctness_opencl_runtime: $(ROOT_DIR)/test/correctness/op $(BIN_DIR)/performance_%: $(ROOT_DIR)/test/performance/%.cpp $(TEST_DEPS) $(CXX) $(TEST_CXX_FLAGS) $(OPTIMIZE) $< -I$(INCLUDE_DIR) -I$(ROOT_DIR)/src/runtime -I$(ROOT_DIR)/test/common $(TEST_LD_FLAGS) -o $@ -$(BIN_DIR)/fuzz_%: $(ROOT_DIR)/test/fuzz/%.cpp $(ROOT_DIR)/test/fuzz/halide_fuzz_main.cpp $(ROOT_DIR)/test/fuzz/fuzz_helpers.h $(ROOT_DIR)/test/fuzz/halide_fuzz_main.h $(TEST_DEPS) +$(BIN_DIR)/fuzz_%: $(ROOT_DIR)/test/fuzz/%.cpp $(ROOT_DIR)/test/fuzz/halide_fuzz_main.cpp $(ROOT_DIR)/test/fuzz/fuzz_helpers.h $(ROOT_DIR)/test/fuzz/halide_fuzz_main.h $(ROOT_DIR)/test/fuzz/IRGraphCXXPrinter.cpp $(ROOT_DIR)/test/fuzz/IRGraphCXXPrinter.h $(TEST_DEPS) $(CXX) $(TEST_CXX_FLAGS) -I$(ROOT_DIR)/src/runtime -I$(ROOT_DIR)/test/common $(OPTIMIZE_FOR_BUILD_TIME) $(filter %.cpp,$^) -I$(INCLUDE_DIR) $(TEST_LD_FLAGS) -o $@ -DHALIDE_FUZZER_BACKEND=0 # Error tests that link against libHalide diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 6d69e1cd57f3..a373136025a9 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -139,7 +139,6 @@ target_sources( IntrusivePtr.h IR.h IREquality.h - IRGraphCXXPrinter.h IRMatch.h IRMutator.h IROperator.h @@ -319,7 +318,6 @@ target_sources( Interval.cpp IR.cpp IREquality.cpp - IRGraphCXXPrinter.cpp IRMatch.cpp IRMutator.cpp IROperator.cpp diff --git a/test/fuzz/CMakeLists.txt b/test/fuzz/CMakeLists.txt index 97b8372cf2ab..af4ba4fad84d 100644 --- a/test/fuzz/CMakeLists.txt +++ b/test/fuzz/CMakeLists.txt @@ -36,20 +36,24 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t *Data, std::size_t Size) { } ]] HAVE_LIBFUZZER_FLAGS) -add_library(Halide_fuzz INTERFACE) +add_library(Halide_fuzz OBJECT) add_library(Halide::fuzz ALIAS Halide_fuzz) +target_sources(Halide_fuzz PRIVATE IRGraphCXXPrinter.cpp) +target_link_libraries(Halide_fuzz PRIVATE Halide::Test) + if (NOT HAVE_LIBFUZZER_FLAGS) if (LIB_FUZZING_ENGINE) message(FATAL_ERROR "Cannot set LIB_FUZZING_ENGINE when not building with -fsanitize=fuzzer or a compatible fuzzing engine.") endif () - target_sources(Halide_fuzz INTERFACE halide_fuzz_main.cpp halide_fuzz_main.h) - target_compile_definitions(Halide_fuzz INTERFACE HALIDE_FUZZER_BACKEND=HALIDE_FUZZER_BACKEND_STDLIB) + target_sources(Halide_fuzz PRIVATE halide_fuzz_main.cpp halide_fuzz_main.h) + target_compile_definitions(Halide_fuzz PUBLIC HALIDE_FUZZER_BACKEND=HALIDE_FUZZER_BACKEND_STDLIB) else () - target_link_libraries(Halide_fuzz INTERFACE ${LIB_FUZZING_ENGINE}) - target_compile_definitions(Halide_fuzz INTERFACE HALIDE_FUZZER_BACKEND=HALIDE_FUZZER_BACKEND_LIBFUZZER) + target_link_libraries(Halide_fuzz PUBLIC ${LIB_FUZZING_ENGINE}) + target_compile_definitions(Halide_fuzz PUBLIC HALIDE_FUZZER_BACKEND=HALIDE_FUZZER_BACKEND_LIBFUZZER) endif () + foreach (fuzzer IN LISTS TEST_NAMES) target_link_libraries("${fuzzer}" PRIVATE Halide::fuzz) endforeach () diff --git a/test/fuzz/IRGraphCXXPrinter.cpp b/test/fuzz/IRGraphCXXPrinter.cpp index 2e8c867dd63c..83a7e3739ad2 100644 --- a/test/fuzz/IRGraphCXXPrinter.cpp +++ b/test/fuzz/IRGraphCXXPrinter.cpp @@ -1,11 +1,5 @@ #include "IRGraphCXXPrinter.h" -#include "Expr.h" -#include "Function.h" -#include "IR.h" -#include "IREquality.h" -#include "IROperator.h" - #include namespace Halide { diff --git a/test/fuzz/IRGraphCXXPrinter.h b/test/fuzz/IRGraphCXXPrinter.h index 210eec6876d1..772ec2e5474e 100644 --- a/test/fuzz/IRGraphCXXPrinter.h +++ b/test/fuzz/IRGraphCXXPrinter.h @@ -6,9 +6,7 @@ #include #include -#include "Expr.h" -#include "IR.h" -#include "IRVisitor.h" +#include namespace Halide { namespace Internal { diff --git a/test/fuzz/simplify.cpp b/test/fuzz/simplify.cpp index 2ddc82e6913d..4964f58df6cc 100644 --- a/test/fuzz/simplify.cpp +++ b/test/fuzz/simplify.cpp @@ -1,6 +1,7 @@ #include "Halide.h" #include +#include "IRGraphCXXPrinter.h" #include "fuzz_helpers.h" #include "random_expr_generator.h" @@ -153,6 +154,13 @@ FUZZ_TEST(simplify, FuzzingContext &fuzz) { return e; }); std::cerr << "Final test case: " << test << "\n"; + + std::cerr << "\n\nC++ code:\n\n"; + IRGraphCXXPrinter printer(std::cerr); + printer.print(test); + std::cerr << "Expr final_expr = " << printer.node_names[test.get()] << ";\n"; + std::cerr << "\n\n"; + return 1; } diff --git a/test/internal.cpp b/test/internal.cpp index 5df642b6db3b..08283fa9cf54 100644 --- a/test/internal.cpp +++ b/test/internal.cpp @@ -9,7 +9,6 @@ #include "Generator.h" #include "IR.h" #include "IREquality.h" -#include "IRGraphCXXPrinter.h" #include "IRMatch.h" #include "IRPrinter.h" #include "Interval.h" @@ -25,7 +24,6 @@ using namespace Halide::Internal; int main(int argc, const char **argv) { IRPrinter::test(); - IRGraphCXXPrinter::test(); CodeGen_C::test(); ir_equality_test(); bounds_test();