Skip to content

Commit 8736e74

Browse files
committed
Update DoMINO to enable multi-batch tests to pass.
Deliberately looping an index_select over the batch index, rather than something vectorized, because we don't want to expand that op and it's domain parallelized.
1 parent 90694f8 commit 8736e74

3 files changed

Lines changed: 10 additions & 9 deletions

File tree

physicsnemo/models/domino/encodings.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -159,10 +159,11 @@ def forward(
159159
for j in range(encoding_g.shape[1]):
160160
geo_encoding = rearrange(encoding_g[:, j], "b nx ny nz -> b 1 (nx ny nz)")
161161

162-
geo_encoding_sampled = torch.index_select(
163-
geo_encoding, 2, mapping.flatten()
164-
)
165-
geo_encoding_sampled = torch.reshape(geo_encoding_sampled, mask.shape)
162+
sampled = []
163+
for b in range(batch_size):
164+
s = torch.index_select(geo_encoding[b], 1, mapping[b].flatten())
165+
sampled.append(s.reshape(1, *mask.shape[1:]))
166+
geo_encoding_sampled = torch.cat(sampled, dim=0)
166167
geo_encoding_sampled = geo_encoding_sampled * mask
167168

168169
encoding_g_inner.append(geo_encoding_sampled)

physicsnemo/models/domino/geometry_rep.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def forward(
195195
self.grid_resolution[1],
196196
self.grid_resolution[2],
197197
)
198-
grid = grid.reshape(1, nx * ny * nz, 3, 1)
198+
grid = grid.reshape(grid.shape[0], nx * ny * nz, 3, 1)
199199

200200
# Rearrange input to flatten spatial and neighbor dimensions
201201
x = rearrange(

physicsnemo/models/domino/model.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -574,8 +574,8 @@ def forward(
574574

575575
# Normalize geometry coordinates based on computational domain
576576
if "volume_min_max" in data_dict.keys():
577-
vol_max = data_dict["volume_min_max"][:, 1]
578-
vol_min = data_dict["volume_min_max"][:, 0]
577+
vol_max = data_dict["volume_min_max"][:, 1].unsqueeze(1)
578+
vol_min = data_dict["volume_min_max"][:, 0].unsqueeze(1)
579579
geo_centers_vol = (
580580
2.0 * (geo_centers - vol_min) / (vol_max - vol_min) - 1
581581
)
@@ -618,8 +618,8 @@ def forward(
618618
# Represent geometry on bounding box
619619
# Normalize geometry coordinates based on surface bounding box
620620
if "surface_min_max" in data_dict.keys():
621-
surf_max = data_dict["surface_min_max"][:, 1]
622-
surf_min = data_dict["surface_min_max"][:, 0]
621+
surf_max = data_dict["surface_min_max"][:, 1].unsqueeze(1)
622+
surf_min = data_dict["surface_min_max"][:, 0].unsqueeze(1)
623623
geo_centers_surf = (
624624
2.0 * (geo_centers - surf_min) / (surf_max - surf_min) - 1
625625
)

0 commit comments

Comments
 (0)