Skip to content

Commit b8aa654

Browse files
committed
[Iter 328] Code modification in PDE_D_DeadzoneCTC.py
[Automated commit by Claude]
1 parent ff3d35d commit b8aa654

1 file changed

Lines changed: 362 additions & 0 deletions

File tree

Lines changed: 362 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,362 @@
1+
import torch
2+
import torch_geometric as pyg
3+
import torch_geometric.utils as pyg_utils
4+
from ParticleGraph.utils import to_numpy
5+
6+
class PDE_D_DeadzoneCTC(pyg.nn.MessagePassing):
7+
"""
8+
CTC with receptor-adaptation deadzone + pp damping + fp drag.
9+
10+
Modifies the standard CTC tanh transfer function by introducing a flat
11+
"deadzone" of width epsilon around the threshold concentration T. When
12+
|C1 - T| < epsilon*A, the CTC sign factor is exactly zero, meaning
13+
particles near the threshold experience NO chemotactic driving force.
14+
They settle purely through pp damping. Outside the deadzone, the standard
15+
tanh response applies with shifted argument.
16+
17+
This addresses the 2-type convergence bottleneck: particles near T
18+
oscillate because the tanh function transitions smoothly through zero,
19+
producing small but sign-alternating forces as C1 fluctuates around T.
20+
The deadzone eliminates these oscillations entirely within the
21+
adaptation zone.
22+
23+
Physical motivation: receptor adaptation / perfect adaptation in
24+
chemotaxis. Cells near their target concentration adapt their receptor
25+
sensitivity, becoming insensitive to small fluctuations. Only large
26+
deviations from the adapted state trigger a migration response.
27+
28+
Transfer function:
29+
sign_factor = 0 if |C1-T| < eps*A
30+
sign_factor = -tanh(steep * (C1-T-eps*A) / A) if C1 > T + eps*A
31+
sign_factor = -tanh(steep * (C1-T+eps*A) / A) if C1 < T - eps*A
32+
33+
Also includes all DragDampedCTC features: durotaxis, pp damping, fp drag.
34+
35+
Literature:
36+
- Barkai, N. & Leibler, S. (1997) Nature 387:913-917
37+
"Robustness in simple biochemical networks"
38+
- Wolpert, L. (1969) J Theor Biol 25:1-47
39+
"Positional information and the spatial pattern of cellular differentiation"
40+
- Lo, C. M. et al. (2000) Biophysical Journal 79:144-152
41+
- Painter, K. J. & Hillen, T. (2002) Can Appl Math Q 10(4):501-543
42+
- Tranquillo, R. T. & Lauffenburger, D. A. (1987) J Math Biol 25:229-262
43+
44+
Per-type params layout: [M1, M2, consumption, production, ar_p1, ar_p2, ar_p3, ar_p4]
45+
"""
46+
47+
PARAMS_DOC = {
48+
"model_name": "DeadzoneCTC",
49+
"literature": "Barkai & Leibler (1997) Nature 387:913; Wolpert (1969); Lo (2000); Painter & Hillen (2002); Tranquillo (1987)",
50+
"description": "CTC with receptor-adaptation deadzone near threshold + pp damping + fp drag",
51+
"equations": {
52+
"field_to_particle": "v = M*(1+alpha*|gradC1|)*sign_factor*grad*dir / (1+fp_drag*|vel|/v_ref)",
53+
"deadzone_ctc": "sign_factor = 0 if |C1-T|<eps*A, else -tanh(3*(C1-T∓eps*A)/A)",
54+
"particle_to_field": "dC1 = -consumption * w(r), dC2 = production * w(r)",
55+
"particle_to_particle": "f = f_AR * (1 - damping * exp(-(C1_i - T)^2 / (2*width^2)))"
56+
},
57+
"params_mesh": [
58+
{
59+
"row": 0, "description": "C1 field parameters + CTC threshold",
60+
"slots": [
61+
{"index": 0, "name": "D1", "description": "Diffusion coeff for C1"},
62+
{"index": 1, "name": "Da_c", "description": "Damkohler number"},
63+
{"index": 2, "name": "A", "description": "Brusselator A"},
64+
{"index": 3, "name": "B", "description": "Brusselator B"},
65+
{"index": 4, "name": "mu", "description": "Morphological param"},
66+
{"index": 5, "name": "M1", "description": "Mobility for C1 gradients"},
67+
{"index": 6, "name": "grad_amp_alpha", "description": "Durotaxis amplification"},
68+
{"index": 7, "name": "ctc_threshold", "description": "CTC threshold (T=ctc*A)"}
69+
]
70+
},
71+
{
72+
"row": 1, "description": "C2 field + pp damping + deadzone params",
73+
"slots": [
74+
{"index": 0, "name": "D2", "description": "Diffusion coeff for C2"},
75+
{"index": 1, "name": "M2", "description": "Mobility for C2 gradients"},
76+
{"index": 2, "name": "pp_damping", "description": "pp damping strength near T"},
77+
{"index": 3, "name": "pp_damping_width", "description": "Width of pp damping zone"},
78+
{"index": 4, "name": "deadzone_eps", "description": "Deadzone half-width (units of A). 0=standard CTC."},
79+
{"index": 5, "name": "unused1", "description": "Unused (pad)"}
80+
]
81+
},
82+
{
83+
"row": 2, "description": "Particle-field coupling + fp drag",
84+
"slots": [
85+
{"index": 0, "name": "Pe", "description": "Peclet number"},
86+
{"index": 1, "name": "consumption", "description": "Consumption rate of C1"},
87+
{"index": 2, "name": "production", "description": "Production rate of C2"},
88+
{"index": 3, "name": "influence_radius", "description": "Gaussian pf influence radius"},
89+
{"index": 4, "name": "fp_drag", "description": "Velocity-dependent fp drag"},
90+
{"index": 5, "name": "cross_type_factor", "description": "Per-type CTC threshold spread"}
91+
]
92+
}
93+
],
94+
"width_constraint": "ALL rows of params_mesh MUST have same number of columns (8). Pad shorter rows with 0.0."
95+
}
96+
97+
def __init__(self, aggr_type='mean', p=None, particle_params=None, bc_dpos=None, dimension=2, sigma=0.005):
98+
super(PDE_D_DeadzoneCTC, self).__init__(aggr=aggr_type)
99+
100+
self.p = p
101+
self.particle_params = particle_params
102+
self.bc_dpos = bc_dpos
103+
self.dimension = dimension
104+
self.sigma = sigma
105+
106+
self.M1 = p[0, 5]
107+
self.M2 = p[1, 1]
108+
self.consumption_rate = p[2, 1]
109+
self.production_rate = p[2, 2]
110+
self.influence_radius = p[2, 3]
111+
self.Pe = p[2, 0]
112+
self.repulsion_strength = 50
113+
self.repulsion_range = 0.04
114+
115+
# Durotaxis gradient amplification
116+
self.grad_amp_alpha = p[0, 6] if p.shape[1] > 6 else 0.0
117+
118+
# CTC threshold
119+
self.ctc_threshold = p[0, 7] if p.shape[1] > 7 else 0.0
120+
self.A_ref = p[0, 2]
121+
122+
# Per-type threshold spread
123+
self.cross_type_factor = p[2, 5] if p.shape[1] > 5 else 0.0
124+
125+
# pp damping parameters (Painter & Hillen 2002)
126+
self.pp_damping = p[1, 2] if p.shape[1] > 2 else 0.0
127+
self.pp_damping_width = p[1, 3] if p.shape[1] > 3 else 0.5
128+
129+
# Deadzone half-width (Barkai & Leibler 1997)
130+
self.deadzone_eps = p[1, 4] if p.shape[1] > 4 else 0.0
131+
132+
# Velocity-dependent fp drag (Tranquillo & Lauffenburger 1987)
133+
self.fp_drag = p[2, 4] if p.shape[1] > 4 else 0.0
134+
self.v_ref = 0.01
135+
136+
print(f"initialized PDE_D_DeadzoneCTC with parameters:")
137+
print(f" mobility: M1={self.M1.item()}, M2={self.M2.item()}")
138+
ga_val = self.grad_amp_alpha.item() if hasattr(self.grad_amp_alpha, 'item') else self.grad_amp_alpha
139+
print(f" grad_amp_alpha={ga_val:.3f} (durotaxis, Lo 2000)")
140+
ctc_val = self.ctc_threshold.item() if hasattr(self.ctc_threshold, 'item') else self.ctc_threshold
141+
T_val = ctc_val * self.A_ref.item()
142+
print(f" ctc_threshold={ctc_val:.3f} (T={T_val:.2f}, Wolpert 1969)")
143+
eps_val = self.deadzone_eps.item() if hasattr(self.deadzone_eps, 'item') else self.deadzone_eps
144+
print(f" DEADZONE: eps={eps_val:.3f} (half-width={eps_val*self.A_ref.item():.3f}, Barkai & Leibler 1997)")
145+
print(f" Particles within |C1-T| < {eps_val*self.A_ref.item():.3f} experience ZERO fp force")
146+
damp_val = self.pp_damping.item() if hasattr(self.pp_damping, 'item') else self.pp_damping
147+
damp_w = self.pp_damping_width.item() if hasattr(self.pp_damping_width, 'item') else self.pp_damping_width
148+
print(f" pp_damping={damp_val:.3f}, pp_damping_width={damp_w:.3f} (Painter & Hillen 2002)")
149+
fp_drag_val = self.fp_drag.item() if hasattr(self.fp_drag, 'item') else self.fp_drag
150+
print(f" fp_drag={fp_drag_val:.3f}, v_ref={self.v_ref:.4f} (Tranquillo 1987)")
151+
ctf_val = self.cross_type_factor.item() if hasattr(self.cross_type_factor, 'item') else self.cross_type_factor
152+
if ctf_val > 0 and particle_params is not None:
153+
n_types = particle_params.shape[0]
154+
mean_idx = (n_types - 1) / 2.0
155+
for t in range(n_types):
156+
t_offset = ctf_val * (t - mean_idx)
157+
t_val_t = T_val * (1.0 + t_offset)
158+
print(f" Type {t}: CTC threshold = {t_val_t:.2f} (offset={t_offset:+.2f})")
159+
print(f" Pe={self.Pe.item():.3f}, sigma={self.sigma}")
160+
print(f" particle->field: consumption={self.consumption_rate.item()}, production={self.production_rate.item()}, influence_radius={self.influence_radius.item():.3f}")
161+
if particle_params is not None:
162+
print(f" multi-type support: {particle_params.shape[0]} particle types")
163+
164+
def forward(self, data, direction='fp'):
165+
x, edge_index = data.x, data.edge_index
166+
edge_index, _ = pyg_utils.remove_self_loops(edge_index)
167+
168+
if self.particle_params is not None:
169+
particle_type = x[:, 1 + 2*self.dimension].long()
170+
max_type = particle_type.max().item()
171+
n_param_rows = self.particle_params.shape[0]
172+
if max_type >= n_param_rows:
173+
raise ValueError(
174+
f"PDE_D_DeadzoneCTC: particle_params has {n_param_rows} rows but found "
175+
f"particle type {max_type}. Need {max_type + 1} rows in simulation.params."
176+
)
177+
parameters = self.particle_params[to_numpy(particle_type), :]
178+
else:
179+
parameters = None
180+
181+
if direction == 'interpolate':
182+
result = self.propagate(edge_index, x=x, mode='interpolate', parameters=parameters)
183+
pos = x[:, 1:self.dimension+1]
184+
in_box = ((pos >= 0) & (pos <= 1)).all(dim=1, keepdim=True)
185+
result = result * in_box.float()
186+
return result
187+
elif direction == 'fp':
188+
result = self.propagate(edge_index, x=x, mode='fp', parameters=parameters)
189+
pos = x[:, 1:self.dimension+1]
190+
in_box = ((pos >= 0) & (pos <= 1)).all(dim=1, keepdim=True)
191+
result = result * in_box.float()
192+
return result
193+
elif direction == 'pf':
194+
result = self.propagate(edge_index, x=x, mode='pf', parameters=parameters)
195+
return result
196+
else:
197+
result = self.propagate(edge_index, x=x, mode='pp', parameters=parameters)
198+
return result
199+
200+
def message(self, edge_index_i, edge_index_j, x_i, x_j, mode=None, parameters_i=None):
201+
pos_i = x_i[:, 1:self.dimension+1]
202+
pos_j = x_j[:, 1:self.dimension+1]
203+
204+
d_pos = self.bc_dpos(pos_j - pos_i)
205+
dist = torch.sqrt(torch.sum(d_pos**2, dim=1))
206+
dist_safe = torch.clamp(dist, min=1e-6)
207+
208+
if mode == 'interpolate':
209+
C1_mesh = x_j[:, 6:7]
210+
C2_mesh = x_j[:, 7:8]
211+
weight = torch.exp(-dist / 0.01).unsqueeze(1)
212+
return torch.cat([C1_mesh * weight, C2_mesh * weight, weight], dim=1)
213+
214+
elif mode == 'fp':
215+
fields_i = x_i[:, 6:8]
216+
fields_j = x_j[:, 6:8]
217+
218+
dC1 = fields_j[:, 0:1] - fields_i[:, 0:1]
219+
dC2 = fields_j[:, 1:2] - fields_i[:, 1:2]
220+
221+
kernel = torch.exp(-dist / 0.05)
222+
dir_norm = d_pos / dist_safe.unsqueeze(1)
223+
domain_scale = 32.0
224+
grad_C1 = (dC1 * kernel.unsqueeze(1)) / (dist_safe.unsqueeze(1) * domain_scale)
225+
grad_C2 = (dC2 * kernel.unsqueeze(1)) / (dist_safe.unsqueeze(1) * domain_scale)
226+
227+
if parameters_i is not None:
228+
M1 = parameters_i[:, 0:1]
229+
M2 = parameters_i[:, 1:2]
230+
else:
231+
M1 = self.M1
232+
M2 = self.M2
233+
234+
velocity_raw = (M1 * grad_C1 + M2 * grad_C2) * dir_norm
235+
236+
# 1. Durotaxis: amplify velocity at steep gradients (Lo et al. 2000)
237+
if self.grad_amp_alpha > 0:
238+
grad_mag = torch.abs(grad_C1)
239+
grad_mag_clamped = torch.clamp(grad_mag, max=1.0)
240+
amp_factor = 1.0 + self.grad_amp_alpha * grad_mag_clamped
241+
velocity_raw = velocity_raw * amp_factor
242+
243+
# 2. CTC with deadzone (Wolpert 1969 + Barkai & Leibler 1997)
244+
if self.ctc_threshold > 0:
245+
C1_local = fields_i[:, 0:1]
246+
A_ref = self.A_ref
247+
base_T = self.ctc_threshold * A_ref
248+
steepness = 3.0
249+
250+
# Per-type thresholds
251+
if (parameters_i is not None and self.cross_type_factor > 0
252+
and x_i.numel() > 0):
253+
type_i = x_i[:, 1 + 2*self.dimension].long()
254+
n_types = type_i.max().item() + 1 if type_i.numel() > 0 else 1
255+
mean_idx = (n_types - 1) / 2.0
256+
type_offset = self.cross_type_factor * (type_i.float() - mean_idx)
257+
T = base_T * (1.0 + type_offset.unsqueeze(1))
258+
else:
259+
T = base_T
260+
261+
deviation = C1_local - T
262+
eps_width = self.deadzone_eps * A_ref # deadzone half-width in C1 units
263+
264+
if eps_width > 0:
265+
# Deadzone CTC: zero force within |deviation| < eps_width
266+
# Shifted tanh outside the deadzone
267+
# Use smooth approximation to avoid discontinuity:
268+
# sign_factor = -tanh(steep * (|dev| - eps) * sign(dev) / A)
269+
# * smoothstep(|dev|, eps)
270+
abs_dev = torch.abs(deviation)
271+
shifted_dev = abs_dev - eps_width
272+
shifted_dev_signed = shifted_dev * torch.sign(deviation)
273+
274+
# Smooth transition at deadzone boundary using sigmoid
275+
# transition_width controls sharpness of deadzone edge
276+
transition_width = 0.1 * A_ref # smooth over 10% of A
277+
gate = torch.sigmoid((abs_dev - eps_width) / (transition_width + 1e-8))
278+
279+
sign_factor = -torch.tanh(steepness * shifted_dev_signed / (A_ref + 1e-6)) * gate
280+
else:
281+
# Standard CTC (no deadzone)
282+
sign_factor = -torch.tanh(steepness * deviation / (A_ref + 1e-6))
283+
284+
velocity_raw = velocity_raw * sign_factor
285+
286+
# 3. Velocity-dependent fp drag (Tranquillo & Lauffenburger 1987)
287+
if self.fp_drag > 0:
288+
vel_i = x_i[:, 1+self.dimension:1+2*self.dimension]
289+
speed = torch.sqrt(torch.sum(vel_i**2, dim=1, keepdim=True))
290+
drag_factor = 1.0 / (1.0 + self.fp_drag * speed / self.v_ref)
291+
velocity_raw = velocity_raw * drag_factor
292+
293+
return velocity_raw
294+
295+
elif mode == 'pf':
296+
weights = torch.exp(-dist**2 / (2 * (self.influence_radius/3)**2))
297+
298+
if parameters_i is not None:
299+
consumption = parameters_i[:, 2]
300+
production = parameters_i[:, 3]
301+
else:
302+
consumption = self.consumption_rate
303+
production = self.production_rate
304+
305+
field_updates = torch.zeros((pos_i.size(0), 2), device=pos_i.device)
306+
field_updates[:, 0] = -consumption * weights
307+
field_updates[:, 1] = production * weights
308+
return field_updates
309+
310+
else: # mode == 'pp'
311+
if parameters_i is not None:
312+
p1 = parameters_i[:, 4]
313+
p2 = parameters_i[:, 5]
314+
p3 = parameters_i[:, 6]
315+
p4 = parameters_i[:, 7]
316+
317+
f = (p1 * torch.exp(-dist ** (2 * p2) / (2 * self.sigma ** 2))
318+
- p3 * torch.exp(-dist ** (2 * p4) / (2 * self.sigma ** 2)))
319+
320+
forces = f[:, None] * d_pos / dist_safe.unsqueeze(1)
321+
else:
322+
forces = torch.zeros_like(pos_i)
323+
in_range = dist < self.repulsion_range
324+
if in_range.any():
325+
dir_norm = d_pos / dist_safe.unsqueeze(1)
326+
repulsion_mag = self.repulsion_strength * torch.exp(
327+
-5.0 * dist[in_range] / self.repulsion_range
328+
)
329+
forces[in_range] = -dir_norm[in_range] * repulsion_mag.unsqueeze(1)
330+
331+
# Field-dependent pp damping (Painter & Hillen 2002)
332+
if self.pp_damping > 0 and self.ctc_threshold > 0:
333+
C1_local = x_i[:, 6:7].squeeze(1)
334+
A_ref = self.A_ref
335+
base_T = self.ctc_threshold * A_ref
336+
337+
if (parameters_i is not None and self.cross_type_factor > 0
338+
and x_i.numel() > 0):
339+
type_i = x_i[:, 1 + 2*self.dimension].long()
340+
n_types = type_i.max().item() + 1 if type_i.numel() > 0 else 1
341+
mean_idx = (n_types - 1) / 2.0
342+
type_offset = self.cross_type_factor * (type_i.float() - mean_idx)
343+
T_local = base_T * (1.0 + type_offset)
344+
else:
345+
T_local = base_T
346+
347+
width = self.pp_damping_width * A_ref
348+
deviation = (C1_local - T_local)
349+
damping_factor = 1.0 - self.pp_damping * torch.exp(-deviation**2 / (2 * width**2 + 1e-8))
350+
forces = forces * damping_factor.unsqueeze(1)
351+
352+
return forces
353+
354+
def update(self, aggr_out, mode=None):
355+
if mode == 'interpolate':
356+
C1_weighted = aggr_out[:, 0:1]
357+
C2_weighted = aggr_out[:, 1:2]
358+
weight_sum = aggr_out[:, 2:3]
359+
weight_sum = torch.clamp(weight_sum, min=1e-10)
360+
return torch.cat([C1_weighted / weight_sum, C2_weighted / weight_sum], dim=1)
361+
else:
362+
return aggr_out

0 commit comments

Comments
 (0)