Skip to content

Commit d9b59ff

Browse files
minor fix to output dimension in vector splines
1 parent 88e4bbd commit d9b59ff

File tree

1 file changed

+17
-4
lines changed

1 file changed

+17
-4
lines changed

pina/_src/model/vectorized_spline.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,8 @@ def __init__(
128128
:raises AssertionError: If ``order`` is not a positive integer.
129129
:raises ValueError: If ``knots`` is neither a torch.Tensor nor a
130130
dictionary, when provided.
131+
:raises ValueError: If ``aggregate_output`` is not None, "mean", or
132+
"sum".
131133
:raises ValueError: If ``control_points`` is not a torch.Tensor,
132134
when provided.
133135
:raises ValueError: If both ``knots`` and ``control_points`` are None.
@@ -155,6 +157,13 @@ def __init__(
155157
if knots is None and control_points is None:
156158
raise ValueError("knots and control_points cannot both be None.")
157159

160+
# Raise error if aggregate_output is not None, "mean", or "sum"
161+
if aggregate_output not in (None, "mean", "sum"):
162+
raise ValueError(
163+
f"aggregate_output must be None, 'mean', or 'sum'."
164+
f" Got {aggregate_output}."
165+
)
166+
158167
# Initialize knots if not provided
159168
if knots is None and control_points is not None:
160169
knots = {
@@ -323,9 +332,11 @@ def forward(self, x):
323332
The input is expected to have shape ``[batch, s]``, where ``s`` is the
324333
number of univariate splines. The output has shape ``[batch, s, o]``,
325334
where ``o`` is the output dimension of each univariate spline, unless an
326-
aggregation method is specified. If ``aggregate_output`` is set to
327-
``"mean"`` or ``"sum"``, the output is aggregated across the last
328-
dimension, resulting in an output of shape ``[batch, s]``.
335+
aggregation method is specified. If both ``s`` and ``o`` are 1, the
336+
output is aggregated across the last dimension, resulting in an output
337+
of shape ``[batch, s]``. If ``aggregate_output`` is set to ``"mean"`` or
338+
``"sum"``, the output is aggregated across the last dimension, resulting
339+
in an output of shape ``[batch, s]``.
329340
330341
:param x: The input tensor.
331342
:type x: torch.Tensor | LabelTensor
@@ -343,6 +354,8 @@ def forward(self, x):
343354
out = out.mean(dim=-1)
344355
elif self.aggregate_output == "sum":
345356
out = out.sum(dim=-1)
357+
elif out.shape[1] == 1 and out.shape[2] == 1:
358+
out = out.squeeze(-1)
346359

347360
return out
348361

@@ -483,7 +496,7 @@ def knots(self, value):
483496
value = value.unsqueeze(0).repeat(n_splines, 1)
484497

485498
# Set knots
486-
self.register_buffer("_knots", value.sort(dim=1).values)
499+
self.register_buffer("_knots", value.sort(dim=-1).values)
487500

488501
# Recompute boundary interval when knots change
489502
if hasattr(self, "_boundary_interval_idx"):

0 commit comments

Comments
 (0)