@@ -222,6 +222,11 @@ class ToMeViTSelfAttention(_HFViTSelfAttention):
222222 - Stores the mean of *k* over heads in ``self._tome_info["metric"]`` so that
223223 the enclosing ``ToMeViTLayer`` can use it for bipartite matching without
224224 requiring changes to the intermediate ``ViTAttention`` wrapper.
225+
226+ Parameters
227+ ----------
228+ config : object
229+ The ViT model configuration.
225230 """
226231
227232 _tome_info : dict [str , Any ]
@@ -231,7 +236,21 @@ def forward(
231236 hidden_states : torch .Tensor ,
232237 head_mask : Optional [torch .Tensor ] = None ,
233238 ) -> Tuple [torch .Tensor , torch .Tensor ]:
234- """Forward pass with proportional attention and key-metric storage."""
239+ """
240+ Forward pass with proportional attention and key-metric storage.
241+
242+ Parameters
243+ ----------
244+ hidden_states : torch.Tensor
245+ Input token tensor of shape ``[batch, tokens, channels]``.
246+ head_mask : torch.Tensor, optional
247+ Mask for attention heads.
248+
249+ Returns
250+ -------
251+ Tuple[torch.Tensor, torch.Tensor]
252+ Context layer and attention probabilities.
253+ """
235254 batch_size = hidden_states .shape [0 ]
236255 new_shape = (batch_size , - 1 , self .num_attention_heads , self .attention_head_size )
237256
@@ -268,6 +287,11 @@ class ToMeViTLayer(_HFViTLayer):
268287 performs bipartite soft matching on the key-metric stored in
269288 ``self._tome_info["metric"]`` and merges the ``r`` most similar token
270289 pairs before proceeding to the MLP sub-layer.
290+
291+ Parameters
292+ ----------
293+ config : object
294+ The ViT model configuration.
271295 """
272296
273297 _tome_info : dict [str , Any ]
@@ -277,7 +301,21 @@ def forward(
277301 hidden_states : torch .Tensor ,
278302 head_mask : Optional [torch .Tensor ] = None ,
279303 ) -> torch .Tensor :
280- """Forward pass with token merging between attention and MLP."""
304+ """
305+ Forward pass with token merging between attention and MLP.
306+
307+ Parameters
308+ ----------
309+ hidden_states : torch.Tensor
310+ Input token tensor of shape ``[batch, tokens, channels]``.
311+ head_mask : torch.Tensor, optional
312+ Mask for attention heads.
313+
314+ Returns
315+ -------
316+ torch.Tensor
317+ Output tensor after attention, token merging, and MLP.
318+ """
281319 # --- self-attention + first residual ---
282320 attention_output = self .attention (
283321 self .layernorm_before (hidden_states ),
@@ -354,7 +392,21 @@ def __init__(
354392 self .parsed_r = _parse_r (self .num_layers , self .r )
355393
356394 def forward (self , * args : Any , ** kwargs : Any ) -> Any :
357- """Initialise ToMe state and forward through the wrapped model."""
395+ """
396+ Initialise ToMe state and forward through the wrapped model.
397+
398+ Parameters
399+ ----------
400+ *args : Any
401+ Positional arguments forwarded to the wrapped model.
402+ **kwargs : Any
403+ Keyword arguments forwarded to the wrapped model.
404+
405+ Returns
406+ -------
407+ Any
408+ The output of the wrapped model's forward pass.
409+ """
358410 # Make a copy of the list to avoid modifying the original
359411 self ._tome_info ["r" ] = list (self .parsed_r )
360412 self ._tome_info ["size" ] = None
0 commit comments