Skip to content

Commit d7d13e0

Browse files
committed
[Iter 4] Code modification in PDE_D_Durotaxis.py
[Automated commit by Claude]
1 parent 771833b commit d7d13e0

1 file changed

Lines changed: 237 additions & 0 deletions

File tree

Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
import torch
2+
import torch_geometric as pyg
3+
import torch_geometric.utils as pyg_utils
4+
from ParticleGraph.utils import to_numpy
5+
6+
class PDE_D_Durotaxis(pyg.nn.MessagePassing):
7+
"""
8+
Gradient-amplified mobility (durotaxis) particle model.
9+
10+
Particles move faster in regions with steep concentration gradients,
11+
analogous to cells migrating preferentially on stiffer substrates.
12+
13+
Literature:
14+
- Lo, C. M. et al. (2000) Biophysical Journal 79:144-152
15+
"Cell movement is guided by the rigidity of the substrate"
16+
- Isenberg, B. C. et al. (2009) Biophysical Journal 97:1313-1322
17+
"Vascular smooth muscle cell durotaxis"
18+
19+
Physics:
20+
M_effective = M * (1 + alpha * clamp(|grad_C1|, max=1.0))
21+
Gradient magnitude is clamped at 1.0 to prevent boundary artifacts.
22+
23+
Per-type params layout: [M1, M2, consumption, production, ar_p1, ar_p2, ar_p3, ar_p4]
24+
"""
25+
26+
PARAMS_DOC = {
27+
"model_name": "Durotaxis",
28+
"literature": "Lo et al. (2000) Biophysical Journal 79:144-152",
29+
"description": "Gradient-amplified mobility: particles move faster at steep gradients",
30+
"equations": {
31+
"field_to_particle": "v = M * (1 + alpha * clamp(|grad_C1|, max=1.0)) * (grad_C1 + grad_C2) * dir",
32+
"particle_to_field": "dC1 = -consumption * w(r), dC2 = production * w(r)",
33+
"particle_to_particle": "f = (p1*exp(-d^(2p2)/(2sigma^2)) - p3*exp(-d^(2p4)/(2sigma^2))) * dir"
34+
},
35+
"params_mesh": [
36+
{
37+
"row": 0, "description": "C1 field parameters (shared with mesh model)",
38+
"slots": [
39+
{"index": 0, "name": "D1", "description": "Diffusion coeff for C1 (mesh model)", "typical_range": [0.01, 0.5]},
40+
{"index": 1, "name": "Da_c", "description": "Damkohler number (mesh model)", "typical_range": [1.0, 50.0]},
41+
{"index": 2, "name": "A", "description": "Brusselator param A (mesh model)", "typical_range": [0.5, 5.0]},
42+
{"index": 3, "name": "B", "description": "Brusselator param B (mesh model)", "typical_range": [1.0, 10.0]},
43+
{"index": 4, "name": "mu", "description": "Morphological parameter (mesh model)", "typical_range": [0.01, 0.1]},
44+
{"index": 5, "name": "M1", "description": "Mobility coefficient for C1 gradients", "typical_range": [-16, 16]}
45+
]
46+
},
47+
{
48+
"row": 1, "description": "C2 field parameters + durotaxis control",
49+
"slots": [
50+
{"index": 0, "name": "D2", "description": "Diffusion coeff for C2 (mesh model)", "typical_range": [0.1, 1.0]},
51+
{"index": 1, "name": "M2", "description": "Mobility coefficient for C2 gradients", "typical_range": [-16, 16]},
52+
{"index": 2, "name": "unused", "description": "Unused (pad)", "typical_range": [0.0, 0.0]},
53+
{"index": 3, "name": "grad_amp_alpha", "description": "Durotaxis gradient amplification (0=off, >0=faster at steep gradients)", "typical_range": [0.0, 2.0]}
54+
]
55+
},
56+
{
57+
"row": 2, "description": "Particle-field coupling",
58+
"slots": [
59+
{"index": 0, "name": "Pe", "description": "Peclet number", "typical_range": [0.5, 2.0]},
60+
{"index": 1, "name": "consumption", "description": "Particle consumption rate of C1", "typical_range": [10, 200]},
61+
{"index": 2, "name": "production", "description": "Particle production rate of C2", "typical_range": [-200, -10]},
62+
{"index": 3, "name": "influence_radius", "description": "Gaussian influence radius for pf coupling", "typical_range": [0.01, 0.1]}
63+
]
64+
}
65+
],
66+
"particle_params": {
67+
"description": "Per-type params from simulation.params (one row per n_particle_types)",
68+
"slots": [
69+
{"index": 0, "name": "M1", "description": "Per-type mobility for C1"},
70+
{"index": 1, "name": "M2", "description": "Per-type mobility for C2"},
71+
{"index": 2, "name": "consumption", "description": "Per-type consumption rate"},
72+
{"index": 3, "name": "production", "description": "Per-type production rate"},
73+
{"index": 4, "name": "ar_p1", "description": "Attraction strength"},
74+
{"index": 5, "name": "ar_p2", "description": "Attraction exponent"},
75+
{"index": 6, "name": "ar_p3", "description": "Repulsion strength"},
76+
{"index": 7, "name": "ar_p4", "description": "Repulsion exponent"}
77+
]
78+
}
79+
}
80+
81+
def __init__(self, aggr_type='mean', p=None, particle_params=None, bc_dpos=None, dimension=2, sigma=0.005):
82+
super(PDE_D_Durotaxis, self).__init__(aggr=aggr_type)
83+
84+
self.p = p
85+
self.particle_params = particle_params
86+
self.bc_dpos = bc_dpos
87+
self.dimension = dimension
88+
self.sigma = sigma
89+
90+
self.M1 = p[0, 5]
91+
self.M2 = p[1, 1]
92+
self.consumption_rate = p[2, 1]
93+
self.production_rate = p[2, 2]
94+
self.influence_radius = p[2, 3]
95+
self.Pe = p[2, 0]
96+
self.repulsion_strength = 50
97+
self.repulsion_range = 0.04
98+
99+
self.grad_amp_alpha = p[1, 3] if p.shape[1] > 3 else 0.0
100+
101+
print(f"initialized PDE_D_Durotaxis with parameters:")
102+
print(f" mobility: M1={self.M1.item()}, M2={self.M2.item()}")
103+
ga_val = self.grad_amp_alpha.item() if hasattr(self.grad_amp_alpha, 'item') else self.grad_amp_alpha
104+
print(f" grad_amp_alpha={ga_val:.3f} (M_eff = M*(1+alpha*clamp(|gradC|,max=1.0)), Lo 2000)")
105+
print(f" Pe={self.Pe.item():.3f}, sigma={self.sigma}")
106+
print(f" particle->field: consumption={self.consumption_rate.item()}, production={self.production_rate.item()}, influence_radius={self.influence_radius.item():.3f}")
107+
if particle_params is not None:
108+
print(f" multi-type support: {particle_params.shape[0]} particle types")
109+
110+
def forward(self, data, direction='fp'):
111+
x, edge_index = data.x, data.edge_index
112+
edge_index, _ = pyg_utils.remove_self_loops(edge_index)
113+
114+
if self.particle_params is not None:
115+
particle_type = x[:, 1 + 2*self.dimension].long()
116+
max_type = particle_type.max().item()
117+
n_param_rows = self.particle_params.shape[0]
118+
if max_type >= n_param_rows:
119+
raise ValueError(
120+
f"PDE_D_Durotaxis: particle_params has {n_param_rows} rows but found "
121+
f"particle type {max_type}. Need {max_type + 1} rows in simulation.params."
122+
)
123+
parameters = self.particle_params[to_numpy(particle_type), :]
124+
else:
125+
parameters = None
126+
127+
if direction == 'interpolate':
128+
result = self.propagate(edge_index, x=x, mode='interpolate', parameters=parameters)
129+
pos = x[:, 1:self.dimension+1]
130+
in_box = ((pos >= 0) & (pos <= 1)).all(dim=1, keepdim=True)
131+
result = result * in_box.float()
132+
return result
133+
elif direction == 'fp':
134+
result = self.propagate(edge_index, x=x, mode='fp', parameters=parameters)
135+
pos = x[:, 1:self.dimension+1]
136+
in_box = ((pos >= 0) & (pos <= 1)).all(dim=1, keepdim=True)
137+
result = result * in_box.float()
138+
return result
139+
elif direction == 'pf':
140+
result = self.propagate(edge_index, x=x, mode='pf', parameters=parameters)
141+
return result
142+
else:
143+
result = self.propagate(edge_index, x=x, mode='pp', parameters=parameters)
144+
return result
145+
146+
def message(self, edge_index_i, edge_index_j, x_i, x_j, mode=None, parameters_i=None):
147+
pos_i = x_i[:, 1:self.dimension+1]
148+
pos_j = x_j[:, 1:self.dimension+1]
149+
150+
d_pos = self.bc_dpos(pos_j - pos_i)
151+
dist = torch.sqrt(torch.sum(d_pos**2, dim=1))
152+
dist_safe = torch.clamp(dist, min=1e-6)
153+
154+
if mode == 'interpolate':
155+
C1_mesh = x_j[:, 6:7]
156+
C2_mesh = x_j[:, 7:8]
157+
weight = torch.exp(-dist / 0.01).unsqueeze(1)
158+
return torch.cat([C1_mesh * weight, C2_mesh * weight, weight], dim=1)
159+
160+
elif mode == 'fp':
161+
fields_i = x_i[:, 6:8]
162+
fields_j = x_j[:, 6:8]
163+
164+
dC1 = fields_j[:, 0:1] - fields_i[:, 0:1]
165+
dC2 = fields_j[:, 1:2] - fields_i[:, 1:2]
166+
167+
kernel = torch.exp(-dist / 0.05)
168+
dir_norm = d_pos / dist_safe.unsqueeze(1)
169+
domain_scale = 32.0
170+
grad_C1 = (dC1 * kernel.unsqueeze(1)) / (dist_safe.unsqueeze(1) * domain_scale)
171+
grad_C2 = (dC2 * kernel.unsqueeze(1)) / (dist_safe.unsqueeze(1) * domain_scale)
172+
173+
if parameters_i is not None:
174+
M1 = parameters_i[:, 0:1]
175+
M2 = parameters_i[:, 1:2]
176+
else:
177+
M1 = self.M1
178+
M2 = self.M2
179+
180+
velocity_raw = (M1 * grad_C1 + M2 * grad_C2) * dir_norm
181+
182+
# Durotaxis: amplify velocity at steep gradients
183+
if self.grad_amp_alpha > 0:
184+
grad_mag = torch.abs(grad_C1)
185+
grad_mag_clamped = torch.clamp(grad_mag, max=1.0)
186+
amp_factor = 1.0 + self.grad_amp_alpha * grad_mag_clamped
187+
velocity_raw = velocity_raw * amp_factor
188+
189+
return velocity_raw
190+
191+
elif mode == 'pf':
192+
weights = torch.exp(-dist**2 / (2 * (self.influence_radius/3)**2))
193+
194+
if parameters_i is not None:
195+
consumption = parameters_i[:, 2]
196+
production = parameters_i[:, 3]
197+
else:
198+
consumption = self.consumption_rate
199+
production = self.production_rate
200+
201+
field_updates = torch.zeros((pos_i.size(0), 2), device=pos_i.device)
202+
field_updates[:, 0] = -consumption * weights
203+
field_updates[:, 1] = production * weights
204+
return field_updates
205+
206+
else: # mode == 'pp'
207+
if parameters_i is not None:
208+
p1 = parameters_i[:, 4]
209+
p2 = parameters_i[:, 5]
210+
p3 = parameters_i[:, 6]
211+
p4 = parameters_i[:, 7]
212+
213+
f = (p1 * torch.exp(-dist ** (2 * p2) / (2 * self.sigma ** 2))
214+
- p3 * torch.exp(-dist ** (2 * p4) / (2 * self.sigma ** 2)))
215+
216+
forces = f[:, None] * d_pos / dist_safe.unsqueeze(1)
217+
else:
218+
forces = torch.zeros_like(pos_i)
219+
in_range = dist < self.repulsion_range
220+
if in_range.any():
221+
dir_norm = d_pos / dist_safe.unsqueeze(1)
222+
repulsion_mag = self.repulsion_strength * torch.exp(
223+
-5.0 * dist[in_range] / self.repulsion_range
224+
)
225+
forces[in_range] = -dir_norm[in_range] * repulsion_mag.unsqueeze(1)
226+
227+
return forces
228+
229+
def update(self, aggr_out, mode=None):
230+
if mode == 'interpolate':
231+
C1_weighted = aggr_out[:, 0:1]
232+
C2_weighted = aggr_out[:, 1:2]
233+
weight_sum = aggr_out[:, 2:3]
234+
weight_sum = torch.clamp(weight_sum, min=1e-10)
235+
return torch.cat([C1_weighted / weight_sum, C2_weighted / weight_sum], dim=1)
236+
else:
237+
return aggr_out

0 commit comments

Comments
 (0)