Skip to content

Commit 31897ae

Browse files
committed
Added the exact solver (for adiabatic dynamics so far) written with PyTorch - it is way more efficient
1 parent 6039e60 commit 31897ae

3 files changed

Lines changed: 225 additions & 0 deletions

File tree

src/libra_py/dynamics/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
__all__ = ["bohmian",
1111
"exact",
12+
"exact_torch",
1213
"heom",
1314
"qtag",
1415
"tsh",
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# ***********************************************************
2+
# * Copyright (C) 2025 Alexey V. Akimov
3+
# * This file is distributed under the terms of the
4+
# * GNU General Public License as published by the
5+
# * Free Software Foundation; either version 3 of the
6+
# * License, or (at your option) any later version.
7+
# * http://www.gnu.org/copyleft/gpl.txt
8+
# ***********************************************************/
9+
10+
__all__ = ["compute",
11+
]
Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
# *********************************************************************************
2+
# * Copyright (C) 2025 Alexey V. Akimov
3+
# *
4+
# * This file is distributed under the terms of the GNU General Public License
5+
# * as published by the Free Software Foundation, either version 3 of
6+
# * the License, or (at your option) any later version.
7+
# * See the file LICENSE in the root directory of this distribution
8+
# * or <http://www.gnu.org/licenses/>.
9+
# ***********************************************************************************
10+
"""
11+
.. module:: compute
12+
:platform: Unix, Windows
13+
:synopsis: This module implements functions for doing exact on-the-grid dynamics with PyTorch
14+
List of functions:
15+
* sech # temporary here
16+
* Martens_model # temporary here
17+
* gaussian_wavepacket
18+
List of classes:
19+
* exact_tdse_solver
20+
21+
.. moduleauthor:: Alexey V. Akimov
22+
23+
"""
24+
25+
__author__ = "Alexey V. Akimov"
26+
__copyright__ = "Copyright 2025 Alexey V. Akimov"
27+
__credits__ = ["Alexey V. Akimov"]
28+
__license__ = "GNU-3"
29+
__version__ = "1.0"
30+
__maintainer__ = "Alexey V. Akimov"
31+
__email__ = "alexvakimov@gmail.com"
32+
__url__ = "https://github.com/Quantum-Dynamics-Hub/libra-code"
33+
34+
35+
36+
import torch
37+
import torch.fft
38+
39+
40+
def sech(x):
41+
return 1 / torch.cosh(x)
42+
43+
def Martens_model(q, params):
44+
"""
45+
q - Tensor(ndof)
46+
47+
Martens_model1 is just this one but with Vc = 0.0
48+
"""
49+
#params = {"Va": 0.00625, "Vb": 0.0106}
50+
Va = params.get("Va", 0.00625)
51+
Vb = params.get("Vb", 0.0106)
52+
Vc = params.get("Vc", 0.0)
53+
return Va * (sech(2.0*q[0]))**2 + 0.5 * Vb * (q[1] + Vc * (q[0]**2 - 1.0 ) )**2
54+
55+
56+
57+
def gaussian_wavepacket(q, params):
58+
"""
59+
q = tensor( [ndof, N_1, N_2, ... N_ndof] )
60+
61+
"""
62+
hbar = 1.0
63+
ndof = q.shape[0]
64+
sz = len(q.shape)
65+
66+
mass = torch.tensor(params.get("mass", [2000.0, 2000.0]) )
67+
omega = torch.tensor(params.get("omega", [0.004, 0.004]) )
68+
sigma = 1.0 / torch.sqrt( 2.0 * mass * omega )
69+
q0 = torch.tensor(params.get("q0", [-1.0, 0.0]) )
70+
p0 = torch.tensor(params.get("p0", [3.0, 0.0]) )
71+
72+
# Reshape q, p, sigma to be compatible in shape with q
73+
sigma = sigma.view(ndof, *[1]*(sz - 1) )
74+
q0 = q0.view(ndof, *[1]*(sz - 1) )
75+
p0 = p0.view(ndof, *[1]*(sz - 1) )
76+
77+
# Do the calculations:
78+
phase = 1j * torch.sum(p0 * q, dim=0, keepdim=False) / hbar
79+
envelope = torch.exp(-torch.sum( 0.25*(q - q0)**2 / sigma**2 , dim=0, keepdim=False) ) # because it is wavefunction
80+
norm = torch.prod(1.0 / ( (sigma * (2 * torch.pi)**0.5 )**0.5 ) , dim=0, keepdim=False)
81+
82+
return (norm * envelope * torch.exp(phase)) # this also reduces the first dimension of q
83+
84+
85+
86+
class exact_tdse_solver:
87+
def __init__(self, params):
88+
self.prefix = params.get("prefix", "exact-solution")
89+
self.grid_size = torch.tensor(params.get("grid_size", [4, 4]))
90+
self.ndim = len(self.grid_size)
91+
self.q_min = torch.tensor(params.get("q_min", [-10.0] * self.ndim))
92+
self.q_max = torch.tensor(params.get("q_max", [10.0] * self.ndim))
93+
self.save_every_n_steps = params.get("save_every_n_steps", 1)
94+
self.dt = params.get("dt", 0.01)
95+
self.nsteps = params.get("nsteps", 500)
96+
self.mass = torch.tensor(params.get("mass", [1.0] * self.ndim))
97+
self.potential_fn = params.get("potential_fn", None)
98+
self.potential_fn_params = params.get("potential_fn_params", None)
99+
self.psi0_fn = params.get("psi0_fn", None)
100+
self.psi0_fn_params = params.get("psi0_fn_params", None)
101+
self.psi = None
102+
self.psi_k = None
103+
self.prob_density = None
104+
self.device = params.get("device", torch.device("cuda" if torch.cuda.is_available() else "cpu"))
105+
self.hbar = 1.0
106+
self.psi_all = None
107+
self.time = []
108+
self.kinetic_energy = []
109+
self.potential_energy = []
110+
self.total_energy = []
111+
self.population_right = []
112+
self.norm = []
113+
114+
115+
def initialize_grids(self):
116+
print("Initializing grids")
117+
print("grid_size = ", self.grid_size)
118+
# Real-space grid
119+
q_axes = [torch.linspace(self.q_min[i], self.q_max[i], self.grid_size[i]) for i in range(self.ndim)]
120+
q_grids = torch.meshgrid(*q_axes, indexing="ij")
121+
self.dq = torch.tensor([q_axes[i][1] - q_axes[i][0] for i in range(self.ndim)])
122+
self.Q = torch.stack(q_grids)
123+
print("Q grid = ", self.Q.shape)
124+
print("dq = ", self.dq)
125+
126+
# Momentum-space grid
127+
self.dk = 2 * torch.pi / (self.grid_size * self.dq)
128+
k_axes = [torch.fft.fftshift(torch.arange(-self.grid_size[i] // 2, self.grid_size[i] // 2)) * self.dk[i] for i in range(self.ndim)]
129+
k_grids = torch.meshgrid(*k_axes, indexing="ij")
130+
self.K = torch.stack(k_grids)
131+
print("K grid = ", self.K.shape)
132+
print("dk = ", self.dk)
133+
134+
# Volume elements
135+
self.dV = self.dq.prod()
136+
self.dVk = self.dk.prod() #(self.dq / self.grid_size).prod()
137+
print("dV = ", self.dV)
138+
print("dVk = ", self.dVk)
139+
140+
size = [ int((self.nsteps - self.nsteps % self.save_every_n_steps ) / self.save_every_n_steps)+1 ]
141+
for i in range(self.ndim):
142+
size.append(self.grid_size[i])
143+
self.psi_all = torch.zeros(size, dtype=torch.cfloat)
144+
145+
def initialize_operators(self):
146+
view_shape = [self.ndim] + [1] * self.ndim
147+
self.T = 0.5 * torch.sum((self.hbar * self.K) ** 2 / self.mass.view(view_shape), dim=0)
148+
self.V = self.potential_fn(self.Q, self.potential_fn_params) if self.potential_fn else torch.zeros_like(self.T)
149+
self.psi = self.psi0_fn(self.Q, self.psi0_fn_params) if self.psi0_fn else torch.zeros_like(self.V)
150+
self.psi_k = torch.fft.fftn(self.psi) #, norm='forward')
151+
self.psi_all[0] = self.psi
152+
self.expV_half = torch.exp(-0.5j * self.V * self.dt / self.hbar)
153+
self.expT = torch.exp(-1j * self.T * self.dt / self.hbar)
154+
155+
def propagate(self):
156+
for step in range(self.nsteps):
157+
158+
if step % self.save_every_n_steps == 0:
159+
istep = int(step / self.save_every_n_steps)
160+
self.psi_all[istep] = self.psi
161+
self.prob_density = torch.abs(self.psi) ** 2
162+
self.psi_k = torch.fft.fftn(self.psi) # norm='forward')
163+
KE = torch.sum(torch.abs(self.psi_k) ** 2 * self.T) * (self.dq/self.grid_size).prod()
164+
PE = torch.sum(self.prob_density * self.V) * self.dV
165+
nrm = torch.sum(torch.abs(self.psi) ** 2 ) * self.dV
166+
x_coords = self.Q[0]
167+
right_mask = self.Q[0] > 0
168+
pop_right = torch.sum(self.prob_density[right_mask]) * self.dV
169+
170+
self.norm.append( nrm )
171+
self.time.append(step * self.dt)
172+
self.kinetic_energy.append( KE.real.item() )
173+
self.potential_energy.append( PE.real.item() )
174+
self.total_energy.append( KE + PE )
175+
self.population_right.append(pop_right.item())
176+
177+
print(f"Step {step}: KE = {KE:.4f}, PE = {PE:.4f}, Total = {KE + PE:.4f}, Norm = {nrm:.4f}")
178+
179+
#============== Propagate ===================
180+
self.psi *= self.expV_half
181+
self.psi_k = torch.fft.fftn(self.psi) # norm='forward')
182+
self.psi_k *= self.expT
183+
self.psi = torch.fft.ifftn(self.psi_k) # norm='forward')
184+
self.psi *= self.expV_half
185+
186+
def save(self):
187+
torch.save( {"grid_size":self.grid_size,
188+
"ndim":self.ndim,
189+
"q_min":self.q_min, "q_max":self.q_max,
190+
"save_every_n_steps": self.save_every_n_steps,
191+
"dt":self.dt, "nsteps":self.nsteps,
192+
"mass":self.mass,
193+
"psi":self.psi, "psi_k":self.psi_k,
194+
"prob_density":self.prob_density,
195+
"psi_all":self.psi_all,
196+
"time":self.time,
197+
"Q":self.Q, "K":self.K, "dq":self.dq, "dk":self.dk,
198+
"dV":self.dV, "dVk":self.dVk,
199+
"kinetic_energy":self.kinetic_energy,
200+
"potential_energy":self.potential_energy,
201+
"total_energy":self.total_energy,
202+
"population_right":self.population_right,
203+
"norm":self.norm,
204+
"V":self.V, "T":self.T
205+
}, F"{self.prefix}.pt" )
206+
207+
def solve(self):
208+
self.initialize_grids()
209+
self.initialize_operators()
210+
self.propagate()
211+
self.save()
212+
213+

0 commit comments

Comments
 (0)