Skip to content

Commit 417218c

Browse files
kirklandsignfacebook-github-bot
authored andcommitted
Integrate LLM metadata NamedData into export/runner pipeline
Summary: This integrates the metadata storage POC (D104322796) into the real LLM export and runner pipeline. Metadata values (max_seq_len, bos_id, eos_ids, etc.) are now dual-written: both as constant_methods (for backward compatibility) and as NamedData entries (for efficient C++ runtime access). On the export side: - Added metadata.py with helpers to encode/decode metadata values into the PTE files NamedData section. - Modified builder.py to call add_metadata() after edge creation in both export_to_edge() and to_edge_transform_and_lower(). On the runner side: - Added metadata.h with typed accessors (get_int, get_string, get_int_list) for reading metadata from NamedDataMap. - Modified llm_runner_helper.cpp so get_llm_metadata() and get_eos_ids() try NamedData first, falling back to constant_methods for old PTE files. Key design decisions: - Backward compatible: constant_methods are NOT removed - Dual-write on export, prefer-NamedData on read with fallback - Failure to write NamedData is non-fatal (logged warning) - NamedData keys use dotted namespace: metadata.context.max_seq_len Differential Revision: D104471143
1 parent 9889c7c commit 417218c

6 files changed

Lines changed: 364 additions & 2 deletions

File tree

extension/llm/export/BUCK

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ fbcode_target(_kind = runtime.python_library,
5555
"//executorch/extension/export_util:export_util",
5656
"//executorch/extension/llm/custom_ops:custom_ops_aot_lib",
5757
"//executorch/extension/llm/custom_ops:custom_ops_aot_py",
58+
"//executorch/extension/llm/export:metadata",
5859
"//pytorch/tokenizers/pytorch_tokenizers:tokenizers",
5960
],
6061
)
@@ -108,3 +109,15 @@ fbcode_target(_kind = runtime.python_test,
108109
":export_lib",
109110
],
110111
)
112+
113+
fbcode_target(_kind = runtime.python_library,
114+
name = "metadata",
115+
srcs = [
116+
"metadata.py",
117+
],
118+
base_module = "executorch.extension.llm.export",
119+
visibility = ["PUBLIC"],
120+
deps = [
121+
"//executorch/exir:lib",
122+
],
123+
)

extension/llm/export/builder.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from executorch.extension.export_util.utils import export_to_edge, save_pte_program
3535

3636
from executorch.extension.llm.export.export_passes import RemoveRedundantTransposes
37+
from executorch.extension.llm.export.metadata import add_metadata
3738
from pytorch_tokenizers import get_tokenizer
3839
from torch.export import export, ExportedProgram
3940
from torch.nn.attention import SDPBackend
@@ -71,6 +72,18 @@ def from_torch_dtype(dtype: torch.dtype):
7172
return mapping[dtype]
7273

7374

75+
_CONSTANT_METHOD_TO_NAMED_DATA = {
76+
"get_bos_id": "tokenizer.bos_id",
77+
"get_eos_ids": "tokenizer.eos_ids",
78+
"get_max_seq_len": "context.max_seq_len",
79+
"get_max_context_len": "context.max_context_len",
80+
"get_vocab_size": "model.vocab_size",
81+
"use_kv_cache": "model.use_kv_cache",
82+
"use_sdpa_with_kv_cache": "model.use_sdpa_with_kv_cache",
83+
"enable_dynamic_shape": "model.enable_dynamic_shape",
84+
}
85+
86+
7487
class LLMEdgeManager:
7588
"""
7689
Host a torch.nn.Module for LLM model and facilitates exporting to ExecuTorch.
@@ -393,6 +406,28 @@ def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManage
393406
logging.info("No quantizer provided, passing...")
394407
return self
395408

409+
def _write_metadata_to_named_data(self):
410+
"""Write metadata to NamedData for efficient C++ runtime access.
411+
412+
This writes the same metadata values stored as constant_methods
413+
also as NamedData entries, enabling the C++ runner to read them
414+
without loading full ExecutionPlan entries.
415+
"""
416+
if self.edge_manager is None:
417+
return
418+
named_data = {}
419+
for key, value in self.metadata.items():
420+
nd_key = _CONSTANT_METHOD_TO_NAMED_DATA.get(key, key)
421+
named_data[nd_key] = value
422+
try:
423+
add_metadata(self.edge_manager, named_data)
424+
except Exception:
425+
# Don't fail the export if metadata writing fails
426+
logging.warning(
427+
"Failed to write metadata to NamedData, "
428+
"falling back to constant_methods only"
429+
)
430+
396431
def export_to_edge(self) -> "LLMEdgeManager":
397432
"""
398433
Export the model to Edge dialect and retrieve a LLMEdgeManager.
@@ -418,6 +453,7 @@ def export_to_edge(self) -> "LLMEdgeManager":
418453
edge_compile_config=edge_config,
419454
verbose=self.verbose,
420455
)
456+
self._write_metadata_to_named_data()
421457
return self
422458

