Skip to content

Commit 49abcb0

Browse files
authored
Merge pull request #265 from DaehoHan/ldr_rev
Add more analyses in the LDR dynamics
2 parents 32a7c87 + cf1b912 commit 49abcb0

1 file changed

Lines changed: 122 additions & 72 deletions

File tree

src/libra_py/dynamics/ldr_torch/compute.py

Lines changed: 122 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -41,17 +41,19 @@ def __init__(self, params):
4141
self.prefix = params.get("prefix", "ldr-solution")
4242
self.device = params.get("device", torch.device("cuda" if torch.cuda.is_available() else "cpu"))
4343
self.hbar = 1.0
44-
self.Hamiltonian_scheme = "symmetrized"
44+
self.hamiltonian_scheme = "symmetrized"
4545
self.q0 = torch.tensor(params.get("q0", [0.0]), dtype=torch.float64, device=self.device)
4646
self.p0 = torch.tensor(params.get("p0", [0.0]), dtype=torch.float64, device=self.device)
4747
self.k = torch.tensor(params.get("k", [0.001]), dtype=torch.float64, device=self.device)
4848
self.mass = torch.tensor(params.get("mass", [2000.0]), dtype=torch.float64, device=self.device)
4949
self.alpha = torch.tensor(params.get("alpha", [18.0]), dtype=torch.float64, device=self.device)
5050
self.qgrid = torch.tensor(params.get("qgrid", [[-10 + i * 0.1] for i in range(int((10 - (-10)) / 0.1) + 1)] ), dtype=torch.float64, device=self.device) #(N, D)
5151
self.ngrids = len(self.qgrid) # N
52+
self.ndof = self.qgrid.shape[1]
5253
self.nstates = params.get("nstates", 2)
5354
self.istate = params.get("istate", 0)
54-
55+
self.elec_ampl = params.get("elec_ampl", torch.tensor([1.0+0.j]*self.ngrids, dtype=torch.cdouble))
56+
5557
self.save_every_n_steps = params.get("save_every_n_steps", 1)
5658
self.properties_to_save = params.get("properties_to_save", ["time", "population_right"])
5759
self.dt = params.get("dt", 0.01)
@@ -60,99 +62,103 @@ def __init__(self, params):
6062

6163
self.E = params.get("E", torch.zeros(self.nstates, self.ngrids, device=self.device) )
6264

63-
Selec_default = torch.zeros(self.ndim, self.ndim, dtype=torch.cdouble, device=self.device)
65+
s_elec_default = torch.zeros(self.ndim, self.ndim, dtype=torch.cdouble, device=self.device)
6466
for i in range(self.nstates):
6567
start, end = i * self.ngrids, (i + 1) * self.ngrids
66-
Selec_default[start:end, start:end] = torch.eye(self.ngrids, device=self.device)
67-
self.Selec = params.get("Selec", Selec_default )
68+
s_elec_default[start:end, start:end] = torch.eye(self.ngrids, device=self.device)
69+
self.s_elec = params.get("s_elec", s_elec_default )
6870

6971
# Computed with LDR methods
7072
self.C0 = torch.zeros(self.ndim, dtype=torch.cdouble, device=self.device)
71-
self.Ccurr = torch.zeros(self.ndim, dtype=torch.cdouble, device=self.device)
73+
self.C_curr = torch.zeros(self.ndim, dtype=torch.cdouble, device=self.device)
7274

73-
self.Snucl = torch.eye(self.ngrids, dtype=torch.cdouble, device=self.device)
74-
self.Tnucl = torch.zeros(self.ngrids, self.ngrids, dtype=torch.cdouble, device=self.device)
75+
self.s_nucl = torch.eye(self.ngrids, dtype=torch.cdouble, device=self.device)
76+
self.t_nucl = torch.zeros(self.ngrids, self.ngrids, dtype=torch.cdouble, device=self.device)
7577

7678
self.S, self.H = torch.zeros(self.ndim, self.ndim, dtype=torch.cdouble, device=self.device), torch.zeros(self.ndim, self.ndim, dtype=torch.cdouble, device=self.device)
7779
self.U = torch.zeros(self.ndim, self.ndim, dtype=torch.cdouble, device=self.device)
80+
self.S_half = torch.zeros(self.ndim, self.ndim, dtype=torch.cdouble, device=self.device)
7881

