Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions PyAPD/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,36 @@
from .log_res_utils import (
reorder_variables as reorder_variables,
)
from .polycrystal_atomistic import (
bulk_bcc_positions as bulk_bcc_positions,
)
from .polycrystal_atomistic import (
bulk_lattice_positions as bulk_lattice_positions,
)
from .polycrystal_atomistic import (
generate_bcc_polycrystal as generate_bcc_polycrystal,
)
from .polycrystal_atomistic import (
generate_fcc_polycrystal as generate_fcc_polycrystal,
)
from .polycrystal_atomistic import (
generate_hcp_polycrystal as generate_hcp_polycrystal,
)
from .polycrystal_atomistic import (
generate_polycrystal as generate_polycrystal,
)
from .polycrystal_atomistic import (
generate_sc_polycrystal as generate_sc_polycrystal,
)
from .polycrystal_atomistic import (
min_interatomic_distance as min_interatomic_distance,
)
from .polycrystal_atomistic import (
pbc_pair_cull as pbc_pair_cull,
)
from .polycrystal_atomistic import (
shoemake_uniform_so3 as shoemake_uniform_so3,
)
from .polycrystal_atomistic import (
write_lammps_data as write_lammps_data,
)
77 changes: 73 additions & 4 deletions PyAPD/apds.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(
error_tolerance=0.01,
pixel_size_prefactor=2,
seed=-1,
periodic=False,
):
"""
Construct an anisotropic power diagram system.
Expand Down Expand Up @@ -86,6 +87,11 @@ def __init__(
Multiplier applied to the computed pixel count. Default: 2.
seed : int
Manual random seed (< 0 means no seed is set). Default: -1.
periodic : bool
If True, distances between seeds and points use the minimum-image
convention across the (rectilinear) domain, so the power diagram
is periodic and grains may wrap across the domain boundary.
Default: False (fully backward-compatible).
"""

self.N = N
Expand All @@ -103,6 +109,15 @@ def __init__(

self.set_domain(domain)

self.periodic = bool(periodic)
# Per-dimension edge lengths of the (rectilinear) domain. Kept on the
# same device/dtype as the domain so KeOps arithmetic stays homogeneous.
edge = (self.domain[:, 1] - self.domain[:, 0]).to(
device=self.device, dtype=self.dt
)
self._L_tensor = edge # shape (D,), plain torch
self._L = LazyTensor(edge.view(1, 1, self.D)) # broadcast over (N, M, D)

self.set_X(X)

self.set_As(As)
Expand Down Expand Up @@ -387,6 +402,22 @@ def mask_pixels(self, mask):
self.W = initial_guess_heuristic(self.As, self.target_masses, self.D)
self.w = LazyTensor(self.W.view(self.N, 1, 1))

def _displacement(self, y, x):
"""Return ``y - x`` as a KeOps LazyTensor, wrapped to the minimum
image across the periodic domain when ``self.periodic`` is True.

Uses the KeOps-compatible expression ``dy - L * round(dy / L)``,
which rounds ``dy / L`` to the nearest integer per component.
``LazyTensor.floor()`` is unavailable on some KeOps builds and the
Python ``%`` operator is not defined between LazyTensors, so the
``.round()`` form is used.
"""
dy = y - x
if self.periodic:
shift = (dy / self._L).round()
dy = dy - self._L * shift
return dy

def assemble_apd(
self, record_time=False, verbose=False, color_by=None, backend="auto"
):
Expand All @@ -396,7 +427,8 @@ def assemble_apd(
if self.Y is None:
self.assemble_pixels()
start = time.time()
D_ij = ((self.y - self.x) | self.a.matvecmult(self.y - self.x)) - self.w
dy = self._displacement(self.y, self.x)
D_ij = (dy | self.a.matvecmult(dy)) - self.w
# Find which grain each pixel belongs to
grain_indices = D_ij.argmin(dim=0, backend=backend).ravel()
time_taken = time.time() - start
Expand All @@ -410,6 +442,34 @@ def assemble_apd(
else:
return grain_indices

def grain_of(self, points, backend="auto"):
"""Return the grain index that owns each point in ``points``.

Parameters
----------
points : torch.Tensor of shape (M, D)
Query points, in the same coordinate frame as ``self.domain``.
When ``self.periodic`` is True, query points are compared to each
seed using the minimum-image distance across the domain.
backend : str
KeOps reduction backend (forwarded to argmin).

Returns
-------
torch.Tensor of shape (M,), dtype int64
For each row of ``points``, the grain index ``i`` minimising
``(y - x_i) . A_i . (y - x_i) - w_i`` under the active metric.
"""
pts = points.to(device=self.device, dtype=self.dt)
if pts.ndim != 2 or pts.shape[1] != self.D:
raise ValueError(
f"grain_of: expected (M, {self.D}), got {tuple(pts.shape)}"
)
y = LazyTensor(pts.view(1, pts.shape[0], self.D))
dy = self._displacement(y, self.x)
D_ij = (dy | self.a.matvecmult(dy)) - self.w
return D_ij.argmin(dim=0, backend=backend).view(-1)

def plot_apd(
self, color_by=None, mode="auto", alpha=None, ps_scale=False, marker_scale=20.0
):
Expand Down Expand Up @@ -603,7 +663,8 @@ def OT_dual_function(self, W, backend="auto"):
self.W = W
self.w = LazyTensor(self.W.view(self.N, 1, 1))

D_ij = ((self.y - self.x) | self.a.matvecmult(self.y - self.x)) - self.w
dy = self._displacement(self.y, self.x)
D_ij = (dy | self.a.matvecmult(dy)) - self.w
idx = D_ij.argmin(dim=0, backend=backend).view(-1)

ind_select = torch.index_select(self.X, 0, idx) - self.Y
Expand All @@ -623,7 +684,8 @@ def check_optimality(
if self.Y is None:
self.assemble_pixels()

D_ij = ((self.y - self.x) | self.a.matvecmult(self.y - self.x)) - self.w
dy = self._displacement(self.y, self.x)
D_ij = (dy | self.a.matvecmult(dy)) - self.w

grain_indices = D_ij.argmin(dim=0, backend=backend).ravel()
volumes = torch.bincount(grain_indices, self.PS, minlength=self.N)
Expand Down Expand Up @@ -707,7 +769,8 @@ def adjust_X(self, backend="auto"):
if not self.optimality:
print("Find optimal W first!")
else:
D_ij = ((self.y - self.x) | self.a.matvecmult(self.y - self.x)) - self.w
dy = self._displacement(self.y, self.x)
D_ij = (dy | self.a.matvecmult(dy)) - self.w
grain_indices = D_ij.argmin(dim=0, backend=backend).ravel()
normalisation = torch.bincount(grain_indices, self.PS, minlength=self.N)
new_X0 = (
Expand All @@ -730,6 +793,12 @@ def adjust_X(self, backend="auto"):
else:
self.X = torch.stack([new_X0, new_X1], dim=1)

if self.periodic:
# Centroids may fall outside the box when a grain straddles the
# periodic seam; wrap them back into [domain[:, 0], domain[:, 1)).
origin = self.domain[:, 0]
self.X = origin + torch.remainder(self.X - origin, self._L_tensor)

self.optimality = False
self.x = LazyTensor(self.X.view(self.N, 1, self.D))

Expand Down
Loading