Skip to content

Commit 23dd6f4

Browse files
committed
[Iter 320] Code modification in PDE_D_DensityDragCIL.py
[Automated commit by Claude]
1 parent 34fb391 commit 23dd6f4

1 file changed

Lines changed: 308 additions & 0 deletions

File tree

Lines changed: 308 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,308 @@
1+
import torch
2+
import torch_geometric as pyg
3+
import torch_geometric.utils as pyg_utils
4+
from ParticleGraph.utils import to_numpy
5+
6+
class PDE_D_DensityDragCIL(pyg.nn.MessagePassing):
7+
"""
8+
Density-dependent mobility (CIL) + velocity-dependent fp drag.
9+
10+
Combines two proven convergence mechanisms that have NEVER been used together:
11+
1. Contact Inhibition of Locomotion (Mayor & Carmona-Fontaine 2010):
12+
Particles reduce mobility when local density is high.
13+
f(rho) = 1 / (1 + (rho/rho_0)^n)
14+
2. Velocity-dependent fp drag (Tranquillo & Lauffenburger 1987):
15+
Fast-moving particles reduce chemotactic sensitivity.
16+
drag = 1 / (1 + fp_drag * |v| / v_ref)
17+
18+
Rationale: CIL provides density-based self-limiting aggregation, but
19+
clusters can oscillate as particles overshoot the density equilibrium.
20+
Adding velocity drag penalizes fast oscillatory motion, potentially
21+
stabilizing the CIL mechanism and improving convergence.
22+
23+
Physics:
24+
1. fp: v = M * f(rho) * nabla_C / (1 + fp_drag * |vel|/v_ref)
25+
- f(rho) = 1/(1+(rho/rho_0)^n) is Hill function CIL
26+
- drag attenuates response at high speed
27+
2. pf: Standard consumption/production
28+
3. pp: Standard attraction-repulsion (density computed from pp graph)
29+
30+
Literature:
31+
- Mayor, R. & Carmona-Fontaine, C. (2010) Trends in Cell Biology 20:319-328
32+
"Keeping in touch with contact inhibition of locomotion"
33+
- Cates, M. E. & Tailleur, J. (2015) ARCMP 6:219-244
34+
"Motility-induced phase separation"
35+
- Tranquillo, R. T. & Lauffenburger, D. A. (1987) J Math Biol 25:229-262
36+
"Stochastic model of leukocyte chemosensory movement"
37+
38+
Per-type params layout: [M1, M2, consumption, production, ar_p1, ar_p2, ar_p3, ar_p4]
39+
"""
40+
41+
PARAMS_DOC = {
42+
"model_name": "DensityDragCIL",
43+
"literature": "Mayor & Carmona-Fontaine (2010); Cates & Tailleur (2015); Tranquillo (1987)",
44+
"description": "Density-dependent CIL + velocity drag: density limits aggregation, speed limits oscillation",
45+
"equations": {
46+
"field_to_particle": "v = M * f(rho) * nabla_C / (1 + fp_drag * |vel|/v_ref)",
47+
"density_function": "f(rho) = 1 / (1 + (rho/rho_0)^n)",
48+
"particle_to_field": "dC1 = -consumption * w(r), dC2 = production * w(r)",
49+
"particle_to_particle": "f = (p1*exp(-d^(2p2)/(2sigma^2)) - p3*exp(-d^(2p4)/(2sigma^2))) * dir"
50+
},
51+
"params_mesh": [
52+
{
53+
"row": 0, "description": "C1 field parameters",
54+
"slots": [
55+
{"index": 0, "name": "D1", "description": "Diffusion coeff for C1"},
56+
{"index": 1, "name": "Da_c", "description": "Damkohler number"},
57+
{"index": 2, "name": "A", "description": "Brusselator A"},
58+
{"index": 3, "name": "B", "description": "Brusselator B"},
59+
{"index": 4, "name": "mu", "description": "Morphological param"},
60+
{"index": 5, "name": "M1", "description": "Mobility for C1 gradients"},
61+
{"index": 6, "name": "unused_6", "description": "Unused (pad)"},
62+
{"index": 7, "name": "unused_7", "description": "Unused (pad)"}
63+
]
64+
},
65+
{
66+
"row": 1, "description": "C2 field parameters",
67+
"slots": [
68+
{"index": 0, "name": "D2", "description": "Diffusion coeff for C2"},
69+
{"index": 1, "name": "M2", "description": "Mobility for C2 gradients"},
70+
{"index": 2, "name": "unused_2", "description": "Unused (pad)"},
71+
{"index": 3, "name": "unused_3", "description": "Unused (pad)"},
72+
{"index": 4, "name": "unused_4", "description": "Unused (pad)"},
73+
{"index": 5, "name": "unused_5", "description": "Unused (pad)"},
74+
{"index": 6, "name": "unused_6", "description": "Unused (pad)"},
75+
{"index": 7, "name": "unused_7", "description": "Unused (pad)"}
76+
]
77+
},
78+
{
79+
"row": 2, "description": "Particle-field coupling + CIL + drag params",
80+
"slots": [
81+
{"index": 0, "name": "Pe", "description": "Peclet number"},
82+
{"index": 1, "name": "consumption", "description": "Consumption rate of C1"},
83+
{"index": 2, "name": "production", "description": "Production rate of C2"},
84+
{"index": 3, "name": "influence_radius", "description": "Gaussian pf influence radius"},
85+
{"index": 4, "name": "rho_0", "description": "CIL critical density threshold"},
86+
{"index": 5, "name": "hill_n", "description": "CIL Hill coefficient"},
87+
{"index": 6, "name": "fp_drag", "description": "Velocity-dependent fp drag (0=off)"},
88+
{"index": 7, "name": "unused_7", "description": "Unused (pad)"}
89+
]
90+
}
91+
],
92+
"width_constraint": "ALL rows of params_mesh MUST have same number of columns (8)."
93+
}
94+
95+
def __init__(self, aggr_type='mean', p=None, particle_params=None, bc_dpos=None, dimension=2, sigma=0.005):
96+
super(PDE_D_DensityDragCIL, self).__init__(aggr=aggr_type)
97+
98+
self.p = p
99+
self.particle_params = particle_params
100+
self.bc_dpos = bc_dpos
101+
self.dimension = dimension
102+
self.sigma = sigma
103+
104+
# Global parameters from mesh
105+
self.M1 = p[0, 5]
106+
self.M2 = p[1, 1]
107+
self.consumption_rate = p[2, 1]
108+
self.production_rate = p[2, 2]
109+
self.influence_radius = p[2, 3]
110+
self.Pe = p[2, 0]
111+
self.repulsion_strength = 50
112+
self.repulsion_range = 0.04
113+
114+
# CIL parameters (Mayor & Carmona-Fontaine 2010)
115+
self.rho_0 = p[2, 4] if p.shape[1] > 4 and p[2, 4] != 0 else 34.0
116+
self.hill_n = p[2, 5] if p.shape[1] > 5 and p[2, 5] != 0 else 2.0
117+
self.sensing_radius = 0.05
118+
119+
# Velocity-dependent fp drag (Tranquillo & Lauffenburger 1987)
120+
self.fp_drag = p[2, 6] if p.shape[1] > 6 and p[2, 6] != 0 else 0.0
121+
self.v_ref = 0.01
122+
123+
# Convert to proper tensors if needed
124+
if not isinstance(self.rho_0, torch.Tensor):
125+
self.rho_0 = torch.tensor(float(self.rho_0), device=p.device)
126+
if not isinstance(self.hill_n, torch.Tensor):
127+
self.hill_n = torch.tensor(float(self.hill_n), device=p.device)
128+
if not isinstance(self.fp_drag, torch.Tensor):
129+
self.fp_drag = torch.tensor(float(self.fp_drag), device=p.device)
130+
131+
# Storage for local density (computed in pp pass, used in fp pass)
132+
self.local_density = None
133+
134+
# Report configuration
135+
rho0_val = self.rho_0.item() if hasattr(self.rho_0, 'item') else self.rho_0
136+
hill_val = self.hill_n.item() if hasattr(self.hill_n, 'item') else self.hill_n
137+
drag_val = self.fp_drag.item() if hasattr(self.fp_drag, 'item') else self.fp_drag
138+
print(f"initialized PDE_D_DensityDragCIL with parameters:")
139+
print(f" mobility: M1={self.M1.item()}, M2={self.M2.item()}")
140+
print(f" CIL: rho_0={rho0_val}, hill_n={hill_val}, sensing_radius={self.sensing_radius} (Mayor 2010)")
141+
print(f" fp_drag={drag_val:.3f}, v_ref={self.v_ref:.4f} (Tranquillo 1987)")
142+
print(f" Pe={self.Pe.item():.3f}, sigma={self.sigma}")
143+
print(f" particle->field: consumption={self.consumption_rate.item()}, production={self.production_rate.item()}, influence_radius={self.influence_radius.item():.3f}")
144+
if particle_params is not None:
145+
print(f" multi-type support: {particle_params.shape[0]} particle types")
146+
147+
def forward(self, data, direction='fp'):
148+
x, edge_index = data.x, data.edge_index
149+
edge_index, _ = pyg_utils.remove_self_loops(edge_index)
150+
151+
if self.particle_params is not None:
152+
particle_type = x[:, 1 + 2*self.dimension].long()
153+
max_type = particle_type.max().item()
154+
n_param_rows = self.particle_params.shape[0]
155+
if max_type >= n_param_rows:
156+
raise ValueError(
157+
f"PDE_D_DensityDragCIL: particle_params has {n_param_rows} rows but found "
158+
f"particle type {max_type}. Need {max_type + 1} rows in simulation.params."
159+
)
160+
parameters = self.particle_params[to_numpy(particle_type), :]
161+
else:
162+
parameters = None
163+
164+
if direction == 'interpolate':
165+
result = self.propagate(edge_index, x=x, mode='interpolate', parameters=parameters)
166+
pos = x[:, 1:self.dimension+1]
167+
in_box = ((pos >= 0) & (pos <= 1)).all(dim=1, keepdim=True)
168+
result = result * in_box.float()
169+
return result
170+
elif direction == 'fp':
171+
result = self.propagate(edge_index, x=x, mode='fp', parameters=parameters)
172+
173+
# Apply density-dependent + velocity drag modulation
174+
if self.local_density is not None:
175+
n_total = x.size(0)
176+
n_particles = self.local_density.size(0)
177+
n_nodes = n_total - n_particles
178+
179+
# CIL Hill function modulation
180+
ratio = self.local_density / self.rho_0
181+
cil_modulation = 1.0 / (1.0 + ratio ** self.hill_n)
182+
183+
# Apply to particle portion only
184+
mod_full = torch.ones(n_total, 1, device=x.device)
185+
mod_full[n_nodes:, 0] = cil_modulation
186+
result = result * mod_full
187+
188+
# Apply velocity-dependent drag at the aggregate level
189+
if self.fp_drag > 0:
190+
n_total = x.size(0)
191+
vel = x[:, 1+self.dimension:1+2*self.dimension]
192+
speed = torch.sqrt(torch.sum(vel**2, dim=1, keepdim=True))
193+
drag_factor = 1.0 / (1.0 + self.fp_drag * speed / self.v_ref)
194+
result = result * drag_factor
195+
196+
pos = x[:, 1:self.dimension+1]
197+
in_box = ((pos >= 0) & (pos <= 1)).all(dim=1, keepdim=True)
198+
result = result * in_box.float()
199+
return result
200+
elif direction == 'pf':
201+
result = self.propagate(edge_index, x=x, mode='pf', parameters=parameters)
202+
return result
203+
else: # direction == 'pp'
204+
self._compute_local_density(x, edge_index)
205+
result = self.propagate(edge_index, x=x, mode='pp', parameters=parameters)
206+
return result
207+
208+
def _compute_local_density(self, x, edge_index):
209+
"""Count particle neighbors within sensing_radius for CIL."""
210+
n_particles = x.size(0)
211+
target_nodes = edge_index[1]
212+
213+
pos_i = x[edge_index[1], 1:self.dimension+1]
214+
pos_j = x[edge_index[0], 1:self.dimension+1]
215+
d_pos = self.bc_dpos(pos_j - pos_i)
216+
dist = torch.sqrt(torch.sum(d_pos**2, dim=1))
217+
218+
within_radius = dist < self.sensing_radius
219+
counts = torch.zeros(n_particles, device=x.device)
220+
counts.scatter_add_(0, target_nodes[within_radius],
221+
torch.ones(within_radius.sum(), device=x.device))
222+
223+
self.local_density = counts
224+
225+
def message(self, edge_index_i, edge_index_j, x_i, x_j, mode=None, parameters_i=None):
226+
pos_i = x_i[:, 1:self.dimension+1]
227+
pos_j = x_j[:, 1:self.dimension+1]
228+
229+
d_pos = self.bc_dpos(pos_j - pos_i)
230+
dist = torch.sqrt(torch.sum(d_pos**2, dim=1))
231+
dist_safe = torch.clamp(dist, min=1e-6)
232+
233+
if mode == 'interpolate':
234+
C1_mesh = x_j[:, 6:7]
235+
C2_mesh = x_j[:, 7:8]
236+
weight = torch.exp(-dist / 0.01).unsqueeze(1)
237+
return torch.cat([C1_mesh * weight, C2_mesh * weight, weight], dim=1)
238+
239+
elif mode == 'fp':
240+
fields_i = x_i[:, 6:8]
241+
fields_j = x_j[:, 6:8]
242+
243+
dC1 = fields_j[:, 0:1] - fields_i[:, 0:1]
244+
dC2 = fields_j[:, 1:2] - fields_i[:, 1:2]
245+
246+
kernel = torch.exp(-dist / 0.05)
247+
dir_norm = d_pos / dist_safe.unsqueeze(1)
248+
domain_scale = 32.0
249+
grad_C1 = (dC1 * kernel.unsqueeze(1)) / (dist_safe.unsqueeze(1) * domain_scale)
250+
grad_C2 = (dC2 * kernel.unsqueeze(1)) / (dist_safe.unsqueeze(1) * domain_scale)
251+
252+
if parameters_i is not None:
253+
M1 = parameters_i[:, 0:1]
254+
M2 = parameters_i[:, 1:2]
255+
else:
256+
M1 = self.M1
257+
M2 = self.M2
258+
259+
velocities = (M1 * grad_C1 + M2 * grad_C2) * dir_norm
260+
return velocities
261+
262+
elif mode == 'pf':
263+
weights = torch.exp(-dist**2 / (2 * (self.influence_radius/3)**2))
264+
265+
if parameters_i is not None:
266+
consumption = parameters_i[:, 2]
267+
production = parameters_i[:, 3]
268+
else:
269+
consumption = self.consumption_rate
270+
production = self.production_rate
271+
272+
field_updates = torch.zeros((pos_i.size(0), 2), device=pos_i.device)
273+
field_updates[:, 0] = -consumption * weights
274+
field_updates[:, 1] = production * weights
275+
return field_updates
276+
277+
else: # mode == 'pp'
278+
if parameters_i is not None:
279+
p1 = parameters_i[:, 4]
280+
p2 = parameters_i[:, 5]
281+
p3 = parameters_i[:, 6]
282+
p4 = parameters_i[:, 7]
283+
284+
f = (p1 * torch.exp(-dist ** (2 * p2) / (2 * self.sigma ** 2))
285+
- p3 * torch.exp(-dist ** (2 * p4) / (2 * self.sigma ** 2)))
286+
287+
forces = f[:, None] * d_pos / dist_safe.unsqueeze(1)
288+
else:
289+
forces = torch.zeros_like(pos_i)
290+
in_range = dist < self.repulsion_range
291+
if in_range.any():
292+
dir_norm = d_pos / dist_safe.unsqueeze(1)
293+
repulsion_mag = self.repulsion_strength * torch.exp(
294+
-5.0 * dist[in_range] / self.repulsion_range
295+
)
296+
forces[in_range] = -dir_norm[in_range] * repulsion_mag.unsqueeze(1)
297+
298+
return forces
299+
300+
def update(self, aggr_out, mode=None):
301+
if mode == 'interpolate':
302+
C1_weighted = aggr_out[:, 0:1]
303+
C2_weighted = aggr_out[:, 1:2]
304+
weight_sum = aggr_out[:, 2:3]
305+
weight_sum = torch.clamp(weight_sum, min=1e-10)
306+
return torch.cat([C1_weighted / weight_sum, C2_weighted / weight_sum], dim=1)
307+
else:
308+
return aggr_out

0 commit comments

Comments
 (0)