We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent d1ef675 commit 91b2043Copy full SHA for 91b2043
1 file changed
src/train.py
@@ -87,6 +87,10 @@ def forward(self, x: th.Tensor) -> th.Tensor:
87
Returns:
88
th.Tensor: Segmented output.
89
"""
90
+ x = x.permute(
91
+ (0, 1, 4, 2, 3)
92
+ ) # Permute such that x has shape (batch, channels, depth, height, width)
93
+
94
# 3.3 TODO: Implement the forward pass of the UNet using the building blocks
95
# defined in the __init__ function and the upsampling function.
96
return th.tensor(0.0)
0 commit comments