7982
self.time = []
8083
self.kinetic_energy = []
8184
self.potential_energy = []
8285
self.total_energy = []
86+
self.average_pos = []
8387
self.population_right = []
88+
self.denmat = []
8489
self.norm = []
8590
self.C_save = []
8691

8792
def chi_overlap(self):
8893
"""
89-
Compute nuclear overlap matrix Snucl[i, j] for the mesh qmesh.
94+
Compute nuclear overlap matrix s_nucl[i, j] for the mesh qmesh
95+
from the Gaussian basis, g(x; q) = \exp(-\alpha * (x-q)**2).
9096
"""
9197
delta = self.qgrid[:, None, :] - self.qgrid[None, :, :] # (N, N, D)
9298
exponent = -0.5 * torch.sum(self.alpha * delta**2, dim=2) # (N, N)
93-
self.Snucl = torch.exp(exponent)
99+
self.s_nucl = torch.exp(exponent)
94100

95101
def chi_kinetic(self):
96-
r"""
97-
Compute nuclear kinetic energy matrix Tnucl[i,j] = <g(x; qgrid[i]) | T | g(x; qgrid[j])>,
98-
with T = Σ_ν -½ m_ν^{-1} ∂²/∂x_ν².
102+
"""
103+
Compute nuclear kinetic energy matrix t_nucl[i,j] = <g(x; qgrid[i]) | T | g(x; qgrid[j])>,
104+
with T = \sum_{\nu} -0.5* m_ν^{-1} \partial^{2}/\partial x_{\nu}^2.
99105
"""
100106
delta = self.qgrid[:, None, :] - self.qgrid[None, :, :] # (N, N, D)
101107
tau = self.alpha / (2.0 * self.mass) * (1.0 - self.alpha * delta**2) # (N, N, D)
102108
tau_sum = torch.sum(tau, dim=2) # (N, N)
103109

104-
self.Tnucl = self.Snucl * tau_sum # (N, N)
105-
110+
self.t_nucl = self.s_nucl * tau_sum # (N, N)
111+
106112
def build_compound_overlap(self):
107113
"""
108114
Build the compound nuclear-electronic overlap matrix self.S (ndim, ndim)
109115
"""
110116
N, s, ndim = self.ngrids, self.nstates, self.ndim
111117

112-
# Reshape Selec[a, b] -> (i, n, j, m) with:
118+
# Reshape s_elec[a, b] -> (i, n, j, m) with:
113119
# a = i * N + n
114120
# b = j * N + m
115-
Selec4D = self.Selec.view(s, N, s, N) # (i, n, j, m)
121+
s_elec_4d = self.s_elec.view(s, N, s, N) # (i, n, j, m)
116122

117-
Snucl4D = self.Snucl.unsqueeze(0).unsqueeze(2) # (1, n, 1, m)
123+
s_nucl_4d = self.s_nucl[None, :, None, :] # (1, n, 1, m)
118124

119-
S4D = Selec4D * Snucl4D
125+
S_4d = s_elec_4d * s_nucl_4d
120126

121127
# Reshape back to (ndim, ndim) with compound indices
122-
self.S = S4D.permute(0, 1, 2, 3).reshape(ndim, ndim)
128+
self.S = S_4d.reshape(ndim, ndim)
123129

124130
def build_compound_hamiltonian(self):
125131
"""
126132
Build the compound nuclear-electronic Hamiltonian self.H (ndim, ndim) using different schemes.
127133
"""
128134
N, s, ndim = self.ngrids, self.nstates, self.ndim
129-
scheme = self.Hamiltonian_scheme
130-
Selec4D = self.Selec.view(s, N, s, N) # (s, N, s, N)
131-
T4D = self.Tnucl.unsqueeze(0).unsqueeze(2) # (1, N, 1, N)
132-
S4D = self.Snucl.unsqueeze(0).unsqueeze(2) # (1, N, 1, N)
135+
scheme = self.hamiltonian_scheme
136+
s_elec_4d = self.s_elec.view(s, N, s, N) # (s, N, s, N)
137+
T_4d = self.t_nucl[None, :, None, :] # (1, N, 1, N)
138+
S_4d = self.s_nucl[None, :, None, :] # (1, N, 1, N)
133139

134140
if scheme == 'as_is': # For showing the original non-Hermitian form, not intended to use
135-
Ej4D = self.E[None, None, :, :] # (1, 1, s, N)
136-
bracket4D = T4D + Ej4D * S4D
141+
E_j_4d = self.E[None, None, :, :] # (1, 1, s, N)
142+
bracket_4d = T_4d + E_j_4d * S_4d
137143
elif scheme == 'symmetrized':
138-
Ei4D = self.E[:, :, None, None] # (s, N, 1, 1)
139-
Ej4D = self.E[None, None, :, :] # (1, 1, s, N)
140-
Eavg4D = 0.5 * (Ei4D + Ej4D) # (s, N, s, N)
141-
bracket4D = T4D + Eavg4D * S4D
144+
E_i_4d = self.E[:, :, None, None] # (s, N, 1, 1)
145+
E_j_4d = self.E[None, None, :, :] # (1, 1, s, N)
146+
E_avg_4d = 0.5 * (E_i_4d + E_j_4d) # (s, N, s, N)
147+
bracket_4d = T_4d + E_avg_4d * S_4d
142148
elif scheme == 'diagonal':
143149
# Build Kronecker deltas for electronic and nuclear indices
144-
delta_ij = torch.eye(s, device=self.device).unsqueeze(1).unsqueeze(3) # (s, 1, s, 1)
145-
delta_nm = torch.eye(N, device=self.device).unsqueeze(0).unsqueeze(2) # (1, N, 1, N)
146-
delta4D = delta_ij * delta_nm
150+
delta_ij = torch.eye(s, device=self.device)[:, None, :, None] # (s, 1, s, 1)
151+
delta_nm = torch.eye(N, device=self.device)[None, :, None, :] # (1, N, 1, N)
152+
delta_4d = delta_ij * delta_nm
147153

148-
Ej4D = self.E[None, None, :, :] # (1, 1, s, N)
149-
bracket4D = T4D + Ej4D * S4D * delta4D
154+
E_j_4d = self.E[None, None, :, :] # (1, 1, s, N)
155+
bracket_4d = T_4d + E_j_4d * S_4d * delta_4d
150156

151157
else:
152158
raise ValueError(f"Unknown Hamiltonian scheme: {scheme}")
153159

154-
H4D = Selec4D * bracket4D
155-
self.H = H4D.reshape(ndim, ndim)
160+
H_4d = s_elec_4d * bracket_4d
161+
self.H = H_4d.reshape(ndim, ndim)
156162

157163
def compute_propagator(self):
158164
"""
@@ -166,7 +172,7 @@ def compute_propagator(self):
166172

