forked from pytorch/executorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathirunner.h
More file actions
173 lines (149 loc) · 5.51 KB
/
irunner.h
File metadata and controls
173 lines (149 loc) · 5.51 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
// Interface for text generation runners.
#pragma once
#include <algorithm>
#include <cstdint>
#include <functional>
#include <memory>
#include <string>
#include <vector>
#include <executorch/extension/llm/runner/stats.h>
#include <executorch/runtime/core/error.h>
#include <executorch/runtime/core/result.h>
namespace executorch {
namespace extension {
namespace llm {
class MultimodalInput; // Forward declaration
// Configuration struct for generation parameters, fields should be sorted in
// alphabetic order
struct GenerationConfig {
// Whether to echo the input prompt in the output
bool echo = true;
// Whether to ignore EOS token and continue generating until max_new_tokens
bool ignore_eos = false;
// Maximum number of new tokens to generate
// If the max_context_len metadata that's serialized in the .pte file exists,
// then the number of prompt tokens + max_new_tokens won't exceed
// max_context_len. If this field is -1, it means we will rely on
// max_context_len metadata and seq_len value. Check resolve_max_new_tokens
// for details.
int32_t max_new_tokens = -1;
// Whether this is a warmup run (affects perf benchmarking)
bool warming = false;
// Maximum number of total tokens
// If the .pte file contains the max_context_len metadata, it will override
// this value if it's smaller. If this field is -1, we will use the
// max_context_len metadata directly. Check resolve_max_new_tokens for
// details.
int32_t seq_len = -1;
// Temperature for sampling (higher = more random)
float temperature = 0.8f;
// Number of eos and bos to add to the prompt
int32_t num_bos = 0;
int32_t num_eos = 0;
/**
* Resolve the maximum number of new tokens to generate based on constraints.
*
* This method calculates the maximum number of new tokens that can be
* generated considering both seq_len and max_new_tokens constraints, as well
* as the model's maximum context length and how many token positions are
* already occupied (e.g. by prior turns and the current prompt).
*
* @param max_context_len The maximum context length supported by the model
* @param num_tokens_occupied The number of token positions already occupied
* in the context window (e.g. pos_ after prefill)
* @return The resolved maximum number of new tokens to generate
*/
int32_t resolve_max_new_tokens(
int64_t max_context_len,
int64_t num_tokens_occupied) const {
int64_t result;
if (seq_len == -1 && max_new_tokens == -1) {
// Both are -1, use max context len minus occupied tokens
result = max_context_len - num_tokens_occupied;
} else if (seq_len == -1 && max_new_tokens != -1) {
// Only max_new_tokens is specified
result = std::min(
static_cast<int64_t>(max_new_tokens),
max_context_len - num_tokens_occupied);
} else if (seq_len != -1 && max_new_tokens == -1) {
// Only seq_len is specified
result = std::min(static_cast<int64_t>(seq_len), max_context_len) -
num_tokens_occupied;
} else {
// Both are specified
result = std::min(
std::min(static_cast<int64_t>(seq_len), max_context_len) -
num_tokens_occupied,
static_cast<int64_t>(max_new_tokens));
}
// Ensure result is not negative
return static_cast<int32_t>(std::max(static_cast<int64_t>(0), result));
}
};
// Base interface for LLM runners
class ET_EXPERIMENTAL IRunner {
public:
virtual ~IRunner() = default;
/**
* Check if the runner is loaded and ready for inference.
*
* @return true if the runner is loaded, false otherwise
*/
virtual bool is_loaded() const = 0;
/**
* Load the model and prepare for inference.
*
* @return Error::Ok if successful, an error otherwise
*/
virtual runtime::Error load() = 0;
/**
* Generate text based on the provided prompt and generation config.
*
* @param prompt The input prompt to generate from
* @param config Generation configuration parameters
* @param token_callback Callback function called for each generated token
* @param stats_callback Callback function for generation statistics
* @return Error::Ok if successful, an error otherwise
*/
virtual runtime::Error generate(
const std::string& prompt,
const GenerationConfig& config,
std::function<void(const std::string&)> token_callback,
std::function<void(const Stats&)> stats_callback) = 0;
/**
* Prefill multimodal inputs into the KV cache without generating.
*
* @param inputs A vector of MultimodalInput objects (text, tokens, images,
* audio)
* @param num_bos Number of BOS tokens to prepend during encoding
* @param num_eos Number of EOS tokens to append during encoding
* @return The next token predicted after prefill, or an error
*/
virtual runtime::Result<uint64_t> prefill(
const std::vector<MultimodalInput>& inputs,
int32_t num_bos = 0,
int32_t num_eos = 0) {
return runtime::Error::NotSupported;
}
/**
* Stop the generation process.
*/
virtual void stop() = 0;
/**
* Force remove prefilled tokens and reset KV cache start position
*
* This method removes the prefilled tokens from the KV cache and resets the
* start position to 0.
*/
virtual void reset() = 0;
};
} // namespace llm
} // namespace extension
} // namespace executorch