Skip to content

Commit 2bf45e1

Browse files
committed
[Iter 4] Code modification in PDE_D_FieldModulated.py
[Automated commit by Claude]
1 parent d7d13e0 commit 2bf45e1

1 file changed

Lines changed: 269 additions & 0 deletions

File tree

Lines changed: 269 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,269 @@
1+
import torch
2+
import torch_geometric as pyg
3+
import torch_geometric.utils as pyg_utils
4+
from ParticleGraph.utils import to_numpy
5+
6+
class PDE_D_FieldModulated(pyg.nn.MessagePassing):
7+
"""
8+
Field-dependent mobility and field-modulated particle-particle adhesion.
9+
10+
Combines two related field-concentration-dependent features:
11+
1. FDM: Mobility depends on local C1 deviation from Brusselator steady state A
12+
2. Field-modulated pp: Adhesion strength scales with local C1 concentration
13+
14+
Literature:
15+
- Hillen, T. & Painter, K. J. (2009) J Math Biol 58:183-217
16+
"A user's guide to PDE models for chemotaxis"
17+
- Hynes, R. O. (2002) Cell 110:673-687
18+
"Integrins: bidirectional, allosteric signaling machines"
19+
- Schwartz, M. A. & Ginsberg, M. H. (2002) Nat Cell Biol 4:E65-E68
20+
"Networks and crosstalk: integrin signaling spreads"
21+
22+
Physics:
23+
FDM (positive alpha): M_eff = M * (1 + alpha * clamp((C1-A)^2/A^2, max=4))
24+
FDM (negative alpha): M_eff = M / (1 + |alpha| * clamp((C1-A)^2/A^2, max=4))
25+
Field-modulated pp: f_eff = f * (1 + alpha * clamp(C1/C1_ref, 0, 2))
26+
27+
Per-type params layout: [M1, M2, consumption, production, ar_p1, ar_p2, ar_p3, ar_p4]
28+
"""
29+
30+
PARAMS_DOC = {
31+
"model_name": "FieldModulated",
32+
"literature": "Hillen & Painter (2009) J Math Biol 58:183-217; Hynes (2002) Cell 110:673-687",
33+
"description": "Field-dependent mobility + field-modulated particle-particle adhesion",
34+
"equations": {
35+
"field_to_particle": "v = M * fdm_factor * (grad_C1 + grad_C2) * dir; fdm_factor depends on (C1-A)^2/A^2",
36+
"particle_to_field": "dC1 = -consumption * w(r), dC2 = production * w(r)",
37+
"particle_to_particle": "f = AR_force * (1 + pp_field_mod * clamp(C1/C1_ref, 0, 2))"
38+
},
39+
"params_mesh": [
40+
{
41+
"row": 0, "description": "C1 field parameters (shared with mesh model) + FDM control",
42+
"slots": [
43+
{"index": 0, "name": "D1", "description": "Diffusion coeff for C1 (mesh model)", "typical_range": [0.01, 0.5]},
44+
{"index": 1, "name": "Da_c", "description": "Damkohler number (mesh model)", "typical_range": [1.0, 50.0]},
45+
{"index": 2, "name": "A", "description": "Brusselator param A (mesh model, also FDM reference)", "typical_range": [0.5, 5.0]},
46+
{"index": 3, "name": "B", "description": "Brusselator param B (mesh model)", "typical_range": [1.0, 10.0]},
47+
{"index": 4, "name": "mu", "description": "Morphological parameter (mesh model)", "typical_range": [0.01, 0.1]},
48+
{"index": 5, "name": "M1", "description": "Mobility coefficient for C1 gradients", "typical_range": [-16, 16]},
49+
{"index": 6, "name": "fdm_alpha", "description": "Field-dependent mobility (0=off, >0=faster at peaks, <0=slower at peaks)", "typical_range": [-2.0, 2.0]}
50+
]
51+
},
52+
{
53+
"row": 1, "description": "C2 field parameters",
54+
"slots": [
55+
{"index": 0, "name": "D2", "description": "Diffusion coeff for C2 (mesh model)", "typical_range": [0.1, 1.0]},
56+
{"index": 1, "name": "M2", "description": "Mobility coefficient for C2 gradients", "typical_range": [-16, 16]}
57+
]
58+
},
59+
{
60+
"row": 2, "description": "Particle-field coupling + field-modulated pp control",
61+
"slots": [
62+
{"index": 0, "name": "Pe", "description": "Peclet number", "typical_range": [0.5, 2.0]},
63+
{"index": 1, "name": "consumption", "description": "Particle consumption rate of C1", "typical_range": [10, 200]},
64+
{"index": 2, "name": "production", "description": "Particle production rate of C2", "typical_range": [-200, -10]},
65+
{"index": 3, "name": "influence_radius", "description": "Gaussian influence radius for pf coupling", "typical_range": [0.01, 0.1]},
66+
{"index": 4, "name": "unused", "description": "Unused (pad)", "typical_range": [0.0, 0.0]},
67+
{"index": 5, "name": "unused", "description": "Unused (pad)", "typical_range": [0.0, 0.0]},
68+
{"index": 6, "name": "pp_field_mod", "description": "Field-modulated pp adhesion (0=off, >0=stronger at peaks)", "typical_range": [0.0, 1.0]}
69+
]
70+
}
71+
],
72+
"width_constraint": "ALL rows of params_mesh MUST have same number of columns (7). Pad shorter rows.",
73+
"particle_params": {
74+
"description": "Per-type params from simulation.params (one row per n_particle_types)",
75+
"slots": [
76+
{"index": 0, "name": "M1", "description": "Per-type mobility for C1"},
77+
{"index": 1, "name": "M2", "description": "Per-type mobility for C2"},
78+
{"index": 2, "name": "consumption", "description": "Per-type consumption rate"},
79+
{"index": 3, "name": "production", "description": "Per-type production rate"},
80+
{"index": 4, "name": "ar_p1", "description": "Attraction strength"},
81+
{"index": 5, "name": "ar_p2", "description": "Attraction exponent"},
82+
{"index": 6, "name": "ar_p3", "description": "Repulsion strength"},
83+
{"index": 7, "name": "ar_p4", "description": "Repulsion exponent"}
84+
]
85+
}
86+
}
87+
88+
def __init__(self, aggr_type='mean', p=None, particle_params=None, bc_dpos=None, dimension=2, sigma=0.005):
89+
super(PDE_D_FieldModulated, self).__init__(aggr=aggr_type)
90+
91+
self.p = p
92+
self.particle_params = particle_params
93+
self.bc_dpos = bc_dpos
94+
self.dimension = dimension
95+
self.sigma = sigma
96+
97+
self.M1 = p[0, 5]
98+
self.M2 = p[1, 1]
99+
self.consumption_rate = p[2, 1]
100+
self.production_rate = p[2, 2]
101+
self.influence_radius = p[2, 3]
102+
self.Pe = p[2, 0]
103+
self.repulsion_strength = 50
104+
self.repulsion_range = 0.04
105+
106+
# FDM: field-dependent mobility
107+
self.fdm_alpha = p[0, 6] if p.shape[1] > 6 else 0.0
108+
self.A_ref = p[0, 2]
109+
110+
# Field-modulated pp adhesion
111+
if p.shape[0] > 2 and p.shape[1] > 6:
112+
self.pp_field_mod = p[2, 6]
113+
else:
114+
self.pp_field_mod = 0.0
115+
116+
print(f"initialized PDE_D_FieldModulated with parameters:")
117+
print(f" mobility: M1={self.M1.item()}, M2={self.M2.item()}")
118+
fdm_val = self.fdm_alpha.item() if hasattr(self.fdm_alpha, 'item') else self.fdm_alpha
119+
print(f" fdm_alpha={fdm_val:.3f} (M_eff depends on (C1-A)^2/A^2, Hillen & Painter 2009)")
120+
ppfm_val = self.pp_field_mod.item() if hasattr(self.pp_field_mod, 'item') else self.pp_field_mod
121+
print(f" pp_field_mod={ppfm_val:.3f} (f_eff = f*(1+alpha*C1_norm), Hynes 2002)")
122+
print(f" Pe={self.Pe.item():.3f}, sigma={self.sigma}")
123+
print(f" particle->field: consumption={self.consumption_rate.item()}, production={self.production_rate.item()}, influence_radius={self.influence_radius.item():.3f}")
124+
if particle_params is not None:
125+
print(f" multi-type support: {particle_params.shape[0]} particle types")
126+
127+
def forward(self, data, direction='fp'):
128+
x, edge_index = data.x, data.edge_index
129+
edge_index, _ = pyg_utils.remove_self_loops(edge_index)
130+
131+
if self.particle_params is not None:
132+
particle_type = x[:, 1 + 2*self.dimension].long()
133+
max_type = particle_type.max().item()
134+
n_param_rows = self.particle_params.shape[0]
135+
if max_type >= n_param_rows:
136+
raise ValueError(
137+
f"PDE_D_FieldModulated: particle_params has {n_param_rows} rows but found "
138+
f"particle type {max_type}. Need {max_type + 1} rows in simulation.params."
139+
)
140+
parameters = self.particle_params[to_numpy(particle_type), :]
141+
else:
142+
parameters = None
143+
144+
if direction == 'interpolate':
145+
result = self.propagate(edge_index, x=x, mode='interpolate', parameters=parameters)
146+
pos = x[:, 1:self.dimension+1]
147+
in_box = ((pos >= 0) & (pos <= 1)).all(dim=1, keepdim=True)
148+
result = result * in_box.float()
149+
return result
150+
elif direction == 'fp':
151+
result = self.propagate(edge_index, x=x, mode='fp', parameters=parameters)
152+
pos = x[:, 1:self.dimension+1]
153+
in_box = ((pos >= 0) & (pos <= 1)).all(dim=1, keepdim=True)
154+
result = result * in_box.float()
155+
return result
156+
elif direction == 'pf':
157+
result = self.propagate(edge_index, x=x, mode='pf', parameters=parameters)
158+
return result
159+
else:
160+
result = self.propagate(edge_index, x=x, mode='pp', parameters=parameters)
161+
return result
162+
163+
def message(self, edge_index_i, edge_index_j, x_i, x_j, mode=None, parameters_i=None):
164+
pos_i = x_i[:, 1:self.dimension+1]
165+
pos_j = x_j[:, 1:self.dimension+1]
166+
167+
d_pos = self.bc_dpos(pos_j - pos_i)
168+
dist = torch.sqrt(torch.sum(d_pos**2, dim=1))
169+
dist_safe = torch.clamp(dist, min=1e-6)
170+
171+
if mode == 'interpolate':
172+
C1_mesh = x_j[:, 6:7]
173+
C2_mesh = x_j[:, 7:8]
174+
weight = torch.exp(-dist / 0.01).unsqueeze(1)
175+
return torch.cat([C1_mesh * weight, C2_mesh * weight, weight], dim=1)
176+
177+
elif mode == 'fp':
178+
fields_i = x_i[:, 6:8]
179+
fields_j = x_j[:, 6:8]
180+
181+
dC1 = fields_j[:, 0:1] - fields_i[:, 0:1]
182+
dC2 = fields_j[:, 1:2] - fields_i[:, 1:2]
183+
184+
kernel = torch.exp(-dist / 0.05)
185+
dir_norm = d_pos / dist_safe.unsqueeze(1)
186+
domain_scale = 32.0
187+
grad_C1 = (dC1 * kernel.unsqueeze(1)) / (dist_safe.unsqueeze(1) * domain_scale)
188+
grad_C2 = (dC2 * kernel.unsqueeze(1)) / (dist_safe.unsqueeze(1) * domain_scale)
189+
190+
if parameters_i is not None:
191+
M1 = parameters_i[:, 0:1]
192+
M2 = parameters_i[:, 1:2]
193+
else:
194+
M1 = self.M1
195+
M2 = self.M2
196+
197+
velocity_raw = (M1 * grad_C1 + M2 * grad_C2) * dir_norm
198+
199+
# Field-dependent mobility (FDM)
200+
if self.fdm_alpha != 0:
201+
C1_local = fields_i[:, 0:1]
202+
A_ref = self.A_ref
203+
deviation_sq = (C1_local - A_ref) ** 2 / (A_ref ** 2 + 1e-6)
204+
deviation_sq = torch.clamp(deviation_sq, max=4.0)
205+
206+
if self.fdm_alpha > 0:
207+
fdm_factor = 1.0 + self.fdm_alpha * deviation_sq
208+
else:
209+
fdm_factor = 1.0 / (1.0 + torch.abs(self.fdm_alpha) * deviation_sq)
210+
211+
velocity_raw = velocity_raw * fdm_factor
212+
213+
return velocity_raw
214+
215+
elif mode == 'pf':
216+
weights = torch.exp(-dist**2 / (2 * (self.influence_radius/3)**2))
217+
218+
if parameters_i is not None:
219+
consumption = parameters_i[:, 2]
220+
production = parameters_i[:, 3]
221+
else:
222+
consumption = self.consumption_rate
223+
production = self.production_rate
224+
225+
field_updates = torch.zeros((pos_i.size(0), 2), device=pos_i.device)
226+
field_updates[:, 0] = -consumption * weights
227+
field_updates[:, 1] = production * weights
228+
return field_updates
229+
230+
else: # mode == 'pp'
231+
if parameters_i is not None:
232+
p1 = parameters_i[:, 4]
233+
p2 = parameters_i[:, 5]
234+
p3 = parameters_i[:, 6]
235+
p4 = parameters_i[:, 7]
236+
237+
f = (p1 * torch.exp(-dist ** (2 * p2) / (2 * self.sigma ** 2))
238+
- p3 * torch.exp(-dist ** (2 * p4) / (2 * self.sigma ** 2)))
239+
240+
# Field-modulated pp adhesion
241+
if self.pp_field_mod > 0:
242+
C1_local = x_i[:, 6]
243+
C1_ref = torch.clamp(torch.abs(C1_local).mean(), min=1.0)
244+
C1_norm = torch.clamp(C1_local / C1_ref, min=0.0, max=2.0)
245+
field_factor = 1.0 + self.pp_field_mod * C1_norm
246+
f = f * field_factor
247+
248+
forces = f[:, None] * d_pos / dist_safe.unsqueeze(1)
249+
else:
250+
forces = torch.zeros_like(pos_i)
251+
in_range = dist < self.repulsion_range
252+
if in_range.any():
253+
dir_norm = d_pos / dist_safe.unsqueeze(1)
254+
repulsion_mag = self.repulsion_strength * torch.exp(
255+
-5.0 * dist[in_range] / self.repulsion_range
256+
)
257+
forces[in_range] = -dir_norm[in_range] * repulsion_mag.unsqueeze(1)
258+
259+
return forces
260+
261+
def update(self, aggr_out, mode=None):
262+
if mode == 'interpolate':
263+
C1_weighted = aggr_out[:, 0:1]
264+
C2_weighted = aggr_out[:, 1:2]
265+
weight_sum = aggr_out[:, 2:3]
266+
weight_sum = torch.clamp(weight_sum, min=1e-10)
267+
return torch.cat([C1_weighted / weight_sum, C2_weighted / weight_sum], dim=1)
268+
else:
269+
return aggr_out

0 commit comments

Comments
 (0)