Skip to content

Commit 811d2e6

Browse files
committed
docstrings
1 parent c90a294 commit 811d2e6

11 files changed

Lines changed: 124 additions & 0 deletions

src/diffusers/models/autoencoders/audio_tokenizer_ace_step.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,12 @@ def __init__(
341341
self.gradient_checkpointing = False
342342

343343
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
344+
"""
345+
Args:
346+
hidden_states (`torch.Tensor`):
347+
Input audio tokens of shape `(batch_size, num_tokens, hidden_size)` to be unpooled back to the 25 Hz
348+
acoustic-latent rate.
349+
"""
344350
batch_size, num_tokens, _ = hidden_states.shape
345351
hidden_states = self.embed_tokens(hidden_states)
346352
hidden_states = hidden_states.unsqueeze(2).expand(-1, -1, self.pool_window_size, -1)
@@ -436,6 +442,12 @@ def __init__(
436442
self.pool_window_size = pool_window_size
437443

438444
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
445+
"""
446+
Args:
447+
hidden_states (`torch.Tensor`):
448+
Input acoustic latents of shape `(batch_size, latent_length, audio_acoustic_hidden_dim)` to be
449+
quantized into ACE-Step 5 Hz audio tokens.
450+
"""
439451
input_dtype = hidden_states.dtype
440452
hidden_states = self.audio_acoustic_proj(hidden_states)
441453
hidden_states = self.attention_pooler(hidden_states)

src/diffusers/models/autoencoders/latent_upsampler_ltx.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,12 @@ def __init__(
144144
self.final_conv = ConvNd(mid_channels, in_channels, kernel_size=3, padding=1)
145145

146146
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
147+
"""
148+
Args:
149+
hidden_states (`torch.Tensor`):
150+
Input latents of shape `(batch_size, num_channels, num_frames, height, width)` to spatially or
151+
temporally upsample.
152+
"""
147153
batch_size, num_channels, num_frames, height, width = hidden_states.shape
148154

149155
if self.dims == 2:

src/diffusers/models/autoencoders/latent_upsampler_ltx2.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,12 @@ def __init__(
243243
self.final_conv = ConvNd(mid_channels, in_channels, kernel_size=3, padding=1)
244244

245245
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
246+
"""
247+
Args:
248+
hidden_states (`torch.Tensor`):
249+
Input latents of shape `(batch_size, num_channels, num_frames, height, width)` to spatially or
250+
temporally upsample.
251+
"""
246252
batch_size, num_channels, num_frames, height, width = hidden_states.shape
247253

248254
if self.dims == 2:

src/diffusers/models/autoencoders/vocoder_ltx2.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -572,6 +572,11 @@ def __init__(
572572
)
573573

574574
def forward(self, mel_spec: torch.Tensor) -> torch.Tensor:
575+
"""
576+
Args:
577+
mel_spec (`torch.Tensor`):
578+
Input mel spectrogram of shape `(batch_size, num_channels, num_frames, num_mel_bins)`.
579+
"""
575580
# 1. Run stage 1 vocoder to get low sampling rate waveform
576581
x = self.vocoder(mel_spec)
577582
batch_size, num_channels, num_samples = x.shape

src/diffusers/models/condition_embedders/condition_encoder_ace_step.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,13 @@ def forward(
178178
inputs_embeds: torch.FloatTensor,
179179
attention_mask: torch.Tensor,
180180
) -> torch.Tensor:
181+
"""
182+
Args:
183+
inputs_embeds (`torch.FloatTensor`):
184+
Lyric token ids of shape `(batch_size, sequence_length)` to embed and encode.
185+
attention_mask (`torch.Tensor`):
186+
Attention mask of shape `(batch_size, sequence_length)` indicating which tokens are valid.
187+
"""
181188
inputs_embeds = self.embed_tokens(inputs_embeds)
182189

183190
seq_len = inputs_embeds.shape[1]
@@ -317,6 +324,15 @@ def forward(
317324
refer_audio_acoustic_hidden_states_packed: torch.FloatTensor,
318325
refer_audio_order_mask: torch.LongTensor,
319326
) -> Tuple[torch.Tensor, torch.Tensor]:
327+
"""
328+
Args:
329+
refer_audio_acoustic_hidden_states_packed (`torch.FloatTensor`):
330+
Packed reference-audio acoustic hidden states of shape `(total_tokens, hidden_size)` across all
331+
reference samples in the batch.
332+
refer_audio_order_mask (`torch.LongTensor`):
333+
Batch-index assignment of shape `(total_tokens,)` indicating which reference sample each packed token
334+
belongs to.
335+
"""
320336
inputs_embeds = self.embed_tokens(refer_audio_acoustic_hidden_states_packed)
321337

322338
seq_len = inputs_embeds.shape[1]
@@ -447,6 +463,22 @@ def forward(
447463
refer_audio_acoustic_hidden_states_packed: torch.FloatTensor,
448464
refer_audio_order_mask: torch.LongTensor,
449465
) -> Tuple[torch.Tensor, torch.Tensor]:
466+
"""
467+
Args:
468+
text_hidden_states (`torch.FloatTensor`):
469+
Text encoder hidden states of shape `(batch_size, text_sequence_length, text_hidden_dim)`.
470+
text_attention_mask (`torch.Tensor`):
471+
Attention mask of shape `(batch_size, text_sequence_length)` for the text hidden states.
472+
lyric_hidden_states (`torch.FloatTensor`):
473+
Lyric token ids of shape `(batch_size, lyric_sequence_length)` to be encoded by the lyric encoder.
474+
lyric_attention_mask (`torch.Tensor`):
475+
Attention mask of shape `(batch_size, lyric_sequence_length)` for the lyric tokens.
476+
refer_audio_acoustic_hidden_states_packed (`torch.FloatTensor`):
477+
Packed reference-audio acoustic hidden states of shape `(total_tokens, hidden_size)`.
478+
refer_audio_order_mask (`torch.LongTensor`):
479+
Batch-index assignment of shape `(total_tokens,)` indicating which reference sample each packed token
480+
belongs to.
481+
"""
450482
text_hidden_states = self.text_projector(text_hidden_states)
451483

452484
lyric_hidden_states = self.lyric_encoder(

src/diffusers/models/condition_embedders/image_encoder_redux.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,12 @@ def __init__(
4141
self.redux_down = nn.Linear(txt_in_features * 3, txt_in_features)
4242

4343
def forward(self, x: torch.Tensor) -> ReduxImageEncoderOutput:
44+
"""
45+
Args:
46+
x (`torch.Tensor`):
47+
Image embeddings of shape `(batch_size, sequence_length, redux_dim)` produced by the SigLIP image
48+
encoder.
49+
"""
4450
projected_x = self.redux_down(nn.functional.silu(self.redux_up(x)))
4551

4652
return ReduxImageEncoderOutput(image_embeds=projected_x)

src/diffusers/models/condition_embedders/projection_audioldm2.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,18 @@ def forward(
109109
attention_mask: torch.LongTensor | None = None,
110110
attention_mask_1: torch.LongTensor | None = None,
111111
):
112+
"""
113+
Args:
114+
hidden_states (`torch.Tensor`, *optional*):
115+
Hidden states from the first text encoder of shape `(batch_size, sequence_length, text_encoder_dim)`.
116+
hidden_states_1 (`torch.Tensor`, *optional*):
117+
Hidden states from the second text encoder of shape `(batch_size, sequence_length_1,
118+
text_encoder_1_dim)`.
119+
attention_mask (`torch.LongTensor`, *optional*):
120+
Attention mask of shape `(batch_size, sequence_length)` for `hidden_states`.
121+
attention_mask_1 (`torch.LongTensor`, *optional*):
122+
Attention mask of shape `(batch_size, sequence_length_1)` for `hidden_states_1`.
123+
"""
112124
hidden_states = self.projection(hidden_states)
113125
hidden_states, attention_mask = add_special_tokens(
114126
hidden_states, attention_mask, sos_token=self.sos_embed, eos_token=self.eos_embed

src/diffusers/models/condition_embedders/projection_clip_image.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,9 @@ def __init__(self, hidden_size: int = 768):
2626
self.project = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
2727

2828
def forward(self, x):
29+
"""
30+
Args:
31+
x (`torch.Tensor`):
32+
Input CLIP image embeddings of shape `(batch_size, hidden_size)`.
33+
"""
2934
return self.project(x)

src/diffusers/models/condition_embedders/projection_stable_audio.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,15 @@ def forward(
141141
start_seconds: torch.Tensor | None = None,
142142
end_seconds: torch.Tensor | None = None,
143143
):
144+
"""
145+
Args:
146+
text_hidden_states (`torch.Tensor`, *optional*):
147+
Hidden states from the text encoder of shape `(batch_size, sequence_length, text_encoder_dim)`.
148+
start_seconds (`torch.Tensor`, *optional*):
149+
Start-time-in-seconds conditioning values of shape `(batch_size,)`.
150+
end_seconds (`torch.Tensor`, *optional*):
151+
End-time-in-seconds conditioning values of shape `(batch_size,)`.
152+
"""
144153
text_hidden_states = (
145154
text_hidden_states if text_hidden_states is None else self.text_projection(text_hidden_states)
146155
)

src/diffusers/models/others/renderer_shap_e.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -659,6 +659,21 @@ def map_indices_to_keys(self, output):
659659
return mapped_output
660660

661661
def forward(self, *, position, direction, ts, nerf_level="coarse", rendering_mode="nerf"):
662+
"""
663+
Args:
664+
position (`torch.Tensor`):
665+
3D query positions of shape `(batch_size, ..., 3)` to evaluate the NeRSTF MLP at.
666+
direction (`torch.Tensor`):
667+
Viewing directions of shape `(batch_size, ..., 3)` used for view-dependent color prediction.
668+
ts (`torch.Tensor`):
669+
Per-ray sample distances of shape `(batch_size, ..., 1)` passed through to the output for downstream
670+
integration.
671+
nerf_level (`str`, *optional*, defaults to `"coarse"`):
672+
Which density/color head to read from — `"coarse"` or `"fine"`.
673+
rendering_mode (`str`, *optional*, defaults to `"nerf"`):
674+
Output head to use: `"nerf"` for radiance-field colors or `"stf"` for the signed-distance/texture
675+
field.
676+
"""
662677
h = encode_position(position)
663678

664679
h_preact = h
@@ -769,6 +784,12 @@ def __init__(
769784
)
770785

771786
def forward(self, x: torch.Tensor):
787+
"""
788+
Args:
789+
x (`torch.Tensor`):
790+
Latent representation of a 3D asset of shape `(batch_size, total_vectors, d_latent)`, sliced per
791+
`param_name` and projected to each MLP weight tensor.
792+
"""
772793
out = {}
773794
start = 0
774795
for k, shape in zip(self.config.param_names, self.config.param_shapes):

0 commit comments

Comments
 (0)