@@ -35,29 +35,28 @@ def get_sinusoidal_embeddings(
3535 """Returns the positional encoding (same as Tensor2Tensor).
3636
3737 Args:
38- timesteps: a 1-D Tensor of N indices, one per batch element .
38+ timesteps: a 1-D or 2-D Tensor of indices.
3939 These may be fractional.
4040 embedding_dim: The number of output channels.
4141 min_timescale: The smallest time unit (should probably be 0.0).
4242 max_timescale: The largest time unit.
4343 Returns:
44- a Tensor of timing signals [N, num_channels]
44+ a Tensor of timing signals [B, num_channels] or [B, N, num_channels]
4545 """
46- assert timesteps .ndim == 1 , "Timesteps should be a 1d-array"
46+ assert timesteps .ndim <= 2 , "Timesteps should be a 1d or 2d -array"
4747 assert embedding_dim % 2 == 0 , f"Embedding dimension { embedding_dim } should be even"
4848 num_timescales = float (embedding_dim // 2 )
4949 log_timescale_increment = math .log (max_timescale / min_timescale ) / (num_timescales - freq_shift )
5050 inv_timescales = min_timescale * jnp .exp (jnp .arange (num_timescales , dtype = jnp .float32 ) * - log_timescale_increment )
51- emb = jnp .expand_dims (timesteps , 1 ) * jnp . expand_dims ( inv_timescales , 0 )
51+ emb = jnp .expand_dims (timesteps , - 1 ) * inv_timescales
5252
5353 # scale embeddings
5454 scaled_time = scale * emb
5555
5656 if flip_sin_to_cos :
57- signal = jnp .concatenate ([jnp .cos (scaled_time ), jnp .sin (scaled_time )], axis = 1 )
57+ signal = jnp .concatenate ([jnp .cos (scaled_time ), jnp .sin (scaled_time )], axis = - 1 )
5858 else :
59- signal = jnp .concatenate ([jnp .sin (scaled_time ), jnp .cos (scaled_time )], axis = 1 )
60- signal = jnp .reshape (signal , [jnp .shape (timesteps )[0 ], embedding_dim ])
59+ signal = jnp .concatenate ([jnp .sin (scaled_time ), jnp .cos (scaled_time )], axis = - 1 )
6160 return signal
6261
6362
@@ -84,7 +83,7 @@ def __init__(
8483 sample_proj_bias = True ,
8584 dtype : jnp .dtype = jnp .float32 ,
8685 weights_dtype : jnp .dtype = jnp .float32 ,
87- precision : jax .lax .Precision = None ,
86+ precision : jax .lax .Precision | None = None ,
8887 ):
8988 self .linear_1 = nnx .Linear (
9089 rngs = rngs ,
@@ -221,7 +220,7 @@ def __call__(self, timesteps):
221220
222221def get_1d_rotary_pos_embed (
223222 dim : int ,
224- pos : Union [jnp .array , int ],
223+ pos : Union [jnp .ndarray , int ],
225224 theta : float = 10000.0 ,
226225 linear_factor = 1.0 ,
227226 ntk_factor = 1.0 ,
@@ -332,11 +331,11 @@ def __init__(
332331 rngs : nnx .Rngs ,
333332 in_features : int ,
334333 hidden_size : int ,
335- out_features : int = None ,
334+ out_features : int | None = None ,
336335 act_fn : str = "gelu_tanh" ,
337336 dtype : jnp .dtype = jnp .float32 ,
338337 weights_dtype : jnp .dtype = jnp .float32 ,
339- precision : jax .lax .Precision = None ,
338+ precision : jax .lax .Precision | None = None ,
340339 ):
341340 if out_features is None :
342341 out_features = hidden_size
@@ -392,11 +391,11 @@ class PixArtAlphaTextProjection(nn.Module):
392391 """
393392
394393 hidden_size : int
395- out_features : int = None
394+ out_features : int | None = None
396395 act_fn : str = "gelu_tanh"
397396 dtype : jnp .dtype = jnp .float32
398397 weights_dtype : jnp .dtype = jnp .float32
399- precision : jax .lax .Precision = None
398+ precision : jax .lax .Precision | None = None
400399
401400 @nn .compact
402401 def __call__ (self , caption ):
@@ -455,7 +454,7 @@ class CombinedTimestepTextProjEmbeddings(nn.Module):
455454 pooled_projection_dim : int
456455 dtype : jnp .dtype = jnp .float32
457456 weights_dtype : jnp .dtype = jnp .float32
458- precision : jax .lax .Precision = None
457+ precision : jax .lax .Precision | None = None
459458
460459 @nn .compact
461460 def __call__ (self , timestep , pooled_projection ):
@@ -479,7 +478,7 @@ class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module):
479478 pooled_projection_dim : int
480479 dtype : jnp .dtype = jnp .float32
481480 weights_dtype : jnp .dtype = jnp .float32
482- precision : jax .lax .Precision = None
481+ precision : jax .lax .Precision | None = None
483482
484483 @nn .compact
485484 def __call__ (self , timestep , guidance , pooled_projection ):
0 commit comments