Skip to content

Commit f6be985

Browse files
xingguo01zingo
andauthored
LLM support: improve VGF export and calibration pipeline (pytorch#19157)
This is stacked on top of pytorch#19029 - make non-KV-cache example inputs match the static export window - fix PT2E calibration flow for padded prefixes and optional LM-Eval tasks - update SmolLM2 export settings used by the VGF PT2E workflow - Fix rope_theta in 135M_config.json to align with Hugging face model config cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson @robell @rascani Signed-off-by: Xingguo Li <xingguo.li@arm.com> Co-authored-by: Zingo Andersen <zingo.andersen@arm.com>
1 parent ea37954 commit f6be985

4 files changed

Lines changed: 183 additions & 73 deletions

File tree

examples/models/llama/eval_llama_lib.py

Lines changed: 70 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3+
# Copyright 2026 Arm Limited and/or its affiliates.
34
#
45
# This source code is licensed under the BSD-style license found in the
56
# LICENSE file in the root directory of this source tree.
@@ -46,9 +47,13 @@ def __init__(
4647
use_kv_cache: bool = False,
4748
generate_full_logits: bool = False,
4849
enable_dynamic_shape: bool = True,
50+
device: Optional[str] = None,
4951
):
5052
super().__init__(
51-
model=model, tokenizer=tokenizer, max_seq_length=max_seq_length
53+
model=model,
54+
tokenizer=tokenizer,
55+
max_seq_length=max_seq_length,
56+
device=device,
5257
)
5358
self._model = model.to(self.device)
5459
self._use_kv_cache = use_kv_cache
@@ -57,30 +62,70 @@ def __init__(
5762

5863
def _model_call(self, inps):
5964
if self._use_kv_cache:
60-
if not self._enable_dynamic_shape:
61-
# graph module exported without dynamic shape won't work with a different shape.
62-
# And we have to do single token prefill here.
63-
result_logits = []
64-
for pos in range(inps.shape[-1]):
65-
pos_tensor = torch.tensor([pos], dtype=torch.int64)
66-
logits = self._model(
67-
inps[:, pos : pos + 1], {"input_pos": pos_tensor}
68-
)
69-
result_logits.append(logits)
70-
if self._generate_full_logits:
71-
return torch.cat(result_logits, dim=1)
72-
else:
73-
return torch.stack(result_logits, dim=1)
74-
else:
75-
pos_tensor = torch.tensor([0], dtype=torch.int64, device=self.device)
76-
# Batch process the whole sequence.
77-
logits = self._model(
78-
inps[:, : self._max_seq_length], {"input_pos": pos_tensor}
79-
)
80-
return logits
65+
return self._model_call_kv_cache(inps)
66+
return self._model_call_no_kv_cache(inps)
8167

82-
else:
83-
return self._model(inps)
68+
def _model_call_kv_cache(self, inps):
69+
if self._enable_dynamic_shape:
70+
pos_tensor = torch.tensor([0], dtype=torch.int64, device=self.device)
71+
return self._model(
72+
inps[:, : self._max_seq_length], {"input_pos": pos_tensor}
73+
)
74+
75+
# graph module exported without dynamic shape won't work with a different shape.
76+
# And we have to do single token prefill here.
77+
result_logits = []
78+
for pos in range(inps.shape[-1]):
79+
pos_tensor = torch.tensor([pos], dtype=torch.int64)
80+
logits = self._model(inps[:, pos : pos + 1], {"input_pos": pos_tensor})
81+
result_logits.append(logits)
82+
if self._generate_full_logits:
83+
return torch.cat(result_logits, dim=1)
84+
return torch.stack(result_logits, dim=1)
85+
86+
def _model_call_no_kv_cache(self, inps):
87+
# lm-eval expects logits shaped [batch, seq, vocab]. In the non-KV path,
88+
# some exported graphs (when generate_full_logits=False) return only
89+
# last-position logits [batch, vocab], so reconstruct per-position
90+
# logits by running prefix calls.
91+
if not self._enable_dynamic_shape and not self._generate_full_logits:
92+
raise ValueError(
93+
"Static non-KV lm-eval requires generate_full_logits=True "
94+
"so logits can be read from the last non-pad token."
95+
)
96+
97+
if self._generate_full_logits:
98+
return self._model(self._pad_to_max_len(inps))
99+
100+
result_logits = []
101+
seq_len = inps.shape[-1]
102+
for pos in range(min(seq_len, self._max_seq_length)):
103+
prefix = self._pad_to_max_len(inps[:, : pos + 1])
104+
logits = self._model(prefix)
105+
if logits.dim() == 3:
106+
logits = logits[:, -1, :]
107+
result_logits.append(logits)
108+
109+
return torch.stack(result_logits, dim=1)
110+
111+
def _pad_to_max_len(self, tokens: torch.Tensor) -> torch.Tensor:
112+
if self._enable_dynamic_shape:
113+
return tokens
114+
token_len = tokens.shape[-1]
115+
if token_len > self._max_seq_length:
116+
return tokens[:, : self._max_seq_length]
117+
if token_len == self._max_seq_length:
118+
return tokens
119+
120+
pad_len = self._max_seq_length - token_len
121+
pad_token = getattr(self._tokenizer, "pad_id", self._tokenizer.eos_id)
122+
pad = torch.full(
123+
(tokens.shape[0], pad_len),
124+
pad_token,
125+
dtype=tokens.dtype,
126+
device=tokens.device,
127+
)
128+
return torch.cat((tokens, pad), dim=-1)
84129

85130
def _model_generate(self, context, max_length, eos_token_id):
86131
raise Exception("unimplemented")
@@ -219,6 +264,7 @@ def gen_eval_wrapper(
219264
tokenizer=tokenizer,
220265
max_seq_length=llm_config.export.max_seq_length,
221266
use_kv_cache=llm_config.model.use_kv_cache,
267+
generate_full_logits=llm_config.debug.generate_full_logits,
222268
enable_dynamic_shape=llm_config.model.enable_dynamic_shape,
223269
)
224270
else:

examples/models/llama/evaluate/eager_eval.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3+
# Copyright 2026 Arm Limited and/or its affiliates.
34
#
45
# This source code is licensed under the BSD-style license found in the
56
# LICENSE file in the root directory of this source tree.
@@ -28,12 +29,13 @@ def __init__(
2829
tokenizer: Union[SentencePieceTokenizer, Tiktoken, HuggingFaceTokenizer],
2930
max_seq_length: Optional[int] = None,
3031
use_kv_cache: bool = False,
32+
device: Optional[str] = None,
3133
):
32-
device = "cuda" if torch.cuda.is_available() else "cpu"
33-
super().__init__(device=device, pretrained="gpt2")
34+
resolved_device = device or ("cuda" if torch.cuda.is_available() else "cpu")
35+
super().__init__(device=resolved_device, pretrained="gpt2")
3436
self._model = model
3537
self._tokenizer = tokenizer
36-
self._device = torch.device(device)
38+
self._device = torch.device(resolved_device)
3739
self._max_seq_length = 2048 if max_seq_length is None else max_seq_length
3840
self._use_kv_cache = use_kv_cache
3941

examples/models/llama/model.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3+
# Copyright 2026 Arm Limited and/or its affiliates.
34
#
45
# This source code is licensed under the BSD-style license found in the
56
# LICENSE file in the root directory of this source tree.
@@ -285,11 +286,25 @@ def get_example_inputs(self):
285286
if self.use_kv_cache:
286287
return self.get_example_inputs_kvcache_sdpa()
287288
else:
288-
return (
289-
torch.tensor(
290-
[[1, 2, 3]], dtype=torch.long
291-
), # tokens, with kv cache our input token length is always just 1 token.
289+
max_seq_len = getattr(self.llm_config.export, "max_seq_length", 3)
290+
# Preserve the historical three-token example input as the minimum.
291+
max_seq_len = max(3, int(max_seq_len))
292+
max_len = max_seq_len - 1 if self.enable_dynamic_shape else max_seq_len
293+
backend = self.llm_config.backend
294+
token_dtype = (
295+
torch.int32
296+
if (
297+
backend.ethosu.enabled
298+
or backend.tosa.enabled
299+
or backend.vgf.enabled
300+
)
301+
else torch.long
292302
)
303+
example_tokens = torch.arange(max_len, dtype=token_dtype).unsqueeze(0)
304+
vocab_size = int(getattr(self.model_.params, "vocab_size", 0))
305+
if vocab_size > 1:
306+
example_tokens = example_tokens % (vocab_size - 1) + 1
307+
return (example_tokens,)
293308

294309
# assumption is the custom op doesnt support dynamic shape right now. It might but its untested so lets first get static shape working
295310
def get_example_inputs_kvcache_sdpa(self):

extension/llm/export/builder.py

Lines changed: 89 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,35 @@ def run_canonical_optimizations(self):
256256
assert res.graph_module is not None, "Pass returned None"
257257
self.pre_autograd_graph_module = res.graph_module
258258

259+
def _check_calibration_prefix_options(self) -> None:
260+
if (
261+
not self.use_kv_cache
262+
and not self.enable_dynamic_shape
263+
and not self.generate_full_logits
264+
):
265+
raise ValueError(
266+
"Static non-KV calibration with padded prefixes requires "
267+
"generate_full_logits so calibration can sample the last "
268+
"non-pad token position."
269+
)
270+
271+
def _prepare_calibration_prefix(
272+
self, token_list: List[int], pos: int, max_len: int, pad_token: int
273+
) -> Tuple[torch.Tensor, int]:
274+
prefix_tokens = list(token_list[: pos + 1])
275+
logits_token_pos = min(len(prefix_tokens), max_len) - 1
276+
277+
if self.enable_dynamic_shape:
278+
prefix_tokens = prefix_tokens[:max_len]
279+
elif len(prefix_tokens) < max_len:
280+
prefix_tokens.extend([pad_token] * (max_len - len(prefix_tokens)))
281+
else:
282+
prefix_tokens = prefix_tokens[:max_len]
283+
284+
input_dtype = self.example_inputs[0].dtype
285+
prefix = torch.tensor(prefix_tokens, dtype=input_dtype).unsqueeze(0)
286+
return prefix, logits_token_pos
287+
259288
def pt2e_calibrate(
260289
self,
261290
prepared_module,
@@ -266,39 +295,41 @@ def pt2e_calibrate(
266295
tokenizer_path,
267296
):
268297
logging.info("Run calibration...")
269-
try:
270-
from executorch.examples.models.llama.eval_llama_lib import (
271-
GraphModuleEvalWrapper,
272-
)
273-
from lm_eval.evaluator import simple_evaluate
274-
except ImportError:
275-
raise ImportError(
276-
"Please install the llm eval dependency via examples/models/llama/install_requirements.sh"
277-
)
278-
298+
self._check_calibration_prefix_options()
279299
tokenizer = get_tokenizer(tokenizer_path)
280300

281301
def calibrate_template(
282302
module: torch.fx.GraphModule, tokenizer, prompts: str, max_len: int
283303
):
284304
# TODO: change criteria & support batch inputs if necessary
285-
pos = torch.tensor(0, dtype=torch.int64)
305+
pos = 0
286306
token_list = tokenizer.encode(prompts, bos=True, eos=False)
287307

308+
pad_token = getattr(tokenizer, "pad_id", tokenizer.eos_id)
309+
288310
with torch.no_grad():
289311
while token_list[-1] != tokenizer.eos_id and pos < max_len:
290-
logits = module(
291-
torch.full((1, 1), token_list[pos]),
292-
{"input_pos": torch.tensor((pos,))},
293-
)
312+
logits_token_pos = -1
313+
if self.use_kv_cache:
314+
logits = module(
315+
torch.full((1, 1), token_list[pos]),
316+
{"input_pos": torch.tensor((pos,))},
317+
)
318+
else:
319+
prefix, logits_token_pos = self._prepare_calibration_prefix(
320+
token_list, pos, max_len, pad_token
321+
)
322+
logits = module(prefix)
323+
294324
pos += 1
295325
if pos >= len(token_list):
296326
if self.generate_full_logits:
297-
token_list.append(
298-
torch.argmax(logits[:, -1], dim=-1).item()
299-
)
327+
next_token = torch.argmax(
328+
logits[:, logits_token_pos], dim=-1
329+
).item()
300330
else:
301-
token_list.append(torch.argmax(logits[:], dim=-1).item())
331+
next_token = torch.argmax(logits[:], dim=-1).item()
332+
token_list.append(next_token)
302333

303334
calibrate_template(
304335
module=prepared_module,
@@ -307,26 +338,41 @@ def calibrate_template(
307338
max_len=calibration_seq_length,
308339
)
309340

310-
eval_wrapper = GraphModuleEvalWrapper(
311-
model=prepared_module,
312-
tokenizer=tokenizer,
313-
max_seq_length=calibration_seq_length,
314-
use_kv_cache=self.use_kv_cache,
315-
generate_full_logits=self.generate_full_logits,
316-
enable_dynamic_shape=self.enable_dynamic_shape,
317-
)
341+
if calibration_tasks:
342+
try:
343+
from executorch.examples.models.llama.eval_llama_lib import (
344+
GraphModuleEvalWrapper,
345+
)
346+
from lm_eval.evaluator import simple_evaluate
347+
except ImportError:
348+
raise ImportError(
349+
"Please install the llm eval dependency via examples/models/llama/install_requirements.sh"
350+
)
318351

319-
# Evaluate the model
320-
with torch.no_grad():
321-
eval_results = simple_evaluate(
322-
model=eval_wrapper,
323-
tasks=calibration_tasks,
324-
limit=calibration_limit,
352+
eval_wrapper = GraphModuleEvalWrapper(
353+
model=prepared_module,
354+
tokenizer=tokenizer,
355+
max_seq_length=calibration_seq_length,
356+
use_kv_cache=self.use_kv_cache,
357+
generate_full_logits=self.generate_full_logits,
358+
enable_dynamic_shape=self.enable_dynamic_shape,
359+
# The exported graph can contain ops like aten.full.default
360+
# without explicit device, which default to CPU and can
361+
# trigger device-mismatch errors when lm_eval runs on CUDA.
362+
# Calibrate on CPU for stability.
363+
device="cpu",
325364
)
326365

327-
for task, res in eval_results["results"].items():
328-
print(f"{task}: {res}")
329-
logging.info("Calibration finish...")
366+
with torch.no_grad():
367+
eval_results = simple_evaluate(
368+
model=eval_wrapper,
369+
tasks=calibration_tasks,
370+
limit=calibration_limit,
371+
)
372+
373+
for task, res in eval_results["results"].items():
374+
print(f"{task}: {res}")
375+
logging.info("Calibration finish...")
330376

331377
def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManager":
332378
"""
@@ -351,18 +397,19 @@ def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManage
351397
assert (
352398
self.pre_autograd_graph_module is not None
353399
), "Please run export() first"
400+
if self.calibration_tasks and self.calibration_limit is None:
401+
logging.warning(
402+
"calibration_tasks provided without calibration_limit; "
403+
"lm-eval will run the full task dataset during "
404+
"calibration."
405+
)
354406
m = prepare_pt2e(
355407
self.pre_autograd_graph_module, # pyre-ignore[6]
356408
composed_quantizer,
357409
)
358-
logging.info(
359-
f"Calibrating with tasks: {self.calibration_tasks}, limit: {self.calibration_limit}, calibration_data: {self.calibration_data}, tokenizer_path: {self.tokenizer_path}, seq_length: {self.calibration_seq_length}"
360-
)
361410
# Calibrate
362411
if (
363-
self.calibration_tasks is not None
364-
and self.calibration_limit is not None
365-
and self.calibration_seq_length is not None
412+
self.calibration_seq_length is not None
366413
and self.calibration_data is not None
367414
and self.tokenizer_path is not None
368415
):

0 commit comments

Comments
 (0)