forked from pytorch/executorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtext_prefiller.h
More file actions
92 lines (80 loc) · 2.77 KB
/
text_prefiller.h
File metadata and controls
92 lines (80 loc) · 2.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
// Given a text prompt, encode it using tokenizer and prefill the KV cache of a
// LLM.
#pragma once
#include <executorch/extension/llm/runner/text_decoder_runner.h>
namespace executorch {
namespace extension {
namespace llm {
class ET_EXPERIMENTAL TextPrefiller {
public:
TextPrefiller(
TextDecoderRunner* text_decoder_runner,
bool use_kv_cache,
bool enable_parallel_prefill,
int64_t max_seq_len = 128);
virtual ~TextPrefiller() = default;
/**
* Prefill an LLM Module with the given text input.
* @param prompt_tokens The text prompt tokens to the LLM Module. Encoded by
* tokenizer.
* @param start_pos The starting position in KV cache of the input in the LLM
* Module.
* @return The next token of the LLM Module after prefill.
*/
virtual ::executorch::runtime::Result<uint64_t> prefill(
std::vector<uint64_t>& prompt_tokens,
int64_t& start_pos);
/**
* Helper method to prefill a chunk of tokens.
* @param prompt_tokens The chunk of text prompt tokens to process.
* @param start_pos The starting position in KV cache of the input in the LLM
* Module.
* @return The next token of the LLM Module after prefilling this chunk.
*/
virtual ::executorch::runtime::Result<uint64_t> prefill_chunk(
std::vector<uint64_t>& prompt_tokens,
int64_t& start_pos);
/**
* Load the necessary resources for the TextPrefiller.
* This method should be called before using the prefill methods.
*/
::executorch::runtime::Error load() {
return text_decoder_runner_->load();
}
/**
* Check if the TextPrefiller has been successfully loaded.
* @return True if the resources are loaded, false otherwise.
*/
bool inline is_loaded() const {
// Implementation to check if resources are loaded
return text_decoder_runner_->is_method_loaded();
}
private:
/**
* Note: TextPrefiller does not own the TextDecoderRunner instance.
* The responsibility of managing the lifecycle of TextDecoderRunner
* lies with the outer class or entity (likely Runner) that creates
* and passes the TextDecoderRunner instance to TextPrefiller.
*/
TextDecoderRunner* text_decoder_runner_;
bool use_kv_cache_;
bool enable_parallel_prefill_;
int64_t max_seq_len_;
};
} // namespace llm
} // namespace extension
} // namespace executorch
namespace torch {
namespace executor {
// TODO(T197294990): Remove these deprecated aliases once all users have moved
// to the new `::executorch` namespaces.
using ::executorch::extension::llm::TextPrefiller;
} // namespace executor
} // namespace torch