Skip to content

Commit 3fef208

Browse files
fix(lfm2_5_vl): fp16 vision encoder for --quantize/--dtype fp16, fix KV cache dtype mismatch
- Cast only vision_tower + multi_modal_projector to fp16 (not text_model), preventing update_cache dtype assertion when --quantize is used with fp32 text decoder - Replace deprecated HintBasedSymShapeEvalPass with ConstraintBasedSymShapeEvalPass Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 4a09493 commit 3fef208

1 file changed

Lines changed: 18 additions & 10 deletions

File tree

examples/models/lfm2_5_vl/export_lfm2_5_vl.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
generic MultimodalRunner (C++ llava_main).
1010
1111
Methods:
12-
vision_encoder : [1, 3, 512, 512] f32 NCHW pixels [0,255] -> [1, 256, 2048] f32
12+
vision_encoder : [1, 3, 512, 512] f32 NCHW pixels [0,255] -> [1, 256, 2048] f32/f16
1313
token_embedding : [1, seq_len] i64 -> [1, seq_len, 2048] f32
1414
text_decoder : ([1, seq_len, 2048] f32, [seq_len] i64) -> [1, 65536] f32
1515
@@ -53,10 +53,7 @@
5353
)
5454
from executorch.exir.passes import MemoryPlanningPass
5555
from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
56-
from executorch.exir.passes.sym_shape_eval_pass import (
57-
ConstraintBasedSymShapeEvalPass,
58-
HintBasedSymShapeEvalPass,
59-
)
56+
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
6057
from executorch.extension.llm.export.builder import DType, LLMEdgeManager
6158
from executorch.extension.llm.export.config.llm_config import LlmConfig
6259
from torch.export import Dim
@@ -87,14 +84,18 @@ def export(self) -> "Lfm2p5VlEdgeManager":
8784
return self
8885

8986

90-
def export_image_encoder(lfm2) -> torch.export.ExportedProgram:
87+
def export_image_encoder(
88+
lfm2, dtype: DType = DType.fp32
89+
) -> torch.export.ExportedProgram:
9190
"""Export vision encoder as 'vision_encoder' method.
9291
9392
Input: [1, 3, 512, 512] float32 NCHW pixels in [0, 255]
94-
Output: [1, 256, 2048] float32 image embeddings
93+
Output: [1, 256, 2048] f32/f16 image embeddings
9594
9695
Normalize + patch extraction are baked in so the C++ runner only
9796
needs to resize to 512x512 and pass the raw pixel buffer.
97+
Weights are cast to dtype (fp16 halves the ~1.6 GB vision encoder).
98+
The input pixel tensor always stays fp32.
9899
"""
99100

100101
class ImageEncoder(torch.nn.Module):
@@ -106,11 +107,16 @@ def forward(self, images: torch.Tensor) -> torch.Tensor:
106107
return self.lfm2.image_embedding(images)
107108

108109
encoder = ImageEncoder(lfm2)
110+
if dtype != DType.fp32:
111+
# Cast only the vision parts of the HF model, not text_model (KV cache buffers
112+
# must stay in the text decoder's dtype, not the vision encoder's dtype).
113+
lfm2.model_.model.vision_tower.to(dtype.to_torch_dtype())
114+
lfm2.model_.model.multi_modal_projector.to(dtype.to_torch_dtype())
109115
example_pixels = torch.randint(
110116
0, 256, (1, 3, IMAGE_SIZE, IMAGE_SIZE), dtype=torch.float32
111117
)
112118

113-
logging.info("Exporting vision encoder...")
119+
logging.info(f"Exporting vision encoder ({dtype.name})...")
114120
with torch.no_grad():
115121
ep = torch.export.export(encoder, (example_pixels,), strict=False)
116122
return ep
@@ -251,8 +257,10 @@ def export_all(
251257
if dtype != DType.fp32:
252258
lfm2 = lfm2.to(dtype.to_torch_dtype())
253259

260+
# Vision encoder: use fp16 when quantizing (halves ~1.6 GB SigLIP2) or when dtype=fp16
261+
vision_dtype = DType.fp16 if (quantize or dtype == DType.fp16) else DType.fp32
254262
logging.info("[1/3] Exporting vision encoder...")
255-
vision_ep = export_image_encoder(lfm2)
263+
vision_ep = export_image_encoder(lfm2, vision_dtype)
256264

257265
# Text decoder MUST come before token embedding (see export_token_embedding docstring)
258266
logging.info("[2/3] Exporting text decoder...")
@@ -304,7 +312,7 @@ def export_all(
304312
memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
305313
sym_shape_eval_pass={
306314
"vision_encoder": ConstraintBasedSymShapeEvalPass(),
307-
"token_embedding": HintBasedSymShapeEvalPass(),
315+
"token_embedding": ConstraintBasedSymShapeEvalPass(),
308316
"text_decoder": ConstraintBasedSymShapeEvalPass(),
309317
},
310318
)

0 commit comments

Comments
 (0)