@@ -68,13 +68,12 @@ def tully_potential_matrix(Q, params):
6868 D = params .get ("D" , 1.0 )
6969
7070
71+ # Diabatic state 1 potential
7172 V11 = torch .where (
7273 x >= 0 ,
7374 A * (1 - torch .exp (- B * x )),
7475 - A * (1 - torch .exp (B * x ))
75- )
76-
77- #V11 = A * (1 - torch.exp(-B * x)) # Diabatic state 1 potential
76+ )
7877 V22 = - V11 # Diabatic state 2 potential (mirror)
7978 V12 = C * torch .exp (- D * x ** 2 ) # Coupling between diabatic states
8079
@@ -290,12 +289,12 @@ def __init__(self, params):
290289 self .initial_state_index = params .get ("initial_state_index" , 0 )
291290 self .method = params .get ("method" , "miller-colton" ).lower ()
292291
293- self .time = []
294- self .kinetic_energy = []
295- self .potential_energy = []
296- self .total_energy = []
297- self .population_right = []
298- self .norm = []
292+ # self.time = []
293+ # self.kinetic_energy = []
294+ # self.potential_energy = []
295+ # self.total_energy = []
296+ # self.population_right = []
297+ # self.norm = []
299298
300299 def initialize_grids (self ):
301300 """
@@ -331,7 +330,14 @@ def initialize_grids(self):
331330
332331
333332 # Allocate storage:
334- self .nsnaps = self .nsteps // self .save_every_n_steps + 1 # how many snapshots to save
333+ self .nsnaps = self .nsteps // self .save_every_n_steps # how many snapshots to save
334+
335+ self .time = torch .zeros ( self .nsnaps , dtype = torch .float )
336+ self .kinetic_energy = torch .zeros ( self .nsnaps , dtype = torch .float )
337+ self .potential_energy = torch .zeros ( self .nsnaps , dtype = torch .float )
338+ self .total_energy = torch .zeros ( self .nsnaps , dtype = torch .float )
339+ self .population_right = torch .zeros ( self .nsnaps , dtype = torch .float )
340+ self .norm = torch .zeros ( self .nsnaps , dtype = torch .float )
335341
336342 # Diabatic properties
337343 self .psi_r_dia = torch .zeros ((* self .grid_size , self .Nstates ), dtype = torch .cfloat )
@@ -443,7 +449,7 @@ def propagate(self):
443449
444450 if step % self .save_every_n_steps == 0 :
445451 istep = int (step / self .save_every_n_steps )
446-
452+
447453 # Diabatic r-space wavefunctions
448454 self .psi_r_dia_all [istep ] = self .psi_r_dia
449455
@@ -471,11 +477,11 @@ def propagate(self):
471477 right_mask = self .Q [0 ] > 0
472478 #pop_right = torch.sum(self.prob_density[right_mask]) * self.dV
473479
474- self .norm . append ( nrm )
475- self .time . append ( step * self .dt )
476- self .kinetic_energy . append ( KE .real .item () )
477- self .potential_energy . append ( PE .real .item () )
478- self .total_energy . append ( KE + PE )
480+ self .norm [ istep ] = nrm
481+ self .time [ istep ] = step * self .dt
482+ self .kinetic_energy [ istep ] = KE .real .item ()
483+ self .potential_energy [ istep ] = PE .real .item ()
484+ self .total_energy [ istep ] = KE + PE
479485 #self.population_right.append(pop_right.item())
480486
481487 print (f"Step { step } : Norm = { nrm :.4f} " )
@@ -522,7 +528,7 @@ def save(self):
522528 "ndim" :self .ndim ,
523529 "q_min" :self .q_min , "q_max" :self .q_max ,
524530 "save_every_n_steps" : self .save_every_n_steps ,
525- "dt" :self .dt , "nsteps" :self .nsteps ,
531+ "dt" :self .dt , "nsteps" :self .nsteps , "nsnaps" : self . nsnaps ,
526532 "mass" :self .mass ,
527533 "psi_r_adi" :self .psi_r_adi ,
528534 "psi_r_dia" :self .psi_r_dia ,
@@ -534,6 +540,8 @@ def save(self):
534540 "psi_r_adi_all" :self .psi_r_adi_all ,
535541 "psi_k_dia_all" :self .psi_k_dia_all ,
536542 "psi_k_adi_all" :self .psi_k_adi_all ,
543+ "E" :self .eigvals ,
544+ "U" :self .eigvecs ,
537545 "time" :self .time ,
538546 "Q" :self .Q , "K" :self .K , "dq" :self .dq , "dk" :self .dk ,
539547 "dV" :self .dV , "dVk" :self .dVk ,
0 commit comments