Skip to content

Commit 9282be7

Browse files
committed
added test suite and validation
1 parent d392ddb commit 9282be7

7 files changed

Lines changed: 560 additions & 16 deletions

File tree

pdesolvers/main.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ def main():
1515
# solver1 = pde.Heat1DCNSolver(equation1)
1616
# solver2 = pde.Heat1DExplicitSolver(equation1)
1717

18-
#testing Heat 2d
18+
# testing 2d heat equation
19+
1920
xLength = 10 # Lx
2021
yLength = 10 # Ly
2122
maxTime = 0.5 # tmax
@@ -30,12 +31,13 @@ def main():
3031
equation.set_top_boundary_temp(lambda t, x: 20 + 5 * x * (xLength - x) * t**4)
3132
equation.set_bottom_boundary_temp(lambda t, x: 20)
3233

33-
solver = pde.Heat2DExplicitSolver(equation)
34-
solution = solver.solve()
35-
# Save temperature matrix to CSV file
36-
# print("Saving temperature matrix to file...")
37-
# np.savetxt("temperature_data_pkg.csv", solution.result.reshape(numPointsTime, -1), delimiter=",")
38-
solution.animate(export=True)
34+
solver1 = pde.Heat2DExplicitSolver(equation)
35+
solver1 = pde.Heat2DCNSolver(equation)
36+
solution1 = solver1.solve()
37+
solution1.animate(filename="Explicit")
38+
solver2 = pde.Heat2DCNSolver(equation)
39+
solution2 = solver2.solve()
40+
solution2.animate(filename="Crank-Nicolson")
3941

4042
# testing for monte carlo pricing
4143
# ticker = 'AAPL'

pdesolvers/pdes/heat_2d.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,17 @@
22

33
class HeatEquation2D:
44
def __init__(self, time, t_nodes, k, xlength, x_nodes, ylength=None, y_nodes=None):
5+
6+
assert time > 0, "Time must be positive"
7+
assert t_nodes > 1, "Number of time nodes must be greater than 1"
8+
assert k > 0, "Diffusivity constant k must be positive"
9+
assert xlength > 0, "X-length must be positive"
10+
assert x_nodes > 2, "Number of x nodes must be greater than 2"
11+
if ylength is not None:
12+
assert ylength > 0, "Y-length must be positive"
13+
if y_nodes is not None:
14+
assert y_nodes > 2, "Number of y nodes must be greater than 2"
15+
516
self.__time = time
617
self.__t_nodes = t_nodes
718
self.__k = k

pdesolvers/solution/solution.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def __plot_surface(u_k, k, ax, xDomain, yDomain, dt):
196196

197197
return surf
198198

199-
def animate(self, export=False):
199+
def animate(self, export=False, filename="heat_equation_2d_plot.gif"):
200200
print("Creating animation...")
201201
self
202202
fig = plt.figure(figsize=(12, 8))
@@ -208,7 +208,7 @@ def animateFrame(k):
208208
anim = FuncAnimation(fig, animateFrame, interval=100, frames=len(self.t_grid), repeat=True)
209209

210210
if export:
211-
anim.save("heat_equation_2d_plot.gif", writer='pillow', fps=5)
211+
anim.save(filename+'.gif', writer='pillow', fps=5)
212212

213213
plt.show()
214214

pdesolvers/solvers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from .heat_solvers import Heat1DExplicitSolver, Heat1DCNSolver
2-
from .heat2d_solvers import Heat2DExplicitSolver
2+
from .heat2d_solvers import Heat2DExplicitSolver, Heat2DCNSolver
33
from .black_scholes_solvers import BlackScholesExplicitSolver, BlackScholesCNSolver

pdesolvers/solvers/heat2d_solvers.py

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,14 @@ def solve(self):
2424
logging.info(f"Starting {self.__class__.__name__} with {self.equation.x_nodes+1} spatial nodes and {self.equation.t_nodes+1} time nodes.")
2525
start = time.perf_counter()
2626

