Skip to content

Commit 0a83df3

Browse files
committed
[Iter 336] Code modification in PDE_D_DualFieldCTC.py
[Automated commit by Claude]
1 parent b8aa654 commit 0a83df3

1 file changed

Lines changed: 381 additions & 0 deletions

File tree

Lines changed: 381 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,381 @@
1+
import torch
2+
import torch_geometric as pyg
3+
import torch_geometric.utils as pyg_utils
4+
from ParticleGraph.utils import to_numpy
5+
6+
class PDE_D_DualFieldCTC(pyg.nn.MessagePassing):
7+
"""
8+
Dual-field CTC: C2 modulates the CTC response strength based on C1.
9+
10+
Extends DampedCTC by adding a second morphogen channel to the CTC decision.
11+
Instead of sign_factor = -tanh(s*(C1-T)/A), uses:
12+
sign_factor = -tanh(s*(C1-T)/A) * (1 + beta * tanh(s2*(C2-T2)/A))
13+
14+
When C2 is near its own threshold T2, the C1-based CTC response is amplified
15+
(or attenuated if beta < 0). This provides a second channel of positional
16+
information for particle sorting, inspired by morphogen gradient intersection
17+
in developmental biology.
18+
19+
The key difference from all prior CTC variants: the core tanh(C1-T) mechanism
20+
is PRESERVED (not replaced). C2 acts as a MODULATOR, not an alternative signal.
21+
This avoids the anti-convergence problem of deadzone, ratio, and gradient-based
22+
CTC modifications.
23+
24+
Physics:
25+
1. fp: Durotaxis gradient-amplified mobility + dual-field CTC coupling
26+
v = M * (1 + alpha * |gradC1|) * (-tanh(s*(C1-T)/A)) * (1+beta*tanh(s2*(C2-T2)/A)) * grad * dir
27+
2. pf: Standard consumption/production coupling
28+
3. pp: Field-damped attraction-repulsion (same as DampedCTC)
29+
30+
Literature:
31+
- Wolpert, L. (1969) J Theor Biol 25:1-47
32+
"Positional information and the spatial pattern of cellular differentiation"
33+
(Intersecting morphogen gradients for 2D positional specification)
34+
- Green, J.B.A. & Sharpe, J. (2015) Development 142:1203-1211
35+
"Positional information and reaction-diffusion: two big ideas in developmental
36+
biology combine" (Multi-morphogen positional encoding)
37+
- Painter, K.J. & Hillen, T. (2002) CAMQ 10(4):501-543
38+
39+
Per-type params layout: [M1, M2, consumption, production, ar_p1, ar_p2, ar_p3, ar_p4]
40+
"""
41+
42+
PARAMS_DOC = {
43+
"model_name": "DualFieldCTC",
44+
"literature": "Wolpert (1969) J Theor Biol 25:1; Green & Sharpe (2015) Development 142:1203",
45+
"description": "Dual-field CTC: C2 modulates C1-based CTC response strength for 2D positional encoding",
46+
"equations": {
47+
"field_to_particle": "v = M * (1+alpha*|gradC1|) * (-tanh(s*(C1-T1)/A)) * (1+beta*tanh(s2*(C2-T2)/A)) * grad * dir",
48+
"particle_to_field": "dC1 = -consumption * w(r), dC2 = production * w(r)",
49+
"particle_to_particle": "f = f_AR * (1 - damping * exp(-(C1_i - T)^2 / (2*width^2)))"
50+
},
51+
"params_mesh": [
52+
{
53+
"row": 0, "description": "C1 field parameters + CTC threshold",
54+
"slots": [
55+
{"index": 0, "name": "D1", "description": "Diffusion coeff for C1 (mesh model)", "typical_range": [0.01, 0.5]},
56+
{"index": 1, "name": "Da_c", "description": "Damkohler number (mesh model)", "typical_range": [1.0, 50.0]},
57+
{"index": 2, "name": "A", "description": "Brusselator param A (mesh model, also CTC reference)", "typical_range": [0.5, 5.0]},
58+
{"index": 3, "name": "B", "description": "Brusselator param B (mesh model)", "typical_range": [1.0, 10.0]},
59+
{"index": 4, "name": "mu", "description": "Morphological parameter (mesh model)", "typical_range": [0.01, 0.1]},
60+
{"index": 5, "name": "M1", "description": "Mobility coefficient for C1 gradients", "typical_range": [-16, 16]},
61+
{"index": 6, "name": "grad_amp_alpha", "description": "Durotaxis gradient amplification (0=off)", "typical_range": [0.0, 2.0]},
62+
{"index": 7, "name": "ctc_threshold", "description": "CTC threshold for C1 (T1=ctc*A)", "typical_range": [0.5, 3.0]}
63+
]
64+
},
65+
{
66+
"row": 1, "description": "C2 field parameters + pp damping + C2 modulation",
67+
"slots": [
68+
{"index": 0, "name": "D2", "description": "Diffusion coeff for C2 (mesh model)", "typical_range": [0.1, 1.0]},
69+
{"index": 1, "name": "M2", "description": "Mobility coefficient for C2 gradients", "typical_range": [-16, 16]},
70+
{"index": 2, "name": "pp_damping", "description": "pp damping strength at CTC threshold", "typical_range": [0.0, 0.95]},
71+
{"index": 3, "name": "pp_damping_width", "description": "Width of Gaussian damping zone", "typical_range": [0.1, 1.0]},
72+
{"index": 4, "name": "c2_beta", "description": "C2 modulation strength (0=off, >0 amplifies, <0 attenuates)", "typical_range": [-0.5, 0.5]},
73+
{"index": 5, "name": "c2_threshold", "description": "C2 threshold factor (T2=c2_thresh*A for Brusselator equilibrium B/A)", "typical_range": [0.5, 3.0]},
74+
{"index": 6, "name": "c2_steepness", "description": "Steepness of C2 modulation tanh", "typical_range": [1.0, 5.0]},
75+
{"index": 7, "name": "unused", "description": "Pad", "typical_range": [0.0, 0.0]}
76+
]
77+
},
78+
{
79+
"row": 2, "description": "Particle-field coupling + per-type threshold spread",
80+
"slots": [
81+
{"index": 0, "name": "Pe", "description": "Peclet number", "typical_range": [0.5, 2.0]},
82+
{"index": 1, "name": "consumption", "description": "Particle consumption rate of C1", "typical_range": [10, 200]},
83+
{"index": 2, "name": "production", "description": "Particle production rate of C2", "typical_range": [-200, -10]},
84+
{"index": 3, "name": "influence_radius", "description": "Gaussian influence radius for pf coupling", "typical_range": [0.01, 0.1]},
85+
{"index": 4, "name": "fp_drag", "description": "Velocity-dependent drag (0=off)", "typical_range": [0.0, 0.3]},
86+
{"index": 5, "name": "cross_type_factor", "description": "Per-type CTC threshold spread", "typical_range": [0.0, 0.5]},
87+
{"index": 6, "name": "unused2", "description": "Pad", "typical_range": [0.0, 0.0]},
88+
{"index": 7, "name": "unused3", "description": "Pad", "typical_range": [0.0, 0.0]}
89+
]
90+
}
91+
],
92+
"width_constraint": "ALL rows of params_mesh MUST have same number of columns (8). Pad shorter rows.",
93+
"particle_params": {
94+
"description": "Per-type params from simulation.params (one row per n_particle_types)",
95+
"slots": [
96+
{"index": 0, "name": "M1", "description": "Per-type mobility for C1"},
97+
{"index": 1, "name": "M2", "description": "Per-type mobility for C2"},
98+
{"index": 2, "name": "consumption", "description": "Per-type consumption rate"},
99+
{"index": 3, "name": "production", "description": "Per-type production rate"},
100+
{"index": 4, "name": "ar_p1", "description": "Attraction strength"},
101+
{"index": 5, "name": "ar_p2", "description": "Attraction exponent"},
102+
{"index": 6, "name": "ar_p3", "description": "Repulsion strength"},
103+
{"index": 7, "name": "ar_p4", "description": "Repulsion exponent"}
104+
]
105+
}
106+
}
107+
108+
def __init__(self, aggr_type='mean', p=None, particle_params=None, bc_dpos=None, dimension=2, sigma=0.005):
109+
super(PDE_D_DualFieldCTC, self).__init__(aggr=aggr_type)
110+
111+
self.p = p
112+
self.particle_params = particle_params
113+
self.bc_dpos = bc_dpos
114+
self.dimension = dimension
115+
self.sigma = sigma
116+
117+
self.M1 = p[0, 5]
118+
self.M2 = p[1, 1]
119+
self.consumption_rate = p[2, 1]
120+
self.production_rate = p[2, 2]
121+
self.influence_radius = p[2, 3]
122+
self.Pe = p[2, 0]
123+
self.repulsion_strength = 50
124+
self.repulsion_range = 0.04
125+
126+
# Durotaxis gradient amplification
127+
self.grad_amp_alpha = p[0, 6] if p.shape[1] > 6 else 0.0
128+
129+
# CTC threshold for C1
130+
self.ctc_threshold = p[0, 7] if p.shape[1] > 7 else 0.0
131+
self.A_ref = p[0, 2]
132+
133+
# Per-type threshold spread
134+
self.cross_type_factor = p[2, 5] if p.shape[1] > 5 else 0.0
135+
136+
# pp damping parameters (Painter & Hillen 2002)
137+
self.pp_damping = p[1, 2] if p.shape[1] > 2 else 0.0
138+
self.pp_damping_width = p[1, 3] if p.shape[1] > 3 else 0.5
139+
140+
# C2 modulation parameters (Wolpert 1969; Green & Sharpe 2015)
141+
self.c2_beta = p[1, 4] if p.shape[1] > 4 else 0.0
142+
self.c2_threshold = p[1, 5] if p.shape[1] > 5 else 1.0
143+
self.c2_steepness = p[1, 6] if p.shape[1] > 6 else 3.0
144+
145+
# fp drag
146+
self.fp_drag = p[2, 4] if p.shape[1] > 4 else 0.0
147+
148+
print(f"initialized PDE_D_DualFieldCTC with parameters:")
149+
print(f" mobility: M1={self.M1.item()}, M2={self.M2.item()}")
150+
ga_val = self.grad_amp_alpha.item() if hasattr(self.grad_amp_alpha, 'item') else self.grad_amp_alpha
151+
print(f" grad_amp_alpha={ga_val:.3f} (durotaxis, Lo 2000)")
152+
ctc_val = self.ctc_threshold.item() if hasattr(self.ctc_threshold, 'item') else self.ctc_threshold
153+
T_val = ctc_val * self.A_ref.item()
154+
print(f" ctc_threshold={ctc_val:.3f} (T1={T_val:.2f}, Wolpert 1969)")
155+
c2b = self.c2_beta.item() if hasattr(self.c2_beta, 'item') else self.c2_beta
156+
c2t = self.c2_threshold.item() if hasattr(self.c2_threshold, 'item') else self.c2_threshold
157+
c2s = self.c2_steepness.item() if hasattr(self.c2_steepness, 'item') else self.c2_steepness
158+
T2_val = c2t * self.A_ref.item()
159+
print(f" C2 modulation: beta={c2b:.3f}, T2={T2_val:.2f} (c2_thresh={c2t:.2f}), steepness={c2s:.1f} (Green & Sharpe 2015)")
160+
damp_val = self.pp_damping.item() if hasattr(self.pp_damping, 'item') else self.pp_damping
161+
damp_w = self.pp_damping_width.item() if hasattr(self.pp_damping_width, 'item') else self.pp_damping_width
162+
print(f" pp_damping={damp_val:.3f}, pp_damping_width={damp_w:.3f}")
163+
fp_d = self.fp_drag.item() if hasattr(self.fp_drag, 'item') else self.fp_drag
164+
print(f" fp_drag={fp_d:.3f}")
165+
ctf_val = self.cross_type_factor.item() if hasattr(self.cross_type_factor, 'item') else self.cross_type_factor
166+
if ctf_val > 0 and particle_params is not None:
167+
n_types = particle_params.shape[0]
168+
mean_idx = (n_types - 1) / 2.0
169+
for t in range(n_types):
170+
t_offset = ctf_val * (t - mean_idx)
171+
t_val = T_val * (1.0 + t_offset)
172+
print(f" Type {t}: CTC threshold = {t_val:.2f} (offset={t_offset:+.2f})")
173+
print(f" Pe={self.Pe.item():.3f}, sigma={self.sigma}")
174+
print(f" particle->field: consumption={self.consumption_rate.item()}, production={self.production_rate.item()}, influence_radius={self.influence_radius.item():.3f}")
175+
if particle_params is not None:
176+
print(f" multi-type support: {particle_params.shape[0]} particle types")
177+
178+
def forward(self, data, direction='fp'):
179+
x, edge_index = data.x, data.edge_index
180+
edge_index, _ = pyg_utils.remove_self_loops(edge_index)
181+
182+
if self.particle_params is not None:
183+
particle_type = x[:, 1 + 2*self.dimension].long()
184+
max_type = particle_type.max().item()
185+
n_param_rows = self.particle_params.shape[0]
186+
if max_type >= n_param_rows:
187+
raise ValueError(
188+
f"PDE_D_DualFieldCTC: particle_params has {n_param_rows} rows but found "
189+
f"particle type {max_type}. Need {max_type + 1} rows in simulation.params."
190+
)
191+
parameters = self.particle_params[to_numpy(particle_type), :]
192+
else:
193+
parameters = None
194+
195+
if direction == 'interpolate':
196+
result = self.propagate(edge_index, x=x, mode='interpolate', parameters=parameters)
197+
pos = x[:, 1:self.dimension+1]
198+
in_box = ((pos >= 0) & (pos <= 1)).all(dim=1, keepdim=True)
199+
result = result * in_box.float()
200+
return result
201+
elif direction == 'fp':
202+
result = self.propagate(edge_index, x=x, mode='fp', parameters=parameters)
203+
pos = x[:, 1:self.dimension+1]
204+
in_box = ((pos >= 0) & (pos <= 1)).all(dim=1, keepdim=True)
205+
result = result * in_box.float()
206+
return result
207+
elif direction == 'pf':
208+
result = self.propagate(edge_index, x=x, mode='pf', parameters=parameters)
209+
return result
210+
else:
211+
result = self.propagate(edge_index, x=x, mode='pp', parameters=parameters)
212+
return result
213+
214+
def message(self, edge_index_i, edge_index_j, x_i, x_j, mode=None, parameters_i=None):
215+
pos_i = x_i[:, 1:self.dimension+1]
216+
pos_j = x_j[:, 1:self.dimension+1]
217+
218+
d_pos = self.bc_dpos(pos_j - pos_i)
219+
dist = torch.sqrt(torch.sum(d_pos**2, dim=1))
220+
dist_safe = torch.clamp(dist, min=1e-6)
221+
222+
if mode == 'interpolate':
223+
C1_mesh = x_j[:, 6:7]
224+
C2_mesh = x_j[:, 7:8]
225+
weight = torch.exp(-dist / 0.01).unsqueeze(1)
226+
return torch.cat([C1_mesh * weight, C2_mesh * weight, weight], dim=1)
227+
228+
elif mode == 'fp':
229+
fields_i = x_i[:, 6:8]
230+
fields_j = x_j[:, 6:8]
231+
232+
dC1 = fields_j[:, 0:1] - fields_i[:, 0:1]
233+
dC2 = fields_j[:, 1:2] - fields_i[:, 1:2]
234+
235+
kernel = torch.exp(-dist / 0.05)
236+
dir_norm = d_pos / dist_safe.unsqueeze(1)
237+
domain_scale = 32.0
238+
grad_C1 = (dC1 * kernel.unsqueeze(1)) / (dist_safe.unsqueeze(1) * domain_scale)
239+
grad_C2 = (dC2 * kernel.unsqueeze(1)) / (dist_safe.unsqueeze(1) * domain_scale)
240+
241+
if parameters_i is not None:
242+
M1 = parameters_i[:, 0:1]
243+
M2 = parameters_i[:, 1:2]
244+
else:
245+
M1 = self.M1
246+
M2 = self.M2
247+
248+
velocity_raw = (M1 * grad_C1 + M2 * grad_C2) * dir_norm
249+
250+
# 1. Durotaxis: amplify velocity at steep gradients (Lo et al. 2000)
251+
if self.grad_amp_alpha > 0:
252+
grad_mag = torch.abs(grad_C1)
253+
grad_mag_clamped = torch.clamp(grad_mag, max=1.0)
254+
amp_factor = 1.0 + self.grad_amp_alpha * grad_mag_clamped
255+
velocity_raw = velocity_raw * amp_factor
256+
257+
# 2. Concentration-threshold coupling on C1 (Wolpert 1969)
258+
if self.ctc_threshold > 0:
259+
C1_local = fields_i[:, 0:1]
260+
C2_local = fields_i[:, 1:2]
261+
A_ref = self.A_ref
262+
base_T = self.ctc_threshold * A_ref
263+
steepness = 3.0
264+
265+
# Per-type thresholds when multi-type + cross_type_factor > 0
266+
if (parameters_i is not None and self.cross_type_factor > 0
267+
and x_i.numel() > 0):
268+
type_i = x_i[:, 1 + 2*self.dimension].long()
269+
n_types = type_i.max().item() + 1 if type_i.numel() > 0 else 1
270+
mean_idx = (n_types - 1) / 2.0
271+
type_offset = self.cross_type_factor * (type_i.float() - mean_idx)
272+
T = base_T * (1.0 + type_offset.unsqueeze(1))
273+
else:
274+
T = base_T
275+
276+
# Core CTC on C1 — PRESERVED exactly as DampedCTC
277+
sign_factor = -torch.tanh(steepness * (C1_local - T) / (A_ref + 1e-6))
278+
279+
# 3. C2 modulation (Green & Sharpe 2015)
280+
# C2 provides a second channel of positional information
281+
# When beta > 0: C2 near T2 amplifies CTC response
282+
# When beta < 0: C2 near T2 attenuates CTC response
283+
c2_beta_val = self.c2_beta
284+
if hasattr(c2_beta_val, 'item'):
285+
c2_beta_check = c2_beta_val.item()
286+
else:
287+
c2_beta_check = float(c2_beta_val)
288+
289+
if abs(c2_beta_check) > 1e-6:
290+
T2 = self.c2_threshold * A_ref
291+
c2_steep = self.c2_steepness
292+
if hasattr(c2_steep, 'item'):
293+
c2_steep = c2_steep
294+
c2_mod = 1.0 + c2_beta_val * torch.tanh(c2_steep * (C2_local - T2) / (A_ref + 1e-6))
295+
sign_factor = sign_factor * c2_mod
296+
297+
velocity_raw = velocity_raw * sign_factor
298+
299+
# 4. Velocity-dependent drag (Tranquillo 1987)
300+
fp_drag_val = self.fp_drag
301+
if hasattr(fp_drag_val, 'item'):
302+
fp_drag_check = fp_drag_val.item()
303+
else:
304+
fp_drag_check = float(fp_drag_val)
305+
306+
if fp_drag_check > 0:
307+
vel_i = x_i[:, self.dimension+1:2*self.dimension+1]
308+
speed = torch.sqrt(torch.sum(vel_i**2, dim=1, keepdim=True) + 1e-10)
309+
drag = 1.0 / (1.0 + fp_drag_check * speed)
310+
velocity_raw = velocity_raw * drag
311+
312+
return velocity_raw
313+
314+
elif mode == 'pf':
315+
weights = torch.exp(-dist**2 / (2 * (self.influence_radius/3)**2))
316+
317+
if parameters_i is not None:
318+
consumption = parameters_i[:, 2]
319+
production = parameters_i[:, 3]
320+
else:
321+
consumption = self.consumption_rate
322+
production = self.production_rate
323+
324+
field_updates = torch.zeros((pos_i.size(0), 2), device=pos_i.device)
325+
field_updates[:, 0] = -consumption * weights
326+
field_updates[:, 1] = production * weights
327+
return field_updates
328+
329+
else: # mode == 'pp'
330+
if parameters_i is not None:
331+
p1 = parameters_i[:, 4]
332+
p2 = parameters_i[:, 5]
333+
p3 = parameters_i[:, 6]
334+
p4 = parameters_i[:, 7]
335+
336+
f = (p1 * torch.exp(-dist ** (2 * p2) / (2 * self.sigma ** 2))
337+
- p3 * torch.exp(-dist ** (2 * p4) / (2 * self.sigma ** 2)))
338+
339+
forces = f[:, None] * d_pos / dist_safe.unsqueeze(1)
340+
else:
341+
forces = torch.zeros_like(pos_i)
342+
in_range = dist < self.repulsion_range
343+
if in_range.any():
344+
dir_norm = d_pos / dist_safe.unsqueeze(1)
345+
repulsion_mag = self.repulsion_strength * torch.exp(
346+
-5.0 * dist[in_range] / self.repulsion_range
347+
)
348+
forces[in_range] = -dir_norm[in_range] * repulsion_mag.unsqueeze(1)
349+
350+
# Field-dependent pp damping (Painter & Hillen 2002)
351+
if self.pp_damping > 0 and self.ctc_threshold > 0:
352+
C1_local = x_i[:, 6:7].squeeze(1)
353+
A_ref = self.A_ref
354+
base_T = self.ctc_threshold * A_ref
355+
356+
if (parameters_i is not None and self.cross_type_factor > 0
357+
and x_i.numel() > 0):
358+
type_i = x_i[:, 1 + 2*self.dimension].long()
359+
n_types = type_i.max().item() + 1 if type_i.numel() > 0 else 1
360+
mean_idx = (n_types - 1) / 2.0
361+
type_offset = self.cross_type_factor * (type_i.float() - mean_idx)
362+
T_local = base_T * (1.0 + type_offset)
363+
else:
364+
T_local = base_T
365+
366+
width = self.pp_damping_width * A_ref
367+
deviation = (C1_local - T_local)
368+
damping_factor = 1.0 - self.pp_damping * torch.exp(-deviation**2 / (2 * width**2 + 1e-8))
369+
forces = forces * damping_factor.unsqueeze(1)
370+
371+
return forces
372+
373+
def update(self, aggr_out, mode=None):
374+
if mode == 'interpolate':
375+
C1_weighted = aggr_out[:, 0:1]
376+
C2_weighted = aggr_out[:, 1:2]
377+
weight_sum = aggr_out[:, 2:3]
378+
weight_sum = torch.clamp(weight_sum, min=1e-10)
379+
return torch.cat([C1_weighted / weight_sum, C2_weighted / weight_sum], dim=1)
380+
else:
381+
return aggr_out

0 commit comments

Comments
 (0)