Skip to content

Commit ca60290

Browse files
authored
Fix docstring nits (#3758)
1 parent 548dd80 commit ca60290

2 files changed

Lines changed: 3 additions & 3 deletions

File tree

python/mlx/nn/layers/linear.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,10 +115,10 @@ class Bilinear(Module):
115115
116116
.. math::
117117
118-
y_i = x_1^\top W_i x_2 + b_i
118+
y_i = x_2^\top W_i x_1 + b_i
119119
120120
where:
121-
:math:`W` has shape ``[output_dims, input1_dims, input2_dims]``, :math:`b` has shape ``[output_dims ]``,
121+
:math:`W` has shape ``[output_dims, input2_dims, input1_dims]``, :math:`b` has shape ``[output_dims]``,
122122
and :math:`i` indexes the output dimension.
123123
124124
The values are initialized from the uniform distribution :math:`\mathcal{U}(-{k}, {k})`,

python/mlx/nn/losses.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -599,7 +599,7 @@ def margin_ranking_loss(
599599
>>> inputs2 = mx.array([0.75596, 0.225763, 0.256995])
600600
>>> loss = nn.losses.margin_ranking_loss(inputs1, inputs2, targets)
601601
>>> loss
602-
array(0.773433, dtype=float32)
602+
array([1.32937, 0.990929, 0], dtype=float32)
603603
"""
604604
if not (inputs1.shape == inputs2.shape == targets.shape):
605605
raise ValueError(

0 commit comments

Comments
 (0)