Skip to content

Commit 9290105

Browse files
committed
[Iter 16] Code modification in PDE_D_DurotaxisThreshold.py
[Automated commit by Claude]
1 parent fa0d149 commit 9290105

1 file changed

Lines changed: 288 additions & 0 deletions

File tree

Lines changed: 288 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,288 @@
1+
import torch
2+
import torch_geometric as pyg
3+
import torch_geometric.utils as pyg_utils
4+
from ParticleGraph.utils import to_numpy
5+
6+
class PDE_D_DurotaxisThreshold(pyg.nn.MessagePassing):
7+
"""
8+
Hybrid particle model combining durotaxis gradient amplification with
9+
concentration-threshold bistable coupling.
10+
11+
This model combines two mechanisms:
12+
1. Durotaxis (Lo et al. 2000): Particles move faster at steep gradients
13+
M_eff = M * (1 + alpha * clamp(|grad_C1|, max=1.0))
14+
2. ThresholdCoupling (Wolpert 1969): Bistable response reverses mobility
15+
sign at concentration threshold T
16+
sign_factor = -tanh(steepness * (C1 - T) / A)
17+
18+
The combination enables rich morphological dynamics (from gradient
19+
amplification) while maintaining convergent stability (from threshold
20+
coupling). Particles are attracted toward the C1=T isoline with
21+
velocity that increases in steep-gradient regions.
22+
23+
Literature:
24+
- Lo, C. M. et al. (2000) Biophysical Journal 79:144-152
25+
"Cell movement is guided by the rigidity of the substrate"
26+
- Wolpert, L. (1969) J Theor Biol 25:1-47
27+
"Positional information and the spatial pattern of cellular differentiation"
28+
- Isenberg et al. (2009) Biophys J 97:1313-1322
29+
"Vascular smooth muscle cell durotaxis"
30+
31+
Per-type params layout: [M1, M2, consumption, production, ar_p1, ar_p2, ar_p3, ar_p4]
32+
"""
33+
34+
PARAMS_DOC = {
35+
"model_name": "DurotaxisThreshold",
36+
"literature": "Lo et al. (2000) Biophys J 79:144; Wolpert (1969) J Theor Biol 25:1",
37+
"description": "Gradient-amplified mobility + bistable concentration-threshold coupling",
38+
"equations": {
39+
"field_to_particle": "v = M * (1+alpha*clamp(|gradC1|,max=1)) * (-tanh(3*(C1-T)/A)) * (grad_C1+grad_C2) * dir",
40+
"particle_to_field": "dC1 = -consumption * w(r), dC2 = production * w(r)",
41+
"particle_to_particle": "f = (p1*exp(-d^(2p2)/(2sigma^2)) - p3*exp(-d^(2p4)/(2sigma^2))) * dir"
42+
},
43+
"params_mesh": [
44+
{
45+
"row": 0, "description": "C1 field parameters + CTC threshold",
46+
"slots": [
47+
{"index": 0, "name": "D1", "description": "Diffusion coeff for C1 (mesh model)", "typical_range": [0.01, 0.5]},
48+
{"index": 1, "name": "Da_c", "description": "Damkohler number (mesh model)", "typical_range": [1.0, 50.0]},
49+
{"index": 2, "name": "A", "description": "Brusselator param A (mesh model, also CTC reference)", "typical_range": [0.5, 5.0]},
50+
{"index": 3, "name": "B", "description": "Brusselator param B (mesh model)", "typical_range": [1.0, 10.0]},
51+
{"index": 4, "name": "mu", "description": "Morphological parameter (mesh model)", "typical_range": [0.01, 0.1]},
52+
{"index": 5, "name": "M1", "description": "Mobility coefficient for C1 gradients", "typical_range": [-16, 16]},
53+
{"index": 6, "name": "grad_amp_alpha", "description": "Durotaxis gradient amplification (0=off, >0=faster at steep gradients)", "typical_range": [0.0, 2.0]},
54+
{"index": 7, "name": "ctc_threshold", "description": "CTC threshold (T=ctc*A; reversal at C1=T)", "typical_range": [0.5, 3.0]}
55+
]
56+
},
57+
{
58+
"row": 1, "description": "C2 field parameters",
59+
"slots": [
60+
{"index": 0, "name": "D2", "description": "Diffusion coeff for C2 (mesh model)", "typical_range": [0.1, 1.0]},
61+
{"index": 1, "name": "M2", "description": "Mobility coefficient for C2 gradients", "typical_range": [-16, 16]}
62+
]
63+
},
64+
{
65+
"row": 2, "description": "Particle-field coupling + per-type threshold spread",
66+
"slots": [
67+
{"index": 0, "name": "Pe", "description": "Peclet number", "typical_range": [0.5, 2.0]},
68+
{"index": 1, "name": "consumption", "description": "Particle consumption rate of C1", "typical_range": [10, 200]},
69+
{"index": 2, "name": "production", "description": "Particle production rate of C2", "typical_range": [-200, -10]},
70+
{"index": 3, "name": "influence_radius", "description": "Gaussian influence radius for pf coupling", "typical_range": [0.01, 0.1]},
71+
{"index": 4, "name": "unused", "description": "Unused (pad)", "typical_range": [0.0, 0.0]},
72+
{"index": 5, "name": "cross_type_factor", "description": "Per-type CTC threshold spread (0=same threshold, 0.3=+-30% spread)", "typical_range": [0.0, 0.5]}
73+
]
74+
}
75+
],
76+
"width_constraint": "ALL rows of params_mesh MUST have same number of columns (8). Pad shorter rows.",
77+
"particle_params": {
78+
"description": "Per-type params from simulation.params (one row per n_particle_types)",
79+
"slots": [
80+
{"index": 0, "name": "M1", "description": "Per-type mobility for C1"},
81+
{"index": 1, "name": "M2", "description": "Per-type mobility for C2"},
82+
{"index": 2, "name": "consumption", "description": "Per-type consumption rate"},
83+
{"index": 3, "name": "production", "description": "Per-type production rate"},
84+
{"index": 4, "name": "ar_p1", "description": "Attraction strength"},
85+
{"index": 5, "name": "ar_p2", "description": "Attraction exponent"},
86+
{"index": 6, "name": "ar_p3", "description": "Repulsion strength"},
87+
{"index": 7, "name": "ar_p4", "description": "Repulsion exponent"}
88+
]
89+
}
90+
}
91+
92+
def __init__(self, aggr_type='mean', p=None, particle_params=None, bc_dpos=None, dimension=2, sigma=0.005):
93+
super(PDE_D_DurotaxisThreshold, self).__init__(aggr=aggr_type)
94+
95+
self.p = p
96+
self.particle_params = particle_params
97+
self.bc_dpos = bc_dpos
98+
self.dimension = dimension
99+
self.sigma = sigma
100+
101+
self.M1 = p[0, 5]
102+
self.M2 = p[1, 1]
103+
self.consumption_rate = p[2, 1]
104+
self.production_rate = p[2, 2]
105+
self.influence_radius = p[2, 3]
106+
self.Pe = p[2, 0]
107+
self.repulsion_strength = 50
108+
self.repulsion_range = 0.04
109+
110+
# Durotaxis gradient amplification
111+
self.grad_amp_alpha = p[0, 6] if p.shape[1] > 6 else 0.0
112+
113+
# CTC threshold
114+
self.ctc_threshold = p[0, 7] if p.shape[1] > 7 else 0.0
115+
self.A_ref = p[0, 2]
116+
117+
# Per-type threshold spread
118+
self.cross_type_factor = p[2, 5] if p.shape[1] > 5 else 0.0
119+
120+
print(f"initialized PDE_D_DurotaxisThreshold with parameters:")
121+
print(f" mobility: M1={self.M1.item()}, M2={self.M2.item()}")
122+
ga_val = self.grad_amp_alpha.item() if hasattr(self.grad_amp_alpha, 'item') else self.grad_amp_alpha
123+
print(f" grad_amp_alpha={ga_val:.3f} (durotaxis: M_eff=M*(1+alpha*|gradC|), Lo 2000)")
124+
ctc_val = self.ctc_threshold.item() if hasattr(self.ctc_threshold, 'item') else self.ctc_threshold
125+
T_val = ctc_val * self.A_ref.item()
126+
print(f" ctc_threshold={ctc_val:.3f} (T={T_val:.2f}, reversal at C1=T*A, Wolpert 1969)")
127+
ctf_val = self.cross_type_factor.item() if hasattr(self.cross_type_factor, 'item') else self.cross_type_factor
128+
if ctf_val > 0 and particle_params is not None:
129+
n_types = particle_params.shape[0]
130+
mean_idx = (n_types - 1) / 2.0
131+
for t in range(n_types):
132+
t_offset = ctf_val * (t - mean_idx)
133+
t_val = T_val * (1.0 + t_offset)
134+
print(f" Type {t}: CTC threshold = {t_val:.2f} (offset={t_offset:+.2f})")
135+
print(f" Pe={self.Pe.item():.3f}, sigma={self.sigma}")
136+
print(f" particle->field: consumption={self.consumption_rate.item()}, production={self.production_rate.item()}, influence_radius={self.influence_radius.item():.3f}")
137+
if particle_params is not None:
138+
print(f" multi-type support: {particle_params.shape[0]} particle types")
139+
140+
def forward(self, data, direction='fp'):
141+
x, edge_index = data.x, data.edge_index
142+
edge_index, _ = pyg_utils.remove_self_loops(edge_index)
143+
144+
if self.particle_params is not None:
145+
particle_type = x[:, 1 + 2*self.dimension].long()
146+
max_type = particle_type.max().item()
147+
n_param_rows = self.particle_params.shape[0]
148+
if max_type >= n_param_rows:
149+
raise ValueError(
150+
f"PDE_D_DurotaxisThreshold: particle_params has {n_param_rows} rows but found "
151+
f"particle type {max_type}. Need {max_type + 1} rows in simulation.params."
152+
)
153+
parameters = self.particle_params[to_numpy(particle_type), :]
154+
else:
155+
parameters = None
156+
157+
if direction == 'interpolate':
158+
result = self.propagate(edge_index, x=x, mode='interpolate', parameters=parameters)
159+
pos = x[:, 1:self.dimension+1]
160+
in_box = ((pos >= 0) & (pos <= 1)).all(dim=1, keepdim=True)
161+
result = result * in_box.float()
162+
return result
163+
elif direction == 'fp':
164+
result = self.propagate(edge_index, x=x, mode='fp', parameters=parameters)
165+
pos = x[:, 1:self.dimension+1]
166+
in_box = ((pos >= 0) & (pos <= 1)).all(dim=1, keepdim=True)
167+
result = result * in_box.float()
168+
return result
169+
elif direction == 'pf':
170+
result = self.propagate(edge_index, x=x, mode='pf', parameters=parameters)
171+
return result
172+
else:
173+
result = self.propagate(edge_index, x=x, mode='pp', parameters=parameters)
174+
return result
175+
176+
def message(self, edge_index_i, edge_index_j, x_i, x_j, mode=None, parameters_i=None):
177+
pos_i = x_i[:, 1:self.dimension+1]
178+
pos_j = x_j[:, 1:self.dimension+1]
179+
180+
d_pos = self.bc_dpos(pos_j - pos_i)
181+
dist = torch.sqrt(torch.sum(d_pos**2, dim=1))
182+
dist_safe = torch.clamp(dist, min=1e-6)
183+
184+
if mode == 'interpolate':
185+
C1_mesh = x_j[:, 6:7]
186+
C2_mesh = x_j[:, 7:8]
187+
weight = torch.exp(-dist / 0.01).unsqueeze(1)
188+
return torch.cat([C1_mesh * weight, C2_mesh * weight, weight], dim=1)
189+
190+
elif mode == 'fp':
191+
fields_i = x_i[:, 6:8]
192+
fields_j = x_j[:, 6:8]
193+
194+
dC1 = fields_j[:, 0:1] - fields_i[:, 0:1]
195+
dC2 = fields_j[:, 1:2] - fields_i[:, 1:2]
196+
197+
kernel = torch.exp(-dist / 0.05)
198+
dir_norm = d_pos / dist_safe.unsqueeze(1)
199+
domain_scale = 32.0
200+
grad_C1 = (dC1 * kernel.unsqueeze(1)) / (dist_safe.unsqueeze(1) * domain_scale)
201+
grad_C2 = (dC2 * kernel.unsqueeze(1)) / (dist_safe.unsqueeze(1) * domain_scale)
202+
203+
if parameters_i is not None:
204+
M1 = parameters_i[:, 0:1]
205+
M2 = parameters_i[:, 1:2]
206+
else:
207+
M1 = self.M1
208+
M2 = self.M2
209+
210+
velocity_raw = (M1 * grad_C1 + M2 * grad_C2) * dir_norm
211+
212+
# 1. Durotaxis: amplify velocity at steep gradients (Lo et al. 2000)
213+
if self.grad_amp_alpha > 0:
214+
grad_mag = torch.abs(grad_C1)
215+
grad_mag_clamped = torch.clamp(grad_mag, max=1.0)
216+
amp_factor = 1.0 + self.grad_amp_alpha * grad_mag_clamped
217+
velocity_raw = velocity_raw * amp_factor
218+
219+
# 2. Concentration-threshold coupling (Wolpert 1969)
220+
if self.ctc_threshold > 0:
221+
C1_local = fields_i[:, 0:1]
222+
A_ref = self.A_ref
223+
base_T = self.ctc_threshold * A_ref
224+
steepness = 3.0
225+
226+
# Per-type thresholds when multi-type + cross_type_factor > 0
227+
if (parameters_i is not None and self.cross_type_factor > 0
228+
and x_i.numel() > 0):
229+
type_i = x_i[:, 1 + 2*self.dimension].long()
230+
n_types = type_i.max().item() + 1 if type_i.numel() > 0 else 1
231+
mean_idx = (n_types - 1) / 2.0
232+
type_offset = self.cross_type_factor * (type_i.float() - mean_idx)
233+
T = base_T * (1.0 + type_offset.unsqueeze(1))
234+
else:
235+
T = base_T
236+
237+
sign_factor = -torch.tanh(steepness * (C1_local - T) / (A_ref + 1e-6))
238+
velocity_raw = velocity_raw * sign_factor
239+
240+
return velocity_raw
241+
242+
elif mode == 'pf':
243+
weights = torch.exp(-dist**2 / (2 * (self.influence_radius/3)**2))
244+
245+
if parameters_i is not None:
246+
consumption = parameters_i[:, 2]
247+
production = parameters_i[:, 3]
248+
else:
249+
consumption = self.consumption_rate
250+
production = self.production_rate
251+
252+
field_updates = torch.zeros((pos_i.size(0), 2), device=pos_i.device)
253+
field_updates[:, 0] = -consumption * weights
254+
field_updates[:, 1] = production * weights
255+
return field_updates
256+
257+
else: # mode == 'pp'
258+
if parameters_i is not None:
259+
p1 = parameters_i[:, 4]
260+
p2 = parameters_i[:, 5]
261+
p3 = parameters_i[:, 6]
262+
p4 = parameters_i[:, 7]
263+
264+
f = (p1 * torch.exp(-dist ** (2 * p2) / (2 * self.sigma ** 2))
265+
- p3 * torch.exp(-dist ** (2 * p4) / (2 * self.sigma ** 2)))
266+
267+
forces = f[:, None] * d_pos / dist_safe.unsqueeze(1)
268+
else:
269+
forces = torch.zeros_like(pos_i)
270+
in_range = dist < self.repulsion_range
271+
if in_range.any():
272+
dir_norm = d_pos / dist_safe.unsqueeze(1)
273+
repulsion_mag = self.repulsion_strength * torch.exp(
274+
-5.0 * dist[in_range] / self.repulsion_range
275+
)
276+
forces[in_range] = -dir_norm[in_range] * repulsion_mag.unsqueeze(1)
277+
278+
return forces
279+
280+
def update(self, aggr_out, mode=None):
281+
if mode == 'interpolate':
282+
C1_weighted = aggr_out[:, 0:1]
283+
C2_weighted = aggr_out[:, 1:2]
284+
weight_sum = aggr_out[:, 2:3]
285+
weight_sum = torch.clamp(weight_sum, min=1e-10)
286+
return torch.cat([C1_weighted / weight_sum, C2_weighted / weight_sum], dim=1)
287+
else:
288+
return aggr_out

0 commit comments

Comments
 (0)