Skip to content

Commit b6b4529

Browse files
committed
fix ut
1 parent 0738f93 commit b6b4529

2 files changed

Lines changed: 4 additions & 2 deletions

File tree

source/tests/pt/model/test_saveload_dpa1.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,8 @@ def create_wrapper(self, read: bool):
115115
return ModelWrapper(model, self.loss)
116116

117117
def get_data(self):
118-
batch_data = next(self.training_data)
118+
with torch.device("cpu"):
119+
batch_data = next(self.training_data)
119120
input_dict = {}
120121
for item in ["coord", "atype", "box"]:
121122
if item in batch_data:

source/tests/pt/model/test_saveload_se_e2_a.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,8 @@ def create_wrapper(self):
109109
return ModelWrapper(model, self.loss)
110110

111111
def get_data(self):
112-
batch_data = next(self.training_data)
112+
with torch.device("cpu"):
113+
batch_data = next(self.training_data)
113114
input_dict = {}
114115
for item in ["coord", "atype", "box"]:
115116
if item in batch_data:

0 commit comments

Comments
 (0)