Skip to content

Commit fed9aea

Browse files
authored
Merge branch 'main' into gh/kimishpatel/213/head
2 parents 7608f53 + 5e60898 commit fed9aea

4 files changed

Lines changed: 68 additions & 30 deletions

File tree

extension/llm/custom_ops/op_sdpa.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,7 @@ Tensor& flash_attention_kernel_out(
273273
// we might consider another appraoch
274274
if (seq_len >= 768) {
275275
sdpa::impl::cpu_flash_attention<CTYPE, 256, 512>(
276+
ctx,
276277
output,
277278
query,
278279
key,
@@ -289,6 +290,7 @@ Tensor& flash_attention_kernel_out(
289290
nullopt);
290291
} else if (seq_len >= 192) {
291292
sdpa::impl::cpu_flash_attention<CTYPE, 64, 512>(
293+
ctx,
292294
output,
293295
query,
294296
key,
@@ -305,6 +307,7 @@ Tensor& flash_attention_kernel_out(
305307
nullopt);
306308
} else {
307309
sdpa::impl::cpu_flash_attention<CTYPE, 32, 512>(
310+
ctx,
308311
output,
309312
query,
310313
key,
@@ -418,6 +421,7 @@ Tensor& custom_sdpa_out_impl(
418421
// we might consider another appraoch
419422
if (seq_len >= 768) {
420423
sdpa::impl::cpu_flash_attention<CTYPE, 256, 512>(
424+
ctx,
421425
output,
422426
q,
423427
k,
@@ -437,6 +441,7 @@ Tensor& custom_sdpa_out_impl(
437441
num_keys_for_causal_attention);
438442
} else if (seq_len >= 192) {
439443
sdpa::impl::cpu_flash_attention<CTYPE, 64, 512>(
444+
ctx,
440445
output,
441446
q,
442447
k,
@@ -456,6 +461,7 @@ Tensor& custom_sdpa_out_impl(
456461
num_keys_for_causal_attention);
457462
} else {
458463
sdpa::impl::cpu_flash_attention<CTYPE, 32, 512>(
464+
ctx,
459465
output,
460466
q,
461467
k,

extension/llm/custom_ops/op_sdpa_impl.h

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,7 @@ TODO: Just handle conversion of bool mask to float
543543
*/
544544
template <typename scalar_t, int64_t q_split_size, int64_t kv_split_size>
545545
void cpu_flash_attention(
546+
RuntimeContext& ctx,
546547
Tensor& output,
547548
const Tensor& query,
548549
const Tensor& key,
@@ -763,29 +764,34 @@ void cpu_flash_attention(
763764

764765
// Since all intermediate compute is accum_t, we need to
765766
// allocate a buffer accordingly.
766-
int64_t size_of_intermediate_precision = sizeof(accum_t);
767-
int64_t size_bytes = size_per_thread * num_thread * query.element_size() *
768-
size_of_intermediate_precision;
769-
std::vector<char> buf_vec(size_bytes);
770-
void* buf = reinterpret_cast<void*>(buf_vec.data());
771-
// Need to double check the following
772-
size_bytes = num_thread * qSplitSize * kvSplitSize * query.element_size();
773-
std::vector<char> buf_reduced_vec(size_bytes);
774-
void* buf_reduced = reinterpret_cast<void*>(buf_reduced_vec.data());
775-
// at::Tensor buf_reduced = at::empty(
776-
// {num_thread, qSplitSize, is_reduced_type ? kvSplitSize : 0},
777-
// query.options());
767+
int64_t size_bytes = size_per_thread * num_thread * sizeof(accum_t);
768+
std::unique_ptr<char[]> allocated_buf;
769+
void* buf;
770+
Result<void*> scratch = ctx.allocate_temp(size_bytes, 64);
771+
if (!scratch.ok()) {
772+
allocated_buf = std::make_unique<char[]>(size_bytes);
773+
buf = allocated_buf.get();
774+
} else {
775+
buf = scratch.get();
776+
}
777+
void* buf_reduced = nullptr;
778778
int64_t size_per_thread_qdq_vec = kvSplitSize * headSize;
779779
// Lets align size_per_thread_qdq_vec to 64 bytes, for coalesced cache reads,
780780
// by padding with right number of per thread elements
781-
constexpr int64_t kAlignment = 64;
782-
size_per_thread_qdq_vec =
783-
(size_per_thread_qdq_vec + kAlignment - 1) & (-(kAlignment - 1));
784781
int64_t size_per_thread_qdq_bytes = size_per_thread_qdq_vec * sizeof(accum_t);
785782
int64_t size_qdq_bytes = size_per_thread_qdq_bytes * num_thread;
786-
std::vector<char> scratch_for_quant_dequant_vec(size_qdq_bytes);
787-
accum_t* scratch_for_quant_dequant =
788-
reinterpret_cast<accum_t*>(scratch_for_quant_dequant_vec.data());
783+
std::unique_ptr<char[]> allocated_buf_for_qdq;
784+
accum_t* scratch_for_quant_dequant;
785+
Result<void*> scratch_for_quant_dequant_res =
786+
ctx.allocate_temp(size_qdq_bytes, 64);
787+
if (!scratch_for_quant_dequant_res.ok()) {
788+
allocated_buf_for_qdq = std::make_unique<char[]>(size_qdq_bytes);
789+
scratch_for_quant_dequant =
790+
reinterpret_cast<accum_t*>(allocated_buf_for_qdq.get());
791+
} else {
792+
scratch_for_quant_dequant =
793+
reinterpret_cast<accum_t*>(scratch_for_quant_dequant_res.get());
794+
}
789795

790796
// Data ptrs
791797
const scalar_t* q_data = query.const_data_ptr<scalar_t>();
@@ -819,6 +825,7 @@ void cpu_flash_attention(
819825
// Initialize max and sum
820826
fill_stub(
821827
qk_max_data, -std::numeric_limits<accum_t>::infinity(), qBlockSize);
828+
fill_stub(qk_sum_data, static_cast<accum_t>(0), qBlockSize);
822829
// Original flash sdpa wasnt really meant to be used
823830
// for decode the way we are using via start_pos here.
824831
// Thus when num_keys is 1 during decode phase, we
@@ -850,6 +857,7 @@ void cpu_flash_attention(
850857
is_causal ? std::min(m + start_pos + qBlockSize, kvSize) : kvSize;
851858
int64_t m_start_pos = m + start_pos;
852859
auto j_kv = j / num_reps;
860+
fill_stub(dst_data, static_cast<accum_t>(0), qSplitSize * headSize);
853861
for (int64_t n = 0; n < num_keys; n += kvSplitSize) {
854862
int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n);
855863
// Calculate scale * q @ k.T

extension/module/module.cpp

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,17 @@ runtime::Result<std::unique_ptr<runtime::DataLoader>> make_data_loader(
7878
Module::Module(
7979
const std::string& file_path,
8080
const LoadMode load_mode,
81-
std::unique_ptr<runtime::EventTracer> event_tracer)
81+
std::unique_ptr<runtime::EventTracer> event_tracer,
82+
std::unique_ptr<runtime::MemoryAllocator> memory_allocator,
83+
std::unique_ptr<runtime::MemoryAllocator> temp_allocator)
8284
: file_path_(file_path),
8385
load_mode_(load_mode),
84-
memory_allocator_(std::make_unique<MallocMemoryAllocator>()),
85-
temp_allocator_(std::make_unique<MallocMemoryAllocator>()),
86+
memory_allocator_(
87+
memory_allocator ? std::move(memory_allocator)
88+
: std::make_unique<MallocMemoryAllocator>()),
89+
temp_allocator_(
90+
temp_allocator ? std::move(temp_allocator)
91+
: std::make_unique<MallocMemoryAllocator>()),
8692
event_tracer_(std::move(event_tracer)) {
8793
runtime::runtime_init();
8894
}
@@ -91,11 +97,17 @@ Module::Module(
9197
const std::string& file_path,
9298
const std::string& data_map_path,
9399
const LoadMode load_mode,
94-
std::unique_ptr<runtime::EventTracer> event_tracer)
100+
std::unique_ptr<runtime::EventTracer> event_tracer,
101+
std::unique_ptr<runtime::MemoryAllocator> memory_allocator,
102+
std::unique_ptr<runtime::MemoryAllocator> temp_allocator)
95103
: file_path_(file_path),
96104
load_mode_(load_mode),
97-
memory_allocator_(std::make_unique<MallocMemoryAllocator>()),
98-
temp_allocator_(std::make_unique<MallocMemoryAllocator>()),
105+
memory_allocator_(
106+
memory_allocator ? std::move(memory_allocator)
107+
: std::make_unique<MallocMemoryAllocator>()),
108+
temp_allocator_(
109+
temp_allocator ? std::move(temp_allocator)
110+
: std::make_unique<MallocMemoryAllocator>()),
99111
event_tracer_(std::move(event_tracer)) {
100112
if (!data_map_path.empty()) {
101113
data_files_.push_back(data_map_path);
@@ -107,12 +119,18 @@ Module::Module(
107119
const std::string& file_path,
108120
std::vector<std::string> data_files,
109121
const LoadMode load_mode,
110-
std::unique_ptr<runtime::EventTracer> event_tracer)
122+
std::unique_ptr<runtime::EventTracer> event_tracer,
123+
std::unique_ptr<runtime::MemoryAllocator> memory_allocator,
124+
std::unique_ptr<runtime::MemoryAllocator> temp_allocator)
111125
: file_path_(file_path),
112126
data_files_(std::move(data_files)),
113127
load_mode_(load_mode),
114-
memory_allocator_(std::make_unique<MallocMemoryAllocator>()),
115-
temp_allocator_(std::make_unique<MallocMemoryAllocator>()),
128+
memory_allocator_(
129+
memory_allocator ? std::move(memory_allocator)
130+
: std::make_unique<MallocMemoryAllocator>()),
131+
temp_allocator_(
132+
temp_allocator ? std::move(temp_allocator)
133+
: std::make_unique<MallocMemoryAllocator>()),
116134
event_tracer_(std::move(event_tracer)) {
117135
runtime::runtime_init();
118136
}

extension/module/module.h

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,9 @@ class Module {
6363
explicit Module(
6464
const std::string& file_path,
6565
const LoadMode load_mode = LoadMode::File,
66-
std::unique_ptr<runtime::EventTracer> event_tracer = nullptr);
66+
std::unique_ptr<runtime::EventTracer> event_tracer = nullptr,
67+
std::unique_ptr<runtime::MemoryAllocator> memory_allocator = nullptr,
68+
std::unique_ptr<runtime::MemoryAllocator> temp_allocator = nullptr);
6769

6870
/**
6971
* Constructs an instance by loading a program from a file with specified
@@ -78,7 +80,9 @@ class Module {
7880
const std::string& file_path,
7981
const std::string& data_map_path,
8082
const LoadMode load_mode = LoadMode::File,
81-
std::unique_ptr<runtime::EventTracer> event_tracer = nullptr);
83+
std::unique_ptr<runtime::EventTracer> event_tracer = nullptr,
84+
std::unique_ptr<runtime::MemoryAllocator> memory_allocator = nullptr,
85+
std::unique_ptr<runtime::MemoryAllocator> temp_allocator = nullptr);
8286

8387
/**
8488
* Constructs an instance by loading a program from a file with specified
@@ -93,7 +97,9 @@ class Module {
9397
const std::string& file_path,
9498
std::vector<std::string> data_files,
9599
const LoadMode load_mode = LoadMode::File,
96-
std::unique_ptr<runtime::EventTracer> event_tracer = nullptr);
100+
std::unique_ptr<runtime::EventTracer> event_tracer = nullptr,
101+
std::unique_ptr<runtime::MemoryAllocator> memory_allocator = nullptr,
102+
std::unique_ptr<runtime::MemoryAllocator> temp_allocator = nullptr);
97103

98104
/**
99105
* Constructs an instance with the provided data loader and memory allocator.

0 commit comments

Comments
 (0)