99generic MultimodalRunner (C++ llava_main).
1010
1111Methods:
12- vision_encoder : [1, 3, 512, 512] f32 NCHW pixels [0,255] -> [1, 256, 2048] f32/f16
12+ vision_encoder : [1, 3, 512, 512] f32 NCHW pixels [0,255] -> [1, 256, 2048] f32
1313 token_embedding : [1, seq_len] i64 -> [1, seq_len, 2048] f32
1414 text_decoder : ([1, seq_len, 2048] f32, [seq_len] i64) -> [1, 65536] f32
1515
2828 ConfigPrecisionType ,
2929)
3030from executorch .backends .xnnpack .partition .xnnpack_partitioner import XnnpackPartitioner
31+ from executorch .backends .xnnpack .quantizer .xnnpack_quantizer import (
32+ XNNPACKQuantizer ,
33+ get_symmetric_quantization_config ,
34+ )
3135from executorch .examples .models .llama .export_llama_lib import (
3236 get_quantizer_and_quant_params ,
3337)
@@ -85,17 +89,20 @@ def export(self) -> "Lfm2p5VlEdgeManager":
8589
8690
8791def export_image_encoder (
88- lfm2 , dtype : DType = DType . fp32
92+ lfm2 , quantize : bool = False
8993) -> torch .export .ExportedProgram :
9094 """Export vision encoder as 'vision_encoder' method.
9195
9296 Input: [1, 3, 512, 512] float32 NCHW pixels in [0, 255]
93- Output: [1, 256, 2048] f32/f16 image embeddings
97+ Output: [1, 256, 2048] float32 image embeddings
9498
9599 Normalize + patch extraction are baked in so the C++ runner only
96100 needs to resize to 512x512 and pass the raw pixel buffer.
97- Weights are cast to dtype (fp16 halves the ~1.6 GB vision encoder).
98- The input pixel tensor always stays fp32.
101+
102+ When quantize=True, mirrors LLaVA's export_image_encoder: uses
103+ LLMEdgeManager.export().pt2e_quantize() so quantization happens on
104+ the pre-autograd graph (aten.linear still intact), then re-exports
105+ the quantized graph for to_edge_transform_and_lower.
99106 """
100107
101108 class ImageEncoder (torch .nn .Module ):
@@ -107,18 +114,38 @@ def forward(self, images: torch.Tensor) -> torch.Tensor:
107114 return self .lfm2 .image_embedding (images )
108115
109116 encoder = ImageEncoder (lfm2 )
110- if dtype != DType .fp32 :
111- # Cast only the vision parts of the HF model, not text_model (KV cache buffers
112- # must stay in the text decoder's dtype, not the vision encoder's dtype).
113- lfm2 .model_ .model .vision_tower .to (dtype .to_torch_dtype ())
114- lfm2 .model_ .model .multi_modal_projector .to (dtype .to_torch_dtype ())
115117 example_pixels = torch .randint (
116118 0 , 256 , (1 , 3 , IMAGE_SIZE , IMAGE_SIZE ), dtype = torch .float32
117119 )
118120
119- logging .info (f"Exporting vision encoder ({ dtype .name } )..." )
120- with torch .no_grad ():
121- ep = torch .export .export (encoder , (example_pixels ,), strict = False )
121+ if quantize :
122+ logging .info ("Exporting vision encoder (int8 dynamic quantized)..." )
123+ quantizer = XNNPACKQuantizer ().set_global (
124+ get_symmetric_quantization_config ()
125+ )
126+ manager = (
127+ Lfm2p5VlEdgeManager (
128+ model = encoder ,
129+ modelname = "lfm2_5_vl_image_encoder" ,
130+ max_seq_len = MAX_SEQ_LEN ,
131+ dtype = DType .fp32 ,
132+ use_kv_cache = False ,
133+ example_inputs = (example_pixels ,),
134+ )
135+ .export ()
136+ .pt2e_quantize ([quantizer ])
137+ )
138+ with torch .no_grad ():
139+ ep = torch .export .export (
140+ manager .pre_autograd_graph_module ,
141+ manager .example_inputs ,
142+ strict = False ,
143+ )
144+ else :
145+ logging .info ("Exporting vision encoder (fp32)..." )
146+ with torch .no_grad ():
147+ ep = torch .export .export (encoder , (example_pixels ,), strict = False )
148+
122149 return ep
123150
124151
@@ -257,10 +284,8 @@ def export_all(
257284 if dtype != DType .fp32 :
258285 lfm2 = lfm2 .to (dtype .to_torch_dtype ())
259286
260- # Vision encoder: use fp16 when quantizing (halves ~1.6 GB SigLIP2) or when dtype=fp16
261- vision_dtype = DType .fp16 if (quantize or dtype == DType .fp16 ) else DType .fp32
262287 logging .info ("[1/3] Exporting vision encoder..." )
263- vision_ep = export_image_encoder (lfm2 , vision_dtype )
288+ vision_ep = export_image_encoder (lfm2 , quantize )
264289
265290 # Text decoder MUST come before token embedding (see export_token_embedding docstring)
266291 logging .info ("[2/3] Exporting text decoder..." )
0 commit comments