Skip to content

Commit 3695b3c

Browse files
committed
Updated and cleaned up the multistate TDSE solver
1 parent a6ad422 commit 3695b3c

1 file changed

Lines changed: 25 additions & 17 deletions

File tree

src/libra_py/dynamics/exact_torch/compute.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -68,13 +68,12 @@ def tully_potential_matrix(Q, params):
6868
D = params.get("D", 1.0)
6969

7070

71+
# Diabatic state 1 potential
7172
V11 = torch.where(
7273
x >= 0,
7374
A * (1 - torch.exp(-B * x)),
7475
-A * (1 - torch.exp(B * x))
75-
)
76-
77-
#V11 = A * (1 - torch.exp(-B * x)) # Diabatic state 1 potential
76+
)
7877
V22 = -V11 # Diabatic state 2 potential (mirror)
7978
V12 = C * torch.exp(-D * x**2) # Coupling between diabatic states
8079

@@ -290,12 +289,12 @@ def __init__(self, params):
290289
self.initial_state_index = params.get("initial_state_index", 0)
291290
self.method = params.get("method", "miller-colton").lower()
292291

293-
self.time = []
294-
self.kinetic_energy = []
295-
self.potential_energy = []
296-
self.total_energy = []
297-
self.population_right = []
298-
self.norm = []
292+
#self.time = []
293+
#self.kinetic_energy = []
294+
#self.potential_energy = []
295+
#self.total_energy = []
296+
#self.population_right = []
297+
#self.norm = []
299298

300299
def initialize_grids(self):
301300
"""
@@ -331,7 +330,14 @@ def initialize_grids(self):
331330

332331

333332
# Allocate storage:
334-
self.nsnaps = self.nsteps // self.save_every_n_steps + 1 # how many snapshots to save
333+
self.nsnaps = self.nsteps // self.save_every_n_steps # how many snapshots to save
334+
335+
self.time = torch.zeros( self.nsnaps, dtype=torch.float )
336+
self.kinetic_energy = torch.zeros( self.nsnaps, dtype=torch.float )
337+
self.potential_energy = torch.zeros( self.nsnaps, dtype=torch.float )
338+
self.total_energy = torch.zeros( self.nsnaps, dtype=torch.float )
339+
self.population_right = torch.zeros( self.nsnaps, dtype=torch.float )
340+
self.norm = torch.zeros( self.nsnaps, dtype=torch.float )
335341

336342
# Diabatic properties
337343
self.psi_r_dia = torch.zeros((*self.grid_size, self.Nstates), dtype=torch.cfloat)
@@ -443,7 +449,7 @@ def propagate(self):
443449

444450
if step % self.save_every_n_steps == 0:
445451
istep = int(step / self.save_every_n_steps)
446-
452+
447453
# Diabatic r-space wavefunctions
448454
self.psi_r_dia_all[istep] = self.psi_r_dia
449455

@@ -471,11 +477,11 @@ def propagate(self):
471477
right_mask = self.Q[0] > 0
472478
#pop_right = torch.sum(self.prob_density[right_mask]) * self.dV
473479

474-
self.norm.append( nrm )
475-
self.time.append(step * self.dt)
476-
self.kinetic_energy.append( KE.real.item() )
477-
self.potential_energy.append( PE.real.item() )
478-
self.total_energy.append( KE + PE )
480+
self.norm[istep] = nrm
481+
self.time[istep] = step * self.dt
482+
self.kinetic_energy[istep] = KE.real.item()
483+
self.potential_energy[istep] = PE.real.item()
484+
self.total_energy[istep] = KE + PE
479485
#self.population_right.append(pop_right.item())
480486

481487
print(f"Step {step}: Norm = {nrm:.4f}")
@@ -522,7 +528,7 @@ def save(self):
522528
"ndim":self.ndim,
523529
"q_min":self.q_min, "q_max":self.q_max,
524530
"save_every_n_steps": self.save_every_n_steps,
525-
"dt":self.dt, "nsteps":self.nsteps,
531+
"dt":self.dt, "nsteps":self.nsteps, "nsnaps":self.nsnaps,
526532
"mass":self.mass,
527533
"psi_r_adi":self.psi_r_adi,
528534
"psi_r_dia":self.psi_r_dia,
@@ -534,6 +540,8 @@ def save(self):
534540
"psi_r_adi_all":self.psi_r_adi_all,
535541
"psi_k_dia_all":self.psi_k_dia_all,
536542
"psi_k_adi_all":self.psi_k_adi_all,
543+
"E":self.eigvals,
544+
"U":self.eigvecs,
537545
"time":self.time,
538546
"Q":self.Q, "K":self.K, "dq":self.dq, "dk":self.dk,
539547
"dV":self.dV, "dVk":self.dVk,

0 commit comments

Comments
 (0)