Skip to content

Commit baf48bb

Browse files
b8zhongvincentzed
authored andcommitted
Add LFM2.5-VL export with CUDA/AOTI backend
Export LFM2.5-VL (450M and 1.6B) as a multi-method PTE with three methods: vision_encoder, token_embedding, and text_decoder, all delegated to the CUDA/AOTI backend. Key changes: - examples/models/lfm2_5_vl/: new model, weight converter, and export script for LFM2.5-VL on CUDA - examples/models/lfm2/short_conv.py: dual state management — state-as-IO for CUDA/AOTI (via attn_options["conv_states"]) with register_buffer fallback for XNNPack/portable backends - examples/models/llama/llama_transformer.py: pass layer_idx to ShortConvBlock for per-layer conv state keying - exir/emit/_emitter.py: copy CUDA tensor storage to CPU before ctypes pointer read to prevent segfault during serialization Tested on NVIDIA B300: 333-400 decode tok/s, 435-454 prefill tok/s, correct coherent generation on text-only and vision-language prompts. Also compatible with llama_main C++ runner.
1 parent 273aee9 commit baf48bb

9 files changed

Lines changed: 626 additions & 77 deletions

File tree

examples/models/lfm2/short_conv.py

Lines changed: 66 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,112 +1,102 @@
1-
from typing import Optional
1+
from __future__ import annotations
22

33
import torch
44
from executorch.examples.models.llama.attention import ForwardOptions
55
from executorch.examples.models.llama.feed_forward import FeedForward
6-
76
from executorch.examples.models.llama.norm import RMSNorm
87
from torch import nn
98

109

1110
class ShortConv(nn.Module):
12-
def __init__(
13-
self,
14-
dim: int,
15-
L_cache: int = 3,
16-
bias: bool = False,
17-
device: Optional[torch.device] = None,
18-
dtype: Optional[torch.dtype] = None,
19-
):
11+
"""Depthwise short convolution with dual state management.
12+
13+
Supports two modes:
14+
1. State-as-IO: caller passes conv_state in and receives new state back.
15+
Required for AOTI which cannot re-trace mutable buffer mutations.
16+
2. Internal buffer: uses register_buffer + copy_() for XNNPack/portable
17+
backends where mutable buffers are handled natively.
18+
"""
19+
20+
def __init__(self, dim: int, L_cache: int = 3, *, bias: bool = False) -> None:
2021
super().__init__()
22+
assert L_cache == 3, f"Manual depthwise conv only supports L_cache=3, got {L_cache}"
2123
self.dim = dim
2224
self.L_cache = L_cache
23-
self.device = device
24-
self.dtype = dtype
25-
self.bias = bias
26-
27-
self.conv = nn.Conv1d(
28-
dim,
29-
dim,
30-
kernel_size=L_cache,
31-
padding=0, ## we don't need padding since we handle it manually
32-
groups=dim,
33-
bias=bias,
34-
)
35-
36-
conv_state = torch.zeros(
37-
1, ## batch size is assumed to be 1 for now
38-
dim,
39-
L_cache - 1,
40-
device="cpu",
41-
)
42-
self.register_buffer("conv_state", conv_state)
4325

44-
## better performance in Executorch with separate projections
26+
self.conv = nn.Conv1d(dim, dim, kernel_size=L_cache, padding=0, groups=dim, bias=bias)
4527
self.B_proj = nn.Linear(dim, dim, bias=bias)
4628
self.C_proj = nn.Linear(dim, dim, bias=bias)
4729
self.x_proj = nn.Linear(dim, dim, bias=bias)
48-
4930
self.out_proj = nn.Linear(dim, dim, bias=bias)
5031

51-
def forward(self, x: torch.Tensor) -> torch.Tensor:
52-
batch_size, seqlen, dim = x.size()
53-
assert batch_size == 1, "batch_size must be 1"
54-
55-
B = self.B_proj(x).transpose(-1, -2) # (batch_size, dim, seq_len)
56-
C = self.C_proj(x).transpose(-1, -2) # (batch_size, dim, seq_len)
57-
x = self.x_proj(x).transpose(-1, -2) # (batch_size, dim, seq_len)
58-
59-
Bx = B * x # (batch_size, dim, seq_len)
32+
self.register_buffer(
33+
"conv_state",
34+
torch.zeros(1, dim, L_cache - 1),
35+
)
6036

61-
## This is where we handle padding
62-
## By default, the conv_state is initialized to 0.
63-
# So, assuming prefill is done on an empty cache, concatenating conv_state to the beginning of the sequence acts similary to
64-
## using nn.Conv1d(padding=L_cache-1) (for prefill) without no manual padding.
65-
## However, the manual padding has the added benefit of being correct during decode, when the cache is not initialized to 0.
66-
Bx = torch.cat(
67-
[self.conv_state, Bx], dim=-1
68-
) # (batch_size, dim, seq_len + L_cache - 1)
37+
def forward(
38+
self, x: torch.Tensor, conv_state: torch.Tensor | None = None
39+
) -> tuple[torch.Tensor, torch.Tensor]:
40+
if conv_state is None:
41+
conv_state = self.conv_state
6942