423459
def to_backend(self, partitioners: Optional[List[Partitioner]]) -> "LLMEdgeManager":
@@ -470,6 +506,7 @@ def to_edge_transform_and_lower(
470506
constant_methods=self.metadata,
471507
generate_etrecord=self.generate_etrecord,
472508
)
509+
self._write_metadata_to_named_data()
473510
if self.verbose:
474511
logging.info(f"Exported graph:\n{self.edge_manager.exported_program()}")
475512
return self

extension/llm/export/metadata.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""Model metadata storage for PTE files.
8+
9+
Embeds model metadata (tokenizer config, chat templates, architecture info)
10+
directly in PTE files via the NamedData mechanism. Replaces the current
11+
constant_methods approach (which creates full ExecutionPlan entries for
12+
simple constant values).
13+
14+
Keys use a dotted namespace.field convention:
15+
tokenizer.bos_id, tokenizer.eos_ids, context.max_seq_len, etc.
16+
"""
17+
18+
import struct
19+
from typing import Dict, List, Sequence, Union
20+
21+
METADATA_PREFIX = "metadata."
22+
23+
MetadataValue = Union[str, int, float, bytes, Sequence[int]]
24+
25+
26+
def _encode_value(key: str, value: MetadataValue) -> bytes:
27+
if isinstance(value, str):
28+
return value.encode("utf-8")
29+
elif isinstance(value, (list, tuple)):
30+
return struct.pack(f"<I{len(value)}q", len(value), *value)
31+
elif isinstance(value, int):
32+
return struct.pack("<q", value)
33+
elif isinstance(value, float):
34+
return struct.pack("<d", value)
35+
elif isinstance(value, bytes):
36+
return value
37+
raise TypeError(f"Unsupported metadata type {type(value)} for key \'{key}\'")
38+
39+
40+
def add_metadata(
41+
edge_manager, # EdgeProgramManager
42+
metadata: Dict[str, MetadataValue],
43+
) -> None:
44+
"""Add metadata KV pairs to a PTE file during export.
45+
46+
Call BEFORE edge_manager.to_executorch().
47+
48+
Args:
49+
edge_manager: The EdgeProgramManager from to_edge() or
50+
to_edge_transform_and_lower().
51+
metadata: Dict mapping string keys to values (str, int, float, or bytes).
52+
Keys are automatically prefixed with "metadata." to avoid collision
53+
with backend named data.
54+
"""
55+
for key, value in metadata.items():
56+
edge_manager._named_data_store.add_named_data(
57+
key=f"{METADATA_PREFIX}{key}",
58+
data=_encode_value(key, value),
59+
)
60+
61+
62+
def read_metadata(pte_path: str) -> Dict[str, bytes]:
63+
"""Read all metadata entries from a PTE file.
64+
65+
Returns raw bytes for each key (without the "metadata." prefix).
66+
Use get_string/get_int/get_float for typed access.
67+
"""
68+
from executorch.exir._serialize._program import deserialize_pte_binary
69+
70+
with open(pte_path, "rb") as f:
71+
pte_data = f.read()
72+
73+
pte_file = deserialize_pte_binary(pte_data)
74+
75+
result = {}
76+
if pte_file.named_data is not None:
77+
for key, entry in pte_file.named_data.pte_data.items():
78+
if key.startswith(METADATA_PREFIX):
79+
short_key = key[len(METADATA_PREFIX):]
80+
result[short_key] = pte_file.named_data.buffers[entry.buffer_index]
81+
82+
return result
83+
84+
85+
def get_string(metadata: Dict[str, bytes], key: str) -> str:
86+
return metadata[key].decode("utf-8")
87+
88+
89+
def get_int(metadata: Dict[str, bytes], key: str) -> int:
90+
return struct.unpack("<q", metadata[key])[0]
91+
92+
93+
def get_float(metadata: Dict[str, bytes], key: str) -> float:
94+
return struct.unpack("<d", metadata[key])[0]
95+
96+
97+
def get_int_list(metadata: Dict[str, bytes], key: str) -> List[int]:
98+
data = metadata[key]
99+
(count,) = struct.unpack_from("<I", data, 0)
100+
return list(struct.unpack_from(f"<{count}q", data, 4))

extension/llm/runner/llm_runner_helper.cpp

Lines changed: 99 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include <executorch/extension/llm/runner/image_prefiller.h>
1212
#include <executorch/extension/llm/runner/llm_runner_helper.h>
13+
#include <executorch/extension/llm/runner/metadata.h>
1314
#include <executorch/extension/llm/runner/multimodal_decoder_runner.h>
1415
#include <executorch/extension/llm/runner/multimodal_prefiller.h>
1516
#include <executorch/extension/llm/runner/multimodal_runner.h>
@@ -99,7 +100,84 @@ get_llm_metadata(tokenizers::Tokenizer* tokenizer, Module* module) {
99100
{llm::kUseSDPAWithKVCache, false},
100101
});
101102

102-
// Read metadata from the model
103+
// Try reading from NamedDataMap first (new format)
104+
auto program = module->program();
105+
if (program) {
106+
auto ndm_result = program->get_named_data_map();
107+
if (ndm_result.ok() && ndm_result.get() != nullptr) {
108+
const auto* named_data_map = ndm_result.get();
109+
110+
// Map from runtime keys to NamedData keys
111+
struct KeyMapping {
112+
const char* runtime_key;
113+
const char* named_data_key;
114+
};
115+
static const KeyMapping mappings[] = {
116+
{llm::kMaxSeqLen, metadata::kMaxSeqLen},
117+
{llm::kMaxContextLen, metadata::kMaxContextLen},
118+
{llm::kUseKVCache, metadata::kUseKVCache},
119+
{llm::kEnableDynamicShape, metadata::kEnableDynamicShape},
120+
{llm::kUseSDPAWithKVCache, metadata::kUseSDPAWithKVCache},
121+
};
122+
123+
// Check if kMaxSeqLen exists in NamedData (required key)
124+
auto max_seq_result =
125+
metadata::get_int(*named_data_map, metadata::kMaxSeqLen);
126+
if (max_seq_result.ok()) {
127+
ET_LOG(Info, "Reading metadata from NamedData");
128+
129+
for (const auto& mapping : mappings) {
130+
auto val =
131+
metadata::get_int(*named_data_map, mapping.named_data_key);
132+
if (val.ok()) {
133+
metadata[mapping.runtime_key] = val.get();
134+
ET_LOG(
135+
Info,
136+
"NamedData: %s = %" PRId64,
137+
mapping.runtime_key,
138+
val.get());
139+
}
140+
}
141+
142+
// Read bos_id from NamedData
143+
auto bos_result =
144+
metadata::get_int(*named_data_map, metadata::kBosId);
145+
if (bos_result.ok()) {
146+
metadata[llm::kBosId] = bos_result.get();
147+
} else {
148+
metadata[llm::kBosId] = tokenizer->bos_tok();
149+
}
150+
151+
// Read vocab_size from NamedData
152+
auto vocab_result =
153+
metadata::get_int(*named_data_map, metadata::kVocabSize);
154+
if (vocab_result.ok()) {
155+
metadata[llm::kVocabSize] = vocab_result.get();
156+
} else {
157+
metadata[llm::kVocabSize] = tokenizer->vocab_size();
158+
}
159+
160+
// Handle kMaxContextLen default: if not explicitly set,
161+
// default to kMaxSeqLen
162+
if (metadata.find(llm::kMaxContextLen) == metadata.end() ||
163+
metadata[llm::kMaxContextLen] == 128) {
164+
auto ctx_result =
165+
metadata::get_int(*named_data_map, metadata::kMaxContextLen);
166+
if (!ctx_result.ok()) {
167+
metadata[llm::kMaxContextLen] = metadata[llm::kMaxSeqLen];
168+
}
169+
}
170+
171+
for (auto& pair : metadata) {
172+
ET_LOG(
173+
Info, "Metadata: %s = %" PRId64, pair.first.c_str(), pair.second);
174+
}
175+
return metadata;
176+
}
177+
}
178+
}
179+
180+
// Fallback: Read metadata from constant_methods (legacy format)
103181
auto method_names_result = module->method_names();
104182
if (method_names_result.error() != Error::Ok) {
105183
ET_LOG(Error, "Failed reading method names");
@@ -158,7 +236,26 @@ std::unordered_set<uint64_t> get_eos_ids(
158236
tokenizers::Tokenizer* tokenizer,
159237
Module* module) {
160238
std::unordered_set<uint64_t> eos_ids = {tokenizer->eos_tok()};
161-
// Get EOS IDs if available
239+
240+
// Try NamedData first (new format)
241+
auto program = module->program();
242+
if (program) {
243+
auto ndm_result = program->get_named_data_map();
244+
if (ndm_result.ok() && ndm_result.get() != nullptr) {
245+
auto eos_result =
246+
metadata::get_int_list(*ndm_result.get(), metadata::kEosIds);
247+
if (eos_result.ok()) {
248+
eos_ids.clear();
249+
for (auto id : eos_result.get()) {
250+
eos_ids.emplace(static_cast<uint64_t>(id));
251+
ET_LOG(Info, "NamedData eos_id = %" PRId64, id);
252+
}
253+
return eos_ids;
254+
}
255+
}
256+
}
257+
258+
// Fallback: Get EOS IDs from constant_methods (legacy format)
162259
auto method_names_result = module->method_names();
163260
if (method_names_result.error() != Error::Ok) {
164261
ET_LOG(Error, "Failed reading method names");

0 commit comments

Comments
 (0)