Skip to content

Commit ff3d35d

Browse files
committed
[Iter 328] Code modification in PDE_D_AdaptivePF.py
[Automated commit by Claude]
1 parent 23dd6f4 commit ff3d35d

1 file changed

Lines changed: 363 additions & 0 deletions

File tree

Lines changed: 363 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,363 @@
1+
import torch
2+
import torch_geometric as pyg
3+
import torch_geometric.utils as pyg_utils
4+
from ParticleGraph.utils import to_numpy
5+
6+
class PDE_D_AdaptivePF(pyg.nn.MessagePassing):
7+
"""
8+
CTC with adaptive particle-to-field coupling + pp damping + fp drag.
9+
10+
Extends DragDampedCTC by making the particle-to-field (pf) coupling
11+
dependent on how close the particle is to its CTC threshold. When a
12+
particle is near T (|C1-T| < width*A), its consumption and production
13+
rates are attenuated by a Gaussian factor. This creates a self-dampening
14+
negative feedback loop:
15+
16+
1. Particle near T reduces its field perturbation
17+
2. Local field stabilizes (less oscillation in C1)
18+
3. CTC sign factor stabilizes
19+
4. Particle velocity converges
20+
21+
Physical motivation: metabolic regulation in cells. Cells that have
22+
reached their target position (positional identity established) reduce
23+
metabolic activity. This is observed in morphogen-responsive cells that
24+
downregulate ligand processing after fate commitment.
25+
26+
pf coupling modulation:
27+
consumption_eff = consumption * (1 - pf_adapt * exp(-(C1-T)^2 / (2*(pf_width*A)^2)))
28+
production_eff = production * (1 - pf_adapt * exp(-(C1-T)^2 / (2*(pf_width*A)^2)))
29+
30+
Also includes all DragDampedCTC features: durotaxis, CTC, pp damping, fp drag.
31+
32+
Literature:
33+
- Dessaud, E. et al. (2008) Development 135:2903-2913
34+
"Dynamic assignment and maintenance of positional identity in the
35+
ventral neural tube by the morphogen sonic hedgehog"
36+
- Wolpert, L. (1969) J Theor Biol 25:1-47
37+
- Lo, C. M. et al. (2000) Biophysical Journal 79:144-152
38+
- Painter, K. J. & Hillen, T. (2002) Can Appl Math Q 10(4):501-543
39+
- Tranquillo, R. T. & Lauffenburger, D. A. (1987) J Math Biol 25:229-262
40+
41+
Per-type params layout: [M1, M2, consumption, production, ar_p1, ar_p2, ar_p3, ar_p4]
42+
"""
43+
44+
PARAMS_DOC = {
45+
"model_name": "AdaptivePF",
46+
"literature": "Dessaud (2008) Development 135:2903; Wolpert (1969); Lo (2000); Painter & Hillen (2002); Tranquillo (1987)",
47+
"description": "CTC + adaptive pf coupling (reduced near threshold) + pp damping + fp drag",
48+
"equations": {
49+
"field_to_particle": "v = M*(1+alpha*|gradC1|)*(-tanh(3*(C1-T)/A))*grad*dir / (1+fp_drag*|vel|/v_ref)",
50+
"particle_to_field": "dC1 = -consumption*(1-pf_adapt*G(C1,T)) * w(r), dC2 = production*(1-pf_adapt*G(C1,T)) * w(r)",
51+
"pf_adaptation": "G(C1,T) = exp(-(C1-T)^2 / (2*(pf_width*A)^2))",
52+
"particle_to_particle": "f = f_AR * (1 - damping * exp(-(C1_i - T)^2 / (2*width^2)))"
53+
},
54+
"params_mesh": [
55+
{
56+
"row": 0, "description": "C1 field parameters + CTC threshold",
57+
"slots": [
58+
{"index": 0, "name": "D1", "description": "Diffusion coeff for C1"},
59+
{"index": 1, "name": "Da_c", "description": "Damkohler number"},
60+
{"index": 2, "name": "A", "description": "Brusselator A"},
61+
{"index": 3, "name": "B", "description": "Brusselator B"},
62+
{"index": 4, "name": "mu", "description": "Morphological param"},
63+
{"index": 5, "name": "M1", "description": "Mobility for C1 gradients"},
64+
{"index": 6, "name": "grad_amp_alpha", "description": "Durotaxis amplification"},
65+
{"index": 7, "name": "ctc_threshold", "description": "CTC threshold (T=ctc*A)"}
66+
]
67+
},
68+
{
69+
"row": 1, "description": "C2 field + pp damping + pf adaptation params",
70+
"slots": [
71+
{"index": 0, "name": "D2", "description": "Diffusion coeff for C2"},
72+
{"index": 1, "name": "M2", "description": "Mobility for C2 gradients"},
73+
{"index": 2, "name": "pp_damping", "description": "pp damping strength near T"},
74+
{"index": 3, "name": "pp_damping_width", "description": "Width of pp damping zone"},
75+
{"index": 4, "name": "pf_adapt", "description": "pf adaptation strength (0=off, 0.5=half, 0.9=strong)"},
76+
{"index": 5, "name": "pf_adapt_width", "description": "Width of pf adaptation zone (units of A)"}
77+
]
78+
},
79+
{
80+
"row": 2, "description": "Particle-field coupling + fp drag",
81+
"slots": [
82+
{"index": 0, "name": "Pe", "description": "Peclet number"},
83+
{"index": 1, "name": "consumption", "description": "Consumption rate of C1"},
84+
{"index": 2, "name": "production", "description": "Production rate of C2"},
85+
{"index": 3, "name": "influence_radius", "description": "Gaussian pf influence radius"},
86+
{"index": 4, "name": "fp_drag", "description": "Velocity-dependent fp drag"},
87+
{"index": 5, "name": "cross_type_factor", "description": "Per-type CTC threshold spread"}
88+
]
89+
}
90+
],
91+
"width_constraint": "ALL rows of params_mesh MUST have same number of columns (8). Pad shorter rows with 0.0."
92+
}
93+
94+
def __init__(self, aggr_type='mean', p=None, particle_params=None, bc_dpos=None, dimension=2, sigma=0.005):
95+
super(PDE_D_AdaptivePF, self).__init__(aggr=aggr_type)
96+
97+
self.p = p
98+
self.particle_params = particle_params
99+
self.bc_dpos = bc_dpos
100+
self.dimension = dimension
101+
self.sigma = sigma
102+
103+
self.M1 = p[0, 5]
104+
self.M2 = p[1, 1]
105+
self.consumption_rate = p[2, 1]
106+
self.production_rate = p[2, 2]
107+
self.influence_radius = p[2, 3]
108+
self.Pe = p[2, 0]
109+
self.repulsion_strength = 50
110+
self.repulsion_range = 0.04
111+
112+
# Durotaxis gradient amplification
113+
self.grad_amp_alpha = p[0, 6] if p.shape[1] > 6 else 0.0
114+
115+
# CTC threshold
116+
self.ctc_threshold = p[0, 7] if p.shape[1] > 7 else 0.0
117+
self.A_ref = p[0, 2]
118+
119+
# Per-type threshold spread
120+
self.cross_type_factor = p[2, 5] if p.shape[1] > 5 else 0.0
121+
122+
# pp damping parameters (Painter & Hillen 2002)
123+
self.pp_damping = p[1, 2] if p.shape[1] > 2 else 0.0
124+
self.pp_damping_width = p[1, 3] if p.shape[1] > 3 else 0.5
125+
126+
# pf adaptation parameters (Dessaud 2008)
127+
self.pf_adapt = p[1, 4] if p.shape[1] > 4 else 0.0
128+
self.pf_adapt_width = p[1, 5] if p.shape[1] > 5 else 0.5
129+
130+
# Velocity-dependent fp drag (Tranquillo & Lauffenburger 1987)
131+
self.fp_drag = p[2, 4] if p.shape[1] > 4 else 0.0
132+
self.v_ref = 0.01
133+
134+
print(f"initialized PDE_D_AdaptivePF with parameters:")
135+
print(f" mobility: M1={self.M1.item()}, M2={self.M2.item()}")
136+
ga_val = self.grad_amp_alpha.item() if hasattr(self.grad_amp_alpha, 'item') else self.grad_amp_alpha
137+
print(f" grad_amp_alpha={ga_val:.3f} (durotaxis, Lo 2000)")
138+
ctc_val = self.ctc_threshold.item() if hasattr(self.ctc_threshold, 'item') else self.ctc_threshold
139+
T_val = ctc_val * self.A_ref.item()
140+
print(f" ctc_threshold={ctc_val:.3f} (T={T_val:.2f}, Wolpert 1969)")
141+
damp_val = self.pp_damping.item() if hasattr(self.pp_damping, 'item') else self.pp_damping
142+
damp_w = self.pp_damping_width.item() if hasattr(self.pp_damping_width, 'item') else self.pp_damping_width
143+
print(f" pp_damping={damp_val:.3f}, pp_damping_width={damp_w:.3f} (Painter & Hillen 2002)")
144+
pf_adapt_val = self.pf_adapt.item() if hasattr(self.pf_adapt, 'item') else self.pf_adapt
145+
pf_adapt_w = self.pf_adapt_width.item() if hasattr(self.pf_adapt_width, 'item') else self.pf_adapt_width
146+
print(f" ADAPTIVE PF: pf_adapt={pf_adapt_val:.3f}, pf_adapt_width={pf_adapt_w:.3f} (Dessaud 2008)")
147+
print(f" Particles near T reduce consumption/production by {pf_adapt_val*100:.0f}%")
148+
fp_drag_val = self.fp_drag.item() if hasattr(self.fp_drag, 'item') else self.fp_drag
149+
print(f" fp_drag={fp_drag_val:.3f}, v_ref={self.v_ref:.4f} (Tranquillo 1987)")
150+
ctf_val = self.cross_type_factor.item() if hasattr(self.cross_type_factor, 'item') else self.cross_type_factor
151+
if ctf_val > 0 and particle_params is not None:
152+
n_types = particle_params.shape[0]
153+
mean_idx = (n_types - 1) / 2.0
154+
for t in range(n_types):
155+
t_offset = ctf_val * (t - mean_idx)
156+
t_val_t = T_val * (1.0 + t_offset)
157+
print(f" Type {t}: CTC threshold = {t_val_t:.2f} (offset={t_offset:+.2f})")
158+
print(f" Pe={self.Pe.item():.3f}, sigma={self.sigma}")
159+
print(f" particle->field: consumption={self.consumption_rate.item()}, production={self.production_rate.item()}, influence_radius={self.influence_radius.item():.3f}")
160+
if particle_params is not None:
161+
print(f" multi-type support: {particle_params.shape[0]} particle types")
162+
163+
def forward(self, data, direction='fp'):
164+
x, edge_index = data.x, data.edge_index
165+
edge_index, _ = pyg_utils.remove_self_loops(edge_index)
166+
167+
if self.particle_params is not None:
168+
particle_type = x[:, 1 + 2*self.dimension].long()
169+
max_type = particle_type.max().item()
170+
n_param_rows = self.particle_params.shape[0]
171+
if max_type >= n_param_rows:
172+
raise ValueError(
173+
f"PDE_D_AdaptivePF: particle_params has {n_param_rows} rows but found "
174+
f"particle type {max_type}. Need {max_type + 1} rows in simulation.params."
175+
)
176+
parameters = self.particle_params[to_numpy(particle_type), :]
177+
else:
178+
parameters = None
179+
180+
if direction == 'interpolate':
181+
result = self.propagate(edge_index, x=x, mode='interpolate', parameters=parameters)
182+
pos = x[:, 1:self.dimension+1]
183+
in_box = ((pos >= 0) & (pos <= 1)).all(dim=1, keepdim=True)
184+
result = result * in_box.float()
185+
return result
186+
elif direction == 'fp':
187+
result = self.propagate(edge_index, x=x, mode='fp', parameters=parameters)
188+
pos = x[:, 1:self.dimension+1]
189+
in_box = ((pos >= 0) & (pos <= 1)).all(dim=1, keepdim=True)
190+
result = result * in_box.float()
191+
return result
192+
elif direction == 'pf':
193+
result = self.propagate(edge_index, x=x, mode='pf', parameters=parameters)
194+
return result
195+
else:
196+
result = self.propagate(edge_index, x=x, mode='pp', parameters=parameters)
197+
return result
198+
199+
def message(self, edge_index_i, edge_index_j, x_i, x_j, mode=None, parameters_i=None):
200+
pos_i = x_i[:, 1:self.dimension+1]
201+
pos_j = x_j[:, 1:self.dimension+1]
202+
203+
d_pos = self.bc_dpos(pos_j - pos_i)
204+
dist = torch.sqrt(torch.sum(d_pos**2, dim=1))
205+
dist_safe = torch.clamp(dist, min=1e-6)
206+
207+
if mode == 'interpolate':
208+
C1_mesh = x_j[:, 6:7]
209+
C2_mesh = x_j[:, 7:8]
210+
weight = torch.exp(-dist / 0.01).unsqueeze(1)
211+
return torch.cat([C1_mesh * weight, C2_mesh * weight, weight], dim=1)
212+
213+
elif mode == 'fp':
214+
fields_i = x_i[:, 6:8]
215+
fields_j = x_j[:, 6:8]
216+
217+
dC1 = fields_j[:, 0:1] - fields_i[:, 0:1]
218+
dC2 = fields_j[:, 1:2] - fields_i[:, 1:2]
219+
220+
kernel = torch.exp(-dist / 0.05)
221+
dir_norm = d_pos / dist_safe.unsqueeze(1)
222+
domain_scale = 32.0
223+
grad_C1 = (dC1 * kernel.unsqueeze(1)) / (dist_safe.unsqueeze(1) * domain_scale)
224+
grad_C2 = (dC2 * kernel.unsqueeze(1)) / (dist_safe.unsqueeze(1) * domain_scale)
225+
226+
if parameters_i is not None:
227+
M1 = parameters_i[:, 0:1]
228+
M2 = parameters_i[:, 1:2]
229+
else:
230+
M1 = self.M1
231+
M2 = self.M2
232+
233+
velocity_raw = (M1 * grad_C1 + M2 * grad_C2) * dir_norm
234+
235+
# 1. Durotaxis: amplify velocity at steep gradients (Lo et al. 2000)
236+
if self.grad_amp_alpha > 0:
237+
grad_mag = torch.abs(grad_C1)
238+
grad_mag_clamped = torch.clamp(grad_mag, max=1.0)
239+
amp_factor = 1.0 + self.grad_amp_alpha * grad_mag_clamped
240+
velocity_raw = velocity_raw * amp_factor
241+
242+
# 2. Concentration-threshold coupling (Wolpert 1969)
243+
if self.ctc_threshold > 0:
244+
C1_local = fields_i[:, 0:1]
245+
A_ref = self.A_ref
246+
base_T = self.ctc_threshold * A_ref
247+
steepness = 3.0
248+
249+
# Per-type thresholds
250+
if (parameters_i is not None and self.cross_type_factor > 0
251+
and x_i.numel() > 0):
252+
type_i = x_i[:, 1 + 2*self.dimension].long()
253+
n_types = type_i.max().item() + 1 if type_i.numel() > 0 else 1
254+
mean_idx = (n_types - 1) / 2.0
255+
type_offset = self.cross_type_factor * (type_i.float() - mean_idx)
256+
T = base_T * (1.0 + type_offset.unsqueeze(1))
257+
else:
258+
T = base_T
259+
260+
sign_factor = -torch.tanh(steepness * (C1_local - T) / (A_ref + 1e-6))
261+
velocity_raw = velocity_raw * sign_factor
262+
263+
# 3. Velocity-dependent fp drag (Tranquillo & Lauffenburger 1987)
264+
if self.fp_drag > 0:
265+
vel_i = x_i[:, 1+self.dimension:1+2*self.dimension]
266+
speed = torch.sqrt(torch.sum(vel_i**2, dim=1, keepdim=True))
267+
drag_factor = 1.0 / (1.0 + self.fp_drag * speed / self.v_ref)
268+
velocity_raw = velocity_raw * drag_factor
269+
270+
return velocity_raw
271+
272+
elif mode == 'pf':
273+
weights = torch.exp(-dist**2 / (2 * (self.influence_radius/3)**2))
274+
275+
if parameters_i is not None:
276+
consumption = parameters_i[:, 2]
277+
production = parameters_i[:, 3]
278+
else:
279+
consumption = self.consumption_rate
280+
production = self.production_rate
281+
282+
# Adaptive pf: reduce consumption/production near CTC threshold (Dessaud 2008)
283+
if self.pf_adapt > 0 and self.ctc_threshold > 0:
284+
C1_local = x_i[:, 6] # C1 at particle position
285+
A_ref = self.A_ref
286+
base_T = self.ctc_threshold * A_ref
287+
288+
# Per-type threshold for pf adaptation zone
289+
if (parameters_i is not None and self.cross_type_factor > 0
290+
and x_i.numel() > 0):
291+
type_i = x_i[:, 1 + 2*self.dimension].long()
292+
n_types = type_i.max().item() + 1 if type_i.numel() > 0 else 1
293+
mean_idx = (n_types - 1) / 2.0
294+
type_offset = self.cross_type_factor * (type_i.float() - mean_idx)
295+
T_local = base_T * (1.0 + type_offset)
296+
else:
297+
T_local = base_T
298+
299+
pf_width = self.pf_adapt_width * A_ref
300+
deviation = (C1_local - T_local)
301+
# Gaussian attenuation: strongest reduction when particle is at threshold
302+
pf_factor = 1.0 - self.pf_adapt * torch.exp(-deviation**2 / (2 * pf_width**2 + 1e-8))
303+
consumption = consumption * pf_factor
304+
production = production * pf_factor
305+
306+
field_updates = torch.zeros((pos_i.size(0), 2), device=pos_i.device)
307+
field_updates[:, 0] = -consumption * weights
308+
field_updates[:, 1] = production * weights
309+
return field_updates
310+
311+
else: # mode == 'pp'
312+
if parameters_i is not None:
313+
p1 = parameters_i[:, 4]
314+
p2 = parameters_i[:, 5]
315+
p3 = parameters_i[:, 6]
316+
p4 = parameters_i[:, 7]
317+
318+
f = (p1 * torch.exp(-dist ** (2 * p2) / (2 * self.sigma ** 2))
319+
- p3 * torch.exp(-dist ** (2 * p4) / (2 * self.sigma ** 2)))
320+
321+
forces = f[:, None] * d_pos / dist_safe.unsqueeze(1)
322+
else:
323+
forces = torch.zeros_like(pos_i)
324+
in_range = dist < self.repulsion_range
325+
if in_range.any():
326+
dir_norm = d_pos / dist_safe.unsqueeze(1)
327+
repulsion_mag = self.repulsion_strength * torch.exp(
328+
-5.0 * dist[in_range] / self.repulsion_range
329+
)
330+
forces[in_range] = -dir_norm[in_range] * repulsion_mag.unsqueeze(1)
331+
332+
# Field-dependent pp damping (Painter & Hillen 2002)
333+
if self.pp_damping > 0 and self.ctc_threshold > 0:
334+
C1_local = x_i[:, 6:7].squeeze(1)
335+
A_ref = self.A_ref
336+
base_T = self.ctc_threshold * A_ref
337+
338+
if (parameters_i is not None and self.cross_type_factor > 0
339+
and x_i.numel() > 0):
340+
type_i = x_i[:, 1 + 2*self.dimension].long()
341+
n_types = type_i.max().item() + 1 if type_i.numel() > 0 else 1
342+
mean_idx = (n_types - 1) / 2.0
343+
type_offset = self.cross_type_factor * (type_i.float() - mean_idx)
344+
T_local = base_T * (1.0 + type_offset)
345+
else:
346+
T_local = base_T
347+
348+
width = self.pp_damping_width * A_ref
349+
deviation = (C1_local - T_local)
350+
damping_factor = 1.0 - self.pp_damping * torch.exp(-deviation**2 / (2 * width**2 + 1e-8))
351+
forces = forces * damping_factor.unsqueeze(1)
352+
353+
return forces
354+
355+
def update(self, aggr_out, mode=None):
356+
if mode == 'interpolate':
357+
C1_weighted = aggr_out[:, 0:1]
358+
C2_weighted = aggr_out[:, 1:2]
359+
weight_sum = aggr_out[:, 2:3]
360+
weight_sum = torch.clamp(weight_sum, min=1e-10)
361+
return torch.cat([C1_weighted / weight_sum, C2_weighted / weight_sum], dim=1)
362+
else:
363+
return aggr_out

0 commit comments

Comments
 (0)