diff --git a/movement/kinematics/__init__.py b/movement/kinematics/__init__.py index 7216a367d..9d03d5651 100644 --- a/movement/kinematics/__init__.py +++ b/movement/kinematics/__init__.py @@ -1,5 +1,6 @@ """Compute variables derived from ``position`` data.""" +from movement.kinematics.collective import compute_polarization from movement.kinematics.distances import compute_pairwise_distances from movement.kinematics.kinematics import ( compute_acceleration, @@ -32,4 +33,5 @@ "compute_head_direction_vector", "compute_forward_vector_angle", "compute_kinetic_energy", + "compute_polarization", ] diff --git a/movement/kinematics/collective.py b/movement/kinematics/collective.py new file mode 100644 index 000000000..6874723a7 --- /dev/null +++ b/movement/kinematics/collective.py @@ -0,0 +1,345 @@ +# collective.py +"""Compute collective behavior metrics for multi-individual tracking data.""" + +from collections.abc import Hashable +from typing import Any + +import numpy as np +import xarray as xr + +from movement.utils.logging import logger +from movement.utils.vector import ( + compute_norm, + compute_signed_angle_2d, + convert_to_unit, +) +from movement.validators.arrays import validate_dims_coords + +_ANGLE_EPS = 1e-12 + + +def compute_polarization( + data: xr.DataArray, + body_axis_keypoints: tuple[Hashable, Hashable] | None = None, + displacement_frames: int = 1, + return_angle: bool = False, + in_degrees: bool = False, +) -> xr.DataArray | tuple[xr.DataArray, xr.DataArray]: + r"""Compute polarization (group alignment) of individuals. + + Polarization measures how aligned individuals' direction vectors are, + supporting two modes: **orientation polarization** (body-axis mode) for + body orientation alignment, and **heading polarization** (displacement + mode) for movement direction alignment. A value of 1 indicates perfect + alignment, while a value near 0 indicates weak or canceling alignment. + + The polarization is computed as + + .. math:: + + \Phi = \frac{1}{N} \left\| \sum_{i=1}^{N} \hat{u}_i \right\| + + where :math:`\hat{u}_i` is the unit direction vector for individual + :math:`i`, and :math:`N` is the number of valid individuals at that time. + + Parameters + ---------- + data : xarray.DataArray + Position data. Must contain ``time``, ``space``, and ``individuals`` as + dimensions. If ``body_axis_keypoints`` is provided, the array must also + contain a ``keypoints`` dimension. For displacement-based heading, + pre-select a keypoint (e.g., ``data.sel(keypoints="thorax")``) or the + first keypoint (index 0) will be used. + + Spatial coordinates must include ``"x"`` and ``"y"``. If additional + spatial coordinates are present (e.g., ``"z"``), they are ignored; + polarization is computed in the x/y plane. + body_axis_keypoints : tuple[Hashable, Hashable], optional + Pair of keypoint names ``(origin, target)`` used to compute heading as + the vector from origin to target. If omitted, heading is inferred from + displacement over ``displacement_frames``. + displacement_frames : int, default=1 + Number of frames used to compute displacement when + ``body_axis_keypoints`` is not provided. Must be a positive integer. + This parameter is ignored when ``body_axis_keypoints`` is provided. + return_angle : bool, default=False + If True, also return the mean angle. Returns the mean body + orientation angle when using ``body_axis_keypoints``, or the mean + movement direction angle when using displacement-based polarization. + in_degrees : bool, default=False + If True, the mean angle is returned in degrees. Otherwise, the + angle is returned in radians. Only relevant when + ``return_angle=True``. + + Returns + ------- + xarray.DataArray or tuple[xarray.DataArray, xarray.DataArray] + If ``return_angle`` is False, returns a DataArray named + ``"polarization"`` with dimension ``("time",)``. + + If ``return_angle`` is True, returns + ``(polarization, mean_angle)`` where ``mean_angle`` is a DataArray + named ``"mean_angle"`` with dimension ``("time",)``. + + Notes + ----- + Missing data are excluded per individual, per frame. + + Zero-length headings are treated as invalid and excluded from the + calculation. + + The mean angle is defined from the summed unit-heading vector projected + onto the x/y plane. When using ``body_axis_keypoints``, this represents + the mean body orientation; when using displacement, it represents the + mean movement direction. When no valid headings exist, or when the summed + heading vector has zero magnitude (for example exact cancellation), the + returned angle is NaN. + + Examples + -------- + Compute orientation polarization from body-axis keypoints: + + >>> polarization = compute_polarization( + ... ds.position, + ... body_axis_keypoints=("tail_base", "neck"), + ... ) + + Compute heading polarization from displacement (pre-select keypoint): + + >>> polarization = compute_polarization( + ... ds.position.sel(keypoints="thorax") + ... ) + + If multiple keypoints exist and none is selected, the first is used: + + >>> polarization = compute_polarization(ds.position) + + Return orientation polarization with mean body orientation angle: + + >>> polarization, mean_angle = compute_polarization( + ... ds.position, + ... body_axis_keypoints=("tail_base", "neck"), + ... return_angle=True, + ... ) + + Return heading polarization with mean movement direction angle (radians): + + >>> polarization, mean_angle = compute_polarization( + ... ds.position.sel(keypoints="thorax"), + ... return_angle=True, + ... ) + + Return heading polarization with mean movement direction angle (degrees): + + >>> polarization, mean_angle = compute_polarization( + ... ds.position.sel(keypoints="thorax"), + ... return_angle=True, + ... in_degrees=True, + ... ) + + If multiple keypoints exist, first is used; also return mean angle: + + >>> polarization, mean_angle = compute_polarization( + ... ds.position, + ... return_angle=True, + ... ) + + """ + _validate_type_data_array(data) + normalized_keypoints = _validate_position_data( + data=data, + body_axis_keypoints=body_axis_keypoints, + ) + + if normalized_keypoints is not None: + heading_vectors = _compute_heading_from_keypoints( + data=data, + body_axis_keypoints=normalized_keypoints, + ) + else: + heading_vectors = _compute_heading_from_velocity( + data=data, + displacement_frames=displacement_frames, + ) + + heading = _select_space(heading_vectors) + + unit_headings = convert_to_unit(heading) + valid_mask = ~unit_headings.isnull().any(dim="space") + vector_sum = unit_headings.sum(dim="individuals", skipna=True) + sum_magnitude = compute_norm(vector_sum) + n_valid = valid_mask.sum(dim="individuals") + + polarization = xr.where( + n_valid > 0, + sum_magnitude / n_valid, + np.nan, + ).clip(min=0.0, max=1.0) + polarization = polarization.rename("polarization") + + if not return_angle: + return polarization + + # Normalize vector_sum to unit vector for angle computation + mean_unit_vector = vector_sum / sum_magnitude + + # Compute angle from positive x-axis to mean unit vector + reference = np.array([1, 0]) + angle_defined = (n_valid > 0) & (sum_magnitude > _ANGLE_EPS) + mean_angle = xr.where( + angle_defined, + compute_signed_angle_2d( + mean_unit_vector, reference, v_as_left_operand=True + ), + np.nan, + ) + if in_degrees: + mean_angle = np.rad2deg(mean_angle) + mean_angle = mean_angle.rename("mean_angle") + + return polarization, mean_angle + + +def _compute_heading_from_keypoints( + data: xr.DataArray, + body_axis_keypoints: tuple[Hashable, Hashable], +) -> xr.DataArray: + """Compute heading vectors from two keypoints (origin to target).""" + origin, target = body_axis_keypoints + heading = data.sel(keypoints=target, drop=True) - data.sel( + keypoints=origin, + drop=True, + ) + return heading + + +def _compute_heading_from_velocity( + data: xr.DataArray, + displacement_frames: int = 1, +) -> xr.DataArray: + """Compute heading vectors from displacement direction.""" + _validate_displacement_frames(displacement_frames) + + position = data + if "keypoints" in data.dims: + if data.sizes["keypoints"] < 1: + raise ValueError( + "data.keypoints must contain at least one keypoint." + ) + position = data.isel(keypoints=0, drop=True) + + if "keypoints" in data.coords and data.coords["keypoints"].size > 0: + logger.info( + "Using keypoint '%s' for displacement-based heading.", + data.coords["keypoints"].values[0], + ) + else: + logger.info( + "Using keypoint index 0 for displacement-based heading." + ) + + displacement = position - position.shift(time=displacement_frames) + return displacement + + +def _select_space(data: xr.DataArray) -> xr.DataArray: + """Return data with standard dim order, selecting only x and y coords.""" + result = data.sel(space=["x", "y"]) + return result.transpose("time", "space", "individuals") + + +def _validate_position_data( + data: xr.DataArray, + body_axis_keypoints: tuple[Hashable, Hashable] | None, +) -> tuple[Hashable, Hashable] | None: + """Validate the input array and normalize ``body_axis_keypoints``.""" + validate_dims_coords( + data, + { + "time": [], + "space": [], + "individuals": [], + }, + ) + + allowed_dims = {"time", "space", "individuals", "keypoints"} + unexpected_dims = set(data.dims) - allowed_dims + if unexpected_dims: + raise ValueError( + f"data contains unsupported dimension(s): " + f"{sorted(str(d) for d in unexpected_dims)}" + ) + + if "space" not in data.coords: + raise ValueError( + "data must have coordinate labels for the 'space' dimension." + ) + + space_labels = set(data.coords["space"].values.tolist()) + if not {"x", "y"}.issubset(space_labels): + raise ValueError( + "data.space must include coordinate labels 'x' and 'y'." + ) + + if body_axis_keypoints is None: + return None + + origin, target = _normalize_body_axis_keypoints(body_axis_keypoints) + + if "keypoints" not in data.dims: + raise ValueError( + "body_axis_keypoints requires a 'keypoints' dimension in data." + ) + + validate_dims_coords(data, {"keypoints": [origin, target]}) + return origin, target + + +def _normalize_body_axis_keypoints( + body_axis_keypoints: tuple[Hashable, Hashable] | Any, +) -> tuple[Hashable, Hashable]: + """Validate and normalize the keypoint pair.""" + if isinstance(body_axis_keypoints, (str, bytes)): + raise TypeError( + "body_axis_keypoints must be an iterable of exactly two " + "keypoint names." + ) + + try: + origin, target = body_axis_keypoints + except (TypeError, ValueError) as exc: + raise TypeError( + "body_axis_keypoints must be an iterable of exactly two " + "keypoint names." + ) from exc + + for keypoint in (origin, target): + if not isinstance(keypoint, Hashable): + raise TypeError("Each body axis keypoint must be hashable.") + + if origin == target: + raise ValueError( + "body_axis_keypoints must contain two distinct keypoint names." + ) + + return origin, target + + +def _validate_displacement_frames(displacement_frames: int) -> None: + """Validate the displacement window.""" + if isinstance(displacement_frames, (bool, np.bool_)) or not isinstance( + displacement_frames, + (int, np.integer), + ): + raise TypeError("displacement_frames must be a positive integer.") + + if displacement_frames < 1: + raise ValueError("displacement_frames must be >= 1.") + + +def _validate_type_data_array(data: xr.DataArray) -> None: + """Validate that the input is an xarray.DataArray.""" + if not isinstance(data, xr.DataArray): + raise TypeError( + f"Input data must be an xarray.DataArray, but got {type(data)}." + ) diff --git a/tests/test_unit/test_kinematics/test_collective.py b/tests/test_unit/test_kinematics/test_collective.py new file mode 100644 index 000000000..2b6cdc265 --- /dev/null +++ b/tests/test_unit/test_kinematics/test_collective.py @@ -0,0 +1,1130 @@ +# test_collective.py +"""Tests for the collective behavior metrics module.""" + +import numpy as np +import pytest +import xarray as xr + +from movement import kinematics + + +def _get_space_labels(n_space: int, space: list[str] | None) -> list[str]: + """Return space labels, defaulting to ['x', 'y'] for 2D.""" + if space is not None: + return space + if n_space == 2: + return ["x", "y"] + raise ValueError("Provide explicit `space` labels for non-2D data.") + + +def _make_position_dataarray( + data: np.ndarray, + *, + time: list | None = None, + individuals: list | None = None, + keypoints: list[str] | None = None, + space: list[str] | None = None, +) -> xr.DataArray: + """Create a position DataArray for tests.""" + data = np.asarray(data, dtype=float) + n_time, n_space = data.shape[0], data.shape[1] + + if data.ndim == 3: + n_individuals = data.shape[2] + ind = individuals or [f"id_{i}" for i in range(n_individuals)] + return xr.DataArray( + data, + dims=["time", "space", "individuals"], + coords={ + "time": time if time else list(range(n_time)), + "space": _get_space_labels(n_space, space), + "individuals": ind, + }, + name="position", + ) + + if data.ndim == 4: + n_keypoints, n_individuals = data.shape[2], data.shape[3] + kp = keypoints or [f"kp_{i}" for i in range(n_keypoints)] + ind = individuals or [f"id_{i}" for i in range(n_individuals)] + return xr.DataArray( + data, + dims=["time", "space", "keypoints", "individuals"], + coords={ + "time": time if time else list(range(n_time)), + "space": _get_space_labels(n_space, space), + "keypoints": kp, + "individuals": ind, + }, + name="position", + ) + + raise ValueError( + "Expected data with shape (time, space, individuals) or " + "(time, space, keypoints, individuals)." + ) + + +@pytest.fixture +def aligned_positions() -> xr.DataArray: + """Two individuals moving together in +x direction.""" + data = np.array( + [ + [[0, 5], [0, 0]], + [[1, 6], [0, 0]], + [[2, 7], [0, 0]], + [[3, 8], [0, 0]], + ], + dtype=float, + ) + return _make_position_dataarray(data) + + +@pytest.fixture +def opposite_positions() -> xr.DataArray: + """Two individuals moving in opposite x directions (+x and -x).""" + data = np.array( + [ + [[0, 5], [0, 0]], + [[1, 4], [0, 0]], + [[2, 3], [0, 0]], + [[3, 2], [0, 0]], + ], + dtype=float, + ) + return _make_position_dataarray(data) + + +@pytest.fixture +def partial_alignment_positions() -> xr.DataArray: + """Three individuals: two move +x, one moves +y.""" + data = np.array( + [ + [[0, 5, 0], [0, 0, 0]], + [[1, 6, 0], [0, 0, 1]], + [[2, 7, 0], [0, 0, 2]], + [[3, 8, 0], [0, 0, 3]], + ], + dtype=float, + ) + return _make_position_dataarray(data) + + +@pytest.fixture +def cardinal_directions_positions() -> xr.DataArray: + """Four individuals moving in cardinal directions (+x, -x, +y, -y).""" + data = np.array( + [ + [[0, 10, 0, 0], [0, 0, 0, 10]], + [[1, 9, 0, 0], [0, 0, 1, 9]], + [[2, 8, 0, 0], [0, 0, 2, 8]], + [[3, 7, 0, 0], [0, 0, 3, 7]], + ], + dtype=float, + ) + return _make_position_dataarray(data) + + +@pytest.fixture +def keypoint_positions() -> xr.DataArray: + """Two individuals with tail_base/neck keypoints, both facing +x.""" + data = np.array( + [ + [ + [[0.0, 10.0], [1.0, 11.0]], + [[0.0, 0.0], [0.0, 0.0]], + ], + [ + [[0.5, 10.5], [1.5, 11.5]], + [[0.0, 0.0], [0.0, 0.0]], + ], + [ + [[1.0, 11.0], [2.0, 12.0]], + [[0.0, 0.0], [0.0, 0.0]], + ], + ], + dtype=float, + ) + return _make_position_dataarray(data, keypoints=["tail_base", "neck"]) + + +class TestComputePolarizationValidation: + """Tests for input validation in compute_polarization.""" + + def test_requires_dataarray(self): + """Raise TypeError if input is not an xarray.DataArray.""" + with pytest.raises(TypeError, match="xarray.DataArray"): + kinematics.compute_polarization(np.zeros((3, 2, 2))) + + @pytest.mark.parametrize( + "dims", + [ + ("space", "individuals"), + ("time", "individuals"), + ("time", "space"), + ], + ids=["missing_time", "missing_space", "missing_individuals"], + ) + def test_requires_time_space_individuals(self, dims): + """Raise ValueError if required dimensions are missing.""" + data = xr.DataArray(np.zeros((2, 2)), dims=dims) + with pytest.raises(ValueError, match="time|space|individuals"): + kinematics.compute_polarization(data) + + def test_rejects_unexpected_dimensions(self): + """Raise ValueError if data contains unsupported dimensions.""" + data = xr.DataArray( + np.zeros((3, 2, 2, 2)), + dims=["time", "space", "individuals", "batch"], + coords={ + "time": [0, 1, 2], + "space": ["x", "y"], + "individuals": ["a", "b"], + "batch": [0, 1], + }, + ) + with pytest.raises(ValueError, match="unsupported dimension"): + kinematics.compute_polarization(data) + + def test_requires_x_and_y_space_labels(self): + """Raise ValueError if space dimension lacks x and y labels.""" + data = xr.DataArray( + np.zeros((3, 2, 2)), + dims=["time", "space", "individuals"], + coords={ + "time": [0, 1, 2], + "space": ["lat", "lon"], + "individuals": ["a", "b"], + }, + ) + with pytest.raises( + ValueError, match="include coordinate labels 'x' and 'y'" + ): + kinematics.compute_polarization(data) + + @pytest.mark.parametrize( + "body_axis_keypoints", + [ + "neck", + ("tail_base",), + ("tail_base", "neck", "ear"), + 123, + ], + ids=["string", "length_one", "length_three", "non_iterable"], + ) + def test_body_axis_keypoints_must_be_length_two_iterable( + self, + body_axis_keypoints, + keypoint_positions, + ): + """Raise TypeError if body_axis_keypoints is not length-two.""" + with pytest.raises(TypeError, match="exactly two keypoint names"): + kinematics.compute_polarization( + keypoint_positions, + body_axis_keypoints=body_axis_keypoints, + ) + + def test_body_axis_keypoints_must_be_hashable(self, keypoint_positions): + """Raise TypeError if body axis keypoints are not hashable.""" + with pytest.raises(TypeError, match="hashable"): + kinematics.compute_polarization( + keypoint_positions, + body_axis_keypoints=(["tail_base"], "neck"), + ) + + def test_body_axis_keypoints_require_keypoints_dimension( + self, aligned_positions + ): + """Raise ValueError if body_axis_keypoints given without keypoints.""" + with pytest.raises( + ValueError, match="requires a 'keypoints' dimension" + ): + kinematics.compute_polarization( + aligned_positions, + body_axis_keypoints=("tail_base", "neck"), + ) + + def test_body_axis_keypoints_must_exist(self, keypoint_positions): + """Raise ValueError if specified keypoints do not exist in data.""" + with pytest.raises(ValueError, match="snout|keypoints"): + kinematics.compute_polarization( + keypoint_positions, + body_axis_keypoints=("tail_base", "snout"), + ) + + def test_body_axis_keypoints_must_be_distinct(self, keypoint_positions): + """Raise ValueError if origin and target keypoints are identical.""" + with pytest.raises(ValueError, match="two distinct keypoint names"): + kinematics.compute_polarization( + keypoint_positions, + body_axis_keypoints=("tail_base", "tail_base"), + ) + + @pytest.mark.parametrize( + "displacement_frames,expected_exception", + [ + (0, ValueError), + (-1, ValueError), + (1.5, TypeError), + (True, TypeError), + ], + ids=["zero", "negative", "float", "bool"], + ) + def test_displacement_frames_must_be_positive_integer( + self, + aligned_positions, + displacement_frames, + expected_exception, + ): + """Raise error if displacement_frames is not a positive integer.""" + with pytest.raises(expected_exception, match="positive integer|>= 1"): + kinematics.compute_polarization( + aligned_positions, + displacement_frames=displacement_frames, + ) + + def test_invalid_displacement_frames_is_ignored_in_keypoint_mode( + self, + keypoint_positions, + ): + """Invalid displacement_frames is ignored when keypoints are used.""" + polarization = kinematics.compute_polarization( + keypoint_positions, + body_axis_keypoints=("tail_base", "neck"), + displacement_frames=0, + ) + assert np.allclose(polarization.values, 1.0, atol=1e-10) + + def test_requires_space_coordinate_labels_to_exist(self): + """Raise ValueError if the space dimension has no coordinate labels.""" + data = xr.DataArray( + np.zeros((3, 2, 2)), + dims=["time", "space", "individuals"], + coords={ + "time": [0, 1, 2], + "individuals": ["a", "b"], + }, + name="position", + ) + with pytest.raises( + ValueError, + match="coordinate labels for the 'space' dimension", + ): + kinematics.compute_polarization(data) + + def test_empty_keypoints_dimension_raises_in_displacement_mode(self): + """Raise if keypoints dimension exists but contains no entries.""" + data = xr.DataArray( + np.empty((3, 2, 0, 2)), + dims=["time", "space", "keypoints", "individuals"], + coords={ + "time": [0, 1, 2], + "space": ["x", "y"], + "keypoints": [], + "individuals": ["a", "b"], + }, + name="position", + ) + with pytest.raises(ValueError, match="at least one keypoint"): + kinematics.compute_polarization(data) + + +class TestComputePolarizationBehavior: + """Tests for polarization computation behavior.""" + + def test_aligned_motion_gives_one(self, aligned_positions): + """Polarization is 1.0 when all individuals move in same direction.""" + polarization = kinematics.compute_polarization(aligned_positions) + assert np.isnan(polarization.values[0]) + assert np.allclose(polarization.values[1:], 1.0, atol=1e-10) + + def test_opposite_motion_gives_zero(self, opposite_positions): + """Polarization is 0.0 when individuals move in opposite directions.""" + polarization = kinematics.compute_polarization(opposite_positions) + assert np.allclose(polarization.values[1:], 0.0, atol=1e-10) + + def test_four_cardinal_directions_cancel_to_zero( + self, cardinal_directions_positions + ): + """Polarization is 0.0 when four individuals move in cardinal dirs.""" + polarization = kinematics.compute_polarization( + cardinal_directions_positions + ) + assert np.allclose(polarization.values[1:], 0.0, atol=1e-10) + + def test_partial_alignment_matches_expected_magnitude( + self, + partial_alignment_positions, + ): + """Polarization matches expected value for partial alignment.""" + polarization = kinematics.compute_polarization( + partial_alignment_positions + ) + expected = np.sqrt(5) / 3 + assert np.allclose(polarization.values[1:], expected, atol=1e-10) + + def test_single_individual_gives_one(self): + """Polarization is 1.0 for a single moving individual.""" + data = np.array( + [ + [[0], [0]], + [[1], [0]], + [[2], [0]], + [[3], [0]], + ], + dtype=float, + ) + polarization = kinematics.compute_polarization( + _make_position_dataarray(data) + ) + assert np.allclose(polarization.values[1:], 1.0, atol=1e-10) + + def test_stationary_individuals_are_excluded(self): + """Stationary individuals produce NaN polarization and angle.""" + data = np.array( + [ + [[0, 10], [0, 0]], + [[0, 10], [0, 0]], + [[0, 10], [0, 0]], + ], + dtype=float, + ) + polarization, mean_angle = kinematics.compute_polarization( + _make_position_dataarray(data), + return_angle=True, + ) + assert np.all(np.isnan(polarization.values)) + assert np.all(np.isnan(mean_angle.values)) + + def test_stationary_and_moving_individuals_uses_only_valid_headings(self): + """Only moving individuals contribute to polarization.""" + data = np.array( + [ + [[0, 10], [0, 0]], + [[1, 10], [0, 0]], + [[2, 10], [0, 0]], + [[3, 10], [0, 0]], + ], + dtype=float, + ) + polarization, mean_angle = kinematics.compute_polarization( + _make_position_dataarray(data), + return_angle=True, + ) + assert np.allclose(polarization.values[1:], 1.0, atol=1e-10) + assert np.allclose(mean_angle.values[1:], 0.0, atol=1e-10) + + def test_one_coordinate_nan_excludes_that_individual(self): + """NaN in one coordinate excludes that individual from calculation.""" + data = np.array( + [ + [[0, 10], [0, 0]], + [[1, np.nan], [0, 0]], + [[2, 12], [0, 0]], + ], + dtype=float, + ) + polarization, mean_angle = kinematics.compute_polarization( + _make_position_dataarray(data), + return_angle=True, + ) + assert np.isnan(polarization.values[0]) + assert np.allclose(polarization.values[1:], 1.0, atol=1e-10) + assert np.allclose(mean_angle.values[1:], 0.0, atol=1e-10) + + def test_nan_in_body_axis_heading_excludes_that_individual(self): + """NaN in keypoint position excludes that individual.""" + data = np.array( + [ + [ + [[0.0, 10.0], [1.0, 11.0]], + [[0.0, 0.0], [0.0, 0.0]], + ], + [ + [[1.0, 10.0], [2.0, np.nan]], + [[0.0, 0.0], [0.0, np.nan]], + ], + [ + [[2.0, 12.0], [3.0, 13.0]], + [[0.0, 0.0], [0.0, 0.0]], + ], + ], + dtype=float, + ) + da = _make_position_dataarray(data, keypoints=["tail_base", "neck"]) + polarization = kinematics.compute_polarization( + da, + body_axis_keypoints=("tail_base", "neck"), + ) + assert np.allclose(polarization.values[[0, 2]], 1.0, atol=1e-10) + assert np.allclose(polarization.values[1], 1.0, atol=1e-10) + + def test_empty_individual_axis_returns_all_nan(self): + """Empty individuals axis returns all NaN values.""" + data = _make_position_dataarray( + np.empty((3, 2, 0)), + individuals=[], + space=["x", "y"], + ) + polarization, mean_angle = kinematics.compute_polarization( + data, + return_angle=True, + ) + assert np.all(np.isnan(polarization.values)) + assert np.all(np.isnan(mean_angle.values)) + + def test_empty_time_axis_returns_empty_outputs(self): + """Empty time axis returns empty output arrays.""" + data = xr.DataArray( + np.empty((0, 2, 0)), + dims=["time", "space", "individuals"], + coords={"time": [], "space": ["x", "y"], "individuals": []}, + name="position", + ) + polarization, mean_angle = kinematics.compute_polarization( + data, + return_angle=True, + ) + assert polarization.shape == (0,) + assert mean_angle.shape == (0,) + assert polarization.name == "polarization" + assert mean_angle.name == "mean_angle" + + def test_preserves_non_uniform_time_coordinates(self, aligned_positions): + """Non-uniform time coordinates are preserved in output.""" + time = [0.0, 0.25, 0.75, 1.5] + data = aligned_positions.assign_coords(time=time) + polarization, mean_angle = kinematics.compute_polarization( + data, + return_angle=True, + ) + np.testing.assert_array_equal(polarization.time.values, time) + np.testing.assert_array_equal(mean_angle.time.values, time) + + def test_polarization_is_invariant_to_individual_order(self): + """Polarization is independent of individual ordering.""" + data = np.array( + [ + [[0, 5, 0], [0, 0, 0]], + [[1, 6, 0], [0, 0, 1]], + [[2, 7, 0], [0, 0, 2]], + [[3, 8, 0], [0, 0, 3]], + ], + dtype=float, + ) + da = _make_position_dataarray(data) + da_permuted = da.isel(individuals=[2, 0, 1]) + + pol_original = kinematics.compute_polarization(da) + pol_permuted = kinematics.compute_polarization(da_permuted) + + np.testing.assert_allclose( + pol_original.values, pol_permuted.values, atol=1e-10 + ) + + def test_zero_length_body_axis_vectors_are_excluded(self): + """Zero-length body-axis headings are excluded as invalid.""" + # ind0 has coincident tail_base and neck (zero-length heading) + # ind1 has valid +x body axis heading + data = np.array( + [ + [ + [[0.0, 10.0], [0.0, 11.0]], # x: ind0 zero-length, ind1 +1 + [[0.0, 0.0], [0.0, 0.0]], # y + ], + [ + [[0.0, 10.5], [0.0, 11.5]], + [[0.0, 0.0], [0.0, 0.0]], + ], + ], + dtype=float, + ) + da = _make_position_dataarray(data, keypoints=["tail_base", "neck"]) + + polarization, mean_angle = kinematics.compute_polarization( + da, + body_axis_keypoints=("tail_base", "neck"), + return_angle=True, + ) + + assert np.allclose(polarization.values, 1.0, atol=1e-10) + assert np.allclose(mean_angle.values, 0.0, atol=1e-10) + + def test_polarization_is_invariant_to_translation( + self, + partial_alignment_positions, + ): + """Adding a constant offset does not change polarization.""" + shifted = partial_alignment_positions.copy() + shifted.loc[{"space": "x"}] = shifted.sel(space="x") + 1000.0 + shifted.loc[{"space": "y"}] = shifted.sel(space="y") - 500.0 + + pol_original = kinematics.compute_polarization( + partial_alignment_positions + ) + pol_shifted = kinematics.compute_polarization(shifted) + + np.testing.assert_allclose( + pol_original.values, + pol_shifted.values, + atol=1e-10, + equal_nan=True, + ) + + def test_polarization_is_invariant_to_positive_scaling( + self, + partial_alignment_positions, + ): + """Positive scalar multiplication preserves polarization.""" + scaled = partial_alignment_positions * 7.5 + + pol_original = kinematics.compute_polarization( + partial_alignment_positions + ) + pol_scaled = kinematics.compute_polarization(scaled) + + np.testing.assert_allclose( + pol_original.values, + pol_scaled.values, + atol=1e-10, + equal_nan=True, + ) + + def test_polarization_is_invariant_to_global_rotation( + self, + partial_alignment_positions, + ): + """A global planar rotation preserves polarization magnitude.""" + x = partial_alignment_positions.sel(space="x") + y = partial_alignment_positions.sel(space="y") + + rotated = partial_alignment_positions.copy() + rotated.loc[{"space": "x"}] = -y + rotated.loc[{"space": "y"}] = x + + pol_original = kinematics.compute_polarization( + partial_alignment_positions + ) + pol_rotated = kinematics.compute_polarization(rotated) + + np.testing.assert_allclose( + pol_original.values, + pol_rotated.values, + atol=1e-10, + equal_nan=True, + ) + + def test_body_axis_invariance_to_translation_scaling_rotation( + self, + ): + """Body-axis polarization is invariant to translation/scaling/rotation. + + Mean body angle is invariant to translation and positive scaling, and + rotates by the same amount under global planar rotation. + """ + # Three individuals with body axes: +x, +x, +y. + # This gives a nontrivial baseline: + # vector sum = (2, 1) + # polarization = sqrt(5) / 3 + # mean angle = atan2(1, 2) + # + # Absolute positions differ across frames to ensure we are really + # testing body-axis heading (target - origin), not any accidental + # dependence on absolute location. + data = np.array( + [ + [ + [[0.0, 10.0, -2.0], [1.0, 11.0, -2.0]], # x + [[0.0, 5.0, 3.0], [0.0, 5.0, 4.0]], # y + ], + [ + [[100.0, 50.0, 7.0], [101.0, 51.0, 7.0]], # x + [[-1.0, 20.0, -3.0], [-1.0, 20.0, -2.0]], # y + ], + ], + dtype=float, + ) + da = _make_position_dataarray(data, keypoints=["tail_base", "neck"]) + + pol_base, angle_base = kinematics.compute_polarization( + da, + body_axis_keypoints=("tail_base", "neck"), + return_angle=True, + ) + + expected_pol = np.sqrt(5) / 3 + expected_angle = np.arctan2(1.0, 2.0) + + np.testing.assert_allclose(pol_base.values, expected_pol, atol=1e-10) + np.testing.assert_allclose( + angle_base.values, expected_angle, atol=1e-10 + ) + + # Global translation: should not affect body-axis vectors. + translated = da.copy() + translated.loc[{"space": "x"}] = translated.sel(space="x") + 123.4 + translated.loc[{"space": "y"}] = translated.sel(space="y") - 56.7 + + pol_translated, angle_translated = kinematics.compute_polarization( + translated, + body_axis_keypoints=("tail_base", "neck"), + return_angle=True, + ) + + np.testing.assert_allclose( + pol_translated.values, pol_base.values, atol=1e-10 + ) + np.testing.assert_allclose( + angle_translated.values, angle_base.values, atol=1e-10 + ) + + # Positive scaling: should preserve directions and therefore preserve + # polarization and angle. + scaled = da * 4.2 + + pol_scaled, angle_scaled = kinematics.compute_polarization( + scaled, + body_axis_keypoints=("tail_base", "neck"), + return_angle=True, + ) + + np.testing.assert_allclose( + pol_scaled.values, pol_base.values, atol=1e-10 + ) + np.testing.assert_allclose( + angle_scaled.values, angle_base.values, atol=1e-10 + ) + + # Global 90-degree rotation: polarization magnitude should be + # unchanged, and mean angle should rotate by +pi/2 (with wraparound). + rotated = da.copy() + x = da.sel(space="x") + y = da.sel(space="y") + rotated.loc[{"space": "x"}] = -y + rotated.loc[{"space": "y"}] = x + + pol_rotated, angle_rotated = kinematics.compute_polarization( + rotated, + body_axis_keypoints=("tail_base", "neck"), + return_angle=True, + ) + + np.testing.assert_allclose( + pol_rotated.values, pol_base.values, atol=1e-10 + ) + + expected_rotated_angle = angle_base.values + (np.pi / 2) + expected_rotated_angle = ( + (expected_rotated_angle + np.pi) % (2 * np.pi) + ) - np.pi + + np.testing.assert_allclose( + angle_rotated.values, expected_rotated_angle, atol=1e-10 + ) + + +class TestHeadingSourceSelection: + """Tests for heading computation mode selection.""" + + def test_body_axis_heading_valid_on_first_frame_returns_expected_angle( + self, keypoint_positions + ): + """Body-axis heading is valid from frame 0 and returns angle 0.""" + polarization, mean_angle = kinematics.compute_polarization( + keypoint_positions, + body_axis_keypoints=("tail_base", "neck"), + return_angle=True, + ) + assert np.allclose(polarization.values, 1.0, atol=1e-10) + assert np.allclose(mean_angle.values, 0.0, atol=1e-10) + + def test_displacement_mode_with_keypoints_uses_first_keypoint(self): + """Displacement mode uses first keypoint when multiple exist.""" + data = np.array( + [ + [ + [[0, 10], [0, 10]], + [[0, 0], [0, 0]], + ], + [ + [[1, 11], [1, 9]], + [[0, 0], [0, 0]], + ], + [ + [[2, 12], [2, 8]], + [[0, 0], [0, 0]], + ], + ], + dtype=float, + ) + da = _make_position_dataarray(data, keypoints=["thorax", "head"]) + polarization = kinematics.compute_polarization(da) + assert np.allclose(polarization.values[1:], 1.0, atol=1e-10) + + def test_explicit_keypoint_selection_with_sel(self): + """Pre-selecting keypoint with .sel() uses that keypoint. + + Data shape: (time, space, keypoints, individuals). + + X-coordinates across frames: + + Keypoint | Individual | Frame 0 | Frame 1 | Displacement + ---------|------------|---------|---------|------------- + thorax | ind0 | 0 | 1 | +1 (right) + thorax | ind1 | 10 | 11 | +1 (right) + head | ind0 | 0 | 1 | +1 (right) + head | ind1 | 10 | 9 | -1 (left) + + Thorax: both individuals move right -> polarization = 1.0 + Head: ind0 moves right, ind1 moves left -> polarization = 0.0 + """ + data = np.array( + [ + [ + [[0, 10], [0, 10]], + [[0, 0], [0, 0]], + ], + [ + [[1, 11], [1, 9]], + [[0, 0], [0, 0]], + ], + ], + dtype=float, + ) + da = _make_position_dataarray(data, keypoints=["thorax", "head"]) + + # Without .sel(): uses thorax -> both move right -> polarization = 1.0 + pol_default = kinematics.compute_polarization(da) + assert np.allclose(pol_default.values[1], 1.0, atol=1e-10) + + # With .sel(): head selected -> ind0 right, ind1 left -> 0.0 + pol_head = kinematics.compute_polarization(da.sel(keypoints="head")) + assert np.allclose(pol_head.values[1], 0.0, atol=1e-10) + + def test_body_axis_heading_overrides_displacement_behavior(self): + """Body-axis heading overrides displacement computation.""" + data = np.array( + [ + [ + [[0.0, 0.0], [1.0, 1.0]], + [[0.0, 2.0], [0.0, 2.0]], + ], + [ + [[0.0, 0.0], [1.0, 1.0]], + [[1.0, 3.0], [1.0, 3.0]], + ], + ], + dtype=float, + ) + da = _make_position_dataarray(data, keypoints=["tail_base", "neck"]) + polarization = kinematics.compute_polarization( + da, + body_axis_keypoints=("tail_base", "neck"), + displacement_frames=1000, + ) + assert np.allclose(polarization.values, 1.0, atol=1e-10) + + def test_extra_spatial_dimensions_are_ignored_for_planar_metrics(self): + """Extra spatial dimensions (z) are ignored; only x/y used.""" + data = np.array( + [ + [[0, 5], [0, 0], [0, 100]], + [[1, 6], [0, 0], [10, -100]], + [[2, 7], [0, 0], [-10, 50]], + [[3, 8], [0, 0], [999, -999]], + ], + dtype=float, + ) + da = _make_position_dataarray(data, space=["x", "y", "z"]) + polarization, mean_angle = kinematics.compute_polarization( + da, + return_angle=True, + ) + assert np.allclose(polarization.values[1:], 1.0, atol=1e-10) + assert np.allclose(mean_angle.values[1:], 0.0, atol=1e-10) + + +class TestDisplacementFrames: + """Tests for displacement_frames parameter behavior.""" + + def test_first_n_frames_are_nan(self, aligned_positions): + """First N frames are NaN when displacement_frames=N.""" + polarization, mean_angle = kinematics.compute_polarization( + aligned_positions, + displacement_frames=2, + return_angle=True, + ) + assert np.isnan(polarization.values[0]) + assert np.isnan(polarization.values[1]) + assert np.isnan(mean_angle.values[0]) + assert np.isnan(mean_angle.values[1]) + assert np.allclose(polarization.values[2:], 1.0, atol=1e-10) + assert np.allclose(mean_angle.values[2:], 0.0, atol=1e-10) + + def test_nan_in_reference_frame_propagates_to_that_displacement_window( + self, + ): + """NaN in reference frame propagates through displacement window.""" + data = np.array( + [ + [[0, 5], [0, 0]], + [[np.nan, np.nan], [np.nan, np.nan]], + [[2, 7], [0, 0]], + [[3, 8], [0, 0]], + [[4, 9], [0, 0]], + ], + dtype=float, + ) + polarization = kinematics.compute_polarization( + _make_position_dataarray(data), + displacement_frames=2, + ) + assert np.isnan(polarization.values[0]) + assert np.isnan(polarization.values[1]) + assert np.allclose(polarization.values[2], 1.0, atol=1e-10) + assert np.isnan(polarization.values[3]) + assert np.allclose(polarization.values[4], 1.0, atol=1e-10) + + def test_larger_displacement_window_can_change_alignment_estimate(self): + """Larger displacement window smooths jittery movement.""" + data = np.array( + [ + [[0, 10], [0, 0]], + [[2, 9], [0, 0]], + [[1, 11], [0, 0]], + [[3, 10], [0, 0]], + [[2, 12], [0, 0]], + [[4, 11], [0, 0]], + ], + dtype=float, + ) + da = _make_position_dataarray(data) + + pol_1frame = kinematics.compute_polarization(da, displacement_frames=1) + pol_2frame = kinematics.compute_polarization(da, displacement_frames=2) + + assert np.allclose(pol_1frame.values[1:], 0.0, atol=1e-10) + assert np.allclose(pol_2frame.values[2:], 1.0, atol=1e-10) + + def test_displacement_frames_larger_than_time_axis_returns_all_nan( + self, + aligned_positions, + ): + """Oversized displacement windows produce no valid headings.""" + polarization, mean_angle = kinematics.compute_polarization( + aligned_positions, + displacement_frames=10, + return_angle=True, + ) + assert np.all(np.isnan(polarization.values)) + assert np.all(np.isnan(mean_angle.values)) + + +class TestReturnAngle: + """Tests for return_angle parameter behavior.""" + + def test_default_returns_only_polarization(self, aligned_positions): + """Default return is a single polarization DataArray.""" + result = kinematics.compute_polarization(aligned_positions) + assert isinstance(result, xr.DataArray) + assert result.name == "polarization" + assert result.dims == ("time",) + + def test_return_angle_true_returns_named_pair(self, aligned_positions): + """return_angle=True returns (polarization, mean_angle) tuple.""" + polarization, mean_angle = kinematics.compute_polarization( + aligned_positions, + return_angle=True, + ) + assert isinstance(polarization, xr.DataArray) + assert isinstance(mean_angle, xr.DataArray) + assert polarization.name == "polarization" + assert mean_angle.name == "mean_angle" + assert polarization.dims == ("time",) + assert mean_angle.dims == ("time",) + + @pytest.mark.parametrize( + "data,expected_angle,use_abs", + [ + ( + np.array( + [ + [[0, 5], [0, 0]], + [[1, 6], [0, 0]], + [[2, 7], [0, 0]], + ], + dtype=float, + ), + 0.0, + False, + ), + ( + np.array( + [ + [[0, 0], [0, 5]], + [[0, 0], [1, 6]], + [[0, 0], [2, 7]], + ], + dtype=float, + ), + np.pi / 2, + False, + ), + ( + np.array( + [ + [[10, 15], [0, 0]], + [[9, 14], [0, 0]], + [[8, 13], [0, 0]], + ], + dtype=float, + ), + np.pi, + True, + ), + ], + ids=["positive_x", "positive_y", "negative_x"], + ) + def test_mean_angle_matches_cardinal_directions( + self, + data, + expected_angle, + use_abs, + ): + """Mean angle matches expected value for cardinal directions.""" + _, mean_angle = kinematics.compute_polarization( + _make_position_dataarray(data), + return_angle=True, + ) + values = mean_angle.values[1:] + if use_abs: + values = np.abs(values) + assert np.allclose(values, expected_angle, atol=1e-10) + + def test_mean_angle_diagonal_motion_is_pi_over_four(self): + """Mean angle is pi/4 for diagonal (+x, +y) motion.""" + data = np.array( + [ + [[0, 5], [0, 5]], + [[1, 6], [1, 6]], + [[2, 7], [2, 7]], + ], + dtype=float, + ) + _, mean_angle = kinematics.compute_polarization( + _make_position_dataarray(data), + return_angle=True, + ) + assert np.allclose(mean_angle.values[1:], np.pi / 4, atol=1e-10) + + def test_mean_angle_partial_alignment_matches_vector_average( + self, + partial_alignment_positions, + ): + """Mean angle matches vector average for partial alignment.""" + _, mean_angle = kinematics.compute_polarization( + partial_alignment_positions, + return_angle=True, + ) + expected = np.arctan2(1, 2) + assert np.allclose(mean_angle.values[1:], expected, atol=1e-10) + + def test_mean_angle_is_nan_when_net_vector_cancels( + self, + opposite_positions, + cardinal_directions_positions, + ): + """Mean angle is NaN when heading vectors cancel out.""" + pol_opposite, angle_opposite = kinematics.compute_polarization( + opposite_positions, + return_angle=True, + ) + pol_cardinal, angle_cardinal = kinematics.compute_polarization( + cardinal_directions_positions, + return_angle=True, + ) + assert np.allclose(pol_opposite.values[1:], 0.0, atol=1e-10) + assert np.allclose(pol_cardinal.values[1:], 0.0, atol=1e-10) + assert np.all(np.isnan(angle_opposite.values[1:])) + assert np.all(np.isnan(angle_cardinal.values[1:])) + + def test_mean_angle_rotates_with_global_rotation( + self, + partial_alignment_positions, + ): + """Mean angle shifts by the same amount under global rotation.""" + _, angle_original = kinematics.compute_polarization( + partial_alignment_positions, + return_angle=True, + ) + + x = partial_alignment_positions.sel(space="x") + y = partial_alignment_positions.sel(space="y") + + rotated = partial_alignment_positions.copy() + rotated.loc[{"space": "x"}] = -y + rotated.loc[{"space": "y"}] = x + + _, angle_rotated = kinematics.compute_polarization( + rotated, + return_angle=True, + ) + + expected = angle_original.values[1:] + (np.pi / 2) + expected = (expected + np.pi) % (2 * np.pi) - np.pi + + np.testing.assert_allclose( + angle_rotated.values[1:], + expected, + atol=1e-10, + ) + + def test_mean_angle_wraparound_near_pi_is_handled_correctly(self): + """Headings near +pi and -pi should average leftward, not to zero.""" + # Two individuals moving left with tiny y-offsets in opposite dirs. + # This creates headings very close to +pi and -pi. + data = np.array( + [ + [[0.0, 0.0], [0.0, 0.0]], + [[-1.0, -1.0], [1e-6, -1e-6]], + [[-2.0, -2.0], [2e-6, -2e-6]], + ], + dtype=float, + ) + + polarization, mean_angle = kinematics.compute_polarization( + _make_position_dataarray(data), + return_angle=True, + ) + + assert np.allclose(polarization.values[1:], 1.0, atol=1e-10) + assert np.allclose( + np.abs(mean_angle.values[1:]), + np.pi, + atol=1e-6, + ) + + def test_in_degrees_true_returns_degrees(self): + """in_degrees=True returns angle in degrees.""" + # Two individuals moving in +y direction + data = np.array( + [ + [[0, 0], [0, 0]], + [[0, 0], [1, 1]], + [[0, 0], [2, 2]], + ], + dtype=float, + ) + _, mean_angle_rad = kinematics.compute_polarization( + _make_position_dataarray(data), + return_angle=True, + in_degrees=False, + ) + _, mean_angle_deg = kinematics.compute_polarization( + _make_position_dataarray(data), + return_angle=True, + in_degrees=True, + ) + # +y direction = 90 degrees = pi/2 radians + assert np.allclose(mean_angle_rad.values[1:], np.pi / 2, atol=1e-10) + assert np.allclose(mean_angle_deg.values[1:], 90.0, atol=1e-10)