Skip to content

Commit b89a411

Browse files
Update octree.h
1 parent e746d72 commit b89a411

1 file changed

Lines changed: 48 additions & 159 deletions

File tree

src/gravity/octree.h

Lines changed: 48 additions & 159 deletions
Original file line numberDiff line numberDiff line change
@@ -16,228 +16,117 @@
1616
#include "dt/softening.h"
1717

1818
struct Octree {
19-
// monopole
20-
real cx, cy, cz; // center of mass
21-
real m; // total mass
22-
23-
// node geometry
24-
real x, y, z; // center of node
19+
real cx, cy, cz; // COM
20+
real m; // mass
21+
real x, y, z; // node center
2522
real size; // half-width
2623
bool leaf = true;
2724
Particle* body = nullptr;
2825
Octree* child[8] = { nullptr };
2926

30-
// symmetric quadrupole tensor (6 independent components)
27+
// Quadrupole tensor
3128
real Qxx = 0, Qyy = 0, Qzz = 0;
3229
real Qxy = 0, Qxz = 0, Qyz = 0;
3330

34-
Octree(real X, real Y, real Z, real S)
35-
: cx(0), cy(0), cz(0), m(0),
36-
x(X), y(Y), z(Z), size(S) {}
31+
Octree(real X, real Y, real Z, real S) : x(X), y(Y), z(Z), size(S), m(0), cx(0), cy(0), cz(0) {}
3732

38-
~Octree() {
39-
for (auto c : child) {
40-
delete c;
41-
}
42-
}
33+
~Octree() { for (auto c : child) delete c; }
4334

4435
int index(const Particle& p) const {
45-
return (p.x > x) * 1
46-
+ (p.y > y) * 2
47-
+ (p.z > z) * 4;
36+
return (p.x > x) * 1 + (p.y > y) * 2 + (p.z > z) * 4;
4837
}
4938

5039
Octree* createChild(int idx) {
5140
real hs = size * real(0.5);
52-
return new Octree(
53-
x + ((idx & 1) ? hs : -hs),
54-
y + ((idx & 2) ? hs : -hs),
55-
z + ((idx & 4) ? hs : -hs),
56-
hs
57-
);
41+
return new Octree(x + ((idx & 1) ? hs : -hs), y + ((idx & 2) ? hs : -hs), z + ((idx & 4) ? hs : -hs), hs);
5842
}
5943

6044
void insert(Particle* p) {
61-
if (leaf && body == nullptr) {
62-
body = p;
63-
return;
64-
}
65-
45+
if (leaf && body == nullptr) { body = p; return; }
6646
if (leaf) {
6747
leaf = false;
68-
Particle* old = body;
69-
body = nullptr;
48+
Particle* old = body; body = nullptr;
7049
int idx = index(*old);
71-
if (!child[idx]) {
72-
child[idx] = createChild(idx);
73-
}
50+
if (!child[idx]) child[idx] = createChild(idx);
7451
child[idx]->insert(old);
7552
}
76-
7753
int idx = index(*p);
78-
if (!child[idx]) {
79-
child[idx] = createChild(idx);
80-
}
54+
if (!child[idx]) child[idx] = createChild(idx);
8155
child[idx]->insert(p);
8256
}
8357

8458
void computeMass() {
8559
if (leaf) {
86-
if (body) {
87-
m = body->m;
88-
cx = body->x;
89-
cy = body->y;
90-
cz = body->z;
91-
} else {
92-
m = 0;
93-
cx = cy = cz = 0;
94-
}
95-
96-
// single particle → no internal quadrupole
97-
Qxx = Qyy = Qzz = 0;
98-
Qxy = Qxz = Qyz = 0;
60+
if (body) { m = body->m; cx = body->x; cy = body->y; cz = body->z; }
61+
else { m = 0; cx = cy = cz = 0; }
62+
Qxx = Qyy = Qzz = Qxy = Qxz = Qyz = 0;
9963
return;
10064
}
10165

102-
m = 0;
103-
cx = cy = cz = 0;
104-
105-
// first: recurse and accumulate mass + COM
66+
m = 0; cx = cy = cz = 0;
10667
for (auto c : child) {
10768
if (!c) continue;
10869
c->computeMass();
10970
if (c->m == 0) continue;
110-
111-
m += c->m;
112-
cx += c->cx * c->m;
113-
cy += c->cy * c->m;
114-
cz += c->cz * c->m;
71+
m += c->m;
72+
cx += c->cx * c->m; cy += c->cy * c->m; cz += c->cz * c->m;
11573
}
74+
if (m > 0) { cx /= m; cy /= m; cz /= m; }
11675

117-
if (m > 0) {
118-
cx /= m;
119-
cy /= m;
120-
cz /= m;
121-
} else {
122-
cx = cy = cz = 0;
123-
}
124-
125-
// second: build quadrupole from children treated as point masses
126-
Qxx = Qyy = Qzz = 0;
127-
Qxy = Qxz = Qyz = 0;
128-
76+
Qxx = Qyy = Qzz = Qxy = Qxz = Qyz = 0;
12977
for (auto c : child) {
13078
if (!c || c->m == 0) continue;
131-
132-
real rx = c->cx - cx;
133-
real ry = c->cy - cy;
134-
real rz = c->cz - cz;
135-
real r2 = rx * rx + ry * ry + rz * rz;
79+
real rx = c->cx - cx; real ry = c->cy - cy; real rz = c->cz - cz;
80+
// Internal node softening to match force calculation
81+
real r2 = rx * rx + ry * ry + rz * rz + (size * size * real(0.01));
13682
real mchild = c->m;
137-
13883
Qxx += mchild * (3 * rx * rx - r2);
13984
Qyy += mchild * (3 * ry * ry - r2);
14085
Qzz += mchild * (3 * rz * rz - r2);
141-
14286
Qxy += mchild * (3 * rx * ry);
14387
Qxz += mchild * (3 * rx * rz);
14488
Qyz += mchild * (3 * ry * rz);
14589
}
14690
}
14791
};
14892

