Skip to content

Commit 00fc22d

Browse files
fix(lfm2_5_vl): always export vision encoder as fp32, add pt2e quantization support
XNNPACK promotes fp16 FC ops to PFP16 internally on Apple Silicon (SME2), producing fp16 activations that crash the non-delegated layer_norm ops (portable kernels only support fp32). fp32 vision encoder is unaffected: XNNPACK uses PFP32 and outputs fp32 activations that layer_norm can consume. Remove the vision_dtype fp16 path entirely. For model size reduction, use --quantize instead: mirrors LLaVA's pt2e_quantize approach (quantize on pre_autograd_graph_module before Edge lowering so aten.linear is intact when XNNPACK partitions), resulting in QS8 FC ops with no PFP16 promotion. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 3fef208 commit 00fc22d

1 file changed

Lines changed: 41 additions & 16 deletions

File tree

examples/models/lfm2_5_vl/export_lfm2_5_vl.py

Lines changed: 41 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
generic MultimodalRunner (C++ llava_main).
1010
1111
Methods:
12-
vision_encoder : [1, 3, 512, 512] f32 NCHW pixels [0,255] -> [1, 256, 2048] f32/f16
12+
vision_encoder : [1, 3, 512, 512] f32 NCHW pixels [0,255] -> [1, 256, 2048] f32
1313
token_embedding : [1, seq_len] i64 -> [1, seq_len, 2048] f32
1414
text_decoder : ([1, seq_len, 2048] f32, [seq_len] i64) -> [1, 65536] f32
1515
@@ -28,6 +28,10 @@
2828
ConfigPrecisionType,
2929
)
3030
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
31+
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (
32+
XNNPACKQuantizer,
33+
get_symmetric_quantization_config,
34+
)
3135
from executorch.examples.models.llama.export_llama_lib import (
3236
get_quantizer_and_quant_params,
3337
)
@@ -85,17 +89,20 @@ def export(self) -> "Lfm2p5VlEdgeManager":
8589

8690

8791
def export_image_encoder(
88-
lfm2, dtype: DType = DType.fp32
92+
lfm2, quantize: bool = False
8993
) -> torch.export.ExportedProgram:
9094
"""Export vision encoder as 'vision_encoder' method.
9195
9296
Input: [1, 3, 512, 512] float32 NCHW pixels in [0, 255]
93-
Output: [1, 256, 2048] f32/f16 image embeddings
97+
Output: [1, 256, 2048] float32 image embeddings
9498
9599
Normalize + patch extraction are baked in so the C++ runner only
96100
needs to resize to 512x512 and pass the raw pixel buffer.
97-
Weights are cast to dtype (fp16 halves the ~1.6 GB vision encoder).
98-
The input pixel tensor always stays fp32.
101+
102+
When quantize=True, mirrors LLaVA's export_image_encoder: uses
103+
LLMEdgeManager.export().pt2e_quantize() so quantization happens on
104+
the pre-autograd graph (aten.linear still intact), then re-exports
105+
the quantized graph for to_edge_transform_and_lower.
99106
"""
100107

101108
class ImageEncoder(torch.nn.Module):
@@ -107,18 +114,38 @@ def forward(self, images: torch.Tensor) -> torch.Tensor:
107114
return self.lfm2.image_embedding(images)
108115

109116
encoder = ImageEncoder(lfm2)
110-
if dtype != DType.fp32:
111-
# Cast only the vision parts of the HF model, not text_model (KV cache buffers
112-
# must stay in the text decoder's dtype, not the vision encoder's dtype).
113-
lfm2.model_.model.vision_tower.to(dtype.to_torch_dtype())
114-
lfm2.model_.model.multi_modal_projector.to(dtype.to_torch_dtype())
115117
example_pixels = torch.randint(
116118
0, 256, (1, 3, IMAGE_SIZE, IMAGE_SIZE), dtype=torch.float32
117119
)
118120

119-
logging.info(f"Exporting vision encoder ({dtype.name})...")
120-
with torch.no_grad():
121-
ep = torch.export.export(encoder, (example_pixels,), strict=False)
121+
if quantize:
122+
logging.info("Exporting vision encoder (int8 dynamic quantized)...")
123+
quantizer = XNNPACKQuantizer().set_global(
124+
get_symmetric_quantization_config()
125+
)
126+
manager = (
127+
Lfm2p5VlEdgeManager(
128+
model=encoder,
129+
modelname="lfm2_5_vl_image_encoder",
130+
max_seq_len=MAX_SEQ_LEN,
131+
dtype=DType.fp32,
132+
use_kv_cache=False,
133+
example_inputs=(example_pixels,),
134+
)
135+
.export()
136+
.pt2e_quantize([quantizer])
137+
)
138+
with torch.no_grad():
139+
ep = torch.export.export(
140+
manager.pre_autograd_graph_module,
141+
manager.example_inputs,
142+
strict=False,
143+
)
144+
else:
145+
logging.info("Exporting vision encoder (fp32)...")
146+
with torch.no_grad():
147+
ep = torch.export.export(encoder, (example_pixels,), strict=False)
148+
122149
return ep
123150

124151

@@ -257,10 +284,8 @@ def export_all(
257284
if dtype != DType.fp32:
258285
lfm2 = lfm2.to(dtype.to_torch_dtype())
259286

260-
# Vision encoder: use fp16 when quantizing (halves ~1.6 GB SigLIP2) or when dtype=fp16
261-
vision_dtype = DType.fp16 if (quantize or dtype == DType.fp16) else DType.fp32
262287
logging.info("[1/3] Exporting vision encoder...")
263-
vision_ep = export_image_encoder(lfm2, vision_dtype)
288+
vision_ep = export_image_encoder(lfm2, quantize)
264289

265290
# Text decoder MUST come before token embedding (see export_token_embedding docstring)
266291
logging.info("[2/3] Exporting text decoder...")

0 commit comments

Comments
 (0)