Skip to content

Commit 44a91bf

Browse files
Arm backend: Enable and support KV cache on Llama (pytorch#20026)
- Run llama with use_kv_cache option - Add LlamaPositionalAdapter to handle input_pos mismatch - Extract USER_OUTPUT in arm test pipeline in order to avoid irrelevant cache data being accidentally analysed against the ref model cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson @robell @rascani Signed-off-by: Christoffer J.L <christoffer.johanssonlundqvist@arm.com>
1 parent e0b6574 commit 44a91bf

2 files changed

Lines changed: 28 additions & 1 deletion

File tree

backends/arm/test/models/test_llama.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from transformers import GenerationConfig, LlamaConfig, LlamaForCausalLM
3535
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
3636

37-
input_t = Tuple[torch.Tensor]
37+
input_t = Tuple[torch.Tensor, ...]
3838
input_th = Tuple[torch.Tensor, torch.Tensor]
3939

4040
# Add project dir to sys path to workaround importlib.import_module() conditions in model_factory.py
@@ -61,6 +61,15 @@ def forward(self, input_ids, cache_position):
6161
return self.inner(input_ids=input_ids, cache_position=cp)
6262

6363

64+
class LlamaPositionalAdapter(torch.nn.Module):
65+
def __init__(self, model):
66+
super().__init__()
67+
self.model = model
68+
69+
def forward(self, tokens, input_pos):
70+
return self.model(tokens, {"input_pos": input_pos})
71+
72+
6473
class TestLlama:
6574
"""Test class of Llama models.
6675
@@ -154,6 +163,7 @@ def prepare_model(self):
154163
params_file,
155164
"--model",
156165
model_name,
166+
"--use_kv_cache",
157167
]
158168

159169
parser = build_args_parser()
@@ -162,6 +172,11 @@ def prepare_model(self):
162172

163173
llama_model, llama_inputs, llama_meta = get_llama_model(llm_config)
164174

175+
if llm_config.model.use_kv_cache:
176+
tokens, attn_options = llama_inputs
177+
llama_model = LlamaPositionalAdapter(llama_model).eval()
178+
llama_inputs = (tokens, attn_options["input_pos"])
179+
165180
return llama_model, llama_inputs, llama_meta
166181

167182

backends/arm/test/tester/arm_tester.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -641,6 +641,18 @@ def run_method_and_compare_outputs(
641641
test_stage.run_artifact(test_input)
642642
)
643643

644+
# When we run with KV cache enabled, the model returns cache data in the results. This we need to strip away by extracting only USER_OUTPUT.
645+
if hasattr(test_stage.artifact, "exported_program"):
646+
output_specs = (
647+
test_stage.artifact.exported_program().graph_signature.output_specs
648+
)
649+
user_outputs = [
650+
output
651+
for output, spec in zip(test_outputs, output_specs)
652+
if spec.kind == OutputKind.USER_OUTPUT
653+
]
654+
test_outputs = user_outputs
655+
644656
logger.info(f"\n Input: {original_input}")
645657
logger.info(f"\n Ref output: {reference_outputs}")
646658
logger.info(f"\nTest output: {test_outputs}")

0 commit comments

Comments
 (0)