167173
evals_S, evecs_S = torch.linalg.eigh(S)
168174

169-
S_half = (evecs_S @ torch.diag(evals_S.sqrt().to(dtype=torch.cdouble)) @ evecs_S.T).to(dtype=torch.cdouble)
175+
self.S_half = (evecs_S @ torch.diag(evals_S.sqrt().to(dtype=torch.cdouble)) @ evecs_S.T).to(dtype=torch.cdouble)
170176
S_invhalf = (evecs_S @ torch.diag((1.0 / evals_S).sqrt().to(dtype=torch.cdouble)) @ evecs_S.T).to(dtype=torch.cdouble)
171177

172178
H_ortho = S_invhalf @ H @ S_invhalf
@@ -176,7 +182,7 @@ def compute_propagator(self):
176182
exp_diag = torch.diag(torch.exp(-1j * evals_H * dt))
177183
U_ortho = evecs_H @ exp_diag @ evecs_H.conj().T
178184

179-
self.U = S_invhalf @ U_ortho @ S_half
185+
self.U = S_invhalf @ U_ortho @ self.S_half
180186

181187

182188
def initialize_C(self):
@@ -217,7 +223,7 @@ def initialize_C(self):
217223
delta_eta = -0.5 * torch.dot(xi0 + p0, q0) + 0.5 * torch.dot(xig, qgrid[n]).conj()
218224
exponent = -1.j * 0.5 * torch.dot(delta_xi, torch.matmul(delta_A_inv, delta_xi)) + 1.j * delta_eta
219225

220-
self.C0[index] = torch.exp(exponent)
226+
self.C0[index] = self.elec_ampl[n] * torch.exp(exponent)
221227

