Skip to content

[MAX] Add UMT5 text encoder for Wan diffusion#14

Draft
jglee-sqbits wants to merge 1 commit into
jglee-sqbits/stack/1from
jglee-sqbits/stack/2
Draft

[MAX] Add UMT5 text encoder for Wan diffusion#14
jglee-sqbits wants to merge 1 commit into
jglee-sqbits/stack/1from
jglee-sqbits/stack/2

Conversation

@jglee-sqbits
Copy link
Copy Markdown
Collaborator

@jglee-sqbits jglee-sqbits commented Apr 1, 2026

Stacked PRs:


[MAX] Add UMT5 text encoder for Wan diffusion

Summary

Add a MAX-native UMT5 text encoder for Wan diffusion pipelines.

Description

  • Implements the UMT5 encoder architecture using max.nn (Module V2 graph API)
  • Supports float32 → bfloat16 weight casting via WeightData.astype() for Wan checkpoints that store text encoder weights in float32
  • Includes T5-style relative position bias and gated GeLU feed-forward
  • Handles diffusers weight key remapping (e.g. shared.weight alias dedup)

UMT5 is the text encoder used by all Wan 2.1/2.2 models. It produces 4096-dim text embeddings consumed by the Wan transformer.

Dependencies

None — can be merged independently.

Checklist

  • PR is small and focused
  • I ran ./bazelw run format to format my changes

Assisted-by: Claude Code

Assisted-by: Claude Code

## Summary

Add a MAX-native UMT5 text encoder for Wan diffusion pipelines.

## Description

- Implements the UMT5 encoder architecture using `max.nn` (Module V2 graph API)
- Supports float32 → bfloat16 weight casting via `WeightData.astype()` for Wan checkpoints that store text encoder weights in float32
- Includes T5-style relative position bias and gated GeLU feed-forward
- Handles diffusers weight key remapping (e.g. `shared.weight` alias dedup)

UMT5 is the text encoder used by all Wan 2.1/2.2 models. It produces 4096-dim text embeddings consumed by the Wan transformer.

## Dependencies

None — can be merged independently.

## Checklist

- [x] PR is small and focused
- [x] I ran `./bazelw run format` to format my changes

Assisted-by: Claude Code

Assisted-by: Claude Code

stack-info: PR: #14, branch: jglee-sqbits/stack/2
Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the UMT5 encoder architecture, providing the necessary components for model configuration, core neural network layers, and weight adaptation. The implementation includes a UMT5Model component designed for the MAX pipeline. Feedback suggests several improvements for better flexibility and performance: avoiding hardcoded data types to support various precisions, pre-calculating constant values like the attention scale and data type minimums outside the forward pass to optimize graph construction, and moving configuration parsing logic into the configuration class to prevent side effects from in-place mutation.

Comment on lines +77 to +78
dtype = DType.bfloat16
self.config.dtype = dtype
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Hardcoding DType.bfloat16 here overrides the encoding and config settings provided during initialization. While the comment explains this is a workaround for specific Wan 2.1 checkpoints, it limits the reusability of the UMT5Model component for other models that might require different precision (e.g., float32 or float16). This precision-forcing logic should ideally be handled during configuration resolution or guarded by a check to ensure it only applies when necessary.

relative_position,
)

scale = (num_buckets - max_exact) / math.log(max_distance / max_exact)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The scale value is constant for a given model configuration and does not depend on the input sequence length. It should be pre-calculated in the __init__ method of UMT5Attention to avoid redundant math.log calls and arithmetic operations during graph construction for every layer.

Comment on lines +479 to +483
_DTYPE_MIN: dict[DType, float] = {
DType.float16: -65504.0,
DType.bfloat16: -3.3895314e38,
DType.float32: -3.4028235e38,
}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The _DTYPE_MIN dictionary is defined inside the __call__ method, which means it is re-allocated on every forward pass during graph construction. This is inefficient. Please move this dictionary to the module level or as a class attribute.

Comment on lines +509 to +517
act_info = config.feed_forward_proj.split("-")
config.dense_act_fn = act_info[-1]
config.is_gated_act = act_info[0] == "gated"
if (len(act_info) > 1 and act_info[0] != "gated") or len(act_info) > 2:
raise ValueError(
f"`feed_forward_proj`: {config.feed_forward_proj} is not valid."
)
if config.feed_forward_proj == "gated-gelu":
config.dense_act_fn = "gelu_new"
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The UMT5EncoderModel constructor is mutating the config object in-place to store parsed activation information. This is a side effect that can affect other parts of the application using the same configuration instance. It is recommended to move this parsing logic into the UMT5ConfigBase class (e.g., as properties or using a Pydantic validator) to keep the model implementation clean and side-effect free.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant