Skip to content

Commit ab10595

Browse files
committed
add support for wan vae 2.2 & fix hacky wan vae 2.1
1 parent 18f6f0f commit ab10595

4 files changed

Lines changed: 1494 additions & 45 deletions

File tree

src/maxdiffusion/models/embeddings_flax.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -35,29 +35,28 @@ def get_sinusoidal_embeddings(
3535
"""Returns the positional encoding (same as Tensor2Tensor).
3636
3737
Args:
38-
timesteps: a 1-D Tensor of N indices, one per batch element.
38+
timesteps: a 1-D or 2-D Tensor of indices.
3939
These may be fractional.
4040
embedding_dim: The number of output channels.
4141
min_timescale: The smallest time unit (should probably be 0.0).
4242
max_timescale: The largest time unit.
4343
Returns:
44-
a Tensor of timing signals [N, num_channels]
44+
a Tensor of timing signals [B, num_channels] or [B, N, num_channels]
4545
"""
46-
assert timesteps.ndim == 1, "Timesteps should be a 1d-array"
46+
assert timesteps.ndim <= 2, "Timesteps should be a 1d or 2d-array"
4747
assert embedding_dim % 2 == 0, f"Embedding dimension {embedding_dim} should be even"
4848
num_timescales = float(embedding_dim // 2)
4949
log_timescale_increment = math.log(max_timescale / min_timescale) / (num_timescales - freq_shift)
5050
inv_timescales = min_timescale * jnp.exp(jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment)
51-
emb = jnp.expand_dims(timesteps, 1) * jnp.expand_dims(inv_timescales, 0)
51+
emb = jnp.expand_dims(timesteps, -1) * inv_timescales
5252

5353
# scale embeddings
5454
scaled_time = scale * emb
5555

5656
if flip_sin_to_cos:
57-
signal = jnp.concatenate([jnp.cos(scaled_time), jnp.sin(scaled_time)], axis=1)
57+
signal = jnp.concatenate([jnp.cos(scaled_time), jnp.sin(scaled_time)], axis=-1)
5858
else:
59-
signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=1)
60-
signal = jnp.reshape(signal, [jnp.shape(timesteps)[0], embedding_dim])
59+
signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=-1)
6160
return signal
6261

6362

@@ -84,7 +83,7 @@ def __init__(
8483
sample_proj_bias=True,
8584
dtype: jnp.dtype = jnp.float32,
8685
weights_dtype: jnp.dtype = jnp.float32,
87-
precision: jax.lax.Precision = None,
86+
precision: jax.lax.Precision | None = None,
8887
):
8988
self.linear_1 = nnx.Linear(
9089
rngs=rngs,
@@ -221,7 +220,7 @@ def __call__(self, timesteps):
221220

222221
def get_1d_rotary_pos_embed(
223222
dim: int,
224-
pos: Union[jnp.array, int],
223+
pos: Union[jnp.ndarray, int],
225224
theta: float = 10000.0,
226225
linear_factor=1.0,
227226
ntk_factor=1.0,
@@ -332,11 +331,11 @@ def __init__(
332331
rngs: nnx.Rngs,
333332
in_features: int,
334333
hidden_size: int,
335-
out_features: int = None,
334+
out_features: int | None = None,
336335
act_fn: str = "gelu_tanh",
337336
dtype: jnp.dtype = jnp.float32,
338337
weights_dtype: jnp.dtype = jnp.float32,
339-
precision: jax.lax.Precision = None,
338+
precision: jax.lax.Precision | None = None,
340339
):
341340
if out_features is None:
342341
out_features = hidden_size
@@ -392,11 +391,11 @@ class PixArtAlphaTextProjection(nn.Module):
392391
"""
393392

394393
hidden_size: int
395-
out_features: int = None
394+
out_features: int | None = None
396395
act_fn: str = "gelu_tanh"
397396
dtype: jnp.dtype = jnp.float32
398397
weights_dtype: jnp.dtype = jnp.float32
399-
precision: jax.lax.Precision = None
398+
precision: jax.lax.Precision | None = None
400399

401400
@nn.compact
402401
def __call__(self, caption):
@@ -455,7 +454,7 @@ class CombinedTimestepTextProjEmbeddings(nn.Module):
455454
pooled_projection_dim: int
456455
dtype: jnp.dtype = jnp.float32
457456
weights_dtype: jnp.dtype = jnp.float32
458-
precision: jax.lax.Precision = None
457+
precision: jax.lax.Precision | None = None
459458

460459
@nn.compact
461460
def __call__(self, timestep, pooled_projection):
@@ -479,7 +478,7 @@ class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module):
479478
pooled_projection_dim: int
480479
dtype: jnp.dtype = jnp.float32
481480
weights_dtype: jnp.dtype = jnp.float32
482-
precision: jax.lax.Precision = None
481+
precision: jax.lax.Precision | None = None
483482

484483
@nn.compact
485484
def __call__(self, timestep, guidance, pooled_projection):

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,8 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=0):
360360
feat_cache = _update_cache(feat_cache, idx, cache_x)
361361
feat_idx += 1
362362
x = x.reshape(b, t, h, w, 2, c)
363-
x = jnp.stack([x[:, :, :, :, 0, :], x[:, :, :, :, 1, :]], axis=1)
363+
# x = jnp.stack([x[:, :, :, :, 0, :], x[:, :, :, :, 1, :]], axis=1)
364+
x = x.transpose(0, 1, 4, 2, 3, 5)
364365
x = x.reshape(b, t * 2, h, w, c)
365366
t = x.shape[1]
366367
x = x.reshape(b * t, h, w, c)
@@ -1160,23 +1161,7 @@ def _decode(
11601161
out, dec_feat_map, conv_idx = self.decoder(x[:, i : i + 1, :, :, :], feat_cache=dec_feat_map, feat_idx=conv_idx)
11611162
else:
11621163
out_, dec_feat_map, conv_idx = self.decoder(x[:, i : i + 1, :, :, :], feat_cache=dec_feat_map, feat_idx=conv_idx)
1163-
1164-
# This is to bypass an issue where frame[1] should be frame[2] and vise versa.
1165-
# Ideally shouldn't need to do this however, can't find where the frame is going out of sync.
1166-
# Most likely due to an incorrect reshaping in the decoder.
1167-
fm1, fm2, fm3, fm4 = out_[:, 0, :, :, :], out_[:, 1, :, :, :], out_[:, 2, :, :, :], out_[:, 3, :, :, :]
1168-
# When batch_size is 0, expand batch dim for concatenation
1169-
# else, expand frame dim for concatenation so that batch dim stays intact.
1170-
axis = 0
1171-
if fm1.shape[0] > 1:
1172-
axis = 1
1173-
1174-
if len(fm1.shape) == 4:
1175-
fm1 = jnp.expand_dims(fm1, axis=axis)
1176-
fm2 = jnp.expand_dims(fm2, axis=axis)
1177-
fm3 = jnp.expand_dims(fm3, axis=axis)
1178-
fm4 = jnp.expand_dims(fm4, axis=axis)
1179-
out = jnp.concatenate([out, fm1, fm3, fm2, fm4], axis=1)
1164+
out = jnp.concatenate([out, out_], axis=1)
11801165

11811166
feat_cache._feat_map = dec_feat_map
11821167

0 commit comments

Comments
 (0)