-
Notifications
You must be signed in to change notification settings - Fork 64
Expand file tree
/
Copy pathstatic_batching_compiler.hpp
More file actions
36 lines (28 loc) · 971 Bytes
/
static_batching_compiler.hpp
File metadata and controls
36 lines (28 loc) · 971 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#pragma once
#include "graph_compiler.hpp"
#include <unordered_map>
namespace infinilm::engine {
class StaticBatchingCompiler : public GraphCompiler {
public:
StaticBatchingCompiler(const std::shared_ptr<InfinilmModel> &model, RankBarrier *barrier);
void compile() override;
Compiled get_compiled(const InfinilmModel::Input &input) override;
private:
struct TupleHash {
size_t operator()(const std::tuple<size_t, size_t> &t) const noexcept {
auto h1 = std::hash<size_t>{}(std::get<0>(t));
auto h2 = std::hash<size_t>{}(std::get<1>(t));
return h1 ^ (h2 + 0x9e3779b97f4a7c15ULL + (h1 << 6) + (h1 >> 2));
}
};
struct CompiledResult {
InfinilmModel::Input input;
Compiled compiled;
};
std::unordered_map<
std::tuple<size_t, size_t>, // (batch_size, seq_len)
CompiledResult,
TupleHash>
compiled_map_;
};
} // namespace infinilm::engine