Skip to content

Commit 33266ae

Browse files
Fix FP4 packed access and GEMM K tile validation
Lower FP4 packed vector load/store with odd or symbolic bases to per-lane nibble operations to avoid silent miscompiles. Reject T.gemm K tiles that are not divisible by the MMA instruction K tile so FP4/A8W4 block_K tails cannot be silently skipped.
1 parent 7f254a9 commit 33266ae

2 files changed

Lines changed: 53 additions & 0 deletions

File tree

src/backend/cuda/codegen/codegen_cuda.cc

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2039,6 +2039,30 @@ std::string CodeGenTileLangCUDA::GetVecLoad(DataType t,
20392039
return os.str();
20402040
}
20412041

2042+
if (IsFp4PackedStorage(buffer_var, buffer->dtype) && t.is_float4_e2m1fn() &&
2043+
t.lanes() > 1) {
2044+
arith::Analyzer analyzer;
2045+
bool base_aligned = is_zero(analyzer.Simplify(truncmod(base, 2)));
2046+
if (!base_aligned) {
2047+
// Packed FP4 vector reinterpret is only nibble-aligned for even logical
2048+
// bases. Odd or symbolic bases need per-lane nibble selection.
2049+
std::string vid = GetVarID(buffer_var);
2050+
std::ostringstream os;
2051+
os << "make_fp4_e2_" << t.lanes() << "_t(";
2052+
for (int i = 0; i < t.lanes(); ++i) {
2053+
if (i != 0) {
2054+
os << ", ";
2055+
}
2056+
PrimExpr index = analyzer.Simplify(
2057+
base + IntImm(base.dtype(), static_cast<int64_t>(i)));
2058+
os << "tl_fp4_packed_load((fp4_e2_2_t*)" << vid << ", "
2059+
<< PrintExpr(index) << ")";
2060+
}
2061+
os << ")";
2062+
return os.str();
2063+
}
2064+
}
2065+
20422066
std::string scope;
20432067
if (alloc_storage_scope_.count(buffer_var)) {
20442068
scope = alloc_storage_scope_.at(buffer_var);
@@ -2133,6 +2157,30 @@ void CodeGenTileLangCUDA::PrintVecStore(const BufferNode *buffer, DataType t,
21332157
return;
21342158
}
21352159

2160+
if (IsFp4PackedStorage(buffer_var, buffer->dtype) && t.is_float4_e2m1fn() &&
2161+
t.lanes() > 1) {
2162+
arith::Analyzer analyzer;
2163+
bool base_aligned = is_zero(analyzer.Simplify(truncmod(base, 2)));
2164+
if (!base_aligned) {
2165+
std::ostringstream vec_type;
2166+
PrintType(t, vec_type);
2167+
std::string vid = GetVarID(buffer_var);
2168+
this->PrintIndent();
2169+
this->stream << "{ " << vec_type.str() << " __tl_fp4_vec = " << value
2170+
<< "; ";
2171+
for (int i = 0; i < t.lanes(); ++i) {
2172+
std::ostringstream elem;
2173+
PrintVecElemLoad("__tl_fp4_vec", t, i, elem);
2174+
PrimExpr index = analyzer.Simplify(
2175+
base + IntImm(base.dtype(), static_cast<int64_t>(i)));
2176+
this->stream << "tl_fp4_packed_store((fp4_e2_2_t*)" << vid << ", "
2177+
<< PrintExpr(index) << ", " << elem.str() << "); ";
2178+
}
2179+
this->stream << "}\n";
2180+
return;
2181+
}
2182+
}
2183+
21362184
std::string scope;
21372185
if (alloc_storage_scope_.count(buffer_var)) {
21382186
scope = alloc_storage_scope_.at(buffer_var);

tilelang/cuda/op/gemm/gemm_mma.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,11 @@ def _make_mma_emitter(self, target: Target, thread_nums: int, thread_var: tir.Va
5757
chunk=self.chunk,
5858
thread_var=thread_var,
5959
)
60+
if self.chunk % emitter.micro_size_k != 0:
61+
raise ValueError(
62+
f"T.gemm K tile ({self.chunk}) must be divisible by MMA instruction K tile "
63+
f"({emitter.micro_size_k}) for A={self.A.dtype}, B={self.B.dtype}"
64+
)
6065
return emitter
6166

6267
def infer_layout(self, target: Target, thread_nums: int):

0 commit comments

Comments
 (0)