Skip to content

Commit 6541f23

Browse files
committed
up
1 parent 3332741 commit 6541f23

7 files changed

Lines changed: 608 additions & 359 deletions

File tree

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
#!/usr/bin/env python3
2+
#
3+
# Copyright (c) Meta Platforms, Inc. and affiliates.
4+
# All rights reserved.
5+
#
6+
# This source code is licensed under the BSD-style license found in the
7+
# LICENSE file in the root directory of this source tree.
8+
9+
"""
10+
Run exported Llama model (from HuggingFace) using ExecuTorch pybindings.
11+
12+
This script runs models exported using export_llama_hf.py. It loads the tokenizer
13+
directly from HuggingFace using the same model ID used during export.
14+
15+
Usage:
16+
python -m executorch.backends.apple.mlx.examples.llama.run_llama_hf \
17+
--pte llama_hf.pte \
18+
--model-id unsloth/Llama-3.2-1B-Instruct \
19+
--prompt "Hello, world!"
20+
"""
21+
22+
import argparse
23+
import logging
24+
import time
25+
26+
import torch
27+
from executorch.runtime import Runtime, Verification
28+
from transformers import AutoTokenizer
29+
30+
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
31+
logging.basicConfig(level=logging.INFO, format=FORMAT)
32+
logger = logging.getLogger(__name__)
33+
34+
35+
def run_inference(
36+
pte_path: str,
37+
model_id: str,
38+
prompt: str,
39+
max_new_tokens: int = 50,
40+
) -> str:
41+
"""Run inference on the exported HuggingFace model."""
42+
logger.info(f"Loading tokenizer from HuggingFace: {model_id}...")
43+
tokenizer = AutoTokenizer.from_pretrained(model_id)
44+
45+
logger.info(f"Loading model from {pte_path}...")
46+
et_runtime = Runtime.get()
47+
program = et_runtime.load_program(pte_path, verification=Verification.Minimal)
48+
forward = program.load_method("forward")
49+
50+
logger.info(f"Encoding prompt: {prompt!r}")
51+
# Apply chat template for instruct models
52+
messages = [{"role": "user", "content": prompt}]
53+
formatted_prompt = tokenizer.apply_chat_template(
54+
messages, tokenize=False, add_generation_prompt=True
55+
)
56+
logger.info(f"Formatted prompt: {formatted_prompt!r}")
57+
input_ids = tokenizer.encode(formatted_prompt, return_tensors="pt")
58+
logger.info(f"Input shape: {input_ids.shape}")
59+
60+
generated_tokens = input_ids[0].tolist()
61+
seq_len = input_ids.shape[1]
62+
63+
# Prefill: process all input tokens at once
64+
logger.info("Running prefill...")
65+
start_time = time.time()
66+
67+
# cache_position must match the sequence length of input_ids
68+
# For prefill with N tokens, cache_position = [0, 1, 2, ..., N-1]
69+
cache_position = torch.arange(seq_len, dtype=torch.long)
70+
logger.info(
71+
f"Prefill: input_ids shape={input_ids.shape}, cache_position shape={cache_position.shape}"
72+
)
73+
outputs = forward.execute([input_ids, cache_position])
74+
logits = outputs[0]
75+
76+
prefill_time = time.time() - start_time
77+
logger.info(f"Prefill time: {prefill_time:.3f}s")
78+
logger.info(f"Output logits shape: {logits.shape}")
79+
80+
# Get the next token from the last position
81+
next_token_logits = logits[0, -1, :]
82+
next_token = torch.argmax(next_token_logits).item()
83+
generated_tokens.append(next_token)
84+
85+
# Decode: generate tokens one at a time
86+
logger.info(f"Generating {max_new_tokens} tokens...")
87+
decode_start = time.time()
88+
89+
for i in range(max_new_tokens - 1):
90+
# Position for the token we're about to process
91+
# After prefill of N tokens and generating 1 token, generated_tokens has N+1 items
92+
# The token we're processing (next_token) is at position len(generated_tokens)-1
93+
pos = len(generated_tokens) - 1
94+
input_pos = torch.tensor([pos], dtype=torch.long)
95+
# Input is just the last generated token
96+
token_input = torch.tensor([[next_token]], dtype=torch.long)
97+
98+
outputs = forward.execute([token_input, input_pos])
99+
logits = outputs[0]
100+
101+
next_token_logits = logits[0, -1, :]
102+
next_token = torch.argmax(next_token_logits).item()
103+
generated_tokens.append(next_token)
104+
105+
# Check for EOS
106+
if next_token == tokenizer.eos_token_id:
107+
logger.info(f"EOS token reached at position {i + 1}")
108+
break
109+
110+
decode_time = time.time() - decode_start
111+
tokens_per_sec = (len(generated_tokens) - input_ids.shape[1]) / decode_time
112+
logger.info(f"Decode time: {decode_time:.3f}s ({tokens_per_sec:.1f} tokens/sec)")
113+
114+
# Decode only the newly generated tokens (not the input prompt)
115+
new_tokens = generated_tokens[input_ids.shape[1]:]
116+
generated_text = tokenizer.decode(new_tokens, skip_special_tokens=True)
117+
return generated_text
118+
119+
120+
def main():
121+
parser = argparse.ArgumentParser(
122+
description="Run exported HuggingFace Llama model"
123+
)
124+
parser.add_argument(
125+
"--pte",
126+
type=str,
127+
default="llama_hf.pte",
128+
help="Path to the .pte file",
129+
)
130+
parser.add_argument(
131+
"--model-id",
132+
type=str,
133+
default="unsloth/Llama-3.2-1B-Instruct",
134+
help="HuggingFace model ID (used to load tokenizer)",
135+
)
136+
parser.add_argument(
137+
"--prompt",
138+
type=str,
139+
default="The quick brown fox",
140+
help="Input prompt",
141+
)
142+
parser.add_argument(
143+
"--max-new-tokens",
144+
type=int,
145+
default=50,
146+
help="Maximum number of new tokens to generate",
147+
)
148+
149+
args = parser.parse_args()
150+
151+
generated_text = run_inference(
152+
pte_path=args.pte,
153+
model_id=args.model_id,
154+
prompt=args.prompt,
155+
max_new_tokens=args.max_new_tokens,
156+
)
157+
158+
print("\n" + "=" * 60)
159+
print("Generated text:")
160+
print("=" * 60)
161+
print(generated_text)
162+
print("=" * 60)
163+
164+
165+
if __name__ == "__main__":
166+
main()

0 commit comments

Comments
 (0)