Skip to content

Commit 772a567

Browse files
akopichxnnpack-bot
authored andcommitted
Avoid dynamic allocation in WASM assembler for storing function types
PiperOrigin-RevId: 545204376
1 parent fab42b3 commit 772a567

2 files changed

Lines changed: 37 additions & 14 deletions

File tree

src/xnnpack/array-helpers.h

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,19 @@ void ArrayApplyImpl(std::array<T, N>&& args, F&& f,
1414
}
1515

1616
template <typename T, size_t N, typename F,
17-
typename Indx = std::make_index_sequence<N> >
17+
typename Indx = std::make_index_sequence<N>>
1818
void ArrayApply(std::array<T, N>&& args, F&& f) {
1919
return ArrayApplyImpl(std::move(args), f, Indx{});
2020
}
2121

2222
template <size_t... Is, typename V>
23-
std::array<V, sizeof...(Is)> MakeArrayImpl(
23+
constexpr std::array<V, sizeof...(Is)> MakeArrayImpl(
2424
V value, std::integer_sequence<size_t, Is...>) {
2525
return {((void)Is, value)...};
2626
}
2727

2828
template <size_t N, typename V>
29-
std::array<V, N> MakeArray(V value) {
29+
constexpr std::array<V, N> MakeArray(V value) {
3030
return MakeArrayImpl(value, std::make_index_sequence<N>{});
3131
}
3232

@@ -36,21 +36,24 @@ static constexpr T kDefault{};
3636
template <typename T, size_t max_size>
3737
class ArrayPrefix {
3838
public:
39-
explicit ArrayPrefix(size_t size) : size_(size) { assert(size_ <= max_size); }
39+
constexpr ArrayPrefix(size_t size, T t)
40+
: size_(size), array_(MakeArray<max_size>(t)) {
41+
assert(size_ <= max_size);
42+
}
4043

41-
ArrayPrefix(size_t size, T t) : size_(size), array_(MakeArray<max_size>(t)) {
44+
explicit constexpr ArrayPrefix(size_t size) : size_(size) {
4245
assert(size_ <= max_size);
4346
}
4447

45-
template <typename Array>
46-
explicit ArrayPrefix(Array&& array) : ArrayPrefix({}) {
48+
template <typename Array,
49+
typename = std::enable_if_t<!std::is_integral_v<Array>>>
50+
explicit constexpr ArrayPrefix(Array&& array) : ArrayPrefix({}) {
4751
for (const auto& v : array) {
48-
array_[size_] = v;
49-
size_++;
52+
push_back(v);
5053
}
5154
}
5255

53-
ArrayPrefix(std::initializer_list<T> init)
56+
constexpr ArrayPrefix(std::initializer_list<T> init)
5457
: ArrayPrefix(init.size(), kDefault<T>) {
5558
assert(size_ <= max_size);
5659
std::copy(init.begin(), init.end(), begin());
@@ -76,6 +79,10 @@ class ArrayPrefix {
7679
assert(index < size_);
7780
return array_[index];
7881
}
82+
void push_back(const T& t) {
83+
assert(size_ + 1 < max_size);
84+
array_[size_++] = t;
85+
}
7986
size_t size() const { return size_; }
8087

8188
private:

src/xnnpack/wasm-assembler.h

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,8 @@ static uint32_t VectorEncodingLength(
7979
}
8080

8181
struct ResultType {
82-
ResultType(std::initializer_list<ValType> codes) : type(kNoTypeCode) {
82+
constexpr ResultType(std::initializer_list<ValType> codes)
83+
: type(kNoTypeCode) {
8384
switch (codes.size()) {
8485
case 0:
8586
break;
@@ -99,20 +100,32 @@ struct ResultType {
99100
static constexpr byte kNoTypeCode = 0;
100101
};
101102

103+
template <>
104+
static constexpr ResultType kDefault<ResultType>{};
105+
102106
inline bool operator==(const ResultType& lhs, const ResultType& rhs) {
103107
return lhs.type == rhs.type;
104108
}
105109

106110
static constexpr size_t kMaxParamsCount = 16;
107-
using Params = ArrayPrefix<ValType, kMaxParamsCount>;
111+
struct Params : ArrayPrefix<ValType, kMaxParamsCount> {
112+
using ArrayPrefix::ArrayPrefix;
113+
};
114+
115+
template <>
116+
static constexpr Params kDefault<Params>{0, kDefault<ValType>};
108117

109118
struct FuncType {
110-
FuncType(const Params& params, ResultType result)
119+
constexpr FuncType(const Params& params, ResultType result)
111120
: params(params), result(result) {}
112121
Params params;
113122
ResultType result;
114123
};
115124

125+
template <>
126+
static constexpr FuncType kDefault<FuncType>{kDefault<Params>,
127+
kDefault<ResultType>};
128+
116129
inline bool operator==(const FuncType& lhs, const FuncType& rhs) {
117130
return lhs.result == rhs.result &&
118131
std::equal(lhs.params.begin(), lhs.params.end(), rhs.params.begin(),
@@ -786,9 +799,12 @@ class WasmAssembler : public AssemblerBase, public internal::WasmOps {
786799
internal::StoreEncodedU32(n, [this](byte b) { emit8(b); });
787800
}
788801

802+
static constexpr size_t kMaxNumFuncTypes = 16;
803+
789804
std::vector<Function> functions_;
790-
std::vector<FuncType> func_types_;
791805
std::vector<Export> exports_;
806+
internal::ArrayPrefix<FuncType, kMaxNumFuncTypes> func_types_{
807+
0, internal::kDefault<FuncType>};
792808
};
793809

794810
} // namespace xnnpack

0 commit comments

Comments
 (0)