99generic MultimodalRunner (C++ llava_main).
1010
1111Methods:
12- vision_encoder : [1, 3, 512, 512] f32 NCHW pixels [0,255] -> [1, 256, 2048] f32
12+ vision_encoder : [1, 3, 512, 512] f32 NCHW pixels [0,255] -> [1, 256, 2048] f32/f16
1313 token_embedding : [1, seq_len] i64 -> [1, seq_len, 2048] f32
1414 text_decoder : ([1, seq_len, 2048] f32, [seq_len] i64) -> [1, 65536] f32
1515
5353)
5454from executorch .exir .passes import MemoryPlanningPass
5555from executorch .exir .passes .quant_fusion_pass import QuantFusionPass
56- from executorch .exir .passes .sym_shape_eval_pass import (
57- ConstraintBasedSymShapeEvalPass ,
58- HintBasedSymShapeEvalPass ,
59- )
56+ from executorch .exir .passes .sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
6057from executorch .extension .llm .export .builder import DType , LLMEdgeManager
6158from executorch .extension .llm .export .config .llm_config import LlmConfig
6259from torch .export import Dim
@@ -87,14 +84,18 @@ def export(self) -> "Lfm2p5VlEdgeManager":
8784 return self
8885
8986
90- def export_image_encoder (lfm2 ) -> torch .export .ExportedProgram :
87+ def export_image_encoder (
88+ lfm2 , dtype : DType = DType .fp32
89+ ) -> torch .export .ExportedProgram :
9190 """Export vision encoder as 'vision_encoder' method.
9291
9392 Input: [1, 3, 512, 512] float32 NCHW pixels in [0, 255]
94- Output: [1, 256, 2048] float32 image embeddings
93+ Output: [1, 256, 2048] f32/f16 image embeddings
9594
9695 Normalize + patch extraction are baked in so the C++ runner only
9796 needs to resize to 512x512 and pass the raw pixel buffer.
97+ Weights are cast to dtype (fp16 halves the ~1.6 GB vision encoder).
98+ The input pixel tensor always stays fp32.
9899 """
99100
100101 class ImageEncoder (torch .nn .Module ):
@@ -106,11 +107,16 @@ def forward(self, images: torch.Tensor) -> torch.Tensor:
106107 return self .lfm2 .image_embedding (images )
107108
108109 encoder = ImageEncoder (lfm2 )
110+ if dtype != DType .fp32 :
111+ # Cast only the vision parts of the HF model, not text_model (KV cache buffers
112+ # must stay in the text decoder's dtype, not the vision encoder's dtype).
113+ lfm2 .model_ .model .vision_tower .to (dtype .to_torch_dtype ())
114+ lfm2 .model_ .model .multi_modal_projector .to (dtype .to_torch_dtype ())
109115 example_pixels = torch .randint (
110116 0 , 256 , (1 , 3 , IMAGE_SIZE , IMAGE_SIZE ), dtype = torch .float32
111117 )
112118
113- logging .info ("Exporting vision encoder..." )
119+ logging .info (f "Exporting vision encoder ( { dtype . name } ) ..." )
114120 with torch .no_grad ():
115121 ep = torch .export .export (encoder , (example_pixels ,), strict = False )
116122 return ep
@@ -251,8 +257,10 @@ def export_all(
251257 if dtype != DType .fp32 :
252258 lfm2 = lfm2 .to (dtype .to_torch_dtype ())
253259
260+ # Vision encoder: use fp16 when quantizing (halves ~1.6 GB SigLIP2) or when dtype=fp16
261+ vision_dtype = DType .fp16 if (quantize or dtype == DType .fp16 ) else DType .fp32
254262 logging .info ("[1/3] Exporting vision encoder..." )
255- vision_ep = export_image_encoder (lfm2 )
263+ vision_ep = export_image_encoder (lfm2 , vision_dtype )
256264
257265 # Text decoder MUST come before token embedding (see export_token_embedding docstring)
258266 logging .info ("[2/3] Exporting text decoder..." )
@@ -304,7 +312,7 @@ def export_all(
304312 memory_planning_pass = MemoryPlanningPass (alloc_graph_input = False ),
305313 sym_shape_eval_pass = {
306314 "vision_encoder" : ConstraintBasedSymShapeEvalPass (),
307- "token_embedding" : HintBasedSymShapeEvalPass (),
315+ "token_embedding" : ConstraintBasedSymShapeEvalPass (),
308316 "text_decoder" : ConstraintBasedSymShapeEvalPass (),
309317 },
310318 )
0 commit comments