2222import argparse
2323import json
2424import time
25- from typing import Any , Dict , List , Tuple
25+ from typing import List
2626
2727import torch
28- import torch .utils ._pytree as pytree
2928
29+ from executorch .examples .apple .coreml .llama .utils import (
30+ create_pte_wrapper ,
31+ setup_multifunction_managers ,
32+ )
3033from executorch .examples .models .llama .model_args import ModelArgs
3134from executorch .examples .models .llama .runner .generation import next_token
32- from executorch .examples .models .llama .static_attention import StaticAttentionIOManager
3335from executorch .runtime import Runtime
3436from pytorch_tokenizers import get_tokenizer
3537
@@ -41,170 +43,6 @@ def get_stop_tokens(tokenizer) -> List[int]:
4143 return [tokenizer .eos_id ]
4244
4345
44- def create_pte_wrapper (
45- decode_method ,
46- prefill_method ,
47- mgr : "StaticAttentionIOManager" ,
48- prefill_seq_len : int ,
49- prefill_mask : Dict [str , torch .Tensor ],
50- ):
51- """
52- Create a wrapper function that adapts PTE execution to the interface
53- expected by StaticAttentionIOManager.
54-
55- This multifunction version selects between prefill and decode methods
56- based on the input sequence length. Both methods use the SAME cache_len,
57- so the cache buffer is shared directly without any slicing or copying.
58-
59- The wrapper:
60- - Takes (tokens, options_dict) like the eager model
61- - Selects prefill or decode method based on token count
62- - Uses the same cache buffer for both methods (no slicing needed)
63- - Flattens inputs using pytree
64- - Executes the appropriate PTE method
65- - Reconstructs outputs to match eager model format: (logits, {"out_cache_state": (k_dict, v_dict)})
66-
67- Args:
68- decode_method: The PTE method for decode (seqlen=1)
69- prefill_method: The PTE method for prefill (seqlen=input_len)
70- mgr: StaticAttentionIOManager with caches sized for shared cache_len
71- prefill_seq_len: The sequence length for prefill
72- prefill_mask: Pre-computed mask tensor for prefill method
73- """
74-
75- k_cache_keys = list (mgr .k_caches .keys ())
76- v_cache_keys = list (mgr .v_caches .keys ())
77-
78- timing_stats = {
79- "flatten_time" : 0.0 ,
80- "execute_time" : 0.0 ,
81- "reconstruct_time" : 0.0 ,
82- "detection_time" : 0.0 ,
83- "options_build_time" : 0.0 ,
84- "call_count" : 0 ,
85- }
86-
87- def wrapper (
88- tokens : torch .Tensor , options : Dict [str , Any ]
89- ) -> Tuple [torch .Tensor , Dict [str , Any ]]:
90- import time as time_module
91-
92- timing_stats ["call_count" ] += 1
93-
94- t0 = time_module .perf_counter ()
95-
96- # Detect actual sequence length.
97- # StaticAttentionIOManager._run_once pads tokens with zeros on the right.
98- # For decode (1 actual token), positions 1+ are all zeros.
99- padded_seq_len = tokens .shape [1 ]
100- if padded_seq_len > 1 and (tokens [0 , 1 :] == 0 ).all ():
101- actual_seq_len = 1
102- else :
103- actual_seq_len = padded_seq_len
104-
105- is_prefill = actual_seq_len == prefill_seq_len
106-
107- t1 = time_module .perf_counter ()
108- timing_stats ["detection_time" ] += t1 - t0
109-
110- t0 = time_module .perf_counter ()
111-
112- # Get the input cache state from options
113- in_k_caches , in_v_caches = options ["in_cache_state" ]
114-
115- # Both prefill and decode use the same cache_len, so no slicing needed!
116- # Just select the appropriate method and mask.
117- if is_prefill :
118- method = prefill_method
119- adapted_mask = prefill_mask
120- else :
121- method = decode_method
122- adapted_mask = mgr .masks
123-
124- adapted_options = {
125- "masks" : adapted_mask ,
126- "freqs_cos_override" : options ["freqs_cos_override" ],
127- "freqs_sin_override" : options ["freqs_sin_override" ],
128- "in_cache_state" : (in_k_caches , in_v_caches ), # Same cache for both!
129- }
130-
131- if "last_valid_token_pos" in options :
132- adapted_options ["last_valid_token_pos" ] = options ["last_valid_token_pos" ]
133-
134- inputs = (tokens , adapted_options )
135-
136- t1 = time_module .perf_counter ()
137- timing_stats ["options_build_time" ] += t1 - t0
138-
139- t0 = time_module .perf_counter ()
140- flat_inputs , _ = pytree .tree_flatten (inputs )
141- t1 = time_module .perf_counter ()
142- timing_stats ["flatten_time" ] += t1 - t0
143-
144- t0 = time_module .perf_counter ()
145- outputs = method .execute (flat_inputs )
146- t1 = time_module .perf_counter ()
147- timing_stats ["execute_time" ] += t1 - t0
148-
149- t0 = time_module .perf_counter ()
150-
151- logits = outputs [0 ]
152-
153- num_layers = len (k_cache_keys )
154- k_updates = outputs [1 : 1 + num_layers ]
155- v_updates = outputs [1 + num_layers : 1 + 2 * num_layers ]
156-
157- k_cache_dict = dict (zip (k_cache_keys , k_updates ))
158- v_cache_dict = dict (zip (v_cache_keys , v_updates ))
159-
160- attn_updates = {"out_cache_state" : (k_cache_dict , v_cache_dict )}
161-
162- t1 = time_module .perf_counter ()
163- timing_stats ["reconstruct_time" ] += t1 - t0
164-
165- return logits , attn_updates
166-
167- def print_timing_stats ():
168- n = timing_stats ["call_count" ]
169- if n > 0 :
170- print (f"\n === Wrapper Timing Stats ({ n } calls) ===" )
171- print (
172- f" Detection time: { timing_stats ['detection_time' ]* 1000 :.2f} ms total, { timing_stats ['detection_time' ]/ n * 1000 :.4f} ms avg"
173- )
174- print (
175- f" Options build: { timing_stats ['options_build_time' ]* 1000 :.2f} ms total, { timing_stats ['options_build_time' ]/ n * 1000 :.4f} ms avg"
176- )
177- print (
178- f" Flatten time: { timing_stats ['flatten_time' ]* 1000 :.2f} ms total, { timing_stats ['flatten_time' ]/ n * 1000 :.4f} ms avg"
179- )
180- print (
181- f" Execute time: { timing_stats ['execute_time' ]* 1000 :.2f} ms total, { timing_stats ['execute_time' ]/ n * 1000 :.3f} ms avg"
182- )
183- print (
184- f" Reconstruct time: { timing_stats ['reconstruct_time' ]* 1000 :.2f} ms total, { timing_stats ['reconstruct_time' ]/ n * 1000 :.4f} ms avg"
185- )
186- total = (
187- timing_stats ["detection_time" ]
188- + timing_stats ["options_build_time" ]
189- + timing_stats ["flatten_time" ]
190- + timing_stats ["execute_time" ]
191- + timing_stats ["reconstruct_time" ]
192- )
193- print (
194- f" Total wrapper: { total * 1000 :.2f} ms total, { total / n * 1000 :.3f} ms avg"
195- )
196- print (
197- f" Execute is { timing_stats ['execute_time' ]/ total * 100 :.1f} % of wrapper time"
198- )
199- expected_tps = 1000 / (timing_stats ["execute_time" ] / n * 1000 )
200- print (f" Expected tok/s from execute alone: { expected_tps :.1f} " )
201-
202- wrapper .print_timing_stats = print_timing_stats
203- wrapper .timing_stats = timing_stats
204-
205- return wrapper
206-
207-
20846def main ():
20947 parser = argparse .ArgumentParser (
21048 description = "Run multifunction static attention Llama model"
@@ -326,36 +164,16 @@ def main():
326164 print (f"Prefill: input_len={ prefill_input_len } , cache_len={ shared_cache_len } " )
327165 print (f"Decode: input_len={ decode_input_len } , cache_len={ shared_cache_len } " )
328166
329- # Create decode manager (input_len=1) - used for decode phase
330- mgr = StaticAttentionIOManager (
331- model_args ,
332- input_len = decode_input_len ,
333- cache_lens = shared_cache_len ,
334- batch_size = 1 ,
335- dtype = torch .float16 ,
336- style = "smart_mask" ,
337- mask_val = float ("-inf" ),
338- )
339-
340- # Create prefill manager (input_len=64) with the SAME cache_len.
341- # Since both use the same cache_len, we can share the cache buffer directly.
342- prefill_mgr = StaticAttentionIOManager (
167+ # Create managers with shared cache buffers
168+ mgr , prefill_mgr , prefill_mask = setup_multifunction_managers (
343169 model_args ,
344- input_len = prefill_input_len ,
345- cache_lens = shared_cache_len , # Same cache_len as decode!
346- batch_size = 1 ,
170+ prefill_input_len ,
171+ decode_input_len ,
172+ shared_cache_len ,
347173 dtype = torch .float16 ,
348- style = "smart_mask" ,
349174 mask_val = float ("-inf" ),
350175 )
351176
352- # Share cache buffers: point prefill_mgr's caches to mgr's caches.
353- # No copying needed since both managers use the same cache_len!
354- prefill_mgr .k_caches = mgr .k_caches
355- prefill_mgr .v_caches = mgr .v_caches
356-
357- prefill_mask = prefill_mgr .masks
358-
359177 # Load PTE model with multifunction support
360178 print (f"Loading multifunction model from { args .model } ..." )
361179 runtime = Runtime .get ()
0 commit comments