Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 73 additions & 4 deletions PyAPD/apds.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(
error_tolerance=0.01,
pixel_size_prefactor=2,
seed=-1,
periodic=False,
):
"""
Construct an anisotropic power diagram system.
Expand Down Expand Up @@ -86,6 +87,11 @@ def __init__(
Multiplier applied to the computed pixel count. Default: 2.
seed : int
Manual random seed (< 0 means no seed is set). Default: -1.
periodic : bool
If True, distances between seeds and points use the minimum-image
convention across the (rectilinear) domain, so the power diagram
is periodic and grains may wrap across the domain boundary.
Default: False (fully backward-compatible).
"""

self.N = N
Expand All @@ -103,6 +109,15 @@ def __init__(

self.set_domain(domain)

self.periodic = bool(periodic)
# Per-dimension edge lengths of the (rectilinear) domain. Kept on the
# same device/dtype as the domain so KeOps arithmetic stays homogeneous.
edge = (self.domain[:, 1] - self.domain[:, 0]).to(
device=self.device, dtype=self.dt
)
self._L_tensor = edge # shape (D,), plain torch
self._L = LazyTensor(edge.view(1, 1, self.D)) # broadcast over (N, M, D)

self.set_X(X)

self.set_As(As)
Expand Down Expand Up @@ -387,6 +402,22 @@ def mask_pixels(self, mask):
self.W = initial_guess_heuristic(self.As, self.target_masses, self.D)
self.w = LazyTensor(self.W.view(self.N, 1, 1))

def _displacement(self, y, x):
"""Return ``y - x`` as a KeOps LazyTensor, wrapped to the minimum
image across the periodic domain when ``self.periodic`` is True.

Uses the KeOps-compatible expression ``dy - L * round(dy / L)``,
which rounds ``dy / L`` to the nearest integer per component.
``LazyTensor.floor()`` is unavailable on some KeOps builds and the
Python ``%`` operator is not defined between LazyTensors, so the
``.round()`` form is used.
"""
dy = y - x
if self.periodic:
shift = (dy / self._L).round()
dy = dy - self._L * shift
return dy

def assemble_apd(
self, record_time=False, verbose=False, color_by=None, backend="auto"
):
Expand All @@ -396,7 +427,8 @@ def assemble_apd(
if self.Y is None:
self.assemble_pixels()
start = time.time()
D_ij = ((self.y - self.x) | self.a.matvecmult(self.y - self.x)) - self.w
dy = self._displacement(self.y, self.x)
D_ij = (dy | self.a.matvecmult(dy)) - self.w
# Find which grain each pixel belongs to
grain_indices = D_ij.argmin(dim=0, backend=backend).ravel()
time_taken = time.time() - start
Expand All @@ -410,6 +442,34 @@ def assemble_apd(
else:
return grain_indices

def grain_of(self, points, backend="auto"):
"""Return the grain index that owns each point in ``points``.

Parameters
----------
points : torch.Tensor of shape (M, D)
Query points, in the same coordinate frame as ``self.domain``.
When ``self.periodic`` is True, query points are compared to each
seed using the minimum-image distance across the domain.
backend : str
KeOps reduction backend (forwarded to argmin).

Returns
-------
torch.Tensor of shape (M,), dtype int64
For each row of ``points``, the grain index ``i`` minimising
``(y - x_i) . A_i . (y - x_i) - w_i`` under the active metric.
"""
pts = points.to(device=self.device, dtype=self.dt)
if pts.ndim != 2 or pts.shape[1] != self.D:
raise ValueError(
f"grain_of: expected (M, {self.D}), got {tuple(pts.shape)}"
)
y = LazyTensor(pts.view(1, pts.shape[0], self.D))
dy = self._displacement(y, self.x)
D_ij = (dy | self.a.matvecmult(dy)) - self.w
return D_ij.argmin(dim=0, backend=backend).view(-1)

def plot_apd(
self, color_by=None, mode="auto", alpha=None, ps_scale=False, marker_scale=20.0
):
Expand Down Expand Up @@ -603,7 +663,8 @@ def OT_dual_function(self, W, backend="auto"):
self.W = W
self.w = LazyTensor(self.W.view(self.N, 1, 1))

D_ij = ((self.y - self.x) | self.a.matvecmult(self.y - self.x)) - self.w
dy = self._displacement(self.y, self.x)
D_ij = (dy | self.a.matvecmult(dy)) - self.w
idx = D_ij.argmin(dim=0, backend=backend).view(-1)

ind_select = torch.index_select(self.X, 0, idx) - self.Y
Expand All @@ -623,7 +684,8 @@ def check_optimality(
if self.Y is None:
self.assemble_pixels()

D_ij = ((self.y - self.x) | self.a.matvecmult(self.y - self.x)) - self.w
dy = self._displacement(self.y, self.x)
D_ij = (dy | self.a.matvecmult(dy)) - self.w

grain_indices = D_ij.argmin(dim=0, backend=backend).ravel()
volumes = torch.bincount(grain_indices, self.PS, minlength=self.N)
Expand Down Expand Up @@ -707,7 +769,8 @@ def adjust_X(self, backend="auto"):
if not self.optimality:
print("Find optimal W first!")
else:
D_ij = ((self.y - self.x) | self.a.matvecmult(self.y - self.x)) - self.w
dy = self._displacement(self.y, self.x)
D_ij = (dy | self.a.matvecmult(dy)) - self.w
grain_indices = D_ij.argmin(dim=0, backend=backend).ravel()
normalisation = torch.bincount(grain_indices, self.PS, minlength=self.N)
new_X0 = (
Expand All @@ -730,6 +793,12 @@ def adjust_X(self, backend="auto"):
else:
self.X = torch.stack([new_X0, new_X1], dim=1)

if self.periodic:
# Centroids may fall outside the box when a grain straddles the
# periodic seam; wrap them back into [domain[:, 0], domain[:, 1)).
origin = self.domain[:, 0]
self.X = origin + torch.remainder(self.X - origin, self._L_tensor)

self.optimality = False
self.x = LazyTensor(self.X.view(self.N, 1, self.D))

Expand Down
117 changes: 117 additions & 0 deletions tests/test_apds_periodic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
"""Tests for the periodic (minimum-image) distance support in apd_system."""

import pytest
import torch

from PyAPD import apd_system


def test_periodic_flag_accepted():
"""apd_system must accept periodic=True without raising."""
apd = apd_system(N=4, D=3, periodic=True, seed=0, device="cpu")
assert apd.periodic is True


def test_non_periodic_is_default():
"""periodic defaults to False (backward-compatible)."""
apd = apd_system(N=4, D=2, seed=0, device="cpu")
assert apd.periodic is False


def test_periodic_single_seed_uniform_labels():
"""A single seed in a periodic box must label every voxel as grain 0,
regardless of where the seed sits (corner, centre, edge)."""
L = 1.0
for seed_pos in [[0.0, 0.0, 0.0], [0.5, 0.5, 0.5], [0.99, 0.01, 0.5]]:
X = torch.tensor([seed_pos])
apd = apd_system(
X=X,
D=3,
domain=torch.tensor([[0.0, L], [0.0, L], [0.0, L]]),
periodic=True,
pixel_params=(16, 16, 16),
device="cpu",
)
labels = apd.assemble_apd().cpu()
assert torch.all(labels == 0), (
f"seed at {seed_pos} produced labels {labels.unique()}"
)


def test_periodic_grain_wraps_across_seam():
"""Two seeds at x = 0.0 and x = 0.5 (y, z = 0.5). Under periodicity,
grain 0 (seed at the seam) owns x in [0, 0.25] u [0.75, 1.0] and grain 1
owns x in [0.25, 0.75]. So the columns at x=0 and x=L (column 0 and column
M-1) must BOTH belong to grain 0 -- the grain wraps across the periodic
seam. Without periodicity, column M-1 would be closer to seed 1 (x=0.5)
than to seed 0 (x=0.0), so it would belong to grain 1."""
L = 1.0
M = 32
X = torch.tensor([[0.0, 0.5, 0.5], [0.5, 0.5, 0.5]])
apd = apd_system(
X=X,
D=3,
domain=torch.tensor([[0.0, L], [0.0, L], [0.0, L]]),
periodic=True,
pixel_params=(M, 16, 16),
device="cpu",
)
labels = apd.assemble_apd().reshape(M, 16, 16).cpu()
# column 0 and column M-1 (both ~0.016 from the seam at x=0) must be in
# the same grain under periodicity, and that grain is grain 0.
assert torch.all(labels[0] == 0)
assert torch.all(labels[-1] == 0)
# and the middle column (x ~ 0.5) must be in grain 1
assert torch.all(labels[M // 2] == 1)


def test_grain_of_matches_assemble_apd_on_voxel_centres():
"""grain_of() on the pixel centres must reproduce assemble_apd()'s output
exactly -- same metric, same data, same result."""
L = 1.0
apd = apd_system(
N=5,
D=3,
domain=torch.tensor([[0.0, L], [0.0, L], [0.0, L]]),
periodic=True,
pixel_params=(16, 16, 16),
seed=42,
device="cpu",
)
apd.find_optimal_W(verbose=False)
voxel_labels = apd.assemble_apd().cpu()
# apd.Y holds the voxel centres (filled by assemble_pixels via assemble_apd).
point_labels = apd.grain_of(apd.Y).cpu()
assert torch.equal(voxel_labels, point_labels)


def test_grain_of_continuous_points():
"""grain_of must accept arbitrary (M, 3) torch tensors in the box."""
L = 1.0
apd = apd_system(
N=4,
D=3,
domain=torch.tensor([[0.0, L], [0.0, L], [0.0, L]]),
periodic=True,
seed=7,
device="cpu",
)
apd.find_optimal_W(verbose=False)
pts = torch.tensor(
[
[0.5, 0.5, 0.5],
[0.0, 0.0, 0.0],
[0.99, 0.01, 0.5],
[0.1, 0.9, 0.7],
]
)
labels = apd.grain_of(pts)
assert labels.shape == (4,)
assert int(labels.min()) >= 0 and int(labels.max()) < 4


def test_grain_of_rejects_wrong_shape():
"""grain_of must validate the (M, D) shape of its input."""
apd = apd_system(N=4, D=3, periodic=True, seed=1, device="cpu")
with pytest.raises(ValueError):
apd.grain_of(torch.zeros(5, 2)) # D=2 points into a D=3 system