Skip to content

Commit 6c2f24b

Browse files
Promotes cluster_tree up to physicsnemo.mesh.spatial for re-use. (#1710)
* Promotes `cluster_tree` up to `physicsnemo.mesh.spatial` for re-use. * lints --------- Co-authored-by: Corey adams <6619961+coreyjadams@users.noreply.github.com>
1 parent bcd10ce commit 6c2f24b

6 files changed

Lines changed: 148 additions & 99 deletions

File tree

physicsnemo/experimental/models/globe/__init__.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,6 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17-
from physicsnemo.experimental.models.globe.cluster_tree import (
18-
ClusterTree,
19-
DualInteractionPlan,
20-
SourceAggregates,
21-
)
2217
from physicsnemo.experimental.models.globe.field_kernel import (
2318
BarnesHutKernel,
2419
Kernel,
@@ -31,7 +26,4 @@
3126
"Kernel",
3227
"BarnesHutKernel",
3328
"MultiscaleKernel",
34-
"ClusterTree",
35-
"DualInteractionPlan",
36-
"SourceAggregates",
3729
]

physicsnemo/experimental/models/globe/field_kernel.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
logger = logging.getLogger("globe.field_kernel")
5252

5353
if TYPE_CHECKING:
54-
from physicsnemo.experimental.models.globe.cluster_tree import (
54+
from physicsnemo.mesh.spatial.cluster_tree import (
5555
ClusterTree,
5656
DualInteractionPlan,
5757
SourceAggregates,
@@ -919,7 +919,7 @@ def forward(
919919
TensorDict[str, Float[torch.Tensor, "n_targets ..."]]
920920
Kernel output fields at target points.
921921
"""
922-
from physicsnemo.experimental.models.globe.cluster_tree import (
922+
from physicsnemo.mesh.spatial.cluster_tree import (
923923
ClusterTree,
924924
DualInteractionPlan,
925925
SourceAggregates,
@@ -1757,7 +1757,7 @@ def forward(
17571757
TensorDict[str, Float[torch.Tensor, "n_targets ..."]]
17581758
Summed results from all kernel branches.
17591759
"""
1760-
from physicsnemo.experimental.models.globe.cluster_tree import ClusterTree
1760+
from physicsnemo.mesh.spatial.cluster_tree import ClusterTree
17611761

17621762
n_sources: int = len(source_points)
17631763
device = source_points.device

physicsnemo/experimental/models/globe/model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
from physicsnemo.core.meta import ModelMetaData
2929
from physicsnemo.core.module import Module
30-
from physicsnemo.experimental.models.globe.cluster_tree import (
30+
from physicsnemo.mesh.spatial.cluster_tree import (
3131
ClusterTree,
3232
DualInteractionPlan,
3333
)
@@ -428,7 +428,7 @@ def _build_trees_and_plans(
428428
comm_plans : dict[str, dict[str, DualInteractionPlan]]
429429
Communication plans indexed as ``comm_plans[dst_bc][src_bc]``.
430430
"""
431-
from physicsnemo.experimental.models.globe.cluster_tree import ClusterTree
431+
from physicsnemo.mesh.spatial.cluster_tree import ClusterTree
432432

433433
### ``no_grad`` is safe: tree inputs (centroids, areas) carry no grad
434434
### and the outputs are consumed downstream as integer indices and as
@@ -515,7 +515,7 @@ def _build_prediction_plans(
515515
Plans indexed by source BC type, each computed from that source
516516
BC's tree to ``pred_target_tree``.
517517
"""
518-
from physicsnemo.experimental.models.globe.cluster_tree import ClusterTree
518+
from physicsnemo.mesh.spatial.cluster_tree import ClusterTree
519519

520520
### See ``_build_trees_and_plans`` for the ``no_grad`` + build-device
521521
### rationale. ``cluster_trees`` arrive on the caller's device from

physicsnemo/mesh/spatial/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,12 @@
1818
1919
This module provides data structures and algorithms for fast spatial queries:
2020
- BVH (Bounding Volume Hierarchy) for point-in-cell queries
21+
- ClusterTree for dual-tree Barnes-Hut acceleration of kernel/attention operators
2122
"""
2223

2324
from physicsnemo.mesh.spatial.bvh import BVH
25+
from physicsnemo.mesh.spatial.cluster_tree import (
26+
ClusterTree,
27+
DualInteractionPlan,
28+
SourceAggregates,
29+
)

0 commit comments

Comments
 (0)