@@ -128,6 +128,8 @@ def __init__(
128128 :raises AssertionError: If ``order`` is not a positive integer.
129129 :raises ValueError: If ``knots`` is neither a torch.Tensor nor a
130130 dictionary, when provided.
131+ :raises ValueError: If ``aggregate_output`` is not None, "mean", or
132+ "sum".
131133 :raises ValueError: If ``control_points`` is not a torch.Tensor,
132134 when provided.
133135 :raises ValueError: If both ``knots`` and ``control_points`` are None.
@@ -155,6 +157,13 @@ def __init__(
155157 if knots is None and control_points is None :
156158 raise ValueError ("knots and control_points cannot both be None." )
157159
160+ # Raise error if aggregate_output is not None, "mean", or "sum"
161+ if aggregate_output not in (None , "mean" , "sum" ):
162+ raise ValueError (
163+ f"aggregate_output must be None, 'mean', or 'sum'."
164+ f" Got { aggregate_output } ."
165+ )
166+
158167 # Initialize knots if not provided
159168 if knots is None and control_points is not None :
160169 knots = {
@@ -323,9 +332,11 @@ def forward(self, x):
323332 The input is expected to have shape ``[batch, s]``, where ``s`` is the
324333 number of univariate splines. The output has shape ``[batch, s, o]``,
325334 where ``o`` is the output dimension of each univariate spline, unless an
326- aggregation method is specified. If ``aggregate_output`` is set to
327- ``"mean"`` or ``"sum"``, the output is aggregated across the last
328- dimension, resulting in an output of shape ``[batch, s]``.
335+ aggregation method is specified. If both ``s`` and ``o`` are 1, the
336+ output is aggregated across the last dimension, resulting in an output
337+ of shape ``[batch, s]``. If ``aggregate_output`` is set to ``"mean"`` or
338+ ``"sum"``, the output is aggregated across the last dimension, resulting
339+ in an output of shape ``[batch, s]``.
329340
330341 :param x: The input tensor.
331342 :type x: torch.Tensor | LabelTensor
@@ -343,6 +354,8 @@ def forward(self, x):
343354 out = out .mean (dim = - 1 )
344355 elif self .aggregate_output == "sum" :
345356 out = out .sum (dim = - 1 )
357+ elif out .shape [1 ] == 1 and out .shape [2 ] == 1 :
358+ out = out .squeeze (- 1 )
346359
347360 return out
348361
@@ -483,7 +496,7 @@ def knots(self, value):
483496 value = value .unsqueeze (0 ).repeat (n_splines , 1 )
484497
485498 # Set knots
486- self .register_buffer ("_knots" , value .sort (dim = 1 ).values )
499+ self .register_buffer ("_knots" , value .sort (dim = - 1 ).values )
487500
488501 # Recompute boundary interval when knots change
489502 if hasattr (self , "_boundary_interval_idx" ):
0 commit comments