149-
150-
inline void bhAccel(Octree* node,
151-
const Particle& p,
152-
real theta,
153-
real& ax,
154-
real& ay,
155-
real& az)
156-
{
157-
if (!node || node->m == 0)
158-
return;
159-
160-
// Skip self-force
161-
if (node->leaf && node->body == &p)
162-
return;
93+
inline void bhAccel(Octree* node, const Particle& p, real theta, real& ax, real& ay, real& az) {
94+
if (!node || node->m == 0) return;
95+
if (node->leaf && node->body == &p) return;
16396

16497
constexpr real G = real(1.0);
165-
166-
// Geometric separation
167-
real dx = node->cx - p.x;
168-
real dy = node->cy - p.y;
169-
real dz = node->cz - p.z;
170-
171-
// Physical distance (unsmoothed)
98+
real dx = node->cx - p.x; real dy = node->cy - p.y; real dz = node->cz - p.z;
17299
real r2 = dx*dx + dy*dy + dz*dz;
173-
real dist = std::sqrt(r2 + real(1e-20)); // tiny floor to avoid NaN
100+
real dist = std::sqrt(r2 + real(1e-20));
174101

175-
// Adaptive softening (returns epsilon)
176102
real eps = nextSoftening(node->size, node->m, dist);
103+
if (p.type == 1) eps *= real(2.0);
177104

178-
// Softened distance for force
179105
real r2_soft = r2 + eps*eps;
180106
real dist_soft = std::sqrt(r2_soft);
181107

182-
// BH acceptance criterion (use geometric distance)
183108
if (node->leaf || (node->size / dist) < theta) {
184-
// Monopole term with softened distance
185-
real invDist = real(1.0) / dist_soft;
186-
real invDist2 = invDist * invDist;
187-
real invDist3 = invDist * invDist2;
188-
109+
real invDist = real(1.0) / dist_soft;
110+
real invDist3 = invDist * invDist * invDist;
189111
real fac = G * node->m * invDist3;
190112

191-
real ax_m = dx * fac;
192-
real ay_m = dy * fac;
193-
real az_m = dz * fac;
194-
195-
// Quadrupole: use geometric r (no softening) for shape
196-
real rx = dx;
197-
real ry = dy;
198-
real rz = dz;
199-
real r2_q = rx*rx + ry*ry + rz*rz + real(1e-12); // avoid zero
200-
real r_q = std::sqrt(r2_q);
201-
real invr = real(1.0) / r_q;
202-
real invr2 = invr * invr;
203-
real invr3 = invr * invr2;
204-
real invr5 = invr3 * invr2;
205-
real invr7 = invr5 * invr2;
206-
207-
// q = r_i Q_ij r_j
208-
real q =
209-
node->Qxx * rx * rx +
210-
node->Qyy * ry * ry +
211-
node->Qzz * rz * rz +
212-
2 * (node->Qxy * rx * ry +
213-
node->Qxz * rx * rz +
214-
node->Qyz * ry * rz);
215-
216-
// ∇q = 2 Q r
217-
real Qrx = 2 * (node->Qxx * rx + node->Qxy * ry + node->Qxz * rz);
218-
real Qry = 2 * (node->Qxy * rx + node->Qyy * ry + node->Qyz * rz);
219-
real Qrz = 2 * (node->Qxz * rx + node->Qyz * ry + node->Qzz * rz);
220-
221-
// ∇(r^-5) = -5 r^-7 r
222-
real grad_r5_x = -5 * invr7 * rx;
223-
real grad_r5_y = -5 * invr7 * ry;
224-
real grad_r5_z = -5 * invr7 * rz;
225-
226-
// a_Q = (G/2) [ (∇q) r^-5 + q ∇(r^-5) ]
227-
real ax_q = (G * real(0.5)) * (Qrx * invr5 + q * grad_r5_x);
228-
real ay_q = (G * real(0.5)) * (Qry * invr5 + q * grad_r5_y);
229-
real az_q = (G * real(0.5)) * (Qrz * invr5 + q * grad_r5_z);
230-
231-
ax += ax_m + ax_q;
232-
ay += ay_m + ay_q;
233-
az += az_m + az_q;
113+
ax += dx * fac; ay += dy * fac; az += dz * fac;
114+
115+
real invr5 = invDist3 * (invDist * invDist);
116+
real invr7 = invr5 * (invDist * invDist);
117+
118+
real q = node->Qxx*dx*dx + node->Qyy*dy*dy + node->Qzz*dz*dz +
119+
2*(node->Qxy*dx*dy + node->Qxz*dx*dz + node->Qyz*dy*dz);
120+
121+
real Qrx = 2*(node->Qxx*dx + node->Qxy*dy + node->Qxz*dz);
122+
real Qry = 2*(node->Qxy*dx + node->Qyy*dy + node->Qyz*dz);
123+
real Qrz = 2*(node->Qxz*dx + node->Qyz*dy + node->Qzz*dz);
124+
125+
ax += (G * real(0.5)) * (Qrx * invr5 - 5 * q * invr7 * dx);
126+
ay += (G * real(0.5)) * (Qry * invr5 - 5 * q * invr7 * dy);
127+
az += (G * real(0.5)) * (Qrz * invr5 - 5 * q * invr7 * dz);
234128
return;
235129
}
236130

237-
// Recurse
238-
for (int i = 0; i < 8; ++i) {
239-
if (node->child[i]) {
240-
bhAccel(node->child[i], p, theta, ax, ay, az);
241-
}
242-
}
243-
}
131+
for (auto c : node->child) if (c) bhAccel(c, p, theta, ax, ay, az);
132+
}

0 commit comments

Comments
 (0)