forked from pytorch/executorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmultimodal_runner.h
More file actions
168 lines (151 loc) · 6.19 KB
/
multimodal_runner.h
File metadata and controls
168 lines (151 loc) · 6.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
// A simple multimodal LLM runner that includes preprocessing and post
// processing logic. The module takes in a string as input and emits a string as
// output.
#pragma once
#include <cstdint>
#include <functional>
#include <memory>
#include <string>
#include <unordered_map>
#include <executorch/extension/llm/runner/image.h>
#include <executorch/extension/llm/runner/image_prefiller.h>
#include <executorch/extension/llm/runner/io_manager/io_manager.h>
#include <executorch/extension/llm/runner/irunner.h>
#include <executorch/extension/llm/runner/multimodal_decoder_runner.h>
#include <executorch/extension/llm/runner/multimodal_input.h>
#include <executorch/extension/llm/runner/multimodal_prefiller.h>
#include <executorch/extension/llm/runner/stats.h>
#include <executorch/extension/llm/runner/text_decoder_runner.h>
#include <executorch/extension/llm/runner/text_prefiller.h>
#include <executorch/extension/llm/runner/text_token_generator.h>
#include <executorch/extension/llm/sampler/sampler.h>
#include <executorch/extension/module/module.h>
#include <pytorch/tokenizers/tokenizer.h>
// Helper functions are now in llm_runner_helper.h
// These are provided for backward compatibility
#include <executorch/extension/llm/runner/llm_runner_helper.h>
#ifdef CUDA_AVAILABLE
#include <executorch/backends/cuda/runtime/memory_tracker.h>
#endif
namespace executorch {
namespace extension {
namespace llm {
/**
* MultimodalRunner - A runner for multimodal input and text output LLMs
*
* This class is designed for Large Language Models that can process multimodal
* inputs (text, images, audio) and generate text outputs. It supports models
* like LLaVA, CLIP-based vision-language models, and speech-to-text models.
*
* Supported Model Architecture see README.md
*
* Key Features:
* - Supports mixed multimodal inputs in any order via
* std::vector<MultimodalInput>
* - Encoder handles non-text modalities (images, audio) → embeddings
* - Text tokenizer converts text tokens → embeddings
* - Embeddings are stitched together based on input ordering
* - Text decoder performs autoregressive generation with KV cache
* - Internal pos_ state tracks KV cache position across calls
* - GenerationConfig provides comprehensive control over generation parameters
*
* Usage:
* std::vector<MultimodalInput> inputs;
* inputs.emplace_back(make_text_input("Describe this image:"));
* inputs.emplace_back(make_image_input(std::move(image)));
*
* GenerationConfig config;
* config.max_new_tokens = 100;
* config.temperature = 0.7f;
*
* runner->generate(inputs, config, token_callback, stats_callback);
*/
class ET_EXPERIMENTAL MultimodalRunner {
public:
/**
* @brief Constructor for MultimodalRunner with dependency injection
*
* Creates a MultimodalRunner instance with all required components for
* multimodal text generation. Note that we don't directly call into
* `module` or `text_decoder_runner`, we take them to manage their lifecycles.
*
* @param metadata Key-value pairs containing model metadata (e.g.,
* vocab_size, context_length)
* @param tokenizer Tokenizer for converting between text and token IDs
* @param module The underlying model module that performs inference
* @param text_decoder_runner Component responsible for running the decoder
* part of the model
* @param multimodal_prefiller Component for prefilling multimodal inputs
* @param io_manager Component for handling I/O operations
* @param text_token_generator Component for generating tokens during the
* @param stats Statistics tracking object for performance monitoring
* decode phase
*/
explicit MultimodalRunner(
std::unordered_map<std::string, int64_t> metadata,
std::unique_ptr<::tokenizers::Tokenizer> tokenizer,
std::unique_ptr<Module> module,
std::unique_ptr<MultimodalDecoderRunner> text_decoder_runner,
std::unique_ptr<MultimodalPrefiller> multimodal_prefiller,
std::unique_ptr<IOManager> io_manager,
std::unique_ptr<TextTokenGenerator> text_token_generator,
std::unique_ptr<Stats> stats);
virtual bool is_loaded();
virtual ::executorch::runtime::Error load();
/**
* Generate tokens from the given multimodal inputs using GenerationConfig.
* @param inputs A vector of MultimodalInput objects containing images and
* text.
* @param config Generation configuration parameters.
* @param token_callback Callback function called for each generated token.
* @param stats_callback Callback function for generation statistics.
* @return The error code. KV cache position is tracked internally in pos_.
*/
virtual ::executorch::runtime::Error generate(
const std::vector<MultimodalInput>& inputs,
const GenerationConfig& config,
std::function<void(const std::string&)> token_callback = {},
std::function<void(const Stats&)> stats_callback = {});
/**
* Prefill multimodal inputs, for example to reload chat history.
* @param inputs A vector of MultimodalInput objects containing images and
* text.
* @return The error code. KV cache position is tracked internally in pos_.
*/
virtual ::executorch::runtime::Error prefill(
const std::vector<MultimodalInput>& inputs);
inline void stop() {
text_token_generator_->stop();
}
inline void reset() {
pos_ = 0;
stats_->reset();
}
virtual ~MultimodalRunner() = default;
protected:
// Components
std::unordered_map<std::string, int64_t> metadata_;
std::unique_ptr<::tokenizers::Tokenizer> tokenizer_;
std::unique_ptr<Module> module_;
std::unique_ptr<MultimodalDecoderRunner> text_decoder_runner_;
std::unique_ptr<MultimodalPrefiller> multimodal_prefiller_;
std::unique_ptr<IOManager> io_manager_;
std::unique_ptr<TextTokenGenerator> text_token_generator_;
std::unique_ptr<Stats> stats_;
#ifdef CUDA_AVAILABLE
std::unique_ptr<::executorch::backends::cuda::CudaMemoryTracker>
cuda_memory_tracker_;
#endif
// Internal state
int64_t pos_;
};
} // namespace llm
} // namespace extension
} // namespace executorch