forked from pytorch/executorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtext_decoder_runner.cpp
More file actions
91 lines (75 loc) · 3.13 KB
/
text_decoder_runner.cpp
File metadata and controls
91 lines (75 loc) · 3.13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
// Given inputs, run a text decoder and return logits.
#include <executorch/extension/llm/runner/text_decoder_runner.h>
#include <executorch/kernels/portable/cpu/util/arange_util.h>
#include <ctime>
#include <executorch/extension/llm/runner/stats.h>
namespace executorch {
namespace extension {
namespace llm {
// NOTE: we observed ~2x loading performance increase on iPhone 15
// and a ~5% improvement on Galaxy S22 by switching to
// FileDataLoader instead of MmapDataLoader + UseMlockIgnoreErrors.
TextDecoderRunner::TextDecoderRunner(Module* module, IOManager* io_manager)
: module_(module), io_manager_(io_manager) {}
// This function is functional, meaning it shouldn't modify any state of the
// input. It should be safe to call multiple times with the same inputs. The
// outer loop (call site) is responsible for managing state.
::executorch::runtime::Result<executorch::aten::Tensor> TextDecoderRunner::step(
TensorPtr& tokens,
int64_t start_pos) {
// ET_LOG(Info, "Input token %" PRIu64, input_token);
auto method_meta_result = module_->method_meta("forward");
if (!method_meta_result.ok()) {
return method_meta_result.error();
}
auto method_meta = std::move(*method_meta_result);
// If only 1 input, we are not using kv cache
bool use_kv_cache = method_meta.num_inputs() > 1;
std::vector<int64_t> cache_positions;
if (use_kv_cache) {
auto start_pos_tensor_result = populate_start_pos_or_cache_position(
module_, start_pos, cache_positions, tokens->numel(), "forward");
if (!start_pos_tensor_result.ok()) {
return start_pos_tensor_result.error();
}
auto start_pos_tensor = std::move(*start_pos_tensor_result);
std::vector<runtime::EValue> inputs;
auto inputs_res = io_manager_->prepare_decode(tokens, start_pos_tensor);
ET_CHECK_OK_OR_RETURN_ERROR(inputs_res.error());
inputs = inputs_res.get();
auto outputs_res = module_->forward(inputs);
ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error());
auto update_err = io_manager_->update_decode(outputs_res.get());
ET_CHECK_OK_OR_RETURN_ERROR(update_err);
ET_CHECK_MSG(
outputs_res.get().size() == 1,
"More then one output returned from executing LLM.");
ET_CHECK_MSG(
outputs_res.get()[0].isTensor(),
"Non Tensor Output returned from executing LLM");
// Return the logits tensor
return outputs_res.get()[0].toTensor();
} else { // no kv cache
(void)start_pos; // unused
auto outputs_res = module_->forward(tokens);
ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error());
ET_CHECK_MSG(
outputs_res.get().size() == 1,
"More then one output returned from executing LLM.");
ET_CHECK_MSG(
outputs_res.get()[0].isTensor(),
"Non Tensor Output returned from executing LLM");
// Return the logits tensor
return outputs_res.get()[0].toTensor();
}
}
} // namespace llm
} // namespace extension
} // namespace executorch