@@ -2039,6 +2039,30 @@ std::string CodeGenTileLangCUDA::GetVecLoad(DataType t,
20392039 return os.str ();
20402040 }
20412041
2042+ if (IsFp4PackedStorage (buffer_var, buffer->dtype ) && t.is_float4_e2m1fn () &&
2043+ t.lanes () > 1 ) {
2044+ arith::Analyzer analyzer;
2045+ bool base_aligned = is_zero (analyzer.Simplify (truncmod (base, 2 )));
2046+ if (!base_aligned) {
2047+ // Packed FP4 vector reinterpret is only nibble-aligned for even logical
2048+ // bases. Odd or symbolic bases need per-lane nibble selection.
2049+ std::string vid = GetVarID (buffer_var);
2050+ std::ostringstream os;
2051+ os << " make_fp4_e2_" << t.lanes () << " _t(" ;
2052+ for (int i = 0 ; i < t.lanes (); ++i) {
2053+ if (i != 0 ) {
2054+ os << " , " ;
2055+ }
2056+ PrimExpr index = analyzer.Simplify (
2057+ base + IntImm (base.dtype (), static_cast <int64_t >(i)));
2058+ os << " tl_fp4_packed_load((fp4_e2_2_t*)" << vid << " , "
2059+ << PrintExpr (index) << " )" ;
2060+ }
2061+ os << " )" ;
2062+ return os.str ();
2063+ }
2064+ }
2065+
20422066 std::string scope;
20432067 if (alloc_storage_scope_.count (buffer_var)) {
20442068 scope = alloc_storage_scope_.at (buffer_var);
@@ -2133,6 +2157,30 @@ void CodeGenTileLangCUDA::PrintVecStore(const BufferNode *buffer, DataType t,
21332157 return ;
21342158 }
21352159
2160+ if (IsFp4PackedStorage (buffer_var, buffer->dtype ) && t.is_float4_e2m1fn () &&
2161+ t.lanes () > 1 ) {
2162+ arith::Analyzer analyzer;
2163+ bool base_aligned = is_zero (analyzer.Simplify (truncmod (base, 2 )));
2164+ if (!base_aligned) {
2165+ std::ostringstream vec_type;
2166+ PrintType (t, vec_type);
2167+ std::string vid = GetVarID (buffer_var);
2168+ this ->PrintIndent ();
2169+ this ->stream << " { " << vec_type.str () << " __tl_fp4_vec = " << value
2170+ << " ; " ;
2171+ for (int i = 0 ; i < t.lanes (); ++i) {
2172+ std::ostringstream elem;
2173+ PrintVecElemLoad (" __tl_fp4_vec" , t, i, elem);
2174+ PrimExpr index = analyzer.Simplify (
2175+ base + IntImm (base.dtype (), static_cast <int64_t >(i)));
2176+ this ->stream << " tl_fp4_packed_store((fp4_e2_2_t*)" << vid << " , "
2177+ << PrintExpr (index) << " , " << elem.str () << " ); " ;
2178+ }
2179+ this ->stream << " }\n " ;
2180+ return ;
2181+ }
2182+ }
2183+
21362184 std::string scope;
21372185 if (alloc_storage_scope_.count (buffer_var)) {
21382186 scope = alloc_storage_scope_.at (buffer_var);
0 commit comments