-
Notifications
You must be signed in to change notification settings - Fork 1k
Expand file tree
/
Copy pathop_update_cache.cpp
More file actions
304 lines (263 loc) · 10.3 KB
/
Copy pathop_update_cache.cpp
File metadata and controls
304 lines (263 loc) · 10.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <executorch/extension/llm/custom_ops/op_update_cache.h>
#include <executorch/runtime/core/exec_aten/util/dim_order_util.h>
// @lint-ignore CLANGTIDY facebook-unused-include-check
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
#include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h>
namespace torch {
namespace executor {
namespace native {
namespace {
// Helper function to validate cache parameters
bool validate_cache_params(
const Tensor& quantized_value,
const Tensor& quantized_cache,
int64_t start_pos,
int64_t seq_length,
const optional<Tensor>& indices = nullopt) {
ET_CHECK_OR_RETURN_FALSE(
quantized_cache.dim() == 4, "quantized cache must be a 4D tensor");
ET_CHECK_OR_RETURN_FALSE(
quantized_value.dim() == 4, "quantized_value must be a 4D tensor");
if (indices.has_value()) {
const auto& indices_tensor = indices.value();
ET_CHECK_OR_RETURN_FALSE(
indices_tensor.dim() == 2,
"indices must be a 2D tensor [batch_size, seq_len]");
ET_CHECK_OR_RETURN_FALSE(
indices_tensor.size(0) == quantized_value.size(0),
"indices batch dimension must match value batch dimension");
ET_CHECK_OR_RETURN_FALSE(
indices_tensor.size(1) == quantized_value.size(1),
"indices sequence length dimension must match value sequence length dimension");
ET_CHECK_OR_RETURN_FALSE(
indices_tensor.scalar_type() == ScalarType::Long,
"indices must be of Long (int64_t) type");
ET_CHECK_OR_RETURN_FALSE(
is_contiguous_dim_order(
indices_tensor.dim_order().data(), indices_tensor.dim()),
"indices must be in contiguous dim order");
} else {
ET_CHECK_OR_RETURN_FALSE(
start_pos < quantized_cache.size(1),
"start_pos: %" PRId64 " must be less than cache size at dim 1: %zd",
start_pos,
quantized_cache.size(1));
ET_CHECK_OR_RETURN_FALSE(
(start_pos + seq_length) <= quantized_cache.size(1),
"start_post + seq_length must be less than max seq length supported by cache."
"start pos: %" PRId64 ", seq_length: %" PRId64
"."
"cache size: %zd",
start_pos,
seq_length,
quantized_cache.size(1));
}
// Make sure they are in contiguous dim order
ET_CHECK_OR_RETURN_FALSE(
is_contiguous_dim_order(
quantized_cache.dim_order().data(), quantized_cache.dim()),
"quantized cache must be in contiguous dim order");
ET_CHECK_OR_RETURN_FALSE(
is_contiguous_dim_order(
quantized_value.dim_order().data(), quantized_value.dim()),
"quantized value must be in contiguous dim order");
return true;
}
// Helper function for the actual update operation
Tensor& update_cache_impl(
RuntimeContext& ctx,
const Tensor& value,
Tensor& cache,
const int64_t start_pos,
Tensor& output,
const optional<Tensor>& indices = nullopt) {
(void)ctx;
ET_CHECK_MSG(
value.size(0) == cache.size(0),
"projected_value batch size (%zd) should be equal to the cache batch size (%zd).",
value.size(0),
cache.size(0));
ET_CHECK_MSG(
value.size(2) == cache.size(2),
"projected_value number of heads (%zd) should be equal to the cache number of heads (%zd).",
value.size(2),
cache.size(2));
ET_CHECK_MSG(
value.size(3) == cache.size(3),
"projected_value embedding dimension (%zd) should be equal to the cache embedding dimension (%zd).",
value.size(3),
cache.size(3));
ET_CHECK_MSG(
value.element_size() == cache.element_size(),
"projected_value data type size (%zd) should be equal to the cache data type size (%zd).",
value.element_size(),
cache.element_size());
ET_CHECK_MSG(
is_contiguous_dim_order(value.dim_order().data(), value.dim()),
"projected value must be in contiguous dim order");
ET_CHECK_MSG(
is_contiguous_dim_order(cache.dim_order().data(), cache.dim()),
"projected value must be in contiguous dim order");
const void* value_data = value.const_data_ptr();
void* cache_data = cache.mutable_data_ptr();
ET_CHECK_MSG(value_data, "projected_value data is null");
ET_CHECK_MSG(cache_data, "cache data is null");
auto cache_strides = cache.strides();
executorch::aten::StridesType cache_batch_dim_stride = cache_strides[0];
executorch::aten::StridesType cache_seq_dim_stride = cache_strides[1];
auto value_strides = value.strides();
executorch::aten::StridesType value_batch_dim_stride = value_strides[0];
executorch::aten::SizesType num_bytes_to_copy =
(value.numel() / value.size(0)) * value.element_size();
if (indices.has_value()) {
// Use the provided indices tensor for each batch and sequence position
const Tensor& indices_tensor = indices.value();
const int64_t* indices_data =
static_cast<const int64_t*>(indices_tensor.const_data_ptr());
auto indices_strides = indices_tensor.strides();
executorch::aten::StridesType indices_batch_stride = indices_strides[0];
executorch::aten::StridesType indices_seq_stride = indices_strides[1];
// Calculate bytes to copy for a single token
executorch::aten::SizesType bytes_per_token =
(value.numel() / (value.size(0) * value.size(1))) *
value.element_size();
for (int64_t batch_line = 0; batch_line < value.size(0); ++batch_line) {
for (int64_t seq_idx = 0; seq_idx < value.size(1); ++seq_idx) {
// Get the target position from the indices tensor
int64_t target_pos = indices_data
[batch_line * indices_batch_stride + seq_idx * indices_seq_stride];
// Ensure the target position is valid
ET_CHECK_MSG(
target_pos >= 0 && target_pos < cache.size(1),
"Index out of bounds: %" PRId64 " not in [0, %zd)",
target_pos,
cache.size(1));
// Calculate offsets for cache and value
executorch::aten::SizesType cache_pos_offset =
(batch_line * cache_batch_dim_stride +
target_pos * cache_seq_dim_stride) *
cache.element_size();
executorch::aten::SizesType value_pos_offset =
(batch_line * value_batch_dim_stride + seq_idx * value_strides[1]) *
value.element_size();
// Copy a single token
std::memcpy(
(uint8_t*)cache_data + cache_pos_offset,
(uint8_t*)value_data + value_pos_offset,
bytes_per_token);
}
}
} else {
// Use the original implementation with start_pos
for (int64_t batch_line = 0; batch_line < value.size(0); ++batch_line) {
executorch::aten::SizesType cache_pos_offset =
(batch_line * cache_batch_dim_stride +
start_pos * cache_seq_dim_stride) *
cache.element_size();
executorch::aten::SizesType value_pos_offset =
(batch_line * value_batch_dim_stride) * cache.element_size();
std::memcpy(
(uint8_t*)cache_data + cache_pos_offset,
(uint8_t*)value_data + value_pos_offset,
num_bytes_to_copy);
}
}
// Noone uses output. Just a placeholder.
return output;
}
} // anonymous namespace
// Grow cache seq dimension if needed (for DYNAMIC_UNBOUND lazy KV cache).
static bool maybe_resize_cache(
RuntimeContext& ctx,
const Tensor& value,
Tensor& cache,
int64_t start_pos) {
ET_CHECK_OR_RETURN_FALSE(cache.dim() == 4, "cache must be a 4D tensor");
ET_CHECK_OR_RETURN_FALSE(value.dim() == 4, "value must be a 4D tensor");
int64_t seq_len = value.size(1);
int64_t required_seq = start_pos + seq_len;
// Resize if cache is too small OR if data hasn't been allocated yet
// (lazy DYNAMIC_UNBOUND cache starts with capacity_bytes=0 and null data).
if (required_seq > cache.size(1) || cache.const_data_ptr() == nullptr) {
int64_t new_seq = std::max(required_seq, cache.size(1));
executorch::aten::SizesType new_sizes[] = {
static_cast<executorch::aten::SizesType>(cache.size(0)),
static_cast<executorch::aten::SizesType>(new_seq),
static_cast<executorch::aten::SizesType>(cache.size(2)),
static_cast<executorch::aten::SizesType>(cache.size(3)),
};
auto err =
resize_tensor(cache, {new_sizes, static_cast<size_t>(cache.dim())});
if (err != Error::Ok) {
return false;
}
}
return true;
}
// Original update_cache_out function without indices parameter
Tensor& update_cache_out(
RuntimeContext& ctx,
const Tensor& value,
Tensor& cache,
const int64_t start_pos,
Tensor& output) {
int64_t seq_len = value.size(1);
ET_KERNEL_CHECK(
ctx,
maybe_resize_cache(ctx, value, cache, start_pos),
InvalidArgument,
output);
ET_KERNEL_CHECK(
ctx,
validate_cache_params(value, cache, start_pos, seq_len),
InvalidArgument,
output);
return update_cache_impl(ctx, value, cache, start_pos, output);
}
// New function that explicitly takes indices
Tensor& update_cache_with_indices_out(
RuntimeContext& ctx,
const Tensor& value,
Tensor& cache,
const int64_t start_pos,
const Tensor& indices,
Tensor& output) {
int64_t seq_len = value.size(1);
ET_KERNEL_CHECK(
ctx,
maybe_resize_cache(ctx, value, cache, start_pos),
InvalidArgument,
output);
ET_KERNEL_CHECK(
ctx,
validate_cache_params(value, cache, start_pos, seq_len, indices),
InvalidArgument,
output);
return update_cache_impl(ctx, value, cache, start_pos, output, indices);
}
} // namespace native
} // namespace executor
} // namespace torch
// Really this is just an inplace tensor update op
// which makes assumption on the rank of a tensor,
// and the dim order (memory layout) of the tensor.
// Furthermore assumes that the indexing is along
// sequence dimension (dim 1) of the tensor.
// In later diffs will rename this to update_cache.
EXECUTORCH_LIBRARY(
llama,
"update_cache.out",
torch::executor::native::update_cache_out);
// Register the new update_cache_with_indices.out op
EXECUTORCH_LIBRARY(
llama,
"update_cache_with_indices.out",
torch::executor::native::update_cache_with_indices_out);