Skip to content

Commit 7d256af

Browse files
prova 6
1 parent aa91c1f commit 7d256af

1 file changed

Lines changed: 6 additions & 1 deletion

File tree

pina/_src/solver/multi_model_simple_solver.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,8 +206,13 @@ def forward(self, x):
206206
if self._active_model_idx is not None:
207207
return self.models[self._active_model_idx](x)
208208

209+
# Strip tensor subclasses before stacking so the ensemble dimension
210+
# stays unlabeled and the outer LabelTensor wrapper relabels once.
209211
return torch.stack(
210-
[self.models[idx](x) for idx in range(self.num_models)],
212+
[
213+
self.models[idx](x).as_subclass(torch.Tensor)
214+
for idx in range(self.num_models)
215+
],
211216
)
212217

213218
def training_step(self, batch):

0 commit comments

Comments
 (0)