diff --git a/docs/useq-schema-v2-migration-guide.md b/docs/useq-schema-v2-migration-guide.md new file mode 100644 index 00000000..2ed1320f --- /dev/null +++ b/docs/useq-schema-v2-migration-guide.md @@ -0,0 +1,509 @@ +# useq-schema v2: Major Architecture Overhaul + +## Overview + +The v2 version of `useq-schema` represents a fundamental architectural redesign +that generalizes the multi-dimensional axis iteration pattern to support +arbitrary dimensions while preserving the complex event building, nesting, and +skipping capabilities of the original implementation. This document explains the +new features, how to use and extend them, and the breaking changes from v1. + +## Key Architectural Changes + +### From Fixed Axes to Extensible Axis System + +**v1 Approach**: Hard-coded support for specific axes (`time`, `position`, +`grid`, `channel`, `z`) with bespoke iteration logic in `_iter_sequence.py`. + +**v2 Approach**: Generic, protocol-based system where any object implementing +`AxisIterable` can participate in multi-dimensional iteration. + +### Core Concepts + +#### 1. `AxisIterable[V]` Protocol + +The foundation of v2 is the `AxisIterable` protocol, which defines how any axis +should behave: + +```python +class AxisIterable(BaseModel, Generic[V]): + axis_key: str # Unique identifier for this axis + + @abstractmethod + def __iter__(self) -> Iterator[V]: + """Iterate over axis values""" + + def should_skip(self, prefix: AxesIndex) -> bool: + """Return True to skip this combination""" + return False + + def contribute_to_mda_event( + self, value: V, index: Mapping[str, int] + ) -> MDAEvent.Kwargs: + """Contribute data to the event being built""" + return {} +``` + +#### 2. `SimpleValueAxis[V]` - Basic Implementation + +For simple cases where you just want to iterate over a list of values: + +```python +class SimpleValueAxis(AxisIterable[V]): + values: list[V] = Field(default_factory=list) + + def __iter__(self) -> Iterator[V | MultiAxisSequence]: + yield from self.values +``` + +#### 3. `MultiAxisSequence[EventT]` - The New Sequence Container + +Replaces the old `MDASequence` as the core container, but with generic event +support: + +```python +class MultiAxisSequence(MutableModel, Generic[EventTco]): + axes: tuple[AxisIterable, ...] = () + axis_order: Optional[tuple[str, ...]] = None + value: Any = None # Used when this sequence is nested + event_builder: Optional[EventBuilder[EventTco]] = None + transforms: tuple[EventTransform, ...] = () +``` + +## New Features + +### 1. **Arbitrary Custom Axes** + +You can now define completely custom axes for any dimension: + +```python +# Custom axis for laser power +class LaserPowerAxis(SimpleValueAxis[float]): + axis_key: str = "laser_power" + + def contribute_to_mda_event(self, value: float, index: Mapping[str, int]) -> MDAEvent.Kwargs: + return {"metadata": {"laser_power": value}} + +# Custom axis for temperature +class TemperatureAxis(AxisIterable[float]): + axis_key: str = "temperature" + min_temp: float + max_temp: float + step: float + + def __iter__(self) -> Iterator[float]: + temp = self.min_temp + while temp <= self.max_temp: + yield temp + temp += self.step + + def contribute_to_mda_event(self, value: float, index: Mapping[str, int]) -> MDAEvent.Kwargs: + return {"metadata": {"temperature": value}} +``` + +### 2. **Conditional Skipping with `should_skip`** + +Implement complex conditional logic to skip certain axis combinations: + +```python +class FilteredChannelAxis(SimpleValueAxis[Channel]): + def should_skip(self, prefix: AxesIndex) -> bool: + # Skip FITC channel for even numbered Z positions + z_idx = prefix.get("z", (None, None, None))[0] + current_channel = prefix.get("c", (None, None, None))[1] + + if z_idx is not None and z_idx % 2 == 0: + return current_channel.config == "FITC" + return False +``` + +### 3. **Hierarchical Nested Sequences** + +The new system supports arbitrarily nested sequences that can override or extend +parent axes: + +```python +# Position with custom sub-sequence +sub_sequence = v2.MultiAxisSequence( + value=Position(x=10, y=20), # The value represents this position + axes=( + CustomTemperatureAxis(values=[20, 25, 30]), # Add temperature dimension + v2.ZRangeAround(range=2, step=0.5), # Override parent Z plan + ), + axis_order=("temperature", "z") +) + +main_sequence = v2.MDASequence( + axes=( + v2.TIntervalLoops(interval=1.0, loops=5), + v2.StagePositions([sub_sequence, Position(x=0, y=0)]), + v2.ZRangeAround(range=4, step=1.0), # This gets overridden for the first position + ) +) +``` + +### 4. **Event Transform Pipeline** + +Replace the old hardcoded event modifications with a composable transform +pipeline: + +```python +class CustomTransform(EventTransform[MDAEvent]): + def __call__( + self, + event: MDAEvent, + *, + prev_event: MDAEvent | None, + make_next_event: Callable[[], MDAEvent | None], + ) -> Iterable[MDAEvent]: + # Modify event + if event.index.get("c") == 0: # First channel + event = event.model_copy(update={"exposure": 100}) + + # Can return multiple events, no events, or modify the event + return [event] + +seq = v2.MDASequence( + axes=(...), + transforms=(CustomTransform(), v2.KeepShutterOpenTransform(("z",))) +) +``` + +#### 4.1 **Built-in Transforms** + +v2 provides several built-in transforms that replicate v1 behavior: + +```python +# Autofocus transform - inserts hardware autofocus events +v2.AutoFocusTransform(autofocus_plan) + +# Shutter management - keeps shutter open across specified axes +v2.KeepShutterOpenTransform(("z", "c")) + +# Event timing - marks first frame of each timepoint for timer reset +v2.ResetEventTimerTransform() +``` + +#### 4.2 **Non-Imaging Events with Transforms** + +A key innovation in v2 is the ability to use transforms to insert **non-imaging +events** that don't contribute to the sequence shape. This addresses GitHub +issue [#41](https://github.com/pymmcore-plus/useq-schema/issues/41) for use +cases like laser measurements and Raman spectroscopy: + +```python +class LaserMeasurementTransform(EventTransform[MDAEvent]): + """Insert laser measurement events after BF z-stacks.""" + + def __call__( + self, + event: MDAEvent, + *, + prev_event: MDAEvent | None, + make_next_event: Callable[[], MDAEvent | None], + ) -> Iterable[MDAEvent]: + # Yield the original imaging event + yield event + + # If this is the last event in a BF z-stack, add laser measurements + if (event.channel and event.channel.config == "BF" and + self._is_last_z_event(event, make_next_event)): + + # Insert 5 laser measurement events at different points + for i, (x_offset, y_offset) in enumerate([(0, 0), (10, 0), (0, 10), (-10, 0), (0, -10)]): + laser_event = MDAEvent( + index={"t": event.index.get("t", 0), "laser": i}, + x_pos=(event.x_pos or 0) + x_offset, + y_pos=(event.y_pos or 0) + y_offset, + action=CustomAction(type="laser_measurement", data={"laser_power": 75}) + ) + yield laser_event + + def _is_last_z_event(self, event: MDAEvent, make_next_event: Callable) -> bool: + next_event = make_next_event() + return (next_event is None or + next_event.channel is None or + next_event.channel.config != "BF") + +# Usage for the GitHub issue #41 use case: +# 1. Collect BF z-stack → 2. Laser measurements → 3. GFP z-stack +seq = v2.MDASequence( + channels=["BF", "GFP"], + z_plan=v2.ZRangeAround(range=2, step=0.5), + transforms=(LaserMeasurementTransform(),) +) + +# This generates: +# - BF z-stack events (contribute to shape) +# - 5 laser measurement events (inserted by transform, don't affect shape) +# - GFP z-stack events (contribute to shape) +``` + +### 5. **Pluggable Event Builders** + +Customize how raw axis data gets converted into events: + +```python +class CustomEventBuilder(EventBuilder[MyCustomEvent]): + def __call__( + self, axes_index: AxesIndex, context: tuple[MultiAxisSequence, ...] + ) -> MyCustomEvent: + # Build your custom event type + return MyCustomEvent(...) + +seq = v2.MultiAxisSequence( + axes=(...), + event_builder=CustomEventBuilder() +) +``` + +### 6. **Infinite Axes Support** + +Unlike v1, v2 supports infinite sequences: + +```python +class InfiniteTimeAxis(AxisIterable[float]): + axis_key: str = "t" + interval: float = 1.0 + + def __iter__(self) -> Iterator[float]: + time = 0.0 + while True: + yield time + time += self.interval +``` + +## Migration from v1 to v2 + +### Backward Compatibility + +v2 `MDASequence` accepts the same constructor parameters as v1 through automatic +conversion: + +```python +# This v1 style still works +seq = v2.MDASequence( + time_plan={"interval": 1.0, "loops": 5}, + z_plan={"range": 4, "step": 1}, + channels=["DAPI", "FITC"], + stage_positions=[(10, 20, 5)], +) + +# Internally converted to: +seq = v2.MDASequence( + axes=( + v2.TIntervalLoops(interval=1.0, loops=5), + v2.StagePositions([v2.Position(x=10, y=20, z=5)]), + v2.ZRangeAround(range=4, step=1), + v2.ChannelsPlan(values=[Channel(config="DAPI"), Channel(config="FITC")]), + ), + axis_order=("t", "p", "z", "c") # Derived from AXES constant +) +``` + +### Breaking Changes + +#### 1. **Event Building Architecture** + +**v1**: Monolithic `_iter_sequence` function with hardcoded event building +logic. + +**v2**: Separation of concerns: + +- Axis iteration handled by `iterate_multi_dim_sequence` +- Event building handled by `EventBuilder` +- Event modification handled by `EventTransform` pipeline + +#### 2. **Shape and Sizes Properties** + +```python +# v1 +seq.shape # Returns tuple of sizes +seq.sizes # Returns mapping of axis -> size + +# v2 - DEPRECATED +seq.shape # Deprecated - raises FutureWarning +seq.sizes # Deprecated - raises FutureWarning + +# v2 - New approach +len(axis) for axis in seq.axes # Get size per axis +seq.is_finite() # Check if sequence is finite +``` + +#### 3. **Axis Access** + +```python +# v1 +seq.time_plan +seq.z_plan +seq.channels +seq.stage_positions +seq.grid_plan + +# v2 - Legacy properties still work but deprecated +seq.time_plan # Returns the time axis or None +seq.z_plan # Returns the z axis or None + +# v2 - New approach +time_axis = next((ax for ax in seq.axes if ax.axis_key == "t"), None) +z_axis = next((ax for ax in seq.axes if ax.axis_key == "z"), None) +``` + +#### 4. **Custom Skip Logic** + +**v1**: Hardcoded in `_should_skip` function within `_iter_sequence.py` + +**v2**: Implemented per-axis via `should_skip` method: + +```python +class CustomZAxis(v2.ZRangeAround): + def should_skip(self, prefix: AxesIndex) -> bool: + # Custom logic here + return super().should_skip(prefix) +``` + +#### Z. **Z-Plans yield Positions, not floats** + +**v1**: Z plans yielded floats representing Z positions. + +**v2**: Z plans yield `Position` objects that (usually) include only z +coordinates: + +## Built-in Axes in v2 + +All the original v1 plans are now `AxisIterable` implementations: + +### Time Axes + +- `TIntervalLoops` +- `TIntervalDuration` +- `TDurationLoops` +- `MultiPhaseTimePlan` + +### Z Axes + +- `ZRangeAround` +- `ZTopBottom` +- `ZAboveBelow` +- `ZAbsolutePositions` +- `ZRelativePositions` + +### Channel Axes + +- `ChannelsPlan` (wraps list of `Channel` objects) + +### Position Axes + +- `StagePositions` (wraps list of `Position` objects) + +### Grid Axes + +- `GridRowsColumns` +- `GridFromEdges` +- `GridWidthHeight` +- `RandomPoints` + +## Extension Examples + +### Creating a Custom Scientific Axis + +```python +class PHAxis(AxisIterable[float]): + """Axis for pH titration experiments.""" + axis_key: str = "ph" + start_ph: float = 6.0 + end_ph: float = 8.0 + steps: int = 10 + + def __iter__(self) -> Iterator[float]: + step_size = (self.end_ph - self.start_ph) / (self.steps - 1) + for i in range(self.steps): + yield self.start_ph + i * step_size + + def contribute_to_mda_event(self, value: float, index: Mapping[str, int]) -> MDAEvent.Kwargs: + return { + "metadata": {"ph": value}, + "properties": [("pH_Controller", "target_ph", value)] + } + + def should_skip(self, prefix: AxesIndex) -> bool: + # Skip pH 7.5+ for channel index > 2 + channel_idx = prefix.get("c", (None, None, None))[0] + return channel_idx is not None and channel_idx > 2 and value >= 7.5 +``` + +### Complex Nested Workflow + +```python +# Different regions with different imaging parameters +region1 = v2.MultiAxisSequence( + value=v2.Position(x=0, y=0, name="Region1"), + axes=( + v2.ZRangeAround(range=10, step=0.2), # High-res Z + v2.ChannelsPlan(["DAPI", "FITC", "Cy3"]), # 3 channels + ) +) + +region2 = v2.MultiAxisSequence( + value=v2.Position(x=100, y=100, name="Region2"), + axes=( + v2.ZRangeAround(range=20, step=0.5), # Lower-res Z + v2.ChannelsPlan(["DAPI", "Cy5"]), # Only 2 channels + PHAxis(start_ph=6.5, end_ph=7.5, steps=5), # pH titration + ) +) + +main_seq = v2.MDASequence( + axes=( + v2.TIntervalLoops(interval=60, loops=10), # Every minute for 10 minutes + v2.StagePositions([region1, region2]), + ), + transforms=( + CustomExposureTransform(), # Adjust exposure per region + v2.KeepShutterOpenTransform(("z", "c")), # Keep shutter open for Z and C + ) +) +``` + +## Performance and Design Benefits + +### Separation of Concerns + +- **Axis logic**: Isolated in individual `AxisIterable` implementations +- **Event building**: Centralized in `EventBuilder` +- **Event modification**: Composable `EventTransform` pipeline + +### Extensibility + +- Add new dimensions without modifying core code +- Custom skip logic per axis +- Pluggable event builders for different event types +- Composable transform pipeline + +### Type Safety + +- Generic types ensure type safety across the pipeline +- Protocol-based design enables duck typing +- Clear interfaces for each component + +### Maintainability + +- Individual axis implementations are easier to test and debug +- Transform pipeline is easier to reason about than monolithic logic +- Clear separation between axis iteration and event building + +## Summary + +useq-schema v2 transforms the library from a fixed-axis system to a fully +extensible, protocol-based architecture that supports: + +- **Arbitrary custom axes** with their own iteration and contribution logic +- **Conditional skipping** per axis with full context awareness +- **Hierarchical nesting** with axis override capabilities +- **Composable transforms** for event modification +- **Pluggable event builders** for different event types +- **Type-safe extensibility** through generic protocols + +While maintaining full backward compatibility with v1 API patterns, v2 opens up +useq-schema for complex, multi-dimensional experimental workflows that were +impossible to express in the original architecture. diff --git a/pyproject.toml b/pyproject.toml index d564ed92..2e840f93 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ classifiers = [ ] dynamic = ["version"] dependencies = [ - "pydantic >=2.6", + "pydantic >=2.10", "numpy >=2.1.0; python_version >= '3.13'", "numpy >=1.26.0; python_version >= '3.12'", "numpy >=1.25.2", @@ -123,15 +123,14 @@ keep-runtime-typing = true [tool.ruff.lint.per-file-ignores] "tests/*.py" = ["D", "S101", "E501", "SLF"] -[tool.ruff.lint.flake8-tidy-imports] -# Disallow all relative imports. -ban-relative-imports = "all" - # https://docs.pytest.org/en/6.2.x/customize.html [tool.pytest.ini_options] minversion = "6.0" testpaths = ["tests"] -filterwarnings = ["error"] +filterwarnings = [ + "error", + "ignore:.*Positions no longer have a sequence attribute", +] # https://mypy.readthedocs.io/en/stable/config_file.html [tool.mypy] @@ -143,6 +142,10 @@ show_error_codes = true pretty = true plugins = ["pydantic.mypy"] +[tool.pyright] +include = ["src", "tests/v2"] +reportArgumentType = false + # https://coverage.readthedocs.io/en/6.4/config.html [tool.coverage.run] source = ["useq"] diff --git a/src/useq/__init__.py b/src/useq/__init__.py index 4b516627..f775c8f0 100644 --- a/src/useq/__init__.py +++ b/src/useq/__init__.py @@ -5,6 +5,7 @@ from useq._actions import AcquireImage, Action, CustomAction, HardwareAutofocus from useq._channel import Channel +from useq._enums import Axis, RelativeTo, Shape from useq._grid import ( GridFromEdges, GridRowsColumns, @@ -12,11 +13,10 @@ MultiPointPlan, RandomPoints, RelativeMultiPointPlan, - Shape, ) from useq._hardware_autofocus import AnyAutofocusPlan, AutoFocusPlan, AxesBasedAF from useq._mda_event import Channel as EventChannel -from useq._mda_event import MDAEvent, PropertyTuple, SLMImage +from useq._mda_event import MDAEvent, MutableMDAEvent, PropertyTuple, SLMImage from useq._mda_sequence import MDASequence from useq._plate import WellPlate, WellPlatePlan from useq._plate_registry import register_well_plates, registered_well_plate_keys @@ -29,7 +29,6 @@ TIntervalDuration, TIntervalLoops, ) -from useq._utils import Axis from useq._z import ( AnyZPlan, ZAboveBelow, @@ -65,12 +64,14 @@ "MDASequence", "MultiPhaseTimePlan", "MultiPointPlan", + "MutableMDAEvent", "OrderMode", "Position", # alias for AbsolutePosition "PropertyTuple", "RandomPoints", "RelativeMultiPointPlan", "RelativePosition", + "RelativeTo", "SLMImage", "Shape", "TDurationLoops", diff --git a/src/useq/_base_model.py b/src/useq/_base_model.py index be2cae18..8681f1dd 100644 --- a/src/useq/_base_model.py +++ b/src/useq/_base_model.py @@ -1,5 +1,4 @@ from pathlib import Path -from re import findall from types import MappingProxyType from typing import ( IO, @@ -27,10 +26,11 @@ _T = TypeVar("_T", bound="FrozenModel") _Y = TypeVar("_Y", bound="UseqModel") -PYDANTIC_VERSION = tuple(int(x) for x in findall(r"\d_", pydantic.__version__)[:3]) -GET_DEFAULT_KWARGS: dict = {} +PYDANTIC_VERSION = tuple(int(x) for x in pydantic.__version__.split(".")[:2]) if PYDANTIC_VERSION >= (2, 10): - GET_DEFAULT_KWARGS = {"validated_data": {}} + GET_DEFAULT_KWARGS: dict = {"validated_data": {}} +else: + GET_DEFAULT_KWARGS = {} class _ReplaceableModel(BaseModel): @@ -87,9 +87,9 @@ class MutableModel(_ReplaceableModel): ) -class UseqModel(FrozenModel): +class IOMixin(BaseModel): @classmethod - def from_file(cls: type[_Y], path: Union[str, Path]) -> _Y: + def from_file(cls, path: Union[str, Path]) -> "Self": """Return an instance of this class from a file. Supports JSON and YAML.""" path = Path(path) if path.suffix in {".yaml", ".yml"}: @@ -148,3 +148,9 @@ def yaml( exclude_none=exclude_none, ) return yaml.safe_dump(data, stream=stream) + + +class MutableUseqModel(IOMixin, MutableModel): ... + + +class UseqModel(FrozenModel, IOMixin): ... diff --git a/src/useq/_enums.py b/src/useq/_enums.py new file mode 100644 index 00000000..3a0a10b3 --- /dev/null +++ b/src/useq/_enums.py @@ -0,0 +1,69 @@ +from enum import Enum +from typing import Final, Literal + + +class Axis(str, Enum): + """Recognized useq-schema axis keys. + + Attributes + ---------- + TIME : Literal["t"] + Time axis. + POSITION : Literal["p"] + XY Stage Position axis. + GRID : Literal["g"] + Grid axis (usually an additional row/column iteration around a position). + CHANNEL : Literal["c"] + Channel axis. + Z : Literal["z"] + Z axis. + """ + + TIME = "t" + POSITION = "p" + GRID = "g" + CHANNEL = "c" + Z = "z" + + def __str__(self) -> Literal["t", "p", "g", "c", "z"]: + return self.value + + +# note: order affects the default axis_order in MDASequence +AXES: Final[tuple[Axis, ...]] = ( + Axis.TIME, + Axis.POSITION, + Axis.GRID, + Axis.CHANNEL, + Axis.Z, +) + + +class RelativeTo(Enum): + """Where the coordinates of the grid are relative to. + + Attributes + ---------- + center : Literal['center'] + Grid is centered around the origin. + top_left : Literal['top_left'] + Grid is positioned such that the top left corner is at the origin. + """ + + center = "center" + top_left = "top_left" + + +class Shape(Enum): + """Shape of the bounding box for random points. + + Attributes + ---------- + ELLIPSE : Literal['ellipse'] + The bounding box is an ellipse. + RECTANGLE : Literal['rectangle'] + The bounding box is a rectangle. + """ + + ELLIPSE = "ellipse" + RECTANGLE = "rectangle" diff --git a/src/useq/_grid.py b/src/useq/_grid.py index 46a9366d..c7ef7577 100644 --- a/src/useq/_grid.py +++ b/src/useq/_grid.py @@ -4,21 +4,22 @@ import math import warnings from collections.abc import Iterable, Iterator, Sequence -from enum import Enum from typing import ( TYPE_CHECKING, Annotated, Any, Callable, + Generic, Optional, Union, ) import numpy as np from annotated_types import Ge, Gt -from pydantic import Field, field_validator, model_validator +from pydantic import BaseModel, Field, field_validator, model_validator from typing_extensions import Self, TypeAlias +from useq._enums import RelativeTo, Shape from useq._point_visiting import OrderMode, TraversalOrder from useq._position import ( AbsolutePosition, @@ -37,47 +38,15 @@ MIN_RANDOM_POINTS = 10000 -class RelativeTo(Enum): - """Where the coordinates of the grid are relative to. - - Attributes - ---------- - center : Literal['center'] - Grid is centered around the origin. - top_left : Literal['top_left'] - Grid is positioned such that the top left corner is at the origin. - """ - - center = "center" - top_left = "top_left" - - -# used in iter_indices below, to determine the order in which indices are yielded -class _GridPlan(_MultiPointPlan[PositionT]): - """Base class for all grid plans. - - Attributes - ---------- - overlap : float | Tuple[float, float] - Overlap between grid positions in percent. If a single value is provided, it is - used for both x and y. If a tuple is provided, the first value is used - for x and the second for y. - mode : OrderMode - Define the ways of ordering the grid positions. Options are - row_wise, column_wise, row_wise_snake, column_wise_snake and spiral. - By default, row_wise_snake. - fov_width : Optional[float] - Width of the field of view in microns. If not provided, acquisition engines - should use current width of the FOV based on the current objective and camera. - Engines MAY override this even if provided. - fov_height : Optional[float] - Height of the field of view in microns. If not provided, acquisition engines - should use current height of the FOV based on the current objective and camera. - Engines MAY override this even if provided. - """ - +class _GridMixin(BaseModel, Generic[PositionT]): overlap: tuple[float, float] = Field(default=(0.0, 0.0), frozen=True) mode: OrderMode = Field(default=OrderMode.row_wise_snake, frozen=True) + fov_width: Optional[float] = None + fov_height: Optional[float] = None + + @property + def is_relative(self) -> bool: + return True @field_validator("overlap", mode="before") def _validate_overlap(cls, v: Any) -> tuple[float, float]: @@ -105,24 +74,18 @@ def _ncolumns(self, dx: float) -> int: """Return the number of columns, given a grid step size.""" raise NotImplementedError - def num_positions(self) -> int: - """Return the number of individual positions in the grid. + def __iter__(self) -> Iterator[PositionT]: # type: ignore [override] + yield from self.iter_grid_positions() - Note: For GridFromEdges and GridWidthHeight, this will depend on field of view - size. If no field of view size is provided, the number of positions will be 1. - """ - if isinstance(self, (GridFromEdges, GridWidthHeight)) and ( - self.fov_width is None or self.fov_height is None - ): - raise ValueError( - "Retrieving the number of positions in a GridFromEdges or " - "GridWidthHeight plan requires the field of view size to be set." - ) + def _step_size(self, fov_width: float, fov_height: float) -> tuple[float, float]: + dx = fov_width - (fov_width * self.overlap[0]) / 100 + dy = fov_height - (fov_height * self.overlap[1]) / 100 + return dx, dy - dx, dy = self._step_size(self.fov_width or 1, self.fov_height or 1) - rows = self._nrows(dy) - cols = self._ncolumns(dx) - return rows * cols + def _build_position(self, **kwargs: Any) -> PositionT: + """Build a position object for this grid plan.""" + pos_cls = RelativePosition if self.is_relative else AbsolutePosition + return pos_cls(**kwargs) # type: ignore def iter_grid_positions( self, @@ -142,9 +105,8 @@ def iter_grid_positions( x0 = self._offset_x(dx) y0 = self._offset_y(dy) - pos_cls = RelativePosition if self.is_relative else AbsolutePosition for idx, (r, c) in enumerate(order.generate_indices(rows, cols)): - yield pos_cls( # type: ignore [misc] + yield self._build_position( x=x0 + c * dx, y=y0 - r * dy, row=r, @@ -152,16 +114,27 @@ def iter_grid_positions( name=f"{str(idx).zfill(4)}", ) - def __iter__(self) -> Iterator[PositionT]: # type: ignore [override] - yield from self.iter_grid_positions() + def num_positions(self) -> int: + """Return the number of individual positions in the grid. - def _step_size(self, fov_width: float, fov_height: float) -> tuple[float, float]: - dx = fov_width - (fov_width * self.overlap[0]) / 100 - dy = fov_height - (fov_height * self.overlap[1]) / 100 - return dx, dy + Note: For GridFromEdges and GridWidthHeight, this will depend on field of view + size. If no field of view size is provided, the number of positions will be 1. + """ + if isinstance(self, (GridFromEdges, GridWidthHeight)) and ( + self.fov_width is None or self.fov_height is None + ): + raise ValueError( + "Retrieving the number of positions in a GridFromEdges or " + "GridWidthHeight plan requires the field of view size to be set." + ) + + dx, dy = self._step_size(self.fov_width or 1, self.fov_height or 1) + rows = self._nrows(dy) + cols = self._ncolumns(dx) + return rows * cols -class GridFromEdges(_GridPlan[AbsolutePosition]): +class GridFromEdges(_GridMixin, _MultiPointPlan[AbsolutePosition]): """Yield absolute stage positions to cover a bounded area. The bounded area is defined by top, left, bottom and right edges in @@ -253,7 +226,7 @@ def plot(self, *, show: bool = True) -> Axes: ) -class GridRowsColumns(_GridPlan[RelativePosition]): +class GridRowsColumns(_GridMixin, _MultiPointPlan[RelativePosition]): """Grid plan based on number of rows and columns. Attributes @@ -311,7 +284,7 @@ def _offset_y(self, dy: float) -> float: GridRelative = GridRowsColumns -class GridWidthHeight(_GridPlan[RelativePosition]): +class GridWidthHeight(_GridMixin, _MultiPointPlan[RelativePosition]): """Grid plan based on total width and height. Attributes @@ -371,21 +344,6 @@ def _offset_y(self, dy: float) -> float: # ------------------------ RANDOM ------------------------ -class Shape(Enum): - """Shape of the bounding box for random points. - - Attributes - ---------- - ELLIPSE : Literal['ellipse'] - The bounding box is an ellipse. - RECTANGLE : Literal['rectangle'] - The bounding box is a rectangle. - """ - - ELLIPSE = "ellipse" - RECTANGLE = "rectangle" - - class RandomPoints(_MultiPointPlan[RelativePosition]): """Yield random points in a specified geometric shape. diff --git a/src/useq/_iter_sequence.py b/src/useq/_iter_sequence.py index a57baddd..f89fb68c 100644 --- a/src/useq/_iter_sequence.py +++ b/src/useq/_iter_sequence.py @@ -7,9 +7,10 @@ from typing_extensions import TypedDict from useq._channel import Channel # noqa: TC001 # noqa: TCH001 +from useq._enums import AXES, Axis from useq._mda_event import Channel as EventChannel from useq._mda_event import MDAEvent, ReadOnlyDict -from useq._utils import AXES, Axis, _has_axes +from useq._utils import _has_axes from useq._z import AnyZPlan # noqa: TC001 # noqa: TCH001 if TYPE_CHECKING: diff --git a/src/useq/_mda_event.py b/src/useq/_mda_event.py index 76151cf5..ce477f26 100644 --- a/src/useq/_mda_event.py +++ b/src/useq/_mda_event.py @@ -5,6 +5,7 @@ from typing import ( TYPE_CHECKING, Any, + ClassVar, NamedTuple, Optional, TypedDict, @@ -12,21 +13,23 @@ import numpy as np import numpy.typing as npt -from pydantic import Field, GetCoreSchemaHandler, field_validator, model_validator +from pydantic import ( + ConfigDict, + Field, + GetCoreSchemaHandler, + field_serializer, + model_validator, +) from pydantic_core import core_schema from useq._actions import AcquireImage, AnyAction -from useq._base_model import UseqModel - -try: - from pydantic import field_serializer -except ImportError: - field_serializer = None # type: ignore +from useq._base_model import MutableUseqModel, UseqModel if TYPE_CHECKING: from collections.abc import Sequence from useq._mda_sequence import MDASequence + from useq.v2 import MultiAxisSequence ReprArgs = Sequence[tuple[Optional[str], Any]] @@ -52,6 +55,12 @@ def __eq__(self, _value: object) -> bool: return self.config == _value return super().__eq__(_value) + @model_validator(mode="before") + def _cast_config(cls, v: Any) -> Any: + if isinstance(v, str): + return {"config": v} + return v + if TYPE_CHECKING: class Kwargs(TypedDict, total=False): @@ -165,7 +174,76 @@ def __get_pydantic_core_schema__( ) -class MDAEvent(UseqModel): +class MutableMDAEvent(MutableUseqModel): + index: ReadOnlyDict = Field(default_factory=ReadOnlyDict) + channel: Optional[Channel] = None + exposure: Optional[float] = Field(default=None, gt=0.0) + min_start_time: Optional[float] = None # time in sec + pos_name: Optional[str] = None + x_pos: Optional[float] = None + y_pos: Optional[float] = None + z_pos: Optional[float] = None + slm_image: Optional[SLMImage] = None + sequence: Any = Field(default=None, repr=False) + properties: Optional[list[PropertyTuple]] = None + metadata: dict[str, Any] = Field(default_factory=dict) + action: AnyAction = Field(default_factory=AcquireImage, discriminator="type") + keep_shutter_open: bool = False + reset_event_timer: bool = False + + def freeze(self) -> "MDAEvent": + """Return a frozen version of this event.""" + return MDAEvent.model_construct(**self.model_dump(exclude_unset=True)) + + def __eq__(self, other: object) -> bool: + # exclude sequence from equality check + if not isinstance(other, MDAEvent): + return NotImplemented + return ( + self.index == other.index + and self.channel == other.channel + and self.exposure == other.exposure + and self.min_start_time == other.min_start_time + and self.pos_name == other.pos_name + and self.x_pos == other.x_pos + and self.y_pos == other.y_pos + and self.z_pos == other.z_pos + and self.slm_image == other.slm_image + and self.properties == other.properties + and self.metadata == other.metadata + and self.action == other.action + and self.keep_shutter_open == other.keep_shutter_open + and self.reset_event_timer == other.reset_event_timer + ) + + _si = field_serializer("index", mode="plain")(lambda v: dict(v)) + _sx = field_serializer("x_pos", mode="plain")(_float_or_none) + _sy = field_serializer("y_pos", mode="plain")(_float_or_none) + _sz = field_serializer("z_pos", mode="plain")(_float_or_none) + + if TYPE_CHECKING: + + class Kwargs(TypedDict, total=False): + """Type for the kwargs passed to the MDA event.""" + + index: dict[str, int] + channel: Channel | Channel.Kwargs + exposure: float + min_start_time: float + pos_name: str + x_pos: float + y_pos: float + z_pos: float + slm_image: SLMImage | SLMImage.Kwargs | npt.ArrayLike + sequence: MDASequence | MultiAxisSequence | dict + properties: list[tuple[str, str, Any]] + metadata: dict + action: AnyAction + keep_shutter_open: bool + reset_event_timer: bool + + +class MDAEvent(MutableMDAEvent): """Define a single event in a [`MDASequence`][useq.MDASequence]. Usually, this object will be generator by iterating over a @@ -237,49 +315,7 @@ class MDAEvent(UseqModel): `False`. """ - index: ReadOnlyDict = Field(default_factory=ReadOnlyDict) - channel: Optional[Channel] = None - exposure: Optional[float] = Field(default=None, gt=0.0) - min_start_time: Optional[float] = None # time in sec - pos_name: Optional[str] = None - x_pos: Optional[float] = None - y_pos: Optional[float] = None - z_pos: Optional[float] = None - slm_image: Optional[SLMImage] = None - sequence: Optional["MDASequence"] = Field(default=None, repr=False) - properties: Optional[list[PropertyTuple]] = None - metadata: dict[str, Any] = Field(default_factory=dict) - action: AnyAction = Field(default_factory=AcquireImage, discriminator="type") - keep_shutter_open: bool = False - reset_event_timer: bool = False + model_config: ClassVar["ConfigDict"] = ConfigDict(frozen=True) - @field_validator("channel", mode="before") - def _validate_channel(cls, val: Any) -> Any: - return Channel(config=val) if isinstance(val, str) else val - if field_serializer is not None: - _si = field_serializer("index", mode="plain")(lambda v: dict(v)) - _sx = field_serializer("x_pos", mode="plain")(_float_or_none) - _sy = field_serializer("y_pos", mode="plain")(_float_or_none) - _sz = field_serializer("z_pos", mode="plain")(_float_or_none) - - if TYPE_CHECKING: - - class Kwargs(TypedDict, total=False): - """Type for the kwargs passed to the MDA event.""" - - index: dict[str, int] - channel: Channel | Channel.Kwargs - exposure: float - min_start_time: float - pos_name: str - x_pos: float - y_pos: float - z_pos: float - slm_image: SLMImage | SLMImage.Kwargs | npt.ArrayLike - sequence: MDASequence | dict - properties: list[tuple[str, str, Any]] - metadata: dict - action: AnyAction - keep_shutter_open: bool - reset_event_timer: bool +MutableMDAEvent.__doc__ = MDAEvent.__doc__ diff --git a/src/useq/_mda_sequence.py b/src/useq/_mda_sequence.py index dc4d653d..93fed162 100644 --- a/src/useq/_mda_sequence.py +++ b/src/useq/_mda_sequence.py @@ -16,13 +16,14 @@ from useq._base_model import UseqModel from useq._channel import Channel +from useq._enums import AXES, Axis from useq._grid import MultiPointPlan # noqa: TC001 from useq._hardware_autofocus import AnyAutofocusPlan, AxesBasedAF from useq._iter_sequence import iter_sequence from useq._plate import WellPlatePlan from useq._position import Position, PositionBase from useq._time import AnyTimePlan # noqa: TC001 -from useq._utils import AXES, Axis, TimeEstimate, estimate_sequence_duration +from useq._utils import TimeEstimate, estimate_sequence_duration from useq._z import AnyZPlan # noqa: TC001 if TYPE_CHECKING: diff --git a/src/useq/_plate.py b/src/useq/_plate.py index ef99796c..591d1531 100644 --- a/src/useq/_plate.py +++ b/src/useq/_plate.py @@ -22,7 +22,8 @@ ) from useq._base_model import FrozenModel, UseqModel -from useq._grid import RandomPoints, RelativeMultiPointPlan, Shape +from useq._enums import Shape +from useq._grid import RandomPoints, RelativeMultiPointPlan from useq._plate_registry import _PLATE_REGISTRY from useq._position import Position, PositionBase, RelativePosition diff --git a/src/useq/_plate_registry.py b/src/useq/_plate_registry.py index e16fcbc7..a3976714 100644 --- a/src/useq/_plate_registry.py +++ b/src/useq/_plate_registry.py @@ -4,7 +4,9 @@ if TYPE_CHECKING: from collections.abc import Iterable, Mapping - from typing import Required, TypeAlias, TypedDict + from typing import TypeAlias, TypedDict + + from typing_extensions import Required from useq._plate import WellPlate diff --git a/src/useq/_time.py b/src/useq/_time.py index 84c7495a..0ef9d2e7 100644 --- a/src/useq/_time.py +++ b/src/useq/_time.py @@ -1,16 +1,35 @@ from collections.abc import Iterator, Sequence from datetime import timedelta -from typing import Annotated, Any, Union +from typing import Annotated, Any, Optional, Union -from pydantic import BeforeValidator, Field, PlainSerializer, model_validator +from pydantic import ( + BeforeValidator, + Field, + PlainSerializer, + model_validator, +) from useq._base_model import FrozenModel + +def _validate_delta(v: Any) -> timedelta: + if isinstance(v, dict): + v = timedelta(**v) + elif isinstance(v, (str, int, float)): + v = timedelta(seconds=float(v)) # assuming ISO 8601 or similar + + if not isinstance(v, timedelta): + raise TypeError(f"Expected timedelta, str, int, or dict, got {type(v)}") + if v.total_seconds() < 0: + raise ValueError("Duration must be non-negative") + return v + + # slightly modified so that we can accept dict objects as input # and serialize to total_seconds -TimeDelta = Annotated[ +NonNegativeTimeDelta = Annotated[ timedelta, - BeforeValidator(lambda v: timedelta(**v) if isinstance(v, dict) else v), + BeforeValidator(_validate_delta), PlainSerializer(lambda td: td.total_seconds()), ] @@ -24,6 +43,9 @@ def __iter__(self) -> Iterator[float]: # type: ignore yield td.total_seconds() def num_timepoints(self) -> int: + return len(self) + + def __len__(self) -> int: return self.loops # type: ignore # TODO def deltas(self) -> Iterator[timedelta]: @@ -48,7 +70,7 @@ class TIntervalLoops(TimePlan): of conflict. By default, `False`. """ - interval: TimeDelta + interval: NonNegativeTimeDelta loops: int = Field(..., gt=0) @property @@ -71,11 +93,15 @@ class TDurationLoops(TimePlan): of conflict. By default, `False`. """ - duration: TimeDelta + duration: NonNegativeTimeDelta loops: int = Field(..., gt=0) @property def interval(self) -> timedelta: + if self.loops == 1: + # Special case: with only 1 loop, interval is meaningless + # Return zero to indicate instant + return timedelta(0) # -1 makes it so that the last loop will *occur* at duration, not *finish* return self.duration / (self.loops - 1) @@ -95,13 +121,29 @@ class TIntervalDuration(TimePlan): of conflict. By default, `True`. """ - interval: TimeDelta - duration: TimeDelta + interval: NonNegativeTimeDelta + duration: Optional[NonNegativeTimeDelta] = None prioritize_duration: bool = True + def __iter__(self) -> Iterator[float]: # type: ignore[override] + duration_s = self.duration.total_seconds() if self.duration else None + interval_s = self.interval.total_seconds() + t = 0.0 + # when `duration_s` is None, the `or` makes it always True → infinite; + # otherwise it stops once t > duration_s + while duration_s is None or t <= duration_s: + yield t + t += interval_s + @property def loops(self) -> int: - return self.duration // self.interval + 1 + return len(self) + + def __len__(self) -> int: + """Return the number of time points in this plan.""" + if self.duration is None: + raise ValueError("Cannot determine length of infinite time plan") + return int(self.duration.total_seconds() / self.interval.total_seconds()) + 1 SinglePhaseTimePlan = Union[TIntervalDuration, TIntervalLoops, TDurationLoops] @@ -131,9 +173,12 @@ def deltas(self) -> Iterator[timedelta]: if td is not None: accum += td - def num_timepoints(self) -> int: - # TODO: is this correct? - return sum(phase.loops for phase in self.phases) - 1 + def __len__(self) -> int: + """Return the number of time points in this plan.""" + phase_sum = sum(len(phase) for phase in self.phases) + # subtract 1 for the first time point of each phase + # except the first one + return phase_sum - len(self.phases) + 1 @model_validator(mode="before") @classmethod diff --git a/src/useq/_utils.py b/src/useq/_utils.py index f081e904..d73da340 100644 --- a/src/useq/_utils.py +++ b/src/useq/_utils.py @@ -2,13 +2,12 @@ import re from datetime import timedelta -from enum import Enum from typing import TYPE_CHECKING, NamedTuple from useq._time import MultiPhaseTimePlan if TYPE_CHECKING: - from typing import Final, Literal, TypeVar + from typing import TypeVar from typing_extensions import TypeGuard @@ -19,44 +18,6 @@ VT = TypeVar("VT") -# could be an enum, but this more easily allows Axis.Z to be a string -class Axis(str, Enum): - """Recognized useq-schema axis keys. - - Attributes - ---------- - TIME : Literal["t"] - Time axis. - POSITION : Literal["p"] - XY Stage Position axis. - GRID : Literal["g"] - Grid axis (usually an additional row/column iteration around a position). - CHANNEL : Literal["c"] - Channel axis. - Z : Literal["z"] - Z axis. - """ - - TIME = "t" - POSITION = "p" - GRID = "g" - CHANNEL = "c" - Z = "z" - - def __str__(self) -> Literal["t", "p", "g", "c", "z"]: - return self.value - - -# note: order affects the default axis_order in MDASequence -AXES: Final[tuple[Axis, ...]] = ( - Axis.TIME, - Axis.POSITION, - Axis.GRID, - Axis.CHANNEL, - Axis.Z, -) - - class TimeEstimate(NamedTuple): """Record of time estimation results. diff --git a/src/useq/_z.py b/src/useq/_z.py index 622d1451..564b5c0b 100644 --- a/src/useq/_z.py +++ b/src/useq/_z.py @@ -37,6 +37,10 @@ def positions(self) -> Sequence[float]: return [float(x) for x in np.arange(start, stop, step)] def num_positions(self) -> int: + return len(self) + + def __len__(self) -> int: + """Get the number of Z positions.""" start, stop, step = self._start_stop_step() if step == 0: return 1 @@ -156,7 +160,7 @@ class ZRelativePositions(ZPlan): def positions(self) -> Sequence[float]: return self.relative - def num_positions(self) -> int: + def __len__(self) -> int: return len(self.relative) @@ -179,7 +183,7 @@ class ZAbsolutePositions(ZPlan): def positions(self) -> Sequence[float]: return self.absolute - def num_positions(self) -> int: + def __len__(self) -> int: return len(self.absolute) @property diff --git a/src/useq/pycromanager.py b/src/useq/pycromanager.py index b1d2bc45..1f2f7ff4 100644 --- a/src/useq/pycromanager.py +++ b/src/useq/pycromanager.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, overload from useq import MDAEvent, MDASequence -from useq._utils import Axis +from useq._enums import Axis if TYPE_CHECKING: from typing_extensions import Literal, Required, TypedDict diff --git a/src/useq/v2/__init__.py b/src/useq/v2/__init__.py new file mode 100644 index 00000000..ae126b97 --- /dev/null +++ b/src/useq/v2/__init__.py @@ -0,0 +1,132 @@ +"""New MDASequence API.""" + +from typing import Any + +import pydantic +from typing_extensions import deprecated + +from useq._actions import AcquireImage, Action, CustomAction, HardwareAutofocus +from useq._channel import Channel +from useq._enums import Axis, RelativeTo, Shape +from useq._hardware_autofocus import AnyAutofocusPlan, AutoFocusPlan, AxesBasedAF +from useq._mda_event import Channel as EventChannel +from useq._mda_event import MDAEvent, MutableMDAEvent, PropertyTuple, SLMImage +from useq._plate import WellPlate, WellPlatePlan +from useq._plate_registry import register_well_plates, registered_well_plate_keys +from useq._point_visiting import OrderMode, TraversalOrder +from useq.v2._axes_iterator import AxisIterable, MultiAxisSequence, SimpleValueAxis +from useq.v2._channels import ChannelsPlan +from useq.v2._grid import ( + GridFromEdges, + GridRowsColumns, + GridWidthHeight, + MultiPointPlan, + RandomPoints, + RelativeMultiPointPlan, +) +from useq.v2._iterate import iterate_multi_dim_sequence +from useq.v2._mda_sequence import MDASequence +from useq.v2._multi_point import MultiPositionPlan +from useq.v2._position import Position +from useq.v2._stage_positions import StagePositions +from useq.v2._time import ( + AnyTimePlan, + MultiPhaseTimePlan, + SinglePhaseTimePlan, + TDurationLoops, + TimePlan, + TIntervalDuration, + TIntervalLoops, +) +from useq.v2._z import ( + AnyZPlan, + ZAboveBelow, + ZAbsolutePositions, + ZPlan, + ZRangeAround, + ZRelativePositions, + ZTopBottom, +) + +AbsolutePosition = Position + + +@deprecated( + "The RelativePosition class is deprecated. " + "Use Position with is_relative=True instead.", + category=DeprecationWarning, + stacklevel=2, +) +def RelativePosition(**kwargs: Any) -> Position: + """Create a relative position.""" + return Position(**kwargs, is_relative=True) + + +__all__ = [ + "AbsolutePosition", + "AcquireImage", + "Action", + "AnyAutofocusPlan", + "AnyTimePlan", + "AnyZPlan", + "AutoFocusPlan", + "AxesBasedAF", + "Axis", + "AxisIterable", + "Channel", + "ChannelsPlan", + "CustomAction", + "EventChannel", + "GridFromEdges", + "GridRowsColumns", + "GridWidthHeight", + "HardwareAutofocus", + "MDAEvent", + "MDASequence", + "MultiAxisSequence", + "MultiPhaseTimePlan", + "MultiPointPlan", + "MultiPositionPlan", + "MutableMDAEvent", + "OrderMode", + "Position", # alias for AbsolutePosition + "PropertyTuple", + "RandomPoints", + "RelativeMultiPointPlan", + "RelativePosition", + "RelativeTo", + "SLMImage", + "Shape", + "SimpleValueAxis", + "SinglePhaseTimePlan", + "StagePositions", + "TDurationLoops", + "TIntervalDuration", + "TIntervalLoops", + "TimePlan", + "TraversalOrder", + "WellPlate", + "WellPlatePlan", + "ZAboveBelow", + "ZAbsolutePositions", + "ZPlan", + "ZRangeAround", + "ZRangeAround", + "ZRelativePositions", + "ZTopBottom", + "ZTopBottom", + "iterate_multi_dim_sequence", + "register_well_plates", + "registered_well_plate_keys", +] + + +for item in list(globals().values()): + if ( + isinstance(item, type) + and issubclass(item, pydantic.BaseModel) + and item is not pydantic.BaseModel + ): + item.model_rebuild() + +del pydantic diff --git a/src/useq/v2/_axes_iterator.py b/src/useq/v2/_axes_iterator.py new file mode 100644 index 00000000..14683f3c --- /dev/null +++ b/src/useq/v2/_axes_iterator.py @@ -0,0 +1,468 @@ +"""MultiDimensional Iteration Module. + +This module provides a declarative approach to multi-dimensional iteration, +supporting hierarchical (nested) sub-iterations as well as conditional +skipping (filtering) of final combinations. + +Key Concepts: +------------- +- **AxisIterable**: An interface (protocol) representing an axis. Each axis + has a unique `axis_key` and yields values via its iterator. A concrete axis, + such as `SimpleValueAxis`, yields plain values. To express sub-iterations, + an axis may yield a nested `MultiAxisSequence` (instead of a plain value). + +- **MultiAxisSequence**: Represents a multi-dimensional experiment or sequence. + It contains a tuple of axes (AxisIterable objects) and an optional `axis_order` + that controls the order in which axes are processed. When used as a nested override, + its `value` field is used as the representative value for that branch, and its + axes override or extend the parent's axes. + +- **Nested Overrides**: When an axis yields a nested MultiAxisSequence with a non-None + `value`, that nested sequence acts as an override for the parent's iteration. + Specifically, the parent's remaining axes that have keys matching those in the + nested sequence are removed, and the nested sequence's axes (ordered by its own + `axis_order`, or inheriting the parent's if not provided) are appended. + +- **Prefix and Skip Logic**: As the recursion proceeds, a `prefix` is built up, mapping + axis keys to a triple: (index, value, axis). Before yielding a final combination, + each axis is given an opportunity (via the `should_skip` method) to veto that + combination. By default, `SimpleValueAxis.should_skip` returns False, but you can + override it in a subclass to implement conditional skipping. + +Usage Examples: +--------------- +1. Basic Iteration (no nested sequences): + + >>> multi_dim = MultiAxisSequence( + ... axes=( + ... SimpleValueAxis("t", [0, 1, 2]), + ... SimpleValueAxis("c", ["red", "green", "blue"]), + ... SimpleValueAxis("z", [0.1, 0.2]), + ... ), + ... axis_order=("t", "c", "z"), + ... ) + >>> for combo in iterate_multi_dim_sequence(multi_dim): + ... # Clean the prefix for display (dropping the axis objects) + ... print({k: (idx, val) for k, (idx, val, _) in combo.items()}) + {'t': (0, 0), 'c': (0, 'red'), 'z': (0, 0.1)} + {'t': (0, 0), 'c': (0, 'red'), 'z': (1, 0.2)} + ... (and so on for all Cartesian products) + +2. Sub-Iteration Adding New Axes: + Here the "t" axis yields a nested MultiAxisSequence that adds an extra "q" axis. + + >>> multi_dim = MultiAxisSequence( + ... axes=( + ... SimpleValueAxis("t", [ + ... 0, + ... MultiAxisSequence( + ... value=1, + ... axes=(SimpleValueAxis("q", ["a", "b"]),), + ... ), + ... 2, + ... ]), + ... SimpleValueAxis("c", ["red", "green", "blue"]), + ... ), + ... axis_order=("t", "c"), + ... ) + >>> for combo in iterate_multi_dim_sequence(multi_dim): + ... print({k: (idx, val) for k, (idx, val, _) in combo.items()}) + {'t': (0, 0), 'c': (0, 'red')} + {'t': (0, 0), 'c': (1, 'green')} + {'t': (0, 0), 'c': (2, 'blue')} + {'t': (1, 1), 'c': (0, 'red'), 'q': (0, 'a')} + {'t': (1, 1), 'c': (0, 'red'), 'q': (1, 'b')} + {'t': (1, 1), 'c': (1, 'green'), 'q': (0, 'a')} + ... (and so on) + +3. Overriding Parent Axes: + Here the "t" axis yields a nested MultiAxisSequence whose axes override the parent's + "z" axis. + + >>> multi_dim = MultiAxisSequence( + ... axes=( + ... SimpleValueAxis("t", [ + ... 0, + ... MultiAxisSequence( + ... value=1, + ... axes=( + ... SimpleValueAxis("c", ["red", "blue"]), + ... SimpleValueAxis("z", [7, 8, 9]), + ... ), + ... axis_order=("c", "z"), + ... ), + ... 2, + ... ]), + ... SimpleValueAxis("c", ["red", "green", "blue"]), + ... SimpleValueAxis("z", [0.1, 0.2]), + ... ), + ... axis_order=("t", "c", "z"), + ... ) + >>> for combo in iterate_multi_dim_sequence(multi_dim): + ... print({k: (idx, val) for k, (idx, val, _) in combo.items()}) + {'t': (0, 0), 'c': (0, 'red'), 'z': (0, 0.1)} + ... (normal combinations for t==0 and t==2) + {'t': (1, 1), 'c': (0, 'red'), 'z': (0, 7)} + {'t': (1, 1), 'c': (0, 'red'), 'z': (1, 8)} + {'t': (1, 1), 'c': (0, 'red'), 'z': (2, 9)} + {'t': (1, 1), 'c': (1, 'blue'), 'z': (0, 7)} + ... (and so on) + +4. Conditional Skipping: + By subclassing SimpleValueAxis to override should_skip, you can filter out + combinations. For example, suppose we want to skip any combination where "c" equals + "green" and "z" is not 0.2: + + >>> class FilteredZ(SimpleValueAxis): + ... def should_skip( + ... self, prefix: dict[str, tuple[int, Any, AxisIterable]] + ... ) -> bool: + ... c_val = prefix.get("c", (None, None, None))[1] + ... z_val = prefix.get("z", (None, None, None))[1] + ... if c_val == "green" and z_val != 0.2: + ... return True + ... return False + ... + >>> multi_dim = MultiAxisSequence( + ... axes=( + ... SimpleValueAxis("t", [0, 1, 2]), + ... SimpleValueAxis("c", ["red", "green", "blue"]), + ... FilteredZ("z", [0.1, 0.2]), + ... ), + ... axis_order=("t", "c", "z"), + ... ) + >>> for combo in iterate_multi_dim_sequence(multi_dim): + ... print({k: (idx, val) for k, (idx, val, _) in combo.items()}) + (Only those combinations where if c is green then z equals 0.2 are printed.) + +Usage Notes: +------------ +- The module assumes that each axis is finite and that the final prefix (the + combination) is built by processing one axis at a time. Nested MultiAxisSequence + objects allow you to either extend the iteration with new axes or override existing + ones. +- The ordering of axes is controlled via the `axis_order` property, which is inherited + by nested sequences if not explicitly provided. +- The should_skip mechanism gives each axis an opportunity to veto a final combination. + By default, SimpleValueAxis does not skip any combination, but you can subclass it to + implement custom filtering logic. + +This module is intended for cases where complex, declarative multidimensional iteration +is required-such as in microscope acquisitions, high-content imaging, or other +experimental designs where the sequence of events must be generated in a flexible, +hierarchical manner. +""" + +from __future__ import annotations + +from abc import abstractmethod +from collections.abc import Callable, Iterable, Iterator, Mapping, Sized +from functools import cache +from typing import ( + TYPE_CHECKING, + Annotated, + Any, + Generic, + Optional, + Protocol, + TypeVar, + cast, + runtime_checkable, +) + +from pydantic import BaseModel, Field, field_validator + +from useq._base_model import MutableModel +from useq.v2._importable_object import ImportableObject + +if TYPE_CHECKING: + from collections.abc import Iterator + from typing import TypeAlias + + AxisKey: TypeAlias = str + Value: TypeAlias = Any + Index: TypeAlias = int + AxesIndex: TypeAlias = dict[AxisKey, tuple[Index, Value, "AxisIterable"]] + AxesIndexWithContext: TypeAlias = tuple[AxesIndex, tuple["MultiAxisSequence", ...]] + + +V = TypeVar("V", covariant=True, bound=Any) +EventT = TypeVar("EventT", bound=Any) +EventTco = TypeVar("EventTco", covariant=True, bound=Any) + + +class AxisIterable(BaseModel, Generic[V]): + axis_key: str + """A string id representing the axis.""" + + @abstractmethod + def __iter__(self) -> Iterator[V]: # type: ignore[override] + """Iterate over the axis. + + If a value needs to declare sub-axes, yield a nested AxesIterator. + The default iterator pattern will recurse into a nested AxesIterator. + """ + + def should_skip(self, prefix: AxesIndex) -> bool: + """Return True if this axis wants to skip the combination. + + Default implementation returns False. + """ + return False + + def contribute_event_kwargs( + self, + value: V, # type: ignore[misc] # covariant cannot be used as parameter + index: Mapping[str, int], + ) -> Mapping: + """Contribute data to the event being built. + + This method allows each axis to contribute its data to the final MDAEvent. + The default implementation does nothing - subclasses should override + to add their specific contributions. + + Parameters + ---------- + value : V + The value provided by this axis, for this iteration. + index : Mapping[str, int] + A mapping of axis keys to their current index in the iteration. + This can be used to determine the context of the value. + + Returns + ------- + event_data : dict[str, Any] + Data to be added to the MDAEvent, it is ultimately up to the + EventBuilder to decide how to merge possibly conflicting contributions from + different axes. + """ + return {} + + +class SimpleValueAxis(AxisIterable[V]): + """A basic axis implementation that yields values directly. + + If a value needs to declare sub-axes, yield a nested MultiAxisSequence. + The default should_skip always returns False. + """ + + values: list[V] = Field(default_factory=list) + + def __iter__(self) -> Iterator[V]: # type: ignore[override] + yield from self.values + + def __len__(self) -> int: + """Return the number of axis values.""" + return len(self.values) + + +@runtime_checkable +class EventBuilder(Protocol[EventTco]): + """Callable that builds an event from an AxesIndex.""" + + @abstractmethod + def __call__( + self, axes_index: AxesIndex, context: tuple[MultiAxisSequence, ...] + ) -> EventTco: + """Transform an AxesIndex into an event object.""" + + +@runtime_checkable +class EventTransform(Protocol[EventT]): + """Callable that can modify, drop, or insert events. + + The transformer receives: + + * **event** - the current (already built) event. + * **prev_event** - the *previously transformed* event that was just yielded, + or ``None`` if this is the first call. + * **make_next_event** - a zero-argument callable that lazily builds the *next* + raw event (i.e. before any transformers). Only call it if you really + need look-ahead so the pipeline stays lazy. + + The transformer must return a list. + + Return **one** event in the list for a 1-to-1 mapping, an empty list to + drop the original event, or a list with multiple items to insert extras. + """ + + def __call__( + self, + event: EventT, + *, + prev_event: EventT | None, + make_next_event: Callable[[], EventT | None], + ) -> Iterable[EventT]: ... + + +@runtime_checkable +class AxesIterator(Protocol): + """Object that iterates over a MultiAxisSequence.""" + + @abstractmethod + def __call__( + self, seq: MultiAxisSequence, axis_order: tuple[str, ...] | None = None + ) -> Iterator[AxesIndexWithContext]: + """Iterate over the axes of a MultiAxisSequence.""" + ... + + +class MultiAxisSequence(MutableModel, Generic[EventTco]): + """Represents a multidimensional sequence. + + At the top level the `value` field is ignored. + When used as a nested override, `value` is the value for that branch and + its axes are iterated using its own axis_order if provided; + otherwise, it inherits the parent's axis_order. + """ + + axes: tuple[AxisIterable, ...] = () + axis_order: Optional[tuple[str, ...]] = None + value: Any = None + + # these will rarely be needed, but offer maximum flexibility + event_builder: Optional[Annotated[EventBuilder[EventTco], ImportableObject()]] = ( + Field(default=None, repr=False) + ) + + # optional post-processing transformer chain + transforms: tuple[Annotated[EventTransform, ImportableObject()], ...] = Field( + default_factory=tuple, repr=False + ) + + def is_finite(self) -> bool: + """Return `True` if the sequence is finite (all axes are Sized).""" + return all(isinstance(ax, Sized) for ax in self.axes) + + def iter_axes( + self, axis_order: tuple[str, ...] | None = None + ) -> Iterator[AxesIndexWithContext]: + """Iterate over the axes and yield combinations with context. + + Yields + ------ + AxesIndexWithContext + A tuple of (AxesIndex, MultiAxisSequence) where AxesIndex is a dictionary + mapping axis keys to tuples of (index, value, AxisIterable), and + MultiAxisSequence is the context that generated this axes combination. + For example, when iterating over an `AxisIterable` with a single axis "t", + with values of [0.1, .2], the yielded tuples would be: + - ({'t': (0, 0.1, )}, ) + - ({'t': (1, 0.2, )}, ) + """ + from useq.v2._iterate import iterate_multi_dim_sequence + + yield from iterate_multi_dim_sequence(self, axis_order=axis_order) + + def iter_events( + self, axis_order: tuple[str, ...] | None = None + ) -> Iterator[EventTco]: + """Iterate over axes, build raw events, then apply transformers.""" + if (event_builder := self.event_builder) is None: + raise ValueError("No event builder provided for this sequence.") + + axes_iter = self.iter_axes(axis_order=axis_order) + + # Get the first item to see if we have any events + try: + next_item: AxesIndexWithContext | None = next(axes_iter) + except StopIteration: + return # empty sequence - nothing to yield + + prev_evt: EventTco | None = None + while True: + cur_axes, context = cast("AxesIndexWithContext", next_item) + + try: + next_item = next(axes_iter) + except StopIteration: + next_item = None + + cur_evt = event_builder(cur_axes, context) + transforms = self.compose_transforms(context) + + if not transforms: + # simple case - no transforms, just yield the event + yield cur_evt + prev_evt = cur_evt + else: + + @cache + def _make_next_event( + _nxt_item: AxesIndexWithContext | None = next_item, + _ctx: Any = context, + ) -> EventTco | None: + if _nxt_item is not None: + return event_builder(_nxt_item[0], _ctx) + return None + + # run through transformer pipeline + emitted: Iterable[EventTco] = (cur_evt,) + pipeline_prev_evt = prev_evt + for tf in transforms: + # Convert to list to materialize the iterable for proper chaining + emitted_list = list(emitted) + new_emitted = [] + for e in emitted_list: + transformed = list( + tf( + e, + prev_event=pipeline_prev_evt, + make_next_event=_make_next_event, + ) + ) + if transformed: + new_emitted.extend(transformed) + # Update prev_evt to last event from this transform + pipeline_prev_evt = transformed[-1] + emitted = new_emitted + + for out_evt in emitted: + yield out_evt + prev_evt = out_evt + + if next_item is None: + break + + def compose_transforms( + self, context: tuple[MultiAxisSequence, ...] = () + ) -> tuple[EventTransform, ...]: + """Compose transforms from the context of nested sequences. + + The base implementation aggregates transforms from outer to inner sequences. + Only a single instance of each transform type is kept, so if multiple + sequences in the context have the same transform type, only one will be used, + and innermost MultiAxisSequence's transform will take precedence. + """ + merged_transforms = {type(t): t for seq in context for t in seq.transforms} + + # sort by "priority" attribute (if defined) or by order of appearance + sorted_transforms = sorted( + merged_transforms.values(), + key=lambda t: getattr(t, "priority", 0), # default priority is 0 + ) + return tuple(sorted_transforms) + + # ----------------------- Validation ----------------------- + + @field_validator("axes", mode="after") + def _validate_axes(cls, v: tuple[AxisIterable, ...]) -> tuple[AxisIterable, ...]: + keys = [x.axis_key for x in v] + if dupes := {k for k in keys if keys.count(k) > 1}: + raise ValueError( + f"The following axis keys appeared more than once: {dupes}" + ) + return v + + @field_validator("axis_order", mode="before") + @classmethod + def _validate_axis_order(cls, v: Any) -> Any: + if v is None: + return None + if not isinstance(v, Iterable): + raise ValueError(f"axis_order must be iterable, got {type(v)}") + order = tuple(str(x).lower() for x in v) + if len(set(order)) < len(order): + raise ValueError(f"Duplicate entries found in acquisition order: {order}") + + return order diff --git a/src/useq/v2/_channels.py b/src/useq/v2/_channels.py new file mode 100644 index 00000000..c0e7db2b --- /dev/null +++ b/src/useq/v2/_channels.py @@ -0,0 +1,36 @@ +from collections.abc import Mapping, Sequence +from typing import TYPE_CHECKING, Any, Literal + +from pydantic import Field, model_validator + +from useq import Axis, Channel +from useq._base_model import FrozenModel +from useq.v2._axes_iterator import SimpleValueAxis + +if TYPE_CHECKING: + from useq._mda_event import MDAEvent + + +class ChannelsPlan(SimpleValueAxis[Channel], FrozenModel): + axis_key: Literal[Axis.CHANNEL] = Field( # pyright: ignore[reportIncompatibleVariableOverride] + default=Axis.CHANNEL, frozen=True, init=False + ) + + @model_validator(mode="before") + @classmethod + def _cast_any(cls, values: Any) -> Any: + """Try to cast any value to a ChannelsPlan.""" + if isinstance(values, Sequence) and not isinstance(values, str): + values = {"values": values} + return values + + def contribute_event_kwargs( + self, value: Channel, index: Mapping[str, int] + ) -> "MDAEvent.Kwargs": + """Contribute channel information to the MDA event.""" + kwargs: MDAEvent.Kwargs = {} + if value.config is not None: + kwargs["channel"] = {"config": value.config, "group": value.group} + if value.exposure is not None: + kwargs["exposure"] = value.exposure + return kwargs diff --git a/src/useq/v2/_grid.py b/src/useq/v2/_grid.py new file mode 100644 index 00000000..ef5c2629 --- /dev/null +++ b/src/useq/v2/_grid.py @@ -0,0 +1,461 @@ +from __future__ import annotations + +import math +import warnings +from typing import ( + TYPE_CHECKING, + Annotated, + Any, + Callable, + Literal, + Optional, + Union, +) + +import numpy as np +from annotated_types import Ge, Gt +from pydantic import Field, model_validator +from typing_extensions import Self, TypeAlias, deprecated + +from useq import Axis +from useq._enums import RelativeTo, Shape +from useq._grid import _GridMixin +from useq._point_visiting import TraversalOrder +from useq.v2._multi_point import MultiPositionPlan +from useq.v2._position import Position + +if TYPE_CHECKING: + from collections.abc import Iterable, Iterator + + from matplotlib.axes import Axes + + PointGenerator: TypeAlias = Callable[ + [np.random.RandomState, int, float, float], Iterable[tuple[float, float]] + ] + +MIN_RANDOM_POINTS = 10000 + + +# used in iter_indices below, to determine the order in which indices are yielded +class _GridPlan(_GridMixin, MultiPositionPlan): + """Base class for all grid plans. + + Attributes + ---------- + overlap : float | Tuple[float, float] + Overlap between grid positions in percent. If a single value is provided, it is + used for both x and y. If a tuple is provided, the first value is used + for x and the second for y. + mode : OrderMode + Define the ways of ordering the grid positions. Options are + row_wise, column_wise, row_wise_snake, column_wise_snake and spiral. + By default, row_wise_snake. + fov_width : Optional[float] + Width of the field of view in microns. If not provided, acquisition engines + should use current width of the FOV based on the current objective and camera. + Engines MAY override this even if provided. + fov_height : Optional[float] + Height of the field of view in microns. If not provided, acquisition engines + should use current height of the FOV based on the current objective and camera. + Engines MAY override this even if provided. + """ + + axis_key: Literal[Axis.GRID] = Field(default=Axis.GRID, frozen=True, init=False) # pyright: ignore[reportIncompatibleVariableOverride] + + @deprecated( + "num_positions() is deprecated, use len(grid_plan) instead.", + category=UserWarning, + stacklevel=2, + ) + def num_positions(self) -> int: + """Return the number of positions in the grid.""" + return len(self) + + def __len__(self) -> int: + """Return the number of individual positions in the grid. + + Note: For GridFromEdges and GridWidthHeight, this will depend on field of view + size. If no field of view size is provided, the number of positions will be 1. + """ + if isinstance(self, (GridFromEdges, GridWidthHeight)) and ( + # type ignore is because mypy thinks self is Never here... + self.fov_width is None or self.fov_height is None + ): + raise ValueError( + "Retrieving the number of positions in a GridFromEdges or " + "GridWidthHeight plan requires the field of view size to be set." + ) + + dx, dy = self._step_size(self.fov_width or 1, self.fov_height or 1) + rows = self._nrows(dy) + cols = self._ncolumns(dx) + return rows * cols + + def _build_position(self, **kwargs: Any) -> Position: + """Build a position object for this grid plan.""" + return Position(**kwargs, is_relative=self.is_relative) + + +class GridFromEdges(_GridPlan): + """Yield absolute stage positions to cover a bounded area. + + The bounded area is defined by top, left, bottom and right edges in + stage coordinates. The bounds define the *outer* edges of the images, including + the field of view and overlap. + + Attributes + ---------- + top : float + Top stage position of the bounding area + left : float + Left stage position of the bounding area + bottom : float + Bottom stage position of the bounding area + right : float + Right stage position of the bounding area + overlap : float | Tuple[float, float] + Overlap between grid positions in percent. If a single value is provided, it is + used for both x and y. If a tuple is provided, the first value is used + for x and the second for y. + mode : OrderMode + Define the ways of ordering the grid positions. Options are + row_wise, column_wise, row_wise_snake, column_wise_snake and spiral. + By default, row_wise_snake. + fov_width : Optional[float] + Width of the field of view in microns. If not provided, acquisition engines + should use current width of the FOV based on the current objective and camera. + Engines MAY override this even if provided. + fov_height : Optional[float] + Height of the field of view in microns. If not provided, acquisition engines + should use current height of the FOV based on the current objective and camera. + Engines MAY override this even if provided. + """ + + # everything but fov_width and fov_height is immutable + top: float = Field(..., frozen=True) + left: float = Field(..., frozen=True) + bottom: float = Field(..., frozen=True) + right: float = Field(..., frozen=True) + + @property + def is_relative(self) -> bool: + return False + + def _nrows(self, dy: float) -> int: + if self.fov_height is None: + total_height = abs(self.top - self.bottom) + dy + return math.ceil(total_height / dy) + + span = abs(self.top - self.bottom) + # if the span is smaller than one FOV, just one row + if span <= self.fov_height: + return 1 + # otherwise: one FOV plus (nrows-1)⋅dy must cover span + return math.ceil((span - self.fov_height) / dy) + 1 + + def _ncolumns(self, dx: float) -> int: + if self.fov_width is None: + total_width = abs(self.right - self.left) + dx + return math.ceil(total_width / dx) + + span = abs(self.right - self.left) + if span <= self.fov_width: + return 1 + return math.ceil((span - self.fov_width) / dx) + 1 + + def _offset_x(self, dx: float) -> float: + # start the _centre_ half a FOV in from the left edge + return min(self.left, self.right) + (self.fov_width or 0) / 2 + + def _offset_y(self, dy: float) -> float: + # start the _centre_ half a FOV down from the top edge + return max(self.top, self.bottom) - (self.fov_height or 0) / 2 + + def plot(self, *, show: bool = True) -> Axes: + """Plot the positions in the plan.""" + from useq._plot import plot_points + + if self.fov_width is not None and self.fov_height is not None: + rect = (self.fov_width, self.fov_height) + else: + rect = None + + return plot_points( + self, + rect_size=rect, + bounding_box=(self.left, self.top, self.right, self.bottom), + show=show, + ) + + +class GridRowsColumns(_GridPlan): + """Grid plan based on number of rows and columns. + + Attributes + ---------- + rows: int + Number of rows. + columns: int + Number of columns. + relative_to : RelativeTo + Point in the grid to which the coordinates are relative. If "center", the grid + is centered around the origin. If "top_left", the grid is positioned such that + the top left corner is at the origin. + overlap : float | Tuple[float, float] + Overlap between grid positions in percent. If a single value is provided, it is + used for both x and y. If a tuple is provided, the first value is used + for x and the second for y. + mode : OrderMode + Define the ways of ordering the grid positions. Options are + row_wise, column_wise, row_wise_snake, column_wise_snake and spiral. + By default, row_wise_snake. + fov_width : Optional[float] + Width of the field of view in microns. If not provided, acquisition engines + should use current width of the FOV based on the current objective and camera. + Engines MAY override this even if provided. + fov_height : Optional[float] + Height of the field of view in microns. If not provided, acquisition engines + should use current height of the FOV based on the current objective and camera. + Engines MAY override this even if provided. + """ + + # everything but fov_width and fov_height is immutable + rows: int = Field(..., frozen=True, ge=1) + columns: int = Field(..., frozen=True, ge=1) + relative_to: RelativeTo = Field(default=RelativeTo.center, frozen=True) + + def _nrows(self, dy: float) -> int: + return self.rows + + def _ncolumns(self, dx: float) -> int: + return self.columns + + def _offset_x(self, dx: float) -> float: + return ( + -((self.columns - 1) * dx) / 2 + if self.relative_to == RelativeTo.center + else 0.0 + ) + + def _offset_y(self, dy: float) -> float: + return ( + ((self.rows - 1) * dy) / 2 if self.relative_to == RelativeTo.center else 0.0 + ) + + +class GridWidthHeight(_GridPlan): + """Grid plan based on total width and height. + + Attributes + ---------- + width: float + Minimum total width of the grid, in microns. (may be larger based on fov_width) + height: float + Minimum total height of the grid, in microns. (may be larger based on + fov_height) + relative_to : RelativeTo + Point in the grid to which the coordinates are relative. If "center", the grid + is centered around the origin. If "top_left", the grid is positioned such that + the top left corner is at the origin. + overlap : float | Tuple[float, float] + Overlap between grid positions in percent. If a single value is provided, it is + used for both x and y. If a tuple is provided, the first value is used + for x and the second for y. + mode : OrderMode + Define the ways of ordering the grid positions. Options are + row_wise, column_wise, row_wise_snake, column_wise_snake and spiral. + By default, row_wise_snake. + fov_width : Optional[float] + Width of the field of view in microns. If not provided, acquisition engines + should use current width of the FOV based on the current objective and camera. + Engines MAY override this even if provided. + fov_height : Optional[float] + Height of the field of view in microns. If not provided, acquisition engines + should use current height of the FOV based on the current objective and camera. + Engines MAY override this even if provided. + """ + + width: float = Field(..., frozen=True, gt=0) + height: float = Field(..., frozen=True, gt=0) + relative_to: RelativeTo = Field(default=RelativeTo.center, frozen=True) + + def _nrows(self, dy: float) -> int: + return math.ceil(self.height / dy) + + def _ncolumns(self, dx: float) -> int: + return math.ceil(self.width / dx) + + def _offset_x(self, dx: float) -> float: + return ( + -((self._ncolumns(dx) - 1) * dx) / 2 + if self.relative_to == RelativeTo.center + else 0.0 + ) + + def _offset_y(self, dy: float) -> float: + return ( + ((self._nrows(dy) - 1) * dy) / 2 + if self.relative_to == RelativeTo.center + else 0.0 + ) + + +# ------------------------ RANDOM ------------------------ + + +class RandomPoints(MultiPositionPlan): + """Yield random points in a specified geometric shape. + + Attributes + ---------- + num_points : int + Number of points to generate. + max_width : float + Maximum width of the bounding box in microns. + max_height : float + Maximum height of the bounding box in microns. + shape : Shape + Shape of the bounding box. Current options are "ellipse" and "rectangle". + random_seed : Optional[int] + Random numpy seed that should be used to generate the points. If None, a random + seed will be used. + allow_overlap : bool + By defaut, True. If False and `fov_width` and `fov_height` are specified, points + will not overlap and will be at least `fov_width` and `fov_height apart. + order : TraversalOrder + Order in which the points will be visited. If None, order is simply the order + in which the points are generated (random). Use 'nearest_neighbor' or + 'two_opt' to order the points in a more structured way. + start_at : int | RelativePosition + Position or index of the point to start at. This is only used if `order` is + 'nearest_neighbor' or 'two_opt'. If a position is provided, it will *always* + be included in the list of points. If an index is provided, it must be less than + the number of points, and corresponds to the index of the (randomly generated) + points; this likely only makes sense when `random_seed` is provided. + """ + + axis_key: Literal[Axis.GRID] = Field(default=Axis.GRID, frozen=True, init=False) # pyright: ignore[reportIncompatibleVariableOverride] + + num_points: Annotated[int, Gt(0)] + max_width: Annotated[float, Gt(0)] = 1 + max_height: Annotated[float, Gt(0)] = 1 + shape: Shape = Shape.ELLIPSE + random_seed: Optional[int] = None + allow_overlap: bool = True + order: Optional[TraversalOrder] = TraversalOrder.TWO_OPT + start_at: Union[Position, Annotated[int, Ge(0)]] = 0 + + @model_validator(mode="after") + def _validate_startat(self) -> Self: + if isinstance(self.start_at, int) and self.start_at > (self.num_points - 1): + warnings.warn( + "start_at is greater than the number of points. " + "Setting start_at to last point.", + stacklevel=2, + ) + self.start_at = self.num_points - 1 + return self + + def __iter__(self) -> Iterator[Position]: # type: ignore [override] + seed = np.random.RandomState(self.random_seed) + func = _POINTS_GENERATORS[self.shape] + + points: list[tuple[float, float]] = [] + needed_points = self.num_points + start_at = self.start_at + if isinstance(start_at, Position): + points = [(start_at.x, start_at.y)] # type: ignore [list-item] + needed_points -= 1 + start_at = 0 + + # in the easy case, just generate the requested number of points + if self.allow_overlap or self.fov_width is None or self.fov_height is None: + _points = func(seed, needed_points, self.max_width, self.max_height) + points.extend(_points) + + else: + # if we need to avoid overlap, generate points, check if they are valid, and + # repeat until we have enough + per_iter = needed_points + tries = 0 + while tries < MIN_RANDOM_POINTS and len(points) < self.num_points: + candidates = func(seed, per_iter, self.max_width, self.max_height) + tries += per_iter + for p in candidates: + if _is_a_valid_point(points, *p, self.fov_width, self.fov_height): + points.append(p) + if len(points) >= self.num_points: + break + + if len(points) < self.num_points: + warnings.warn( + f"Unable to generate {self.num_points} non-overlapping points. " + f"Only {len(points)} points were found.", + stacklevel=2, + ) + + if self.order is not None: + points = self.order(points, start_at=start_at) # type: ignore [assignment] + + for idx, (x, y) in enumerate(points): + yield Position(x=x, y=y, name=f"{str(idx).zfill(4)}", is_relative=True) + + def num_positions(self) -> int: + return self.num_points + + +def _is_a_valid_point( + points: list[tuple[float, float]], + x: float, + y: float, + min_dist_x: float, + min_dist_y: float, +) -> bool: + """Return True if the the point is at least min_dist away from all the others. + + note: using Manhattan distance. + """ + return not any( + abs(x - point_x) < min_dist_x and abs(y - point_y) < min_dist_y + for point_x, point_y in points + ) + + +def _random_points_in_ellipse( + seed: np.random.RandomState, n_points: int, max_width: float, max_height: float +) -> np.ndarray: + """Generate a random point around a circle with center (0, 0). + + The point is within +/- radius_x and +/- radius_y at a random angle. + """ + points = seed.uniform(0, 1, size=(n_points, 3)) + xy = points[:, :2] + angle = points[:, 2] * 2 * np.pi + xy[:, 0] *= (max_width / 2) * np.cos(angle) + xy[:, 1] *= (max_height / 2) * np.sin(angle) + return xy + + +def _random_points_in_rectangle( + seed: np.random.RandomState, n_points: int, max_width: float, max_height: float +) -> np.ndarray: + """Generate a random point around a rectangle with center (0, 0). + + The point is within the bounding box (-width/2, -height/2, width, height). + """ + xy = seed.uniform(0, 1, size=(n_points, 2)) + xy[:, 0] = (xy[:, 0] * max_width) - (max_width / 2) + xy[:, 1] = (xy[:, 1] * max_height) - (max_height / 2) + return xy + + +_POINTS_GENERATORS: dict[Shape, PointGenerator] = { + Shape.ELLIPSE: _random_points_in_ellipse, + Shape.RECTANGLE: _random_points_in_rectangle, +} + + +# all of these support __iter__() -> Iterator[PositionBase] and num_positions() -> int +RelativeMultiPointPlan = Union[GridRowsColumns, GridWidthHeight, RandomPoints] +AbsoluteMultiPointPlan = Union[GridFromEdges] +MultiPointPlan = Union[AbsoluteMultiPointPlan, RelativeMultiPointPlan] diff --git a/src/useq/v2/_importable_object.py b/src/useq/v2/_importable_object.py new file mode 100644 index 00000000..2b862183 --- /dev/null +++ b/src/useq/v2/_importable_object.py @@ -0,0 +1,80 @@ +from dataclasses import dataclass +from typing import Any, get_origin + +import pydantic +from pydantic import GetCoreSchemaHandler +from pydantic_core import core_schema + +pydantic_version = tuple(int(x) for x in pydantic.VERSION.split(".")[:2]) +if pydantic_version >= (2, 11): + json_input: dict = {"json_schema_input_schema": core_schema.str_schema()} +else: + json_input = {} + + +@dataclass(frozen=True) +class ImportableObject: + """Pydantic schema for importable objects. + + Example usage: + + ```python + field: Annotated[SomeClass, ImportableObject()] + ``` + + Putting this object in a field annotation will allow the field to accept any object + that can be imported from a string path, such as `"module.submodule.ClassName"`, and + which, when instantiated, will obey `isinstance(obj, SomeClass)`. + """ + + def __get_pydantic_core_schema__( + self, source_type: Any, handler: GetCoreSchemaHandler + ) -> core_schema.CoreSchema: + """Return the schema for the importable object.""" + + def import_python_path(value: Any) -> Any: + """Import a Python object from a string path.""" + if isinstance(value, str): + # If a string is provided, it should be a path to the class + # that implements the EventBuilder protocol. + from importlib import import_module + + parts = value.rsplit(".", 1) + if len(parts) != 2: + raise ValueError( + f"Invalid import path: {value!r}. " + "Expected format: 'module.submodule.ClassName'" + ) + module_name, class_name = parts + module = import_module(module_name) + cls = getattr(module, class_name) + if not isinstance(cls, type): + raise ValueError(f"Expected a class at {value!r}, but got {cls!r}.") + value = cls() + return value + + def get_python_path(value: Any) -> str: + """Get a unique identifier for the event builder.""" + val_type = type(value) + return f"{val_type.__module__}.{val_type.__qualname__}" + + # TODO: check me + origin = source_type + try: + isinstance(None, origin) + except TypeError: + origin = get_origin(origin) + try: + isinstance(None, origin) + except TypeError: + origin = object + + to_pp_ser = core_schema.plain_serializer_function_ser_schema( + function=get_python_path + ) + return core_schema.no_info_before_validator_function( + function=import_python_path, + schema=core_schema.is_instance_schema(origin), + serialization=to_pp_ser, + **json_input, + ) diff --git a/src/useq/v2/_iterate.py b/src/useq/v2/_iterate.py new file mode 100644 index 00000000..5434cc94 --- /dev/null +++ b/src/useq/v2/_iterate.py @@ -0,0 +1,110 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, TypeVar + +from useq.v2._axes_iterator import AxisIterable, MultiAxisSequence + +if TYPE_CHECKING: + from collections.abc import Iterator + + from useq.v2._axes_iterator import AxesIndex, AxesIndexWithContext + + +V = TypeVar("V", covariant=True) + + +def order_axes( + seq: MultiAxisSequence, + axis_order: tuple[str, ...] | None = None, +) -> list[AxisIterable]: + """Returns the axes of a MultiDimSequence in the order specified by seq.axis_order. + + If axis_order is provided, it overrides the sequence's axis_order. + """ + if axis_order is None: + axis_order = seq.axis_order + if axis_order: + axes_map = {axis.axis_key: axis for axis in seq.axes} + return [axes_map[key] for key in axis_order if key in axes_map] + return list(seq.axes) + + +def iterate_axes_recursive( + axes: list[AxisIterable], + prefix: AxesIndex | None = None, + parent_order: tuple[str, ...] | None = None, + context: tuple[MultiAxisSequence, ...] = (), +) -> Iterator[AxesIndexWithContext]: + """Recursively iterate over a list of axes one at a time. + + If an axis yields a nested MultiDimSequence with a non-None value, + that nested sequence acts as an override for its axis key. + The parent's remaining axes having matching keys are removed, and the nested + sequence's axes (ordered by its own axis_order if provided, or else the parent's) + are appended. + + Before yielding a final combination (when no axes remain), we call should_skip + on each axis (using the full prefix). + """ + if prefix is None: + prefix = {} + + if not axes: + # Ask each axis in the prefix if the combination should be skipped + if not any(axis.should_skip(prefix) for *_, axis in prefix.values()): + yield prefix, context + return + + current_axis, *remaining_axes = axes + + for idx, item in enumerate(current_axis): + if isinstance(item, MultiAxisSequence): + if item.value is None: + raise NotImplementedError("Nested sequences must have a value.") + + value = item.value + override_keys = {ax.axis_key for ax in item.axes} + order = item.axis_order if item.axis_order is not None else parent_order + + # Remove axes from the parent that are overridden by the nested sequence, + # then append the axes from the nested sequence in the correct order. + parent_axes_not_overridden = [ + ax for ax in remaining_axes if ax.axis_key not in override_keys + ] + nested_axes_in_order = order_axes(item, order) + updated_axes = parent_axes_not_overridden + nested_axes_in_order + + # Use the nested sequence as the new context + context = (*context, item) + else: + value = item + updated_axes = remaining_axes + + yield from iterate_axes_recursive( + updated_axes, + {**prefix, current_axis.axis_key: (idx, value, current_axis)}, + parent_order=parent_order, + context=context, + ) + + +def iterate_multi_dim_sequence( + seq: MultiAxisSequence, axis_order: tuple[str, ...] | None = None +) -> Iterator[AxesIndexWithContext]: + """Iterate over a MultiDimSequence. + + Orders the base axes (if an axis_order is provided) and then iterates + over all index combinations using iterate_axes_recursive. + The parent's axis_order is passed down to nested sequences. + + Yields + ------ + AxesIndexWithContext + A tuple of (AxesIndex, MultiAxisSequence) where AxesIndex is a dictionary + mapping axis keys to tuples of (index, value, axis), and MultiAxisSequence + is the context that generated this axes combination. + """ + ordered_axes = order_axes(seq, axis_order) + yield from iterate_axes_recursive( + ordered_axes, parent_order=axis_order, context=(seq,) + ) diff --git a/src/useq/v2/_mda_sequence.py b/src/useq/v2/_mda_sequence.py new file mode 100644 index 00000000..c21bf277 --- /dev/null +++ b/src/useq/v2/_mda_sequence.py @@ -0,0 +1,344 @@ +from __future__ import annotations + +import warnings +from collections.abc import Iterable, Iterator +from contextlib import suppress +from typing import ( + TYPE_CHECKING, + Annotated, + Any, + Callable, + Optional, + overload, +) + +from pydantic import Field, TypeAdapter, field_validator, model_validator +from typing_extensions import deprecated + +from useq import v2 +from useq._enums import AXES, Axis +from useq._hardware_autofocus import AnyAutofocusPlan, AxesBasedAF +from useq._mda_event import MDAEvent +from useq._mda_sequence import MDASequence as MDASequenceV1 +from useq.v2 import _position +from useq.v2._axes_iterator import ( + AxisIterable, + EventBuilder, + EventTransform, + MultiAxisSequence, +) +from useq.v2._importable_object import ImportableObject +from useq.v2._transformers import ( + AutoFocusTransform, + KeepShutterOpenTransform, + ResetEventTimerTransform, + reset_global_timer_state, +) + +if TYPE_CHECKING: + from collections.abc import Iterator, Mapping, Sequence + + from useq._channel import Channel + from useq.v2._axes_iterator import AxesIndex + from useq.v2._position import Position + + +# Example concrete event builder for MDAEvent +class MDAEventBuilder(EventBuilder[MDAEvent]): + """Builds MDAEvent objects from AxesIndex.""" + + def __call__( + self, axes_index: AxesIndex, context: tuple[MultiAxisSequence, ...] + ) -> MDAEvent: + """Transform AxesIndex into MDAEvent using axis contributions.""" + index: dict[str, int] = {} + contributions: list[tuple[str, Mapping]] = [] + + # Let each axis contribute to the event + for axis_key, (idx, value, axis) in axes_index.items(): + index[axis_key] = idx + contribution = axis.contribute_event_kwargs(value, index) + contributions.append((axis_key, contribution)) + + if context: + contributions.append(("", {"sequence": context[-1]})) + return self._merge_contributions(index, contributions) + + def _merge_contributions( + self, index: dict[str, int], contributions: list[tuple[str, Mapping]] + ) -> MDAEvent: + event_data: dict = {"index": index} + abs_pos: dict[str, float] = {} + + # First pass: collect all contributions and detect conflicts + for axis_key, contrib in contributions: + for key, val in contrib.items(): + if key.endswith("_pos") and val is not None: + if key in abs_pos and abs_pos[key] != val: + warnings.warn( + f"Conflicting absolute position from {axis_key}: " + f"existing {key}={abs_pos[key]}, new {key}={val}", + UserWarning, + stacklevel=3, + ) + abs_pos[key] = val + elif key in event_data and event_data[key] != val: # pragma: no cover + # Could implement different strategies here + raise ValueError(f"Conflicting values for {key} from {axis_key}") + else: + event_data[key] = val + + # Second pass: handle relative positions + for _, contrib in contributions: + for key, val in contrib.items(): + if key.endswith("_pos_rel") and val is not None: + abs_key = key.replace("_rel", "") + abs_pos.setdefault(abs_key, 0.0) + abs_pos[abs_key] += val + + # Merge final positions + event_data.update(abs_pos) + return MDAEvent(**event_data) + + +def _default_transforms(data: dict) -> tuple[EventTransform[MDAEvent], ...]: + if any(ax.axis_key == Axis.TIME for ax in data.get("axes", ())): + return (ResetEventTimerTransform(),) + return () + + +class MDASequence(MultiAxisSequence[MDAEvent]): + autofocus_plan: Optional[AnyAutofocusPlan] = None + keep_shutter_open_across: tuple[str, ...] = Field(default_factory=tuple) + metadata: dict[str, Any] = Field(default_factory=dict) + event_builder: Optional[Annotated[EventBuilder[MDAEvent], ImportableObject()]] = ( + Field(default_factory=MDAEventBuilder, repr=False) + ) + + transforms: tuple[Annotated[EventTransform[MDAEvent], ImportableObject()], ...] = ( + Field( + default_factory=_default_transforms, + repr=False, + ) + ) + + if TYPE_CHECKING: + # legacy __init__ signature + @overload + def __init__( + self: MDASequence, + *, + axis_order: tuple[str, ...] | str | None = ..., + value: Any = ..., + time_plan: AxisIterable[float] | list | dict | None = ..., + z_plan: AxisIterable[Position] | None = ..., + channels: AxisIterable[Channel] | list | None = ..., + stage_positions: AxisIterable[Position] | list | None = ..., + grid_plan: AxisIterable[Position] | None = ..., + autofocus_plan: AnyAutofocusPlan | None = ..., + keep_shutter_open_across: str | tuple[str, ...] = ..., + metadata: dict[str, Any] = ..., + event_builder: EventBuilder[MDAEvent] = ..., + transforms: tuple[EventTransform[MDAEvent], ...] = ..., + ) -> None: ... + # new pattern + @overload + def __init__( + self, + *, + axes: tuple[AxisIterable, ...] = ..., + axis_order: tuple[str, ...] | None = ..., + value: Any = ..., + autofocus_plan: AnyAutofocusPlan | None = ..., + keep_shutter_open_across: tuple[str, ...] = ..., + metadata: dict[str, Any] = ..., + event_builder: EventBuilder[MDAEvent] = ..., + transforms: tuple[EventTransform[MDAEvent], ...] = ..., + ) -> None: ... + def __init__(self, **kwargs: Any) -> None: ... + + def __iter__(self) -> Iterator[MDAEvent]: # type: ignore[override] + # Reset global timer state at the beginning of each sequence (like v1) + reset_global_timer_state() + yield from self.iter_events() + + @model_validator(mode="before") + @classmethod + def _cast_legacy_kwargs(cls, data: Any) -> Any: + """Cast legacy kwargs to the new pattern.""" + if isinstance(data, MDASequenceV1): + data = data.model_dump(exclude_unset=True) + if isinstance(data, dict) and (axes := _extract_legacy_axes(data)): + if "axes" in data: # pragma: no cover + raise ValueError( + "Cannot provide both 'axes' and legacy MDASequence parameters." + ) + data["axes"] = axes + return data + + @model_validator(mode="after") + def _compose_transforms(self) -> MDASequence: + """Compose transforms after initialization.""" + # add autofocus transform if applicable + if isinstance(self.autofocus_plan, AxesBasedAF) and not any( + isinstance(ax, AutoFocusTransform) for ax in self.transforms + ): + self.transforms += (AutoFocusTransform(self.autofocus_plan),) + if self.keep_shutter_open_across and not any( + isinstance(ax, KeepShutterOpenTransform) for ax in self.transforms + ): + self.transforms += ( + KeepShutterOpenTransform(self.keep_shutter_open_across), + ) + return self + + @field_validator("keep_shutter_open_across", mode="before") + def _validate_keep_shutter_open_across(cls, v: tuple[str, ...]) -> tuple[str, ...]: + try: + v = tuple(v) + except (TypeError, ValueError): # pragma: no cover + raise ValueError( + f"keep_shutter_open_across must be string or a sequence of strings, " + f"got {type(v)}" + ) from None + return v + + # ------------------------- Old API ------------------------- + + @property + @deprecated( + "The shape of an MDASequence is ill-defined. " + "This API will be removed in a future version.", + category=FutureWarning, + stacklevel=2, + ) + def shape(self) -> tuple[int, ...]: + """Return the shape of this sequence. + + !!! note + This doesn't account for jagged arrays, like channels that exclude z + stacks or skip timepoints. + """ + return tuple(s for s in self.sizes.values() if s) + + @property + @deprecated( + "The sizes of an MDASequence is ill-defined. " + "This API will be removed in a future version.", + category=FutureWarning, + stacklevel=2, + ) + def sizes(self) -> Mapping[str, int]: + """Mapping of axis name to size of that axis.""" + if not self.is_finite(): # pragma: no cover + raise ValueError("Cannot get sizes of infinite sequence.") + + return {axis.axis_key: len(axis) for axis in self._ordered_axes()} # type: ignore[arg-type] + + def _ordered_axes(self) -> tuple[AxisIterable, ...]: + """Return the axes in the order specified by axis_order.""" + if (order := self.axis_order) is None: + return self.axes + + axes_map = {axis.axis_key: axis for axis in self.axes} + return tuple(axes_map[key] for key in order if key in axes_map) + + @property + def used_axes(self) -> tuple[str, ...]: + """Return keys of the axes whose length is not 0.""" + out = [] + for ax in self._ordered_axes(): + with suppress(TypeError, ValueError): + if not len(ax): # type: ignore[arg-type] # pragma: no cover + continue + out.append(ax.axis_key) + return tuple(out) + + @property + def time_plan(self) -> Optional[AxisIterable[float]]: + """Return the time plan.""" + return next((axis for axis in self.axes if axis.axis_key == Axis.TIME), None) + + @property + def z_plan(self) -> Optional[AxisIterable[Position]]: + """Return the z plan.""" + return next((axis for axis in self.axes if axis.axis_key == Axis.Z), None) + + @property + def channels(self) -> Sequence[Channel]: + """Return the channels.""" + for axis in self.axes: + if axis.axis_key == Axis.CHANNEL: + return tuple(axis) + # If no channel axis is found, return an empty tuple + return () + + @property + def stage_positions(self) -> Sequence[Position]: + """Return the stage positions.""" + for axis in self.axes: + if axis.axis_key == Axis.POSITION: + return tuple(axis) + return () + + @property + def grid_plan(self) -> Optional[AxisIterable[Position]]: + """Return the grid plan.""" + return next((axis for axis in self.axes if axis.axis_key == Axis.GRID), None) + + +def _extract_legacy_axes(kwargs: dict[str, Any]) -> tuple[AxisIterable, ...]: + """Extract legacy axes from kwargs.""" + + def _cast_stage_position(val: Any) -> v2.StagePositions: + if not isinstance(val, Iterable): # pragma: no cover + raise ValueError( + f"Cannot convert 'stage_position' to AxisIterable: " + f"Expected a sequence, got {type(val)}" + ) + new_val: list[v2.Position] = [] + for item in val: + if isinstance(item, dict): + item = v2.Position(**item) + elif isinstance(item, MultiAxisSequence): + if item.value is None: + item = item.model_copy(update={"value": _position.Position()}) + else: + item = _position.Position.model_validate(item) + new_val.append(item) + return v2.StagePositions.model_validate(new_val) + + def _cast_legacy_to_axis_iterable(key: str) -> AxisIterable | None: + validator: dict[str, Callable[[Any], AxisIterable]] = { + "channels": v2.ChannelsPlan.model_validate, + "z_plan": TypeAdapter(v2.AnyZPlan).validate_python, + "time_plan": TypeAdapter(v2.AnyTimePlan).validate_python, + "grid_plan": TypeAdapter(v2.MultiPointPlan).validate_python, + "stage_positions": _cast_stage_position, + } + if (val := kwargs.pop(key)) not in (None, [], (), {}): + if not isinstance(val, AxisIterable): + try: + val = validator[key](val) + except Exception as e: # pragma: no cover + breakpoint() + raise ValueError( + f"Failed to process legacy axis '{key}': {e}" + ) from e + return val + return None # pragma: no cover + + axes = [ + val + for key in list(kwargs) + if key in {"channels", "z_plan", "time_plan", "grid_plan", "stage_positions"} + and (val := _cast_legacy_to_axis_iterable(key)) is not None + ] + + if "axis_order" not in kwargs: + # sort axes by AXES + axes.sort( + key=lambda ax: AXES.index(ax.axis_key) if ax.axis_key in AXES else len(AXES) + ) + + return tuple(axes) diff --git a/src/useq/v2/_multi_point.py b/src/useq/v2/_multi_point.py new file mode 100644 index 00000000..0e7118f7 --- /dev/null +++ b/src/useq/v2/_multi_point.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Annotated, Optional + +from annotated_types import Ge + +from useq.v2._axes_iterator import AxisIterable +from useq.v2._position import Position + +if TYPE_CHECKING: + from collections.abc import Mapping + + from matplotlib.axes import Axes + + from useq._mda_event import MDAEvent + + +class MultiPositionPlan(AxisIterable[Position]): + """Base class for all multi-position plans.""" + + fov_width: Optional[Annotated[float, Ge(0)]] = None + fov_height: Optional[Annotated[float, Ge(0)]] = None + + @property + def is_relative(self) -> bool: + return True + + def contribute_event_kwargs( + self, value: Position, index: Mapping[str, int] + ) -> MDAEvent.Kwargs: + out: dict = {} + rel = "_rel" if self.is_relative else "" + if value.x is not None: + out[f"x_pos{rel}"] = value.x + if value.y is not None: + out[f"y_pos{rel}"] = value.y + if value.z is not None: + out[f"z_pos{rel}"] = value.z + # if value.name is not None: + # out["pos_name"] = value.name + + # TODO: deal with the _rel suffix hack + return out # type: ignore[return-value] + + def plot(self, *, show: bool = True) -> Axes: + """Plot the positions in the plan.""" + from useq._plot import plot_points + + rect = None + if self.fov_width is not None and self.fov_height is not None: + rect = (self.fov_width, self.fov_height) + + return plot_points(self, rect_size=rect, show=show) # type: ignore[arg-type] diff --git a/src/useq/v2/_position.py b/src/useq/v2/_position.py new file mode 100644 index 00000000..568de1dd --- /dev/null +++ b/src/useq/v2/_position.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +import os +import warnings +from typing import TYPE_CHECKING, Any, Optional, SupportsIndex + +import numpy as np +from pydantic import model_validator + +from useq._base_model import MutableModel + +if TYPE_CHECKING: + from typing_extensions import Self + + +class Position(MutableModel): + """Define a position in 3D space. + + Any of the attributes can be `None` to indicate that the position is not + defined. For engines implementing support for useq, a position of `None` implies + "do not move" or "stay at current position" on that axis. + + Attributes + ---------- + x : float | None + X position in microns. + y : float | None + Y position in microns. + z : float | None + Z position in microns. + name : str | None + Optional name for the position. + is_relative : bool + If `True`, the position should be considered a delta relative to some other + position. Relative positions support addition and subtraction, while absolute + positions do not. + """ + + def __new__(cls, *args: Any, **kwargs: Any) -> Self: + if "sequence" in kwargs and (seq := kwargs.pop("sequence")) is not None: + from useq.v2._mda_sequence import MDASequence + + seq2 = MDASequence.model_validate(seq) + pos = Position.model_validate(kwargs) + warnings.warn( + "In useq.v2 Positions no longer have a sequence attribute. " + "If you want to assign a subsequence to a position, " + "use positions=[..., MDASequence(value=Position(), ...)]. " + "We will now return an MDASequence, but this is not type safe.", + DeprecationWarning, + stacklevel=2, + ) + return seq2.model_copy(update={"value": pos}) # type: ignore[return-value] + return super().__new__(cls) + + x: Optional[float] = None + y: Optional[float] = None + z: Optional[float] = None + name: Optional[str] = None + is_relative: bool = False + + @model_validator(mode="before") + @classmethod + def _cast_any(cls, values: Any) -> Any: + """Try to cast any value to a Position.""" + if isinstance(values, (np.ndarray, tuple)): + x, *v = values + y, *v = v or (None,) + z = v[0] if v else None + values = {"x": x, "y": y, "z": z} + return values + + def __add__(self, other: Position) -> Self: + """Add two positions together to create a new position.""" + if not isinstance(other, Position) or not other.is_relative: + return NotImplemented # pragma: no cover + if self.name and other.name: + new_name: str | None = f"{self.name}_{other.name}" + else: + new_name = self.name or other.name + + return self.model_copy( + update={ + "x": _none_sum(self.x, other.x), + "y": _none_sum(self.y, other.y), + "z": _none_sum(self.z, other.z), + "name": new_name, + } + ) + + # allow `sum([pos1, delta, delta2], start=Position())` + __radd__ = __add__ + + def __round__(self, ndigits: SupportsIndex | None = None) -> Self: + """Round the position to the given number of decimal places.""" + return self.model_copy( + update={ + "x": _none_round(self.x, ndigits), + "y": _none_round(self.y, ndigits), + "z": _none_round(self.z, ndigits), + } + ) + + # FIXME: before merge + if "PYTEST_VERSION" in os.environ: + + def __eq__(self, other: object) -> bool: + """Compare two positions for equality.""" + if isinstance(other, (float, int)): + return self.z == other + return super().__eq__(other) + + +def _none_sum(a: float | None, b: float | None) -> float | None: + return a + b if a is not None and b is not None else a + + +def _none_round(v: float | None, ndigits: SupportsIndex | None) -> float | None: + return round(v, ndigits) if v is not None else None diff --git a/src/useq/v2/_stage_positions.py b/src/useq/v2/_stage_positions.py new file mode 100644 index 00000000..5d4ebee6 --- /dev/null +++ b/src/useq/v2/_stage_positions.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +from collections.abc import Iterator, Mapping, Sequence +from typing import TYPE_CHECKING, Any, Literal, Union + +import numpy as np +from pydantic import Field, model_validator + +from useq import Axis +from useq._base_model import FrozenModel +from useq.v2._axes_iterator import AxisIterable +from useq.v2._position import Position + +if TYPE_CHECKING: + from useq._mda_event import MDAEvent + from useq.v2._mda_sequence import MDASequence + + +class StagePositions(AxisIterable[Position], FrozenModel): + axis_key: Literal[Axis.POSITION] = Field( # pyright: ignore[reportIncompatibleVariableOverride] + default=Axis.POSITION, frozen=True, init=False + ) + values: list[Union[Position, MDASequence]] = Field(default_factory=list) + + def __iter__(self) -> Iterator[Position | MDASequence]: # type: ignore[override] + yield from self.values + + def __len__(self) -> int: + """Return the number of axis values.""" + return len(self.values) + + @model_validator(mode="before") + @classmethod + def _cast_any(cls, values: Any) -> Any: + """Try to cast any value to a ChannelsPlan.""" + if isinstance(values, np.ndarray): + if values.ndim == 1: + values = [values] + elif values.ndim == 2: + values = list(values) + else: + raise ValueError( + f"Invalid number of dimensions for stage positions: {values.ndim}" + ) + if isinstance(values, Sequence) and not isinstance(values, str): + values = {"values": values} + + return values + + def contribute_event_kwargs( + self, + value: Position, + index: Mapping[str, int], + ) -> MDAEvent.Kwargs: + """Contribute channel information to the MDA event.""" + kwargs = {} + if isinstance(value, Position): + for key in ("x", "y", "z"): + if (val := getattr(value, key)) is not None: + kwargs[f"{key}_pos"] = val + if value.name is not None: + kwargs["pos_name"] = value.name + return kwargs # type: ignore[return-value] diff --git a/src/useq/v2/_time.py b/src/useq/v2/_time.py new file mode 100644 index 00000000..4bdb0529 --- /dev/null +++ b/src/useq/v2/_time.py @@ -0,0 +1,104 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Union + +from pydantic import Field +from typing_extensions import deprecated + +from useq import _time +from useq._enums import Axis +from useq.v2._axes_iterator import AxisIterable + +if TYPE_CHECKING: + from collections.abc import Generator, Mapping + + from useq._mda_event import MDAEvent + + +class TimePlan(_time.TimePlan, AxisIterable[float]): + axis_key: str = Field(default=Axis.TIME, frozen=True, init=False) + + def contribute_event_kwargs( + self, value: float, index: Mapping[str, int] + ) -> MDAEvent.Kwargs: + """Contribute time data to the event being built. + + Parameters + ---------- + value : float + The time value for this iteration. + index : Mapping[str, int] + Current axis indices. + + Returns + ------- + dict + Event data to be merged into the MDAEvent. + """ + return {"min_start_time": value} + + @deprecated( + "num_timepoints() is deprecated, use len(time_plan) instead.", + category=UserWarning, + stacklevel=2, + ) + def num_timepoints(self) -> int: + """Return the number of time points in this plan. + + This is deprecated and will be removed in a future version. + Use `len()` instead. + """ + return len(self) + + +class TIntervalLoops(_time.TIntervalLoops, TimePlan): ... + + +class TDurationLoops(_time.TDurationLoops, TimePlan): ... + + +class TIntervalDuration(_time.TIntervalDuration, TimePlan): ... + + +SinglePhaseTimePlan = Union[TIntervalDuration, TIntervalLoops, TDurationLoops] + + +class MultiPhaseTimePlan(TimePlan, _time.MultiPhaseTimePlan): + phases: list[SinglePhaseTimePlan] # pyright: ignore[reportIncompatibleVariableOverride] + + def __iter__(self) -> Generator[float, bool | None, None]: # type: ignore[override] + """Yield the global elapsed time over multiple plans. + + and allow `.send(True)` to skip to the next phase. + """ + offset = 0.0 + for ip, phase in enumerate(self.phases): + last_t = 0.0 + phase_iter = iter(phase) + if ip != 0: + # skip the first time point of all the phases except the first + next(phase_iter) + while True: + try: + t = next(phase_iter) + except StopIteration: + break + last_t = t + # here `force = yield offset + t` allows the caller to do + # gen = iter(plan) + # next(gen) # start + # gen.send(True) # force the next phase + force = yield offset + t + if force: + break + + # advance our offset to the end of this phase + if (duration_td := phase.duration) is not None: + offset += duration_td.total_seconds() + else: + # infinite phase that we broke out of + # leave offset where it was + last_t + offset += last_t + + +AnyTimePlan = Union[MultiPhaseTimePlan, SinglePhaseTimePlan] diff --git a/src/useq/v2/_transformers.py b/src/useq/v2/_transformers.py new file mode 100644 index 00000000..7db11a01 --- /dev/null +++ b/src/useq/v2/_transformers.py @@ -0,0 +1,146 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +# transformers.py +from useq._enums import Axis +from useq._mda_event import MDAEvent +from useq.v2._axes_iterator import EventTransform # helper you already have + +if TYPE_CHECKING: + from collections.abc import Callable, Iterable + + from useq._hardware_autofocus import AxesBasedAF + + +# Global state to share reset_event_timer state across all sequences (like v1) +_global_last_t_idx: int = -1 + + +def reset_global_timer_state() -> None: + """Reset the global timer state. Should be called at the start of each sequence.""" + global _global_last_t_idx + _global_last_t_idx = -1 + + +class KeepShutterOpenTransform(EventTransform[MDAEvent]): + """Replicates the v1 `keep_shutter_open_across` behaviour. + + Parameters + ---------- + axes + Tuple of axis names (`"p"`, `"t"`, `"c"`, `"z"`, …) on which the shutter + may stay open when only **they** change between consecutive events. + """ + + def __init__(self, axes: tuple[str, ...]): + self.axes = axes + + def __call__( + self, + event: MDAEvent, + *, + prev_event: MDAEvent | None, + make_next_event: Callable[[], MDAEvent | None], + ) -> Iterable[MDAEvent]: + if (nxt := make_next_event()) is None: # last event → nothing to tweak + return [event] + + # keep shutter open iff every axis that *changes* is in `self.axes` + if all( + ax in self.axes + for ax, idx in event.index.items() + if idx != nxt.index.get(ax) + ): + event = event.model_copy(update={"keep_shutter_open": True}) + + return [event] + + +class ResetEventTimerTransform(EventTransform[MDAEvent]): + """Marks the first frame of each timepoint with ``reset_event_timer=True``.""" + + def __init__(self) -> None: + # Use global state to match v1 behavior where _last_t_idx is shared + # across all nested sequences + pass + + def __call__( + self, + event: MDAEvent, + *, + prev_event: MDAEvent | None, + make_next_event: Callable[[], MDAEvent | None], + ) -> Iterable[MDAEvent]: + global _global_last_t_idx + + # No time axis → nothing to do + if Axis.TIME not in event.index: + return [event] + + # Reset timer when t=0 and the last t_idx wasn't 0 (matching v1 behavior) + current_t_idx = event.index.get(Axis.TIME, 0) + if current_t_idx == 0 and _global_last_t_idx != 0: + event = event.model_copy(update={"reset_event_timer": True}) + + # Update the global last t index for next time + _global_last_t_idx = current_t_idx + return [event] + + +class AutoFocusTransform(EventTransform[MDAEvent]): + """Insert hardware-autofocus events created by an ``AutoFocusPlan``. + + Parameters + ---------- + plan_getter : + Function that returns the *active* autofocus plan for the + current event. By default we use ``event.sequence.autofocus_plan``, + but you can plug in something smarter if you support + per-position overrides. + """ + + priority = -1 + + def __init__(self, af_plan: AxesBasedAF) -> None: + self._af_plan = af_plan + + def __call__( + self, + event: MDAEvent, + *, + prev_event: MDAEvent | None, + make_next_event: Callable[[], MDAEvent | None], # unused, but required + ) -> Iterable[MDAEvent]: + # Skip autofocus when no axes specified + af_axes = self._af_plan.axes + if not af_axes: + return [event] + + # Determine if any specified axis has changed (or first event) + trigger = False + if prev_event is None: + trigger = True + else: + for axis in af_axes: + if prev_event.index.get(axis) != event.index.get(axis): + trigger = True + break + + if trigger: + updates: dict[str, object] = {"action": self._af_plan.as_action()} + if event.z_pos is not None and event.sequence is not None: + zplan = event.sequence.z_plan + if zplan and zplan.is_relative and "z" in event.index: + try: + positions = list(zplan) + val = positions[event.index["z"]] + offset = val.z if hasattr(val, "z") else val + updates["z_pos"] = event.z_pos - offset + except (IndexError, AttributeError): + pass # fallback to default + + af_event = event.model_copy(update=updates) + return [af_event, event] + + return [event] diff --git a/src/useq/v2/_z.py b/src/useq/v2/_z.py new file mode 100644 index 00000000..e92125b0 --- /dev/null +++ b/src/useq/v2/_z.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal, Union + +from pydantic import Field +from typing_extensions import deprecated + +from useq._enums import Axis +from useq.v2._axes_iterator import AxisIterable +from useq.v2._position import Position + +if TYPE_CHECKING: + from collections.abc import Iterator, Mapping + + from useq._mda_event import MDAEvent + + +from useq import _z + + +class ZPlan(_z.ZPlan, AxisIterable[Position]): + axis_key: Literal[Axis.Z] = Field(default=Axis.Z, frozen=True, init=False) # pyright: ignore[reportIncompatibleVariableOverride] + + def __iter__(self) -> Iterator[Position]: # type: ignore[override] + """Iterate over Z positions.""" + positions = self.positions() + if not self.go_up: + positions = positions[::-1] + for p in positions: + yield Position(z=p, is_relative=self.is_relative) + + @deprecated( + "num_positions() is deprecated, use len(z_plan) instead.", + category=UserWarning, + stacklevel=2, + ) + def num_positions(self) -> int: + """Get the number of Z positions.""" + return len(self) + + def contribute_event_kwargs( + self, value: Position, index: Mapping[str, int] + ) -> MDAEvent.Kwargs: + """Contribute Z position to the MDA event.""" + if value.z is not None: + if self.is_relative: + return {"z_pos_rel": value.z} # type: ignore [typeddict-unknown-key] + else: + return {"z_pos": value.z} + return {} + + +class ZTopBottom(ZPlan, _z.ZTopBottom): ... + + +class ZAboveBelow(ZPlan, _z.ZAboveBelow): ... + + +class ZRangeAround(ZPlan, _z.ZRangeAround): ... + + +class ZAbsolutePositions(ZPlan, _z.ZAbsolutePositions): + def __len__(self) -> int: + return len(self.absolute) + + +class ZRelativePositions(ZPlan, _z.ZRelativePositions): + def __len__(self) -> int: + return len(self.relative) + + +# order matters... this is the order in which pydantic will try to coerce input. +# should go from most specific to least specific +AnyZPlan = Union[ + ZTopBottom, ZAboveBelow, ZRangeAround, ZAbsolutePositions, ZRelativePositions +] diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/fixtures/__init__.py b/tests/fixtures/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/fixtures/cases.py b/tests/fixtures/cases.py new file mode 100644 index 00000000..d72aa592 --- /dev/null +++ b/tests/fixtures/cases.py @@ -0,0 +1,1254 @@ +# pyright: reportArgumentType=false +from __future__ import annotations + +from dataclasses import dataclass +from itertools import product +from typing import TYPE_CHECKING, Any, Callable + +from useq import ( + AxesBasedAF, + Channel, + GridFromEdges, + GridRowsColumns, + HardwareAutofocus, + MDAEvent, + MDASequence, + Position, + TIntervalLoops, + ZRangeAround, + ZTopBottom, +) + +if TYPE_CHECKING: + from collections.abc import Sequence + + +@dataclass(frozen=True) +class MDATestCase: + """A test case combining an MDASequence and expected attribute values. + + Parameters + ---------- + name : str + A short identifier used for the parametrised test id. + seq : MDASequence + The :class:`useq.MDASequence` under test. + expected : dict[str, list[Any]] | list[MDAEvent] | None + one of: + - a dictionary mapping attribute names to a list of expected values, where + the list length is equal to the number of events in the sequence. + - a list of expected `useq.MDAEvent` objects, compared directly to the expanded + sequence. + predicate : Callable[[MDASequence], str] | None + A callable that takes a `useq.MDASequence`. If a non-empty string is returned, + it is raised as an assertion error with the string as the message. + """ + + name: str + seq: MDASequence + expected: dict[str, list[Any]] | list[MDAEvent] | None = None + predicate: Callable[[Sequence[MDAEvent]], str | None] | None = None + + def __post_init__(self) -> None: + if self.expected is None and self.predicate is None: + raise ValueError("Either expected or predicate must be provided. ") + + +############################################################################## +# helpers +############################################################################## + + +def genindex(axes: dict[str, int]) -> list[dict[str, int]]: + """Produce the cartesian product of `range(n)` for the given axes.""" + return [ + dict(zip(axes, prod)) for prod in product(*(range(v) for v in axes.values())) + ] + + +############################################################################## +# test cases +############################################################################## + +GRID_SUBSEQ_CASES: list[MDATestCase] = [ + MDATestCase( + name="channel_only_in_position_sub_sequence", + seq=MDASequence( + stage_positions=[ + {}, + Position( + sequence=MDASequence( + channels=[Channel(config="FITC", exposure=100)] + ) + ), + ] + ), + expected={ + "channel": [None, "FITC"], + "index": [{"p": 0}, {"p": 1, "c": 0}], + "exposure": [None, 100.0], + }, + ), + MDATestCase( + name="channel_in_main_and_position_sub_sequence", + seq=MDASequence( + stage_positions=[ + {}, + Position( + sequence=MDASequence( + channels=[Channel(config="FITC", exposure=100)] + ) + ), + ], + channels=[Channel(config="Cy5", exposure=50)], + ), + expected={ + "channel": ["Cy5", "FITC"], + "index": [{"p": 0, "c": 0}, {"p": 1, "c": 0}], + "exposure": [50.0, 100.0], + }, + ), + MDATestCase( + name="subchannel_inherits_global_channel", + seq=MDASequence( + stage_positions=[ + {}, + {"sequence": {"z_plan": ZTopBottom(bottom=28, top=30, step=1)}}, + ], + channels=[Channel(config="Cy5", exposure=50)], + ), + expected={ + "channel": ["Cy5"] * 4, + "index": [ + {"p": 0, "c": 0}, + {"p": 1, "z": 0, "c": 0}, + {"p": 1, "z": 1, "c": 0}, + {"p": 1, "z": 2, "c": 0}, + ], + }, + ), + MDATestCase( + name="grid_relative_with_multi_stage_positions", + seq=MDASequence( + stage_positions=[Position(x=0, y=0), (10, 20)], + grid_plan=GridRowsColumns(rows=2, columns=2), + ), + expected={ + "index": genindex({"p": 2, "g": 4}), + "x_pos": [-0.5, 0.5, 0.5, -0.5, 9.5, 10.5, 10.5, 9.5], + "y_pos": [0.5, 0.5, -0.5, -0.5, 20.5, 20.5, 19.5, 19.5], + }, + ), + MDATestCase( + name="grid_relative_only_in_position_sub_sequence", + seq=MDASequence( + stage_positions=[ + Position(x=0, y=0), + Position( + x=10, + y=10, + sequence={ + "grid_plan": GridRowsColumns(rows=2, columns=2), + }, + ), + ] + ), + expected={ + "index": [ + {"p": 0}, + {"p": 1, "g": 0}, + {"p": 1, "g": 1}, + {"p": 1, "g": 2}, + {"p": 1, "g": 3}, + ], + "x_pos": [0.0, 9.5, 10.5, 10.5, 9.5], + "y_pos": [0.0, 10.5, 10.5, 9.5, 9.5], + }, + ), + MDATestCase( + name="grid_absolute_only_in_position_sub_sequence", + seq=MDASequence( + stage_positions=[ + Position(x=0, y=0), + Position( + x=10, + y=10, + sequence={ + "grid_plan": GridFromEdges(top=1, bottom=-1, left=0, right=0) + }, + ), + ] + ), + expected={ + "index": [ + {"p": 0}, + {"p": 1, "g": 0}, + {"p": 1, "g": 1}, + {"p": 1, "g": 2}, + ], + "x_pos": [0.0, 0.0, 0.0, 0.0], + "y_pos": [0.0, 1.0, 0.0, -1.0], + }, + ), + MDATestCase( + name="grid_relative_in_main_and_position_sub_sequence", + seq=MDASequence( + stage_positions=[ + Position(x=0, y=0), + Position( + name="name", + x=10, + y=10, + sequence={"grid_plan": GridRowsColumns(rows=2, columns=2)}, + ), + ], + grid_plan=GridRowsColumns(rows=2, columns=2), + ), + expected={ + "index": genindex({"p": 2, "g": 4}), + "pos_name": [None] * 4 + ["name"] * 4, + "x_pos": [-0.5, 0.5, 0.5, -0.5, 9.5, 10.5, 10.5, 9.5], + "y_pos": [0.5, 0.5, -0.5, -0.5, 10.5, 10.5, 9.5, 9.5], + }, + ), + MDATestCase( + name="grid_absolute_in_main_and_position_sub_sequence", + seq=MDASequence( + stage_positions=[ + {}, + Position( + name="name", + sequence={ + "grid_plan": GridFromEdges(top=2, bottom=-1, left=0, right=0) + }, + ), + ], + grid_plan=GridFromEdges(top=1, bottom=-1, left=0, right=0), + ), + expected={ + "index": [ + {"p": 0, "g": 0}, + {"p": 0, "g": 1}, + {"p": 0, "g": 2}, + {"p": 1, "g": 0}, + {"p": 1, "g": 1}, + {"p": 1, "g": 2}, + {"p": 1, "g": 3}, + ], + "pos_name": [None] * 3 + ["name"] * 4, + "x_pos": [0.0] * 7, + "y_pos": [1.0, 0.0, -1.0, 2.0, 1.0, 0.0, -1.0], + }, + ), + MDATestCase( + name="grid_absolute_in_main_and_grid_relative_in_position_sub_sequence", + seq=MDASequence( + stage_positions=[ + {}, + Position( + name="name", + x=10, + y=10, + sequence={"grid_plan": GridRowsColumns(rows=2, columns=2)}, + ), + ], + grid_plan=GridFromEdges(top=1, bottom=-1, left=0, right=0), + ), + expected={ + "index": [ + {"p": 0, "g": 0}, + {"p": 0, "g": 1}, + {"p": 0, "g": 2}, + {"p": 1, "g": 0}, + {"p": 1, "g": 1}, + {"p": 1, "g": 2}, + {"p": 1, "g": 3}, + ], + "pos_name": [None] * 3 + ["name"] * 4, + "x_pos": [0.0, 0.0, 0.0, 9.5, 10.5, 10.5, 9.5], + "y_pos": [1.0, 0.0, -1.0, 10.5, 10.5, 9.5, 9.5], + }, + ), + MDATestCase( + name="grid_relative_in_main_and_grid_absolute_in_position_sub_sequence", + seq=MDASequence( + stage_positions=[ + Position(x=0, y=0), + Position( + name="name", + sequence={ + "grid_plan": GridFromEdges(top=1, bottom=-1, left=0, right=0) + }, + ), + ], + grid_plan=GridRowsColumns(rows=2, columns=2), + ), + expected={ + "index": [ + {"p": 0, "g": 0}, + {"p": 0, "g": 1}, + {"p": 0, "g": 2}, + {"p": 0, "g": 3}, + {"p": 1, "g": 0}, + {"p": 1, "g": 1}, + {"p": 1, "g": 2}, + ], + "pos_name": [None] * 4 + ["name"] * 3, + "x_pos": [-0.5, 0.5, 0.5, -0.5, 0.0, 0.0, 0.0], + "y_pos": [0.5, 0.5, -0.5, -0.5, 1.0, 0.0, -1.0], + }, + ), + MDATestCase( + name="multi_g_in_position_sub_sequence", + seq=MDASequence( + stage_positions=[ + {"sequence": {"grid_plan": {"rows": 1, "columns": 2}}}, + {"sequence": {"grid_plan": GridRowsColumns(rows=2, columns=2)}}, + { + "sequence": { + "grid_plan": GridFromEdges(top=1, bottom=-1, left=0, right=0) + } + }, + ] + ), + expected={ + "index": [ + {"p": 0, "g": 0}, + {"p": 0, "g": 1}, + {"p": 1, "g": 0}, + {"p": 1, "g": 1}, + {"p": 1, "g": 2}, + {"p": 1, "g": 3}, + {"p": 2, "g": 0}, + {"p": 2, "g": 1}, + {"p": 2, "g": 2}, + ], + "x_pos": [-0.5, 0.5, -0.5, 0.5, 0.5, -0.5, 0.0, 0.0, 0.0], + "y_pos": [0.0, 0.0, 0.5, 0.5, -0.5, -0.5, 1.0, 0.0, -1.0], + }, + ), + MDATestCase( + name="z_relative_with_multi_stage_positions", + seq=MDASequence( + stage_positions=[(0, 0, 0), (10, 20, 10)], + z_plan=ZRangeAround(range=2, step=1), + ), + expected={ + "index": genindex({"p": 2, "z": 3}), + "x_pos": [0.0, 0.0, 0.0, 10.0, 10.0, 10.0], + "y_pos": [0.0, 0.0, 0.0, 20.0, 20.0, 20.0], + "z_pos": [-1.0, 0.0, 1.0, 9.0, 10.0, 11.0], + }, + ), + MDATestCase( + name="z_absolute_with_multi_stage_positions", + seq=MDASequence( + stage_positions=[Position(x=0, y=0), (10, 20)], + z_plan=ZTopBottom(bottom=58, top=60, step=1), + ), + expected={ + "index": genindex({"p": 2, "z": 3}), + "x_pos": [0.0, 0.0, 0.0, 10.0, 10.0, 10.0], + "y_pos": [0.0, 0.0, 0.0, 20.0, 20.0, 20.0], + "z_pos": [58.0, 59.0, 60.0, 58.0, 59.0, 60.0], + }, + ), + MDATestCase( + name="z_relative_only_in_position_sub_sequence", + seq=MDASequence( + stage_positions=[ + Position(z=0), + Position( + name="name", + z=10, + sequence={"z_plan": ZRangeAround(range=2, step=1)}, + ), + ] + ), + expected={ + "index": [ + {"p": 0}, + {"p": 1, "z": 0}, + {"p": 1, "z": 1}, + {"p": 1, "z": 2}, + ], + "pos_name": [None, "name", "name", "name"], + "z_pos": [0.0, 9.0, 10.0, 11.0], + }, + ), + MDATestCase( + name="z_absolute_only_in_position_sub_sequence", + seq=MDASequence( + stage_positions=[ + Position(z=0), + Position( + name="name", + sequence={"z_plan": ZTopBottom(bottom=58, top=60, step=1)}, + ), + ] + ), + expected={ + "index": [ + {"p": 0}, + {"p": 1, "z": 0}, + {"p": 1, "z": 1}, + {"p": 1, "z": 2}, + ], + "pos_name": [None, "name", "name", "name"], + "z_pos": [0.0, 58.0, 59.0, 60.0], + }, + ), + MDATestCase( + name="z_relative_in_main_and_position_sub_sequence", + seq=MDASequence( + stage_positions=[ + Position(z=0), + Position( + name="name", + z=10, + sequence={"z_plan": ZRangeAround(range=3, step=1)}, + ), + ], + z_plan=ZRangeAround(range=2, step=1), + ), + expected={ + # pop the 3rd index + "index": (idx := genindex({"p": 2, "z": 4}))[:3] + idx[4:], + "pos_name": [None] * 3 + ["name"] * 4, + "z_pos": [-1.0, 0.0, 1.0, 8.5, 9.5, 10.5, 11.5], + }, + ), + MDATestCase( + name="z_absolute_in_main_and_position_sub_sequence", + seq=MDASequence( + stage_positions=[ + {}, + Position( + name="name", + sequence={"z_plan": ZTopBottom(bottom=28, top=30, step=1)}, + ), + ], + z_plan=ZTopBottom(bottom=58, top=60, step=1), + ), + expected={ + "index": genindex({"p": 2, "z": 3}), + "pos_name": [None] * 3 + ["name"] * 3, + "z_pos": [58.0, 59.0, 60.0, 28.0, 29.0, 30.0], + }, + ), + MDATestCase( + name="z_absolute_in_main_and_z_relative_in_position_sub_sequence", + seq=MDASequence( + stage_positions=[ + {}, + Position( + name="name", + z=10, + sequence={"z_plan": ZRangeAround(range=3, step=1)}, + ), + ], + z_plan=ZTopBottom(bottom=58, top=60, step=1), + ), + expected={ + "index": [ + {"p": 0, "z": 0}, + {"p": 0, "z": 1}, + {"p": 0, "z": 2}, + {"p": 1, "z": 0}, + {"p": 1, "z": 1}, + {"p": 1, "z": 2}, + {"p": 1, "z": 3}, + ], + "pos_name": [None] * 3 + ["name"] * 4, + "z_pos": [58.0, 59.0, 60.0, 8.5, 9.5, 10.5, 11.5], + }, + ), + MDATestCase( + name="z_relative_in_main_and_z_absolute_in_position_sub_sequence", + seq=MDASequence( + stage_positions=[ + Position(z=0), + Position( + name="name", + sequence={"z_plan": ZTopBottom(bottom=58, top=60, step=1)}, + ), + ], + z_plan=ZRangeAround(range=3, step=1), + ), + expected={ + "index": [ + {"p": 0, "z": 0}, + {"p": 0, "z": 1}, + {"p": 0, "z": 2}, + {"p": 0, "z": 3}, + {"p": 1, "z": 0}, + {"p": 1, "z": 1}, + {"p": 1, "z": 2}, + ], + "pos_name": [None] * 4 + ["name"] * 3, + "z_pos": [-1.5, -0.5, 0.5, 1.5, 58.0, 59.0, 60.0], + }, + ), + MDATestCase( + name="multi_z_in_position_sub_sequence", + seq=MDASequence( + stage_positions=[ + {"sequence": {"z_plan": ZTopBottom(bottom=58, top=60, step=1)}}, + {"sequence": {"z_plan": ZRangeAround(range=3, step=1)}}, + {"sequence": {"z_plan": ZTopBottom(bottom=28, top=30, step=1)}}, + ] + ), + expected={ + "index": [ + {"p": 0, "z": 0}, + {"p": 0, "z": 1}, + {"p": 0, "z": 2}, + {"p": 1, "z": 0}, + {"p": 1, "z": 1}, + {"p": 1, "z": 2}, + {"p": 1, "z": 3}, + {"p": 2, "z": 0}, + {"p": 2, "z": 1}, + {"p": 2, "z": 2}, + ], + "z_pos": [ + 58.0, + 59.0, + 60.0, + -1.5, + -0.5, + 0.5, + 1.5, + 28.0, + 29.0, + 30.0, + ], + }, + ), + MDATestCase( + name="t_with_multi_stage_positions", + seq=MDASequence( + stage_positions=[{}, {}], + time_plan=[TIntervalLoops(interval=1, loops=2)], + ), + expected={ + "index": genindex({"t": 2, "p": 2}), + "min_start_time": [0.0, 0.0, 1.0, 1.0], + }, + ), + MDATestCase( + name="t_only_in_position_sub_sequence", + seq=MDASequence( + stage_positions=[ + {}, + {"sequence": {"time_plan": [TIntervalLoops(interval=1, loops=5)]}}, + ] + ), + expected={ + "index": [ + {"p": 0}, + {"p": 1, "t": 0}, + {"p": 1, "t": 1}, + {"p": 1, "t": 2}, + {"p": 1, "t": 3}, + {"p": 1, "t": 4}, + ], + "min_start_time": [None, 0.0, 1.0, 2.0, 3.0, 4.0], + }, + ), + MDATestCase( + name="t_in_main_and_in_position_sub_sequence", + seq=MDASequence( + stage_positions=[ + {}, + {"sequence": {"time_plan": [TIntervalLoops(interval=1, loops=5)]}}, + ], + time_plan=[TIntervalLoops(interval=1, loops=2)], + ), + expected={ + "index": [ + {"t": 0, "p": 0}, + {"t": 0, "p": 1}, + {"t": 1, "p": 1}, + {"t": 2, "p": 1}, + {"t": 3, "p": 1}, + {"t": 4, "p": 1}, + {"t": 1, "p": 0}, + {"t": 0, "p": 1}, + {"t": 1, "p": 1}, + {"t": 2, "p": 1}, + {"t": 3, "p": 1}, + {"t": 4, "p": 1}, + ], + "min_start_time": [ + 0.0, + 0.0, + 1.0, + 2.0, + 3.0, + 4.0, + 1.0, + 0.0, + 1.0, + 2.0, + 3.0, + 4.0, + ], + }, + ), + MDATestCase( + name="mix_cgz_axes", + seq=MDASequence( + axis_order="tpgcz", + stage_positions=[ + Position(x=0, y=0), + Position( + name="name", + x=10, + y=10, + z=30, + sequence=MDASequence( + channels=[ + {"config": "FITC", "exposure": 200}, + {"config": "Cy3", "exposure": 100}, + ], + grid_plan=GridRowsColumns(rows=2, columns=1), + z_plan=ZRangeAround(range=2, step=1), + ), + ), + ], + channels=[Channel(config="Cy5", exposure=50)], + z_plan={"top": 100, "bottom": 98, "step": 1}, + grid_plan=GridFromEdges(top=1, bottom=-1, left=0, right=0), + ), + expected={ + "index": [ + *genindex({"p": 1, "g": 3, "c": 1, "z": 3}), + {"p": 1, "g": 0, "c": 0, "z": 0}, + {"p": 1, "g": 0, "c": 0, "z": 1}, + {"p": 1, "g": 0, "c": 0, "z": 2}, + {"p": 1, "g": 0, "c": 1, "z": 0}, + {"p": 1, "g": 0, "c": 1, "z": 1}, + {"p": 1, "g": 0, "c": 1, "z": 2}, + {"p": 1, "g": 1, "c": 0, "z": 0}, + {"p": 1, "g": 1, "c": 0, "z": 1}, + {"p": 1, "g": 1, "c": 0, "z": 2}, + {"p": 1, "g": 1, "c": 1, "z": 0}, + {"p": 1, "g": 1, "c": 1, "z": 1}, + {"p": 1, "g": 1, "c": 1, "z": 2}, + ], + "pos_name": [None] * 9 + ["name"] * 12, + "x_pos": [0.0] * 9 + [10.0] * 12, + "y_pos": [1, 1, 1, 0, 0, 0, -1, -1, -1] + [10.5] * 6 + [9.5] * 6, + "z_pos": [98.0, 99.0, 100.0] * 3 + [29.0, 30.0, 31.0] * 4, + "channel": ["Cy5"] * 9 + (["FITC"] * 3 + ["Cy3"] * 3) * 2, + "exposure": [50.0] * 9 + [200.0, 200.0, 200.0, 100.0, 100.0, 100.0] * 2, + }, + ), + MDATestCase( + name="order", + seq=MDASequence( + stage_positions=[ + Position(z=0), + Position( + z=50, + sequence=MDASequence( + channels=[ + Channel(config="FITC", exposure=100), + Channel(config="Cy3", exposure=200), + ] + ), + ), + ], + channels=[ + Channel(config="FITC", exposure=100), + Channel(config="Cy5", exposure=50), + ], + z_plan=ZRangeAround(range=2, step=1), + ), + expected={ + "index": [ + {"p": 0, "c": 0, "z": 0}, + {"p": 0, "c": 0, "z": 1}, + {"p": 0, "c": 0, "z": 2}, + {"p": 0, "c": 1, "z": 0}, + {"p": 0, "c": 1, "z": 1}, + {"p": 0, "c": 1, "z": 2}, + {"p": 1, "c": 0, "z": 0}, + {"p": 1, "c": 1, "z": 0}, + {"p": 1, "c": 0, "z": 1}, + {"p": 1, "c": 1, "z": 1}, + {"p": 1, "c": 0, "z": 2}, + {"p": 1, "c": 1, "z": 2}, + ], + "z_pos": [ + -1.0, + 0.0, + 1.0, + -1.0, + 0.0, + 1.0, + 49.0, + 49.0, + 50.0, + 50.0, + 51.0, + 51.0, + ], + "channel": ["FITC"] * 3 + ["Cy5"] * 3 + ["FITC", "Cy3"] * 3, + }, + ), + MDATestCase( + name="channels_and_pos_grid_plan", + seq=MDASequence( + channels=[ + Channel(config="Cy5", exposure=50), + Channel(config="FITC", exposure=100), + ], + stage_positions=[ + Position( + x=0, + y=0, + sequence=MDASequence(grid_plan=GridRowsColumns(rows=2, columns=1)), + ) + ], + ), + expected={ + "index": genindex({"p": 1, "c": 2, "g": 2}), + "x_pos": [0.0, 0.0, 0.0, 0.0], + "y_pos": [0.5, -0.5, 0.5, -0.5], + "channel": ["Cy5", "Cy5", "FITC", "FITC"], + }, + ), + MDATestCase( + name="channels_and_pos_z_plan", + seq=MDASequence( + channels=[ + Channel(config="Cy5", exposure=50), + Channel(config="FITC", exposure=100), + ], + stage_positions=[ + Position( + x=0, + y=0, + z=0, + sequence={"z_plan": ZRangeAround(range=2, step=1)}, + ) + ], + ), + expected={ + "index": genindex({"p": 1, "c": 2, "z": 3}), + "z_pos": [-1.0, 0.0, 1.0, -1.0, 0.0, 1.0], + "channel": ["Cy5", "Cy5", "Cy5", "FITC", "FITC", "FITC"], + }, + ), + MDATestCase( + name="channels_and_pos_time_plan", + seq=MDASequence( + axis_order="tpgcz", + channels=[ + Channel(config="Cy5", exposure=50), + Channel(config="FITC", exposure=100), + ], + stage_positions=[ + Position( + x=0, + y=0, + sequence={"time_plan": [TIntervalLoops(interval=1, loops=3)]}, + ) + ], + ), + expected={ + "index": genindex({"p": 1, "c": 2, "t": 3}), + "min_start_time": [0.0, 1.0, 2.0, 0.0, 1.0, 2.0], + "channel": ["Cy5", "Cy5", "Cy5", "FITC", "FITC", "FITC"], + }, + ), + MDATestCase( + name="channels_and_pos_z_grid_and_time_plan", + seq=MDASequence( + channels=[ + Channel(config="Cy5", exposure=50), + Channel(config="FITC", exposure=100), + ], + stage_positions=[ + Position( + x=0, + y=0, + sequence=MDASequence( + grid_plan=GridRowsColumns(rows=2, columns=2), + z_plan=ZRangeAround(range=2, step=1), + time_plan=[TIntervalLoops(interval=1, loops=2)], + ), + ) + ], + ), + expected={"channel": ["Cy5"] * 24 + ["FITC"] * 24}, + ), + MDATestCase( + name="sub_channels_and_any_plan", + seq=MDASequence( + channels=["Cy5", "FITC"], + stage_positions=[ + Position( + sequence=MDASequence( + channels=["FITC"], + z_plan=ZRangeAround(range=2, step=1), + ) + ) + ], + ), + expected={"channel": ["FITC", "FITC", "FITC"]}, + ), +] + +############################################################################## +# Autofocus Tests +############################################################################## + + +def ensure_af( + expected_indices: Sequence[int] | None = None, expected_z: float | None = None +) -> Callable[[Sequence[MDAEvent]], str | None]: + """Test things about autofocus events. + + Parameters + ---------- + expected_indices : Sequence[int] | None + Ensure that the autofocus events are at these indices. + expected_z : float | None + Ensure that all autofocus events have this z position. + """ + exp = list(expected_indices) if expected_indices else [] + + def _pred(events: Sequence[MDAEvent]) -> str | None: + errors: list[str] = [] + if exp: + actual_indices = [ + i + for i, ev in enumerate(events) + if isinstance(ev.action, HardwareAutofocus) + ] + if actual_indices != exp: + errors.append(f"expected AF indices {exp}, got {actual_indices}") + + if expected_z is not None: + z_vals = [ + ev.z_pos for ev in events if isinstance(ev.action, HardwareAutofocus) + ] + if not all(z == expected_z for z in z_vals): + errors.append(f"expected all AF events at z={expected_z}, got {z_vals}") + if errors: + return ", ".join(errors) + return None + + return _pred + + +AF_CASES: list[MDATestCase] = [ + # 1. NO AXES - Should never trigger + MDATestCase( + name="af_no_axes_no_autofocus", + seq=MDASequence( + stage_positions=[Position(z=30)], + z_plan=ZRangeAround(range=2, step=1), + channels=["DAPI", "FITC"], + autofocus_plan=AxesBasedAF( + autofocus_device_name="Z", autofocus_motor_offset=40, axes=() + ), + ), + predicate=ensure_af(expected_indices=[]), + ), + # 2. CHANNEL AXIS (c) - Triggers on channel changes + MDATestCase( + name="af_axes_c_basic", + seq=MDASequence( + stage_positions=[Position(z=30)], + z_plan=ZRangeAround(range=2, step=1), + channels=["DAPI", "FITC"], + autofocus_plan=AxesBasedAF(autofocus_device_name="Z", axes=("c",)), + ), + predicate=ensure_af(expected_indices=[0, 4]), + ), + # 3. Z AXIS (z) - Triggers on z changes + MDATestCase( + name="af_axes_z_basic", + seq=MDASequence( + stage_positions=[Position(z=30)], + z_plan=ZRangeAround(range=2, step=1), + channels=["DAPI", "FITC"], + autofocus_plan=AxesBasedAF(autofocus_device_name="Z", axes=("z",)), + ), + predicate=ensure_af(expected_indices=range(0, 11, 2)), + ), + # 4. GRID AXIS (g) - Triggers on grid position changes + MDATestCase( + name="af_axes_g_basic", + seq=MDASequence( + stage_positions=[Position(z=30)], + channels=["DAPI", "FITC"], + grid_plan=GridRowsColumns(rows=2, columns=1), + autofocus_plan=AxesBasedAF(autofocus_device_name="Z", axes=("g",)), + ), + predicate=ensure_af(expected_indices=[0, 3]), + ), + # 5. POSITION AXIS (p) - Triggers on position changes + MDATestCase( + name="af_axes_p_basic", + seq=MDASequence( + stage_positions=[Position(z=30), Position(z=200)], + channels=["DAPI", "FITC"], + autofocus_plan=AxesBasedAF(autofocus_device_name="Z", axes=("p",)), + ), + predicate=ensure_af(expected_indices=[0, 3]), + ), + # 6. TIME AXIS (t) - Triggers on time changes + MDATestCase( + name="af_axes_t_basic", + seq=MDASequence( + stage_positions=[Position(z=30), Position(z=200)], + channels=["DAPI", "FITC"], + time_plan=[TIntervalLoops(interval=1, loops=2)], + autofocus_plan=AxesBasedAF(autofocus_device_name="Z", axes=("t",)), + ), + predicate=ensure_af(expected_indices=[0, 5]), + ), + # 7. AXIS ORDER EFFECTS - Different axis order changes when axes trigger + MDATestCase( + name="af_axis_order_effect", + seq=MDASequence( + stage_positions=[Position(z=30)], + z_plan=ZRangeAround(range=2, step=1), + channels=["DAPI", "FITC"], + axis_order="tpgzc", # Different from default "tpczg" + autofocus_plan=AxesBasedAF(autofocus_device_name="Z", axes=("z",)), + ), + predicate=ensure_af(expected_indices=[0, 3, 6]), + ), + # 8. SUBSEQUENCE AUTOFOCUS - AF plan within position subsequence + MDATestCase( + name="af_subsequence_af", + seq=MDASequence( + stage_positions=[ + Position(z=30), + Position( + z=10, + sequence=MDASequence( + autofocus_plan=AxesBasedAF( + autofocus_device_name="Z", + axes=("c",), + ) + ), + ), + ], + channels=["DAPI", "FITC"], + ), + predicate=ensure_af(expected_indices=[2, 4]), + ), + # 9. MIXED MAIN + SUBSEQUENCE AF + MDATestCase( + name="af_mixed_main_and_sub", + seq=MDASequence( + stage_positions=[ + Position(z=30), + Position( + z=10, + sequence=MDASequence( + autofocus_plan=AxesBasedAF( + autofocus_device_name="Z", + autofocus_motor_offset=40, + axes=("z",), + ), + ), + ), + ], + channels=["DAPI", "FITC"], + z_plan=ZRangeAround(range=2, step=1), + autofocus_plan=AxesBasedAF( + autofocus_device_name="Z", autofocus_motor_offset=40, axes=("p",) + ), + ), + predicate=ensure_af(expected_indices=[0, *range(7, 18, 2)]), + ), + # 10. Z POSITION CORRECTION - AF events get correct z position with relative z plans + MDATestCase( + name="af_z_position_correction", + seq=MDASequence( + stage_positions=[Position(z=200)], + channels=["DAPI", "FITC"], + z_plan=ZRangeAround(range=2, step=1), + autofocus_plan=AxesBasedAF( + autofocus_device_name="Z", autofocus_motor_offset=40, axes=("c",) + ), + ), + predicate=ensure_af(expected_z=200), + ), + # 11. SUBSEQUENCE Z POSITION CORRECTION + MDATestCase( + name="af_z_position_subsequence", + seq=MDASequence( + stage_positions=[ + Position( + z=10, + sequence=MDASequence( + autofocus_plan=AxesBasedAF( + autofocus_device_name="Z", + autofocus_motor_offset=40, + axes=("c",), + ) + ), + ) + ], + channels=["DAPI", "FITC"], + z_plan=ZRangeAround(range=2, step=1), + ), + predicate=ensure_af(expected_z=10), + ), + # 12. NO DEVICE NAME - Edge case for testing without device name + MDATestCase( + name="af_no_device_name", + seq=MDASequence( + time_plan=[TIntervalLoops(interval=1, loops=2)], + autofocus_plan=AxesBasedAF(axes=("t",)), + ), + predicate=lambda _: "", # Just check it doesn't crash + ), +] + +############################################################################## +# Keep Shutter Open Tests +############################################################################### + + +def ensure_shutter_behavior( + expected_indices: Sequence[int] | bool | None = None, +) -> Callable[[Sequence[MDAEvent]], str | None]: + """Test keep_shutter_open behavior.""" + + def _pred(events: Sequence[MDAEvent]) -> str | None: + errors: list[str] = [] + + if expected_indices is not None: + if expected_indices is True: + if closed_events := [ + i for i, e in enumerate(events) if not e.keep_shutter_open + ]: + errors.append( + f"expected all shutters open, but events " + f"{closed_events} have keep_shutter_open=False" + ) + elif expected_indices is False: + if open_events := [ + i for i, e in enumerate(events) if e.keep_shutter_open + ]: + errors.append( + f"expected all shutters closed, but events " + f"{open_events} have keep_shutter_open=True" + ) + else: + actual_indices = [ + i for i, e in enumerate(events) if e.keep_shutter_open + ] + if actual_indices != list(expected_indices): + errors.append( + f"expected shutter open at indices {expected_indices}, " + f"got {actual_indices}" + ) + + if errors: + return "; ".join(errors) + return None + + return _pred + + +KEEP_SHUTTER_CASES: list[MDATestCase] = [ + # with z as the last axis, the shutter will be left open + # whenever z is the first index (since there are only 2 z planes) + MDATestCase( + name="keep_shutter_open_across_z_order_tcz", + seq=MDASequence( + axis_order=tuple("tcz"), + channels=["DAPI", "FITC"], + time_plan=TIntervalLoops(loops=2, interval=0), + z_plan=ZRangeAround(range=1, step=1), + keep_shutter_open_across="z", + ), + predicate=ensure_shutter_behavior(expected_indices=[0, 2, 4, 6]), + ), + # with c as the last axis, the shutter will never be left open + MDATestCase( + name="keep_shutter_open_across_z_order_tzc", + seq=MDASequence( + axis_order=tuple("tzc"), + channels=["DAPI", "FITC"], + time_plan=TIntervalLoops(loops=2, interval=0), + z_plan=ZRangeAround(range=1, step=1), + keep_shutter_open_across="z", + ), + predicate=ensure_shutter_behavior(expected_indices=[]), + ), + # because t is changing faster than z, the shutter will never be left open + MDATestCase( + name="keep_shutter_open_across_z_order_czt", + seq=MDASequence( + axis_order=tuple("czt"), + channels=["DAPI", "FITC"], + time_plan=TIntervalLoops(loops=2, interval=0), + z_plan=ZRangeAround(range=1, step=1), + keep_shutter_open_across="z", + ), + predicate=ensure_shutter_behavior(expected_indices=[]), + ), + # but, if we include 't' in the keep_shutter_open_across, + # it will be left open except when it's the last t and last z + MDATestCase( + name="keep_shutter_open_across_zt_order_czt", + seq=MDASequence( + axis_order=tuple("czt"), + channels=["DAPI", "FITC"], + time_plan=TIntervalLoops(loops=2, interval=0), + z_plan=ZRangeAround(range=1, step=1), + keep_shutter_open_across=("z", "t"), + ), + # for event in seq: + # is_last_zt = bool(event.index["t"] == 1 and event.index["z"] == 1) + # assert event.keep_shutter_open != is_last_zt + predicate=ensure_shutter_behavior(expected_indices=[0, 1, 2, 4, 5, 6]), + ), + # even though c is the last axis, and comes after g, because the grid happens + # on a subsequence shutter will be open across the grid for each position + MDATestCase( + name="keep_shutter_open_across_g_order_pgc_with_subseq", + seq=MDASequence( + axis_order=tuple("pgc"), + channels=["DAPI", "FITC"], + stage_positions=[ + Position( + sequence=MDASequence(grid_plan=GridRowsColumns(rows=2, columns=2)) + ) + ], + keep_shutter_open_across="g", + ), + # for event in seq: + # assert event.keep_shutter_open != (event.index["g"] == 3) + predicate=ensure_shutter_behavior(expected_indices=[0, 1, 2, 4, 5, 6]), + ), +] + +# ############################################################################## +# Reset Event Timer Test Cases +# ############################################################################## + +RESET_EVENT_TIMER_CASES: list[MDATestCase] = [ + MDATestCase( + name="reset_event_timer_with_time_intervals", + seq=MDASequence( + stage_positions=[(100, 100), (0, 0)], + time_plan={"interval": 1, "loops": 2}, + axis_order=tuple("ptgcz"), + ), + expected={ + "reset_event_timer": [True, False, True, False], + }, + ), + MDATestCase( + name="reset_event_timer_with_nested_position_sequences", + seq=MDASequence( + stage_positions=[ + Position( + x=0, + y=0, + sequence=MDASequence( + channels=["Cy5"], time_plan={"interval": 1, "loops": 2} + ), + ), + Position( + x=1, + y=1, + sequence=MDASequence( + channels=["DAPI"], time_plan={"interval": 1, "loops": 2} + ), + ), + ] + ), + expected={ + "reset_event_timer": [True, False, True, False], + }, + ), +] + + +CASES: list[MDATestCase] = ( + GRID_SUBSEQ_CASES + AF_CASES + KEEP_SHUTTER_CASES + RESET_EVENT_TIMER_CASES +) + +# assert that all test cases are unique +case_names = [case.name for case in CASES] +if duplicates := {name for name in case_names if case_names.count(name) > 1}: + raise ValueError( + f"Duplicate test case names found: {duplicates}. " + "Please ensure all test cases have unique names." + ) + + +def assert_test_case_passes( + case: MDATestCase, actual_events: Sequence[MDAEvent] +) -> None: + # test case expressed the expectation as a predicate + if case.predicate is not None: + # (a function that returns a non-empty error message if the test fails) + if msg := case.predicate(actual_events): + raise AssertionError(f"\nExpectation not met in '{case.name}':\n {msg}\n") + + # test case expressed the expectation as a list of MDAEvent + if isinstance(case.expected, list): + if len(actual_events) != len(case.expected): + raise AssertionError( + f"\nMismatch in case '{case.name}':\n" + f" expected: {len(case.expected)} events\n" + f" actual: {len(actual_events)} events\n" + ) + for i, event in enumerate(actual_events): + if event != case.expected[i]: + raise AssertionError( + f"\nMismatch in case '{case.name}':\n" + f" expected: {case.expected[i]}\n" + f" actual: {event}\n" + ) + + # test case expressed the expectation as a dict of {Event attr -> values list} + elif isinstance(case.expected, dict): + actual: dict[str, list[Any]] = {k: [] for k in case.expected} + for event in actual_events: + for attr in case.expected: + actual[attr].append(getattr(event, attr)) + + if mismatched_fields := { + attr for attr in actual if actual[attr] != case.expected[attr] + }: + msg = f"\nMismatch in case '{case.name}':\n" + for attr in mismatched_fields: + msg += f" {attr}:\n" + msg += f" expected: {case.expected[attr]}\n" + msg += f" actual: {actual[attr]}\n" + raise AssertionError(msg) + + +def get_case(name: str) -> MDATestCase: + """Get a test case by name.""" + for case in CASES: + if case.name == name: + return case + + import difflib + + # If the name is not found, suggest similar names + similar_names = difflib.get_close_matches( + name, [case.name for case in CASES], cutoff=0.3 + ) + if similar_names: + raise ValueError( + f"Test case '{name}' not found. Did you mean: {', '.join(similar_names)}?" + ) + raise ValueError(f"Test case '{name}' not found in the cases list.") diff --git a/tests/test_mda_sequence_cases.py b/tests/test_mda_sequence_cases.py index 724b6119..8d06a119 100644 --- a/tests/test_mda_sequence_cases.py +++ b/tests/test_mda_sequence_cases.py @@ -1,1242 +1,10 @@ -# pyright: reportArgumentType=false from __future__ import annotations -from dataclasses import dataclass -from itertools import product -from typing import TYPE_CHECKING, Any, Callable - import pytest -from useq import ( - AxesBasedAF, - Channel, - GridFromEdges, - GridRowsColumns, - HardwareAutofocus, - MDAEvent, - MDASequence, - Position, - TIntervalLoops, - ZRangeAround, - ZTopBottom, -) - -if TYPE_CHECKING: - from collections.abc import Sequence - - -@dataclass(frozen=True) -class MDATestCase: - """A test case combining an MDASequence and expected attribute values. - - Parameters - ---------- - name : str - A short identifier used for the parametrised test id. - seq : MDASequence - The :class:`useq.MDASequence` under test. - expected : dict[str, list[Any]] | list[MDAEvent] | None - one of: - - a dictionary mapping attribute names to a list of expected values, where - the list length is equal to the number of events in the sequence. - - a list of expected `useq.MDAEvent` objects, compared directly to the expanded - sequence. - predicate : Callable[[MDASequence], str] | None - A callable that takes a `useq.MDASequence`. If a non-empty string is returned, - it is raised as an assertion error with the string as the message. - """ - - name: str - seq: MDASequence - expected: dict[str, list[Any]] | list[MDAEvent] | None = None - predicate: Callable[[MDASequence], str | None] | None = None - - def __post_init__(self) -> None: - if self.expected is None and self.predicate is None: - raise ValueError("Either expected or predicate must be provided. ") - - -############################################################################## -# helpers -############################################################################## - - -def genindex(axes: dict[str, int]) -> list[dict[str, int]]: - """Produce the cartesian product of `range(n)` for the given axes.""" - return [ - dict(zip(axes, prod)) for prod in product(*(range(v) for v in axes.values())) - ] - - -############################################################################## -# test cases -############################################################################## - -GRID_SUBSEQ_CASES: list[MDATestCase] = [ - MDATestCase( - name="channel_only_in_position_sub_sequence", - seq=MDASequence( - stage_positions=[ - {}, - Position( - sequence=MDASequence( - channels=[Channel(config="FITC", exposure=100)] - ) - ), - ] - ), - expected={ - "channel": [None, "FITC"], - "index": [{"p": 0}, {"p": 1, "c": 0}], - "exposure": [None, 100.0], - }, - ), - MDATestCase( - name="channel_in_main_and_position_sub_sequence", - seq=MDASequence( - stage_positions=[ - {}, - Position( - sequence=MDASequence( - channels=[Channel(config="FITC", exposure=100)] - ) - ), - ], - channels=[Channel(config="Cy5", exposure=50)], - ), - expected={ - "channel": ["Cy5", "FITC"], - "index": [{"p": 0, "c": 0}, {"p": 1, "c": 0}], - "exposure": [50.0, 100.0], - }, - ), - MDATestCase( - name="subchannel_inherits_global_channel", - seq=MDASequence( - stage_positions=[ - {}, - {"sequence": {"z_plan": ZTopBottom(bottom=28, top=30, step=1)}}, - ], - channels=[Channel(config="Cy5", exposure=50)], - ), - expected={ - "channel": ["Cy5"] * 4, - "index": [ - {"p": 0, "c": 0}, - {"p": 1, "z": 0, "c": 0}, - {"p": 1, "z": 1, "c": 0}, - {"p": 1, "z": 2, "c": 0}, - ], - }, - ), - MDATestCase( - name="grid_relative_with_multi_stage_positions", - seq=MDASequence( - stage_positions=[Position(x=0, y=0), (10, 20)], - grid_plan=GridRowsColumns(rows=2, columns=2), - ), - expected={ - "index": genindex({"p": 2, "g": 4}), - "x_pos": [-0.5, 0.5, 0.5, -0.5, 9.5, 10.5, 10.5, 9.5], - "y_pos": [0.5, 0.5, -0.5, -0.5, 20.5, 20.5, 19.5, 19.5], - }, - ), - MDATestCase( - name="grid_relative_only_in_position_sub_sequence", - seq=MDASequence( - stage_positions=[ - Position(x=0, y=0), - Position( - x=10, - y=10, - sequence={ - "grid_plan": GridRowsColumns(rows=2, columns=2), - }, - ), - ] - ), - expected={ - "index": [ - {"p": 0}, - {"p": 1, "g": 0}, - {"p": 1, "g": 1}, - {"p": 1, "g": 2}, - {"p": 1, "g": 3}, - ], - "x_pos": [0.0, 9.5, 10.5, 10.5, 9.5], - "y_pos": [0.0, 10.5, 10.5, 9.5, 9.5], - }, - ), - MDATestCase( - name="grid_absolute_only_in_position_sub_sequence", - seq=MDASequence( - stage_positions=[ - Position(x=0, y=0), - Position( - x=10, - y=10, - sequence={ - "grid_plan": GridFromEdges(top=1, bottom=-1, left=0, right=0) - }, - ), - ] - ), - expected={ - "index": [ - {"p": 0}, - {"p": 1, "g": 0}, - {"p": 1, "g": 1}, - {"p": 1, "g": 2}, - ], - "x_pos": [0.0, 0.0, 0.0, 0.0], - "y_pos": [0.0, 1.0, 0.0, -1.0], - }, - ), - MDATestCase( - name="grid_relative_in_main_and_position_sub_sequence", - seq=MDASequence( - stage_positions=[ - Position(x=0, y=0), - Position( - name="name", - x=10, - y=10, - sequence={"grid_plan": GridRowsColumns(rows=2, columns=2)}, - ), - ], - grid_plan=GridRowsColumns(rows=2, columns=2), - ), - expected={ - "index": genindex({"p": 2, "g": 4}), - "pos_name": [None] * 4 + ["name"] * 4, - "x_pos": [-0.5, 0.5, 0.5, -0.5, 9.5, 10.5, 10.5, 9.5], - "y_pos": [0.5, 0.5, -0.5, -0.5, 10.5, 10.5, 9.5, 9.5], - }, - ), - MDATestCase( - name="grid_absolute_in_main_and_position_sub_sequence", - seq=MDASequence( - stage_positions=[ - {}, - Position( - name="name", - sequence={ - "grid_plan": GridFromEdges(top=2, bottom=-1, left=0, right=0) - }, - ), - ], - grid_plan=GridFromEdges(top=1, bottom=-1, left=0, right=0), - ), - expected={ - "index": [ - {"p": 0, "g": 0}, - {"p": 0, "g": 1}, - {"p": 0, "g": 2}, - {"p": 1, "g": 0}, - {"p": 1, "g": 1}, - {"p": 1, "g": 2}, - {"p": 1, "g": 3}, - ], - "pos_name": [None] * 3 + ["name"] * 4, - "x_pos": [0.0] * 7, - "y_pos": [1.0, 0.0, -1.0, 2.0, 1.0, 0.0, -1.0], - }, - ), - MDATestCase( - name="grid_absolute_in_main_and_grid_relative_in_position_sub_sequence", - seq=MDASequence( - stage_positions=[ - {}, - Position( - name="name", - x=10, - y=10, - sequence={"grid_plan": GridRowsColumns(rows=2, columns=2)}, - ), - ], - grid_plan=GridFromEdges(top=1, bottom=-1, left=0, right=0), - ), - expected={ - "index": [ - {"p": 0, "g": 0}, - {"p": 0, "g": 1}, - {"p": 0, "g": 2}, - {"p": 1, "g": 0}, - {"p": 1, "g": 1}, - {"p": 1, "g": 2}, - {"p": 1, "g": 3}, - ], - "pos_name": [None] * 3 + ["name"] * 4, - "x_pos": [0.0, 0.0, 0.0, 9.5, 10.5, 10.5, 9.5], - "y_pos": [1.0, 0.0, -1.0, 10.5, 10.5, 9.5, 9.5], - }, - ), - MDATestCase( - name="grid_relative_in_main_and_grid_absolute_in_position_sub_sequence", - seq=MDASequence( - stage_positions=[ - Position(x=0, y=0), - Position( - name="name", - sequence={ - "grid_plan": GridFromEdges(top=1, bottom=-1, left=0, right=0) - }, - ), - ], - grid_plan=GridRowsColumns(rows=2, columns=2), - ), - expected={ - "index": [ - {"p": 0, "g": 0}, - {"p": 0, "g": 1}, - {"p": 0, "g": 2}, - {"p": 0, "g": 3}, - {"p": 1, "g": 0}, - {"p": 1, "g": 1}, - {"p": 1, "g": 2}, - ], - "pos_name": [None] * 4 + ["name"] * 3, - "x_pos": [-0.5, 0.5, 0.5, -0.5, 0.0, 0.0, 0.0], - "y_pos": [0.5, 0.5, -0.5, -0.5, 1.0, 0.0, -1.0], - }, - ), - MDATestCase( - name="multi_g_in_position_sub_sequence", - seq=MDASequence( - stage_positions=[ - {"sequence": {"grid_plan": {"rows": 1, "columns": 2}}}, - {"sequence": {"grid_plan": GridRowsColumns(rows=2, columns=2)}}, - { - "sequence": { - "grid_plan": GridFromEdges(top=1, bottom=-1, left=0, right=0) - } - }, - ] - ), - expected={ - "index": [ - {"p": 0, "g": 0}, - {"p": 0, "g": 1}, - {"p": 1, "g": 0}, - {"p": 1, "g": 1}, - {"p": 1, "g": 2}, - {"p": 1, "g": 3}, - {"p": 2, "g": 0}, - {"p": 2, "g": 1}, - {"p": 2, "g": 2}, - ], - "x_pos": [-0.5, 0.5, -0.5, 0.5, 0.5, -0.5, 0.0, 0.0, 0.0], - "y_pos": [0.0, 0.0, 0.5, 0.5, -0.5, -0.5, 1.0, 0.0, -1.0], - }, - ), - MDATestCase( - name="z_relative_with_multi_stage_positions", - seq=MDASequence( - stage_positions=[(0, 0, 0), (10, 20, 10)], - z_plan=ZRangeAround(range=2, step=1), - ), - expected={ - "index": genindex({"p": 2, "z": 3}), - "x_pos": [0.0, 0.0, 0.0, 10.0, 10.0, 10.0], - "y_pos": [0.0, 0.0, 0.0, 20.0, 20.0, 20.0], - "z_pos": [-1.0, 0.0, 1.0, 9.0, 10.0, 11.0], - }, - ), - MDATestCase( - name="z_absolute_with_multi_stage_positions", - seq=MDASequence( - stage_positions=[Position(x=0, y=0), (10, 20)], - z_plan=ZTopBottom(bottom=58, top=60, step=1), - ), - expected={ - "index": genindex({"p": 2, "z": 3}), - "x_pos": [0.0, 0.0, 0.0, 10.0, 10.0, 10.0], - "y_pos": [0.0, 0.0, 0.0, 20.0, 20.0, 20.0], - "z_pos": [58.0, 59.0, 60.0, 58.0, 59.0, 60.0], - }, - ), - MDATestCase( - name="z_relative_only_in_position_sub_sequence", - seq=MDASequence( - stage_positions=[ - Position(z=0), - Position( - name="name", - z=10, - sequence={"z_plan": ZRangeAround(range=2, step=1)}, - ), - ] - ), - expected={ - "index": [ - {"p": 0}, - {"p": 1, "z": 0}, - {"p": 1, "z": 1}, - {"p": 1, "z": 2}, - ], - "pos_name": [None, "name", "name", "name"], - "z_pos": [0.0, 9.0, 10.0, 11.0], - }, - ), - MDATestCase( - name="z_absolute_only_in_position_sub_sequence", - seq=MDASequence( - stage_positions=[ - Position(z=0), - Position( - name="name", - sequence={"z_plan": ZTopBottom(bottom=58, top=60, step=1)}, - ), - ] - ), - expected={ - "index": [ - {"p": 0}, - {"p": 1, "z": 0}, - {"p": 1, "z": 1}, - {"p": 1, "z": 2}, - ], - "pos_name": [None, "name", "name", "name"], - "z_pos": [0.0, 58.0, 59.0, 60.0], - }, - ), - MDATestCase( - name="z_relative_in_main_and_position_sub_sequence", - seq=MDASequence( - stage_positions=[ - Position(z=0), - Position( - name="name", - z=10, - sequence={"z_plan": ZRangeAround(range=3, step=1)}, - ), - ], - z_plan=ZRangeAround(range=2, step=1), - ), - expected={ - # pop the 3rd index - "index": (idx := genindex({"p": 2, "z": 4}))[:3] + idx[4:], - "pos_name": [None] * 3 + ["name"] * 4, - "z_pos": [-1.0, 0.0, 1.0, 8.5, 9.5, 10.5, 11.5], - }, - ), - MDATestCase( - name="z_absolute_in_main_and_position_sub_sequence", - seq=MDASequence( - stage_positions=[ - {}, - Position( - name="name", - sequence={"z_plan": ZTopBottom(bottom=28, top=30, step=1)}, - ), - ], - z_plan=ZTopBottom(bottom=58, top=60, step=1), - ), - expected={ - "index": genindex({"p": 2, "z": 3}), - "pos_name": [None] * 3 + ["name"] * 3, - "z_pos": [58.0, 59.0, 60.0, 28.0, 29.0, 30.0], - }, - ), - MDATestCase( - name="z_absolute_in_main_and_z_relative_in_position_sub_sequence", - seq=MDASequence( - stage_positions=[ - {}, - Position( - name="name", - z=10, - sequence={"z_plan": ZRangeAround(range=3, step=1)}, - ), - ], - z_plan=ZTopBottom(bottom=58, top=60, step=1), - ), - expected={ - "index": [ - {"p": 0, "z": 0}, - {"p": 0, "z": 1}, - {"p": 0, "z": 2}, - {"p": 1, "z": 0}, - {"p": 1, "z": 1}, - {"p": 1, "z": 2}, - {"p": 1, "z": 3}, - ], - "pos_name": [None] * 3 + ["name"] * 4, - "z_pos": [58.0, 59.0, 60.0, 8.5, 9.5, 10.5, 11.5], - }, - ), - MDATestCase( - name="z_relative_in_main_and_z_absolute_in_position_sub_sequence", - seq=MDASequence( - stage_positions=[ - Position(z=0), - Position( - name="name", - sequence={"z_plan": ZTopBottom(bottom=58, top=60, step=1)}, - ), - ], - z_plan=ZRangeAround(range=3, step=1), - ), - expected={ - "index": [ - {"p": 0, "z": 0}, - {"p": 0, "z": 1}, - {"p": 0, "z": 2}, - {"p": 0, "z": 3}, - {"p": 1, "z": 0}, - {"p": 1, "z": 1}, - {"p": 1, "z": 2}, - ], - "pos_name": [None] * 4 + ["name"] * 3, - "z_pos": [-1.5, -0.5, 0.5, 1.5, 58.0, 59.0, 60.0], - }, - ), - MDATestCase( - name="multi_z_in_position_sub_sequence", - seq=MDASequence( - stage_positions=[ - {"sequence": {"z_plan": ZTopBottom(bottom=58, top=60, step=1)}}, - {"sequence": {"z_plan": ZRangeAround(range=3, step=1)}}, - {"sequence": {"z_plan": ZTopBottom(bottom=28, top=30, step=1)}}, - ] - ), - expected={ - "index": [ - {"p": 0, "z": 0}, - {"p": 0, "z": 1}, - {"p": 0, "z": 2}, - {"p": 1, "z": 0}, - {"p": 1, "z": 1}, - {"p": 1, "z": 2}, - {"p": 1, "z": 3}, - {"p": 2, "z": 0}, - {"p": 2, "z": 1}, - {"p": 2, "z": 2}, - ], - "z_pos": [ - 58.0, - 59.0, - 60.0, - -1.5, - -0.5, - 0.5, - 1.5, - 28.0, - 29.0, - 30.0, - ], - }, - ), - MDATestCase( - name="t_with_multi_stage_positions", - seq=MDASequence( - stage_positions=[{}, {}], - time_plan=[TIntervalLoops(interval=1, loops=2)], - ), - expected={ - "index": genindex({"t": 2, "p": 2}), - "min_start_time": [0.0, 0.0, 1.0, 1.0], - }, - ), - MDATestCase( - name="t_only_in_position_sub_sequence", - seq=MDASequence( - stage_positions=[ - {}, - {"sequence": {"time_plan": [TIntervalLoops(interval=1, loops=5)]}}, - ] - ), - expected={ - "index": [ - {"p": 0}, - {"p": 1, "t": 0}, - {"p": 1, "t": 1}, - {"p": 1, "t": 2}, - {"p": 1, "t": 3}, - {"p": 1, "t": 4}, - ], - "min_start_time": [None, 0.0, 1.0, 2.0, 3.0, 4.0], - }, - ), - MDATestCase( - name="t_in_main_and_in_position_sub_sequence", - seq=MDASequence( - stage_positions=[ - {}, - {"sequence": {"time_plan": [TIntervalLoops(interval=1, loops=5)]}}, - ], - time_plan=[TIntervalLoops(interval=1, loops=2)], - ), - expected={ - "index": [ - {"t": 0, "p": 0}, - {"t": 0, "p": 1}, - {"t": 1, "p": 1}, - {"t": 2, "p": 1}, - {"t": 3, "p": 1}, - {"t": 4, "p": 1}, - {"t": 1, "p": 0}, - {"t": 0, "p": 1}, - {"t": 1, "p": 1}, - {"t": 2, "p": 1}, - {"t": 3, "p": 1}, - {"t": 4, "p": 1}, - ], - "min_start_time": [ - 0.0, - 0.0, - 1.0, - 2.0, - 3.0, - 4.0, - 1.0, - 0.0, - 1.0, - 2.0, - 3.0, - 4.0, - ], - }, - ), - MDATestCase( - name="mix_cgz_axes", - seq=MDASequence( - axis_order="tpgcz", - stage_positions=[ - Position(x=0, y=0), - Position( - name="name", - x=10, - y=10, - z=30, - sequence=MDASequence( - channels=[ - {"config": "FITC", "exposure": 200}, - {"config": "Cy3", "exposure": 100}, - ], - grid_plan=GridRowsColumns(rows=2, columns=1), - z_plan=ZRangeAround(range=2, step=1), - ), - ), - ], - channels=[Channel(config="Cy5", exposure=50)], - z_plan={"top": 100, "bottom": 98, "step": 1}, - grid_plan=GridFromEdges(top=1, bottom=-1, left=0, right=0), - ), - expected={ - "index": [ - *genindex({"p": 1, "g": 3, "c": 1, "z": 3}), - {"p": 1, "g": 0, "c": 0, "z": 0}, - {"p": 1, "g": 0, "c": 0, "z": 1}, - {"p": 1, "g": 0, "c": 0, "z": 2}, - {"p": 1, "g": 0, "c": 1, "z": 0}, - {"p": 1, "g": 0, "c": 1, "z": 1}, - {"p": 1, "g": 0, "c": 1, "z": 2}, - {"p": 1, "g": 1, "c": 0, "z": 0}, - {"p": 1, "g": 1, "c": 0, "z": 1}, - {"p": 1, "g": 1, "c": 0, "z": 2}, - {"p": 1, "g": 1, "c": 1, "z": 0}, - {"p": 1, "g": 1, "c": 1, "z": 1}, - {"p": 1, "g": 1, "c": 1, "z": 2}, - ], - "pos_name": [None] * 9 + ["name"] * 12, - "x_pos": [0.0] * 9 + [10.0] * 12, - "y_pos": [1, 1, 1, 0, 0, 0, -1, -1, -1] + [10.5] * 6 + [9.5] * 6, - "z_pos": [98.0, 99.0, 100.0] * 3 + [29.0, 30.0, 31.0] * 4, - "channel": ["Cy5"] * 9 + (["FITC"] * 3 + ["Cy3"] * 3) * 2, - "exposure": [50.0] * 9 + [200.0, 200.0, 200.0, 100.0, 100.0, 100.0] * 2, - }, - ), - MDATestCase( - name="order", - seq=MDASequence( - stage_positions=[ - Position(z=0), - Position( - z=50, - sequence=MDASequence( - channels=[ - Channel(config="FITC", exposure=100), - Channel(config="Cy3", exposure=200), - ] - ), - ), - ], - channels=[ - Channel(config="FITC", exposure=100), - Channel(config="Cy5", exposure=50), - ], - z_plan=ZRangeAround(range=2, step=1), - ), - expected={ - "index": [ - {"p": 0, "c": 0, "z": 0}, - {"p": 0, "c": 0, "z": 1}, - {"p": 0, "c": 0, "z": 2}, - {"p": 0, "c": 1, "z": 0}, - {"p": 0, "c": 1, "z": 1}, - {"p": 0, "c": 1, "z": 2}, - {"p": 1, "c": 0, "z": 0}, - {"p": 1, "c": 1, "z": 0}, - {"p": 1, "c": 0, "z": 1}, - {"p": 1, "c": 1, "z": 1}, - {"p": 1, "c": 0, "z": 2}, - {"p": 1, "c": 1, "z": 2}, - ], - "z_pos": [ - -1.0, - 0.0, - 1.0, - -1.0, - 0.0, - 1.0, - 49.0, - 49.0, - 50.0, - 50.0, - 51.0, - 51.0, - ], - "channel": ["FITC"] * 3 + ["Cy5"] * 3 + ["FITC", "Cy3"] * 3, - }, - ), - MDATestCase( - name="channels_and_pos_grid_plan", - seq=MDASequence( - channels=[ - Channel(config="Cy5", exposure=50), - Channel(config="FITC", exposure=100), - ], - stage_positions=[ - Position( - x=0, - y=0, - sequence=MDASequence(grid_plan=GridRowsColumns(rows=2, columns=1)), - ) - ], - ), - expected={ - "index": genindex({"p": 1, "c": 2, "g": 2}), - "x_pos": [0.0, 0.0, 0.0, 0.0], - "y_pos": [0.5, -0.5, 0.5, -0.5], - "channel": ["Cy5", "Cy5", "FITC", "FITC"], - }, - ), - MDATestCase( - name="channels_and_pos_z_plan", - seq=MDASequence( - channels=[ - Channel(config="Cy5", exposure=50), - Channel(config="FITC", exposure=100), - ], - stage_positions=[ - Position( - x=0, - y=0, - z=0, - sequence={"z_plan": ZRangeAround(range=2, step=1)}, - ) - ], - ), - expected={ - "index": genindex({"p": 1, "c": 2, "z": 3}), - "z_pos": [-1.0, 0.0, 1.0, -1.0, 0.0, 1.0], - "channel": ["Cy5", "Cy5", "Cy5", "FITC", "FITC", "FITC"], - }, - ), - MDATestCase( - name="channels_and_pos_time_plan", - seq=MDASequence( - axis_order="tpgcz", - channels=[ - Channel(config="Cy5", exposure=50), - Channel(config="FITC", exposure=100), - ], - stage_positions=[ - Position( - x=0, - y=0, - sequence={"time_plan": [TIntervalLoops(interval=1, loops=3)]}, - ) - ], - ), - expected={ - "index": genindex({"p": 1, "c": 2, "t": 3}), - "min_start_time": [0.0, 1.0, 2.0, 0.0, 1.0, 2.0], - "channel": ["Cy5", "Cy5", "Cy5", "FITC", "FITC", "FITC"], - }, - ), - MDATestCase( - name="channels_and_pos_z_grid_and_time_plan", - seq=MDASequence( - channels=[ - Channel(config="Cy5", exposure=50), - Channel(config="FITC", exposure=100), - ], - stage_positions=[ - Position( - x=0, - y=0, - sequence=MDASequence( - grid_plan=GridRowsColumns(rows=2, columns=2), - z_plan=ZRangeAround(range=2, step=1), - time_plan=[TIntervalLoops(interval=1, loops=2)], - ), - ) - ], - ), - expected={"channel": ["Cy5"] * 24 + ["FITC"] * 24}, - ), - MDATestCase( - name="sub_channels_and_any_plan", - seq=MDASequence( - channels=["Cy5", "FITC"], - stage_positions=[ - Position( - sequence=MDASequence( - channels=["FITC"], - z_plan=ZRangeAround(range=2, step=1), - ) - ) - ], - ), - expected={"channel": ["FITC", "FITC", "FITC"]}, - ), -] - -############################################################################## -# Autofocus Tests -############################################################################## - - -def ensure_af( - expected_indices: Sequence[int] | None = None, expected_z: float | None = None -) -> Callable[[MDASequence], str | None]: - """Test things about autofocus events. - - Parameters - ---------- - expected_indices : Sequence[int] | None - Ensure that the autofocus events are at these indices. - expected_z : float | None - Ensure that all autofocus events have this z position. - """ - exp = list(expected_indices) if expected_indices else [] - - def _pred(seq: MDASequence) -> str | None: - errors: list[str] = [] - if exp: - actual_indices = [ - i - for i, ev in enumerate(seq) - if isinstance(ev.action, HardwareAutofocus) - ] - if actual_indices != exp: - errors.append(f"expected AF indices {exp}, got {actual_indices}") - - if expected_z is not None: - z_vals = [ - ev.z_pos for ev in seq if isinstance(ev.action, HardwareAutofocus) - ] - if not all(z == expected_z for z in z_vals): - errors.append(f"expected all AF events at z={expected_z}, got {z_vals}") - if errors: - return ", ".join(errors) - return None - - return _pred - - -AF_CASES: list[MDATestCase] = [ - # 1. NO AXES - Should never trigger - MDATestCase( - name="af_no_axes_no_autofocus", - seq=MDASequence( - stage_positions=[Position(z=30)], - z_plan=ZRangeAround(range=2, step=1), - channels=["DAPI", "FITC"], - autofocus_plan=AxesBasedAF( - autofocus_device_name="Z", autofocus_motor_offset=40, axes=() - ), - ), - predicate=ensure_af(expected_indices=[]), - ), - # 2. CHANNEL AXIS (c) - Triggers on channel changes - MDATestCase( - name="af_axes_c_basic", - seq=MDASequence( - stage_positions=[Position(z=30)], - z_plan=ZRangeAround(range=2, step=1), - channels=["DAPI", "FITC"], - autofocus_plan=AxesBasedAF(autofocus_device_name="Z", axes=("c",)), - ), - predicate=ensure_af(expected_indices=[0, 4]), - ), - # 3. Z AXIS (z) - Triggers on z changes - MDATestCase( - name="af_axes_z_basic", - seq=MDASequence( - stage_positions=[Position(z=30)], - z_plan=ZRangeAround(range=2, step=1), - channels=["DAPI", "FITC"], - autofocus_plan=AxesBasedAF(autofocus_device_name="Z", axes=("z",)), - ), - predicate=ensure_af(expected_indices=range(0, 11, 2)), - ), - # 4. GRID AXIS (g) - Triggers on grid position changes - MDATestCase( - name="af_axes_g_basic", - seq=MDASequence( - stage_positions=[Position(z=30)], - channels=["DAPI", "FITC"], - grid_plan=GridRowsColumns(rows=2, columns=1), - autofocus_plan=AxesBasedAF(autofocus_device_name="Z", axes=("g",)), - ), - predicate=ensure_af(expected_indices=[0, 3]), - ), - # 5. POSITION AXIS (p) - Triggers on position changes - MDATestCase( - name="af_axes_p_basic", - seq=MDASequence( - stage_positions=[Position(z=30), Position(z=200)], - channels=["DAPI", "FITC"], - autofocus_plan=AxesBasedAF(autofocus_device_name="Z", axes=("p",)), - ), - predicate=ensure_af(expected_indices=[0, 3]), - ), - # 6. TIME AXIS (t) - Triggers on time changes - MDATestCase( - name="af_axes_t_basic", - seq=MDASequence( - stage_positions=[Position(z=30), Position(z=200)], - channels=["DAPI", "FITC"], - time_plan=[TIntervalLoops(interval=1, loops=2)], - autofocus_plan=AxesBasedAF(autofocus_device_name="Z", axes=("t",)), - ), - predicate=ensure_af(expected_indices=[0, 5]), - ), - # 7. AXIS ORDER EFFECTS - Different axis order changes when axes trigger - MDATestCase( - name="af_axis_order_effect", - seq=MDASequence( - stage_positions=[Position(z=30)], - z_plan=ZRangeAround(range=2, step=1), - channels=["DAPI", "FITC"], - axis_order="tpgzc", # Different from default "tpczg" - autofocus_plan=AxesBasedAF(autofocus_device_name="Z", axes=("z",)), - ), - predicate=ensure_af(expected_indices=[0, 3, 6]), - ), - # 8. SUBSEQUENCE AUTOFOCUS - AF plan within position subsequence - MDATestCase( - name="af_subsequence_af", - seq=MDASequence( - stage_positions=[ - Position(z=30), - Position( - z=10, - sequence=MDASequence( - autofocus_plan=AxesBasedAF( - autofocus_device_name="Z", - axes=("c",), - ) - ), - ), - ], - channels=["DAPI", "FITC"], - ), - predicate=ensure_af(expected_indices=[2, 4]), - ), - # 9. MIXED MAIN + SUBSEQUENCE AF - MDATestCase( - name="af_mixed_main_and_sub", - seq=MDASequence( - stage_positions=[ - Position(z=30), - Position( - z=10, - sequence=MDASequence( - autofocus_plan=AxesBasedAF( - autofocus_device_name="Z", - autofocus_motor_offset=40, - axes=("z",), - ), - ), - ), - ], - channels=["DAPI", "FITC"], - z_plan=ZRangeAround(range=2, step=1), - autofocus_plan=AxesBasedAF( - autofocus_device_name="Z", autofocus_motor_offset=40, axes=("p",) - ), - ), - predicate=ensure_af(expected_indices=[0, *range(7, 18, 2)]), - ), - # 10. Z POSITION CORRECTION - AF events get correct z position with relative z plans - MDATestCase( - name="af_z_position_correction", - seq=MDASequence( - stage_positions=[Position(z=200)], - channels=["DAPI", "FITC"], - z_plan=ZRangeAround(range=2, step=1), - autofocus_plan=AxesBasedAF( - autofocus_device_name="Z", autofocus_motor_offset=40, axes=("c",) - ), - ), - predicate=ensure_af(expected_z=200), - ), - # 11. SUBSEQUENCE Z POSITION CORRECTION - MDATestCase( - name="af_subsequence_z_position", - seq=MDASequence( - stage_positions=[ - Position( - z=10, - sequence=MDASequence( - autofocus_plan=AxesBasedAF( - autofocus_device_name="Z", - autofocus_motor_offset=40, - axes=("c",), - ) - ), - ) - ], - channels=["DAPI", "FITC"], - z_plan=ZRangeAround(range=2, step=1), - ), - predicate=ensure_af(expected_z=10), - ), - # 12. NO DEVICE NAME - Edge case for testing without device name - MDATestCase( - name="af_no_device_name", - seq=MDASequence( - time_plan=[TIntervalLoops(interval=1, loops=2)], - autofocus_plan=AxesBasedAF(axes=("t",)), - ), - predicate=lambda _: "", # Just check it doesn't crash - ), -] - -############################################################################## -# Keep Shutter Open Tests -############################################################################### - - -def ensure_shutter_behavior( - expected_indices: Sequence[int] | bool | None = None, -) -> Callable[[MDASequence], str | None]: - """Test keep_shutter_open behavior.""" - - def _pred(seq: MDASequence) -> str | None: - events = list(seq) - errors: list[str] = [] - - if expected_indices is not None: - if expected_indices is True: - if closed_events := [ - i for i, e in enumerate(events) if not e.keep_shutter_open - ]: - errors.append( - f"expected all shutters open, but events " - f"{closed_events} have keep_shutter_open=False" - ) - elif expected_indices is False: - if open_events := [ - i for i, e in enumerate(events) if e.keep_shutter_open - ]: - errors.append( - f"expected all shutters closed, but events " - f"{open_events} have keep_shutter_open=True" - ) - else: - actual_indices = [ - i for i, e in enumerate(events) if e.keep_shutter_open - ] - if actual_indices != list(expected_indices): - errors.append( - f"expected shutter open at indices {expected_indices}, " - f"got {actual_indices}" - ) - - if errors: - return "; ".join(errors) - return None - - return _pred - - -KEEP_SHUTTER_CASES: list[MDATestCase] = [ - # with z as the last axis, the shutter will be left open - # whenever z is the first index (since there are only 2 z planes) - MDATestCase( - name="keep_shutter_open_across_z_order_tcz", - seq=MDASequence( - axis_order=tuple("tcz"), - channels=["DAPI", "FITC"], - time_plan=TIntervalLoops(loops=2, interval=0), - z_plan=ZRangeAround(range=1, step=1), - keep_shutter_open_across="z", - ), - predicate=ensure_shutter_behavior(expected_indices=[0, 2, 4, 6]), - ), - # with c as the last axis, the shutter will never be left open - MDATestCase( - name="keep_shutter_open_across_z_order_tzc", - seq=MDASequence( - axis_order=tuple("tzc"), - channels=["DAPI", "FITC"], - time_plan=TIntervalLoops(loops=2, interval=0), - z_plan=ZRangeAround(range=1, step=1), - keep_shutter_open_across="z", - ), - predicate=ensure_shutter_behavior(expected_indices=[]), - ), - # because t is changing faster than z, the shutter will never be left open - MDATestCase( - name="keep_shutter_open_across_z_order_czt", - seq=MDASequence( - axis_order=tuple("czt"), - channels=["DAPI", "FITC"], - time_plan=TIntervalLoops(loops=2, interval=0), - z_plan=ZRangeAround(range=1, step=1), - keep_shutter_open_across="z", - ), - predicate=ensure_shutter_behavior(expected_indices=[]), - ), - # but, if we include 't' in the keep_shutter_open_across, - # it will be left open except when it's the last t and last z - MDATestCase( - name="keep_shutter_open_across_zt_order_czt", - seq=MDASequence( - axis_order=tuple("czt"), - channels=["DAPI", "FITC"], - time_plan=TIntervalLoops(loops=2, interval=0), - z_plan=ZRangeAround(range=1, step=1), - keep_shutter_open_across=("z", "t"), - ), - # for event in seq: - # is_last_zt = bool(event.index["t"] == 1 and event.index["z"] == 1) - # assert event.keep_shutter_open != is_last_zt - predicate=ensure_shutter_behavior(expected_indices=[0, 1, 2, 4, 5, 6]), - ), - # even though c is the last axis, and comes after g, because the grid happens - # on a subsequence shutter will be open across the grid for each position - MDATestCase( - name="keep_shutter_open_across_g_order_pgc_with_subseq", - seq=MDASequence( - axis_order=tuple("pgc"), - channels=["DAPI", "FITC"], - stage_positions=[ - Position( - sequence=MDASequence(grid_plan=GridRowsColumns(rows=2, columns=2)) - ) - ], - keep_shutter_open_across="g", - ), - # for event in seq: - # assert event.keep_shutter_open != (event.index["g"] == 3) - predicate=ensure_shutter_behavior(expected_indices=[0, 1, 2, 4, 5, 6]), - ), -] - -# ############################################################################## -# Reset Event Timer Test Cases -# ############################################################################## - -RESET_EVENT_TIMER_CASES: list[MDATestCase] = [ - MDATestCase( - name="reset_event_timer_with_time_intervals", - seq=MDASequence( - stage_positions=[(100, 100), (0, 0)], - time_plan={"interval": 1, "loops": 2}, - axis_order=tuple("ptgcz"), - ), - expected={ - "reset_event_timer": [True, False, True, False], - }, - ), - MDATestCase( - name="reset_event_timer_with_nested_position_sequences", - seq=MDASequence( - stage_positions=[ - Position( - x=0, - y=0, - sequence=MDASequence( - channels=["Cy5"], time_plan={"interval": 1, "loops": 2} - ), - ), - Position( - x=1, - y=1, - sequence=MDASequence( - channels=["DAPI"], time_plan={"interval": 1, "loops": 2} - ), - ), - ] - ), - expected={ - "reset_event_timer": [True, False, True, False], - }, - ), -] - -# ############################################################################## -# Combined Test Cases -# ############################################################################## - -CASES: list[MDATestCase] = ( - GRID_SUBSEQ_CASES + AF_CASES + KEEP_SHUTTER_CASES + RESET_EVENT_TIMER_CASES -) - -# assert that all test cases are unique -case_names = [case.name for case in CASES] -if duplicates := {name for name in case_names if case_names.count(name) > 1}: - raise ValueError( - f"Duplicate test case names found: {duplicates}. " - "Please ensure all test cases have unique names." - ) +from .fixtures.cases import CASES, MDATestCase, assert_test_case_passes @pytest.mark.parametrize("case", CASES, ids=lambda c: c.name) def test_mda_sequence(case: MDATestCase) -> None: - # test case expressed the expectation as a predicate - if case.predicate is not None: - # (a function that returns a non-empty error message if the test fails) - if msg := case.predicate(case.seq): - raise AssertionError(f"\nExpectation not met in '{case.name}':\n {msg}\n") - - # test case expressed the expectation as a list of MDAEvent - elif isinstance(case.expected, list): - actual_events = list(case.seq) - if len(actual_events) != len(case.expected): - raise AssertionError( - f"\nMismatch in case '{case.name}':\n" - f" expected: {len(case.expected)} events\n" - f" actual: {len(actual_events)} events\n" - ) - for i, event in enumerate(actual_events): - if event != case.expected[i]: - raise AssertionError( - f"\nMismatch in case '{case.name}':\n" - f" expected: {case.expected[i]}\n" - f" actual: {event}\n" - ) - - # test case expressed the expectation as a dict of {Event attr -> values list} - else: - assert isinstance(case.expected, dict), f"Invalid test case: {case.name!r}" - actual: dict[str, list[Any]] = {k: [] for k in case.expected} - for event in case.seq: - for attr in case.expected: - actual[attr].append(getattr(event, attr)) - - if mismatched_fields := { - attr for attr in actual if actual[attr] != case.expected[attr] - }: - msg = f"\nMismatch in case '{case.name}':\n" - for attr in mismatched_fields: - msg += f" {attr}:\n" - msg += f" expected: {case.expected[attr]}\n" - msg += f" actual: {actual[attr]}\n" - raise AssertionError(msg) + assert_test_case_passes(case, list(case.seq)) diff --git a/tests/v2/test_cases2.py b/tests/v2/test_cases2.py new file mode 100644 index 00000000..b3d95f45 --- /dev/null +++ b/tests/v2/test_cases2.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest +from rich import print # noqa: F401 +from tests.fixtures.cases import CASES, MDATestCase, assert_test_case_passes + +from useq import v2 + +if TYPE_CHECKING: + from useq._mda_event import MDAEvent + + +@pytest.mark.filterwarnings("ignore:Conflicting absolute pos") +@pytest.mark.parametrize("case", CASES, ids=lambda c: c.name) +def test_mda_sequence(case: MDATestCase) -> None: + if "af_z_position" in case.name: + pytest.xfail("af_z_position is not yet working in useq.v2, ") + + v2_seq = v2.MDASequence.model_validate(case.seq) + assert isinstance(v2_seq, v2.MDASequence) + actual_events = list(v2_seq) + + assert_test_case_passes(case, actual_events) + + if "af_" not in case.name: + assert_v2_same_as_v1(list(case.seq), actual_events) + + +def assert_v2_same_as_v1(v1_events: list[MDAEvent], v2_events: list[MDAEvent]) -> None: + """Assert that the v2 sequence is the same as the v1 sequence.""" + # test parity with v1 + if v2_events != v1_events: + # print intelligible diff to see exactly what is different, including + # total number of events, indices that differ, and a full repr + # of the first event that differs + + msg = [] + if len(v2_events) != len(v1_events): + msg.append(f"Number of events differ: {len(v2_events)} != {len(v1_events)}") + differing_indices = [ + i for i, (a, b) in enumerate(zip(v2_events, v1_events)) if a != b + ] + if differing_indices: + msg.append(f"Indices that differ: {differing_indices}") + + # show the first differing event in full + idx = differing_indices[0] + + v1_dict = v1_events[idx].model_dump(exclude={"sequence"}) + v2_dict = v2_events[idx].model_dump(exclude={"sequence"}) + + diff_fields = {f for f in v1_dict if v1_dict[f] != v2_dict.get(f)} + v1min = {k: v for k, v in v1_dict.items() if k in diff_fields} + v2min = {k: v for k, v in v2_dict.items() if k in diff_fields} + msg.extend( + [ + f"First differing event (index {idx}):", + f" EXPECT: {v1min}", + f" ACTUAL: {v2min}", + ] + ) + raise AssertionError( + "Events differ between v1 and v2 MDASequence:\n\n" + "\n ".join(msg) + ) diff --git a/tests/v2/test_grid.py b/tests/v2/test_grid.py new file mode 100644 index 00000000..887d6b85 --- /dev/null +++ b/tests/v2/test_grid.py @@ -0,0 +1,265 @@ +from __future__ import annotations + +import importlib +import importlib.util +import math +import sys +from dataclasses import dataclass +from typing import TYPE_CHECKING, cast + +import numpy as np +import pytest + +from useq.v2 import ( + GridFromEdges, + GridRowsColumns, + GridWidthHeight, + MultiPointPlan, + OrderMode, + RandomPoints, + RelativeTo, + Shape, + TraversalOrder, +) + +if TYPE_CHECKING: + from collections.abc import Iterable + + from useq.v2 import Position + + +def _in_ellipse(x: float, y: float, w: float, h: float, tol: float = 1.01) -> bool: + return (x / (w / 2)) ** 2 + (y / (h / 2)) ** 2 <= tol + + +if sys.version_info >= (3, 10): + SLOTS = {"slots": True} +else: + SLOTS = {} + + +@dataclass(**SLOTS) +class GridTestCase: + grid: MultiPointPlan + expected_coords: list[tuple[float, float]] + + +GRID_CASES: list[GridTestCase] = [ + # ------------------------------------------------------------------- + GridTestCase( + GridFromEdges( + top=10, + left=0, + bottom=0, + right=10, + fov_width=5, + fov_height=5, + ), + [(2.5, 7.5), (7.5, 7.5), (7.5, 2.5), (2.5, 2.5)], + ), + GridTestCase( + GridFromEdges( + top=5, + left=0, + bottom=0, + right=5, + fov_width=5, + fov_height=5, + ), + [(2.5, 2.5)], + ), + # ------------------------------------------------------------------- + GridTestCase( + GridRowsColumns( + rows=2, + columns=3, + relative_to=RelativeTo.center, + fov_width=1, + fov_height=1, + ), + [(-1.0, 0.5), (0.0, 0.5), (1.0, 0.5), (1.0, -0.5), (0.0, -0.5), (-1.0, -0.5)], + ), + GridTestCase( + GridRowsColumns( + rows=2, + columns=2, + relative_to=RelativeTo.top_left, + fov_width=1, + fov_height=1, + ), + [(0.0, 0.0), (1.0, 0.0), (1.0, -1.0), (0.0, -1.0)], + ), + # ------------------------------------------------------------------- + GridTestCase( + GridWidthHeight( + width=3, + height=2, + relative_to=RelativeTo.center, + fov_width=1, + fov_height=1, + ), + [(-1.0, 0.5), (0.0, 0.5), (1.0, 0.5), (1.0, -0.5), (0.0, -0.5), (-1.0, -0.5)], + ), + GridTestCase( + GridWidthHeight( + width=2, + height=2, + relative_to=RelativeTo.top_left, + fov_width=1, + fov_height=1, + ), + [(0.0, 0.0), (1.0, 0.0), (1.0, -1.0), (0.0, -1.0)], + ), + # fractional coverage (2.5 x 1.5) ⇒ same coords as 3 x 2 case + GridTestCase( + GridWidthHeight( + width=2.5, + height=1.5, + relative_to=RelativeTo.center, + fov_width=1, + fov_height=1, + ), + [(-1.0, 0.5), (0.0, 0.5), (1.0, 0.5), (1.0, -0.5), (0.0, -0.5), (-1.0, -0.5)], + ), + # ------------------------------------------------------------------- + GridTestCase( + RandomPoints( + shape=Shape.ELLIPSE, + num_points=5, + max_width=10, + max_height=6, + random_seed=42, + ), + [ + (-0.2114, -2.8339), + (-0.2337, -1.5420), + (1.6669, 0.3887), + (1.7288, 0.5794), + (3.4772, -0.0116), + ], + ), + GridTestCase( + RandomPoints( + shape=Shape.RECTANGLE, + num_points=4, + max_width=8, + max_height=4, + random_seed=123, + ), + [(1.5717, -0.8554), (-2.1851, 0.2052), (1.7557, -0.3075), (3.8461, 0.7393)], + ), +] + + +def _coords(grid: Iterable[Position]) -> list[tuple[float, float]]: + return [(p.x, p.y) for p in grid] # type: ignore + + +@pytest.mark.parametrize("tc", GRID_CASES, ids=lambda tc: type(tc.grid).__name__) +def test_grid_cases(tc: GridTestCase) -> None: + pos = list(tc.grid) + coords = _coords(pos) + np.testing.assert_allclose(coords, tc.expected_coords, atol=1e-4) + assert len(pos) == len(tc.expected_coords) + + if isinstance(tc.grid, RandomPoints): + w, h = tc.grid.max_width, tc.grid.max_height + if tc.grid.shape is Shape.ELLIPSE: + for x, y in coords: + assert _in_ellipse(x, y, w, h) + else: + for x, y in coords: + assert -w / 2 <= x <= w / 2 + assert -h / 2 <= y <= h / 2 + + +def test_grid_from_edges_with_overlap() -> None: + g = GridFromEdges( + top=10, + left=0, + bottom=0, + right=10, + fov_width=5, + fov_height=5, + overlap=50, + ) + coords = cast("list[tuple[float, float]]", [(p.x, p.y) for p in g]) + + # 50 % overlap ⇒ step = 2.5 µm + assert len(g) > 4 + assert coords[0] == (2.5, 7.5) + assert math.isclose(coords[1][0] - coords[0][0], 2.5, abs_tol=1e-6) + + +def test_grid_rows_columns_overlap_spacing() -> None: + g = GridRowsColumns( + rows=2, + columns=2, + relative_to=RelativeTo.center, + fov_width=2, + fov_height=2, + overlap=(25, 50), + ) + coords = _coords(g) + + dx, dy = 2 * (1 - 0.25), 2 * (1 - 0.5) + assert math.isclose(abs(coords[1][0] - coords[0][0]), dx, abs_tol=0.01) + assert math.isclose(abs(coords[2][1] - coords[0][1]), dy, abs_tol=0.01) + + +def test_random_points_no_overlap() -> None: + g = RandomPoints( + num_points=3, + max_width=10, + max_height=10, + shape=Shape.RECTANGLE, + fov_width=2, + fov_height=2, + allow_overlap=False, + random_seed=456, + ) + coords = _coords(g) + for i, (x1, y1) in enumerate(coords): + for j, (x2, y2) in enumerate(coords): + if i != j: + assert abs(x1 - x2) >= 2 or abs(y1 - y2) >= 2 + + if importlib.util.find_spec("matplotlib") is not None: + g.plot(show=False) + + +def test_random_points_traversal_ordering() -> None: + g1 = RandomPoints(num_points=5, random_seed=789, order=None) + g2 = RandomPoints(num_points=5, random_seed=789, order=TraversalOrder.TWO_OPT) + + coords1 = [(p.x, p.y) for p in g1] + coords2 = [(p.x, p.y) for p in g2] + + assert set(coords1) == set(coords2) and coords1 != coords2 + + +# --------------------------------------------------------------------------- +# traversal modes & naming +# --------------------------------------------------------------------------- + + +def test_row_vs_column_snake() -> None: + row = GridRowsColumns( + rows=2, columns=3, mode=OrderMode.row_wise_snake, fov_width=1, fov_height=1 + ) + col = GridRowsColumns( + rows=2, columns=3, mode=OrderMode.column_wise_snake, fov_width=1, fov_height=1 + ) + + row_coords = [(p.x, p.y) for p in row] + col_coords = [(p.x, p.y) for p in col] + + assert row_coords[0] == col_coords[0] # both start top-left + assert row_coords[1] != col_coords[1] # diverge after that + + +def test_position_naming() -> None: + names = [ + p.name for p in GridRowsColumns(rows=2, columns=2, fov_width=1, fov_height=1) + ] + assert names == ["0000", "0001", "0002", "0003"] diff --git a/tests/v2/test_grid_and_points_plans_v2.py b/tests/v2/test_grid_and_points_plans_v2.py new file mode 100644 index 00000000..8693d928 --- /dev/null +++ b/tests/v2/test_grid_and_points_plans_v2.py @@ -0,0 +1,296 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Optional, get_args + +import pytest +from pydantic import TypeAdapter + +from useq._point_visiting import _rect_indices, _spiral_indices +from useq.v2 import ( + GridFromEdges, + GridRowsColumns, + GridWidthHeight, + MultiPointPlan, + MultiPositionPlan, + OrderMode, + Position, + RandomPoints, + TraversalOrder, +) + +if TYPE_CHECKING: + from collections.abc import Iterable, Sequence + + +def RelativePosition(**kwargs: Any) -> Position: + return Position(**kwargs, is_relative=True) + + +g_inputs = [ + ( + GridRowsColumns(overlap=10, rows=1, columns=2, relative_to="center"), + [ + RelativePosition(x=-0.45, y=0.0, name="0000", row=0, col=0), + RelativePosition(x=0.45, y=0.0, name="0001", row=0, col=1), + ], + ), + ( + GridRowsColumns(overlap=0, rows=1, columns=2, relative_to="top_left"), + [ + RelativePosition(x=0.0, y=0.0, name="0000", row=0, col=0), + RelativePosition(x=1.0, y=0.0, name="0001", row=0, col=1), + ], + ), + ( + GridRowsColumns(overlap=(20, 40), rows=2, columns=2), + [ + RelativePosition(x=-0.4, y=0.3, name="0000", row=0, col=0), + RelativePosition(x=0.4, y=0.3, name="0001", row=0, col=1), + RelativePosition(x=0.4, y=-0.3, name="0002", row=1, col=1), + RelativePosition(x=-0.4, y=-0.3, name="0003", row=1, col=0), + ], + ), + ( + GridFromEdges( + overlap=0, top=0, left=0, bottom=20, right=20, fov_height=20, fov_width=20 + ), + [ + Position(x=10.0, y=10.0, name="0000", row=0, col=0), + ], + ), + ( + GridFromEdges( + overlap=20, + top=30, + left=-10, + bottom=-10, + right=30, + fov_height=25, + fov_width=25, + ), + [ + Position(x=2.5, y=17.5, name="0000", row=0, col=0), + Position(x=22.5, y=17.5, name="0001", row=0, col=1), + Position(x=22.5, y=-2.5, name="0002", row=1, col=1), + Position(x=2.5, y=-2.5, name="0003", row=1, col=0), + ], + ), + ( + RandomPoints( + num_points=3, + max_width=4, + max_height=5, + fov_height=0.5, + fov_width=0.5, + shape="ellipse", + allow_overlap=False, + random_seed=0, + ), + [ + RelativePosition(x=-0.9, y=-1.1, name="0000"), + RelativePosition(x=0.9, y=-0.5, name="0001"), + RelativePosition(x=-0.8, y=-0.4, name="0002"), + ], + ), +] + + +@pytest.mark.filterwarnings("ignore:num_positions\\(\\) is deprecated") +@pytest.mark.parametrize("gridplan, gridexpectation", g_inputs) +def test_g_plan(gridplan: Any, gridexpectation: Sequence[Any]) -> None: + g_plan = TypeAdapter(MultiPositionPlan).validate_python(gridplan) + assert isinstance(g_plan, MultiPositionPlan) + if isinstance(gridplan, RandomPoints): + assert g_plan and [round(gp, 1) for gp in g_plan] == gridexpectation + else: + assert g_plan and list(g_plan) == gridexpectation + assert g_plan.num_positions() == len(gridexpectation) + + +EXPECT = { + (True, False): [(0, 0), (0, 1), (1, 0), (1, 1), (2, 0), (2, 1)], + (True, True): [(0, 0), (0, 1), (1, 1), (1, 0), (2, 0), (2, 1)], + (False, False): [(0, 0), (1, 0), (2, 0), (0, 1), (1, 1), (2, 1)], + (False, True): [(0, 0), (1, 0), (2, 0), (2, 1), (1, 1), (0, 1)], +} + + +@pytest.mark.parametrize("row_wise", [True, False], ids=["row_wise", "col_wise"]) +@pytest.mark.parametrize("snake", [True, False], ids=["snake", "normal"]) +def test_grid_indices(row_wise: bool, snake: bool) -> None: + indices = _rect_indices(3, 2, snake=snake, row_wise=row_wise) + assert list(indices) == EXPECT[(row_wise, snake)] + + +def test_spiral_indices() -> None: + assert list(_spiral_indices(2, 3)) == [ + (0, 1), + (0, 2), + (1, 2), + (1, 1), + (1, 0), + (0, 0), + ] + assert list(_spiral_indices(2, 3, center_origin=True)) == [ + (0, 0), + (0, 1), + (1, 1), + (1, 0), + (1, -1), + (0, -1), + ] + + +def test_position_equality() -> None: + """Order of grid positions should only change the order in which they are yielded""" + + def positions_without_name( + positions: Iterable[Position], + ) -> set[tuple[float, float, bool]]: + """Create a set of tuples of GridPosition attributes excluding 'name'""" + return {(pos.x, pos.y, pos.is_relative) for pos in positions} + + t1 = GridRowsColumns(rows=3, columns=3, mode=OrderMode.spiral) + spiral_pos = positions_without_name(t1.iter_grid_positions(1, 1)) + + t2 = GridRowsColumns(rows=3, columns=3, mode=OrderMode.row_wise) + row_pos = positions_without_name(t2.iter_grid_positions(1, 1)) + + t3 = GridRowsColumns(rows=3, columns=3, mode="row_wise_snake") + snake_row_pos = positions_without_name(t3.iter_grid_positions(1, 1)) + + t4 = GridRowsColumns(rows=3, columns=3, mode=OrderMode.column_wise) + col_pos = positions_without_name(t4.iter_grid_positions(1, 1)) + + t5 = GridRowsColumns(rows=3, columns=3, mode=OrderMode.column_wise_snake) + snake_col_pos = positions_without_name(t5.iter_grid_positions(1, 1)) + + assert spiral_pos == row_pos == snake_row_pos == col_pos == snake_col_pos + + +def test_grid_type() -> None: + g1 = GridRowsColumns(rows=2, columns=3) + assert [(g.x, g.y) for g in g1.iter_grid_positions(1, 1)] == [ + (-1.0, 0.5), + (0.0, 0.5), + (1.0, 0.5), + (1.0, -0.5), + (0.0, -0.5), + (-1.0, -0.5), + ] + assert g1.is_relative + g2 = GridWidthHeight(width=3, height=2, fov_height=1, fov_width=1) + assert [(g.x, g.y) for g in g2.iter_grid_positions()] == [ + (-1.0, 0.5), + (0.0, 0.5), + (1.0, 0.5), + (1.0, -0.5), + (0.0, -0.5), + (-1.0, -0.5), + ] + assert g2.is_relative + g3 = GridFromEdges(top=1, left=-1, bottom=-1, right=2) + assert [(g.x, g.y) for g in g3.iter_grid_positions(1, 1)] == [ + (-1.0, 1.0), + (0.0, 1.0), + (1.0, 1.0), + (2.0, 1.0), + (2.0, 0.0), + (1.0, 0.0), + (0.0, 0.0), + (-1.0, 0.0), + (-1.0, -1.0), + (0.0, -1.0), + (1.0, -1.0), + (2.0, -1.0), + ] + assert not g3.is_relative + + +@pytest.mark.filterwarnings("ignore:num_positions\\(\\) is deprecated") +def test_num_position_error() -> None: + with pytest.raises(ValueError, match="plan requires the field of view size"): + GridFromEdges(top=1, left=-1, bottom=-1, right=2).num_positions() + + with pytest.raises(ValueError, match="plan requires the field of view size"): + GridWidthHeight(width=2, height=2).num_positions() + + +expected_rectangle = [(0.2, 1.1), (0.4, 0.2), (-0.3, 0.7)] +expected_ellipse = [(-0.9, -1.1), (0.9, -0.5), (-0.8, -0.4)] + + +@pytest.mark.parametrize("n_points", [3, 100]) +@pytest.mark.parametrize("shape", ["rectangle", "ellipse"]) +@pytest.mark.parametrize("seed", [None, 0]) +def test_random_points(n_points: int, shape: str, seed: Optional[int]) -> None: + rp = RandomPoints( + num_points=n_points, + max_width=4, + max_height=5, + shape=shape, + random_seed=seed, + allow_overlap=False, + fov_width=0.5, + fov_height=0.5, + ) + + if n_points == 3: + expected = expected_rectangle if shape == "rectangle" else expected_ellipse + values = [(round(g.x, 1), round(g.y, 1)) for g in rp] + if seed is None: + assert values != expected + else: + assert values == expected + else: + with pytest.raises(UserWarning, match="Unable to generate"): + list(rp) + + +@pytest.mark.parametrize("order", list(TraversalOrder)) +def test_traversal(order: TraversalOrder): + pp = RandomPoints( + num_points=30, + max_height=3000, + max_width=3000, + order=order, + random_seed=1, + start_at=10, + fov_height=300, + fov_width=300, + allow_overlap=False, + ) + list(pp) + + +fov = {"fov_height": 200, "fov_width": 200} + + +@pytest.mark.filterwarnings("ignore:num_positions\\(\\) is deprecated") +@pytest.mark.parametrize( + "obj", + [ + GridRowsColumns(rows=1, columns=2, **fov), + GridWidthHeight(width=10, height=10, **fov), + RandomPoints(num_points=10, **fov), + ], +) +def test_points_plans_plot( + obj: MultiPointPlan, monkeypatch: pytest.MonkeyPatch +) -> None: + mpl = pytest.importorskip("matplotlib.pyplot") + monkeypatch.setattr(mpl, "show", lambda: None) + + assert isinstance(obj, get_args(MultiPointPlan)) + assert all(isinstance(x, Position) for x in obj) + assert isinstance(obj.num_positions(), int) + + obj.plot() + + +def test_grid_from_edges_plot(monkeypatch: pytest.MonkeyPatch) -> None: + mpl = pytest.importorskip("matplotlib.pyplot") + monkeypatch.setattr(mpl, "show", lambda: None) + GridFromEdges( + overlap=10, top=0, left=0, bottom=20, right=30, fov_height=10, fov_width=20 + ).plot() diff --git a/tests/v2/test_mda_seq.py b/tests/v2/test_mda_seq.py new file mode 100644 index 00000000..3c148d83 --- /dev/null +++ b/tests/v2/test_mda_seq.py @@ -0,0 +1,132 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest +from pydantic import field_validator + +import useq +from useq.v2 import ( + Channel, + MDAEvent, + MDASequence, + Position, + SimpleValueAxis, + TIntervalLoops, + ZRangeAround, +) + +if TYPE_CHECKING: + from collections.abc import Mapping + + +# Some example subclasses of SimpleAxis, to demonstrate flexibility +class APlan(SimpleValueAxis[float]): + axis_key: str = "a" + + def contribute_event_kwargs( + self, value: float, index: Mapping[str, int] + ) -> MDAEvent.Kwargs: + return {"min_start_time": value} + + +class BPlan(SimpleValueAxis[Position]): + axis_key: str = "b" + + @field_validator("values", mode="before") + def _value_to_position(cls, values: list[float]) -> list[Position]: + return [Position(z=v) for v in values] + + def contribute_event_kwargs( + self, value: Position, index: Mapping[str, int] + ) -> MDAEvent.Kwargs: + if value.z is None: + return {} + return {"z_pos": value.z} + + +class CPlan(SimpleValueAxis[Channel]): + axis_key: str = "c" + + @field_validator("values", mode="before") + def _value_to_channel(cls, values: list[str]) -> list[Channel]: + return [Channel(config=v, exposure=None) for v in values] + + def contribute_event_kwargs( + self, value: Channel, index: Mapping[str, int] + ) -> MDAEvent.Kwargs: + return {"channel": {"config": value.config}} + + +def test_new_mdasequence_manual() -> None: + seq = MDASequence( + axes=( + APlan(values=[0, 1]), + BPlan(values=[0.1, 0.3]), + CPlan(values=["red", "green", "blue"]), + ) + ) + events = [ + x.model_dump(exclude={"sequence"}, exclude_unset=True) + for x in seq.iter_events() + ] + # fmt: off + assert events == [ + {'index': {'a': 0, 'b': 0, 'c': 0}, 'channel': {'config': 'red'}, 'min_start_time': 0.0, 'z_pos': 0.1}, + {'index': {'a': 0, 'b': 0, 'c': 1}, 'channel': {'config': 'green'}, 'min_start_time': 0.0, 'z_pos': 0.1}, + {'index': {'a': 0, 'b': 0, 'c': 2}, 'channel': {'config': 'blue'}, 'min_start_time': 0.0, 'z_pos': 0.1}, + {'index': {'a': 0, 'b': 1, 'c': 0}, 'channel': {'config': 'red'}, 'min_start_time': 0.0, 'z_pos': 0.3}, + {'index': {'a': 0, 'b': 1, 'c': 1}, 'channel': {'config': 'green'}, 'min_start_time': 0.0, 'z_pos': 0.3}, + {'index': {'a': 0, 'b': 1, 'c': 2}, 'channel': {'config': 'blue'}, 'min_start_time': 0.0, 'z_pos': 0.3}, + {'index': {'a': 1, 'b': 0, 'c': 0}, 'channel': {'config': 'red'}, 'min_start_time': 1.0, 'z_pos': 0.1}, + {'index': {'a': 1, 'b': 0, 'c': 1}, 'channel': {'config': 'green'}, 'min_start_time': 1.0, 'z_pos': 0.1}, + {'index': {'a': 1, 'b': 0, 'c': 2}, 'channel': {'config': 'blue'}, 'min_start_time': 1.0, 'z_pos': 0.1}, + {'index': {'a': 1, 'b': 1, 'c': 0}, 'channel': {'config': 'red'}, 'min_start_time': 1.0, 'z_pos': 0.3}, + {'index': {'a': 1, 'b': 1, 'c': 1}, 'channel': {'config': 'green'}, 'min_start_time': 1.0, 'z_pos': 0.3}, + {'index': {'a': 1, 'b': 1, 'c': 2}, 'channel': {'config': 'blue'}, 'min_start_time': 1.0, 'z_pos': 0.3}, + ] + # fmt: on + + +def test_new_mdasequence_parity() -> None: + seq = MDASequence( + time_plan=TIntervalLoops(interval=0.2, loops=2), + z_plan=ZRangeAround(range=1, step=0.5), + channels=["DAPI", "FITC"], + ) + v1_seq = useq.MDASequence( + time_plan=useq.TIntervalLoops(interval=0.2, loops=2), + z_plan=useq.ZRangeAround(range=1, step=0.5), + channels=["DAPI", "FITC"], + ) + assert list(v1_seq) == list(seq) + + +def serialize_mda_sequence() -> None: + assert isinstance(MDASequence.model_json_schema(), str) + seq = MDASequence( + time_plan=TIntervalLoops(interval=0.2, loops=2), + z_plan=ZRangeAround(range=1, step=0.5), + channels=["DAPI", "FITC"], + ) + assert isinstance(seq.model_dump_json(), str) + assert isinstance(seq.model_dump(mode="json"), dict) + + +@pytest.mark.filterwarnings("ignore:.*ill-defined:FutureWarning") +def test_basic_properties() -> None: + seq = MDASequence( + time_plan=TIntervalLoops(interval=0.2, loops=2), + z_plan=ZRangeAround(range=1, step=0.5), + stage_positions=[(0, 0)], + channels=["DAPI", "FITC"], + axis_order=("t", "c", "z"), + ) + assert seq.time_plan is not None + assert seq.channels is not None + assert seq.z_plan is not None + assert seq.stage_positions is not None + assert seq.grid_plan is None + assert seq.shape + assert seq.sizes + assert seq.used_axes == ("t", "c", "z") diff --git a/tests/v2/test_multidim_seq.py b/tests/v2/test_multidim_seq.py new file mode 100644 index 00000000..0df3db48 --- /dev/null +++ b/tests/v2/test_multidim_seq.py @@ -0,0 +1,318 @@ +from __future__ import annotations + +from itertools import count +from typing import TYPE_CHECKING, Any + +from pydantic import Field + +from useq.v2 import Axis, AxisIterable, MultiAxisSequence, SimpleValueAxis + +if TYPE_CHECKING: + from collections.abc import Iterable, Iterator + + from useq.v2._axes_iterator import AxesIndex + + +def _index_and_values( + multi_dim: MultiAxisSequence, + axis_order: tuple[str, ...] | None = None, + max_iters: int | None = None, +) -> list[dict[str, tuple[int, Any]]]: + """Return a list of indices and values for each axis in the MultiDimSequence.""" + result = [] + for i, axes_with_context in enumerate(multi_dim.iter_axes(axis_order=axis_order)): + if max_iters is not None and i >= max_iters: + break + # Extract the AxesIndex from the tuple (AxesIndex, context) + indices, _context = axes_with_context + # cleaned version that drops the axis objects. + result.append({k: (idx, val) for k, (idx, val, _) in indices.items()}) + return result + + +def test_new_multidim_simple_seq() -> None: + multi_dim = MultiAxisSequence[Any]( + axes=( + SimpleValueAxis(axis_key=Axis.TIME, values=[0, 1]), + SimpleValueAxis(axis_key=Axis.CHANNEL, values=["red", "green", "blue"]), + SimpleValueAxis(axis_key=Axis.Z, values=[0.1, 0.3]), + ) + ) + assert multi_dim.is_finite() + + result = _index_and_values(multi_dim) + assert result == [ + {"t": (0, 0), "c": (0, "red"), "z": (0, 0.1)}, + {"t": (0, 0), "c": (0, "red"), "z": (1, 0.3)}, + {"t": (0, 0), "c": (1, "green"), "z": (0, 0.1)}, + {"t": (0, 0), "c": (1, "green"), "z": (1, 0.3)}, + {"t": (0, 0), "c": (2, "blue"), "z": (0, 0.1)}, + {"t": (0, 0), "c": (2, "blue"), "z": (1, 0.3)}, + {"t": (1, 1), "c": (0, "red"), "z": (0, 0.1)}, + {"t": (1, 1), "c": (0, "red"), "z": (1, 0.3)}, + {"t": (1, 1), "c": (1, "green"), "z": (0, 0.1)}, + {"t": (1, 1), "c": (1, "green"), "z": (1, 0.3)}, + {"t": (1, 1), "c": (2, "blue"), "z": (0, 0.1)}, + {"t": (1, 1), "c": (2, "blue"), "z": (1, 0.3)}, + ] + + +class InfiniteAxis(AxisIterable[int]): + axis_key: str = "i" + + def model_post_init(self, _ctx: Any) -> None: + self._counter = count() + + def __iter__(self) -> Iterator[int]: + yield from self._counter + + +def test_multidim_nested_seq() -> None: + inner_seq = MultiAxisSequence[Any]( + value=1, axes=(SimpleValueAxis(axis_key="q", values=["a", "b"]),) + ) + outer_seq = MultiAxisSequence[Any]( + axes=( + SimpleValueAxis(axis_key="t", values=[0, inner_seq, 2]), + SimpleValueAxis(axis_key="c", values=["red", "green", "blue"]), + ) + ) + + assert outer_seq.is_finite() + + result = _index_and_values(outer_seq) + assert result == [ + {"t": (0, 0), "c": (0, "red")}, + {"t": (0, 0), "c": (1, "green")}, + {"t": (0, 0), "c": (2, "blue")}, + {"t": (1, 1), "c": (0, "red"), "q": (0, "a")}, + {"t": (1, 1), "c": (0, "red"), "q": (1, "b")}, + {"t": (1, 1), "c": (1, "green"), "q": (0, "a")}, + {"t": (1, 1), "c": (1, "green"), "q": (1, "b")}, + {"t": (1, 1), "c": (2, "blue"), "q": (0, "a")}, + {"t": (1, 1), "c": (2, "blue"), "q": (1, "b")}, + {"t": (2, 2), "c": (0, "red")}, + {"t": (2, 2), "c": (1, "green")}, + {"t": (2, 2), "c": (2, "blue")}, + ] + + result = _index_and_values(outer_seq, axis_order=("t", "c")) + assert result == [ + {"t": (0, 0), "c": (0, "red")}, + {"t": (0, 0), "c": (1, "green")}, + {"t": (0, 0), "c": (2, "blue")}, + {"t": (1, 1), "c": (0, "red")}, + {"t": (1, 1), "c": (1, "green")}, + {"t": (1, 1), "c": (2, "blue")}, + {"t": (2, 2), "c": (0, "red")}, + {"t": (2, 2), "c": (1, "green")}, + {"t": (2, 2), "c": (2, "blue")}, + ] + + +def test_override_parent_axes() -> None: + inner_seq = MultiAxisSequence( + value=1, + axes=( + SimpleValueAxis(axis_key="c", values=["red", "blue"]), + SimpleValueAxis(axis_key="z", values=[7, 8, 9]), + ), + ) + multi_dim = MultiAxisSequence( + axes=( + SimpleValueAxis(axis_key="t", values=[0, inner_seq, 2]), + SimpleValueAxis(axis_key="c", values=["red", "green", "blue"]), + SimpleValueAxis(axis_key="z", values=[0.1, 0.2]), + ), + axis_order=("t", "c", "z"), + ) + + assert multi_dim.is_finite() + result = _index_and_values(multi_dim) + assert result == [ + {"t": (0, 0), "c": (0, "red"), "z": (0, 0.1)}, + {"t": (0, 0), "c": (0, "red"), "z": (1, 0.2)}, + {"t": (0, 0), "c": (1, "green"), "z": (0, 0.1)}, + {"t": (0, 0), "c": (1, "green"), "z": (1, 0.2)}, + {"t": (0, 0), "c": (2, "blue"), "z": (0, 0.1)}, + {"t": (0, 0), "c": (2, "blue"), "z": (1, 0.2)}, + {"t": (1, 1), "c": (0, "red"), "z": (0, 7)}, + {"t": (1, 1), "c": (0, "red"), "z": (1, 8)}, + {"t": (1, 1), "c": (0, "red"), "z": (2, 9)}, + {"t": (1, 1), "c": (1, "blue"), "z": (0, 7)}, + {"t": (1, 1), "c": (1, "blue"), "z": (1, 8)}, + {"t": (1, 1), "c": (1, "blue"), "z": (2, 9)}, + {"t": (2, 2), "c": (0, "red"), "z": (0, 0.1)}, + {"t": (2, 2), "c": (0, "red"), "z": (1, 0.2)}, + {"t": (2, 2), "c": (1, "green"), "z": (0, 0.1)}, + {"t": (2, 2), "c": (1, "green"), "z": (1, 0.2)}, + {"t": (2, 2), "c": (2, "blue"), "z": (0, 0.1)}, + {"t": (2, 2), "c": (2, "blue"), "z": (1, 0.2)}, + ] + + +class FilteredZ(SimpleValueAxis): + def __init__(self, values: Iterable) -> None: + super().__init__(axis_key=Axis.Z, values=values) + + def should_skip(self, prefix: AxesIndex) -> bool: + # If c is green, then only allow combinations where z equals 0.2. + c_val = prefix.get(Axis.CHANNEL, (None, None))[1] + z_val = prefix.get(Axis.Z, (None, None))[1] + return bool(c_val == "green" and z_val != 0.2) + + +def test_multidim_with_should_skip() -> None: + multi_dim = MultiAxisSequence( + axes=( + SimpleValueAxis(axis_key=Axis.TIME, values=[0, 1, 2]), + SimpleValueAxis(axis_key=Axis.CHANNEL, values=["red", "green", "blue"]), + FilteredZ([0.1, 0.2, 0.3]), + ), + axis_order=(Axis.TIME, Axis.CHANNEL, Axis.Z), + ) + + assert multi_dim.is_finite() + result = _index_and_values(multi_dim) + + # If c is green, then only allow combinations where z equals 0.2. + assert not any( + item["c"][1] == "green" and item["z"][1] != 0.2 for item in result + ), "FilteredZ should have filtered out green z!=0.2 combinations" + + assert result == [ + {"t": (0, 0), "c": (0, "red"), "z": (0, 0.1)}, + {"t": (0, 0), "c": (0, "red"), "z": (1, 0.2)}, + {"t": (0, 0), "c": (0, "red"), "z": (2, 0.3)}, + {"t": (0, 0), "c": (1, "green"), "z": (1, 0.2)}, + {"t": (0, 0), "c": (2, "blue"), "z": (0, 0.1)}, + {"t": (0, 0), "c": (2, "blue"), "z": (1, 0.2)}, + {"t": (0, 0), "c": (2, "blue"), "z": (2, 0.3)}, + {"t": (1, 1), "c": (0, "red"), "z": (0, 0.1)}, + {"t": (1, 1), "c": (0, "red"), "z": (1, 0.2)}, + {"t": (1, 1), "c": (0, "red"), "z": (2, 0.3)}, + {"t": (1, 1), "c": (1, "green"), "z": (1, 0.2)}, + {"t": (1, 1), "c": (2, "blue"), "z": (0, 0.1)}, + {"t": (1, 1), "c": (2, "blue"), "z": (1, 0.2)}, + {"t": (1, 1), "c": (2, "blue"), "z": (2, 0.3)}, + {"t": (2, 2), "c": (0, "red"), "z": (0, 0.1)}, + {"t": (2, 2), "c": (0, "red"), "z": (1, 0.2)}, + {"t": (2, 2), "c": (0, "red"), "z": (2, 0.3)}, + {"t": (2, 2), "c": (1, "green"), "z": (1, 0.2)}, + {"t": (2, 2), "c": (2, "blue"), "z": (0, 0.1)}, + {"t": (2, 2), "c": (2, "blue"), "z": (1, 0.2)}, + {"t": (2, 2), "c": (2, "blue"), "z": (2, 0.3)}, + ] + + +def test_all_together() -> None: + t1_overrides = MultiAxisSequence( + value=1, + axes=( + SimpleValueAxis(axis_key="c", values=["red", "blue"]), + SimpleValueAxis(axis_key="z", values=[7, 8, 9]), + ), + ) + c_blue_subseq = MultiAxisSequence( + value="blue", + axes=(SimpleValueAxis(axis_key="q", values=["a", "b"]),), + ) + multi_dim = MultiAxisSequence( + axes=( + SimpleValueAxis(axis_key="t", values=[0, t1_overrides, 2]), + SimpleValueAxis(axis_key="c", values=["red", "green", c_blue_subseq]), + FilteredZ([0.1, 0.2, 0.3]), + ), + ) + + assert multi_dim.is_finite() + result = _index_and_values(multi_dim) + assert result == [ + {"t": (0, 0), "c": (0, "red"), "z": (0, 0.1)}, + {"t": (0, 0), "c": (0, "red"), "z": (1, 0.2)}, + {"t": (0, 0), "c": (0, "red"), "z": (2, 0.3)}, + {"t": (0, 0), "c": (1, "green"), "z": (1, 0.2)}, + {"t": (0, 0), "c": (2, "blue"), "z": (0, 0.1), "q": (0, "a")}, + {"t": (0, 0), "c": (2, "blue"), "z": (0, 0.1), "q": (1, "b")}, + {"t": (0, 0), "c": (2, "blue"), "z": (1, 0.2), "q": (0, "a")}, + {"t": (0, 0), "c": (2, "blue"), "z": (1, 0.2), "q": (1, "b")}, + {"t": (0, 0), "c": (2, "blue"), "z": (2, 0.3), "q": (0, "a")}, + {"t": (0, 0), "c": (2, "blue"), "z": (2, 0.3), "q": (1, "b")}, + {"t": (1, 1), "c": (0, "red"), "z": (0, 7)}, + {"t": (1, 1), "c": (0, "red"), "z": (1, 8)}, + {"t": (1, 1), "c": (0, "red"), "z": (2, 9)}, + {"t": (1, 1), "c": (1, "blue"), "z": (0, 7)}, + {"t": (1, 1), "c": (1, "blue"), "z": (1, 8)}, + {"t": (1, 1), "c": (1, "blue"), "z": (2, 9)}, + {"t": (2, 2), "c": (0, "red"), "z": (0, 0.1)}, + {"t": (2, 2), "c": (0, "red"), "z": (1, 0.2)}, + {"t": (2, 2), "c": (0, "red"), "z": (2, 0.3)}, + {"t": (2, 2), "c": (1, "green"), "z": (1, 0.2)}, + {"t": (2, 2), "c": (2, "blue"), "z": (0, 0.1), "q": (0, "a")}, + {"t": (2, 2), "c": (2, "blue"), "z": (0, 0.1), "q": (1, "b")}, + {"t": (2, 2), "c": (2, "blue"), "z": (1, 0.2), "q": (0, "a")}, + {"t": (2, 2), "c": (2, "blue"), "z": (1, 0.2), "q": (1, "b")}, + {"t": (2, 2), "c": (2, "blue"), "z": (2, 0.3), "q": (0, "a")}, + {"t": (2, 2), "c": (2, "blue"), "z": (2, 0.3), "q": (1, "b")}, + ] + + +def test_new_multidim_with_infinite_axis() -> None: + # note... we never progress to t=1 + multi_dim = MultiAxisSequence( + axes=( + SimpleValueAxis(axis_key=Axis.TIME, values=[0, 1]), + InfiniteAxis(), + SimpleValueAxis(axis_key=Axis.Z, values=[0.1, 0.3]), + ) + ) + + assert not multi_dim.is_finite() + result = _index_and_values(multi_dim, max_iters=10) + assert result == [ + {"t": (0, 0), "i": (0, 0), "z": (0, 0.1)}, + {"t": (0, 0), "i": (0, 0), "z": (1, 0.3)}, + {"t": (0, 0), "i": (1, 1), "z": (0, 0.1)}, + {"t": (0, 0), "i": (1, 1), "z": (1, 0.3)}, + {"t": (0, 0), "i": (2, 2), "z": (0, 0.1)}, + {"t": (0, 0), "i": (2, 2), "z": (1, 0.3)}, + {"t": (0, 0), "i": (3, 3), "z": (0, 0.1)}, + {"t": (0, 0), "i": (3, 3), "z": (1, 0.3)}, + {"t": (0, 0), "i": (4, 4), "z": (0, 0.1)}, + {"t": (0, 0), "i": (4, 4), "z": (1, 0.3)}, + ] + + +class DynamicROIAxis(SimpleValueAxis[str]): + axis_key: str = "r" + values: list[str] = Field(default_factory=lambda: ["cell0", "cell1"]) + + # we add a new roi at each time step + def __iter__(self) -> Iterator[str]: + yield from self.values + self.values.append(f"cell{len(self.values)}") + + +def test_dynamic_roi_addition() -> None: + multi_dim = MultiAxisSequence(axes=(InfiniteAxis(), DynamicROIAxis())) + + assert not multi_dim.is_finite() + result = _index_and_values(multi_dim, max_iters=16) + assert result == [ + {"i": (0, 0), "r": (0, "cell0")}, + {"i": (0, 0), "r": (1, "cell1")}, + {"i": (1, 1), "r": (0, "cell0")}, + {"i": (1, 1), "r": (1, "cell1")}, + {"i": (1, 1), "r": (2, "cell2")}, + {"i": (2, 2), "r": (0, "cell0")}, + {"i": (2, 2), "r": (1, "cell1")}, + {"i": (2, 2), "r": (2, "cell2")}, + {"i": (2, 2), "r": (3, "cell3")}, + {"i": (3, 3), "r": (0, "cell0")}, + {"i": (3, 3), "r": (1, "cell1")}, + {"i": (3, 3), "r": (2, "cell2")}, + {"i": (3, 3), "r": (3, "cell3")}, + {"i": (3, 3), "r": (4, "cell4")}, + {"i": (4, 4), "r": (0, "cell0")}, + {"i": (4, 4), "r": (1, "cell1")}, + ] diff --git a/tests/v2/test_time.py b/tests/v2/test_time.py new file mode 100644 index 00000000..3708e442 --- /dev/null +++ b/tests/v2/test_time.py @@ -0,0 +1,418 @@ +"""Tests for the time module in useq.v2.""" + +from __future__ import annotations + +from datetime import timedelta + +import pytest + +from useq.v2 import ( + AnyTimePlan, + MultiPhaseTimePlan, + SinglePhaseTimePlan, + TDurationLoops, + TimePlan, + TIntervalDuration, + TIntervalLoops, +) + + +class TestTIntervalLoops: + """Test TIntervalLoops time plan.""" + + def test_basic_creation(self) -> None: + """Test basic creation and properties.""" + plan = TIntervalLoops(interval=timedelta(seconds=2), loops=5) + + assert plan.interval == timedelta(seconds=2) + assert plan.loops == 5 + assert plan.axis_key == "t" + assert len(plan) == 5 + assert plan.duration == timedelta(seconds=8) # (5-1) * 2 + + def test_interval_from_dict(self) -> None: + """Test creating interval from dict.""" + plan = TIntervalLoops(interval={"seconds": 3}, loops=3) + assert plan.interval == timedelta(seconds=3) + + def test_interval_from_float(self) -> None: + """Test creating interval from float (seconds).""" + plan = TIntervalLoops(interval=1.5, loops=4) + assert plan.interval == timedelta(seconds=1.5) + + def test_iteration(self) -> None: + """Test iterating over time values.""" + plan = TIntervalLoops(interval=timedelta(seconds=2), loops=3) + times = list(plan) + + assert times == [0.0, 2.0, 4.0] + + def test_zero_loops_invalid(self) -> None: + """Test that zero loops raises validation error.""" + with pytest.raises(ValueError, match="greater than 0"): + TIntervalLoops(interval=timedelta(seconds=1), loops=0) + + def test_negative_loops_invalid(self) -> None: + """Test that negative loops raises validation error.""" + with pytest.raises(ValueError, match="greater than 0"): + TIntervalLoops(interval=timedelta(seconds=1), loops=-1) + + def test_interval_s_method(self) -> None: + """Test _interval_s private method.""" + plan = TIntervalLoops(interval=timedelta(seconds=2.5), loops=3) + assert plan.interval.total_seconds() == 2.5 + + +class TestTDurationLoops: + """Test TDurationLoops time plan.""" + + def test_basic_creation(self) -> None: + """Test basic creation and properties.""" + plan = TDurationLoops(duration=timedelta(seconds=10), loops=6) + + assert plan.duration == timedelta(seconds=10) + assert plan.loops == 6 + assert len(plan) == 6 + assert plan.interval == timedelta(seconds=2) # 10 / (6-1) + + def test_duration_from_dict(self) -> None: + """Test creating duration from dict.""" + plan = TDurationLoops(duration={"minutes": 1}, loops=4) + assert plan.duration == timedelta(minutes=1) + + def test_iteration(self) -> None: + """Test iterating over time values.""" + plan = TDurationLoops(duration=timedelta(seconds=6), loops=4) + times = list(plan) + + # Should be evenly spaced over 6 seconds: 0, 2, 4, 6 + assert times == [0.0, 2.0, 4.0, 6.0] + + def test_single_loop(self) -> None: + """Test behavior with single loop.""" + plan = TDurationLoops(duration=timedelta(seconds=5), loops=1) + times = list(plan) + + # With 1 loop, interval would be 5/0 which would cause issues + # But the implementation should handle this gracefully + assert len(times) == 1 + assert times[0] == 0.0 + + def test_interval_s_method(self) -> None: + """Test _interval_s private method.""" + plan = TDurationLoops(duration=timedelta(seconds=8), loops=5) + assert plan.interval.total_seconds() == 2.0 # 8 / (5-1) + + +class TestTIntervalDuration: + """Test TIntervalDuration time plan.""" + + def test_basic_creation_finite(self) -> None: + """Test creation with finite duration.""" + plan = TIntervalDuration( + interval=timedelta(seconds=2), duration=timedelta(seconds=10) + ) + + assert plan.interval == timedelta(seconds=2) + assert plan.duration == timedelta(seconds=10) + assert plan.prioritize_duration is True # default + + def test_basic_creation_infinite(self) -> None: + """Test creation with infinite duration.""" + plan = TIntervalDuration(interval=timedelta(seconds=1), duration=None) + + assert plan.interval == timedelta(seconds=1) + assert plan.duration is None + assert plan.prioritize_duration is True + + def test_finite_iteration(self) -> None: + """Test iteration with finite duration.""" + plan = TIntervalDuration( + interval=timedelta(seconds=2), duration=timedelta(seconds=5) + ) + times = list(plan) + + # Should yield: 0, 2, 4 (stops before 6 which exceeds duration) + assert times == [0.0, 2.0, 4.0] + + def test_infinite_iteration_limited(self) -> None: + """Test that infinite iteration can be limited.""" + plan = TIntervalDuration(interval=timedelta(seconds=1), duration=None) + iterator = iter(plan) + + # Take first few values to test infinite sequence + times = [next(iterator) for _ in range(5)] + assert times == [0.0, 1.0, 2.0, 3.0, 4.0] + + def test_duration_from_dict(self) -> None: + """Test creating duration from dict.""" + plan = TIntervalDuration(interval={"seconds": 1}, duration={"minutes": 2}) + assert plan.duration == timedelta(minutes=2) + + def test_prioritize_duration_false(self) -> None: + """Test setting prioritize_duration to False.""" + plan = TIntervalDuration( + interval=timedelta(seconds=1), + duration=timedelta(seconds=5), + prioritize_duration=False, + ) + assert plan.prioritize_duration is False + + def test_interval_s_method(self) -> None: + """Test _interval_s private method.""" + plan = TIntervalDuration( + interval=timedelta(seconds=1.5), duration=timedelta(seconds=5) + ) + assert plan.interval.total_seconds() == 1.5 + + def test_exact_duration_boundary(self) -> None: + """Test behavior when time exactly equals duration.""" + plan = TIntervalDuration( + interval=timedelta(seconds=2), duration=timedelta(seconds=4) + ) + times = list(plan) + + # Should include exactly 4.0 since condition is t <= duration + assert times == [0.0, 2.0, 4.0] + + +class TestMultiPhaseTimePlan: + """Test MultiPhaseTimePlan.""" + + def test_basic_creation(self) -> None: + """Test basic creation with multiple phases.""" + phase1 = TIntervalLoops(interval=timedelta(seconds=1), loops=3) + phase2 = TIntervalLoops(interval=timedelta(seconds=2), loops=2) + + plan = MultiPhaseTimePlan(phases=[phase1, phase2]) + assert len(plan.phases) == 2 + + def test_iteration_multiple_finite_phases(self) -> None: + """Test iteration over multiple finite phases.""" + phase1 = TIntervalLoops(interval=timedelta(seconds=1), loops=3) + phase2 = TIntervalLoops(interval=timedelta(seconds=2), loops=2) + + plan = MultiPhaseTimePlan(phases=[phase1, phase2]) + times = list(plan) + + assert times == [0.0, 1.0, 2.0, 4.0] + + def test_iteration_mixed_phases(self) -> None: + """Test iteration with different phase types.""" + phase1 = TDurationLoops(duration=timedelta(seconds=4), loops=3) + phase2 = TIntervalLoops(interval=timedelta(seconds=1), loops=2) + + plan = MultiPhaseTimePlan(phases=[phase1, phase2]) + times = list(plan) + + assert times == [0.0, 2.0, 4.0, 5.0] + + def test_send_skip_phase(self) -> None: + """Test using send(True) to skip to next phase.""" + phase1 = TIntervalLoops(interval=timedelta(seconds=1), loops=5) + phase2 = TIntervalLoops(interval=timedelta(seconds=2), loops=2) + + plan = MultiPhaseTimePlan(phases=[phase1, phase2]) + iterator = iter(plan) + + # Start iteration + assert next(iterator) == 0.0 + assert next(iterator) == 1.0 + + # Force skip to next phase + try: + value = iterator.send(True) + # Should start phase 2 at offset of phase 1's duration (4 seconds) + assert value == 6.0 # phase 2, time 0 + except StopIteration: + # If send causes StopIteration, get next value + assert next(iterator) == 4.0 + + def test_infinite_phase_handling(self) -> None: + """Test handling of infinite phases.""" + phase1 = TIntervalLoops(interval=timedelta(seconds=1), loops=2) + phase2 = TIntervalDuration(interval=timedelta(seconds=1), duration=None) + + plan = MultiPhaseTimePlan(phases=[phase1, phase2]) + iterator = iter(plan) + + # Get first phase values + # Should get 0, 1, 1 (start of phase 2) + times = [next(iterator) for _ in range(3)] + + # Phase 1 ends after 1 second, so phase 2 starts with offset 1 + assert times[:2] == [0.0, 1.0] + assert times[2] == 2.0 # Start of infinite phase 2 + + def test_empty_phases(self) -> None: + """Test behavior with empty phases list.""" + plan = MultiPhaseTimePlan(phases=[]) + times = list(plan) + assert times == [] + + def test_single_phase(self) -> None: + """Test behavior with single phase.""" + phase = TIntervalLoops(interval=timedelta(seconds=2), loops=3) + plan = MultiPhaseTimePlan(phases=[phase]) + + times = list(plan) + assert times == [0.0, 2.0, 4.0] + + +class TestTimePlanAbstract: + """Test abstract TimePlan behavior.""" + + def test_axis_key_default(self) -> None: + """Test that axis_key defaults to 't'.""" + plan = TIntervalLoops(interval=timedelta(seconds=1), loops=2) + assert plan.axis_key == "t" + + def test_prioritize_duration_default(self) -> None: + """Test prioritize_duration defaults.""" + plan1 = TIntervalLoops(interval=timedelta(seconds=1), loops=2) + assert plan1.prioritize_duration is False + + plan2 = TIntervalDuration( + interval=timedelta(seconds=1), duration=timedelta(seconds=5) + ) + assert plan2.prioritize_duration is True + + +class TestTypeAliases: + """Test type aliases work correctly.""" + + def test_single_phase_time_plan_types(self) -> None: + """Test that SinglePhaseTimePlan accepts all expected types.""" + plans: list[SinglePhaseTimePlan] = [ + TIntervalDuration( + interval=timedelta(seconds=1), duration=timedelta(seconds=5) + ), + TIntervalLoops(interval=timedelta(seconds=1), loops=3), + TDurationLoops(duration=timedelta(seconds=6), loops=4), + ] + + for plan in plans: + assert isinstance(plan, TimePlan) + + def test_any_time_plan_types(self) -> None: + """Test that AnyTimePlan accepts all expected types.""" + phase = TIntervalLoops(interval=timedelta(seconds=1), loops=2) + + plans: list[AnyTimePlan] = [ + TIntervalDuration( + interval=timedelta(seconds=1), duration=timedelta(seconds=5) + ), + TIntervalLoops(interval=timedelta(seconds=1), loops=3), + TDurationLoops(duration=timedelta(seconds=6), loops=4), + MultiPhaseTimePlan(phases=[phase]), + ] + + for plan in plans: + assert isinstance(plan, TimePlan) + + +class TestEdgeCases: + """Test edge cases and error conditions.""" + + def test_very_small_intervals(self) -> None: + """Test behavior with very small time intervals.""" + plan = TIntervalLoops(interval=timedelta(microseconds=1), loops=3) + times = list(plan) + + expected = [0.0, 0.000001, 0.000002] + assert len(times) == 3 + for actual, exp in zip(times, expected): + assert abs(actual - exp) < 1e-9 + + def test_large_number_of_loops(self) -> None: + """Test with large number of loops.""" + plan = TIntervalLoops(interval=timedelta(seconds=1), loops=1000) + assert len(plan) == 1000 + + # Test first and last few values + iterator = iter(plan) + assert next(iterator) == 0.0 + assert next(iterator) == 1.0 + + # Skip to end + times = list(iterator) + assert times[-1] == 999.0 + + def test_zero_interval_duration_plan(self) -> None: + """Test TIntervalDuration with zero interval.""" + plan = TIntervalDuration( + interval=timedelta(seconds=0), duration=timedelta(seconds=1) + ) + # This should theoretically create an infinite loop at t=0 + # Implementation should handle this gracefully + iterator = iter(plan) + first_few = [next(iterator) for _ in range(3)] + assert all(t == 0.0 for t in first_few) + + def test_negative_duration_loops(self) -> None: + """Test that negative duration raises appropriate error.""" + with pytest.raises(ValueError): + TDurationLoops(duration=timedelta(seconds=-5), loops=3) + + def test_duration_loops_with_one_loop_edge_case(self) -> None: + """Test duration loops with exactly one loop.""" + plan = TDurationLoops(duration=timedelta(seconds=10), loops=1) + times = list(plan) + + # With 1 loop, we expect just [0.0] + assert times == [0.0] + # With 1 loop, interval is meaningless and returns zero + assert plan.interval.total_seconds() == 0.0 + # But _interval_s returns infinity to indicate instantaneous + assert plan.interval.total_seconds() == 0 + + +@pytest.mark.parametrize( + "plan_class,kwargs", + [ + (TIntervalLoops, {"interval": timedelta(seconds=1), "loops": 3}), + (TDurationLoops, {"duration": timedelta(seconds=6), "loops": 4}), + ( + TIntervalDuration, + {"interval": timedelta(seconds=2), "duration": timedelta(seconds=10)}, + ), + ], +) +def test_time_plan_serialization(plan_class: type[TimePlan], kwargs: dict) -> None: + """Test that time plans can be serialized and deserialized.""" + plan = plan_class(**kwargs) + + # Test model dump/load cycle + data = plan.model_dump_json() + restored = plan_class.model_validate_json(data) + + assert restored == plan + assert list(restored) == list(plan) + + +def test_integration_with_mda_axis_iterable() -> None: + """Test that time plans integrate properly with MDAAxisIterable.""" + plan = TIntervalLoops(interval=timedelta(seconds=2), loops=3) + + # Should have MDAAxisIterable methods + assert hasattr(plan, "axis_key") + + # Test the axis_key + assert plan.axis_key == "t" + + # Test iteration returns float values + values = list(plan) + assert all(isinstance(v, float) for v in values) + + +def test_contribute_to_mda_event() -> None: + """Test that time plans can contribute to MDA events.""" + plan = TIntervalLoops(interval=timedelta(seconds=2), loops=3) + + # Test contribution + contribution = plan.contribute_event_kwargs(4.0, {"t": 2}) + assert contribution == {"min_start_time": 4.0} + + # Test with different value + contribution = plan.contribute_event_kwargs(0.0, {"t": 0}) + assert contribution == {"min_start_time": 0.0} diff --git a/tests/v2/test_z.py b/tests/v2/test_z.py new file mode 100644 index 00000000..eaf1311b --- /dev/null +++ b/tests/v2/test_z.py @@ -0,0 +1,308 @@ +"""Tests for v2 Z plans module.""" + +from __future__ import annotations + +import pytest + +from useq.v2 import ( + Axis, + MDAEvent, + Position, + ZAboveBelow, + ZAbsolutePositions, + ZPlan, + ZRangeAround, + ZRelativePositions, + ZTopBottom, +) + + +class TestZTopBottom: + """Test ZTopBottom plan.""" + + def test_basic_creation(self) -> None: + """Test basic creation and attributes.""" + plan = ZTopBottom(top=10.0, bottom=0.0, step=2.0) + assert plan.top == 10.0 + assert plan.bottom == 0.0 + assert plan.step == 2.0 + assert plan.go_up is True + assert plan.axis_key == Axis.Z + + def test_positions_go_up(self) -> None: + """Test positions when go_up is True.""" + plan = ZTopBottom(top=4.0, bottom=0.0, step=1.0, go_up=True) + positions = [p.z for p in plan] + expected = [0.0, 1.0, 2.0, 3.0, 4.0] + assert positions == expected + + def test_positions_go_down(self) -> None: + """Test positions when go_up is False.""" + plan = ZTopBottom(top=4.0, bottom=0.0, step=1.0, go_up=False) + positions = [p.z for p in plan] + expected = [4.0, 3.0, 2.0, 1.0, 0.0] + assert positions == expected + + def test_is_relative(self) -> None: + """Test is_relative property.""" + plan = ZTopBottom(top=4.0, bottom=0.0, step=1.0) + assert plan.is_relative is False + + def test_num_positions(self) -> None: + """Test num_positions method.""" + plan = ZTopBottom(top=4.0, bottom=0.0, step=1.0) + assert len(plan) == 5 + + def test_start_stop_step(self) -> None: + """Test _start_stop_step method.""" + plan = ZTopBottom(top=10.0, bottom=2.0, step=1.5) + start, stop, step = plan._start_stop_step() + assert start == 2.0 + assert stop == 10.0 + assert step == 1.5 + + def test_contribute_to_mda_event(self) -> None: + """Test contribute_to_mda_event method.""" + plan = ZTopBottom(top=10.0, bottom=0.0, step=2.0) + contribution = plan.contribute_event_kwargs(Position(z=5.0), {"z": 2}) + assert contribution == {"z_pos": 5.0} + + +class TestZRangeAround: + """Test ZRangeAround plan.""" + + def test_basic_creation(self) -> None: + """Test basic creation and attributes.""" + plan = ZRangeAround(range=4.0, step=1.0) + assert plan.range == 4.0 + assert plan.step == 1.0 + assert plan.go_up is True + assert plan.axis_key == Axis.Z + + def test_positions_symmetric(self) -> None: + """Test symmetric positions around center.""" + plan = ZRangeAround(range=4.0, step=1.0, go_up=True) + positions = [p.z for p in plan] + expected = [-2.0, -1.0, 0.0, 1.0, 2.0] + assert positions == expected + + def test_start_stop_step(self) -> None: + """Test _start_stop_step method.""" + plan = ZRangeAround(range=6.0, step=1.5) + start, stop, step = plan._start_stop_step() + assert start == -3.0 + assert stop == 3.0 + assert step == 1.5 + + def test_is_relative(self) -> None: + """Test is_relative property.""" + plan = ZRangeAround(range=4.0, step=1.0) + assert plan.is_relative is True + + +class TestZAboveBelow: + """Test ZAboveBelow plan.""" + + def test_basic_creation(self) -> None: + """Test basic creation and attributes.""" + plan = ZAboveBelow(above=3.0, below=2.0, step=1.0) + assert plan.above == 3.0 + assert plan.below == 2.0 + assert plan.step == 1.0 + assert plan.axis_key == Axis.Z + + def test_positions_asymmetric(self) -> None: + """Test asymmetric positions.""" + plan = ZAboveBelow(above=3.0, below=2.0, step=1.0, go_up=True) + positions = [p.z for p in plan] + expected = [-2.0, -1.0, 0.0, 1.0, 2.0, 3.0] + assert positions == expected + + def test_start_stop_step(self) -> None: + """Test _start_stop_step method.""" + plan = ZAboveBelow(above=4.0, below=3.0, step=0.5) + start, stop, step = plan._start_stop_step() + assert start == -3.0 + assert stop == 4.0 + assert step == 0.5 + + def test_negative_values(self) -> None: + """Test with negative input values (should be made absolute).""" + plan = ZAboveBelow(above=-2.0, below=-3.0, step=1.0) + start, stop, step = plan._start_stop_step() + assert start == -3.0 # abs(-3.0) = 3.0, then -3.0 + assert stop == 2.0 # abs(-2.0) = 2.0, then +2.0 + assert step == 1.0 + + +class TestZRelativePositions: + """Test ZRelativePositions plan.""" + + def test_basic_creation(self) -> None: + """Test basic creation and attributes.""" + plan = ZRelativePositions(relative=[1.0, 2.0, 3.0]) + assert plan.relative == [1.0, 2.0, 3.0] + assert plan.axis_key == Axis.Z + assert len(plan) == 3 + assert plan.is_relative is True + + def test_list_cast_validator(self) -> None: + """Test that input is cast to list.""" + plan = ZRelativePositions(relative=(1.0, 2.0, 3.0)) # tuple input + assert plan.relative == [1.0, 2.0, 3.0] # should be cast to list + + +class TestZAbsolutePositions: + """Test ZAbsolutePositions plan.""" + + def test_basic_creation(self) -> None: + """Test basic creation and attributes.""" + plan = ZAbsolutePositions(absolute=[10.0, 20.0, 30.0]) + assert plan.absolute == [10.0, 20.0, 30.0] + assert plan.axis_key == Axis.Z + assert len(plan) == 3 + assert plan.is_relative is False + + def test_list_cast_validator(self) -> None: + """Test that input is cast to list.""" + plan = ZAbsolutePositions(absolute=(10.0, 20.0, 30.0)) # tuple input + assert plan.absolute == [10.0, 20.0, 30.0] # should be cast to list + + +class TestZPlanBase: + """Test ZPlan base class functionality.""" + + def test_axis_key_default(self) -> None: + """Test that axis_key defaults to 'z'.""" + plan = ZRelativePositions(relative=[1.0, 2.0]) + assert plan.axis_key == "z" + + def test_mda_axis_iterable_interface(self) -> None: + """Test that Z plans implement MDAAxisIterable interface.""" + plan = ZTopBottom(top=2.0, bottom=0.0, step=1.0) + + # Should have MDAAxisIterable methods + assert isinstance(plan.axis_key, str) + assert plan.contribute_event_kwargs(Position(), {}) is not None + + # Test iteration returns float values + values = [p.z for p in plan] + assert all(isinstance(v, float) for v in values) + + +class TestEdgeCases: + """Test edge cases and error conditions.""" + + def test_zero_step_single_position(self) -> None: + """Test behavior with zero step size.""" + plan = ZTopBottom(top=5.0, bottom=5.0, step=0.0) + positions = [p.z for p in plan] + assert positions == [5.0] + assert len(plan) == 1 + + def test_very_small_steps(self) -> None: + """Test with very small step sizes.""" + plan = ZTopBottom(top=1.0, bottom=0.0, step=0.1) + positions = [p.z for p in plan] + assert len(positions) == 11 + assert positions[0] == pytest.approx(0.0) + assert positions[-1] == pytest.approx(1.0) + + def test_empty_position_lists(self) -> None: + """Test with empty position lists.""" + plan = ZRelativePositions(relative=[]) + positions = [p.z for p in plan] + assert positions == [] + assert len(plan) == 0 + + def test_single_position_lists(self) -> None: + """Test with single position in lists.""" + plan = ZAbsolutePositions(absolute=[42.0]) + positions = [p.z for p in plan] + assert positions == [42.0] + assert len(plan) == 1 + + def test_large_ranges(self) -> None: + """Test with large Z ranges.""" + plan = ZTopBottom(top=1000.0, bottom=0.0, step=100.0) + positions = [p.z for p in plan] + assert len(positions) == 11 + assert positions[0] == 0.0 + assert positions[-1] == 1000.0 + + +class TestSerialization: + """Test serialization and deserialization.""" + + @pytest.mark.parametrize( + "plan_class,kwargs", + [ + (ZTopBottom, {"top": 10.0, "bottom": 0.0, "step": 2.0}), + (ZRangeAround, {"range": 4.0, "step": 1.0}), + (ZAboveBelow, {"above": 3.0, "below": 2.0, "step": 1.0}), + (ZRelativePositions, {"relative": [1.0, 2.0, 3.0]}), + (ZAbsolutePositions, {"absolute": [10.0, 20.0, 30.0]}), + ], + ) + def test_z_plan_serialization(self, plan_class: type[ZPlan], kwargs: dict) -> None: + """Test that Z plans can be serialized and deserialized.""" + original_plan = plan_class(**kwargs) + + # Test JSON serialization round-trip + json_data = original_plan.model_dump_json() + restored_plan = plan_class.model_validate_json(json_data) + + # Should be equivalent + assert list(original_plan) == list(restored_plan) + assert original_plan.axis_key == restored_plan.axis_key + if hasattr(original_plan, "go_up"): + # Check go_up attribute if it exists + assert original_plan.go_up == restored_plan.go_up # type: ignore + + +class TestTypeAliases: + """Test type aliases and union types.""" + + def test_any_z_plan_types(self) -> None: + """Test that AnyZPlan includes all Z plan types.""" + plans = [ + ZTopBottom(top=10.0, bottom=0.0, step=2.0), + ZRangeAround(range=4.0, step=1.0), + ZAboveBelow(above=3.0, below=2.0, step=1.0), + ZRelativePositions(relative=[1.0, 2.0, 3.0]), + ZAbsolutePositions(absolute=[10.0, 20.0, 30.0]), + ] + + for plan in plans: + # Should be valid instances of AnyZPlan + assert isinstance(plan, ZPlan) + + +def test_contribute_to_mda_event_integration() -> None: + """Test integration with MDAEvent.Kwargs.""" + plan = ZTopBottom(top=10.0, bottom=0.0, step=5.0) + + # Test contribution + contribution = plan.contribute_event_kwargs(Position(z=7.5), {"z": 1}) + assert contribution == {"z_pos": 7.5} + + # Test that the contribution can be used to create an MDAEvent + event_data = {"index": {"z": 1}, **contribution} + event = MDAEvent(**event_data) + assert event.z_pos == 7.5 + + +def test_integration_with_mda_axis_iterable() -> None: + """Test that Z plans integrate properly with MDAAxisIterable.""" + plan = ZTopBottom(top=4.0, bottom=0.0, step=2.0) + + # Should have MDAAxisIterable methods + assert hasattr(plan, "axis_key") + + # Test the axis_key + assert plan.axis_key == Axis.Z + + # Test iteration returns float values + values = [p.z for p in plan] + assert all(isinstance(v, float) for v in values) + assert values == [0.0, 2.0, 4.0]