Skip to content

Commit c8bce43

Browse files
Update octree.h
1 parent b124d06 commit c8bce43

1 file changed

Lines changed: 79 additions & 54 deletions

File tree

src/gravity/octree.h

Lines changed: 79 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -10,38 +10,48 @@
1010
// along with this program. If not, see <http://www.gnu.org/licenses/>.
1111

1212
#pragma once
13-
#include "../struct/particle.h"
13+
#include "floatdef.h"
14+
#include "dt/softening.h"
1415
#include <vector>
1516
#include <cmath>
16-
#include <memory> // Required for unique_ptr
17-
#include "dt/softening.h"
18-
#include "floatdef.h"
17+
#include <memory>
18+
#include <algorithm>
1919

20+
/**
21+
* @brief Octree node structure redesigned for Structure of Arrays (SoA).
22+
* Instead of storing Particle pointers, it stores indices into the ParticleSystem.
23+
*/
2024
struct Octree {
21-
real cx, cy, cz; // COM
22-
real m; // mass
23-
real x, y, z; // node center
24-
real size; // half-width
25+
real cx, cy, cz; // Center of Mass
26+
real m; // Total Mass
27+
real x, y, z; // Geometric center of node
28+
real size; // Half-width of node
2529
bool leaf = true;
26-
Particle* body = nullptr;
30+
31+
// Index of the particle in the ParticleSystem. -1 means empty.
32+
int bodyIdx = -1;
2733

2834
// Ownership: unique_ptr handles memory automatically
29-
std::unique_ptr<Octree> child[8] = { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr };
35+
std::unique_ptr<Octree> child[8] = { nullptr };
3036

31-
// Quadrupole tensor
37+
// Quadrupole tensor for higher-order gravity approximation
3238
real Qxx = 0, Qyy = 0, Qzz = 0;
3339
real Qxy = 0, Qxz = 0, Qyz = 0;
3440

3541
Octree(real X, real Y, real Z, real S) : x(X), y(Y), z(Z), size(S), m(0), cx(0), cy(0), cz(0) {}
3642

37-
// Destructor is now empty; unique_ptr cleans up children automatically
3843
~Octree() = default;
3944

40-
int index(const Particle& p) const {
41-
return (p.x > x) * 1 + (p.y > y) * 2 + (p.z > z) * 4;
45+
/**
46+
* @brief Determines which octant a particle belongs to.
47+
*/
48+
int getOctant(real px, real py, real pz) const {
49+
return (px > x) * 1 + (py > y) * 2 + (pz > z) * 4;
4250
}
4351

44-
// Returns unique_ptr to take ownership
52+
/**
53+
* @brief Creates a new child node in the specified octant.
54+
*/
4555
std::unique_ptr<Octree> createChild(int idx) {
4656
real hs = size * real(0.5);
4757
return std::make_unique<Octree>(
@@ -52,38 +62,48 @@ struct Octree {
5262
);
5363
}
5464

55-
void insert(Particle* p) {
56-
if (leaf && body == nullptr) {
57-
body = p;
65+
/**
66+
* @brief Inserts a particle index into the tree.
67+
*/
68+
void insert(int idx, const ParticleSystem& ps) {
69+
if (leaf && bodyIdx == -1) {
70+
bodyIdx = idx;
5871
return;
5972
}
6073

6174
if (leaf) {
6275
leaf = false;
63-
Particle* old = body;
64-
body = nullptr;
65-
int idx = index(*old);
66-
if (!child[idx]) child[idx] = createChild(idx);
67-
child[idx]->insert(old);
76+
int oldIdx = bodyIdx;
77+
bodyIdx = -1;
78+
int oct = getOctant(ps.x[oldIdx], ps.y[oldIdx], ps.z[oldIdx]);
79+
if (!child[oct]) child[oct] = createChild(oct);
80+
child[oct]->insert(oldIdx, ps);
6881
}
6982

70-
int idx = index(*p);
71-
if (!child[idx]) child[idx] = createChild(idx);
72-
child[idx]->insert(p);
83+
int oct = getOctant(ps.x[idx], ps.y[idx], ps.z[idx]);
84+
if (!child[oct]) child[oct] = createChild(oct);
85+
child[oct]->insert(idx, ps);
7386
}
7487

75-
void computeMass() {
88+
/**
89+
* @brief Recursively computes mass properties and quadrupole moments.
90+
*/
91+
void computeMass(const ParticleSystem& ps) {
7692
if (leaf) {
77-
if (body) { m = body->m; cx = body->x; cy = body->y; cz = body->z; }
78-
else { m = 0; cx = cy = cz = 0; }
93+
if (bodyIdx != -1) {
94+
m = ps.m[bodyIdx];
95+
cx = ps.x[bodyIdx]; cy = ps.y[bodyIdx]; cz = ps.z[bodyIdx];
96+
} else {
97+
m = 0; cx = cy = cz = 0;
98+
}
7999
Qxx = Qyy = Qzz = Qxy = Qxz = Qyz = 0;
80100
return;
81101
}
82102

83103
m = 0; cx = cy = cz = 0;
84-
for (auto& c : child) { // Use reference to unique_ptr
104+
for (auto& c : child) {
85105
if (!c) continue;
86-
c->computeMass();
106+
c->computeMass(ps);
87107
if (c->m == 0) continue;
88108
m += c->m;
89109
cx += c->cx * c->m; cy += c->cy * c->m; cz += c->cz * c->m;
@@ -95,44 +115,49 @@ struct Octree {
95115
if (!c || c->m == 0) continue;
96116
real rx = c->cx - cx; real ry = c->cy - cy; real rz = c->cz - cz;
97117
real r2 = rx * rx + ry * ry + rz * rz + (size * size * real(0.01));
98-
real mchild = c->m;
99-
Qxx += mchild * (3 * rx * rx - r2);
100-
Qyy += mchild * (3 * ry * ry - r2);
101-
Qzz += mchild * (3 * rz * rz - r2);
102-
Qxy += mchild * (3 * rx * ry);
103-
Qxz += mchild * (3 * rx * rz);
104-
Qyz += mchild * (3 * ry * rz);
118+
real mc = c->m;
119+
Qxx += mc * (3 * rx * rx - r2);
120+
Qyy += mc * (3 * ry * ry - r2);
121+
Qzz += mc * (3 * rz * rz - r2);
122+
Qxy += mc * (3 * rx * ry);
123+
Qxz += mc * (3 * rx * rz);
124+
Qyz += mc * (3 * ry * rz);
105125
}
106126
}
107127
};
108128

109-
// Traverse using raw pointers (non-owning observer)
110-
inline void bhAccel(Octree* node, const Particle& p, real theta, real& ax, real& ay, real& az) {
129+
/**
130+
* @brief Barnes-Hut acceleration calculation for a target particle at index 'i'.
131+
*/
132+
inline void bhAccel(Octree* node, int i, const ParticleSystem& ps, real theta, real& ax, real& ay, real& az) {
111133
if (!node || node->m == 0) return;
112-
if (node->leaf && node->body == &p) return;
134+
if (node->leaf && node->bodyIdx == i) return;
113135

114136
constexpr real G = real(1.0);
115-
real dx = node->cx - p.x; real dy = node->cy - p.y; real dz = node->cz - p.z;
137+
real dx = node->cx - ps.x[i];
138+
real dy = node->cy - ps.y[i];
139+
real dz = node->cz - ps.z[i];
116140
real r2 = dx*dx + dy*dy + dz*dz;
117141
real dist = std::sqrt(r2 + real(1e-20));
118142

143+
// Adaptive softening for Dark Matter (type 1) vs Stars (type 0)
119144
real eps = nextSoftening(node->size, node->m, dist);
120-
if (p.type == 1) {
121-
eps = std::max(eps, 2.0 * node->size / std::pow(node->m / p.m, 1.0/3));
145+
if (ps.type[i] == 1) {
146+
eps = std::max(eps, 2.0 * node->size / std::pow(node->m / ps.m[i], 0.333333333));
122147
}
123148

124149
real r2_soft = r2 + eps*eps;
125-
real dist_soft = std::sqrt(r2_soft);
150+
real dist_inv = real(1.0) / std::sqrt(r2_soft);
126151

127152
if (node->leaf || (node->size / dist) < theta) {
128-
real invDist = real(1.0) / dist_soft;
129-
real invDist3 = invDist * invDist * invDist;
130-
real fac = G * node->m * invDist3;
153+
real inv3 = dist_inv * dist_inv * dist_inv;
154+
real fac = G * node->m * inv3;
131155

132156
ax += dx * fac; ay += dy * fac; az += dz * fac;
133157

134-
real invr5 = invDist3 * (invDist * invDist);
135-
real invr7 = invr5 * (invDist * invDist);
158+
// Quadrupole contributions
159+
real inv5 = inv3 * (dist_inv * dist_inv);
160+
real inv7 = inv5 * (dist_inv * dist_inv);
136161

137162
real q = node->Qxx*dx*dx + node->Qyy*dy*dy + node->Qzz*dz*dz +
138163
2*(node->Qxy*dx*dy + node->Qxz*dx*dz + node->Qyz*dy*dz);
@@ -141,13 +166,13 @@ inline void bhAccel(Octree* node, const Particle& p, real theta, real& ax, real&
141166
real Qry = 2*(node->Qxy*dx + node->Qyy*dy + node->Qyz*dz);
142167
real Qrz = 2*(node->Qxz*dx + node->Qyz*dy + node->Qzz*dz);
143168

144-
ax += (G * real(0.5)) * (Qrx * invr5 - 5 * q * invr7 * dx);
145-
ay += (G * real(0.5)) * (Qry * invr5 - 5 * q * invr7 * dy);
146-
az += (G * real(0.5)) * (Qrz * invr5 - 5 * q * invr7 * dz);
169+
ax += (G * real(0.5)) * (Qrx * inv5 - 5 * q * inv7 * dx);
170+
ay += (G * real(0.5)) * (Qry * inv5 - 5 * q * inv7 * dy);
171+
az += (G * real(0.5)) * (Qrz * inv5 - 5 * q * inv7 * dz);
147172
return;
148173
}
149174

150175
for (auto& c : node->child) {
151-
if (c) bhAccel(c.get(), p, theta, ax, ay, az); // Use .get() to pass raw pointer
176+
if (c) bhAccel(c.get(), i, ps, theta, ax, ay, az);
152177
}
153178
}

0 commit comments

Comments
 (0)