70-
## Update the conv_state
71-
new_conv_state = Bx[
72-
..., -(self.L_cache - 1) :
73-
] # (batch_size, dim, L_cache - 1)
74-
with torch.no_grad():
75-
self.conv_state.copy_(new_conv_state)
43+
B = self.B_proj(x).transpose(-1, -2)
44+
C = self.C_proj(x).transpose(-1, -2)
45+
x = self.x_proj(x).transpose(-1, -2)
7646

77-
conv_out = self.conv(Bx)[..., : x.size(-1)] # (batch_size, dim, seq_len)
78-
y = C * conv_out # (batch_size, dim, seq_len)
47+
Bx = torch.cat([conv_state, B * x], dim=-1)
48+
new_conv_state = Bx[..., -(self.L_cache - 1) :]
7949

80-
y = y.transpose(-1, -2) # (batch_size, seq_len, dim)
81-
y = y.contiguous() # (batch_size, seq_len, dim)
82-
y = self.out_proj(y) # (batch_size, seq_len, dim)
83-
return y
50+
# Manual depthwise conv — Triton has no template for nn.Conv1d
51+
# with groups=dim and dynamic sequence length.
52+
w = self.conv.weight[:, 0, :]
53+
conv_out = Bx[..., :-2] * w[:, 0:1] + Bx[..., 1:-1] * w[:, 1:2] + Bx[..., 2:] * w[:, 2:3]
8454

85-
def reset_cache(self):
86-
self.conv_state.zero_()
55+
y = self.out_proj((C * conv_out).transpose(-1, -2).contiguous())
56+
return y, new_conv_state
8757

8858

8959
class ShortConvBlock(nn.Module):
90-
def __init__(self, dim: int, hidden_dim: int, norm_eps: float):
60+
def __init__(self, dim: int, hidden_dim: int, norm_eps: float, layer_idx: int = -1) -> None:
9161
super().__init__()
92-
self.L_cache = 3 # hardcode 3 for now
93-
self.conv = ShortConv(dim, self.L_cache, bias=False)
62+
self.layer_idx = layer_idx
63+
self.conv = ShortConv(dim, L_cache=3, bias=False)
9464
self.feed_forward = FeedForward(dim, hidden_dim)
9565
self.ffn_norm = RMSNorm(dim, norm_eps)
96-
# use attention_norm norm instead of operator_norm to unify with TransformerBlock
9766
self.attention_norm = RMSNorm(dim, norm_eps)
9867

9968
def forward(
10069
self,
101-
x,
102-
freqs_cos=None,
103-
freqs_sin=None,
104-
_unused_attn_options: Optional[ForwardOptions] = None,
105-
): # x: 1xN
106-
h = self.conv.forward(self.attention_norm(x))
70+
x: torch.Tensor,
71+
freqs_cos: torch.Tensor | None = None,
72+
freqs_sin: torch.Tensor | None = None,
73+
attn_options: ForwardOptions | None = None,
74+
) -> tuple[torch.Tensor, dict]:
75+
# State-as-IO: read from attn_options if provided (CUDA/AOTI path)
76+
conv_state = None
77+
if attn_options is not None:
78+
conv_states = attn_options.get("conv_states")
79+
if conv_states is not None:
80+
conv_state = conv_states.get(self.layer_idx)
81+
82+
h, new_conv_state = self.conv(self.attention_norm(x), conv_state)
10783
h = x + h
10884
out = h + self.feed_forward(self.ffn_norm(h))
109-
return out, None
11085

