Skip to content

Commit 4b562ab

Browse files
committed
docs: fix docstrings
1 parent 0c5e547 commit 4b562ab

1 file changed

Lines changed: 55 additions & 3 deletions

File tree

src/pruna/algorithms/token_merging.py

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,11 @@ class ToMeViTSelfAttention(_HFViTSelfAttention):
222222
- Stores the mean of *k* over heads in ``self._tome_info["metric"]`` so that
223223
the enclosing ``ToMeViTLayer`` can use it for bipartite matching without
224224
requiring changes to the intermediate ``ViTAttention`` wrapper.
225+
226+
Parameters
227+
----------
228+
config : object
229+
The ViT model configuration.
225230
"""
226231

227232
_tome_info: dict[str, Any]
@@ -231,7 +236,21 @@ def forward(
231236
hidden_states: torch.Tensor,
232237
head_mask: Optional[torch.Tensor] = None,
233238
) -> Tuple[torch.Tensor, torch.Tensor]:
234-
"""Forward pass with proportional attention and key-metric storage."""
239+
"""
240+
Forward pass with proportional attention and key-metric storage.
241+
242+
Parameters
243+
----------
244+
hidden_states : torch.Tensor
245+
Input token tensor of shape ``[batch, tokens, channels]``.
246+
head_mask : torch.Tensor, optional
247+
Mask for attention heads.
248+
249+
Returns
250+
-------
251+
Tuple[torch.Tensor, torch.Tensor]
252+
Context layer and attention probabilities.
253+
"""
235254
batch_size = hidden_states.shape[0]
236255
new_shape = (batch_size, -1, self.num_attention_heads, self.attention_head_size)
237256

@@ -268,6 +287,11 @@ class ToMeViTLayer(_HFViTLayer):
268287
performs bipartite soft matching on the key-metric stored in
269288
``self._tome_info["metric"]`` and merges the ``r`` most similar token
270289
pairs before proceeding to the MLP sub-layer.
290+
291+
Parameters
292+
----------
293+
config : object
294+
The ViT model configuration.
271295
"""
272296

273297
_tome_info: dict[str, Any]
@@ -277,7 +301,21 @@ def forward(
277301
hidden_states: torch.Tensor,
278302
head_mask: Optional[torch.Tensor] = None,
279303
) -> torch.Tensor:
280-
"""Forward pass with token merging between attention and MLP."""
304+
"""
305+
Forward pass with token merging between attention and MLP.
306+
307+
Parameters
308+
----------
309+
hidden_states : torch.Tensor
310+
Input token tensor of shape ``[batch, tokens, channels]``.
311+
head_mask : torch.Tensor, optional
312+
Mask for attention heads.
313+
314+
Returns
315+
-------
316+
torch.Tensor
317+
Output tensor after attention, token merging, and MLP.
318+
"""
281319
# --- self-attention + first residual ---
282320
attention_output = self.attention(
283321
self.layernorm_before(hidden_states),
@@ -354,7 +392,21 @@ def __init__(
354392
self.parsed_r = _parse_r(self.num_layers, self.r)
355393

356394
def forward(self, *args: Any, **kwargs: Any) -> Any:
357-
"""Initialise ToMe state and forward through the wrapped model."""
395+
"""
396+
Initialise ToMe state and forward through the wrapped model.
397+
398+
Parameters
399+
----------
400+
*args : Any
401+
Positional arguments forwarded to the wrapped model.
402+
**kwargs : Any
403+
Keyword arguments forwarded to the wrapped model.
404+
405+
Returns
406+
-------
407+
Any
408+
The output of the wrapped model's forward pass.
409+
"""
358410
# Make a copy of the list to avoid modifying the original
359411
self._tome_info["r"] = list(self.parsed_r)
360412
self._tome_info["size"] = None

0 commit comments

Comments
 (0)