Skip to content

Commit 2e0fcf2

Browse files
committed
upd
1 parent 6ca25d8 commit 2e0fcf2

8 files changed

Lines changed: 97 additions & 203 deletions

File tree

src/backend/metal/op/copy.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
#include "op/utils.h"
1010
#include "target/utils.h"
1111

12-
#include <tvm/tir/builtin.h>
12+
#include <tvm/tirx/builtin.h>
1313

1414
#include <algorithm>
1515
#include <cmath>

src/backend/metal/op/fill.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@
1212
#include "transform/loop_partition.h"
1313
#include "transform/loop_vectorize.h"
1414

15-
#include <tvm/tir/builtin.h>
15+
#include <tvm/tirx/builtin.h>
1616

1717
namespace tvm {
1818
namespace tl {
1919

20-
using namespace tir;
20+
using namespace tirx;
2121

2222
namespace metal {
2323

src/backend/metal/op/gemm.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,16 @@
77

88
#include "target/utils.h"
99

10+
#include <tvm/runtime/logging.h>
11+
1012
#include <cmath>
1113
#include <limits>
1214
#include <utility>
1315

1416
namespace tvm {
1517
namespace tl {
1618

17-
using namespace tir;
19+
using namespace tirx;
1820

1921
namespace metal {
2022

src/backend/metal/op/utils.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@ namespace tvm {
1212
namespace tl {
1313
namespace metal {
1414

15-
inline bool IsSIMDGroupBuffer(const tir::Buffer &buffer) {
15+
inline bool IsSIMDGroupBuffer(const Buffer &buffer) {
1616
return buffer.defined() && buffer.scope() == "metal.simdgroup";
1717
}
1818

19-
inline bool IsRegisterBuffer(const tir::Buffer &buffer) {
19+
inline bool IsRegisterBuffer(const Buffer &buffer) {
2020
return IsFragmentBuffer(buffer) || IsSIMDGroupBuffer(buffer);
2121
}
2222

src/runtime/metal/metal_module.h

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
#ifndef TVM_RUNTIME_METAL_METAL_MODULE_H_
2+
#define TVM_RUNTIME_METAL_METAL_MODULE_H_
3+
4+
#include <tvm/ffi/container/map.h>
5+
#include <tvm/ffi/extra/module.h>
6+
#include <tvm/ffi/function.h>
7+
#include <tvm/ir/module.h>
8+
#include <tvm/runtime/logging.h>
9+
#include <tvm/target/codegen.h>
10+
11+
#include <string>
12+
#include <utility>
13+
14+
namespace tvm {
15+
namespace codegen {
16+
17+
inline ffi::Module MetalModuleCreate(ffi::Map<ffi::String, ffi::Bytes> smap,
18+
ffi::Map<ffi::String, runtime::FunctionInfo> fmap,
19+
ffi::String fmt, ffi::String source) {
20+
auto fcreate = ffi::Function::GetGlobal("ffi.Module.create.metal");
21+
if (fcreate.has_value()) {
22+
return (*fcreate)(smap, fmt, fmap,
23+
ffi::Map<ffi::String, ffi::String>{{"metal", source}})
24+
.cast<ffi::Module>();
25+
}
26+
auto fallback = ffi::Function::GetGlobal("ffi.Module.create.metal_fallback");
27+
if (fallback.has_value()) {
28+
return (*fallback)(smap, fmt, fmap,
29+
ffi::Map<ffi::String, ffi::String>{{"metal", source}})
30+
.cast<ffi::Module>();
31+
}
32+
LOG(FATAL) << "Metal module factory not available.";
33+
// Unreachable; LOG(FATAL) aborts.
34+
__builtin_unreachable();
35+
}
36+
37+
} // namespace codegen
38+
} // namespace tvm
39+
40+
#endif // TVM_RUNTIME_METAL_METAL_MODULE_H_

src/target/codegen_metal.cc

Lines changed: 48 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
#include "codegen_metal.h"
2424

2525
#include <tvm/ffi/reflection/registry.h>
26-
#include <tvm/tir/transform.h>
26+
#include <tvm/tirx/transform.h>
2727

2828
#include <algorithm>
2929
#include <sstream>
@@ -79,7 +79,7 @@ void CodeGenTileLangMetal::AddFunction(const GlobalVar &gvar,
7979

8080
// add to alloc buffer type.
8181
auto global_symbol = func->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol);
82-
ICHECK(global_symbol.has_value())
82+
TVM_FFI_ICHECK(global_symbol.has_value())
8383
<< "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
8484

8585
// Function header.
@@ -128,7 +128,7 @@ void CodeGenTileLangMetal::AddFunction(const GlobalVar &gvar,
128128
decl_stream << "struct " << arg_buf_type << " {\n";
129129
for (size_t i = num_buffer; i < func->params.size(); ++i) {
130130
Var v = func->params[i];
131-
ICHECK(!v.dtype().is_handle());
131+
TVM_FFI_ICHECK(!v.dtype().is_handle());
132132
std::string vid = AllocVarID(v.get());
133133
std::ostringstream vref;
134134
if (v.dtype().bits() == 32) {
@@ -152,11 +152,11 @@ void CodeGenTileLangMetal::AddFunction(const GlobalVar &gvar,
152152
decl_stream << "};\n\n";
153153
}
154154
// Setup the thread group info.
155-
ICHECK_EQ(name_supply_->FreshName("threadIdx"), "threadIdx");
156-
ICHECK_EQ(name_supply_->FreshName("blockIdx"), "blockIdx");
155+
TVM_FFI_ICHECK_EQ(name_supply_->FreshName("threadIdx"), "threadIdx");
156+
TVM_FFI_ICHECK_EQ(name_supply_->FreshName("blockIdx"), "blockIdx");
157157
int work_dim = 0;
158158
auto launch_params =
159-
func->GetAttr<ffi::Array<ffi::String>>(tir::attr::kKernelLaunchParams)
159+
func->GetAttr<ffi::Array<ffi::String>>(tirx::attr::kKernelLaunchParams)
160160
.value();
161161
for (const auto &tag : launch_params) {
162162
if (tag != runtime::launch_param::kUseDynamicSharedMemoryTag) {
@@ -186,7 +186,7 @@ void CodeGenTileLangMetal::AddFunction(const GlobalVar &gvar,
186186
}
187187

188188
void CodeGenTileLangMetal::BindThreadIndex(const IterVar &iv) {
189-
ICHECK(!var_idmap_.count(iv->var.get()));
189+
TVM_FFI_ICHECK(!var_idmap_.count(iv->var.get()));
190190
// if we only have threadIdx.x
191191
// metal will directly print as threadIdx
192192
std::string vname = iv->thread_tag;
@@ -201,7 +201,7 @@ void CodeGenTileLangMetal::PrintType(DataType t,
201201
std::ostream &os) { // NOLINT(*)
202202
int lanes = t.lanes();
203203
if (t.is_handle()) {
204-
ICHECK_EQ(lanes, 1) << "do not yet support vector types";
204+
TVM_FFI_ICHECK_EQ(lanes, 1) << "do not yet support vector types";
205205
os << "void*";
206206
return;
207207
}
@@ -326,40 +326,46 @@ void CodeGenTileLangMetal::PrintStorageScope(const std::string &scope,
326326
}
327327
}
328328

329-
void CodeGenTileLangMetal::VisitStmt_(const AllocateNode *op) {
330-
ICHECK(!is_zero(op->condition));
331-
std::string vid = AllocVarID(op->buffer_var.get());
329+
void CodeGenTileLangMetal::VisitStmt_(const AllocBufferNode *op) {
330+
TVM_FFI_ICHECK(op->buffer.defined());
331+
std::string vid = AllocVarID(op->buffer->data.get());
332332

333333
this->PrintIndent();
334-
size_t constant_size = op->ConstantAllocationSize();
335-
ICHECK_GT(constant_size, 0)
334+
// Compute constant_size from buffer shape
335+
size_t constant_size = 1;
336+
for (const auto &dim : op->buffer->shape) {
337+
const IntImmNode *dim_imm = dim.as<IntImmNode>();
338+
TVM_FFI_ICHECK(dim_imm) << "Can only handle constant size stack allocation for now";
339+
constant_size *= dim_imm->value;
340+
}
341+
TVM_FFI_ICHECK_GT(constant_size, 0)
336342
<< "Can only handle constant size stack allocation for now";
337343

338-
auto scope = GetPtrStorageScope(op->buffer_var);
339-
alloc_storage_scope_[op->buffer_var.get()] = scope;
344+
DataType dtype = op->buffer->dtype;
345+
auto scope = GetPtrStorageScope(op->buffer->data);
346+
alloc_storage_scope_[op->buffer->data.get()] = scope;
340347
if (scope == "metal.simdgroup") {
341-
ICHECK(op->dtype == DataType::Float(16) ||
342-
op->dtype == DataType::Float(32) ||
343-
op->dtype == DataType::BFloat(16))
348+
TVM_FFI_ICHECK(dtype == DataType::Float(16) ||
349+
dtype == DataType::Float(32) ||
350+
dtype == DataType::BFloat(16))
344351
<< "Only float16, float32, and bfloat16 are supported, but got "
345-
<< op->dtype;
346-
ICHECK(constant_size % 64 == 0) << "Only 8x8 matrix is supported, but got "
352+
<< dtype;
353+
TVM_FFI_ICHECK(constant_size % 64 == 0) << "Only 8x8 matrix is supported, but got "
347354
<< constant_size << " bytes\n";
348355

349356
std::ostringstream dtype_os;
350-
PrintType(op->dtype, dtype_os);
357+
PrintType(dtype, dtype_os);
351358
std::string dtype_str = dtype_os.str();
352-
simdgroup_dtype_[op->buffer_var.get()] = dtype_str;
359+
simdgroup_dtype_[op->buffer->data.get()] = dtype_str;
353360
stream << "simdgroup_" << dtype_str << "8x8 " << vid << '['
354361
<< constant_size / 64 << "];\n";
355362
} else {
356363
PrintStorageScope(scope, stream);
357-
PrintType(op->dtype, stream);
364+
PrintType(dtype, stream);
358365
stream << ' ' << vid << '[' << constant_size << "];\n";
359366
}
360367

361-
RegisterHandleType(op->buffer_var.get(), op->dtype);
362-
this->PrintStmt(op->body);
368+
RegisterHandleType(op->buffer->data.get(), dtype);
363369
}
364370

365371
void CodeGenTileLangMetal::VisitExpr_(const SelectNode *op,
@@ -394,26 +400,26 @@ void CodeGenTileLangMetal::VisitExpr_(const BroadcastNode *op,
394400

395401
void CodeGenTileLangMetal::VisitExpr_(const CallNode *op,
396402
std::ostream &os) { // NOLINT(*)
397-
CHECK(!op->op.as<GlobalVarNode>())
403+
TVM_FFI_ICHECK(!op->op.as<GlobalVarNode>())
398404
<< "CodegenMetal does not support inter-function calls, "
399405
<< "but expression " << ffi::GetRef<Call>(op) << " calls PrimFunc "
400406
<< op->op;
401407
auto f_check_simdgroup_shape = [](PrimExpr col, PrimExpr row) {
402-
ICHECK(col->IsInstance<IntImmNode>() && row->IsInstance<IntImmNode>())
408+
TVM_FFI_ICHECK(col->IsInstance<IntImmNode>() && row->IsInstance<IntImmNode>())
403409
<< "Only constant shape is supported for simdgroup matrix, but got "
404410
<< col << "x" << row;
405411
int col_val = col.as<IntImmNode>()->value;
406412
int row_val = row.as<IntImmNode>()->value;
407-
ICHECK(col_val == 8 && row_val == 8)
413+
TVM_FFI_ICHECK(col_val == 8 && row_val == 8)
408414
<< "Only 8x8 matrix is supported, but got " << col_val << "x"
409415
<< row_val;
410416
};
411417
if (op->op.same_as(builtin::make_filled_simdgroup_matrix())) {
412-
ICHECK_EQ(op->args.size(), 5);
418+
TVM_FFI_ICHECK_EQ(op->args.size(), 5);
413419
Var var = Downcast<Var>(op->args[0]);
414420
// Get the data type of the simdgroup matrix
415421
auto it = simdgroup_dtype_.find(var.get());
416-
ICHECK(it != simdgroup_dtype_.end())
422+
TVM_FFI_ICHECK(it != simdgroup_dtype_.end())
417423
<< "Cannot find variable allocation for simdgroup: " << var;
418424
const std::string &dtype_str = it->second;
419425
f_check_simdgroup_shape(op->args[3], op->args[4]);
@@ -422,19 +428,19 @@ void CodeGenTileLangMetal::VisitExpr_(const CallNode *op,
422428
<< PrintExpr(op->args[3]) << ", " << PrintExpr(op->args[4]) << ">("
423429
<< PrintExpr(op->args[2]) << ")";
424430
} else if (op->op.same_as(builtin::simdgroup_load())) {
425-
ICHECK_EQ(op->args.size(), 7);
431+
TVM_FFI_ICHECK_EQ(op->args.size(), 7);
426432
f_check_simdgroup_shape(op->args[4], op->args[5]);
427433
os << "simdgroup_load(" << PrintExpr(op->args[0]) << "["
428434
<< PrintExpr(op->args[1]) << "], " << PrintExpr(op->args[2]) << ", "
429435
<< PrintExpr(op->args[3]) << ", 0, " << PrintExpr(op->args[6]) << ")";
430436
} else if (op->op.same_as(builtin::simdgroup_store())) {
431-
ICHECK_EQ(op->args.size(), 7);
437+
TVM_FFI_ICHECK_EQ(op->args.size(), 7);
432438
f_check_simdgroup_shape(op->args[4], op->args[5]);
433439
os << "simdgroup_store(" << PrintExpr(op->args[0]) << "["
434440
<< PrintExpr(op->args[1]) << "], " << PrintExpr(op->args[2]) << ", "
435441
<< PrintExpr(op->args[3]) << ", 0, " << PrintExpr(op->args[6]) << ")";
436442
} else if (op->op.same_as(builtin::simdgroup_multiply_accumulate())) {
437-
ICHECK_EQ(op->args.size(), 8);
443+
TVM_FFI_ICHECK_EQ(op->args.size(), 8);
438444
os << "simdgroup_multiply_accumulate(" //
439445
<< PrintExpr(op->args[0]) << "[" << PrintExpr(op->args[1]) << "], " //
440446
<< PrintExpr(op->args[2]) << "[" << PrintExpr(op->args[3]) << "], " //
@@ -475,28 +481,28 @@ void CodeGenTileLangMetal::VisitExpr_(const FloatImmNode *op,
475481

476482
ffi::Module BuildTileLangMetal(IRModule mod, Target target) {
477483
bool output_ssa = false;
478-
mod = tir::transform::PointerValueTypeRewrite()(std::move(mod));
484+
mod = tirx::transform::PointerValueTypeRewrite()(std::move(mod));
479485

480486
std::ostringstream source_maker;
481-
std::unordered_map<std::string, std::string> smap;
487+
ffi::Map<ffi::String, ffi::Bytes> smap;
482488
const auto fmetal_compile =
483489
tvm::ffi::Function::GetGlobal("tvm_callback_metal_compile");
484490
std::string fmt = fmetal_compile ? "metallib" : "metal";
485491

486492
for (auto kv : mod->functions) {
487-
ICHECK(kv.second->IsInstance<PrimFuncNode>())
493+
TVM_FFI_ICHECK(kv.second->IsInstance<tirx::PrimFuncNode>())
488494
<< "CodeGenTileLangMetal: Can only take PrimFunc";
489495
auto global_symbol =
490496
kv.second->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol);
491-
ICHECK(global_symbol.has_value());
497+
TVM_FFI_ICHECK(global_symbol.has_value());
492498
std::string func_name = global_symbol.value();
493499

494500
source_maker << "// Function: " << func_name << "\n";
495501
CodeGenTileLangMetal cg(target);
496502
cg.Init(output_ssa);
497-
auto f = Downcast<PrimFunc>(kv.second);
503+
auto f = Downcast<tirx::PrimFunc>(kv.second);
498504
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
499-
ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
505+
TVM_FFI_ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
500506
<< "CodeGenTileLangMetal: expect calling_conv equals "
501507
"CallingConv::kDeviceKernelLaunch";
502508

@@ -507,10 +513,11 @@ ffi::Module BuildTileLangMetal(IRModule mod, Target target) {
507513
if (fmetal_compile) {
508514
fsource = (*fmetal_compile)(fsource, target).cast<std::string>();
509515
}
510-
smap[func_name] = fsource;
516+
smap.Set(func_name, ffi::Bytes(std::move(fsource)));
511517
}
512518

513-
return MetalModuleCreate(smap, ExtractFuncInfo(mod), fmt, source_maker.str());
519+
return MetalModuleCreate(std::move(smap), ExtractFuncInfo(mod),
520+
ffi::String(fmt), ffi::String(source_maker.str()));
514521
}
515522

516523
TVM_FFI_STATIC_INIT_BLOCK() {

src/target/codegen_metal.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ class CodeGenTileLangMetal final : public CodeGenC {
5353
void PrintVecElemStore(const std::string &vec, DataType t, int i,
5454
const std::string &value) final;
5555
// overload visitor
56-
void VisitStmt_(const AllocateNode *op) final; // NOLINT(*)
56+
void VisitStmt_(const AllocBufferNode *op) final; // NOLINT(*)
5757
void VisitExpr_(const SelectNode *op, std::ostream &os) final; // NOLINT(*)
5858
void VisitExpr_(const BroadcastNode *op, std::ostream &os) final; // NOLINT(*)
5959
void VisitExpr_(const CallNode *op, std::ostream &os) final; // NOLINT(*)

0 commit comments

Comments
 (0)