Skip to content

Commit d40fc4d

Browse files
committed
fix to device for station ids
1 parent 017ab05 commit d40fc4d

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

src/pyvisgen/simulation/observation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def get_valid_subset(self, num_baselines: int, device: str):
132132
stations_2 = bas_reshaped.st2[:-1][mask].ravel()
133133
baseline_nums = (256 * (stations_1 + 1) + stations_2 + 1).to(device)
134134

135-
st_id_pairs = torch.stack([stations_1, stations_2], dim=1)
135+
st_id_pairs = torch.stack([stations_1, stations_2], dim=1).to(device)
136136

137137
u_start = bas_reshaped.u[:-1][mask].to(device)
138138
v_start = bas_reshaped.v[:-1][mask].to(device)

0 commit comments

Comments
 (0)