@@ -178,6 +178,13 @@ def forward(
178178 inputs_embeds : torch .FloatTensor ,
179179 attention_mask : torch .Tensor ,
180180 ) -> torch .Tensor :
181+ """
182+ Args:
183+ inputs_embeds (`torch.FloatTensor`):
184+ Lyric token ids of shape `(batch_size, sequence_length)` to embed and encode.
185+ attention_mask (`torch.Tensor`):
186+ Attention mask of shape `(batch_size, sequence_length)` indicating which tokens are valid.
187+ """
181188 inputs_embeds = self .embed_tokens (inputs_embeds )
182189
183190 seq_len = inputs_embeds .shape [1 ]
@@ -317,6 +324,15 @@ def forward(
317324 refer_audio_acoustic_hidden_states_packed : torch .FloatTensor ,
318325 refer_audio_order_mask : torch .LongTensor ,
319326 ) -> Tuple [torch .Tensor , torch .Tensor ]:
327+ """
328+ Args:
329+ refer_audio_acoustic_hidden_states_packed (`torch.FloatTensor`):
330+ Packed reference-audio acoustic hidden states of shape `(total_tokens, hidden_size)` across all
331+ reference samples in the batch.
332+ refer_audio_order_mask (`torch.LongTensor`):
333+ Batch-index assignment of shape `(total_tokens,)` indicating which reference sample each packed token
334+ belongs to.
335+ """
320336 inputs_embeds = self .embed_tokens (refer_audio_acoustic_hidden_states_packed )
321337
322338 seq_len = inputs_embeds .shape [1 ]
@@ -447,6 +463,22 @@ def forward(
447463 refer_audio_acoustic_hidden_states_packed : torch .FloatTensor ,
448464 refer_audio_order_mask : torch .LongTensor ,
449465 ) -> Tuple [torch .Tensor , torch .Tensor ]:
466+ """
467+ Args:
468+ text_hidden_states (`torch.FloatTensor`):
469+ Text encoder hidden states of shape `(batch_size, text_sequence_length, text_hidden_dim)`.
470+ text_attention_mask (`torch.Tensor`):
471+ Attention mask of shape `(batch_size, text_sequence_length)` for the text hidden states.
472+ lyric_hidden_states (`torch.FloatTensor`):
473+ Lyric token ids of shape `(batch_size, lyric_sequence_length)` to be encoded by the lyric encoder.
474+ lyric_attention_mask (`torch.Tensor`):
475+ Attention mask of shape `(batch_size, lyric_sequence_length)` for the lyric tokens.
476+ refer_audio_acoustic_hidden_states_packed (`torch.FloatTensor`):
477+ Packed reference-audio acoustic hidden states of shape `(total_tokens, hidden_size)`.
478+ refer_audio_order_mask (`torch.LongTensor`):
479+ Batch-index assignment of shape `(total_tokens,)` indicating which reference sample each packed token
480+ belongs to.
481+ """
450482 text_hidden_states = self .text_projector (text_hidden_states )
451483
452484 lyric_hidden_states = self .lyric_encoder (
0 commit comments