Skip to content

Commit 029cfd0

Browse files
Krzysztof Rymskicopybara-github
authored andcommitted
Int8 + microscaling support for kv cache formats.
Right now multiplication is done by converting to corresponding float format. Can yield up to 2x improvements for membw constrained shapes PiperOrigin-RevId: 880748493
1 parent d2806fb commit 029cfd0

11 files changed

Lines changed: 565 additions & 16 deletions

compression/compress-inl.h

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,142 @@ struct CompressTraits<SfpStream> {
444444
}
445445
};
446446

447+
template <>
448+
struct CompressTraits<int8_t> {
449+
using Packed = int8_t;
450+
451+
static size_t CompressBound(size_t num) { return num * sizeof(Packed); }
452+
453+
template <class DF, HWY_IF_F32_D(DF)>
454+
static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT raw,
455+
size_t num, CompressPerThread& /*tls*/,
456+
const PackedSpan<Packed>& packed,
457+
const size_t packed_ofs) {
458+
const hn::Repartition<int32_t, DF> di32;
459+
const hn::Repartition<int16_t, DF> di16;
460+
const hn::Repartition<int8_t, DF> di8;
461+
const auto di16_16 = hn::Half<decltype(di16)>();
462+
const auto di8_16 = hn::Half<decltype(di8)>();
463+
using VF = hn::Vec<DF>;
464+
const size_t NF = hn::Lanes(df);
465+
466+
size_t i = 0;
467+
if (num >= 2 * NF) {
468+
for (; i <= num - 2 * NF; i += 2 * NF) {
469+
const VF v0 = hn::LoadU(df, raw + i);
470+
const VF v1 = hn::LoadU(df, raw + i + NF);
471+
const auto vi32_0 = hn::NearestInt(v0);
472+
const auto vi32_1 = hn::NearestInt(v1);
473+
const auto vi16 = hn::OrderedDemote2To(di16, vi32_0, vi32_1);
474+
const auto vi8 = hn::OrderedDemote2To(
475+
di8_16, hn::UpperHalf(di16_16, vi16), hn::LowerHalf(di16_16, vi16));
476+
hn::StoreU(vi8, di8_16, packed.ptr + packed_ofs + i);
477+
}
478+
}
479+
const size_t remaining = num - i;
480+
if (remaining > 0) {
481+
HWY_ALIGN float buf[2 * NF];
482+
hwy::ZeroBytes(buf, 2 * NF * sizeof(float));
483+
for (size_t j = 0; j < remaining; ++j) buf[j] = raw[i + j];
484+
const VF v0 = hn::LoadU(df, buf);
485+
const VF v1 = hn::LoadU(df, buf + NF);
486+
const auto vi32_0 = hn::NearestInt(v0);
487+
const auto vi32_1 = hn::NearestInt(v1);
488+
const auto vi16 = hn::OrderedDemote2To(di16, vi32_0, vi32_1);
489+
const auto vi8 = hn::OrderedDemote2To(
490+
di8_16, hn::UpperHalf(di16_16, vi16), hn::LowerHalf(di16_16, vi16));
491+
hn::StoreN(vi8, di8_16, packed.ptr + packed_ofs + i, remaining);
492+
}
493+
}
494+
495+
static float ToFloatSlow(const Packed x) { return static_cast<float>(x); }
496+
497+
template <class DF, HWY_IF_F32_D(DF)>
498+
static HWY_INLINE void Load2(DF df, const PackedSpan<const Packed>& packed,
499+
const size_t packed_ofs, hn::Vec<DF>& raw0,
500+
hn::Vec<DF>& raw1) {
501+
const hn::Repartition<int32_t, DF> di32;
502+
const hn::Repartition<int16_t, DF> di16;
503+
const hn::Rebind<int8_t, decltype(di16)> di8_half;
504+
505+
const auto vec_i8 = hn::LoadU(di8_half, packed.ptr + packed_ofs);
506+
const auto vec_i16 = hn::PromoteTo(di16, vec_i8);
507+
const auto vec_i32_0 = hn::PromoteLowerTo(di32, vec_i16);
508+
const auto vec_i32_1 = hn::PromoteUpperTo(di32, vec_i16);
509+
510+
raw0 = hn::ConvertTo(df, vec_i32_0);
511+
raw1 = hn::ConvertTo(df, vec_i32_1);
512+
}
513+
514+
template <class DBF, HWY_IF_BF16_D(DBF)>
515+
static HWY_INLINE void Load2(DBF dbf, const PackedSpan<const Packed>& packed,
516+
const size_t packed_ofs, hn::Vec<DBF>& raw0,
517+
hn::Vec<DBF>& raw1) {
518+
const hn::Repartition<float, DBF> df;
519+
using VF = hn::Vec<decltype(df)>;
520+
const size_t NF = hn::Lanes(df);
521+
522+
VF f0, f1, f2, f3;
523+
Load2(df, packed, packed_ofs, f0, f1);
524+
Load2(df, packed, packed_ofs + 2 * NF, f2, f3);
525+
526+
raw0 = hn::OrderedDemote2To(dbf, f0, f1);
527+
raw1 = hn::OrderedDemote2To(dbf, f2, f3);
528+
}
529+
530+
template <class DF, HWY_IF_F32_D(DF)>
531+
static HWY_INLINE void DecompressAndZeroPad(
532+
DF df, const PackedSpan<const Packed>& packed, const size_t packed_ofs,
533+
float* HWY_RESTRICT raw, size_t num) {
534+
using VF = hn::Vec<decltype(df)>;
535+
const size_t NF = hn::Lanes(df);
536+
537+
size_t i = 0;
538+
if (num >= 2 * NF) {
539+
for (; i <= num - 2 * NF; i += 2 * NF) {
540+
VF raw0, raw1;
541+
Load2(df, packed, packed_ofs + i, raw0, raw1);
542+
hn::StoreU(raw0, df, raw + i);
543+
hn::StoreU(raw1, df, raw + i + NF);
544+
}
545+
}
546+
547+
const size_t remaining = num - i;
548+
if (HWY_UNLIKELY(remaining != 0)) {
549+
for (size_t j = 0; j < remaining; ++j) {
550+
raw[i + j] = static_cast<float>(packed.ptr[packed_ofs + i + j]);
551+
}
552+
}
553+
}
554+
555+
template <class DBF, HWY_IF_BF16_D(DBF)>
556+
static HWY_INLINE void DecompressAndZeroPad(
557+
DBF dbf, const PackedSpan<const Packed>& packed, const size_t packed_ofs,
558+
BF16* HWY_RESTRICT raw, size_t num) {
559+
const hn::Repartition<float, DBF> df;
560+
const size_t NF = hn::Lanes(df);
561+
size_t i = 0;
562+
const size_t NBF = hn::Lanes(dbf);
563+
if (num >= NBF) {
564+
for (; i <= num - NBF; i += NBF) {
565+
hn::Vec<decltype(df)> f0, f1;
566+
Load2(df, packed, packed_ofs + i, f0, f1);
567+
auto vbf = hn::OrderedDemote2To(dbf, f0, f1);
568+
hn::StoreU(vbf, dbf, raw + i);
569+
}
570+
}
571+
const size_t remaining = num - i;
572+
if (remaining > 0) {
573+
HWY_ALIGN float buf[2 * hn::MaxLanes(df)];
574+
DecompressAndZeroPad(df, packed, packed_ofs + i, buf, remaining);
575+
auto f0 = hn::LoadU(df, buf);
576+
auto f1 = hn::LoadU(df, buf + NF);
577+
auto vbf = hn::OrderedDemote2To(dbf, f0, f1);
578+
hn::StoreN(vbf, dbf, raw + i, remaining);
579+
}
580+
}
581+
};
582+
447583
// Integer quantization.
448584
template <>
449585
struct CompressTraits<I8Stream> {

compression/compress_test.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,8 @@ struct TestDecompress2 {
126126
HWY_ASSERT(stats.L1().Max() <= 0.08f);
127127
HWY_ASSERT(IsInside(0.02, 0.05, stats.WeightedAverageL1()));
128128
HWY_ASSERT(IsInside(18.0, 62.0, stats.GeomeanValueDivL1()));
129+
} else if constexpr (hwy::IsSame<Packed, int8_t>()) {
130+
HWY_ASSERT(stats.L1().Max() <= 0.6f);
129131
} else {
130132
HWY_ABORT("Unhandled type requested by ForeachPackedAndRawType");
131133
}
@@ -200,6 +202,8 @@ struct TestShortLengths {
200202
HWY_ASSERT(stats.L1().Max() <= 0.14f);
201203
HWY_ASSERT(IsInside(7E-5, 0.06, stats.WeightedAverageL1()));
202204
HWY_ASSERT(IsInside(11.0, 180.0, stats.GeomeanValueDivL1()));
205+
} else if constexpr (hwy::IsSame<Packed, int8_t>()) {
206+
HWY_ASSERT(stats.L1().Max() <= 0.6f);
203207
} else {
204208
HWY_ABORT("Unhandled type requested by ForeachPackedAndRawType");
205209
}

compression/types.h

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,11 @@ constexpr bool IsF32() {
192192
return hwy::IsSame<hwy::RemoveCvRef<Packed>, float>();
193193
}
194194

195+
template <typename Packed>
196+
constexpr bool IsInt8() {
197+
return hwy::IsSame<hwy::RemoveCvRef<Packed>, int8_t>();
198+
}
199+
195200
template <typename Packed>
196201
constexpr bool IsBF16() {
197202
return hwy::IsSame<hwy::RemoveCvRef<Packed>, BF16>();
@@ -231,12 +236,13 @@ enum class Type {
231236
kI8,
232237
kU16,
233238
kU8,
239+
kInt8,
234240
};
235241
// These are used in `ModelConfig.Specifier`, hence the strings will not
236242
// change, though new ones may be added.
237-
static constexpr const char* kTypeStrings[] = {"unknown", "f32", "bf16", "sfp",
238-
"nuq", "f64", "u32", "u64",
239-
"i8", "u16", "u8"};
243+
static constexpr const char* kTypeStrings[] = {
244+
"unknown", "f32", "bf16", "sfp", "nuq", "f64",
245+
"u32", "u64", "i8", "u16", "u8", "int8"};
240246
static constexpr size_t kNumTypes =
241247
sizeof(kTypeStrings) / sizeof(kTypeStrings[0]);
242248
static constexpr size_t kTypeBits[] = {
@@ -251,6 +257,7 @@ static constexpr size_t kTypeBits[] = {
251257
8 * sizeof(I8Stream),
252258
8 * sizeof(uint16_t),
253259
8 * sizeof(uint8_t),
260+
8 * sizeof(int8_t),
254261
};
255262

256263
static inline bool EnumValid(Type type) {
@@ -281,6 +288,8 @@ constexpr Type TypeEnum() {
281288
return Type::kU16;
282289
} else if constexpr (hwy::IsSame<Packed, uint8_t>()) {
283290
return Type::kU8;
291+
} else if constexpr (hwy::IsSame<Packed, int8_t>()) {
292+
return Type::kInt8;
284293
} else {
285294
return Type::kUnknown;
286295
}

gemma/flash_attention.cc

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1260,6 +1260,52 @@ static HWY_NOINLINE void ApplyMasking(
12601260
}
12611261
}
12621262

1263+
template <int kNumQueries, class DF, class VF = hn::Vec<DF>>
1264+
static HWY_INLINE void MultiplyByScale(DF df, const BF16* scales, VF& x0_p0,
1265+
VF& x0_p1, VF& x1_p0, VF& x1_p1,
1266+
VF& x2_p0, VF& x2_p1, VF& x3_p0,
1267+
VF& x3_p1, VF& x4_p0, VF& x4_p1,
1268+
VF& x5_p0, VF& x5_p1, VF& x6_p0,
1269+
VF& x6_p1, VF& x7_p0, VF& x7_p1) {
1270+
const size_t kTileSize = hn::Lanes(df);
1271+
const PackedSpan<const BF16> scales_span =
1272+
MakeConstSpan(scales, 2 * kTileSize);
1273+
VF scales_p0, scales_p1;
1274+
Decompress2(df, scales_span, 0, scales_p0, scales_p1);
1275+
if constexpr (kNumQueries >= 1) {
1276+
x0_p0 = hn::Mul(x0_p0, scales_p0);
1277+
x0_p1 = hn::Mul(x0_p1, scales_p1);
1278+
}
1279+
if constexpr (kNumQueries >= 2) {
1280+
x1_p0 = hn::Mul(x1_p0, scales_p0);
1281+
x1_p1 = hn::Mul(x1_p1, scales_p1);
1282+
}
1283+
if constexpr (kNumQueries >= 3) {
1284+
x2_p0 = hn::Mul(x2_p0, scales_p0);
1285+
x2_p1 = hn::Mul(x2_p1, scales_p1);
1286+
}
1287+
if constexpr (kNumQueries >= 4) {
1288+
x3_p0 = hn::Mul(x3_p0, scales_p0);
1289+
x3_p1 = hn::Mul(x3_p1, scales_p1);
1290+
}
1291+
if constexpr (kNumQueries >= 5) {
1292+
x4_p0 = hn::Mul(x4_p0, scales_p0);
1293+
x4_p1 = hn::Mul(x4_p1, scales_p1);
1294+
}
1295+
if constexpr (kNumQueries >= 6) {
1296+
x5_p0 = hn::Mul(x5_p0, scales_p0);
1297+
x5_p1 = hn::Mul(x5_p1, scales_p1);
1298+
}
1299+
if constexpr (kNumQueries >= 7) {
1300+
x6_p0 = hn::Mul(x6_p0, scales_p0);
1301+
x6_p1 = hn::Mul(x6_p1, scales_p1);
1302+
}
1303+
if constexpr (kNumQueries >= 8) {
1304+
x7_p0 = hn::Mul(x7_p0, scales_p0);
1305+
x7_p1 = hn::Mul(x7_p1, scales_p1);
1306+
}
1307+
}
1308+
12631309
// Performs tiled flash attention for arbitrary number of queries
12641310
// It depends on kv being tiled.
12651311
// Runs 2 loops one over tiles, and inner one over queries(up to 4 at a time).
@@ -1400,6 +1446,21 @@ HWY_NOINLINE void TileFlashAttentionReturnExpSumsAndMaxLogits(
14001446
false,
14011447
"Query type type not supported, only float and BF16 are supported");
14021448
}
1449+
// microscaling
1450+
// TODO: Change to more generic function to inform if we should use
1451+
// microscaling or not.
1452+
constexpr bool kUseMicroScaling = IsInt8<KV_T>();
1453+
if constexpr (kUseMicroScaling) {
1454+
// After end of the tile, we have kTileSize * 2 bfloat16 for the
1455+
// microscaling scales for K and V.
1456+
const BF16* microscaling_scales_k =
1457+
reinterpret_cast<const BF16*>(tile_base + qkv_dim * 2 * kTileSize) +
1458+
pos_in_tile;
1459+
MultiplyByScale<kNumQueries>(df, microscaling_scales_k, x_0_p_0, x_0_p_1,
1460+
x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1, x_3_p_0,
1461+
x_3_p_1, x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1,
1462+
x_6_p_0, x_6_p_1, x_7_p_0, x_7_p_1);
1463+
}
14031464

14041465
constexpr int kFirstHalfAmountOfQueries = std::min(kNumQueries, 4);
14051466
constexpr int kSecondHalfAmountOfQueries =
@@ -1433,6 +1494,15 @@ HWY_NOINLINE void TileFlashAttentionReturnExpSumsAndMaxLogits(
14331494
x_3_p_0, x_3_p_1, x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1, x_6_p_0, x_6_p_1,
14341495
x_7_p_0, x_7_p_1, max_logits, exp_denominator_sums, scales, q_group_idx,
14351496
kNumQueriesPerGroup);
1497+
if constexpr (kUseMicroScaling) {
1498+
const BF16* microscaling_scales_v =
1499+
reinterpret_cast<const BF16*>(tile_base + qkv_dim * 2 * kTileSize) +
1500+
kTileSize + pos_in_tile;
1501+
MultiplyByScale<kNumQueries>(df, microscaling_scales_v, x_0_p_0, x_0_p_1,
1502+
x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1, x_3_p_0,
1503+
x_3_p_1, x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1,
1504+
x_6_p_0, x_6_p_1, x_7_p_0, x_7_p_1);
1505+
}
14361506
if constexpr (IsF32<Q_T>()) {
14371507
MulByConstAndAddTileUpTo8<kNumQueries>(
14381508
df, scales, x_0_p_0, x_0_p_1, x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1,

0 commit comments

Comments
 (0)