2323#include " codegen_metal.h"
2424
2525#include < tvm/ffi/reflection/registry.h>
26- #include < tvm/tir /transform.h>
26+ #include < tvm/tirx /transform.h>
2727
2828#include < algorithm>
2929#include < sstream>
@@ -79,7 +79,7 @@ void CodeGenTileLangMetal::AddFunction(const GlobalVar &gvar,
7979
8080 // add to alloc buffer type.
8181 auto global_symbol = func->GetAttr <ffi::String>(tvm::attr::kGlobalSymbol );
82- ICHECK (global_symbol.has_value ())
82+ TVM_FFI_ICHECK (global_symbol.has_value ())
8383 << " CodeGenC: Expect PrimFunc to have the global_symbol attribute" ;
8484
8585 // Function header.
@@ -128,7 +128,7 @@ void CodeGenTileLangMetal::AddFunction(const GlobalVar &gvar,
128128 decl_stream << " struct " << arg_buf_type << " {\n " ;
129129 for (size_t i = num_buffer; i < func->params .size (); ++i) {
130130 Var v = func->params [i];
131- ICHECK (!v.dtype ().is_handle ());
131+ TVM_FFI_ICHECK (!v.dtype ().is_handle ());
132132 std::string vid = AllocVarID (v.get ());
133133 std::ostringstream vref;
134134 if (v.dtype ().bits () == 32 ) {
@@ -152,11 +152,11 @@ void CodeGenTileLangMetal::AddFunction(const GlobalVar &gvar,
152152 decl_stream << " };\n\n " ;
153153 }
154154 // Setup the thread group info.
155- ICHECK_EQ (name_supply_->FreshName (" threadIdx" ), " threadIdx" );
156- ICHECK_EQ (name_supply_->FreshName (" blockIdx" ), " blockIdx" );
155+ TVM_FFI_ICHECK_EQ (name_supply_->FreshName (" threadIdx" ), " threadIdx" );
156+ TVM_FFI_ICHECK_EQ (name_supply_->FreshName (" blockIdx" ), " blockIdx" );
157157 int work_dim = 0 ;
158158 auto launch_params =
159- func->GetAttr <ffi::Array<ffi::String>>(tir ::attr::kKernelLaunchParams )
159+ func->GetAttr <ffi::Array<ffi::String>>(tirx ::attr::kKernelLaunchParams )
160160 .value ();
161161 for (const auto &tag : launch_params) {
162162 if (tag != runtime::launch_param::kUseDynamicSharedMemoryTag ) {
@@ -186,7 +186,7 @@ void CodeGenTileLangMetal::AddFunction(const GlobalVar &gvar,
186186}
187187
188188void CodeGenTileLangMetal::BindThreadIndex (const IterVar &iv) {
189- ICHECK (!var_idmap_.count (iv->var .get ()));
189+ TVM_FFI_ICHECK (!var_idmap_.count (iv->var .get ()));
190190 // if we only have threadIdx.x
191191 // metal will directly print as threadIdx
192192 std::string vname = iv->thread_tag ;
@@ -201,7 +201,7 @@ void CodeGenTileLangMetal::PrintType(DataType t,
201201 std::ostream &os) { // NOLINT(*)
202202 int lanes = t.lanes ();
203203 if (t.is_handle ()) {
204- ICHECK_EQ (lanes, 1 ) << " do not yet support vector types" ;
204+ TVM_FFI_ICHECK_EQ (lanes, 1 ) << " do not yet support vector types" ;
205205 os << " void*" ;
206206 return ;
207207 }
@@ -326,40 +326,46 @@ void CodeGenTileLangMetal::PrintStorageScope(const std::string &scope,
326326 }
327327}
328328
329- void CodeGenTileLangMetal::VisitStmt_ (const AllocateNode *op) {
330- ICHECK (! is_zero ( op->condition ));
331- std::string vid = AllocVarID (op->buffer_var .get ());
329+ void CodeGenTileLangMetal::VisitStmt_ (const AllocBufferNode *op) {
330+ TVM_FFI_ICHECK ( op->buffer . defined ( ));
331+ std::string vid = AllocVarID (op->buffer -> data .get ());
332332
333333 this ->PrintIndent ();
334- size_t constant_size = op->ConstantAllocationSize ();
335- ICHECK_GT (constant_size, 0 )
334+ // Compute constant_size from buffer shape
335+ size_t constant_size = 1 ;
336+ for (const auto &dim : op->buffer ->shape ) {
337+ const IntImmNode *dim_imm = dim.as <IntImmNode>();
338+ TVM_FFI_ICHECK (dim_imm) << " Can only handle constant size stack allocation for now" ;
339+ constant_size *= dim_imm->value ;
340+ }
341+ TVM_FFI_ICHECK_GT (constant_size, 0 )
336342 << " Can only handle constant size stack allocation for now" ;
337343
338- auto scope = GetPtrStorageScope (op->buffer_var );
339- alloc_storage_scope_[op->buffer_var .get ()] = scope;
344+ DataType dtype = op->buffer ->dtype ;
345+ auto scope = GetPtrStorageScope (op->buffer ->data );
346+ alloc_storage_scope_[op->buffer ->data .get ()] = scope;
340347 if (scope == " metal.simdgroup" ) {
341- ICHECK (op-> dtype == DataType::Float (16 ) ||
342- op-> dtype == DataType::Float (32 ) ||
343- op-> dtype == DataType::BFloat (16 ))
348+ TVM_FFI_ICHECK ( dtype == DataType::Float (16 ) ||
349+ dtype == DataType::Float (32 ) ||
350+ dtype == DataType::BFloat (16 ))
344351 << " Only float16, float32, and bfloat16 are supported, but got "
345- << op-> dtype ;
346- ICHECK (constant_size % 64 == 0 ) << " Only 8x8 matrix is supported, but got "
352+ << dtype;
353+ TVM_FFI_ICHECK (constant_size % 64 == 0 ) << " Only 8x8 matrix is supported, but got "
347354 << constant_size << " bytes\n " ;
348355
349356 std::ostringstream dtype_os;
350- PrintType (op-> dtype , dtype_os);
357+ PrintType (dtype, dtype_os);
351358 std::string dtype_str = dtype_os.str ();
352- simdgroup_dtype_[op->buffer_var .get ()] = dtype_str;
359+ simdgroup_dtype_[op->buffer -> data .get ()] = dtype_str;
353360 stream << " simdgroup_" << dtype_str << " 8x8 " << vid << ' ['
354361 << constant_size / 64 << " ];\n " ;
355362 } else {
356363 PrintStorageScope (scope, stream);
357- PrintType (op-> dtype , stream);
364+ PrintType (dtype, stream);
358365 stream << ' ' << vid << ' [' << constant_size << " ];\n " ;
359366 }
360367
361- RegisterHandleType (op->buffer_var .get (), op->dtype );
362- this ->PrintStmt (op->body );
368+ RegisterHandleType (op->buffer ->data .get (), dtype);
363369}
364370
365371void CodeGenTileLangMetal::VisitExpr_ (const SelectNode *op,
@@ -394,26 +400,26 @@ void CodeGenTileLangMetal::VisitExpr_(const BroadcastNode *op,
394400
395401void CodeGenTileLangMetal::VisitExpr_ (const CallNode *op,
396402 std::ostream &os) { // NOLINT(*)
397- CHECK (!op->op .as <GlobalVarNode>())
403+ TVM_FFI_ICHECK (!op->op .as <GlobalVarNode>())
398404 << " CodegenMetal does not support inter-function calls, "
399405 << " but expression " << ffi::GetRef<Call>(op) << " calls PrimFunc "
400406 << op->op ;
401407 auto f_check_simdgroup_shape = [](PrimExpr col, PrimExpr row) {
402- ICHECK (col->IsInstance <IntImmNode>() && row->IsInstance <IntImmNode>())
408+ TVM_FFI_ICHECK (col->IsInstance <IntImmNode>() && row->IsInstance <IntImmNode>())
403409 << " Only constant shape is supported for simdgroup matrix, but got "
404410 << col << " x" << row;
405411 int col_val = col.as <IntImmNode>()->value ;
406412 int row_val = row.as <IntImmNode>()->value ;
407- ICHECK (col_val == 8 && row_val == 8 )
413+ TVM_FFI_ICHECK (col_val == 8 && row_val == 8 )
408414 << " Only 8x8 matrix is supported, but got " << col_val << " x"
409415 << row_val;
410416 };
411417 if (op->op .same_as (builtin::make_filled_simdgroup_matrix ())) {
412- ICHECK_EQ (op->args .size (), 5 );
418+ TVM_FFI_ICHECK_EQ (op->args .size (), 5 );
413419 Var var = Downcast<Var>(op->args [0 ]);
414420 // Get the data type of the simdgroup matrix
415421 auto it = simdgroup_dtype_.find (var.get ());
416- ICHECK (it != simdgroup_dtype_.end ())
422+ TVM_FFI_ICHECK (it != simdgroup_dtype_.end ())
417423 << " Cannot find variable allocation for simdgroup: " << var;
418424 const std::string &dtype_str = it->second ;
419425 f_check_simdgroup_shape (op->args [3 ], op->args [4 ]);
@@ -422,19 +428,19 @@ void CodeGenTileLangMetal::VisitExpr_(const CallNode *op,
422428 << PrintExpr (op->args [3 ]) << " , " << PrintExpr (op->args [4 ]) << " >("
423429 << PrintExpr (op->args [2 ]) << " )" ;
424430 } else if (op->op .same_as (builtin::simdgroup_load ())) {
425- ICHECK_EQ (op->args .size (), 7 );
431+ TVM_FFI_ICHECK_EQ (op->args .size (), 7 );
426432 f_check_simdgroup_shape (op->args [4 ], op->args [5 ]);
427433 os << " simdgroup_load(" << PrintExpr (op->args [0 ]) << " ["
428434 << PrintExpr (op->args [1 ]) << " ], " << PrintExpr (op->args [2 ]) << " , "
429435 << PrintExpr (op->args [3 ]) << " , 0, " << PrintExpr (op->args [6 ]) << " )" ;
430436 } else if (op->op .same_as (builtin::simdgroup_store ())) {
431- ICHECK_EQ (op->args .size (), 7 );
437+ TVM_FFI_ICHECK_EQ (op->args .size (), 7 );
432438 f_check_simdgroup_shape (op->args [4 ], op->args [5 ]);
433439 os << " simdgroup_store(" << PrintExpr (op->args [0 ]) << " ["
434440 << PrintExpr (op->args [1 ]) << " ], " << PrintExpr (op->args [2 ]) << " , "
435441 << PrintExpr (op->args [3 ]) << " , 0, " << PrintExpr (op->args [6 ]) << " )" ;
436442 } else if (op->op .same_as (builtin::simdgroup_multiply_accumulate ())) {
437- ICHECK_EQ (op->args .size (), 8 );
443+ TVM_FFI_ICHECK_EQ (op->args .size (), 8 );
438444 os << " simdgroup_multiply_accumulate(" //
439445 << PrintExpr (op->args [0 ]) << " [" << PrintExpr (op->args [1 ]) << " ], " //
440446 << PrintExpr (op->args [2 ]) << " [" << PrintExpr (op->args [3 ]) << " ], " //
@@ -475,28 +481,28 @@ void CodeGenTileLangMetal::VisitExpr_(const FloatImmNode *op,
475481
476482ffi::Module BuildTileLangMetal (IRModule mod, Target target) {
477483 bool output_ssa = false ;
478- mod = tir ::transform::PointerValueTypeRewrite ()(std::move (mod));
484+ mod = tirx ::transform::PointerValueTypeRewrite ()(std::move (mod));
479485
480486 std::ostringstream source_maker;
481- std::unordered_map<std::string, std::string > smap;
487+ ffi::Map<ffi::String, ffi::Bytes > smap;
482488 const auto fmetal_compile =
483489 tvm::ffi::Function::GetGlobal (" tvm_callback_metal_compile" );
484490 std::string fmt = fmetal_compile ? " metallib" : " metal" ;
485491
486492 for (auto kv : mod->functions ) {
487- ICHECK (kv.second ->IsInstance <PrimFuncNode>())
493+ TVM_FFI_ICHECK (kv.second ->IsInstance <tirx:: PrimFuncNode>())
488494 << " CodeGenTileLangMetal: Can only take PrimFunc" ;
489495 auto global_symbol =
490496 kv.second ->GetAttr <ffi::String>(tvm::attr::kGlobalSymbol );
491- ICHECK (global_symbol.has_value ());
497+ TVM_FFI_ICHECK (global_symbol.has_value ());
492498 std::string func_name = global_symbol.value ();
493499
494500 source_maker << " // Function: " << func_name << " \n " ;
495501 CodeGenTileLangMetal cg (target);
496502 cg.Init (output_ssa);
497- auto f = Downcast<PrimFunc>(kv.second );
503+ auto f = Downcast<tirx:: PrimFunc>(kv.second );
498504 auto calling_conv = f->GetAttr <Integer>(tvm::attr::kCallingConv );
499- ICHECK (calling_conv == CallingConv::kDeviceKernelLaunch )
505+ TVM_FFI_ICHECK (calling_conv == CallingConv::kDeviceKernelLaunch )
500506 << " CodeGenTileLangMetal: expect calling_conv equals "
501507 " CallingConv::kDeviceKernelLaunch" ;
502508
@@ -507,10 +513,11 @@ ffi::Module BuildTileLangMetal(IRModule mod, Target target) {
507513 if (fmetal_compile) {
508514 fsource = (*fmetal_compile)(fsource, target).cast <std::string>();
509515 }
510- smap[ func_name] = fsource;
516+ smap. Set ( func_name, ffi::Bytes ( std::move ( fsource))) ;
511517 }
512518
513- return MetalModuleCreate (smap, ExtractFuncInfo (mod), fmt, source_maker.str ());
519+ return MetalModuleCreate (std::move (smap), ExtractFuncInfo (mod),
520+ ffi::String (fmt), ffi::String (source_maker.str ()));
514521}
515522
516523TVM_FFI_STATIC_INIT_BLOCK () {
0 commit comments