@@ -41,17 +41,19 @@ def __init__(self, params):
4141 self .prefix = params .get ("prefix" , "ldr-solution" )
4242 self .device = params .get ("device" , torch .device ("cuda" if torch .cuda .is_available () else "cpu" ))
4343 self .hbar = 1.0
44- self .Hamiltonian_scheme = "symmetrized"
44+ self .hamiltonian_scheme = "symmetrized"
4545 self .q0 = torch .tensor (params .get ("q0" , [0.0 ]), dtype = torch .float64 , device = self .device )
4646 self .p0 = torch .tensor (params .get ("p0" , [0.0 ]), dtype = torch .float64 , device = self .device )
4747 self .k = torch .tensor (params .get ("k" , [0.001 ]), dtype = torch .float64 , device = self .device )
4848 self .mass = torch .tensor (params .get ("mass" , [2000.0 ]), dtype = torch .float64 , device = self .device )
4949 self .alpha = torch .tensor (params .get ("alpha" , [18.0 ]), dtype = torch .float64 , device = self .device )
5050 self .qgrid = torch .tensor (params .get ("qgrid" , [[- 10 + i * 0.1 ] for i in range (int ((10 - (- 10 )) / 0.1 ) + 1 )] ), dtype = torch .float64 , device = self .device ) #(N, D)
5151 self .ngrids = len (self .qgrid ) # N
52+ self .ndof = self .qgrid .shape [1 ]
5253 self .nstates = params .get ("nstates" , 2 )
5354 self .istate = params .get ("istate" , 0 )
54-
55+ self .elec_ampl = params .get ("elec_ampl" , torch .tensor ([1.0 + 0.j ]* self .ngrids , dtype = torch .cdouble ))
56+
5557 self .save_every_n_steps = params .get ("save_every_n_steps" , 1 )
5658 self .properties_to_save = params .get ("properties_to_save" , ["time" , "population_right" ])
5759 self .dt = params .get ("dt" , 0.01 )
@@ -60,99 +62,103 @@ def __init__(self, params):
6062
6163 self .E = params .get ("E" , torch .zeros (self .nstates , self .ngrids , device = self .device ) )
6264
63- Selec_default = torch .zeros (self .ndim , self .ndim , dtype = torch .cdouble , device = self .device )
65+ s_elec_default = torch .zeros (self .ndim , self .ndim , dtype = torch .cdouble , device = self .device )
6466 for i in range (self .nstates ):
6567 start , end = i * self .ngrids , (i + 1 ) * self .ngrids
66- Selec_default [start :end , start :end ] = torch .eye (self .ngrids , device = self .device )
67- self .Selec = params .get ("Selec " , Selec_default )
68+ s_elec_default [start :end , start :end ] = torch .eye (self .ngrids , device = self .device )
69+ self .s_elec = params .get ("s_elec " , s_elec_default )
6870
6971 # Computed with LDR methods
7072 self .C0 = torch .zeros (self .ndim , dtype = torch .cdouble , device = self .device )
71- self .Ccurr = torch .zeros (self .ndim , dtype = torch .cdouble , device = self .device )
73+ self .C_curr = torch .zeros (self .ndim , dtype = torch .cdouble , device = self .device )
7274
73- self .Snucl = torch .eye (self .ngrids , dtype = torch .cdouble , device = self .device )
74- self .Tnucl = torch .zeros (self .ngrids , self .ngrids , dtype = torch .cdouble , device = self .device )
75+ self .s_nucl = torch .eye (self .ngrids , dtype = torch .cdouble , device = self .device )
76+ self .t_nucl = torch .zeros (self .ngrids , self .ngrids , dtype = torch .cdouble , device = self .device )
7577
7678 self .S , self .H = torch .zeros (self .ndim , self .ndim , dtype = torch .cdouble , device = self .device ), torch .zeros (self .ndim , self .ndim , dtype = torch .cdouble , device = self .device )
7779 self .U = torch .zeros (self .ndim , self .ndim , dtype = torch .cdouble , device = self .device )
80+ self .S_half = torch .zeros (self .ndim , self .ndim , dtype = torch .cdouble , device = self .device )
7881
7982 self .time = []
8083 self .kinetic_energy = []
8184 self .potential_energy = []
8285 self .total_energy = []
86+ self .average_pos = []
8387 self .population_right = []
88+ self .denmat = []
8489 self .norm = []
8590 self .C_save = []
8691
8792 def chi_overlap (self ):
8893 """
89- Compute nuclear overlap matrix Snucl[i, j] for the mesh qmesh.
94+ Compute nuclear overlap matrix s_nucl[i, j] for the mesh qmesh
95+ from the Gaussian basis, g(x; q) = \exp(-\a lpha * (x-q)**2).
9096 """
9197 delta = self .qgrid [:, None , :] - self .qgrid [None , :, :] # (N, N, D)
9298 exponent = - 0.5 * torch .sum (self .alpha * delta ** 2 , dim = 2 ) # (N, N)
93- self .Snucl = torch .exp (exponent )
99+ self .s_nucl = torch .exp (exponent )
94100
95101 def chi_kinetic (self ):
96- r """
97- Compute nuclear kinetic energy matrix Tnucl [i,j] = <g(x; qgrid[i]) | T | g(x; qgrid[j])>,
98- with T = Σ_ν -½ m_ν^{-1} ∂²/∂x_ν² .
102+ """
103+ Compute nuclear kinetic energy matrix t_nucl [i,j] = <g(x; qgrid[i]) | T | g(x; qgrid[j])>,
104+ with T = \sum_{ \n u} -0.5* m_ν^{-1} \partial^{2}/\partial x_{ \n u}^2 .
99105 """
100106 delta = self .qgrid [:, None , :] - self .qgrid [None , :, :] # (N, N, D)
101107 tau = self .alpha / (2.0 * self .mass ) * (1.0 - self .alpha * delta ** 2 ) # (N, N, D)
102108 tau_sum = torch .sum (tau , dim = 2 ) # (N, N)
103109
104- self .Tnucl = self .Snucl * tau_sum # (N, N)
105-
110+ self .t_nucl = self .s_nucl * tau_sum # (N, N)
111+
106112 def build_compound_overlap (self ):
107113 """
108114 Build the compound nuclear-electronic overlap matrix self.S (ndim, ndim)
109115 """
110116 N , s , ndim = self .ngrids , self .nstates , self .ndim
111117
112- # Reshape Selec [a, b] -> (i, n, j, m) with:
118+ # Reshape s_elec [a, b] -> (i, n, j, m) with:
113119 # a = i * N + n
114120 # b = j * N + m
115- Selec4D = self .Selec .view (s , N , s , N ) # (i, n, j, m)
121+ s_elec_4d = self .s_elec .view (s , N , s , N ) # (i, n, j, m)
116122
117- Snucl4D = self .Snucl . unsqueeze ( 0 ). unsqueeze ( 2 ) # (1, n, 1, m)
123+ s_nucl_4d = self .s_nucl [ None , :, None , :] # (1, n, 1, m)
118124
119- S4D = Selec4D * Snucl4D
125+ S_4d = s_elec_4d * s_nucl_4d
120126
121127 # Reshape back to (ndim, ndim) with compound indices
122- self .S = S4D . permute ( 0 , 1 , 2 , 3 ) .reshape (ndim , ndim )
128+ self .S = S_4d .reshape (ndim , ndim )
123129
124130 def build_compound_hamiltonian (self ):
125131 """
126132 Build the compound nuclear-electronic Hamiltonian self.H (ndim, ndim) using different schemes.
127133 """
128134 N , s , ndim = self .ngrids , self .nstates , self .ndim
129- scheme = self .Hamiltonian_scheme
130- Selec4D = self .Selec .view (s , N , s , N ) # (s, N, s, N)
131- T4D = self .Tnucl . unsqueeze ( 0 ). unsqueeze ( 2 ) # (1, N, 1, N)
132- S4D = self .Snucl . unsqueeze ( 0 ). unsqueeze ( 2 ) # (1, N, 1, N)
135+ scheme = self .hamiltonian_scheme
136+ s_elec_4d = self .s_elec .view (s , N , s , N ) # (s, N, s, N)
137+ T_4d = self .t_nucl [ None , :, None , :] # (1, N, 1, N)
138+ S_4d = self .s_nucl [ None , :, None , :] # (1, N, 1, N)
133139
134140 if scheme == 'as_is' : # For showing the original non-Hermitian form, not intended to use
135- Ej4D = self .E [None , None , :, :] # (1, 1, s, N)
136- bracket4D = T4D + Ej4D * S4D
141+ E_j_4d = self .E [None , None , :, :] # (1, 1, s, N)
142+ bracket_4d = T_4d + E_j_4d * S_4d
137143 elif scheme == 'symmetrized' :
138- Ei4D = self .E [:, :, None , None ] # (s, N, 1, 1)
139- Ej4D = self .E [None , None , :, :] # (1, 1, s, N)
140- Eavg4D = 0.5 * (Ei4D + Ej4D ) # (s, N, s, N)
141- bracket4D = T4D + Eavg4D * S4D
144+ E_i_4d = self .E [:, :, None , None ] # (s, N, 1, 1)
145+ E_j_4d = self .E [None , None , :, :] # (1, 1, s, N)
146+ E_avg_4d = 0.5 * (E_i_4d + E_j_4d ) # (s, N, s, N)
147+ bracket_4d = T_4d + E_avg_4d * S_4d
142148 elif scheme == 'diagonal' :
143149 # Build Kronecker deltas for electronic and nuclear indices
144- delta_ij = torch .eye (s , device = self .device ). unsqueeze ( 1 ). unsqueeze ( 3 ) # (s, 1, s, 1)
145- delta_nm = torch .eye (N , device = self .device ). unsqueeze ( 0 ). unsqueeze ( 2 ) # (1, N, 1, N)
146- delta4D = delta_ij * delta_nm
150+ delta_ij = torch .eye (s , device = self .device )[:, None , :, None ] # (s, 1, s, 1)
151+ delta_nm = torch .eye (N , device = self .device )[ None , :, None , :] # (1, N, 1, N)
152+ delta_4d = delta_ij * delta_nm
147153
148- Ej4D = self .E [None , None , :, :] # (1, 1, s, N)
149- bracket4D = T4D + Ej4D * S4D * delta4D
154+ E_j_4d = self .E [None , None , :, :] # (1, 1, s, N)
155+ bracket_4d = T_4d + E_j_4d * S_4d * delta_4d
150156
151157 else :
152158 raise ValueError (f"Unknown Hamiltonian scheme: { scheme } " )
153159
154- H4D = Selec4D * bracket4D
155- self .H = H4D .reshape (ndim , ndim )
160+ H_4d = s_elec_4d * bracket_4d
161+ self .H = H_4d .reshape (ndim , ndim )
156162
157163 def compute_propagator (self ):
158164 """
@@ -166,7 +172,7 @@ def compute_propagator(self):
166172
167173 evals_S , evecs_S = torch .linalg .eigh (S )
168174
169- S_half = (evecs_S @ torch .diag (evals_S .sqrt ().to (dtype = torch .cdouble )) @ evecs_S .T ).to (dtype = torch .cdouble )
175+ self . S_half = (evecs_S @ torch .diag (evals_S .sqrt ().to (dtype = torch .cdouble )) @ evecs_S .T ).to (dtype = torch .cdouble )
170176 S_invhalf = (evecs_S @ torch .diag ((1.0 / evals_S ).sqrt ().to (dtype = torch .cdouble )) @ evecs_S .T ).to (dtype = torch .cdouble )
171177
172178 H_ortho = S_invhalf @ H @ S_invhalf
@@ -176,7 +182,7 @@ def compute_propagator(self):
176182 exp_diag = torch .diag (torch .exp (- 1j * evals_H * dt ))
177183 U_ortho = evecs_H @ exp_diag @ evecs_H .conj ().T
178184
179- self .U = S_invhalf @ U_ortho @ S_half
185+ self .U = S_invhalf @ U_ortho @ self . S_half
180186
181187
182188 def initialize_C (self ):
@@ -217,7 +223,7 @@ def initialize_C(self):
217223 delta_eta = - 0.5 * torch .dot (xi0 + p0 , q0 ) + 0.5 * torch .dot (xig , qgrid [n ]).conj ()
218224 exponent = - 1.j * 0.5 * torch .dot (delta_xi , torch .matmul (delta_A_inv , delta_xi )) + 1.j * delta_eta
219225
220- self .C0 [index ] = torch .exp (exponent )
226+ self .C0 [index ] = self . elec_ampl [ n ] * torch .exp (exponent )
221227
222228 # Normalize
223229 overlap = torch .matmul (self .S , self .C0 )
@@ -230,14 +236,14 @@ def propagate(self):
230236 Propagate coefficient.
231237 """
232238 # Initialize first step with normalized initial wavefunction
233- self .Ccurr = self .C0 .clone ()
239+ self .C_curr = self .C0 .clone ()
234240
235241 print (F"step = 0" )
236242 self .save_results (0 )
237243
238244 for step in range (1 , self .nsteps ):
239- Cvec = self .Ccurr .clone ()
240- self .Ccurr = self .U @ Cvec
245+ C_vec = self .C_curr .clone ()
246+ self .C_curr = self .U @ C_vec
241247
242248 if step % self .save_every_n_steps == 0 :
243249 print (F"step = { step } " )
@@ -247,53 +253,72 @@ def save_results(self, step):
247253 if "time" in self .properties_to_save :
248254 self .time .append (step * self .dt )
249255 if "norm" in self .properties_to_save :
250- overlap = torch .matmul (self .S , self .Ccurr )
251- self .norm .append (torch .sqrt (torch .vdot (self .Ccurr , overlap )))
256+ overlap = torch .matmul (self .S , self .C_curr )
257+ self .norm .append (torch .sqrt (torch .vdot (self .C_curr , overlap )))
252258 if "population_right" in self .properties_to_save :
253259 self .population_right .append (self .compute_populations ())
260+ if "denmat" in self .properties_to_save :
261+ self .denmat .append (self .compute_denmat ())
254262 if "kinetic_energy" in self .properties_to_save :
255263 self .kinetic_energy .append (self .compute_kinetic_energy ())
256264 if "potential_energy" in self .properties_to_save :
257265 self .potential_energy .append (self .compute_potential_energy ())
258266 if "total_energy" in self .properties_to_save :
259267 self .total_energy .append (self .compute_total_energy ())
268+ if "average_pos" in self .properties_to_save :
269+ self .average_pos .append (self .compute_average_pos ())
260270 if "C_save" in self .properties_to_save :
261- self .C_save .append (self .Ccurr )
271+ self .C_save .append (self .C_curr )
262272
263273 def compute_populations (self ):
264274 """
265275 Compute electronic state population for a single step.
266276 """
267277 N , s = self .ngrids , self .nstates
268- Cvec = self .Ccurr
278+ C_vec = self .C_curr
269279
270280 # Compute SC once: shape (ndim,)
271- SC = self .S @ Cvec
281+ SC = self .S @ C_vec
272282
273- C_blocks = Cvec .view (s , N )
283+ C_blocks = C_vec .view (s , N )
274284 SC_blocks = SC .view (s , N )
275285
276286 # Compute P[i] = sum_j <C_j|S_{ji}|C_i> = Re[ sum_N (C_j*) * SC_j ]
277287 P = torch .sum (C_blocks .conj () * SC_blocks , dim = 1 ).real
278288
279289 return P
280290
291+ def compute_denmat (self ):
292+ """
293+ Compute electronic density matrix for a single step using the orthogonalization.
294+ """
295+ N , s = self .ngrids , self .nstates
296+ C_vec = self .C_curr
297+
298+ # Orthogonalize coefficients: C_ortho = S^{1/2} C
299+ C_ortho = self .S_half @ C_vec
300+
301+ C_blocks = C_ortho .view (s , N )
302+
303+ rho = C_blocks @ C_blocks .conj ().T # (s, s)
304+
305+ return rho
306+
281307 def compute_kinetic_energy (self ):
282308 """
283309 Compute nuclear kinetic energy as C^+ T C / C^+ S C for a single step.
284310 """
285311 N , s , ndim = self .ngrids , self .nstates , self .ndim
286312
287- # Rebuild compound kinetic matrix: T4D * Selec4D
288- Selec4D = self .Selec .view (s , N , s , N )
289- T4D = self .Tnucl .unsqueeze (0 ).unsqueeze (2 ) # (1, n, 1, m)
290- T4D_compound = Selec4D * T4D
291- T_compound = T4D_compound .permute (0 , 1 , 2 , 3 ).reshape (ndim , ndim )
313+ # Rebuild compound kinetic matrix: T_4d * s_elec_4d
314+ s_elec_4d = self .s_elec .view (s , N , s , N )
315+ T_4d = self .t_nucl [None , :, None , :]
316+ T_compound = (s_elec_4d * T_4d ).reshape (ndim , ndim )
292317
293- Cvec = self .Ccurr
318+ C_vec = self .C_curr
294319
295- numer = torch .vdot (Cvec , T_compound @ Cvec ).real
296- denom = torch .vdot (Cvec , self .S @ Cvec ).real
320+ numer = torch .vdot (C_vec , T_compound @ C_vec ).real
321+ denom = torch .vdot (C_vec , self .S @ C_vec ).real
297322
298323 return numer / denom
299324
@@ -304,17 +329,16 @@ def compute_potential_energy(self):
304329 """
305330 N , s , ndim = self .ngrids , self .nstates , self .ndim
306331
307- Selec4D = self .Selec .view (s , N , s , N )
308- S4D = self .Snucl . unsqueeze ( 0 ). unsqueeze ( 2 ) # (1, n, 1, m)
309- Ej4D = self .E [None , None , :, :] # (1,1,j,m)
332+ s_elec_4d = self .s_elec .view (s , N , s , N )
333+ S_4d = self .s_nucl [ None , :, None , :]
334+ E_j_4d = self .E [None , None , :, :] # (1,1,j,m)
310335
311- V4D_compound = Selec4D * (Ej4D * S4D )
312- V_compound = V4D_compound .permute (0 , 1 , 2 , 3 ).reshape (ndim , ndim )
336+ V_compound = (s_elec_4d * (E_j_4d * S_4d )).reshape (ndim , ndim )
313337
314- Cvec = self .Ccurr
338+ C_vec = self .C_curr
315339
316- numer = torch .vdot (Cvec , V_compound @ Cvec ).real
317- denom = torch .vdot (Cvec , self .S @ Cvec ).real
340+ numer = torch .vdot (C_vec , V_compound @ C_vec ).real
341+ denom = torch .vdot (C_vec , self .S @ C_vec ).real
318342
319343 return numer / denom
320344
@@ -323,13 +347,37 @@ def compute_total_energy(self):
323347 """
324348 Compute total energy as C^+ H C / C^+ S C for a single step.
325349 """
326- Cvec = self .Ccurr
350+ C_vec = self .C_curr
327351
328- numer = torch .vdot (Cvec , self .H @ Cvec ).real
329- denom = torch .vdot (Cvec , self .S @ Cvec ).real
352+ numer = torch .vdot (C_vec , self .H @ C_vec ).real
353+ denom = torch .vdot (C_vec , self .S @ C_vec ).real
330354
331355 return numer / denom
356+
357+ def compute_average_pos (self ):
358+ """
359+ Compute average position as <q_i> = \sum_i C^+ Q C / C^+ S C for a single step.
360+ """
361+ N , s , ndim = self .ngrids , self .nstates , self .ndim
332362
363+ C_vec = self .C_curr
364+
365+ denom = torch .vdot (C_vec , self .S @ C_vec ).real
366+ s_elec_4d = self .s_elec .view (s , N , s , N )
367+
368+ avg_q = []
369+ for idof in range (self .ndof ):
370+ q_med = 0.5 * (self .qgrid [:, None , idof ] + self .qgrid [None ,:,idof ])
371+ q_nucl = self .s_nucl * q_med
372+ Q_4d = q_nucl [None , :, None , :]
373+ Q_4d_compound = s_elec_4d * Q_4d
374+ Q_compound = Q_4d_compound .reshape (ndim , ndim )
375+
376+ numer = torch .vdot (C_vec , Q_compound @ C_vec ).real
377+ avg_q .append (numer / denom )
378+
379+ return avg_q
380+
333381 def save (self ):
334382 torch .save ( {"q0" :self .q0 ,
335383 "p0" :self .p0 ,
@@ -339,22 +387,24 @@ def save(self):
339387 "qgrid" :self .qgrid ,
340388 "nstates" :self .nstates ,
341389 "istate" :self .istate ,
342- "Snucl " :self .Snucl ,
343- "Tnucl " :self .Tnucl ,
390+ "s_nucl " :self .s_nucl ,
391+ "t_nucl " :self .t_nucl ,
344392 "E" :self .E ,
345- "Selec " :self .Selec ,
393+ "s_elec " :self .s_elec ,
346394 "S" :self .S ,
347395 "H" :self .H ,
348396 "U" :self .U ,
349397 "C_save" :self .C_save ,
350398 "save_every_n_steps" :self .save_every_n_steps ,
351- "Hamiltonian_scheme " : self .Hamiltonian_scheme ,
399+ "hamiltonian_scheme " : self .hamiltonian_scheme ,
352400 "dt" :self .dt , "nsteps" :self .nsteps ,
353401 "time" :self .time ,
354402 "kinetic_energy" :self .kinetic_energy ,
355403 "potential_energy" :self .potential_energy ,
356404 "total_energy" :self .total_energy ,
405+ "average_pos" :self .average_pos ,
357406 "population_right" :self .population_right ,
407+ "denmat" :self .denmat ,
358408 "norm" :self .norm
359409 }, F"{ self .prefix } .pt" )
360410
0 commit comments