Skip to content

Commit 91b2043

Browse files
committed
add permutation in forward function of UNet
1 parent d1ef675 commit 91b2043

1 file changed

Lines changed: 4 additions & 0 deletions

File tree

src/train.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,10 @@ def forward(self, x: th.Tensor) -> th.Tensor:
8787
Returns:
8888
th.Tensor: Segmented output.
8989
"""
90+
x = x.permute(
91+
(0, 1, 4, 2, 3)
92+
) # Permute such that x has shape (batch, channels, depth, height, width)
93+
9094
# 3.3 TODO: Implement the forward pass of the UNet using the building blocks
9195
# defined in the __init__ function and the upsampling function.
9296
return th.tensor(0.0)

0 commit comments

Comments
 (0)