27+
if (self.equation.left_boundary is None or
28+
self.equation.right_boundary is None or
29+
self.equation.top_boundary is None or
30+
self.equation.bottom_boundary is None):
31+
raise ValueError("All boundary conditions must be set before solving")
32+
if self.equation.initial_temp is None:
33+
raise ValueError("Initial condition must be set before solving")
34+
2735
x = np.linspace(0, self.equation.xlength, self.equation.x_nodes)
2836
y = np.linspace(0, self.equation.ylength, self.equation.y_nodes)
2937
t = np.linspace(0, self.equation.time, self.equation.t_nodes)
@@ -45,7 +53,7 @@ def solve(self):
4553
self.equation.initial_temp,
4654
x, y, t)
4755

48-
print("Calculating temperature evolution...")
56+
print(f"Calculating temperature evolution with {self.equation.t_nodes-1} iterations...", flush=True)
4957
for tau in range(self.equation.t_nodes-1):
5058
for i in range(1, self.equation.x_nodes-1):
5159
for j in range(1, self.equation.y_nodes-1):
@@ -69,13 +77,27 @@ def __init__(self, equation: heat.HeatEquation2D):
6977
def solve(self):
7078
logging.info(f"Starting {self.__class__.__name__} with {self.equation.x_nodes+1} spatial nodes and {self.equation.t_nodes+1} time nodes.")
7179
start = time.perf_counter()
80+
81+
if (self.equation.left_boundary is None or
82+
self.equation.right_boundary is None or
83+
self.equation.top_boundary is None or
84+
self.equation.bottom_boundary is None):
85+
raise ValueError("All boundary conditions must be set before solving")
86+
if self.equation.initial_temp is None:
87+
raise ValueError("Initial condition must be set before solving")
88+
7289
x = np.linspace(0, self.equation.xlength, self.equation.x_nodes)
7390
y = np.linspace(0, self.equation.ylength, self.equation.y_nodes)
7491
t = np.linspace(0, self.equation.time, self.equation.t_nodes)
7592

7693
dx = x[1]-x[0]
7794
dy = y[1]-y[0]
7895
dt = t[1]-t[0]
96+
c = self.equation.k * dt / 2
97+
cx = c / (dx**2)
98+
cy = c / (dy**2)
99+
alpha = 1 + 2*cx + 2*cy
100+
beta = 1 - 2*cx - 2*cy
79101

80102
print("Initializing matrix...")
81103
U = utility.Heat2DHelper.initMatrix(self.equation.t_nodes,
@@ -89,9 +111,33 @@ def solve(self):
89111
x, y, t)
90112

91113
# create sparse matrix
92-
93-
# time-stepping loop
114+
G, n_interior_x, n_interior_y = utility.Heat2DHelper.innitTriDiagMatrix(self.equation.x_nodes, self.equation.y_nodes, cx, cy, alpha)
94115

116+
# time-stepping loop
117+
print(f"Calculating temperature evolution with {self.equation.t_nodes-1} iterations...", flush=True)
118+
for tau in range(self.equation.t_nodes - 1):
119+
rhs = np.zeros(n_interior_x * n_interior_y)
120+
idx = 0
121+
for j in range(1, self.equation.y_nodes-1):
122+
for i in range(1, self.equation.x_nodes-1):
123+
# RHS = β*U_τ + cx*(neighbors_x) + cy*(neighbors_y) + boundary_terms
124+
rhs[idx] = beta * U[tau, i, j]
125+
rhs[idx] += cx * (U[tau, i-1, j] + U[tau, i+1, j])
126+
rhs[idx] += cy * (U[tau, i, j-1] + U[tau, i, j+1])
127+
# Boundary contributions
128+
if i == 1:
129+
rhs[idx] += cx * U[tau+1, 0, j]
130+
if i == self.equation.x_nodes-2:
131+
rhs[idx] += cx * U[tau+1, -1, j]
132+
if j == 1:
133+
rhs[idx] += cy * U[tau+1, i, 0]
134+
if j == self.equation.y_nodes-2:
135+
rhs[idx] += cy * U[tau+1, i, -1]
136+
idx += 1
137+
# Solve G*u_{τ+1} = rhs
138+
u_next_interior = spsolve(G, rhs)
139+
U[tau+1, 1:-1, 1:-1] = u_next_interior.reshape((n_interior_x, n_interior_y))
140+
95141
end = time.perf_counter()
96142
duration = end - start
97143
logging.info(f"Solver completed in {duration} seconds.")

0 commit comments

Comments
 (0)