diff --git a/taichi/ir/ir_builder.cpp b/taichi/ir/ir_builder.cpp index b28412cd441f4..bfc99fcbfc744 100644 --- a/taichi/ir/ir_builder.cpp +++ b/taichi/ir/ir_builder.cpp @@ -138,6 +138,12 @@ LoopIndexStmt *IRBuilder::get_loop_index(Stmt *loop, int index) { return insert(Stmt::make_typed(loop, index)); } +ConstStmt *IRBuilder::get_bool(bool value) { + return insert(Stmt::make_typed(TypedConstant( + TypeFactory::get_instance().get_primitive_type(PrimitiveTypeID::u1), + value))); +} + ConstStmt *IRBuilder::get_int32(int32 value) { return insert(Stmt::make_typed(TypedConstant( TypeFactory::get_instance().get_primitive_type(PrimitiveTypeID::i32), diff --git a/taichi/ir/ir_builder.h b/taichi/ir/ir_builder.h index c585ed7be425e..5c3aab031577b 100644 --- a/taichi/ir/ir_builder.h +++ b/taichi/ir/ir_builder.h @@ -130,6 +130,7 @@ class IRBuilder { LoopIndexStmt *get_loop_index(Stmt *loop, int index = 0); // Constants. TODO: add more types + ConstStmt *get_bool(bool value); ConstStmt *get_int32(int32 value); ConstStmt *get_int64(int64 value); ConstStmt *get_uint32(uint32 value); diff --git a/tests/cpp/ir/ir_builder_test.cpp b/tests/cpp/ir/ir_builder_test.cpp index c76e4febdee64..9410cdf37a16f 100644 --- a/tests/cpp/ir/ir_builder_test.cpp +++ b/tests/cpp/ir/ir_builder_test.cpp @@ -4,12 +4,30 @@ #include "taichi/ir/statements.h" #include "tests/cpp/program/test_program.h" #include "tests/cpp/ir/ndarray_kernel.h" +#include "taichi/ir/transforms.h" #ifdef TI_WITH_VULKAN #include "taichi/rhi/vulkan/vulkan_loader.h" #endif namespace taichi::lang { +TEST(IRBuilder, Bool) { + IRBuilder builder; + auto *bool_true = builder.get_bool(true); + auto *bool_false = builder.get_bool(false); + builder.create_and(bool_true, bool_false); + auto block = builder.extract_ir(); + + std::string ir_string; + irpass::print(block->get_ir_root(), &ir_string); + EXPECT_STREQ(ir_string.c_str(), R"(kernel { + $0 = const true + $1 = const false + $2 = bit_and $0 $1 +} +)"); +} + TEST(IRBuilder, Basic) { IRBuilder builder; auto *lhs = builder.get_int32(40);