Skip to content

Commit 400990c

Browse files
committed
lint
1 parent 2e0fcf2 commit 400990c

6 files changed

Lines changed: 30 additions & 30 deletions

File tree

src/backend/metal/op/copy.cc

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,7 @@ Stmt LowerSIMDGroupCopy(const CopyNode &op, const LowerArgs &T,
5353

5454
int warp_size = TargetGetWarpSize(T.target);
5555
const auto *block_size_imm = T.thread_bounds->extent.as<IntImmNode>();
56-
ICHECK(block_size_imm)
57-
<< "simdgroup copy requires constant thread bounds";
56+
ICHECK(block_size_imm) << "simdgroup copy requires constant thread bounds";
5857
int block_size = block_size_imm->value;
5958
int num_warps = block_size / warp_size;
6059
PrimExpr warp_id = FloorDiv(T.thread_var, warp_size);
@@ -112,12 +111,12 @@ Stmt LowerSIMDGroupCopy(const CopyNode &op, const LowerArgs &T,
112111
dst_col_base + warp_n * (warp_col_tiles * kNPerWarp) + j * kNPerWarp;
113112
PrimExpr ptr = Call(DataType::Handle(), builtin::address_of(),
114113
{BufferLoad(op.dst, {row, col})});
115-
stmts.push_back(Evaluate(Call(
116-
DataType::Handle(), builtin::simdgroup_store(),
117-
{op.src->data, IntImm(DataType::Int(32), tile_idx), ptr, dst_stride,
118-
IntImm(DataType::Int(32), kMPerWarp),
119-
IntImm(DataType::Int(32), kNPerWarp),
120-
Cast(DataType::Bool(), IntImm(DataType::Int(32), 0))})));
114+
stmts.push_back(Evaluate(
115+
Call(DataType::Handle(), builtin::simdgroup_store(),
116+
{op.src->data, IntImm(DataType::Int(32), tile_idx), ptr,
117+
dst_stride, IntImm(DataType::Int(32), kMPerWarp),
118+
IntImm(DataType::Int(32), kNPerWarp),
119+
Cast(DataType::Bool(), IntImm(DataType::Int(32), 0))})));
121120
}
122121
}
123122
if (stmts.size() == 1) {

src/backend/metal/op/gemm.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,8 @@ namespace {
2424

2525
constexpr const char *kMetalSIMDGroup = "metal.simdgroup";
2626

27-
std::pair<int, int>
28-
ComputeMetalWarpPartition(const GemmWarpPolicyNode &policy, int M, int N,
29-
int num_warps) {
27+
std::pair<int, int> ComputeMetalWarpPartition(const GemmWarpPolicyNode &policy,
28+
int M, int N, int num_warps) {
3029
int m_warp = 1, n_warp = 1;
3130
constexpr int kMPerWarp = 8;
3231
constexpr int kNPerWarp = 8;

src/runtime/metal/metal_module.h

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,27 +14,28 @@
1414
namespace tvm {
1515
namespace codegen {
1616

17-
inline ffi::Module MetalModuleCreate(ffi::Map<ffi::String, ffi::Bytes> smap,
18-
ffi::Map<ffi::String, runtime::FunctionInfo> fmap,
19-
ffi::String fmt, ffi::String source) {
17+
inline ffi::Module
18+
MetalModuleCreate(ffi::Map<ffi::String, ffi::Bytes> smap,
19+
ffi::Map<ffi::String, runtime::FunctionInfo> fmap,
20+
ffi::String fmt, ffi::String source) {
2021
auto fcreate = ffi::Function::GetGlobal("ffi.Module.create.metal");
2122
if (fcreate.has_value()) {
2223
return (*fcreate)(smap, fmt, fmap,
23-
ffi::Map<ffi::String, ffi::String>{{"metal", source}})
24+
ffi::Map<ffi::String, ffi::String>{{"metal", source}})
2425
.cast<ffi::Module>();
2526
}
2627
auto fallback = ffi::Function::GetGlobal("ffi.Module.create.metal_fallback");
2728
if (fallback.has_value()) {
2829
return (*fallback)(smap, fmt, fmap,
29-
ffi::Map<ffi::String, ffi::String>{{"metal", source}})
30+
ffi::Map<ffi::String, ffi::String>{{"metal", source}})
3031
.cast<ffi::Module>();
3132
}
3233
LOG(FATAL) << "Metal module factory not available.";
3334
// Unreachable; LOG(FATAL) aborts.
3435
__builtin_unreachable();
3536
}
3637

37-
} // namespace codegen
38-
} // namespace tvm
38+
} // namespace codegen
39+
} // namespace tvm
3940

40-
#endif // TVM_RUNTIME_METAL_METAL_MODULE_H_
41+
#endif // TVM_RUNTIME_METAL_METAL_MODULE_H_

src/target/codegen_metal.cc

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,8 @@ void CodeGenTileLangMetal::VisitStmt_(const AllocBufferNode *op) {
335335
size_t constant_size = 1;
336336
for (const auto &dim : op->buffer->shape) {
337337
const IntImmNode *dim_imm = dim.as<IntImmNode>();
338-
TVM_FFI_ICHECK(dim_imm) << "Can only handle constant size stack allocation for now";
338+
TVM_FFI_ICHECK(dim_imm)
339+
<< "Can only handle constant size stack allocation for now";
339340
constant_size *= dim_imm->value;
340341
}
341342
TVM_FFI_ICHECK_GT(constant_size, 0)
@@ -346,12 +347,13 @@ void CodeGenTileLangMetal::VisitStmt_(const AllocBufferNode *op) {
346347
alloc_storage_scope_[op->buffer->data.get()] = scope;
347348
if (scope == "metal.simdgroup") {
348349
TVM_FFI_ICHECK(dtype == DataType::Float(16) ||
349-
dtype == DataType::Float(32) ||
350-
dtype == DataType::BFloat(16))
350+
dtype == DataType::Float(32) ||
351+
dtype == DataType::BFloat(16))
351352
<< "Only float16, float32, and bfloat16 are supported, but got "
352353
<< dtype;
353-
TVM_FFI_ICHECK(constant_size % 64 == 0) << "Only 8x8 matrix is supported, but got "
354-
<< constant_size << " bytes\n";
354+
TVM_FFI_ICHECK(constant_size % 64 == 0)
355+
<< "Only 8x8 matrix is supported, but got " << constant_size
356+
<< " bytes\n";
355357

356358
std::ostringstream dtype_os;
357359
PrintType(dtype, dtype_os);
@@ -405,7 +407,8 @@ void CodeGenTileLangMetal::VisitExpr_(const CallNode *op,
405407
<< "but expression " << ffi::GetRef<Call>(op) << " calls PrimFunc "
406408
<< op->op;
407409
auto f_check_simdgroup_shape = [](PrimExpr col, PrimExpr row) {
408-
TVM_FFI_ICHECK(col->IsInstance<IntImmNode>() && row->IsInstance<IntImmNode>())
410+
TVM_FFI_ICHECK(col->IsInstance<IntImmNode>() &&
411+
row->IsInstance<IntImmNode>())
409412
<< "Only constant shape is supported for simdgroup matrix, but got "
410413
<< col << "x" << row;
411414
int col_val = col.as<IntImmNode>()->value;
@@ -517,7 +520,7 @@ ffi::Module BuildTileLangMetal(IRModule mod, Target target) {
517520
}
518521

519522
return MetalModuleCreate(std::move(smap), ExtractFuncInfo(mod),
520-
ffi::String(fmt), ffi::String(source_maker.str()));
523+
ffi::String(fmt), ffi::String(source_maker.str()));
521524
}
522525

523526
TVM_FFI_STATIC_INIT_BLOCK() {

src/target/codegen_metal.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ class CodeGenTileLangMetal final : public CodeGenC {
5353
void PrintVecElemStore(const std::string &vec, DataType t, int i,
5454
const std::string &value) final;
5555
// overload visitor
56-
void VisitStmt_(const AllocBufferNode *op) final; // NOLINT(*)
56+
void VisitStmt_(const AllocBufferNode *op) final; // NOLINT(*)
5757
void VisitExpr_(const SelectNode *op, std::ostream &os) final; // NOLINT(*)
5858
void VisitExpr_(const BroadcastNode *op, std::ostream &os) final; // NOLINT(*)
5959
void VisitExpr_(const CallNode *op, std::ostream &os) final; // NOLINT(*)

tilelang/tileop/gemm/gemm_metal.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,7 @@ def lower(
2424
self, layout_map: dict, target: Target, thread_bounds: Range, thread_var: tir.Var, mbar_phase_expr: tir.PrimExpr | None = None
2525
):
2626
thread_nums = thread_bounds.extent
27-
m_warp, n_warp = self.policy.compute_warp_partition(
28-
self.M, self.N, thread_nums, target, GEMM_INST_METAL
29-
)
27+
m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, GEMM_INST_METAL)
3028
warp_row_tiles = int(self.M // m_warp)
3129
warp_col_tiles = int(self.N // n_warp)
3230

0 commit comments

Comments
 (0)