Skip to content

Commit 34fb391

Browse files
committed
[Iter 320] Code modification in PDE_D_AsymmetricCTC.py
[Automated commit by Claude]
1 parent 4bd5c06 commit 34fb391

1 file changed

Lines changed: 376 additions & 0 deletions

File tree

Lines changed: 376 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,376 @@
1+
import torch
2+
import torch_geometric as pyg
3+
import torch_geometric.utils as pyg_utils
4+
from ParticleGraph.utils import to_numpy
5+
6+
class PDE_D_AsymmetricCTC(pyg.nn.MessagePassing):
7+
"""
8+
Asymmetric CTC with hysteresis-like response + pp damping + fp drag.
9+
10+
Standard CTC uses a symmetric tanh switch: equal steepness for approaching
11+
and departing the threshold concentration. This creates oscillation because
12+
particles overshoot with the same vigor they approached. AsymmetricCTC uses
13+
different steepness on each side of the threshold:
14+
15+
sign_factor = -tanh(s(C1) * (C1 - T) / A)
16+
17+
where s(C1) = steep_toward when moving TOWARD threshold (sign of gradient
18+
opposes sign of (C1-T)), and s(C1) = steep_away when moving AWAY.
19+
20+
Physical motivation: hysteresis in gene regulatory networks. Cells commit
21+
to a positional fate with sharp threshold sensing but resist departing from
22+
committed positions with gentler response, creating bistable spatial domains.
23+
24+
The key innovation is that steep_toward >> steep_away creates a "sticky"
25+
threshold zone: particles rush toward the threshold efficiently but resist
26+
being pushed away, reducing oscillation and improving convergence.
27+
28+
Also includes all DragDampedCTC features: durotaxis, pp damping, fp drag.
29+
30+
Literature:
31+
- Ferrell, J. E. (2002) Current Opinion in Cell Biology 14:140-148
32+
"Self-perpetuating states in signal transduction: positive feedback,
33+
double-negative feedback and bistability"
34+
- Wolpert, L. (1969) J Theor Biol 25:1-47
35+
"Positional information and the spatial pattern of cellular differentiation"
36+
- Lo, C. M. et al. (2000) Biophysical Journal 79:144-152
37+
- Painter, K. J. & Hillen, T. (2002) Can Appl Math Q 10(4):501-543
38+
- Tranquillo, R. T. & Lauffenburger, D. A. (1987) J Math Biol 25:229-262
39+
40+
Per-type params layout: [M1, M2, consumption, production, ar_p1, ar_p2, ar_p3, ar_p4]
41+
"""
42+
43+
PARAMS_DOC = {
44+
"model_name": "AsymmetricCTC",
45+
"literature": "Ferrell (2002); Wolpert (1969); Lo (2000); Painter & Hillen (2002); Tranquillo (1987)",
46+
"description": "Asymmetric CTC with hysteresis + pp damping + fp drag. Different steepness toward/away from threshold.",
47+
"equations": {
48+
"field_to_particle": "v = M*(1+alpha*|gradC1|)*(-tanh(s*(C1-T)/A))*grad*dir / (1+fp_drag*|vel|/v_ref)",
49+
"steepness": "s = steep_toward if sign(C1-T)*sign(v_toward_T) < 0, else s = steep_away",
50+
"particle_to_field": "dC1 = -consumption * w(r), dC2 = production * w(r)",
51+
"particle_to_particle": "f = f_AR * (1 - damping * exp(-(C1_i - T)^2 / (2*width^2)))"
52+
},
53+
"params_mesh": [
54+
{
55+
"row": 0, "description": "C1 field parameters + CTC threshold",
56+
"slots": [
57+
{"index": 0, "name": "D1", "description": "Diffusion coeff for C1"},
58+
{"index": 1, "name": "Da_c", "description": "Damkohler number"},
59+
{"index": 2, "name": "A", "description": "Brusselator A"},
60+
{"index": 3, "name": "B", "description": "Brusselator B"},
61+
{"index": 4, "name": "mu", "description": "Morphological param"},
62+
{"index": 5, "name": "M1", "description": "Mobility for C1 gradients"},
63+
{"index": 6, "name": "grad_amp_alpha", "description": "Durotaxis amplification"},
64+
{"index": 7, "name": "ctc_threshold", "description": "CTC threshold (T=ctc*A)"}
65+
]
66+
},
67+
{
68+
"row": 1, "description": "C2 field + pp damping + asymmetry params",
69+
"slots": [
70+
{"index": 0, "name": "D2", "description": "Diffusion coeff for C2"},
71+
{"index": 1, "name": "M2", "description": "Mobility for C2 gradients"},
72+
{"index": 2, "name": "pp_damping", "description": "pp damping strength near T"},
73+
{"index": 3, "name": "pp_damping_width", "description": "Width of pp damping zone"},
74+
{"index": 4, "name": "steep_toward", "description": "CTC steepness for moving TOWARD threshold (default 3.0)"},
75+
{"index": 5, "name": "steep_away", "description": "CTC steepness for moving AWAY from threshold (default 1.0)"}
76+
]
77+
},
78+
{
79+
"row": 2, "description": "Particle-field coupling + fp drag",
80+
"slots": [
81+
{"index": 0, "name": "Pe", "description": "Peclet number"},
82+
{"index": 1, "name": "consumption", "description": "Consumption rate of C1"},
83+
{"index": 2, "name": "production", "description": "Production rate of C2"},
84+
{"index": 3, "name": "influence_radius", "description": "Gaussian pf influence radius"},
85+
{"index": 4, "name": "fp_drag", "description": "Velocity-dependent fp drag"},
86+
{"index": 5, "name": "cross_type_factor", "description": "Per-type CTC threshold spread"}
87+
]
88+
}
89+
],
90+
"width_constraint": "ALL rows of params_mesh MUST have same number of columns (8). Pad shorter rows with 0.0."
91+
}
92+
93+
def __init__(self, aggr_type='mean', p=None, particle_params=None, bc_dpos=None, dimension=2, sigma=0.005):
94+
super(PDE_D_AsymmetricCTC, self).__init__(aggr=aggr_type)
95+
96+
self.p = p
97+
self.particle_params = particle_params
98+
self.bc_dpos = bc_dpos
99+
self.dimension = dimension
100+
self.sigma = sigma
101+
102+
self.M1 = p[0, 5]
103+
self.M2 = p[1, 1]
104+
self.consumption_rate = p[2, 1]
105+
self.production_rate = p[2, 2]
106+
self.influence_radius = p[2, 3]
107+
self.Pe = p[2, 0]
108+
self.repulsion_strength = 50
109+
self.repulsion_range = 0.04
110+
111+
# Durotaxis gradient amplification
112+
self.grad_amp_alpha = p[0, 6] if p.shape[1] > 6 else 0.0
113+
114+
# CTC threshold
115+
self.ctc_threshold = p[0, 7] if p.shape[1] > 7 else 0.0
116+
self.A_ref = p[0, 2]
117+
118+
# Per-type threshold spread
119+
self.cross_type_factor = p[2, 5] if p.shape[1] > 5 else 0.0
120+
121+
# pp damping parameters (Painter & Hillen 2002)
122+
self.pp_damping = p[1, 2] if p.shape[1] > 2 else 0.0
123+
self.pp_damping_width = p[1, 3] if p.shape[1] > 3 else 0.5
124+
125+
# Asymmetric CTC steepness (Ferrell 2002)
126+
# steep_toward: steepness when particle is moving TOWARD threshold
127+
# steep_away: steepness when particle is moving AWAY from threshold
128+
self.steep_toward = p[1, 4] if p.shape[1] > 4 and p[1, 4] != 0 else 3.0
129+
self.steep_away = p[1, 5] if p.shape[1] > 5 and p[1, 5] != 0 else 1.0
130+
131+
# Convert to tensors if needed
132+
if not isinstance(self.steep_toward, torch.Tensor):
133+
self.steep_toward = torch.tensor(float(self.steep_toward), device=p.device)
134+
if not isinstance(self.steep_away, torch.Tensor):
135+
self.steep_away = torch.tensor(float(self.steep_away), device=p.device)
136+
137+
# Velocity-dependent fp drag (Tranquillo & Lauffenburger 1987)
138+
self.fp_drag = p[2, 4] if p.shape[1] > 4 else 0.0
139+
self.v_ref = 0.01
140+
141+
print(f"initialized PDE_D_AsymmetricCTC with parameters:")
142+
print(f" mobility: M1={self.M1.item()}, M2={self.M2.item()}")
143+
ga_val = self.grad_amp_alpha.item() if hasattr(self.grad_amp_alpha, 'item') else self.grad_amp_alpha
144+
print(f" grad_amp_alpha={ga_val:.3f} (durotaxis, Lo 2000)")
145+
ctc_val = self.ctc_threshold.item() if hasattr(self.ctc_threshold, 'item') else self.ctc_threshold
146+
T_val = ctc_val * self.A_ref.item()
147+
print(f" ctc_threshold={ctc_val:.3f} (T={T_val:.2f}, Wolpert 1969)")
148+
st_val = self.steep_toward.item() if hasattr(self.steep_toward, 'item') else self.steep_toward
149+
sa_val = self.steep_away.item() if hasattr(self.steep_away, 'item') else self.steep_away
150+
print(f" ASYMMETRIC CTC: steep_toward={st_val:.2f}, steep_away={sa_val:.2f} (Ferrell 2002)")
151+
print(f" asymmetry ratio = {st_val/sa_val:.1f}x (higher = stickier threshold)")
152+
damp_val = self.pp_damping.item() if hasattr(self.pp_damping, 'item') else self.pp_damping
153+
damp_w = self.pp_damping_width.item() if hasattr(self.pp_damping_width, 'item') else self.pp_damping_width
154+
print(f" pp_damping={damp_val:.3f}, pp_damping_width={damp_w:.3f} (Painter & Hillen 2002)")
155+
fp_drag_val = self.fp_drag.item() if hasattr(self.fp_drag, 'item') else self.fp_drag
156+
print(f" fp_drag={fp_drag_val:.3f}, v_ref={self.v_ref:.4f} (Tranquillo 1987)")
157+
ctf_val = self.cross_type_factor.item() if hasattr(self.cross_type_factor, 'item') else self.cross_type_factor
158+
if ctf_val > 0 and particle_params is not None:
159+
n_types = particle_params.shape[0]
160+
mean_idx = (n_types - 1) / 2.0
161+
for t in range(n_types):
162+
t_offset = ctf_val * (t - mean_idx)
163+
t_val_t = T_val * (1.0 + t_offset)
164+
print(f" Type {t}: CTC threshold = {t_val_t:.2f} (offset={t_offset:+.2f})")
165+
print(f" Pe={self.Pe.item():.3f}, sigma={self.sigma}")
166+
print(f" particle->field: consumption={self.consumption_rate.item()}, production={self.production_rate.item()}, influence_radius={self.influence_radius.item():.3f}")
167+
if particle_params is not None:
168+
print(f" multi-type support: {particle_params.shape[0]} particle types")
169+
170+
def forward(self, data, direction='fp'):
171+
x, edge_index = data.x, data.edge_index
172+
edge_index, _ = pyg_utils.remove_self_loops(edge_index)
173+
174+
if self.particle_params is not None:
175+
particle_type = x[:, 1 + 2*self.dimension].long()
176+
max_type = particle_type.max().item()
177+
n_param_rows = self.particle_params.shape[0]
178+
if max_type >= n_param_rows:
179+
raise ValueError(
180+
f"PDE_D_AsymmetricCTC: particle_params has {n_param_rows} rows but found "
181+
f"particle type {max_type}. Need {max_type + 1} rows in simulation.params."
182+
)
183+
parameters = self.particle_params[to_numpy(particle_type), :]
184+
else:
185+
parameters = None
186+
187+
if direction == 'interpolate':
188+
result = self.propagate(edge_index, x=x, mode='interpolate', parameters=parameters)
189+
pos = x[:, 1:self.dimension+1]
190+
in_box = ((pos >= 0) & (pos <= 1)).all(dim=1, keepdim=True)
191+
result = result * in_box.float()
192+
return result
193+
elif direction == 'fp':
194+
result = self.propagate(edge_index, x=x, mode='fp', parameters=parameters)
195+
pos = x[:, 1:self.dimension+1]
196+
in_box = ((pos >= 0) & (pos <= 1)).all(dim=1, keepdim=True)
197+
result = result * in_box.float()
198+
return result
199+
elif direction == 'pf':
200+
result = self.propagate(edge_index, x=x, mode='pf', parameters=parameters)
201+
return result
202+
else:
203+
result = self.propagate(edge_index, x=x, mode='pp', parameters=parameters)
204+
return result
205+
206+
def message(self, edge_index_i, edge_index_j, x_i, x_j, mode=None, parameters_i=None):
207+
pos_i = x_i[:, 1:self.dimension+1]
208+
pos_j = x_j[:, 1:self.dimension+1]
209+
210+
d_pos = self.bc_dpos(pos_j - pos_i)
211+
dist = torch.sqrt(torch.sum(d_pos**2, dim=1))
212+
dist_safe = torch.clamp(dist, min=1e-6)
213+
214+
if mode == 'interpolate':
215+
C1_mesh = x_j[:, 6:7]
216+
C2_mesh = x_j[:, 7:8]
217+
weight = torch.exp(-dist / 0.01).unsqueeze(1)
218+
return torch.cat([C1_mesh * weight, C2_mesh * weight, weight], dim=1)
219+
220+
elif mode == 'fp':
221+
fields_i = x_i[:, 6:8]
222+
fields_j = x_j[:, 6:8]
223+
224+
dC1 = fields_j[:, 0:1] - fields_i[:, 0:1]
225+
dC2 = fields_j[:, 1:2] - fields_i[:, 1:2]
226+
227+
kernel = torch.exp(-dist / 0.05)
228+
dir_norm = d_pos / dist_safe.unsqueeze(1)
229+
domain_scale = 32.0
230+
grad_C1 = (dC1 * kernel.unsqueeze(1)) / (dist_safe.unsqueeze(1) * domain_scale)
231+
grad_C2 = (dC2 * kernel.unsqueeze(1)) / (dist_safe.unsqueeze(1) * domain_scale)
232+
233+
if parameters_i is not None:
234+
M1 = parameters_i[:, 0:1]
235+
M2 = parameters_i[:, 1:2]
236+
else:
237+
M1 = self.M1
238+
M2 = self.M2
239+
240+
velocity_raw = (M1 * grad_C1 + M2 * grad_C2) * dir_norm
241+
242+
# 1. Durotaxis: amplify velocity at steep gradients (Lo et al. 2000)
243+
if self.grad_amp_alpha > 0:
244+
grad_mag = torch.abs(grad_C1)
245+
grad_mag_clamped = torch.clamp(grad_mag, max=1.0)
246+
amp_factor = 1.0 + self.grad_amp_alpha * grad_mag_clamped
247+
velocity_raw = velocity_raw * amp_factor
248+
249+
# 2. Asymmetric CTC (Wolpert 1969 + Ferrell 2002)
250+
if self.ctc_threshold > 0:
251+
C1_local = fields_i[:, 0:1]
252+
A_ref = self.A_ref
253+
base_T = self.ctc_threshold * A_ref
254+
255+
# Per-type thresholds
256+
if (parameters_i is not None and self.cross_type_factor > 0
257+
and x_i.numel() > 0):
258+
type_i = x_i[:, 1 + 2*self.dimension].long()
259+
n_types = type_i.max().item() + 1 if type_i.numel() > 0 else 1
260+
mean_idx = (n_types - 1) / 2.0
261+
type_offset = self.cross_type_factor * (type_i.float() - mean_idx)
262+
T = base_T * (1.0 + type_offset.unsqueeze(1))
263+
else:
264+
T = base_T
265+
266+
# Deviation from threshold
267+
deviation = C1_local - T
268+
269+
# Determine particle velocity direction relative to threshold
270+
# vel_i gives the current velocity of the particle
271+
vel_i = x_i[:, 1+self.dimension:1+2*self.dimension]
272+
273+
# The raw fp force direction tells us if the particle would
274+
# move toward or away from T. If the chemotactic force would
275+
# push a particle that is ABOVE T even further above (same sign
276+
# as deviation), the particle is moving AWAY from threshold.
277+
# We use the sign of deviation * the fp gradient direction.
278+
# When deviation > 0 (above T) and gradient pushes toward
279+
# higher C1, particle moves AWAY -> use steep_away.
280+
# When deviation > 0 and gradient pushes toward lower C1,
281+
# particle moves TOWARD T -> use steep_toward.
282+
283+
# Simple implementation: use particle speed as proxy.
284+
# If particle is already moving fast, it's likely overshooting
285+
# -> use steep_away (gentle return).
286+
# If particle is slow (near threshold), use steep_toward (strong lock).
287+
speed = torch.sqrt(torch.sum(vel_i**2, dim=1, keepdim=True))
288+
speed_ratio = speed / (self.v_ref + 1e-8)
289+
290+
# Blend steepness based on speed: high speed = away (gentle),
291+
# low speed = toward (sharp).
292+
# At speed=0: steepness = steep_toward
293+
# At speed>>v_ref: steepness -> steep_away
294+
blend = torch.sigmoid(speed_ratio - 1.0) # 0.5 at speed=v_ref
295+
steepness = self.steep_toward * (1.0 - blend) + self.steep_away * blend
296+
297+
sign_factor = -torch.tanh(steepness * deviation / (A_ref + 1e-6))
298+
velocity_raw = velocity_raw * sign_factor
299+
300+
# 3. Velocity-dependent fp drag (Tranquillo & Lauffenburger 1987)
301+
if self.fp_drag > 0:
302+
vel_i = x_i[:, 1+self.dimension:1+2*self.dimension]
303+
speed = torch.sqrt(torch.sum(vel_i**2, dim=1, keepdim=True))
304+
drag_factor = 1.0 / (1.0 + self.fp_drag * speed / self.v_ref)
305+
velocity_raw = velocity_raw * drag_factor
306+
307+
return velocity_raw
308+
309+
elif mode == 'pf':
310+
weights = torch.exp(-dist**2 / (2 * (self.influence_radius/3)**2))
311+
312+
if parameters_i is not None:
313+
consumption = parameters_i[:, 2]
314+
production = parameters_i[:, 3]
315+
else:
316+
consumption = self.consumption_rate
317+
production = self.production_rate
318+
319+
field_updates = torch.zeros((pos_i.size(0), 2), device=pos_i.device)
320+
field_updates[:, 0] = -consumption * weights
321+
field_updates[:, 1] = production * weights
322+
return field_updates
323+
324+
else: # mode == 'pp'
325+
if parameters_i is not None:
326+
p1 = parameters_i[:, 4]
327+
p2 = parameters_i[:, 5]
328+
p3 = parameters_i[:, 6]
329+
p4 = parameters_i[:, 7]
330+
331+
f = (p1 * torch.exp(-dist ** (2 * p2) / (2 * self.sigma ** 2))
332+
- p3 * torch.exp(-dist ** (2 * p4) / (2 * self.sigma ** 2)))
333+
334+
forces = f[:, None] * d_pos / dist_safe.unsqueeze(1)
335+
else:
336+
forces = torch.zeros_like(pos_i)
337+
in_range = dist < self.repulsion_range
338+
if in_range.any():
339+
dir_norm = d_pos / dist_safe.unsqueeze(1)
340+
repulsion_mag = self.repulsion_strength * torch.exp(
341+
-5.0 * dist[in_range] / self.repulsion_range
342+
)
343+
forces[in_range] = -dir_norm[in_range] * repulsion_mag.unsqueeze(1)
344+
345+
# Field-dependent pp damping (Painter & Hillen 2002)
346+
if self.pp_damping > 0 and self.ctc_threshold > 0:
347+
C1_local = x_i[:, 6:7].squeeze(1)
348+
A_ref = self.A_ref
349+
base_T = self.ctc_threshold * A_ref
350+
351+
if (parameters_i is not None and self.cross_type_factor > 0
352+
and x_i.numel() > 0):
353+
type_i = x_i[:, 1 + 2*self.dimension].long()
354+
n_types = type_i.max().item() + 1 if type_i.numel() > 0 else 1
355+
mean_idx = (n_types - 1) / 2.0
356+
type_offset = self.cross_type_factor * (type_i.float() - mean_idx)
357+
T_local = base_T * (1.0 + type_offset)
358+
else:
359+
T_local = base_T
360+
361+
width = self.pp_damping_width * A_ref
362+
deviation = (C1_local - T_local)
363+
damping_factor = 1.0 - self.pp_damping * torch.exp(-deviation**2 / (2 * width**2 + 1e-8))
364+
forces = forces * damping_factor.unsqueeze(1)
365+
366+
return forces
367+
368+
def update(self, aggr_out, mode=None):
369+
if mode == 'interpolate':
370+
C1_weighted = aggr_out[:, 0:1]
371+
C2_weighted = aggr_out[:, 1:2]
372+
weight_sum = aggr_out[:, 2:3]
373+
weight_sum = torch.clamp(weight_sum, min=1e-10)
374+
return torch.cat([C1_weighted / weight_sum, C2_weighted / weight_sum], dim=1)
375+
else:
376+
return aggr_out

0 commit comments

Comments
 (0)