222228
# Normalize
223229
overlap = torch.matmul(self.S, self.C0)
@@ -230,14 +236,14 @@ def propagate(self):
230236
Propagate coefficient.
231237
"""
232238
# Initialize first step with normalized initial wavefunction
233-
self.Ccurr = self.C0.clone()
239+
self.C_curr = self.C0.clone()
234240

235241
print(F"step = 0")
236242
self.save_results(0)
237243

238244
for step in range(1, self.nsteps):
239-
Cvec = self.Ccurr.clone()
240-
self.Ccurr = self.U @ Cvec
245+
C_vec = self.C_curr.clone()
246+
self.C_curr = self.U @ C_vec
241247

242248
if step % self.save_every_n_steps == 0:
243249
print(F"step = {step}")
@@ -247,53 +253,72 @@ def save_results(self, step):
247253
if "time" in self.properties_to_save:
248254
self.time.append(step*self.dt)
249255
if "norm" in self.properties_to_save:
250-
overlap = torch.matmul(self.S, self.Ccurr)
251-
self.norm.append(torch.sqrt(torch.vdot(self.Ccurr, overlap)))
256+
overlap = torch.matmul(self.S, self.C_curr)
257+
self.norm.append(torch.sqrt(torch.vdot(self.C_curr, overlap)))
252258
if "population_right" in self.properties_to_save:
253259
self.population_right.append(self.compute_populations())
260+
if "denmat" in self.properties_to_save:
261+
self.denmat.append(self.compute_denmat())
254262
if "kinetic_energy" in self.properties_to_save:
255263
self.kinetic_energy.append(self.compute_kinetic_energy())
256264
if "potential_energy" in self.properties_to_save:
257265
self.potential_energy.append(self.compute_potential_energy())
258266
if "total_energy" in self.properties_to_save:
259267
self.total_energy.append(self.compute_total_energy())
268+
if "average_pos" in self.properties_to_save:
269+
self.average_pos.append(self.compute_average_pos())
260270
if "C_save" in self.properties_to_save:
261-
self.C_save.append(self.Ccurr)
271+
self.C_save.append(self.C_curr)
262272

263273
def compute_populations(self):
264274
"""
265275
Compute electronic state population for a single step.
266276
"""
267277
N, s = self.ngrids, self.nstates
268-
Cvec = self.Ccurr
278+
C_vec = self.C_curr
269279

270280
# Compute SC once: shape (ndim,)
271-
SC = self.S @ Cvec
281+
SC = self.S @ C_vec
272282

273-
C_blocks = Cvec.view(s, N)
283+
C_blocks = C_vec.view(s, N)
274284
SC_blocks = SC.view(s, N)
275285

276286
# Compute P[i] = sum_j <C_j|S_{ji}|C_i> = Re[ sum_N (C_j*) * SC_j ]
277287
P = torch.sum(C_blocks.conj() * SC_blocks, dim=1).real
278288

279289
return P
280290

291+
def compute_denmat(self):
292+
"""
293+
Compute electronic density matrix for a single step using the orthogonalization.
294+
"""
295+
N, s = self.ngrids, self.nstates
296+
C_vec = self.C_curr
297+
298+
# Orthogonalize coefficients: C_ortho = S^{1/2} C
299+
C_ortho = self.S_half @ C_vec
300+
301+
C_blocks = C_ortho.view(s, N)
302+
303+
rho = C_blocks @ C_blocks.conj().T # (s, s)
304+
305+
return rho
306+
281307
def compute_kinetic_energy(self):
282308
"""
283309
Compute nuclear kinetic energy as C^+ T C / C^+ S C for a single step.
284310
"""
285311
N, s, ndim = self.ngrids, self.nstates, self.ndim
286312

287-
# Rebuild compound kinetic matrix: T4D * Selec4D
288-
Selec4D = self.Selec.view(s, N, s, N)
289-
T4D = self.Tnucl.unsqueeze(0).unsqueeze(2) # (1, n, 1, m)
290-
T4D_compound = Selec4D * T4D
291-
T_compound = T4D_compound.permute(0, 1, 2, 3).reshape(ndim, ndim)
313+
# Rebuild compound kinetic matrix: T_4d * s_elec_4d
314+
s_elec_4d = self.s_elec.view(s, N, s, N)
315+
T_4d = self.t_nucl[None, :, None, :]
316+
T_compound = (s_elec_4d * T_4d).reshape(ndim, ndim)
292317

293-
Cvec = self.Ccurr
318+
C_vec = self.C_curr
294319

295-
numer = torch.vdot(Cvec, T_compound @ Cvec).real
296-
denom = torch.vdot(Cvec, self.S @ Cvec).real
320+
numer = torch.vdot(C_vec, T_compound @ C_vec).real
321+
denom = torch.vdot(C_vec, self.S @ C_vec).real
297322

298323
return numer / denom
299324

@@ -304,17 +329,16 @@ def compute_potential_energy(self):
304329
"""
305330
N, s, ndim = self.ngrids, self.nstates, self.ndim
306331