111-
def reset_cache(self):
112-
self.conv.reset_cache()
86+
# Write back state
87+
update: dict = {}
88+
if attn_options is not None and "conv_states" in attn_options:
89+
if conv_state is not None:
90+
conv_state.copy_(new_conv_state)
91+
states = dict(attn_options["conv_states"])
92+
states[self.layer_idx] = new_conv_state
93+
update["conv_states"] = states
94+
else:
95+
# XNNPack/portable path: persist via internal buffer
96+
with torch.no_grad():
97+
self.conv.conv_state.copy_(new_conv_state)
98+
99+
return out, update
100+
101+
def reset_cache(self) -> None:
102+
self.conv.conv_state.zero_()
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from executorch.examples.models.lfm2_5_vl.convert_weights import convert_weights
8+
from executorch.examples.models.lfm2_5_vl.model import Lfm2p5VlModel
9+
10+
__all__ = [
11+
"convert_weights",
12+
"Lfm2p5VlModel",
13+
]
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
{
2+
"dim": 2048,
3+
"ffn_dim_multiplier": 1,
4+
"hidden_dim": 8192,
5+
"n_heads": 32,
6+
"n_kv_heads": 8,
7+
"n_layers": 16,
8+
"norm_eps": 1e-5,
9+
"rope_theta": 1000000.0,
10+
"use_scaled_rope": false,
11+
"vocab_size": 65536,
12+
"use_hf_rope": true,
13+
"use_qk_norm": true,
14+
"qk_norm_before_rope": true,
15+
"layer_types": [
16+
"conv",
17+
"conv",
18+
"full_attention",
19+
"conv",
20+
"conv",
21+
"full_attention",
22+
"conv",
23+
"conv",
24+
"full_attention",
25+
"conv",
26+
"full_attention",
27+
"conv",
28+
"full_attention",
29+
"conv",
30+
"full_attention",
31+
"conv"
32+
]
33+
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
{
2+
"dim": 1024,
3+
"ffn_dim_multiplier": 1,
4+
"hidden_dim": 4608,
5+
"n_heads": 16,
6+
"n_kv_heads": 8,
7+
"n_layers": 16,
8+
"norm_eps": 1e-5,
9+
"rope_theta": 1000000.0,
10+
"use_scaled_rope": false,
11+
"vocab_size": 65536,
12+
"use_hf_rope": true,
13+
"use_qk_norm": true,
14+
"qk_norm_before_rope": true,
15+
"layer_types": [
16+
"conv",
17+
"conv",
18+
"full_attention",
19+
"conv",
20+
"conv",
21+
"full_attention",
22+
"conv",
23+
"conv",
24+
"full_attention",
25+
"conv",
26+
"full_attention",
27+
"conv",
28+
"full_attention",
29+
"conv",
30+
"full_attention",
31+
"conv"
32+
]
33+
}
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""Convert LFM2.5-VL text decoder weights from HuggingFace to ET format."""
8+
9+
from __future__ import annotations
10+
11+
import argparse
12+
from pathlib import Path
13+
14+
import torch
15+
from executorch.examples.models.checkpoint import get_mapped_key
16+
from safetensors.torch import load_file
17+
18+
_LFM2_5_VL_TO_META: dict[str, str] = {
19+
"model.language_model.embed_tokens.weight": "tok_embeddings.weight",
20+
"model.language_model.embedding_norm.weight": "norm.weight",
21+
"model.language_model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
22+
"model.language_model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
23+
"model.language_model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
24+
"model.language_model.layers.{}.self_attn.out_proj.weight": "layers.{}.attention.wo.weight",
25+
"model.language_model.layers.{}.self_attn.q_layernorm.weight": "layers.{}.attention.q_norm_fn.weight",
26+
"model.language_model.layers.{}.self_attn.k_layernorm.weight": "layers.{}.attention.k_norm_fn.weight",
27+
"model.language_model.layers.{}.operator_norm.weight": "layers.{}.attention_norm.weight",
28+
"model.language_model.layers.{}.ffn_norm.weight": "layers.{}.ffn_norm.weight",
29+
"model.language_model.layers.{}.feed_forward.w1.weight": "layers.{}.feed_forward.w1.weight",
30+
"model.language_model.layers.{}.feed_forward.w2.weight": "layers.{}.feed_forward.w2.weight",
31+
"model.language_model.layers.{}.feed_forward.w3.weight": "layers.{}.feed_forward.w3.weight",
32+
"model.language_model.layers.{}.conv.conv.weight": "layers.{}.conv.conv.weight",
33+
"model.language_model.layers.{}.conv.out_proj.weight": "layers.{}.conv.out_proj.weight",
34+
"model.language_model.lm_head.weight": "output.weight",
35+
}
36+
37+
_IN_PROJ_SPLITS = ("B_proj", "C_proj", "x_proj")
38+
39+
40+
def lfm2_5_vl_to_meta(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
41+
"""Extract and remap language model weights from a full VL state dict."""
42+
converted: dict[str, torch.Tensor] = {}
43+
44+
for key, value in state_dict.items():
45+
if not key.startswith("model.language_model."):
46+
continue
47+
48+
try:
49+
new_key = get_mapped_key(key, _LFM2_5_VL_TO_META)
50+
except Exception:
51+
new_key = key.removeprefix("model.language_model.")
52+
53+
if new_key.endswith(".conv.in_proj.weight"):
54+
for name, chunk in zip(_IN_PROJ_SPLITS, torch.chunk(value, 3, dim=0)):
55+
converted[new_key.replace("in_proj", name)] = chunk
56+
else:
57+
converted[new_key] = value
58+
59+
if "output.weight" not in converted:
60+
converted["output.weight"] = converted["tok_embeddings.weight"]
61+
62+
return converted
63+
64+
65+
def convert_weights(input_dir: str, output_file: str) -> None:
66+
sd = load_file(str(Path(input_dir) / "model.safetensors"))
67+
sd = lfm2_5_vl_to_meta(sd)
68+
torch.save(sd, output_file)
69+
print(f"Saved {len(sd)} tensors to {output_file}")
70+
71+
72+
def main() -> None:
73+
parser = argparse.ArgumentParser(description="Convert LFM2.5-VL weights to ET format.")
74+
parser.add_argument("input_dir", help="Directory containing model.safetensors.")
75+
parser.add_argument("output", help="Output .pt checkpoint path.")
76+
args = parser.parse_args()
77+
convert_weights(args.input_dir, args.output)
78+
79+
80+
if __name__ == "__main__":
81+
main()

0 commit comments

Comments
 (0)