Skip to content

Commit 618be56

Browse files
committed
issue/1113
1 parent 3f0a98c commit 618be56

19 files changed

Lines changed: 6664 additions & 0 deletions
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
#ifndef __INFINIOP_AWQ_MARLIN_GEMM_API_H__
2+
#define __INFINIOP_AWQ_MARLIN_GEMM_API_H__
3+
4+
#include "../operator_descriptor.h"
5+
#include <cstdint>
6+
7+
typedef struct InfiniopDescriptor *infiniopAwqMarlinGemmDescriptor_t;
8+
9+
__INFINI_C __export infiniStatus_t infiniopCreateAwqMarlinGemmDescriptor(infiniopHandle_t handle,
10+
infiniopAwqMarlinGemmDescriptor_t *desc_ptr,
11+
infiniopTensorDescriptor_t out_desc,
12+
infiniopTensorDescriptor_t a_desc,
13+
infiniopTensorDescriptor_t b_desc,
14+
infiniopTensorDescriptor_t b_bias_desc,
15+
infiniopTensorDescriptor_t b_scales_desc,
16+
infiniopTensorDescriptor_t a_scales_desc,
17+
infiniopTensorDescriptor_t global_scales_desc,
18+
infiniopTensorDescriptor_t b_zeros_desc,
19+
infiniopTensorDescriptor_t g_idx_desc,
20+
infiniopTensorDescriptor_t perm_desc);
21+
22+
__INFINI_C __export infiniStatus_t infiniopGetAwqMarlinGemmWorkspaceSize(infiniopAwqMarlinGemmDescriptor_t desc, size_t *size);
23+
24+
__INFINI_C __export infiniStatus_t infiniopAwqMarlinGemm(infiniopAwqMarlinGemmDescriptor_t desc,
25+
void *workspace,
26+
size_t workspace_size,
27+
void *c,
28+
const void *a,
29+
const void *b,
30+
void *b_bias,
31+
void *b_scales,
32+
void *a_scales,
33+
void *global_scales,
34+
void *b_zeros,
35+
void *g_idx,
36+
void *perm,
37+
int64_t b_q_type_id,
38+
bool is_k_full,
39+
bool use_atomic_add,
40+
bool use_fp32_reduce,
41+
bool is_zp_float,
42+
void *stream);
43+
44+
__INFINI_C __export infiniStatus_t infiniopDestroyAwqMarlinGemmDescriptor(infiniopAwqMarlinGemmDescriptor_t desc);
45+
46+
#endif
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
#ifndef AWQ_MARLIN_GEMM_H
2+
#define AWQ_MARLIN_GEMM_H
3+
4+
#include "../../operator.h"
5+
#include "info.h"
6+
7+
#define DESCRIPTOR(NAMESPACE) \
8+
\
9+
namespace op::awq_marlin_gemm::NAMESPACE { \
10+
class Descriptor final : public InfiniopDescriptor { \
11+
struct Opaque; \
12+
Opaque *_opaque; \
13+
AwqMarlinGemmInfo _info; \
14+
size_t _workspace_size; \
15+
\
16+
Descriptor( \
17+
Opaque *opaque, \
18+
AwqMarlinGemmInfo info, \
19+
size_t workspace_size, \
20+
infiniDevice_t device_type, \
21+
int device_id) \
22+
: InfiniopDescriptor{device_type, device_id}, \
23+
_opaque(opaque), \
24+
_info(info), \
25+
_workspace_size(workspace_size) {} \
26+
\
27+
public: \
28+
~Descriptor(); \
29+
\
30+
size_t workspaceSize() const { return _workspace_size; } \
31+
\
32+
static infiniStatus_t create( \
33+
infiniopHandle_t handle, \
34+
Descriptor **desc_ptr, \
35+
infiniopTensorDescriptor_t out_desc, \
36+
infiniopTensorDescriptor_t a_desc, \
37+
infiniopTensorDescriptor_t b_desc, \
38+
infiniopTensorDescriptor_t b_bias_desc, \
39+
infiniopTensorDescriptor_t b_scales_desc, \
40+
infiniopTensorDescriptor_t a_scales_desc, \
41+
infiniopTensorDescriptor_t global_scales_desc, \
42+
infiniopTensorDescriptor_t b_zeros_desc, \
43+
infiniopTensorDescriptor_t g_idx_desc, \
44+
infiniopTensorDescriptor_t perm_desc); \
45+
\
46+
infiniStatus_t calculate( \
47+
void *workspace, size_t workspace_size, \
48+
void *c, \
49+
const void *a, const void *b, \
50+
void *b_bias, void *b_scales, void *a_scales, void *global_scales, \
51+
void *b_zeros, void *g_idx, void *perm, \
52+
int64_t b_q_type_id, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float, \
53+
void *stream) const; \
54+
}; \
55+
}
56+
57+
#endif // AWQ_MARLIN_GEMM_H
Lines changed: 291 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,291 @@
1+
#pragma once
2+
3+
#include <cstdint>
4+
#include <tuple>
5+
#include <variant>
6+
#include <string>
7+
#include <type_traits>
8+
#include <stdexcept>
9+
#include <limits>
10+
11+
namespace vllm
12+
{
13+
14+
class ScalarType
15+
{
16+
public:
17+
enum NanRepr : uint8_t
18+
{
19+
NAN_NONE = 0,
20+
NAN_IEEE_754 = 1,
21+
NAN_EXTD_RANGE_MAX_MIN = 2,
22+
23+
NAN_REPR_ID_MAX
24+
};
25+
26+
constexpr ScalarType(uint8_t exponent, uint8_t mantissa, bool signed_,
27+
int32_t bias, bool finite_values_only = false,
28+
NanRepr nan_repr = NAN_IEEE_754)
29+
: exponent(exponent),
30+
mantissa(mantissa),
31+
signed_(signed_),
32+
bias(bias),
33+
finite_values_only(finite_values_only),
34+
nan_repr(nan_repr) {}
35+
36+
// -----------------------
37+
// Integer
38+
// -----------------------
39+
static constexpr ScalarType int_(uint8_t size_bits, int32_t bias = 0)
40+
{
41+
return ScalarType(0, size_bits - 1, true, bias);
42+
}
43+
44+
static constexpr ScalarType uint(uint8_t size_bits, int32_t bias = 0)
45+
{
46+
return ScalarType(0, size_bits, false, bias);
47+
}
48+
49+
// -----------------------
50+
// Floating point(constexpr安全:不做检查)
51+
// -----------------------
52+
static constexpr ScalarType float_IEEE754(uint8_t exponent,
53+
uint8_t mantissa)
54+
{
55+
return ScalarType(exponent, mantissa, true, 0, false, NAN_IEEE_754);
56+
}
57+
58+
static constexpr ScalarType float_(uint8_t exponent, uint8_t mantissa,
59+
bool finite_values_only,
60+
NanRepr nan_repr)
61+
{
62+
return ScalarType(exponent, mantissa, true, 0,
63+
finite_values_only, nan_repr);
64+
}
65+
66+
// -----------------------
67+
// Runtime checked(可选)
68+
// -----------------------
69+
static inline ScalarType float_checked(uint8_t exponent,
70+
uint8_t mantissa,
71+
bool finite_values_only,
72+
NanRepr nan_repr)
73+
{
74+
if (!(nan_repr < NAN_REPR_ID_MAX))
75+
throw std::runtime_error("Invalid NanRepr");
76+
77+
if (!(mantissa > 0 && exponent > 0))
78+
throw std::runtime_error("mantissa/exponent must > 0");
79+
80+
if (nan_repr == NAN_IEEE_754)
81+
throw std::runtime_error("use float_IEEE754");
82+
83+
return float_(exponent, mantissa, finite_values_only, nan_repr);
84+
}
85+
86+
uint8_t const exponent;
87+
uint8_t const mantissa;
88+
bool const signed_;
89+
int32_t const bias;
90+
91+
bool const finite_values_only;
92+
NanRepr const nan_repr;
93+
94+
using Id = int64_t;
95+
96+
private:
97+
template <typename T_>
98+
static constexpr size_t member_id_field_width()
99+
{
100+
using T = std::decay_t<T_>;
101+
return std::is_same<T, bool>::value ? 1 : sizeof(T) * 8;
102+
}
103+
104+
template <typename Fn, typename Init, typename Member, typename... Rest>
105+
static constexpr auto reduce_members_helper(Fn f, Init val, Member member,
106+
Rest... rest)
107+
{
108+
auto new_val = f(val, member);
109+
if constexpr (sizeof...(rest) > 0)
110+
{
111+
return reduce_members_helper(f, new_val, rest...);
112+
}
113+
else
114+
{
115+
return new_val;
116+
}
117+
}
118+
119+
template <typename Fn, typename Init>
120+
constexpr auto reduce_members(Fn f, Init init) const
121+
{
122+
return reduce_members_helper(f, init, exponent, mantissa, signed_, bias,
123+
finite_values_only, nan_repr);
124+
}
125+
126+
template <typename Fn, typename Init>
127+
static constexpr auto reduce_member_types(Fn f, Init init)
128+
{
129+
constexpr auto dummy = ScalarType(0, 0, false, 0, false, NAN_NONE);
130+
return dummy.reduce_members(f, init);
131+
}
132+
133+
static constexpr auto id_size_bits()
134+
{
135+
return reduce_member_types(
136+
[](int acc, auto member) -> int
137+
{
138+
return acc + member_id_field_width<decltype(member)>();
139+
},
140+
0);
141+
}
142+
143+
public:
144+
constexpr Id id() const
145+
{
146+
static_assert(id_size_bits() <= sizeof(Id) * 8,
147+
"ScalarType id too large");
148+
149+
auto fn = [](std::pair<Id, uint32_t> result, auto member)
150+
{
151+
auto [id, offset] = result;
152+
constexpr auto bits = member_id_field_width<decltype(member)>();
153+
return std::pair<Id, uint32_t>{
154+
id | (int64_t(member) & ((uint64_t(1) << bits) - 1)) << offset,
155+
offset + bits};
156+
};
157+
158+
return reduce_members(fn, std::pair<Id, uint32_t>{}).first;
159+
}
160+
161+
static constexpr ScalarType from_id(Id id)
162+
{
163+
auto fn = [id](auto result, auto member)
164+
{
165+
using T = decltype(member);
166+
auto [tuple, offset] = result;
167+
constexpr auto bits = member_id_field_width<T>();
168+
auto val = static_cast<T>((id >> offset) & ((uint64_t(1) << bits) - 1));
169+
return std::pair{std::tuple_cat(tuple, std::make_tuple(val)), offset + bits};
170+
};
171+
172+
auto [args, _] =
173+
reduce_member_types(fn, std::pair<std::tuple<>, int>{});
174+
175+
return std::apply([](auto... xs)
176+
{ return ScalarType(xs...); }, args);
177+
}
178+
179+
constexpr int64_t size_bits() const
180+
{
181+
return mantissa + exponent + (signed_ ? 1 : 0);
182+
}
183+
184+
constexpr bool is_signed() const { return signed_; }
185+
constexpr bool is_integer() const { return exponent == 0; }
186+
constexpr bool is_floating_point() const { return exponent > 0; }
187+
188+
constexpr bool is_ieee_754() const
189+
{
190+
return is_floating_point() && !finite_values_only &&
191+
nan_repr == NAN_IEEE_754;
192+
}
193+
194+
constexpr bool has_nans() const
195+
{
196+
return is_floating_point() && nan_repr != NAN_NONE;
197+
}
198+
199+
constexpr bool has_infs() const
200+
{
201+
return is_floating_point() && !finite_values_only;
202+
}
203+
204+
constexpr bool has_bias() const { return bias != 0; }
205+
206+
std::string str() const
207+
{
208+
if (is_floating_point())
209+
{
210+
auto ret = "float" + std::to_string(size_bits()) + "_e" +
211+
std::to_string(exponent) + "m" + std::to_string(mantissa);
212+
213+
if (!is_ieee_754())
214+
{
215+
if (finite_values_only)
216+
ret += "f";
217+
if (nan_repr != NAN_NONE)
218+
ret += "n";
219+
}
220+
return ret;
221+
}
222+
else
223+
{
224+
auto ret = (signed_ ? "int" : "uint") +
225+
std::to_string(size_bits());
226+
if (has_bias())
227+
ret += "b" + std::to_string(bias);
228+
return ret;
229+
}
230+
}
231+
232+
constexpr bool operator==(ScalarType const &other) const
233+
{
234+
return mantissa == other.mantissa &&
235+
exponent == other.exponent &&
236+
bias == other.bias &&
237+
signed_ == other.signed_ &&
238+
finite_values_only == other.finite_values_only &&
239+
nan_repr == other.nan_repr;
240+
}
241+
};
242+
243+
using ScalarTypeId = ScalarType::Id;
244+
245+
// -----------------------
246+
// 原始常量(完全保留)
247+
// -----------------------
248+
249+
static inline constexpr auto kS4 = ScalarType::int_(4);
250+
static inline constexpr auto kU4 = ScalarType::uint(4);
251+
static inline constexpr auto kU4B8 = ScalarType::uint(4, 8);
252+
static inline constexpr auto kS8 = ScalarType::int_(8);
253+
static inline constexpr auto kU8 = ScalarType::uint(8);
254+
static inline constexpr auto kU8B128 = ScalarType::uint(8, 128);
255+
256+
static inline constexpr auto kFE2M1f =
257+
ScalarType::float_(2, 1, true, ScalarType::NAN_NONE);
258+
static inline constexpr auto kFE3M2f =
259+
ScalarType::float_(3, 2, true, ScalarType::NAN_NONE);
260+
static inline constexpr auto kFE4M3fn =
261+
ScalarType::float_(4, 3, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN);
262+
static inline constexpr auto kFE8M0fnu =
263+
ScalarType(8, 0, false, 0, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN);
264+
static inline constexpr auto kFE5M2 = ScalarType::float_IEEE754(5, 2);
265+
static inline constexpr auto kFE8M7 = ScalarType::float_IEEE754(8, 7);
266+
static inline constexpr auto kFE5M10 = ScalarType::float_IEEE754(5, 10);
267+
268+
// 🔥 关键:alias(不能丢!)
269+
270+
static inline constexpr auto kInt4 = kS4;
271+
static inline constexpr auto kUint4 = kU4;
272+
static inline constexpr auto kUint4b8 = kU4B8;
273+
static inline constexpr auto kInt8 = kS8;
274+
static inline constexpr auto kUint8 = kU8;
275+
static inline constexpr auto kUint8b128 = kU8B128;
276+
277+
static inline constexpr auto kFloat4_e2m1f = kFE2M1f;
278+
static inline constexpr auto kFloat6_e3m2f = kFE3M2f;
279+
static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn;
280+
static inline constexpr auto kFloat8_e5m2 = kFE5M2;
281+
static inline constexpr auto kFloat16_e8m7 = kFE8M7;
282+
static inline constexpr auto kFloat16_e5m10 = kFE5M10;
283+
284+
// ⭐ 这些就是你报错缺失的
285+
static inline constexpr auto kHalf = kFE5M10;
286+
static inline constexpr auto kFloat16 = kHalf;
287+
static inline constexpr auto kBFloat16 = kFE8M7;
288+
289+
static inline constexpr auto kFloat16Id = kFloat16.id();
290+
291+
} // namespace vllm

0 commit comments

Comments
 (0)