Skip to content

Commit 52854a3

Browse files
Merge pull request #408 from AI-Hypercomputer:ninatu/fix_text_encoder
PiperOrigin-RevId: 918533638
2 parents 19d4e4d + 46e96da commit 52854a3

7 files changed

Lines changed: 34 additions & 2 deletions

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,11 @@ revision: ''
4141
weights_dtype: 'bfloat16'
4242
# This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype)
4343
activations_dtype: 'bfloat16'
44+
# The dtype for text_encoder model during load/compile
45+
text_encoder_dtype: 'float32'
46+
47+
# Whether to compile the text_encoder with torch.compile
48+
compile_text_encoder: False
4449

4550
# Replicates vae across devices instead of using the model's sharding annotations for sharding.
4651
replicate_vae: False

src/maxdiffusion/configs/base_wan_1_3b.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,11 @@ revision: ''
4141
weights_dtype: 'bfloat16'
4242
# This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype)
4343
activations_dtype: 'bfloat16'
44+
# The dtype for text_encoder model during load/compile
45+
text_encoder_dtype: 'float32'
46+
47+
# Whether to compile the text_encoder with torch.compile
48+
compile_text_encoder: False
4449

4550
# Replicates vae across devices instead of using the model's sharding annotations for sharding.
4651
replicate_vae: False

src/maxdiffusion/configs/base_wan_27b.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,11 @@ revision: ''
4141
weights_dtype: 'bfloat16'
4242
# This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype)
4343
activations_dtype: 'bfloat16'
44+
# The dtype for text_encoder model during load/compile
45+
text_encoder_dtype: 'float32'
46+
47+
# Whether to compile the text_encoder with torch.compile
48+
compile_text_encoder: False
4449

4550
# Replicates vae across devices instead of using the model's sharding annotations for sharding.
4651
replicate_vae: False

src/maxdiffusion/configs/base_wan_animate.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,11 @@ revision: ''
4141
weights_dtype: 'bfloat16'
4242
# This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype)
4343
activations_dtype: 'bfloat16'
44+
# The dtype for text_encoder model during load/compile
45+
text_encoder_dtype: 'float32'
46+
47+
# Whether to compile the text_encoder with torch.compile
48+
compile_text_encoder: False
4449

4550
# Replicates vae across devices instead of using the model's sharding annotations for sharding.
4651
replicate_vae: False

src/maxdiffusion/configs/base_wan_i2v_14b.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,11 @@ revision: ''
4141
weights_dtype: 'bfloat16'
4242
# This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype)
4343
activations_dtype: 'bfloat16'
44+
# The dtype for text_encoder model during load/compile
45+
text_encoder_dtype: 'float32'
46+
47+
# Whether to compile the text_encoder with torch.compile
48+
compile_text_encoder: False
4449

4550
# Replicates vae across devices instead of using the model's sharding annotations for sharding.
4651
replicate_vae: False

src/maxdiffusion/configs/base_wan_i2v_27b.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,11 @@ revision: ''
4141
weights_dtype: 'bfloat16'
4242
# This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype)
4343
activations_dtype: 'bfloat16'
44+
# The dtype for text_encoder model during load/compile
45+
text_encoder_dtype: 'float32'
46+
47+
# Whether to compile the text_encoder with torch.compile
48+
compile_text_encoder: False
4449

4550
# Replicates vae across devices instead of using the model's sharding annotations for sharding.
4651
replicate_vae: False

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -270,13 +270,15 @@ def __init__(
270270

271271
@classmethod
272272
def load_text_encoder(cls, config: HyperParameters):
273-
torch_dtype = getattr(torch, str(config.weights_dtype), torch.float32)
273+
text_encoder_dtype = getattr(config, "text_encoder_dtype", "float32")
274+
torch_dtype = getattr(torch, str(text_encoder_dtype), torch.float32)
274275
text_encoder = UMT5EncoderModel.from_pretrained(
275276
config.pretrained_model_name_or_path,
276277
subfolder="text_encoder",
277278
torch_dtype=torch_dtype,
278279
)
279-
text_encoder = torch.compile(text_encoder)
280+
if getattr(config, "compile_text_encoder", True):
281+
text_encoder = torch.compile(text_encoder)
280282
return text_encoder
281283

282284
@classmethod

0 commit comments

Comments
 (0)