Skip to content

Commit 6d23e41

Browse files
authored
Extract shared multifunction PTE utilities to utils.py (#19035)
Differential Revision: D101887672 Pull Request resolved: #19035
1 parent 3ec63f4 commit 6d23e41

2 files changed

Lines changed: 228 additions & 192 deletions

File tree

examples/apple/coreml/llama/run_static_llm_multifunction.py

Lines changed: 10 additions & 192 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,16 @@
2222
import argparse
2323
import json
2424
import time
25-
from typing import Any, Dict, List, Tuple
25+
from typing import List
2626

2727
import torch
28-
import torch.utils._pytree as pytree
2928

29+
from executorch.examples.apple.coreml.llama.utils import (
30+
create_pte_wrapper,
31+
setup_multifunction_managers,
32+
)
3033
from executorch.examples.models.llama.model_args import ModelArgs
3134
from executorch.examples.models.llama.runner.generation import next_token
32-
from executorch.examples.models.llama.static_attention import StaticAttentionIOManager
3335
from executorch.runtime import Runtime
3436
from pytorch_tokenizers import get_tokenizer
3537

@@ -41,170 +43,6 @@ def get_stop_tokens(tokenizer) -> List[int]:
4143
return [tokenizer.eos_id]
4244

4345

44-
def create_pte_wrapper(
45-
decode_method,
46-
prefill_method,
47-
mgr: "StaticAttentionIOManager",
48-
prefill_seq_len: int,
49-
prefill_mask: Dict[str, torch.Tensor],
50-
):
51-
"""
52-
Create a wrapper function that adapts PTE execution to the interface
53-
expected by StaticAttentionIOManager.
54-
55-
This multifunction version selects between prefill and decode methods
56-
based on the input sequence length. Both methods use the SAME cache_len,
57-
so the cache buffer is shared directly without any slicing or copying.
58-
59-
The wrapper:
60-
- Takes (tokens, options_dict) like the eager model
61-
- Selects prefill or decode method based on token count
62-
- Uses the same cache buffer for both methods (no slicing needed)
63-
- Flattens inputs using pytree
64-
- Executes the appropriate PTE method
65-
- Reconstructs outputs to match eager model format: (logits, {"out_cache_state": (k_dict, v_dict)})
66-
67-
Args:
68-
decode_method: The PTE method for decode (seqlen=1)
69-
prefill_method: The PTE method for prefill (seqlen=input_len)
70-
mgr: StaticAttentionIOManager with caches sized for shared cache_len
71-
prefill_seq_len: The sequence length for prefill
72-
prefill_mask: Pre-computed mask tensor for prefill method
73-
"""
74-
75-
k_cache_keys = list(mgr.k_caches.keys())
76-
v_cache_keys = list(mgr.v_caches.keys())
77-
78-
timing_stats = {
79-
"flatten_time": 0.0,
80-
"execute_time": 0.0,
81-
"reconstruct_time": 0.0,
82-
"detection_time": 0.0,
83-
"options_build_time": 0.0,
84-
"call_count": 0,
85-
}
86-
87-
def wrapper(
88-
tokens: torch.Tensor, options: Dict[str, Any]
89-
) -> Tuple[torch.Tensor, Dict[str, Any]]:
90-
import time as time_module
91-
92-
timing_stats["call_count"] += 1
93-
94-
t0 = time_module.perf_counter()
95-
96-
# Detect actual sequence length.
97-
# StaticAttentionIOManager._run_once pads tokens with zeros on the right.
98-
# For decode (1 actual token), positions 1+ are all zeros.
99-
padded_seq_len = tokens.shape[1]
100-
if padded_seq_len > 1 and (tokens[0, 1:] == 0).all():
101-
actual_seq_len = 1
102-
else:
103-
actual_seq_len = padded_seq_len
104-
105-
is_prefill = actual_seq_len == prefill_seq_len
106-
107-
t1 = time_module.perf_counter()
108-
timing_stats["detection_time"] += t1 - t0
109-
110-
t0 = time_module.perf_counter()
111-
112-
# Get the input cache state from options
113-
in_k_caches, in_v_caches = options["in_cache_state"]
114-
115-
# Both prefill and decode use the same cache_len, so no slicing needed!
116-
# Just select the appropriate method and mask.
117-
if is_prefill:
118-
method = prefill_method
119-
adapted_mask = prefill_mask
120-
else:
121-
method = decode_method
122-
adapted_mask = mgr.masks
123-
124-
adapted_options = {
125-
"masks": adapted_mask,
126-
"freqs_cos_override": options["freqs_cos_override"],
127-
"freqs_sin_override": options["freqs_sin_override"],
128-
"in_cache_state": (in_k_caches, in_v_caches), # Same cache for both!
129-
}
130-
131-
if "last_valid_token_pos" in options:
132-
adapted_options["last_valid_token_pos"] = options["last_valid_token_pos"]
133-
134-
inputs = (tokens, adapted_options)
135-
136-
t1 = time_module.perf_counter()
137-
timing_stats["options_build_time"] += t1 - t0
138-
139-
t0 = time_module.perf_counter()
140-
flat_inputs, _ = pytree.tree_flatten(inputs)
141-
t1 = time_module.perf_counter()
142-
timing_stats["flatten_time"] += t1 - t0
143-
144-
t0 = time_module.perf_counter()
145-
outputs = method.execute(flat_inputs)
146-
t1 = time_module.perf_counter()
147-
timing_stats["execute_time"] += t1 - t0
148-
149-
t0 = time_module.perf_counter()
150-
151-
logits = outputs[0]
152-
153-
num_layers = len(k_cache_keys)
154-
k_updates = outputs[1 : 1 + num_layers]
155-
v_updates = outputs[1 + num_layers : 1 + 2 * num_layers]
156-
157-
k_cache_dict = dict(zip(k_cache_keys, k_updates))
158-
v_cache_dict = dict(zip(v_cache_keys, v_updates))
159-
160-
attn_updates = {"out_cache_state": (k_cache_dict, v_cache_dict)}
161-
162-
t1 = time_module.perf_counter()
163-
timing_stats["reconstruct_time"] += t1 - t0
164-
165-
return logits, attn_updates
166-
167-
def print_timing_stats():
168-
n = timing_stats["call_count"]
169-
if n > 0:
170-
print(f"\n=== Wrapper Timing Stats ({n} calls) ===")
171-
print(
172-
f" Detection time: {timing_stats['detection_time']*1000:.2f}ms total, {timing_stats['detection_time']/n*1000:.4f}ms avg"
173-
)
174-
print(
175-
f" Options build: {timing_stats['options_build_time']*1000:.2f}ms total, {timing_stats['options_build_time']/n*1000:.4f}ms avg"
176-
)
177-
print(
178-
f" Flatten time: {timing_stats['flatten_time']*1000:.2f}ms total, {timing_stats['flatten_time']/n*1000:.4f}ms avg"
179-
)
180-
print(
181-
f" Execute time: {timing_stats['execute_time']*1000:.2f}ms total, {timing_stats['execute_time']/n*1000:.3f}ms avg"
182-
)
183-
print(
184-
f" Reconstruct time: {timing_stats['reconstruct_time']*1000:.2f}ms total, {timing_stats['reconstruct_time']/n*1000:.4f}ms avg"
185-
)
186-
total = (
187-
timing_stats["detection_time"]
188-
+ timing_stats["options_build_time"]
189-
+ timing_stats["flatten_time"]
190-
+ timing_stats["execute_time"]
191-
+ timing_stats["reconstruct_time"]
192-
)
193-
print(
194-
f" Total wrapper: {total*1000:.2f}ms total, {total/n*1000:.3f}ms avg"
195-
)
196-
print(
197-
f" Execute is {timing_stats['execute_time']/total*100:.1f}% of wrapper time"
198-
)
199-
expected_tps = 1000 / (timing_stats["execute_time"] / n * 1000)
200-
print(f" Expected tok/s from execute alone: {expected_tps:.1f}")
201-
202-
wrapper.print_timing_stats = print_timing_stats
203-
wrapper.timing_stats = timing_stats
204-
205-
return wrapper
206-
207-
20846
def main():
20947
parser = argparse.ArgumentParser(
21048
description="Run multifunction static attention Llama model"
@@ -326,36 +164,16 @@ def main():
326164
print(f"Prefill: input_len={prefill_input_len}, cache_len={shared_cache_len}")
327165
print(f"Decode: input_len={decode_input_len}, cache_len={shared_cache_len}")
328166

329-
# Create decode manager (input_len=1) - used for decode phase
330-
mgr = StaticAttentionIOManager(
331-
model_args,
332-
input_len=decode_input_len,
333-
cache_lens=shared_cache_len,
334-
batch_size=1,
335-
dtype=torch.float16,
336-
style="smart_mask",
337-
mask_val=float("-inf"),
338-
)
339-
340-
# Create prefill manager (input_len=64) with the SAME cache_len.
341-
# Since both use the same cache_len, we can share the cache buffer directly.
342-
prefill_mgr = StaticAttentionIOManager(
167+
# Create managers with shared cache buffers
168+
mgr, prefill_mgr, prefill_mask = setup_multifunction_managers(
343169
model_args,
344-
input_len=prefill_input_len,
345-
cache_lens=shared_cache_len, # Same cache_len as decode!
346-
batch_size=1,
170+
prefill_input_len,
171+
decode_input_len,
172+
shared_cache_len,
347173
dtype=torch.float16,
348-
style="smart_mask",
349174
mask_val=float("-inf"),
350175
)
351176

352-
# Share cache buffers: point prefill_mgr's caches to mgr's caches.
353-
# No copying needed since both managers use the same cache_len!
354-
prefill_mgr.k_caches = mgr.k_caches
355-
prefill_mgr.v_caches = mgr.v_caches
356-
357-
prefill_mask = prefill_mgr.masks
358-
359177
# Load PTE model with multifunction support
360178
print(f"Loading multifunction model from {args.model}...")
361179
runtime = Runtime.get()

0 commit comments

Comments
 (0)