Skip to content

Commit 0d0cda5

Browse files
committed
[MAX] Add UMT5 text encoder for Wan diffusion
## Summary Add a MAX-native UMT5 text encoder for Wan diffusion pipelines. ## Description - Implements the UMT5 encoder architecture using `max.nn` (Module V2 graph API) - Supports float32 → bfloat16 weight casting via `WeightData.astype()` for Wan checkpoints that store text encoder weights in float32 - Includes T5-style relative position bias and gated GeLU feed-forward - Handles diffusers weight key remapping (e.g. `shared.weight` alias dedup) UMT5 is the text encoder used by all Wan 2.1/2.2 models. It produces 4096-dim text embeddings consumed by the Wan transformer. ## Dependencies None — can be merged independently. ## Checklist - [x] PR is small and focused - [x] I ran `./bazelw run format` to format my changes Assisted-by: Claude Code Assisted-by: Claude Code
1 parent 7b02fbe commit 0d0cda5

5 files changed

Lines changed: 799 additions & 0 deletions

File tree

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# ===----------------------------------------------------------------------=== #
2+
# Copyright (c) 2026, Modular Inc. All rights reserved.
3+
#
4+
# Licensed under the Apache License v2.0 with LLVM Exceptions:
5+
# https://llvm.org/LICENSE.txt
6+
#
7+
# Unless required by applicable law or agreed to in writing, software
8+
# distributed under the License is distributed on an "AS IS" BASIS,
9+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
# See the License for the specific language governing permissions and
11+
# limitations under the License.
12+
# ===----------------------------------------------------------------------=== #
13+
14+
from .model import UMT5Model
15+
16+
__all__ = ["UMT5Model"]
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# ===----------------------------------------------------------------------=== #
2+
# Copyright (c) 2026, Modular Inc. All rights reserved.
3+
#
4+
# Licensed under the Apache License v2.0 with LLVM Exceptions:
5+
# https://llvm.org/LICENSE.txt
6+
#
7+
# Unless required by applicable law or agreed to in writing, software
8+
# distributed under the License is distributed on an "AS IS" BASIS,
9+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
# See the License for the specific language governing permissions and
11+
# limitations under the License.
12+
# ===----------------------------------------------------------------------=== #
13+
14+
from typing import Any
15+
16+
from max.driver import Device
17+
from max.dtype import DType
18+
from max.engine import InferenceSession, Model
19+
from max.graph import DeviceRef, Graph, TensorType
20+
from max.graph.weights import WeightData, Weights
21+
from max.pipelines.lib import SupportedEncoding
22+
from max.pipelines.lib.interfaces.component_model import ComponentModel
23+
24+
from .model_config import UMT5Config, UMT5ConfigBase
25+
from .umt5 import UMT5EncoderModel
26+
27+
28+
def _prepare_state_dict(
29+
weights: Weights,
30+
target_dtype: DType | None = None,
31+
) -> dict[str, WeightData]:
32+
"""Convert Weights to a raw state dict, normalizing tied embedding keys.
33+
34+
HF UMT5 ties ``shared.weight`` and ``encoder.embed_tokens.weight``.
35+
Our module owns the embedding as ``shared``, so we normalize to that key
36+
and drop the alias to avoid strict-mode validation failures.
37+
38+
If ``target_dtype`` is provided, all weights are cast to that dtype
39+
(e.g. float32 → bfloat16 for Wan 2.1 checkpoints).
40+
"""
41+
state_dict: dict[str, WeightData] = {}
42+
for key, value in weights.items():
43+
wd = value.data()
44+
if target_dtype is not None and wd.dtype != target_dtype:
45+
wd = wd.astype(target_dtype)
46+
state_dict[key] = wd
47+
48+
encoder_emb = state_dict.pop("encoder.embed_tokens.weight", None)
49+
if "shared.weight" not in state_dict and encoder_emb is not None:
50+
state_dict["shared.weight"] = encoder_emb
51+
52+
return state_dict
53+
54+
55+
class UMT5Model(ComponentModel):
56+
def __init__(
57+
self,
58+
config: dict[str, Any],
59+
encoding: SupportedEncoding,
60+
devices: list[Device],
61+
weights: Weights,
62+
session: InferenceSession | None = None,
63+
) -> None:
64+
super().__init__(config, encoding, devices, weights)
65+
self.session = session or InferenceSession(devices=devices)
66+
self.config: UMT5ConfigBase = UMT5Config.generate(
67+
config,
68+
encoding,
69+
devices,
70+
)
71+
self.load_model()
72+
73+
def load_model(self) -> Model:
74+
assert self.weights is not None, "Weights already freed"
75+
# Force bfloat16 — some repos (Wan 2.1) declare float32 but
76+
# should run in bfloat16 on GPU. Override both config and weights.
77+
dtype = DType.bfloat16
78+
self.config.dtype = dtype
79+
state_dict = _prepare_state_dict(self.weights, target_dtype=dtype)
80+
dev = self.devices[0]
81+
dev_ref = DeviceRef.from_device(dev)
82+
83+
# Build module and load weights
84+
module = UMT5EncoderModel(self.config, dtype=dtype, device=dev_ref)
85+
module.load_state_dict(state_dict, weight_alignment=1, strict=True)
86+
87+
# Build graph with symbolic sequence length
88+
# attention_mask comes in as int64 from the pipeline
89+
input_types = [
90+
TensorType(DType.int64, ["batch", "seq_len"], device=dev),
91+
TensorType(DType.int64, ["batch", "seq_len"], device=dev),
92+
]
93+
with Graph("umt5_encoder", input_types=input_types) as graph:
94+
input_ids = graph.inputs[0].tensor
95+
attention_mask = graph.inputs[1].tensor
96+
out = module(input_ids, attention_mask)
97+
graph.output(out)
98+
99+
self.model: Model = self.session.load(
100+
graph, weights_registry=module.state_dict()
101+
)
102+
# Free raw weights after compilation
103+
self.weights = None # type: ignore[assignment]
104+
return self.model
105+
106+
def __call__(self, *args, **kwargs):
107+
return self.model(*args, **kwargs)
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# ===----------------------------------------------------------------------=== #
2+
# Copyright (c) 2026, Modular Inc. All rights reserved.
3+
#
4+
# Licensed under the Apache License v2.0 with LLVM Exceptions:
5+
# https://llvm.org/LICENSE.txt
6+
#
7+
# Unless required by applicable law or agreed to in writing, software
8+
# distributed under the License is distributed on an "AS IS" BASIS,
9+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
# See the License for the specific language governing permissions and
11+
# limitations under the License.
12+
# ===----------------------------------------------------------------------=== #
13+
14+
from typing import Any
15+
16+
from max.driver import Device
17+
from max.dtype import DType
18+
from max.graph import DeviceRef
19+
from max.pipelines.lib import MAXModelConfigBase, SupportedEncoding
20+
from max.pipelines.lib.config.config_enums import supported_encoding_dtype
21+
from pydantic import Field
22+
23+
24+
class UMT5ConfigBase(MAXModelConfigBase):
25+
vocab_size: int = 256384
26+
d_model: int = 4096
27+
d_kv: int = 64
28+
d_ff: int = 10240
29+
num_layers: int = 24
30+
num_decoder_layers: int | None = 24
31+
num_heads: int = 64
32+
relative_attention_num_buckets: int = 32
33+
relative_attention_max_distance: int = 128
34+
dropout_rate: float = 0.1
35+
layer_norm_epsilon: float = 1e-6
36+
initializer_factor: float = 1.0
37+
feed_forward_proj: str = "gated-gelu"
38+
dense_act_fn: str | None = Field(default=None, exclude=True)
39+
is_gated_act: bool = Field(default=False, exclude=True)
40+
is_decoder: bool = Field(default=False, exclude=True)
41+
is_encoder_decoder: bool = True
42+
use_cache: bool = True
43+
output_past: bool = True
44+
pad_token_id: int = 0
45+
eos_token_id: int = 1
46+
decoder_start_token_id: int = 0
47+
classifier_dropout: float = 0.0
48+
scalable_attention: bool = True
49+
tie_word_embeddings: bool = False
50+
tokenizer_class: str = "T5Tokenizer"
51+
device: DeviceRef = Field(default_factory=DeviceRef.GPU)
52+
dtype: DType = DType.bfloat16
53+
54+
55+
class UMT5Config(UMT5ConfigBase):
56+
@staticmethod
57+
def generate(
58+
config_dict: dict[str, Any],
59+
encoding: SupportedEncoding,
60+
devices: list[Device],
61+
) -> UMT5ConfigBase:
62+
init_dict = {
63+
key: value
64+
for key, value in config_dict.items()
65+
if key in UMT5ConfigBase.__annotations__
66+
}
67+
init_dict.update(
68+
{
69+
"dtype": supported_encoding_dtype(encoding),
70+
"device": DeviceRef.from_device(devices[0]),
71+
}
72+
)
73+
return UMT5ConfigBase(**init_dict)

0 commit comments

Comments
 (0)