|
10 | 10 |
|
11 | 11 | #include <executorch/extension/llm/runner/image_prefiller.h> |
12 | 12 | #include <executorch/extension/llm/runner/llm_runner_helper.h> |
| 13 | +#include <executorch/extension/llm/runner/metadata.h> |
13 | 14 | #include <executorch/extension/llm/runner/multimodal_decoder_runner.h> |
14 | 15 | #include <executorch/extension/llm/runner/multimodal_prefiller.h> |
15 | 16 | #include <executorch/extension/llm/runner/multimodal_runner.h> |
@@ -99,7 +100,84 @@ get_llm_metadata(tokenizers::Tokenizer* tokenizer, Module* module) { |
99 | 100 | {llm::kUseSDPAWithKVCache, false}, |
100 | 101 | }); |
101 | 102 |
|
102 | | - // Read metadata from the model |
| 103 | + // Try reading from NamedDataMap first (new format) |
| 104 | + auto program = module->program(); |
| 105 | + if (program) { |
| 106 | + auto ndm_result = program->get_named_data_map(); |
| 107 | + if (ndm_result.ok() && ndm_result.get() != nullptr) { |
| 108 | + const auto* named_data_map = ndm_result.get(); |
| 109 | + |
| 110 | + // Map from runtime keys to NamedData keys |
| 111 | + struct KeyMapping { |
| 112 | + const char* runtime_key; |
| 113 | + const char* named_data_key; |
| 114 | + }; |
| 115 | + static const KeyMapping mappings[] = { |
| 116 | + {llm::kMaxSeqLen, metadata::kMaxSeqLen}, |
| 117 | + {llm::kMaxContextLen, metadata::kMaxContextLen}, |
| 118 | + {llm::kUseKVCache, metadata::kUseKVCache}, |
| 119 | + {llm::kEnableDynamicShape, metadata::kEnableDynamicShape}, |
| 120 | + {llm::kUseSDPAWithKVCache, metadata::kUseSDPAWithKVCache}, |
| 121 | + }; |
| 122 | + |
| 123 | + // Check if kMaxSeqLen exists in NamedData (required key) |
| 124 | + auto max_seq_result = |
| 125 | + metadata::get_int(*named_data_map, metadata::kMaxSeqLen); |
| 126 | + if (max_seq_result.ok()) { |
| 127 | + ET_LOG(Info, "Reading metadata from NamedData"); |
| 128 | + |
| 129 | + for (const auto& mapping : mappings) { |
| 130 | + auto val = |
| 131 | + metadata::get_int(*named_data_map, mapping.named_data_key); |
| 132 | + if (val.ok()) { |
| 133 | + metadata[mapping.runtime_key] = val.get(); |
| 134 | + ET_LOG( |
| 135 | + Info, |
| 136 | + "NamedData: %s = %" PRId64, |
| 137 | + mapping.runtime_key, |
| 138 | + val.get()); |
| 139 | + } |
| 140 | + } |
| 141 | + |
| 142 | + // Read bos_id from NamedData |
| 143 | + auto bos_result = |
| 144 | + metadata::get_int(*named_data_map, metadata::kBosId); |
| 145 | + if (bos_result.ok()) { |
| 146 | + metadata[llm::kBosId] = bos_result.get(); |
| 147 | + } else { |
| 148 | + metadata[llm::kBosId] = tokenizer->bos_tok(); |
| 149 | + } |
| 150 | + |
| 151 | + // Read vocab_size from NamedData |
| 152 | + auto vocab_result = |
| 153 | + metadata::get_int(*named_data_map, metadata::kVocabSize); |
| 154 | + if (vocab_result.ok()) { |
| 155 | + metadata[llm::kVocabSize] = vocab_result.get(); |
| 156 | + } else { |
| 157 | + metadata[llm::kVocabSize] = tokenizer->vocab_size(); |
| 158 | + } |
| 159 | + |
| 160 | + // Handle kMaxContextLen default: if not explicitly set, |
| 161 | + // default to kMaxSeqLen |
| 162 | + if (metadata.find(llm::kMaxContextLen) == metadata.end() || |
| 163 | + metadata[llm::kMaxContextLen] == 128) { |
| 164 | + auto ctx_result = |
| 165 | + metadata::get_int(*named_data_map, metadata::kMaxContextLen); |
| 166 | + if (!ctx_result.ok()) { |
| 167 | + metadata[llm::kMaxContextLen] = metadata[llm::kMaxSeqLen]; |
| 168 | + } |
| 169 | + } |
| 170 | + |
| 171 | + for (auto& pair : metadata) { |
| 172 | + ET_LOG( |
| 173 | + Info, "Metadata: %s = %" PRId64, pair.first.c_str(), pair.second); |
| 174 | + } |
| 175 | + return metadata; |
| 176 | + } |
| 177 | + } |
| 178 | + } |
| 179 | + |
| 180 | + // Fallback: Read metadata from constant_methods (legacy format) |
103 | 181 | auto method_names_result = module->method_names(); |
104 | 182 | if (method_names_result.error() != Error::Ok) { |
105 | 183 | ET_LOG(Error, "Failed reading method names"); |
@@ -158,7 +236,26 @@ std::unordered_set<uint64_t> get_eos_ids( |
158 | 236 | tokenizers::Tokenizer* tokenizer, |
159 | 237 | Module* module) { |
160 | 238 | std::unordered_set<uint64_t> eos_ids = {tokenizer->eos_tok()}; |
161 | | - // Get EOS IDs if available |
| 239 | + |
| 240 | + // Try NamedData first (new format) |
| 241 | + auto program = module->program(); |
| 242 | + if (program) { |
| 243 | + auto ndm_result = program->get_named_data_map(); |
| 244 | + if (ndm_result.ok() && ndm_result.get() != nullptr) { |
| 245 | + auto eos_result = |
| 246 | + metadata::get_int_list(*ndm_result.get(), metadata::kEosIds); |
| 247 | + if (eos_result.ok()) { |
| 248 | + eos_ids.clear(); |
| 249 | + for (auto id : eos_result.get()) { |
| 250 | + eos_ids.emplace(static_cast<uint64_t>(id)); |
| 251 | + ET_LOG(Info, "NamedData eos_id = %" PRId64, id); |
| 252 | + } |
| 253 | + return eos_ids; |
| 254 | + } |
| 255 | + } |
| 256 | + } |
| 257 | + |
| 258 | + // Fallback: Get EOS IDs from constant_methods (legacy format) |
162 | 259 | auto method_names_result = module->method_names(); |
163 | 260 | if (method_names_result.error() != Error::Ok) { |
164 | 261 | ET_LOG(Error, "Failed reading method names"); |
|
0 commit comments