@@ -79,7 +79,10 @@ def eval_model(
7979 if spins is not None :
8080 assert isinstance (spins , torch .Tensor ), err_msg
8181 assert isinstance (atom_types , torch .Tensor ) or isinstance (atom_types , list )
82- atom_types = torch .tensor (atom_types , dtype = torch .int32 , device = DEVICE )
82+ if isinstance (atom_types , torch .Tensor ):
83+ atom_types = atom_types .clone ().detach ().to (dtype = torch .int32 , device = DEVICE )
84+ else :
85+ atom_types = torch .tensor (atom_types , dtype = torch .int32 , device = DEVICE )
8386 elif isinstance (coords , np .ndarray ):
8487 if cells is not None :
8588 assert isinstance (cells , np .ndarray ), err_msg
@@ -101,28 +104,40 @@ def eval_model(
101104 else :
102105 natoms = len (atom_types [0 ])
103106
104- coord_input = torch .tensor (
105- coords .reshape ([- 1 , natoms , 3 ]), dtype = GLOBAL_PT_FLOAT_PRECISION , device = DEVICE
106- )
107+ if isinstance (coords , torch .Tensor ):
108+ coord_input = coords .reshape ([- 1 , natoms , 3 ]).clone ().detach ().to (dtype = GLOBAL_PT_FLOAT_PRECISION , device = DEVICE )
109+ else :
110+ coord_input = torch .tensor (
111+ coords .reshape ([- 1 , natoms , 3 ]), dtype = GLOBAL_PT_FLOAT_PRECISION , device = DEVICE
112+ )
107113 spin_input = None
108114 if spins is not None :
109- spin_input = torch .tensor (
110- spins .reshape ([- 1 , natoms , 3 ]),
111- dtype = GLOBAL_PT_FLOAT_PRECISION ,
112- device = DEVICE ,
113- )
115+ if isinstance (spins , torch .Tensor ):
116+ spin_input = spins .reshape ([- 1 , natoms , 3 ]).clone ().detach ().to (dtype = GLOBAL_PT_FLOAT_PRECISION , device = DEVICE )
117+ else :
118+ spin_input = torch .tensor (
119+ spins .reshape ([- 1 , natoms , 3 ]),
120+ dtype = GLOBAL_PT_FLOAT_PRECISION ,
121+ device = DEVICE ,
122+ )
114123 has_spin = getattr (model , "has_spin" , False )
115124 if callable (has_spin ):
116125 has_spin = has_spin ()
117- type_input = torch .tensor (atom_types , dtype = torch .long , device = DEVICE )
126+ if isinstance (atom_types , torch .Tensor ):
127+ type_input = atom_types .clone ().detach ().to (dtype = torch .long , device = DEVICE )
128+ else :
129+ type_input = torch .tensor (atom_types , dtype = torch .long , device = DEVICE )
118130 box_input = None
119131 if cells is None :
120132 pbc = False
121133 else :
122134 pbc = True
123- box_input = torch .tensor (
124- cells .reshape ([- 1 , 3 , 3 ]), dtype = GLOBAL_PT_FLOAT_PRECISION , device = DEVICE
125- )
135+ if isinstance (cells , torch .Tensor ):
136+ box_input = cells .reshape ([- 1 , 3 , 3 ]).clone ().detach ().to (dtype = GLOBAL_PT_FLOAT_PRECISION , device = DEVICE )
137+ else :
138+ box_input = torch .tensor (
139+ cells .reshape ([- 1 , 3 , 3 ]), dtype = GLOBAL_PT_FLOAT_PRECISION , device = DEVICE
140+ )
126141 num_iter = int ((nframes + infer_batch_size - 1 ) / infer_batch_size )
127142
128143 for ii in range (num_iter ):
0 commit comments