307-
Selec4D = self.Selec.view(s, N, s, N)
308-
S4D = self.Snucl.unsqueeze(0).unsqueeze(2) # (1, n, 1, m)
309-
Ej4D = self.E[None, None, :, :] # (1,1,j,m)
332+
s_elec_4d = self.s_elec.view(s, N, s, N)
333+
S_4d = self.s_nucl[None, :, None, :]
334+
E_j_4d = self.E[None, None, :, :] # (1,1,j,m)
310335

311-
V4D_compound = Selec4D * (Ej4D * S4D)
312-
V_compound = V4D_compound.permute(0, 1, 2, 3).reshape(ndim, ndim)
336+
V_compound = (s_elec_4d * (E_j_4d * S_4d)).reshape(ndim, ndim)
313337

314-
Cvec = self.Ccurr
338+
C_vec = self.C_curr
315339

316-
numer = torch.vdot(Cvec, V_compound @ Cvec).real
317-
denom = torch.vdot(Cvec, self.S @ Cvec).real
340+
numer = torch.vdot(C_vec, V_compound @ C_vec).real
341+
denom = torch.vdot(C_vec, self.S @ C_vec).real
318342

319343
return numer / denom
320344

@@ -323,13 +347,37 @@ def compute_total_energy(self):
323347
"""
324348
Compute total energy as C^+ H C / C^+ S C for a single step.
325349
"""
326-
Cvec = self.Ccurr
350+
C_vec = self.C_curr
327351

328-
numer = torch.vdot(Cvec, self.H @ Cvec).real
329-
denom = torch.vdot(Cvec, self.S @ Cvec).real
352+
numer = torch.vdot(C_vec, self.H @ C_vec).real
353+
denom = torch.vdot(C_vec, self.S @ C_vec).real
330354

331355
return numer / denom
356+
357+
def compute_average_pos(self):
358+
"""
359+
Compute average position as <q_i> = \sum_i C^+ Q C / C^+ S C for a single step.
360+
"""
361+
N, s, ndim = self.ngrids, self.nstates, self.ndim
332362

363+
C_vec = self.C_curr
364+
365+
denom = torch.vdot(C_vec, self.S @ C_vec).real
366+
s_elec_4d = self.s_elec.view(s, N, s, N)
367+
368+
avg_q = []
369+
for idof in range(self.ndof):
370+
q_med = 0.5 * (self.qgrid[:, None, idof] + self.qgrid[None,:,idof])
371+
q_nucl = self.s_nucl * q_med
372+
Q_4d = q_nucl[None, :, None, :]
373+
Q_4d_compound = s_elec_4d * Q_4d
374+
Q_compound = Q_4d_compound.reshape(ndim, ndim)
375+
376+
numer = torch.vdot(C_vec, Q_compound @ C_vec).real
377+
avg_q.append(numer / denom)
378+
379+
return avg_q
380+
333381
def save(self):
334382
torch.save( {"q0":self.q0,
335383
"p0":self.p0,
@@ -339,22 +387,24 @@ def save(self):
339387
"qgrid":self.qgrid,
340388
"nstates":self.nstates,
341389
"istate":self.istate,
342-
"Snucl":self.Snucl,
343-
"Tnucl":self.Tnucl,
390+
"s_nucl":self.s_nucl,
391+
"t_nucl":self.t_nucl,
344392
"E":self.E,
345-
"Selec":self.Selec,
393+
"s_elec":self.s_elec,
346394
"S":self.S,
347395
"H":self.H,
348396
"U":self.U,
349397
"C_save":self.C_save,
350398
"save_every_n_steps":self.save_every_n_steps,
351-
"Hamiltonian_scheme": self.Hamiltonian_scheme,
399+
"hamiltonian_scheme": self.hamiltonian_scheme,
352400
"dt":self.dt, "nsteps":self.nsteps,
353401
"time":self.time,
354402
"kinetic_energy":self.kinetic_energy,
355403
"potential_energy":self.potential_energy,
356404
"total_energy":self.total_energy,
405+
"average_pos":self.average_pos,
357406
"population_right":self.population_right,
407+
"denmat":self.denmat,
358408
"norm":self.norm
359409
}, F"{self.prefix}.pt" )
360410

0 commit comments

Comments
 (0)