[MAX] Add UMT5 text encoder for Wan diffusion#14
Conversation
## Summary Add a MAX-native UMT5 text encoder for Wan diffusion pipelines. ## Description - Implements the UMT5 encoder architecture using `max.nn` (Module V2 graph API) - Supports float32 → bfloat16 weight casting via `WeightData.astype()` for Wan checkpoints that store text encoder weights in float32 - Includes T5-style relative position bias and gated GeLU feed-forward - Handles diffusers weight key remapping (e.g. `shared.weight` alias dedup) UMT5 is the text encoder used by all Wan 2.1/2.2 models. It produces 4096-dim text embeddings consumed by the Wan transformer. ## Dependencies None — can be merged independently. ## Checklist - [x] PR is small and focused - [x] I ran `./bazelw run format` to format my changes Assisted-by: Claude Code Assisted-by: Claude Code stack-info: PR: #14, branch: jglee-sqbits/stack/2
0d0cda5 to
ea94825
Compare
7b02fbe to
aca5cea
Compare
There was a problem hiding this comment.
Code Review
This pull request introduces the UMT5 encoder architecture, providing the necessary components for model configuration, core neural network layers, and weight adaptation. The implementation includes a UMT5Model component designed for the MAX pipeline. Feedback suggests several improvements for better flexibility and performance: avoiding hardcoded data types to support various precisions, pre-calculating constant values like the attention scale and data type minimums outside the forward pass to optimize graph construction, and moving configuration parsing logic into the configuration class to prevent side effects from in-place mutation.
| dtype = DType.bfloat16 | ||
| self.config.dtype = dtype |
There was a problem hiding this comment.
Hardcoding DType.bfloat16 here overrides the encoding and config settings provided during initialization. While the comment explains this is a workaround for specific Wan 2.1 checkpoints, it limits the reusability of the UMT5Model component for other models that might require different precision (e.g., float32 or float16). This precision-forcing logic should ideally be handled during configuration resolution or guarded by a check to ensure it only applies when necessary.
| relative_position, | ||
| ) | ||
|
|
||
| scale = (num_buckets - max_exact) / math.log(max_distance / max_exact) |
There was a problem hiding this comment.
| _DTYPE_MIN: dict[DType, float] = { | ||
| DType.float16: -65504.0, | ||
| DType.bfloat16: -3.3895314e38, | ||
| DType.float32: -3.4028235e38, | ||
| } |
| act_info = config.feed_forward_proj.split("-") | ||
| config.dense_act_fn = act_info[-1] | ||
| config.is_gated_act = act_info[0] == "gated" | ||
| if (len(act_info) > 1 and act_info[0] != "gated") or len(act_info) > 2: | ||
| raise ValueError( | ||
| f"`feed_forward_proj`: {config.feed_forward_proj} is not valid." | ||
| ) | ||
| if config.feed_forward_proj == "gated-gelu": | ||
| config.dense_act_fn = "gelu_new" |
There was a problem hiding this comment.
The UMT5EncoderModel constructor is mutating the config object in-place to store parsed activation information. This is a side effect that can affect other parts of the application using the same configuration instance. It is recommended to move this parsing logic into the UMT5ConfigBase class (e.g., as properties or using a Pydantic validator) to keep the model implementation clean and side-effect free.
Stacked PRs:
[MAX] Add UMT5 text encoder for Wan diffusion
Summary
Add a MAX-native UMT5 text encoder for Wan diffusion pipelines.
Description
max.nn(Module V2 graph API)WeightData.astype()for Wan checkpoints that store text encoder weights in float32shared.weightalias dedup)UMT5 is the text encoder used by all Wan 2.1/2.2 models. It produces 4096-dim text embeddings consumed by the Wan transformer.
Dependencies
None — can be merged independently.
Checklist
./bazelw run formatto format my changesAssisted-by: Claude Code
Assisted-by: Claude Code