diff --git a/.gitignore b/.gitignore
index f5b3d71..8b2bbe7 100644
--- a/.gitignore
+++ b/.gitignore
@@ -3,6 +3,9 @@ __pycache__/
*.py[cod]
*$py.class
+AFMReader/data/*
+AFMReader/notebooks/*
+
# C extensions
*.so
diff --git a/AFMReader/asd.py b/AFMReader/asd.py
index 64727f5..c0da406 100644
--- a/AFMReader/asd.py
+++ b/AFMReader/asd.py
@@ -10,6 +10,7 @@
import numpy.typing as npt
from matplotlib import animation
+from AFMReader.data_classes import AFMLoad
from AFMReader.io import (
read_ascii,
read_bool,
@@ -182,7 +183,7 @@ def calculate_scaling_factor(
raise ValueError(f"channel {channel} not known for .asd file type.")
-def load_asd(file_path: str | Path, channel: str):
+def load_asd(file_path: str | Path, channel: str) -> AFMLoad:
"""
Load a .asd file.
@@ -196,17 +197,16 @@ def load_asd(file_path: str | Path, channel: str):
Returns
-------
- npt.NDArray
- The .asd file frames data as a numpy 3D array N x W x H
- (Number of frames x Width of each frame x height of each frame).
- float
- The number of nanometres per pixel for the .asd file. (AKA the resolution).
- Enables converting between pixels and nanometres when working with the data, in order to use real-world length
- scales.
- dict
- Metadata for the .asd file. The number of entries is too long to list here, and changes based on the file
- version please either look into the `read_header_file_version_x` functions or print the keys too see what
- metadata is available.
+ AFMLoad
+ An AFMLoad object containing:
+ - image : npt.NDArray
+ Shape (Number of frames x Width of each frame x height of each frame).
+ - px2nm : float
+ The number of nanometres per pixel for the .asd file.
+ - metadata : dict
+ Metadata for the .asd file. The number of entries is too long to list here, and changes based on the file
+ version please either look into the `read_header_file_version_x` functions or print the keys too see what
+ metadata is available.
"""
# Ensure the file path is a Path object
file_path = Path(file_path)
@@ -285,7 +285,39 @@ def load_asd(file_path: str | Path, channel: str):
frames = np.array(frames)
logger.info(f"[{filename}] : Extracted image.")
- return frames, pixel_to_nanometre_scaling_factor, header_dict
+ return AFMLoad(image=frames, px2nm=pixel_to_nanometre_scaling_factor, metadata=header_dict)
+
+
+def get_asd_channels(file_path: Path):
+ """
+ Get the channels available in given .asd file.
+
+ Parameters
+ ----------
+ file_path : Path
+ Path to the .asd file.
+
+ Returns
+ -------
+ list
+ List of channels available in the .asd file.
+ """
+ with Path.open(file_path, "rb", encoding=None) as open_file: # pylint: disable=unspecified-encoding
+ file_version = read_file_version(open_file)
+
+ if file_version == 0:
+ header_dict = read_header_file_version_0(open_file)
+
+ elif file_version == 1:
+ header_dict = read_header_file_version_1(open_file)
+
+ elif file_version == 2:
+ header_dict = read_header_file_version_2(open_file)
+ else:
+ raise ValueError(
+ f"File version {file_version} unknown. Please add support if you know how to decode this file version."
+ )
+ return [header_dict["channel1"], header_dict["channel2"]]
def read_file_version(open_file: BinaryIO) -> int:
diff --git a/AFMReader/data_classes.py b/AFMReader/data_classes.py
new file mode 100644
index 0000000..7e4f2d5
--- /dev/null
+++ b/AFMReader/data_classes.py
@@ -0,0 +1,326 @@
+"""
+Data classes for lazy loading of curve data and metadata from files.
+
+These classes provide a consistent interface for accessing curve data and metadata in a
+lazy manner (i.e. loading data on demand rather than all at once) across different file
+formats. This is necessary for handling large datasets with massive memory consumption.
+"""
+
+import numpy as np
+
+# pylint: disable=too-few-public-methods,fixme
+
+
+class CurvesMetadata:
+ """
+ A class representing the metadata for a dataset of curves, providing lazy loaded access to pixel metadata.
+
+ This is a parent class that should be subclassed for specific file formats to implement
+ the get_pixel_metadata method, which defines how the metadata is retrieved from the
+ underlying data source.
+
+ Parameters
+ ----------
+ toplevel : dict
+ A dictionary containing the top-level metadata for the dataset.
+ shape_x : int
+ The number of columns in the image.
+ shape_y : int
+ The number of rows in the image.
+ flip_image : bool, optional
+ Whether to flip the image vertically. Default is True.
+ """
+
+ def __init__(self, toplevel: dict, shape_x: int, shape_y: int, flip_image: bool = True):
+ """
+ Initialise CurvesMetadata.
+
+ Parameters
+ ----------
+ toplevel : dict
+ A dictionary containing the top-level metadata for the dataset.
+ shape_x : int
+ The number of columns in the image.
+ shape_y : int
+ The number of rows in the image.
+ flip_image : bool, optional
+ Whether to flip the image vertically. Default is True.
+ """
+ self.toplevel = toplevel
+ self.shape_x = shape_x
+ self.shape_y = shape_y
+ self.flip_image = flip_image
+
+ def __getitem__(self, keys):
+ """
+ Fetch the metadata for a specific pixel or pixel direction.
+
+ For example, metadata[y, x] would return the metadata for the pixel at row y and column x, while
+ metadata[y, x, 0] would return the metadata for specifically the first direction of that pixel
+ (usually the approach).
+
+ Parameters
+ ----------
+ keys : tuple
+ A tuple of (y, x) or (y, x, direction) representing the indices.
+
+ Returns
+ -------
+ dict
+ The metadata for the specified pixel or direction.
+ """
+ if isinstance(keys, tuple) and len(keys) == 2:
+ y, x = keys
+ return self.get_point_metadata(y, x)
+ if isinstance(keys, tuple) and len(keys) == 3:
+ y, x, direction = keys
+ return self.get_point_metadata(y, x, direction)
+ raise IndexError(
+ f"Invalid indexing. Expected (y, x) or (y, x, direction) for point metadata indexing. Got {keys}."
+ )
+
+ # pylint: disable=unused-argument
+ def get_point_metadata(self, y: int, x: int, direction: int | None = None):
+ """
+ Fetch the metadata for a specific pixel/ point, optionally for a specific direction.
+
+ Should be implemented by subclasses if there exists per point metadata to define how the metadata is retrieved
+ from the underlying data source. If there is no per point metadata, this can simply return an empty dict.
+
+ Parameters
+ ----------
+ y : int
+ Row index of the pixel.
+ x : int
+ Column index of the pixel.
+ direction : int, optional
+ The index of the direction to fetch metadata for. If None, returns metadata for the entire pixel.
+
+ Returns
+ -------
+ dict
+ The metadata for the specified pixel (or direction, if provided).
+ """
+ return {}
+
+
+class CurvesVolume:
+ """
+ A class representing a 2D map or volume of curves, providing lazy loaded access to curve data.
+
+ An individual curve can be accessed using volume[y, x], which will load the curve data for that pixel on demand.
+
+ Parameters
+ ----------
+ name : str
+ The name of the curve volume.
+ shape_x : int
+ The number of columns in the image.
+ shape_y : int
+ The number of rows in the image.
+ channel_units : dict[str, str]
+ A dictionary mapping channel names to their units.
+ flip_image : bool, optional
+ Whether to flip the image vertically. Default is True.
+ """
+
+ def __init__(self, name: str, shape_x: int, shape_y: int, channel_units: dict[str, str], flip_image: bool = True):
+ """
+ Initialise CurvesVolume.
+
+ Parameters
+ ----------
+ name : str
+ The name of the curve volume.
+ shape_x : int
+ The number of columns in the image.
+ shape_y : int
+ The number of rows in the image.
+ channel_units : dict[str, str]
+ A dictionary mapping channel names to their units.
+ flip_image : bool, optional
+ Whether to flip the image vertically. Default is True.
+ """
+ self.shape_x = shape_x
+ self.shape_y = shape_y
+ self.dims = (shape_y, shape_x)
+ self.flip_image = flip_image
+ self.name = name
+ self.channel_units = channel_units
+
+ def __len__(self):
+ """
+ Return the total number of pixels in the image.
+
+ Returns
+ -------
+ int
+ The total number of pixels in the image.
+ """
+ return self.shape_x * self.shape_y
+
+ def __getitem__(self, keys):
+ """
+ Allow numpy style indexing to fetch curve data for a specific pixel.
+
+ Parameters
+ ----------
+ keys : tuple
+ A tuple of (y, x) representing the row and column indices of the pixel.
+
+ Returns
+ -------
+ dict
+ The QI curve data for the specified pixel.
+ """
+ if not isinstance(keys, tuple) or len(keys) != 2:
+ raise IndexError(f"Invalid indexing. Expected (y, x) for pixel indexing. Got {keys}.")
+ y, x = keys
+ if y < 0 or y >= self.shape_y or x < 0 or x >= self.shape_x:
+ raise IndexError(f"Pixel index ({y}, {x}) is out of bounds for image of shape {self.dims}.")
+ return self.get_curve(y, x)
+
+ def get_curve(self, y: int, x: int):
+ """
+ Fetch the QI curve data for a specific pixel.
+
+ Should be implemented by subclasses to define how the curve data is retrieved from the underlying data source.
+
+ Parameters
+ ----------
+ y : int
+ Row index of the pixel.
+ x : int
+ Column index of the pixel.
+
+ Returns
+ -------
+ dict
+ The QI curve data for the specified pixel.
+ """
+ raise NotImplementedError("This method should be implemented by subclasses to fetch curve data on demand.")
+
+
+class CurvesDataset:
+ """
+ A dataset containing multiple curve volumes and associated metadata.
+
+ Parameters
+ ----------
+ volumes : dict[str, CurvesVolume]
+ A dictionary mapping curve names to CurvesVolume instances that
+ provide access to the curve data for each pixel.
+ metadata : CurvesMetadata
+ An instance of CurvesMetadata that provides access to the metadata
+ for each curve.
+ default_volume_name : str, optional
+ The name of the default volume to use when accessing curve data.
+ If None, the first volume in the dictionary is used.
+ """
+
+ def __init__(
+ self, volumes: dict[str, CurvesVolume], metadata: CurvesMetadata, default_volume_name: str | None = None
+ ):
+ """
+ Initialise CurvesDataset.
+
+ Parameters
+ ----------
+ volumes : dict[str, CurvesVolume]
+ A dictionary mapping curve names to CurvesVolume instances that
+ provide access to the curve data for each pixel.
+ metadata : CurvesMetadata
+ An instance of CurvesMetadata that provides access to the metadata
+ for each curve.
+ default_volume_name : str | None, optional
+ The name of the default volume to use when accessing curve data.
+ If None, the first volume in the dictionary is used.
+ """
+ self.volumes = volumes
+ self.metadata = metadata
+ self.default_volume_name = default_volume_name or next(
+ iter(volumes)
+ ) # Use the first volume as default if not specified
+
+ def add_volume(self, name: str, volume: CurvesVolume, default: bool = False):
+ """
+ Add a CurvesVolume to the dataset.
+
+ Parameters
+ ----------
+ name : str
+ The name of the curve to add.
+ volume : CurvesVolume
+ The CurvesVolume instance containing the curve data for each pixel.
+ default : bool, optional
+ Whether to set this volume as the default volume. Default is False.
+ """
+ self.volumes[name] = volume
+ if default:
+ self.default_volume_name = name
+
+ def get_default_volume(self) -> CurvesVolume:
+ """
+ Get the default CurvesVolume for this dataset.
+
+ Returns
+ -------
+ CurvesVolume
+ The default CurvesVolume instance for this dataset.
+ """
+ return self.volumes[self.default_volume_name]
+
+
+class AFMLoad:
+ """
+ A class representing the loaded AFM data, including the image and scaling factors.
+
+ Parameters
+ ----------
+ image : np.ndarray
+ The image data.
+ px2nm : float
+ The pixel to nanometer scaling factor.
+ timestamps : dict | None, optional
+ Timestamps associated with the data. Default is None.
+ metadata : dict | None, optional
+ Metadata associated with the data. Default is None.
+ curves_dataset : CurvesDataset | None, optional
+ Curves dataset associated with the data. Default is None.
+ """
+
+ image: np.ndarray
+ px2nm: float
+ timestamps: dict | None = None
+ metadata: dict | None = None
+ curves_dataset: CurvesDataset | None = None
+
+ def __init__(
+ self,
+ image: np.ndarray,
+ px2nm: float,
+ timestamps: dict | None = None,
+ metadata: dict | None = None,
+ curves_dataset: CurvesDataset | None = None,
+ ):
+ """
+ Initialise AFMLoad.
+
+ Parameters
+ ----------
+ image : np.ndarray
+ The image data.
+ px2nm : float
+ The pixel to nanometer scaling factor.
+ timestamps : dict | None, optional
+ Timestamps associated with the data. Default is None.
+ metadata : dict | None, optional
+ Metadata associated with the data. Default is None.
+ curves_dataset : CurvesDataset | None, optional
+ Curves dataset associated with the data. Default is None.
+ """
+ self.image = image
+ self.px2nm = px2nm
+ self.timestamps = timestamps
+ self.metadata = metadata
+ self.curves_dataset = curves_dataset
diff --git a/AFMReader/general_loader.py b/AFMReader/general_loader.py
index be27937..2e201a6 100644
--- a/AFMReader/general_loader.py
+++ b/AFMReader/general_loader.py
@@ -1,16 +1,17 @@
"""Switchboard for input files."""
from pathlib import Path
+from typing import Any
-import numpy.typing as npt
-from AFMReader import asd, gwy, h5_jpk, ibw, jpk, spm, stp, top, topostats
+from AFMReader import asd, gwy, h5_jpk, ibw, jpk, raw_bin, spm, stp, top, topostats, jpk_qi
+from AFMReader.data_classes import AFMLoad
from AFMReader.logging import logger
logger.enable(__package__)
-# pylint: disable=too-few-public-methods
+# pylint: disable=too-few-public-methods,too-many-branches,too-many-statements,fixme
class LoadFile:
"""
Class to handle the general loading of an AFM file.
@@ -21,9 +22,11 @@ class LoadFile:
Path to the AFM image.
channel : str
Channel to extract from the AFM image.
+ kwargs : dict, optional
+ Additional keyword arguments to pass to the specific loaders.
"""
- def __init__(self, filepath: str | Path, channel: str):
+ def __init__(self, filepath: str | Path, channel: str, kwargs: dict | None = None):
"""
Initialise the general LoadFile class with a filepath and channel.
@@ -33,62 +36,106 @@ def __init__(self, filepath: str | Path, channel: str):
Path to the AFM image.
channel : str
Channel to extract from the AFM image.
+ kwargs : dict, optional
+ Additional keyword arguments to pass to the specific loaders.
"""
self.filepath = Path(filepath)
self.channel = channel
self.suffix = self.filepath.suffix
+ self.loaded_curves = False
+ self.kwargs = kwargs if kwargs else {}
- def load(self) -> tuple[npt.NDArray | str, float | None]: # noqa: C901
+ # Store heavy loaded data in a dict to avoid having to reload it
+ self.cached_data: dict[str, Any] = {}
+
+ def load(self, channel: str | None = None, kwargs: dict | None = None) -> AFMLoad: # noqa: C901
"""
Generally loads a file type that can be handled by AFMReader.
+ Parameters
+ ----------
+ channel : str, optional
+ Overriding channel to extract from the AFM image.
+ kwargs : dict, optional
+ Additional keyword arguments to pass to the specific loaders.
+
Returns
-------
- tuple
- The image data (stack if ''.asd'' or ''.h5-jpk'') and the pixel to nanometre scaling ratio.
+ AFMLoad
+ An AFMLoad object containing the loaded AFM image data and metadata.
Raises
------
ValueError
- Where the channel is not found, returned as a tuple of "error message" and "None" so that this can be
- propagated to Napari without outright failing.
+ Where the channel is not found.
"""
+ if channel:
+ self.channel = channel
+ if kwargs:
+ self.kwargs = kwargs
try:
if self.suffix == ".asd":
- image, pixel_to_nanometre_scaling_factor, _ = asd.load_asd(self.filepath, self.channel)
+ afm_load = asd.load_asd(self.filepath, self.channel)
elif self.suffix == ".gwy":
- image, pixel_to_nanometre_scaling_factor = gwy.load_gwy(self.filepath, self.channel)
+ afm_load = gwy.load_gwy(self.filepath, self.channel)
elif self.suffix == ".ibw":
- image, pixel_to_nanometre_scaling_factor = ibw.load_ibw(self.filepath, self.channel)
- elif self.suffix == ".jpk":
- image, pixel_to_nanometre_scaling_factor = jpk.load_jpk(self.filepath, self.channel)
+ afm_load = ibw.load_ibw(self.filepath, self.channel)
+ elif self.suffix in [".jpk", ".jpk-qi-image"]:
+ afm_load = jpk.load_jpk(self.filepath, self.channel)
elif self.suffix == ".spm":
- image, pixel_to_nanometre_scaling_factor = spm.load_spm(self.filepath, self.channel)
+ afm_load = spm.load_spm(self.filepath, self.channel)
elif self.suffix == ".h5-jpk":
- image, pixel_to_nanometre_scaling_factor, _ = h5_jpk.load_h5jpk(self.filepath, self.channel)
+ afm_load = h5_jpk.load_h5jpk(self.filepath, self.channel)
+ elif self.suffix == ".jpk-qi-data":
+ afm_load = jpk_qi.load_jpk_data(
+ filepath=self.filepath, channel=self.channel, cached_data=self.cached_data, **self.kwargs
+ )
elif self.suffix == ".stp":
- image, pixel_to_nanometre_scaling_factor = stp.load_stp(self.filepath)
+ afm_load = stp.load_stp(self.filepath)
elif self.suffix == ".top":
- image, pixel_to_nanometre_scaling_factor = top.load_top(self.filepath)
+ afm_load = top.load_top(self.filepath)
elif self.suffix == ".topostats":
- ts_dict = topostats.load_topostats(self.filepath)
- try:
- image = ts_dict[self.channel]
- pixel_to_nanometre_scaling_factor = ts_dict["pixel_to_nm_scaling"]
- except KeyError as exc:
- image_keys = ["image", "image_original"]
- topostats_keys = list(ts_dict.keys())
- raise ValueError(
- f"'{self.channel}' not in available image keys: "
- f"{[im for im in image_keys if im in topostats_keys]}"
- ) from exc
+ afm_load = topostats.load_topostats(self.filepath, self.channel)
+ elif self.suffix == ".bin":
+ afm_load = raw_bin.load_bin(self.filepath, **self.kwargs)
else:
raise ValueError(f"File type '{self.suffix}' is not currently handled by AFMReader.")
- return image, pixel_to_nanometre_scaling_factor
+ return afm_load
except ValueError as e:
logger.error(f"{e}")
- return (e, None) # cheeky return of an image, px2nm-like tuple object to propagate error message to Napari
+ raise e
- # scope for a "check what channels are available" function similar to above.
+ def get_available_channels(self): # noqa: C901
+ """
+ Get the available channels for the file type.
+
+ Returns
+ -------
+ list
+ List of available channels.
+ """
+ if self.suffix == ".asd":
+ available_channels = asd.get_asd_channels(self.filepath)
+ elif self.suffix == ".gwy":
+ available_channels = gwy.get_gwy_channels(self.filepath)
+ elif self.suffix == ".ibw":
+ available_channels = ibw.get_ibw_channels(self.filepath)
+ elif self.suffix in [".jpk", ".jpk-qi-image"]:
+ available_channels = jpk.get_jpk_channels(self.filepath)
+ elif self.suffix == ".spm":
+ available_channels = spm.get_spm_channels(self.filepath)
+ elif self.suffix == ".h5-jpk":
+ available_channels = h5_jpk.get_h5jpk_channels(self.filepath)
+ elif self.suffix == ".jpk-qi-data":
+ available_channels = jpk_qi.get_jpk_data_channels(filepath=self.filepath, cached_data=self.cached_data)
+ elif self.suffix == ".topostats":
+ available_channels = topostats.get_topostats_channels()
+ elif self.suffix == ".bin":
+ available_channels = raw_bin.get_bin_channels()
+ elif self.suffix in [".stp", ".top"]:
+ return []
+ else:
+ raise ValueError(f"File type '{self.suffix}' is not currently handled by AFMReader.")
+ return available_channels
diff --git a/AFMReader/gwy.py b/AFMReader/gwy.py
index 834cd25..653cfd9 100644
--- a/AFMReader/gwy.py
+++ b/AFMReader/gwy.py
@@ -7,10 +7,37 @@
import numpy as np
from loguru import logger
+from AFMReader.data_classes import AFMLoad
from AFMReader.io import read_char, read_double, read_null_terminated_string, read_uint32
-def load_gwy(file_path: Path | str, channel: str) -> tuple[np.ndarray[Any, np.float64], float]:
+def get_gwy_channels(file_path):
+ """
+ Extract a list of available channels and their corresponding dictionary key ids from the `.gwy` file.
+
+ Parameters
+ ----------
+ file_path : Path or str
+ Path to the .gwy file.
+
+ Returns
+ -------
+ list
+ List of available channels.
+ """
+ image_data_dict: dict[Any, Any] = {}
+ with Path.open(file_path, "rb") as open_file: # pylint: disable=unspecified-encoding
+ # Read header
+ header = open_file.read(4)
+ logger.debug(f"Gwy file header: {header.decode}")
+
+ gwy_read_object(open_file, data_dict=image_data_dict)
+ channel_ids = gwy_get_channels(gwy_file_structure=image_data_dict)
+
+ return list(channel_ids)
+
+
+def load_gwy(file_path: Path | str, channel: str) -> AFMLoad:
"""
Extract image and pixel to nm scaling from the .gwy file.
@@ -23,8 +50,8 @@ def load_gwy(file_path: Path | str, channel: str) -> tuple[np.ndarray[Any, np.fl
Returns
-------
- tuple(np.ndarray, float)
- A tuple containing the image and its pixel to nanometre scaling value.
+ AFMLoad
+ An AFMLoad object containing the image and its pixel to nanometre scaling value.
Raises
------
@@ -39,7 +66,9 @@ def load_gwy(file_path: Path | str, channel: str) -> tuple[np.ndarray[Any, np.fl
Sensor'.
>>> from AFMReader.gwy import load_gwy
- >>> image, pixel_to_nm = load_gwy(file_path="path/to/file.gwy", channel="Height")
+ >>> afm_load = load_gwy(file_path="path/to/file.gwy", channel="Height")
+ >>> image = afm_load.image
+ >>> px2nm = afm_load.px2nm
```
"""
logger.info(f"Loading image from : {file_path}")
@@ -87,7 +116,7 @@ def load_gwy(file_path: Path | str, channel: str) -> tuple[np.ndarray[Any, np.fl
raise ValueError(f"'{channel}' not found in {file_path.suffix} channel list: {channel_ids}") from e
logger.info(f"[{filename}] : Extracted image.")
- return (image, px_to_nm)
+ return AFMLoad(image=image, px2nm=px_to_nm)
def gwy_read_object(open_file: BinaryIO, data_dict: dict) -> None:
diff --git a/AFMReader/h5_jpk.py b/AFMReader/h5_jpk.py
index 13c834b..facff01 100644
--- a/AFMReader/h5_jpk.py
+++ b/AFMReader/h5_jpk.py
@@ -6,14 +6,23 @@
"""
from pathlib import Path
+from typing import Any
import h5py
import numpy as np
from AFMReader.logging import logger
+from AFMReader.data_classes import (
+ AFMLoad,
+ CurvesDataset,
+ CurvesMetadata,
+ CurvesVolume,
+)
logger.enable(__package__)
+# pylint: disable=too-few-public-methods,too-many-locals,fixme,too-many-positional-arguments
+
def _parse_channel_name(channel: str) -> tuple[str, str]:
"""
@@ -268,9 +277,266 @@ def generate_timestamps(num_frames: int, line_rate: float, image_size: int) -> d
return {f"frame {i}": timestamp for i, timestamp in enumerate(timestamps)}
-def load_h5jpk(
- file_path: Path | str, channel: str, flip_image: bool = True
-) -> tuple[np.ndarray, float, dict[str, float]]:
+def get_h5jpk_channels(file_path: Path | str):
+ """
+ Get available channels from a .h5-jpk file.
+
+ Parameters
+ ----------
+ file_path : Path | str
+ Path to the .h5-jpk file.
+
+ Returns
+ -------
+ list
+ List of available channels.
+ """
+ with h5py.File(file_path, "r") as f:
+ return list(_available_channels(f))
+
+
+class CurvesH5Volume(CurvesVolume):
+ """
+ A CurvesVolume implementation for HDF5 curve data that provides lazy loading of curve data for each pixel.
+
+ Note that the curve data in the HDF5 file is usually copied from another format for fast access.
+
+ Parameters
+ ----------
+ name : str
+ The name of the curve volume.
+ shape_x : int
+ The number of columns in the image.
+ shape_y : int
+ The number of rows in the image.
+ qi_data_group : h5py.Group
+ The HDF5 group containing the QI curve data.
+ channel_units : dict[str, str]
+ A dictionary mapping channel names to their units.
+ flip_image : bool, optional
+ Whether to flip the image vertically. Default is True.
+ """
+
+ def __init__(
+ self,
+ name: str,
+ shape_x: int,
+ shape_y: int,
+ qi_data_group: h5py.Group,
+ channel_units: dict[str, str],
+ flip_image: bool = True,
+ ):
+ """
+ Initialize the CurvesH5Volume instance.
+
+ Parameters
+ ----------
+ name : str
+ The name of the curve volume.
+ shape_x : int
+ The number of columns in the image.
+ shape_y : int
+ The number of rows in the image.
+ qi_data_group : h5py.Group
+ The HDF5 group containing the QI curve data.
+ channel_units : dict[str, str]
+ A dictionary mapping channel names to their units.
+ flip_image : bool, optional
+ Whether to flip the image vertically. Default is True.
+ """
+ super().__init__(
+ name=name,
+ shape_x=shape_x,
+ shape_y=shape_y,
+ channel_units=channel_units,
+ flip_image=flip_image,
+ )
+ self.qi_data_group = qi_data_group
+
+ def __iter__(self): # noqa: C901
+ """
+ Efficiently iterate over the QI curve data, loading one row at a time.
+
+ Yields
+ ------
+ dict
+ A dictionary containing the QI curve data for each channel and segment.
+ """
+ indices_map = {}
+ for segment, segment_group in self.qi_data_group["Curves"].items():
+ for channel in segment_group["Indices"]:
+ if channel not in indices_map:
+ indices_map[channel] = {}
+ indices_map[channel][segment] = segment_group["Indices"][channel][:]
+ for y_idx in range(self.shape_y):
+ data = {}
+ y = self.shape_y - 1 - y_idx if self.flip_image else y_idx
+ for segment, segment_group in self.qi_data_group["Curves"].items():
+ for channel in segment_group["Indices"]:
+ if channel not in data:
+ data[channel] = {}
+ indices = indices_map[channel][segment]
+ start_idx = int(indices[self.shape_x * y])
+ end_idx = int(indices[self.shape_x * (y + 1)])
+
+ data[channel][segment] = segment_group["Data"][channel][start_idx:end_idx]
+ for x in range(self.shape_x):
+ curve_data = {}
+ for channel, channel_data in data.items():
+ curve_data[channel] = {}
+ for segment, segment_data in channel_data.items():
+ indices = indices_map[channel][segment]
+ start_idx = int(indices[self.shape_x * y + x]) - int(indices[self.shape_x * y])
+ end_idx = int(indices[self.shape_x * y + x + 1]) - int(indices[self.shape_x * y])
+ curve_data[channel][segment] = segment_data[start_idx:end_idx]
+ yield curve_data
+
+ def get_curve(self, y: int, x: int):
+ """
+ Fetch the QI curve data for a specific pixel (x, y) on demand.
+
+ Parameters
+ ----------
+ y : int
+ The row index.
+ x : int
+ The column index.
+
+ Returns
+ -------
+ dict
+ A dictionary containing the QI curve data for the specified pixel.
+ """
+ if y < 0 or y >= self.shape_y or x < 0 or x >= self.shape_x:
+ raise IndexError(f"Curve index out of bounds: ({x}, {y})")
+ curve_dict: dict[str, dict[str, Any]] = {}
+ if self.flip_image:
+ y = self.shape_y - 1 - y
+ curve_num = self.shape_x * y + x
+ for segment, segment_group in self.qi_data_group["Curves"].items():
+ for channel in segment_group["Indices"]:
+ start_idx = int(segment_group["Indices"][channel][curve_num])
+ end_idx = int(segment_group["Indices"][channel][curve_num + 1])
+ if channel not in curve_dict:
+ curve_dict[channel] = {}
+ curve_dict[channel][segment] = segment_group["Data"][channel][start_idx:end_idx]
+ return curve_dict
+
+ def load_all_curves(self):
+ """
+ Load all QI curve data into memory.
+
+ Returns
+ -------
+ list
+ A 2D list containing dictionaries with QI curve data for each pixel.
+ """
+ all_curves = [[{} for _ in range(self.shape_x)] for _ in range(self.shape_y)]
+ for segment, segment_group in self.qi_data_group["Curves"].items():
+ for channel in segment_group["Indices"]:
+ indices = segment_group["Indices"][channel][:]
+ data = segment_group["Data"][channel][:]
+ for i in range(len(indices) - 1):
+ start_idx = int(indices[i])
+ end_idx = int(indices[i + 1])
+ x = i % self.shape_x
+ y = i // self.shape_x
+ if self.flip_image:
+ y = self.shape_y - 1 - y
+ if channel not in all_curves[y][x]:
+ all_curves[y][x][channel] = {}
+ all_curves[y][x][channel][segment] = data[start_idx:end_idx]
+
+ return all_curves
+
+
+class CurvesH5Metadata(CurvesMetadata):
+ """
+ Metadata class for H5 JPK data that provides access to metadata on demand.
+
+ Parameters
+ ----------
+ qi_data_group : h5py.Group
+ The HDF5 group containing the QI curve data.
+ toplevel : dict[str, Any]
+ The top-level metadata dictionary.
+ shape_x : int
+ The number of columns in the image.
+ shape_y : int
+ The number of rows in the image.
+ flip_image : bool, optional
+ Whether to flip the image vertically. Default is ``True``.
+ """
+
+ # pylint: disable=too-many-positional-arguments
+ def __init__(
+ self,
+ qi_data_group: h5py.Group,
+ toplevel: dict[str, Any],
+ shape_x: int,
+ shape_y: int,
+ flip_image: bool = True,
+ ):
+ """
+ Initialize the CurvesH5Metadata instance.
+
+ Parameters
+ ----------
+ qi_data_group : h5py.Group
+ The HDF5 group containing the QI curve data.
+ toplevel : dict[str, Any]
+ The top-level metadata dictionary.
+ shape_x : int
+ The number of columns in the image.
+ shape_y : int
+ The number of rows in the image.
+ flip_image : bool, optional
+ Whether to flip the image vertically. Default is ``True``.
+ """
+ super().__init__(toplevel, shape_x, shape_y, flip_image)
+ self.qi_data_group = qi_data_group
+
+ def get_point_metadata(self, y: int, x: int, direction: int | None = None):
+ """
+ Fetch metadata for a specific pixel (x, y) on demand.
+
+ Parameters
+ ----------
+ y : int
+ The row index.
+ x : int
+ The column index.
+ direction : int, optional
+ The direction index for segment metadata (0 or 1), required if meta_type is "segment".
+
+ Returns
+ -------
+ dict
+ A dictionary containing the fetched metadata.
+ """
+ if y < 0 or y >= self.shape_y or x < 0 or x >= self.shape_x:
+ raise IndexError(f"Curve index out of bounds: ({x}, {y})")
+ if self.flip_image:
+ y = self.shape_y - 1 - y
+ idx = (y * self.shape_x) + x
+ if direction is not None:
+ idx = (idx * 2) + direction
+ meta_dict = {}
+ for key in self.qi_data_group["Curve_Metadata"]:
+ if key.startswith(f"{'segment' if direction is not None else 'curve'}."):
+ new_key = key.split(".", 1)[1]
+ if isinstance(self.qi_data_group["Curve_Metadata"][key], h5py.Dataset):
+ meta_dict[new_key] = (
+ self.qi_data_group["Curve_Metadata"][key][idx].decode("utf-8")
+ if isinstance(self.qi_data_group["Curve_Metadata"][key][idx], bytes)
+ else self.qi_data_group["Curve_Metadata"][key][idx]
+ )
+ else:
+ meta_dict[new_key] = self.qi_data_group["Curve_Metadata"][key]
+ return meta_dict
+
+
+def load_h5jpk(file_path: Path | str, channel: str, flip_image: bool = True, load_curves: bool = True) -> AFMLoad:
"""
Load image from JPK Instruments .h5-jpk files.
@@ -282,15 +548,15 @@ def load_h5jpk(
The channel to extract from the .h5-jpk file.
flip_image : bool, optional
Whether to flip the images vertically. Default is ``True``.
+ load_curves : bool, optional
+ Whether to load QI curve data if present. Default is ``True``.
Returns
-------
- image : np.ndarray
- 3D array of shape (frames, height, width) with image data.
- pixel_to_nm_scaling : float
- Scaling factor converting pixels to nanometers.
- timestamps : dict[str, float]
- Dictionary mapping frame labels (e.g., "frame 0") to timestamp values in seconds.
+ AFMLoad
+ An AFMLoad object containing the image, its pixel to nanometre scaling value, timestamps, and
+ optionally the curves dataset. Curves dataset only if load_curves is True and curve data is
+ present in the file.
Raises
------
@@ -304,9 +570,9 @@ def load_h5jpk(
Load height trace channel from the .jpk file. 'height_trace' is the default channel name.
>>> from AFMReader.jpk import load_h5jpk
- >>> frames, pixel_to_nanometre_scaling_factor, timestamps = load_h5jpk(file_path="./my_jpk_file.jpk",
- >>> channel="height_trace",
- >>> flip_image=True)
+ >>> afm_load = load_h5jpk(file_path="./my_jpk_file.jpk", channel="height_trace", flip_image=True)
+ >>> image = afm_load.image
+ >>> pixel_to_nm_scaling = afm_load.px2nm
"""
logger.info(f"Loading H5-JPK file from : {file_path}")
file_path = Path(file_path)
@@ -323,18 +589,25 @@ def load_h5jpk(
images = (images * scaling) + offset
# Select and reshape a flattened frame
- image_size = measurement_group.attrs["position-pattern.grid.ilength"] # number of pixels
+ shape_x = measurement_group.attrs["position-pattern.grid.ilength"]
+ shape_y = measurement_group.attrs.get("position-pattern.grid.jlength", shape_x) # number of pixels
# Reshape each column vector (height, width) to get (num_frames, height, width)
num_frames = images.shape[1]
- image_stack = np.empty((num_frames, image_size, image_size), dtype=images.dtype)
+ if num_frames == 1:
+ image_stack = np.empty((shape_y, shape_x), dtype=images.dtype)
+ else:
+ image_stack = np.empty((num_frames, shape_y, shape_x), dtype=images.dtype)
for i in range(num_frames):
- frame = images[:, i].reshape((image_size, image_size))
+ frame = images[:, i].reshape((shape_y, shape_x))
# Flip images
if flip_image:
frame = np.flipud(frame)
- image_stack[i] = frame
+ if num_frames == 1:
+ image_stack = frame
+ else:
+ image_stack[i] = frame
# Convert to nm
if dataset_name.lower() in ("height", "error", "measuredheight", "amplitude"):
@@ -342,7 +615,43 @@ def load_h5jpk(
# Generate a dictionary of timestamps
line_rate = _get_line_rate(measurement_group)
- timestamps = generate_timestamps(num_frames, line_rate, image_size)
+ timestamps = generate_timestamps(num_frames, line_rate, shape_y)
logger.info(f"[{file_path.stem}] : Extracted {num_frames} frames from channel '{channel}'")
- return (image_stack, _jpk_pixel_to_nm_scaling_h5(measurement_group), timestamps)
+ px2nm = _jpk_pixel_to_nm_scaling_h5(measurement_group)
+
+ if "QI_Curve_Data" not in f:
+ load_curves = False
+
+ if load_curves:
+ f = h5py.File(file_path, "r")
+ logger.info(f"[{file_path.stem}] : Found Force Curves QI data in file.")
+ qi_data_group = f["QI_Curve_Data"]
+ channels_units = {}
+ top_level_meta = {}
+ for key, value in qi_data_group["Global_Metadata"].attrs.items():
+ if key.startswith("channel.unit."):
+ channels_units[key.split(".")[-1]] = value
+ top_level_meta[key] = value
+
+ curves_volume = CurvesH5Volume(
+ name="Trace",
+ shape_x=shape_x,
+ shape_y=shape_y,
+ qi_data_group=qi_data_group,
+ channel_units=channels_units,
+ flip_image=flip_image,
+ )
+ curves_metadata = CurvesH5Metadata(
+ qi_data_group=qi_data_group,
+ toplevel=top_level_meta,
+ shape_x=shape_x,
+ shape_y=shape_y,
+ flip_image=flip_image,
+ )
+
+ curves_data = CurvesDataset(volumes={"Trace": curves_volume}, metadata=curves_metadata)
+
+ return AFMLoad(image=image_stack, px2nm=px2nm, timestamps=timestamps, curves_dataset=curves_data)
+
+ return AFMLoad(image=image_stack, px2nm=px2nm, timestamps=timestamps)
diff --git a/AFMReader/ibw.py b/AFMReader/ibw.py
index f084403..442d459 100644
--- a/AFMReader/ibw.py
+++ b/AFMReader/ibw.py
@@ -7,6 +7,7 @@
import numpy as np
from igor2 import binarywave
+from AFMReader.data_classes import AFMLoad
from AFMReader.logging import logger
logger.enable(__package__)
@@ -39,7 +40,33 @@ def _ibw_pixel_to_nm_scaling(scan: dict) -> float:
)[0]
-def load_ibw(file_path: Path | str, channel: str) -> tuple[np.ndarray, float]:
+def get_ibw_channels(file_path: Path | str):
+ """
+ Extract a list of available channels and their corresponding dictionary key ids from the `.ibw` file.
+
+ Parameters
+ ----------
+ file_path : Path or str
+ Path to the .ibw file.
+
+ Returns
+ -------
+ list
+ List of available channels.
+ """
+ file_path = Path(file_path)
+ filename = file_path.stem
+ scan = binarywave.load(file_path)
+ logger.info(f"[{filename}] : Loaded image from : {file_path}")
+ labels = []
+ for label_list in scan["wave"]["labels"]:
+ for label in label_list:
+ if label:
+ labels.append(label.decode())
+ return labels
+
+
+def load_ibw(file_path: Path | str, channel: str) -> AFMLoad:
"""
Load image from Asylum Research (Igor) .ibw files.
@@ -52,8 +79,8 @@ def load_ibw(file_path: Path | str, channel: str) -> tuple[np.ndarray, float]:
Returns
-------
- tuple[np.ndarray, float]
- A tuple containing the image and its pixel to nanometre scaling value.
+ AFMLoad
+ An AFMLoad object containing the image and its pixel to nanometre scaling value.
Raises
------
@@ -68,7 +95,9 @@ def load_ibw(file_path: Path | str, channel: str) -> tuple[np.ndarray, float]:
not a typo!).
>>> from AFMReader.ibw import load_ibw
- >>> image, pixel_to_nanometre_scaling_factor = load_ibw(file_path="./my_ibw_file.ibw", channel="HeightTracee")
+ >>> afm_load = load_ibw(file_path="./my_ibw_file.ibw", channel="HeightTracee")
+ >>> image = afm_load.image
+ >>> pixel_to_nanometre_scaling_factor = afm_load.px2nm
"""
logger.info(f"Loading image from : {file_path}")
file_path = Path(file_path)
@@ -100,4 +129,4 @@ def load_ibw(file_path: Path | str, channel: str) -> tuple[np.ndarray, float]:
raise e
logger.info(f"[{filename}] : Extracted image.")
- return (image, _ibw_pixel_to_nm_scaling(scan))
+ return AFMLoad(image=image, px2nm=_ibw_pixel_to_nm_scaling(scan))
diff --git a/AFMReader/jpk.py b/AFMReader/jpk.py
index 4a13a61..9f1d185 100644
--- a/AFMReader/jpk.py
+++ b/AFMReader/jpk.py
@@ -2,16 +2,18 @@
from importlib import resources
from pathlib import Path
+from io import BytesIO
import numpy as np
import tifffile
from AFMReader.io import read_yaml
from AFMReader.logging import logger
+from AFMReader.data_classes import AFMLoad
logger.enable(__package__)
-# pylint: disable=too-many-locals
+# pylint: disable=too-many-locals,too-many-positional-arguments,fixme
def _jpk_pixel_to_nm_scaling(tiff_page: tifffile.tifffile.TiffPage, jpk_tags: dict[str, int]) -> float:
@@ -171,9 +173,72 @@ def _get_z_scaling(tif: tifffile.tifffile, channel_idx: int, jpk_tags: dict[str,
return scaling, offset
+def _get_jpk_channels(
+ file: Path | BytesIO, filename: str, file_path: Path | str, config_path: Path | str | None = None
+):
+ """
+ Retrieve the list of available channels from a JPK TIFF file.
+
+ Parameters
+ ----------
+ file : Path | BytesIO
+ Path to the JPK TIFF file.
+ filename : str
+ Name of the JPK TIFF file.
+ file_path : Path | str
+ Path to the JPK TIFF file.
+ config_path : Path | str | None, optional
+ Path to a configuration file. If ''None'' (default) then the packages
+ default configuration is loaded from ''default_config.yaml''.
+
+ Returns
+ -------
+ dict
+ Dictionary of available channels with their corresponding page indices.
+ """
+ jpk_tags = _load_jpk_tags(config_path)
+ try:
+ tif = tifffile.TiffFile(file)
+ except FileNotFoundError:
+ logger.error(f"[{filename}] File not found : {file_path}")
+ raise
+ # Obtain channel list for all channels in file
+ channel_list = {}
+ for i, page in enumerate(tif.pages[1:]): # [0] is thumbnail
+ available_channel = page.tags[jpk_tags["channel_name"]].value # keys are hexadecimal values
+ if page.tags[jpk_tags["trace_retrace"]].value == 0: # whether img is trace or retrace
+ tr_rt = "trace"
+ else:
+ tr_rt = "retrace"
+ channel_list[f"{available_channel}_{tr_rt}"] = i + 1
+ return channel_list
+
+
+def get_jpk_channels(file_path: Path | str, config_path: Path | str | None = None) -> list[str]:
+ """
+ Get the list of channels available in the .jpk file.
+
+ Parameters
+ ----------
+ file_path : Path | str
+ Path to the .jpk file.
+ config_path : Path | str | None
+ Path to a configuration file. If ''None'' (default) then the packages
+ default configuration is loaded from ''default_config.yaml''.
+
+ Returns
+ -------
+ list[str]
+ List of available channels.
+ """
+ file_path = Path(file_path)
+ filename = file_path.stem
+ return _get_jpk_channels(file_path, filename, file_path, config_path)
+
+
def load_jpk(
- file_path: Path | str, channel: str, config_path: Path | str | None = None, flip_image: bool | None = True
-) -> tuple[np.ndarray, float]:
+ file_path: Path | str, channel: str, config_path: Path | str | None = None, flip_image: bool = True
+) -> AFMLoad:
"""
Load image from JPK Instruments .jpk files.
@@ -186,13 +251,13 @@ def load_jpk(
config_path : Path | str | None
Path to a configuration file. If ''None'' (default) then the packages default configuration is loaded from
''default_config.yaml''.
- flip_image : bool, optional
+ flip_image : bool
Whether to flip the image vertically. Default is ``True``.
Returns
-------
- tuple[npt.NDArray, float]
- A tuple containing the image and its pixel to nanometre scaling value.
+ AFMLoad
+ An AFMLoad object containing the image and its pixel to nanometre scaling value.
Raises
------
@@ -206,18 +271,64 @@ def load_jpk(
Load height trace channel from the .jpk file. 'height_trace' is the default channel name.
>>> from AFMReader.jpk import load_jpk
- >>> image, pixel_to_nanometre_scaling_factor = load_jpk(file_path="./my_jpk_file.jpk",
- >>> channel="height_trace",
- >>> flip_image=True)
+ >>> afm_load = load_jpk(file_path="./my_jpk_file.jpk", channel="height_trace", flip_image=True)
+ >>> image = afm_load.image
+ >>> pixel_to_nanometre_scaling_factor = afm_load.px2nm
"""
logger.info(f"Loading image from : {file_path}")
file_path = Path(file_path)
filename = file_path.stem
+ image, px2nm = _load_jpk(
+ file=file_path,
+ filename=filename,
+ channel=channel,
+ file_suffix=file_path.suffix,
+ config_path=config_path,
+ flip_image=flip_image,
+ )
+ return AFMLoad(image=image, px2nm=px2nm)
+
+
+def _load_jpk(
+ file: Path | BytesIO,
+ filename: str,
+ channel: str,
+ file_suffix: str,
+ config_path: Path | str | None = None,
+ flip_image: bool = True,
+ convert_to_nm: bool = True,
+) -> tuple[np.ndarray, float]:
+ """
+ Load image data and pixel scaling from a JPK TIFF file for a given channel.
+
+ Parameters
+ ----------
+ file : Path | BytesIO
+ Path to the JPK TIFF file.
+ filename : str
+ Name of the JPK TIFF file.
+ channel : str
+ The channel to extract from the JPK TIFF file.
+ file_suffix : str
+ The file suffix of the JPK TIFF file.
+ config_path : Path | str | None, optional
+ Path to a configuration file. If ''None'' (default) then the packages default configuration is
+ loaded from ''default_config.yaml''.
+ flip_image : bool, optional
+ Whether to flip the image vertically. Default is True.
+ convert_to_nm : bool, optional
+ Whether to convert the image to nanometres. Default is True.
+
+ Returns
+ -------
+ tuple[np.ndarray, float]
+ A tuple containing the image and its pixel to nanometre scaling value.
+ """
jpk_tags = _load_jpk_tags(config_path)
try:
- tif = tifffile.TiffFile(file_path)
+ tif = tifffile.TiffFile(file)
except FileNotFoundError:
- logger.error(f"[{filename}] File not found : {file_path}")
+ logger.error(f"[{filename}] File not found : {file}")
raise
# Obtain channel list for all channels in file
channel_list = {}
@@ -231,8 +342,8 @@ def load_jpk(
try:
channel_idx = channel_list[channel]
except KeyError as e:
- logger.error(f"'{channel}' not in {file_path.suffix} channel list: {channel_list}")
- raise ValueError(f"'{channel}' not in {file_path.suffix} channel list: {channel_list}") from e
+ logger.error(f"'{channel}' not in {file_suffix} channel list: {channel_list}")
+ raise ValueError(f"'{channel}' not in {file_suffix} channel list: {channel_list}") from e
# Get image and if applicable, scale it
channel_page = tif.pages[channel_idx]
@@ -242,7 +353,7 @@ def load_jpk(
if flip_image is True:
image = np.flipud(image)
- if channel_page.tags[jpk_tags["channel_name"]].value in ("height", "measuredHeight", "amplitude"):
+ if convert_to_nm and channel_page.tags[jpk_tags["channel_name"]].value in ("height", "measuredHeight", "amplitude"):
image = image * 1e9
# Get page for common metadata between scans
diff --git a/AFMReader/jpk_qi.py b/AFMReader/jpk_qi.py
new file mode 100644
index 0000000..2338ff4
--- /dev/null
+++ b/AFMReader/jpk_qi.py
@@ -0,0 +1,1368 @@
+"""
+Module to decode and load JPK QI (Quantitative Imaging) data files.
+
+It provides lazy loading for curve data and metadata to minimize memory usage,
+and supports exporting to HDF5 format.
+"""
+
+# pylint: disable=too-many-lines,too-many-positional-arguments,too-few-public-methods,too-many-instance-attributes
+# pylint: disable=too-many-locals,too-many-branches,protected-access,attribute-defined-outside-init,fixme
+# pylint: disable=too-many-arguments
+
+import os
+import io
+import zipfile
+import time
+from pathlib import Path
+from contextlib import nullcontext
+from typing import Any
+
+import numpy as np
+import javaproperties
+import h5py
+import psutil
+
+from AFMReader.data_classes import AFMLoad, CurvesMetadata, CurvesVolume, CurvesDataset
+from AFMReader.logging import logger
+from AFMReader import jpk
+
+
+class CurvesJPKDataset(CurvesDataset):
+ """
+ A dataset class for JPK QI data that holds the raw data as well as metadata.
+
+ Parameters
+ ----------
+ volumes : dict[str, CurvesVolume]
+ A dictionary mapping curve names to CurvesVolume instances that
+ provide access to the curve data for each pixel.
+ metadata : CurvesMetadata
+ An instance of CurvesMetadata that provides access to the metadata
+ for each curve.
+ archive : zipfile.ZipFile
+ The ZIP archive containing the JPK data.
+ """
+
+ def __init__(self, volumes: dict[str, CurvesVolume], metadata: CurvesMetadata, archive: zipfile.ZipFile):
+ """
+ Initialise CurvesJPKDataset.
+
+ Parameters
+ ----------
+ volumes : dict[str, CurvesVolume]
+ A dictionary mapping curve names to CurvesVolume instances that
+ provide access to the curve data for each pixel.
+ metadata : CurvesMetadata
+ An instance of CurvesMetadata that provides access to the metadata
+ for each curve.
+ archive : zipfile.ZipFile
+ The ZIP archive containing the JPK data.
+ """
+ super().__init__(volumes, metadata)
+ self.archive = archive
+
+ def close(self):
+ """Close the ZIP archive when done to free up resources."""
+ self.archive.close()
+
+
+class CurvesJPKMetadata(CurvesMetadata):
+ """
+ A metadata class for JPK QI data that provides lazy loading of pixel metadata.
+
+ Parameters
+ ----------
+ toplevel : dict
+ A dictionary containing the top-level metadata for the dataset.
+ archive : zipfile.ZipFile
+ The ZIP archive containing the JPK data.
+ shape_x : int
+ Number of columns in the image.
+ shape_y : int
+ Number of rows in the image.
+ flip_image : bool, optional
+ Whether to flip the image vertically. Default is True.
+ """
+
+ def __init__(
+ self,
+ toplevel: dict,
+ archive: zipfile.ZipFile,
+ shape_x: int,
+ shape_y: int,
+ flip_image: bool = True,
+ ):
+ """
+ Initialize the CurvesJPKMetadata instance.
+
+ Parameters
+ ----------
+ toplevel : dict
+ A dictionary containing the top-level metadata for the dataset.
+ archive : zipfile.ZipFile
+ The ZIP archive containing the JPK data.
+ shape_x : int
+ Number of columns in the image.
+ shape_y : int
+ Number of rows in the image.
+ flip_image : bool, optional
+ Whether to flip the image vertically. Default is ``True``.
+ """
+ super().__init__(toplevel, shape_x, shape_y, flip_image)
+ self.archive = archive
+
+ def get_point_metadata(self, y: int, x: int, direction: int | None = None):
+ """
+ Fetch the metadata for a specific pixel or direction.
+
+ Parameters
+ ----------
+ y : int
+ Row index of the pixel.
+ x : int
+ Column index of the pixel.
+ direction : int, optional
+ The index of the direction to fetch metadata for. If None, returns metadata for the entire pixel.
+
+ Returns
+ -------
+ dict
+ The metadata for the specified pixel (or direction, if provided).
+ """
+ if y < 0 or y >= self.shape_y or x < 0 or x >= self.shape_x:
+ raise IndexError(f"Curve index out of bounds: ({x}, {y})")
+ if self.flip_image:
+ y = self.shape_y - 1 - y
+ idx = (y * self.shape_x) + x
+ if direction is None:
+ path = f"index/{idx}/header.properties"
+ else:
+ path = f"index/{idx}/segments/{direction}/segment-header.properties"
+
+ try:
+ with self.archive.open(path) as f:
+ meta_dict = {".".join(k.split(".")[1:]): v for k, v in javaproperties.load(f).items()}
+ except KeyError:
+ meta_dict = {}
+
+ return meta_dict
+
+
+class CurvesJPKVolume(CurvesVolume):
+ """
+ A CurvesVolume implementation for JPK QI curve data that provides lazy loading of curve data for each pixel.
+
+ Parameters
+ ----------
+ name : str
+ The name of the curve volume.
+ shape_x : int
+ The number of columns in the image.
+ shape_y : int
+ The number of rows in the image.
+ archive : zipfile.ZipFile
+ The ZIP archive containing the JPK data.
+ channel_scaling : dict[str, dict[str, float]]
+ A dictionary mapping channel names to their scaling factors.
+ channel_units : dict[str, str]
+ A dictionary mapping channel names to their units.
+ flip_image : bool, optional
+ Whether to flip the image vertically. Default is True.
+ """
+
+ def __init__(
+ self,
+ name: str,
+ shape_x: int,
+ shape_y: int,
+ archive: zipfile.ZipFile,
+ channel_scaling: dict[str, dict[str, float]],
+ channel_units: dict[str, str],
+ flip_image: bool = True,
+ ):
+ """
+ Initialise CurvesJPKVolume.
+
+ Parameters
+ ----------
+ name : str
+ The name of the curve volume.
+ shape_x : int
+ The number of columns in the image.
+ shape_y : int
+ The number of rows in the image.
+ archive : zipfile.ZipFile
+ The ZIP archive containing the JPK data.
+ channel_scaling : dict[str, dict[str, float]]
+ A dictionary mapping channel names to their scaling factors.
+ channel_units : dict[str, str]
+ A dictionary mapping channel names to their units.
+ flip_image : bool, optional
+ Whether to flip the image vertically. Default is True.
+ """
+ super().__init__(
+ name=name,
+ shape_x=shape_x,
+ shape_y=shape_y,
+ channel_units=channel_units,
+ flip_image=flip_image,
+ )
+ self.archive = archive
+ self.channel_scaling = channel_scaling
+
+ def __iter__(self):
+ """Yield the curve data for each pixel in the image, iterating in row-major order (y first, then x)."""
+ for y in range(self.shape_y):
+ for x in range(self.shape_x):
+ yield self.get_curve(y, x)
+
+ def get_curve(self, y: int, x: int):
+ """
+ Fetch the curve data for a specific pixel.
+
+ Parameters
+ ----------
+ y : int
+ Row index of the pixel.
+ x : int
+ Column index of the pixel.
+
+ Returns
+ -------
+ dict
+ Dictionary containing the curve data for the specified pixel.
+ """
+ if y < 0 or y >= self.shape_y or x < 0 or x >= self.shape_x:
+ raise IndexError(f"Curve index out of bounds: ({x}, {y})")
+ if self.flip_image:
+ y = self.shape_y - 1 - y
+ curve_num = y * self.shape_x + x
+ curve_data: dict[str, Any] = {}
+
+ for chan_name, scale in self.channel_scaling.items():
+ curve_data[chan_name] = {}
+ for direction in (0, 1):
+ dat_path = f"index/{curve_num}/segments/{direction}/channels/{chan_name}.dat"
+ try:
+ # Access the file directly without re-parsing the ZIP directory
+ with self.archive.open(dat_path) as f:
+ raw_array = np.frombuffer(f.read(), dtype=">i4")
+ curve_data[chan_name][f"Segment_{direction}"] = (raw_array * scale["multiplier"]) + scale[
+ "offset"
+ ]
+ except KeyError:
+ pass # File doesn't exist for this segment
+
+ return curve_data
+
+
+def _get_channel_scaling(props, channel_index):
+ """
+ Parse the JPK properties dictionary to find cumulative multiplier and offset for a specific channel index.
+
+ Parameters
+ ----------
+ props : dict
+ The properties dictionary loaded from the JPK file.
+ channel_index : str
+ The index of the channel to find the scaling for (e.g., '1' for vDeflection).
+
+ Returns
+ -------
+ final_multiplier : float
+ The cumulative multiplier for the specified channel.
+ final_offset : float
+ The cumulative offset for the specified channel.
+ unit : str
+ The unit of the channel.
+ """
+ prefix = f"lcd-info.{channel_index}."
+
+ current_slot = props.get(f"{prefix}conversion-set.conversions.default")
+
+ if not current_slot:
+ mult = float(props.get(f"{prefix}encoder.scaling.multiplier", "1.0"))
+ off = float(props.get(f"{prefix}encoder.scaling.offset", "0.0"))
+ unit = props.get(f"{prefix}encoder.scaling.unit.unit", "Unknown")
+ return mult, off, unit
+
+ cumulative_multiplier = 1.0
+ cumulative_offset = 0.0
+ unit = props.get(f"{prefix}conversion-set.conversion.{current_slot}.scaling.unit.unit")
+
+ while current_slot:
+ slot_prefix = f"{prefix}conversion-set.conversion.{current_slot}."
+
+ if f"{slot_prefix}scaling.multiplier" in props:
+ m = float(props[f"{slot_prefix}scaling.multiplier"])
+ c = float(props[f"{slot_prefix}scaling.offset"])
+
+ cumulative_offset = (cumulative_multiplier * c) + cumulative_offset
+ cumulative_multiplier *= m
+
+ current_slot = props.get(f"{slot_prefix}base-calibration-slot")
+
+ if current_slot == props.get(f"{prefix}conversion-set.conversions.base"):
+ break
+ else:
+ break
+
+ enc_m = float(props.get(f"{prefix}encoder.scaling.multiplier", "1.0"))
+ enc_c = float(props.get(f"{prefix}encoder.scaling.offset", "0.0"))
+
+ final_multiplier = cumulative_multiplier * enc_m
+ final_offset = (cumulative_multiplier * enc_c) + cumulative_offset
+ if not unit:
+ unit = props.get(f"{prefix}encoder.scaling.unit.unit", "Unknown")
+
+ return final_multiplier, final_offset, unit
+
+
+def _make_num_min_characters(num: int, min_chars: int = 3):
+ """
+ Zero-pad an integer to a minimum number of characters.
+
+ Parameters
+ ----------
+ num : int
+ The integer to pad.
+ min_chars : int
+ The minimum number of characters the resulting string should have. Default is 3.
+
+ Returns
+ -------
+ str
+ The zero-padded string.
+ """
+ string_num = str(num)
+ if len(string_num) >= min_chars:
+ return string_num
+ return "0" * (min_chars - len(string_num)) + string_num
+
+
+class JPKQILoader:
+ """
+ Class for readability and improving modularity in the load jpk qi data function.
+
+ Parameters
+ ----------
+ filepath : Path | str
+ The path to the .jpk-qi file to be loaded.
+ channel : str | None, optional
+ The specific channel to be extracted (e.g., "measuredHeight"). Default is None.
+ config_path : Path | str | None, optional
+ The path to the configuration file, if any. Default is None.
+ flip_image : bool | None, optional
+ Whether to flip the image vertically. Default is True.
+ save_as_h5 : bool, optional
+ Whether to save the loaded data as an H5 file. Default is False.
+ """
+
+ def __init__(
+ self,
+ filepath: Path | str,
+ channel: str | None = None,
+ config_path: Path | str | None = None,
+ flip_image: bool | None = True,
+ save_as_h5: bool = False,
+ ):
+ """
+ Initialize the loader with the provided parameters.
+
+ Parameters
+ ----------
+ filepath : Path | str
+ The path to the .jpk-qi file to be loaded.
+ channel : str | None, optional
+ The specific channel to be extracted (e.g., "measuredHeight"). Default is None.
+ config_path : Path | str | None, optional
+ The path to the configuration file, if any. Default is None.
+ flip_image : bool | None, optional
+ Whether to flip the image vertically. Default is True.
+ save_as_h5 : bool, optional
+ Whether to save the loaded data as an H5 file. Default is False.
+ """
+ self.filepath = Path(filepath)
+ self.channel = channel
+ self.config_path = config_path
+ self.flip_image = flip_image
+ self.save_as_h5 = save_as_h5
+
+ # Open the ZIP archive once and keep it open for the duration of the loading process
+ self.qi_archive = zipfile.ZipFile(self.filepath, "r") # pylint: disable=consider-using-with
+ logger.info(f"Opened JPK QI archive at {self.filepath}")
+ # Store the list of all paths in the archive to avoid having to call namelist() multiple times
+ self.list_of_all_paths = self.qi_archive.namelist()
+ # For holding the reference to where the actual .jqk-qi image is (not the metadata).
+ self.path_to_image = None
+
+ # Chunk size for H5 datasets
+ self.DATA_CHUNKSIZE = 512 * 1024
+ # Chunk size for indices datasets
+ self.INDICES_CHUNKSIZE = 64 * 1024
+ # Chunk size for metadata datasets (if needed)
+ self.META_CHUNKSIZE = 64 * 1024
+ # Maximum number of curves to check for changing metadata keys (to avoid checking every curve)
+ self.MAX_CURVE_CHECKS = 20
+ # Number of curves to hold in buffer
+ self.BUFFER_SIZE = 500
+
+ # Initialize key attributes that will be returned / accessed frequently
+
+ # Just the top level metadata extracted from the header files
+ self.top_level_meta: dict[str, Any] = {}
+ # A lazy reference containing all metadata
+ self.full_metadata: CurvesJPKMetadata | None = None
+ # A 2D list of curve data dictionaries
+ self.curves_volume: CurvesJPKVolume | None = None
+ # A lookup for channel name to unit to be returned
+ self.channels_units: dict[str, str] = {}
+ # The list of channels for the segments with their scaling information extracted from the shared header
+ self.segment_channels: list[dict[str, Any]] = []
+ self.curve_meta: dict[str, Any] = {}
+ self.segment_meta: dict[str, Any] = {}
+
+ # Define the image shape and size attributes
+ self.size_x: float | None = None
+ self.size_y: float | None = None
+ self.shape_x: int | None = None
+ self.shape_y: int | None = None
+ self.failed_curves: set[tuple[int, int | None, str | None]] = set()
+
+ # Instantiate containers for data to be saved (so an exception is not caused if not saving)
+ self.curve_groups = None
+ self.saved_to_h5 = False
+
+ def get_available_channels(self):
+ """
+ Retrieve available channels from the .jpk-qi-image file within the archive.
+
+ Returns
+ -------
+ channels : list
+ A list of available channels including the calculated channels.
+ metadata_options : dict
+ A dictionary of options for what metadata to return.
+ """
+ # Look for the jpk-qi-image file in the archive
+ if self.path_to_image is None:
+ for file_name in self.list_of_all_paths:
+ if file_name.endswith(".jpk-qi-image"):
+ self.path_to_image = file_name
+
+ # Add the channels which exist in the jpk-qi-image file
+ with self.qi_archive.open(self.path_to_image, "r") as image_file:
+ channels = jpk._get_jpk_channels(
+ file=image_file, filename=self.filepath.stem, file_path=self.filepath / Path(self.path_to_image)
+ )
+ return channels, {"save_as_h5": bool}
+
+ def load(
+ self,
+ channel: str | None = None,
+ config_path: Path | str | None = None,
+ flip_image: bool | None = True,
+ save_as_h5: bool | None = None,
+ ) -> AFMLoad:
+ """
+ Load the .jpk-qi-data file.
+
+ Parameters
+ ----------
+ channel : str | None, optional
+ The specific channel to be extracted. Default is None.
+ config_path : Path | str | None, optional
+ Path to the configuration file. Default is None.
+ flip_image : bool | None, optional
+ Whether to flip the image. Default is True.
+ save_as_h5 : bool, optional
+ Whether to save the data as an H5 file. Default is False.
+
+ Returns
+ -------
+ AFMLoad
+ An AFMLoad object containing the image, its pixel to nanometre scaling value, and curves dataset.
+ """
+ # Update instance attributes based on provided parameters
+ self.channel = channel if channel else self.channel
+ self.config_path = config_path if config_path else self.config_path
+ self.flip_image = flip_image if flip_image is not None else self.flip_image
+ self.save_as_h5 = save_as_h5 if save_as_h5 is not None else self.save_as_h5
+
+ if self.save_as_h5:
+ self.h5_path = self.filepath.parent / f"{self.filepath.stem}.h5-jpk"
+ i = 0
+ while self.h5_path.exists():
+ self.h5_path = self.filepath.parent / f"{self.filepath.stem}_{i}.h5-jpk"
+ i += 1
+
+ logger.info(f"Loading JPK QI data from {self.filepath} with channel {self.channel}")
+ self.extract_global_metadata()
+
+ self.parse_dimension_data()
+
+ # Setup H5 Data structures if needed
+ if self.save_as_h5 and not self.saved_to_h5:
+ self.save_to_h5()
+
+ # Establish the lazy loading structures for curve data and metadata. Note how lazy structure is used even if
+ # all the data has been accessed and saved to H5 to prevent excessive memory usage
+ self.full_metadata = CurvesJPKMetadata(
+ self.top_level_meta,
+ self.qi_archive,
+ self.shape_x or 0,
+ self.shape_y or 0,
+ flip_image=bool(self.flip_image),
+ )
+ self.curves_volume = CurvesJPKVolume(
+ name="Trace",
+ shape_x=self.shape_x or 0,
+ shape_y=self.shape_y or 0,
+ archive=self.qi_archive,
+ channel_scaling=self.channel_scaling,
+ channel_units=self.channels_units,
+ flip_image=bool(self.flip_image),
+ )
+ self.curves_dataset = CurvesJPKDataset(
+ volumes={"Trace": self.curves_volume}, metadata=self.full_metadata, archive=self.qi_archive
+ )
+
+ # Load the image
+ self.image, _ = self.get_image()
+
+ return AFMLoad(image=self.image, px2nm=self.px2nm, curves_dataset=self.curves_dataset)
+
+ def output_summary(self):
+ """Output a summary of the loading process, including any failed curve loads and their details."""
+ if self.failed_curves:
+ logger.warning(f"Failed to load {len(self.failed_curves)} files.")
+ logger.warning("Summary of missing files (up to 10 shown):")
+
+ # Output the first 10 failed loads with details
+ for i, (curve_num, direction, chan_name) in enumerate(self.failed_curves):
+ if i < 10:
+ if chan_name:
+ logger.warning(
+ f"Failed to load data for curve {curve_num}, direction {direction}, channel {chan_name}"
+ )
+ else:
+ if direction is not None:
+ logger.warning(
+ f"Failed to load segment meta file for curve {curve_num}, direction {direction}"
+ )
+ else:
+ logger.warning(f"Failed to load curve meta file for curve {curve_num}")
+ else:
+ break
+ else:
+ # If there are no failed loads, log that all data was loaded successfully
+ logger.info("Successfully loaded all curve data without any missing files.")
+
+ def extract_data_to_h5(
+ self, h5_datasets, h5_meta_datasets, h5_datasets_buffer, h5_meta_datasets_buffer, include_metadata: bool = True
+ ):
+ """
+ Load all curve data and optionally metadata from the JPK QI archive into HDF5 datasets.
+
+ Parameters
+ ----------
+ h5_datasets : dict
+ Dictionary of HDF5 datasets for storing curve data.
+ h5_meta_datasets : dict
+ Dictionary of HDF5 datasets for storing metadata.
+ h5_datasets_buffer : dict
+ Dictionary of buffers for HDF5 curve data.
+ h5_meta_datasets_buffer : dict
+ Dictionary of buffers for HDF5 metadata.
+ include_metadata : bool, optional
+ Whether to include metadata in the loading process, by default True.
+ """
+ logger.info(
+ f"Loading all curve data from JPK QI archive with {len(self.list_of_all_paths)} files "
+ f"{'' if include_metadata else 'not '}including metadata"
+ )
+ progress_counter = 0
+ process = psutil.Process(os.getpid())
+ if include_metadata:
+ # Prepare keys for metadata to speed up processing
+ curve_work = [
+ (f"{k}=".encode(), h5_meta_datasets[f"curve.{k}"], h5_meta_datasets_buffer[f"curve.{k}"])
+ for k in self.changing_curve_keys
+ ]
+ seg_work = [
+ (f"{k}=".encode(), h5_meta_datasets[f"segment.{k}"], h5_meta_datasets_buffer[f"segment.{k}"])
+ for k in self.changing_segment_keys
+ ]
+ for curve_num in range(self.num_of_curves):
+ # Output progress every 1000 curves to give some indication of how long the loading is taking
+ if progress_counter % 1000 == 0:
+ mem = process.memory_info().rss / 1024 / 1024
+ logger.info(
+ f"Progress: {progress_counter}/{self.num_of_curves} curves processed, Memory usage: {mem:.2f} MB"
+ )
+ progress_counter += 1
+
+ for direction in range(2):
+ for chan in self.segment_channels:
+ # Save the actual curve data to the h5 datasets
+ self.extract_dat_file(
+ h5_datasets=h5_datasets,
+ h5_datasets_buffer=h5_datasets_buffer,
+ curve_num=curve_num,
+ direction=direction,
+ chan_name=chan["name"],
+ )
+
+ if include_metadata:
+ # Extract and store the segment metadata for later saving
+ self.extract_segment_metadata(curve_num=curve_num, direction=direction, seg_work=seg_work)
+
+ if include_metadata:
+ # Extract and store the curve metadata for later saving
+ self.extract_curve_metadata(curve_num=curve_num, curve_work=curve_work)
+
+ # Add the last index to the indices datasets to mark the end of the last curve
+ for direction in range(2):
+ seg_name = f"Segment_{direction}"
+ for chan in self.segment_channels:
+ chan_name = chan["name"]
+ current_dataset = h5_datasets[seg_name][chan_name]["Data"]
+ indices_dataset = h5_datasets[seg_name][chan_name]["Indices"]
+ indices_dataset[-1] = current_dataset.shape[0]
+
+ def save_to_h5(
+ self,
+ include_metadata: bool = True,
+ ):
+ """
+ Save data as an H5 file. If include_metadata is False, only curve data is saved.
+
+ Parameters
+ ----------
+ include_metadata : bool, optional
+ If True, metadata will be included in the saved H5 file. Default is True.
+ """
+ with self.get_saving_context() as file:
+
+ t0 = time.perf_counter()
+
+ # Sample curves in dataset to make a best guess for the meta keys
+ self.changing_curve_keys, self.changing_segment_keys = self.get_changing_keys()
+ self.points_for_channel_segment = self.predict_total_points()
+ self.t_changing_keys = time.perf_counter() - t0
+
+ # Setup H5 structure for saving the data
+ global_meta_group, h5_datasets, h5_meta_datasets, h5_datasets_buffer, h5_meta_datasets_buffer = (
+ self.setup_h5_structure(file)
+ )
+
+ # Set up current_offsets to keep track of how many points have been read
+ self.current_offsets: dict[int, dict[str, int]] = {}
+ for direction in range(2):
+ self.current_offsets[direction] = {}
+ for chan in self.segment_channels:
+ self.current_offsets[direction][chan["name"]] = 0
+
+ # Reset points_for_channel_segment to 0 to store actual number of points
+ self.points_for_channel_segment[direction][chan["name"]] = 0
+
+ # Extract data from the JPK QI archive and save to H5 datasets
+ self.extract_data_to_h5(
+ h5_datasets,
+ h5_meta_datasets,
+ h5_datasets_buffer,
+ h5_meta_datasets_buffer,
+ include_metadata=include_metadata,
+ )
+ # Resize the datasets to the actual number of points read
+ for direction in range(2):
+ for chan in self.segment_channels:
+ h5_datasets[f"Segment_{direction}"][chan["name"]]["Data"].resize(
+ (self.points_for_channel_segment[direction][chan["name"]],)
+ )
+
+ self.output_summary()
+
+ if include_metadata:
+ # Save the global metadata to the h5 file
+ for key, value in self.get_collated_metadata().items():
+ global_meta_group.attrs[key] = str(value).encode("utf-8")
+
+ logger.info(f"QI data copied to h5 data {file.filename}")
+ # Save a lite form of the images (precalculated) if saving to a file
+ self.save_lite_data()
+ self.saved_to_h5 = True
+
+ def get_curves_sample(self):
+ """
+ Get a sample of curve numbers distrubuted evenly across the dataset.
+
+ Returns
+ -------
+ range:
+ A range object representing the sampled curve numbers.
+ """
+ # Check evenly spaced curves in the dataset to sample metadata without having to load every curve
+ step = 1 if self.num_of_curves <= self.MAX_CURVE_CHECKS else self.num_of_curves // self.MAX_CURVE_CHECKS
+ # If the step is equal to a shape dimension, we might just go down the row or column
+ while step in [self.shape_x, self.shape_y] and step > 1:
+ # So make the step slightly smaller (more checks) to ensure we get a good sample
+ step -= 1
+ return range(0, self.num_of_curves, step)
+
+ def predict_total_points(self):
+ """
+ Predict the total number of points for each channel and segment.
+
+ This is done by sampling a subset of curves and extrapolating based on the maximum number
+ of points found in the sample.
+
+ Returns
+ -------
+ dict:
+ A dictionary containing the predicted total points for each channel and segment.
+ """
+ # Get a sample of curve (indices)
+ curves_to_check = self.get_curves_sample()
+ points_for_channel_segment = {}
+
+ # Iterate through the segments, channels and our curve indices
+ for direction in range(2):
+ points_for_channel_segment[direction] = {}
+ for channel in self.segment_channels:
+ points_for_channel_segment[direction][channel["name"]] = []
+ for curve_num in curves_to_check:
+ # Loop until we successfully retrieve some data
+ while True:
+ dat_path = f"index/{curve_num}/segments/{direction}/channels/{channel['name']}.dat"
+ try:
+ # Count points in extracted data
+ with self.qi_archive.open(dat_path) as f:
+ raw_array = np.frombuffer(f.read(), dtype=">i4")
+ points_for_channel_segment[direction][channel["name"]].append(len(raw_array))
+ break
+
+ except KeyError:
+ # If the file doesn't exist for this curve, check the next curve so we don't just get
+ # a smaller sample
+ if curve_num + 1 >= self.num_of_curves:
+ # If we've gone past the number of curves, stop checking
+ break
+ curve_num += 1
+ continue
+ # Calculate a prediction for total number of points based on maximum number of points then assuming
+ # maximum points throughout data is no more than 10% higher
+ points_for_channel_segment[direction][channel["name"]] = (
+ int(np.max(points_for_channel_segment[direction][channel["name"]]) * 1.1) * self.num_of_curves
+ )
+ return points_for_channel_segment
+
+ def get_changing_keys(self): # noqa: C901
+ """
+ Check a sample of curves to see which metadata keys change across curves and segments.
+
+ This allows us to extract only the changing keys for each curve and segment.
+ Non-changing keys are moved to the top-level metadata and not extracted for each curve/segment.
+
+ Returns
+ -------
+ tuple:
+ A tuple containing two sets: changing_curve_keys and changing_segment_keys.
+ """
+ curve_meta_dict: dict[str, list[Any]] = {}
+ segment_meta_dict: dict[str, list[Any]] = {}
+ curves_to_check = self.get_curves_sample()
+ for curve_num in curves_to_check:
+ for direction in range(2):
+ while True:
+ meta_path = f"index/{curve_num}/segments/{direction}/segment-header.properties"
+ try:
+ with self.qi_archive.open(meta_path) as f:
+ meta_dict = {".".join(k.split(".")[1:]): v for k, v in javaproperties.load(f).items()}
+ for k, v in meta_dict.items():
+ if k not in segment_meta_dict:
+ segment_meta_dict[k] = []
+ segment_meta_dict[k].append(v)
+ break
+
+ except KeyError:
+ if curve_num + 1 >= self.num_of_curves:
+ break # If we've gone past the number of curves, stop checking
+ curve_num += 1
+ continue
+ meta_path = f"index/{curve_num}/header.properties"
+ while True:
+ try:
+ with self.qi_archive.open(meta_path) as f:
+ meta_dict = {".".join(k.split(".")[1:]): v for k, v in javaproperties.load(f).items()}
+ for k, v in meta_dict.items():
+ if k not in curve_meta_dict:
+ curve_meta_dict[k] = []
+ curve_meta_dict[k].append(v)
+ break
+ except KeyError:
+ if curve_num + 1 >= self.num_of_curves:
+ break # If we've gone past the number of curves, stop checking
+ curve_num += 1
+ continue
+
+ changing_curve_keys, changing_segment_keys = set(), set()
+ for key, values in curve_meta_dict.items():
+ if len({v for v in values if v is not None}) > 1:
+ changing_curve_keys.add(key)
+ else:
+ # If the key does not change across curves, move it to the top level metadata
+ self.top_level_meta[f"curve.{key}"] = values[0]
+ for key, values in segment_meta_dict.items():
+ if len({v for v in values if v is not None}) > 1:
+ changing_segment_keys.add(key)
+ else:
+ # If the key does not change across segments, move it to the top level metadata
+ self.top_level_meta[f"segment.{key}"] = values[0]
+ return changing_curve_keys, changing_segment_keys
+
+ def get_collated_metadata(self):
+ """
+ Collate metadata from being split by curve to being split by attribute.
+
+ Returns
+ -------
+ dict
+ A dictionary containing the collated metadata.
+ """
+ collated_meta = {}
+ for seg_chan in self.segment_channels:
+ collated_meta[f"channel.unit.{seg_chan['name']}"] = seg_chan["unit"]
+ for key, value in self.top_level_meta.items():
+ collated_meta[key] = value
+ return collated_meta
+
+ def get_image(
+ self, overide_channel: str | None = None, convert_to_nm: bool = True, flip_image: bool | None = None
+ ) -> tuple[np.ndarray, float]:
+ """
+ Process the flat curve data dictionary into a 2D list structure matching the image dimensions.
+
+ Parameters
+ ----------
+ overide_channel : str | None, optional
+ Channel name to use instead of the instance default. Default is None.
+ convert_to_nm : bool, optional
+ Whether to convert the image data to nanometres. Default is True.
+ flip_image : bool | None, optional
+ Whether to flip the image vertically. Defaults to the instance setting if None.
+
+ Returns
+ -------
+ tuple[np.ndarray, float]
+ A 2D array representing the image data and the pixel-to-nm scaling factor.
+ """
+ # Get channel and flip_image parameters
+ channel = str(overide_channel) if overide_channel else str(self.channel)
+
+ if flip_image is None:
+ flip_image = bool(self.flip_image)
+
+ # Search through the namelist to find the .jpk-qi-image file
+ path_to_image = None
+ for file_name in self.list_of_all_paths:
+ if file_name.endswith(".jpk-qi-image"):
+ path_to_image = file_name
+ if path_to_image is None:
+ raise FileNotFoundError(f"{path_to_image} not found in JPK archive")
+
+ # Read the .jpk-qi-image file as bytes and load it using the existing jpk loading function
+ tif_bytes = self.qi_archive.read(path_to_image)
+
+ virtual_file = io.BytesIO(tif_bytes)
+ logger.info(f"Looking for channel {channel} in {path_to_image}")
+ return jpk._load_jpk(
+ virtual_file,
+ path_to_image,
+ channel=channel,
+ file_suffix=".jpk-qi-data",
+ config_path=self.config_path,
+ convert_to_nm=convert_to_nm,
+ flip_image=bool(flip_image),
+ )
+
+ def save_lite_data(self):
+ """Save a lite form of the data (e.g., the calculated image data) to H5."""
+ with h5py.File(self.h5_path, "a") as h5file:
+ # Save data required for reading the h5 file as a normal image file
+ meas_grp = h5file.require_group("Measurement_000")
+ # Save dimensions data
+ meas_grp.attrs["position-pattern.grid.ulength"] = self.size_x
+ meas_grp.attrs["position-pattern.grid.ilength"] = self.shape_x
+ meas_grp.attrs["position-pattern.grid.vlength"] = self.size_y
+ meas_grp.attrs["position-pattern.grid.jlength"] = self.shape_y
+ meas_grp.attrs["timing-settings.scanRate"] = 1.0 # Dummy value to satisfy reader
+
+ logger.info(f"Saving a hdf5 copy of the data {self.h5_path}")
+
+ h5_channels = [self.channel]
+ # Look for the jpk-qi-image file in the archive
+ path_to_image = None
+ for file_name in self.list_of_all_paths:
+ if file_name.endswith(".jpk-qi-image"):
+ path_to_image = file_name
+ break
+ # Add the channels which exist in the jpk-qi-image file
+ if path_to_image:
+ with self.qi_archive.open(path_to_image, "r") as image_file:
+ h5_channels += jpk._get_jpk_channels(
+ file=image_file, filename=self.filepath.stem, file_path=self.filepath / Path(path_to_image)
+ )
+ for i, h5_channel in enumerate(h5_channels):
+ # For each available channel, save the required data to the h5 file
+ # TODO make sure this metadata is accurate for the channels coming from the .jpk-qi-image file
+ chan_grp = meas_grp.require_group(f"Channel_{_make_num_min_characters(i)}")
+ # Extract name and retrace information from the channel name
+ if h5_channel and "_" in str(h5_channel):
+ base_name, trace_dir = str(h5_channel).rsplit("_", 1)
+ is_retrace = "true" if trace_dir.lower() == "retrace" else "false"
+ else:
+ base_name = h5_channel
+ is_retrace = "false"
+
+ # Add the necessary attributes to the channel group
+ chan_grp.attrs["channel.name"] = base_name.encode("utf-8")
+ chan_grp.attrs["retrace"] = is_retrace.encode("utf-8")
+ chan_grp.attrs["net-encoder.scaling.multiplier"] = 1.0
+ chan_grp.attrs["net-encoder.scaling.offset"] = 0.0
+
+ # Format name and reshape image (flattened frame stack)
+ dataset_name = h5_channel.split("_")[0].capitalize()
+ # Include all the channels including the calculated channel
+ # TODO make this slightly faster by remembering we have load a channel already but
+ # difficult cause of scaling
+ channel_image, _ = self.get_image(overide_channel=h5_channel, convert_to_nm=False, flip_image=False)
+ frame_stack = channel_image.flatten().reshape(-1, 1)
+
+ # Update/ replace the channels dataset
+ if dataset_name in chan_grp:
+ del chan_grp[dataset_name]
+ chan_grp.create_dataset(dataset_name, data=frame_stack)
+
+ def extract_dat_file(self, h5_datasets, h5_datasets_buffer, curve_num: int, direction: int, chan_name: str):
+ """
+ Extract the data from a .dat file in the JPK QI archive.
+
+ Applies the appropriate scaling and saves it to the internal data structure and h5 dataset if required.
+
+ Parameters
+ ----------
+ h5_datasets : dict
+ A dictionary containing the h5 datasets for each channel and segment direction, used for saving the
+ data.
+ h5_datasets_buffer : dict
+ A dictionary containing the buffer for each h5 dataset, used for temporary storage before writing to
+ the dataset.
+ curve_num : int
+ The curve number associated with the .dat file, parsed from the filename.
+ direction : int
+ The segment direction (0 or 1) associated with the .dat file, parsed from the filename.
+ chan_name : str
+ The channel name associated with the .dat file, parsed from the filename.
+ """
+ if chan_name in self.channel_scaling:
+ # Get data structures for this channel and segment
+ scale = self.channel_scaling[chan_name]
+ dat_path = f"index/{curve_num}/segments/{direction}/channels/{chan_name}.dat"
+ data_set = h5_datasets[f"Segment_{direction}"][chan_name]["Data"]
+ indices_set = h5_datasets[f"Segment_{direction}"][chan_name]["Indices"]
+ data_size = data_set.shape[0]
+ buf = h5_datasets_buffer[f"Segment_{direction}"][chan_name]
+ filled_size = self.points_for_channel_segment[direction][chan_name]
+ start_offset = self.current_offsets[direction][chan_name]
+
+ try:
+ with self.qi_archive.open(dat_path) as f:
+ # Read binary data as big-endian 32-bit integers
+ raw_bytes = f.read()
+
+ raw_array = np.frombuffer(raw_bytes, dtype=">i4")
+
+ # Apply scaling to convert raw values into real world values
+ segment_array = (raw_array * scale["multiplier"]) + scale["offset"]
+
+ # Update the current offset so it include the length of the data we have just read
+ self.current_offsets[direction][chan_name] += len(segment_array)
+
+ buf["Data"].append(segment_array)
+ if len(buf["Data"]) >= self.BUFFER_SIZE or curve_num == self.num_of_curves - 1:
+ buffered_data = np.concatenate(buf["Data"])
+ required_size = filled_size + len(buffered_data)
+ if required_size > data_size:
+ # Fetch and resize the existing dataset for this channel and segment to fit the new data
+ data_set.resize((required_size,))
+
+ # Add the buffer to the dataset
+ data_set[filled_size:required_size] = buffered_data
+ # Update the filled size for this channel and segment
+ self.points_for_channel_segment[direction][chan_name] = required_size
+ # Clear the buffer
+ buf["Data"].clear()
+
+ except KeyError:
+ self.failed_curves.add((curve_num, direction, chan_name))
+
+ # Limit the number of warnings to avoid spamming the logs
+ if len(self.failed_curves) < 10:
+ logger.warning(
+ f"Data file {dat_path} not found in archive. Skipping data for curve {curve_num}, "
+ f"direction {direction}, channel {chan_name}."
+ )
+ elif len(self.failed_curves) == 10:
+ logger.warning(
+ "Lots of missing files, further warnings will be suppressed. View summary at the end."
+ )
+
+ # Append the new index to the indices buffer
+ buf["Indices"].append(start_offset)
+
+ # If the indices buffer is full add it to the indices dataset and clear the buffer
+ if len(buf["Indices"]) > 0 and len(buf["Indices"]) % self.BUFFER_SIZE == 0:
+ indices_set[curve_num - self.BUFFER_SIZE + 1 : curve_num + 1] = buf["Indices"]
+ buf["Indices"].clear()
+
+ # Or if this is the last curve and there are still indices in the buffer
+ elif len(buf["Indices"]) > 0 and curve_num == self.num_of_curves - 1:
+ # Add the remaining indices to the indices dataset and clear the buffer
+ items_in_buffer = len(buf["Indices"])
+ indices_set[curve_num - items_in_buffer + 1 : curve_num + 1] = buf["Indices"]
+ buf["Indices"].clear()
+
+ else:
+ # Log if curve failed
+ self.failed_curves.add((curve_num, direction, chan_name))
+ if len(self.failed_curves) < 10: # Limit the number of warnings to avoid spamming the logs
+ logger.warning(
+ f"Channel {chan_name} not found in scaling information. Skipping data for curve {curve_num}, "
+ f"direction {direction}."
+ )
+
+ def extract_curve_metadata(self, curve_num: int, curve_work):
+ """
+ Extract the curve metadata from its header.properties file in the JPK QI archive and save to h5.
+
+ Parameters
+ ----------
+ curve_num : int
+ The curve number associated with the metadata.
+ curve_work : list
+ A list of tuples containing the search term for the metadata, the h5 dataset to save to,
+ and the buffer for that dataset.
+ """
+ meta_path = f"index/{curve_num}/header.properties"
+ raw_bytes = b""
+ try:
+ # Read metadata file as raw bytes
+ with self.qi_archive.open(meta_path) as f:
+ raw_bytes = f.read()
+ except KeyError:
+ self.failed_curves.add((curve_num, None, None))
+ # Limit the number of warnings to avoid spamming the logs
+ if len(self.failed_curves) < 10:
+ logger.warning(
+ f"Metadata file {meta_path} not found in archive. Skipping metadata for curve {curve_num}."
+ )
+ elif len(self.failed_curves) == 10:
+ logger.warning("Lots of missing files, further warnings will be suppressed. View summary at the end.")
+
+ for search_term, meta_set, meta_buffer in curve_work:
+ # Find the location of the metadata value in the raw bytes
+ start = raw_bytes.find(search_term)
+ # If found, extract the actual value
+ if start != -1:
+ start += len(search_term)
+ end = raw_bytes.find(b"\n", start)
+ value = raw_bytes[start:end].decode("utf-8").strip()
+ # Save a no data value if the search term is not found in the metadata file
+ else:
+ value = "No data"
+ if meta_buffer is not None:
+ meta_buffer.append(value)
+ if len(meta_buffer) >= self.BUFFER_SIZE or curve_num == self.num_of_curves - 1:
+ meta_set[curve_num - len(meta_buffer) + 1 : curve_num + 1] = meta_buffer
+ meta_buffer.clear()
+ else:
+ logger.error(
+ f"Metadata dataset for key {search_term.decode('utf-8')} not found when trying to save "
+ f"metadata for curve {curve_num}"
+ )
+
+ def extract_segment_metadata(self, curve_num: int, direction: int, seg_work):
+ """
+ Extract segment metadata from its header.properties file.
+
+ Parameters
+ ----------
+ curve_num : int
+ The curve number associated with the metadata.
+ direction : int
+ The segment direction (0 or 1) associated with the metadata.
+ seg_work : list
+ A list of tuples containing metadata extraction information.
+ """
+ meta_path = f"index/{curve_num}/segments/{direction}/segment-header.properties"
+ raw_content = b""
+ try:
+ with self.qi_archive.open(meta_path) as f:
+ raw_content = f.read()
+ except KeyError:
+ self.failed_curves.add((curve_num, direction, None))
+ if len(self.failed_curves) < 10: # Limit the number of warnings to avoid spamming the logs
+ logger.warning(
+ f"Metadata file {meta_path} not found in archive. Skipping metadata for curve {curve_num}, "
+ f"direction {direction}."
+ )
+ elif len(self.failed_curves) == 10:
+ logger.warning("Lots of missing files, further warnings will be suppressed. View summary at the end.")
+ for search_term, meta_set, meta_buffer in seg_work:
+ start = raw_content.find(search_term)
+ if start != -1:
+ start += len(search_term)
+ end = raw_content.find(b"\n", start)
+ value = raw_content[start:end].decode("utf-8").strip()
+ else:
+ value = "No data"
+ if meta_buffer is not None:
+ meta_buffer.append(value)
+ if len(meta_buffer) >= self.BUFFER_SIZE or curve_num == self.num_of_curves - 1:
+ idx = curve_num * 2 + direction
+ meta_set[idx - len(meta_buffer) + 1 : idx + 1] = meta_buffer
+ meta_buffer.clear()
+ else:
+ logger.error(
+ f"Metadata dataset for key {search_term.decode('utf-8')} not found when trying to save "
+ f"metadata for curve {curve_num}, direction {direction}"
+ )
+
+ def setup_h5_structure(self, h5file):
+ """
+ Set up structure in the h5 file for saving curve data and metadata.
+
+ Parameters
+ ----------
+ h5file : h5py.File
+ The h5 file in which to set up the structure.
+
+ Returns
+ -------
+ global_meta_group : h5py.Group
+ The h5 group for storing global metadata.
+ h5_datasets : dict
+ A dictionary containing the h5 datasets for storing curve data.
+ h5_meta_datasets : dict
+ A dictionary containing the h5 datasets for storing metadata.
+ h5_datasets_buffer : dict
+ A dictionary containing buffers for the curve data datasets for temporary pre-writing storage.
+ h5_meta_datasets_buffer : dict
+ A dictionary containing buffers for the metadata datasets for temporary pre-writing storage.
+ """
+ # Create the main group for the QI curve data that all the curve data will be in
+ qi_group = h5file.require_group("QI_Curve_Data")
+
+ # Establish empty groups for global metadata and curve metadata
+ global_meta_group = qi_group.require_group("Global_Metadata")
+ curves_meta_group = qi_group.require_group("Curve_Metadata")
+ curves_group = qi_group.require_group("Curves")
+
+ curve_groups = {"Data": {}, "Indices": {}}
+ h5_datasets = {}
+ h5_meta_datasets = {}
+ h5_datasets_buffer = {}
+ h5_meta_datasets_buffer = {}
+ for key in self.changing_curve_keys:
+ h5_meta_datasets[f"curve.{key}"] = curves_meta_group.create_dataset(
+ name=f"curve.{key}",
+ shape=(self.num_of_curves,),
+ maxshape=(None,),
+ chunks=self.META_CHUNKSIZE,
+ dtype=h5py.string_dtype(encoding="utf-8"),
+ )
+ h5_meta_datasets_buffer[f"curve.{key}"] = []
+ for key in self.changing_segment_keys:
+ h5_meta_datasets[f"segment.{key}"] = curves_meta_group.create_dataset(
+ name=f"segment.{key}",
+ shape=(self.num_of_curves * 2,),
+ maxshape=(None,),
+ chunks=self.META_CHUNKSIZE,
+ dtype=h5py.string_dtype(encoding="utf-8"),
+ )
+ h5_meta_datasets_buffer[f"segment.{key}"] = []
+
+ for direction in range(2):
+ # For each segment direction, establish necessary group structure that will contain each channel dataset
+ seg_name = f"Segment_{direction}"
+ dir_group = curves_group.require_group(seg_name)
+ h5_datasets[seg_name] = {}
+ h5_datasets_buffer[seg_name] = {}
+ # Create the Data and Indices subfolders and store their references
+ curve_groups["Data"][seg_name] = dir_group.require_group("Data")
+ curve_groups["Indices"][seg_name] = dir_group.require_group("Indices")
+ for chan in self.segment_channels:
+ h5_datasets[seg_name][chan["name"]] = {}
+ # For each channel, create an empty dataset
+ h5_datasets[seg_name][chan["name"]]["Data"] = curve_groups["Data"][seg_name].create_dataset(
+ name=chan["name"],
+ shape=(self.points_for_channel_segment[direction][chan["name"]],),
+ maxshape=(None,),
+ chunks=(self.DATA_CHUNKSIZE,),
+ dtype=np.float32,
+ )
+ h5_datasets[seg_name][chan["name"]]["Indices"] = curve_groups["Indices"][seg_name].create_dataset(
+ name=chan["name"],
+ shape=(self.num_of_curves + 1,),
+ maxshape=(None,),
+ chunks=(self.INDICES_CHUNKSIZE,),
+ dtype=np.int32,
+ )
+ h5_datasets_buffer[seg_name][chan["name"]] = {"Data": [], "Indices": []}
+ return global_meta_group, h5_datasets, h5_meta_datasets, h5_datasets_buffer, h5_meta_datasets_buffer
+
+ def get_saving_context(self):
+ """
+ Return the appropriate context manager for saving the data based on the save_as_h5 attribute.
+
+ If save_as_h5 is True, it returns a context manager for an h5 file. Otherwise, it returns a null context.
+
+ Returns
+ -------
+ contextlib.AbstractContextManager
+ The context manager for saving the data.
+ """
+ if self.save_as_h5:
+ return h5py.File(self.h5_path, "a")
+ return nullcontext()
+
+ def parse_dimension_data(self):
+ """Parse dimension data and calculate the pixel to nanometer scaling factor."""
+ # Extract both real size and pixel dimensions from the metadata
+ for key, value in self.top_level_meta.items():
+ if key.endswith(".ulength"):
+ self.size_x = float(value)
+ if key.endswith(".vlength"):
+ self.size_y = float(value)
+ if key.endswith(".ilength"):
+ self.shape_x = int(value)
+ if key.endswith(".jlength"):
+ self.shape_y = int(value)
+
+ # Log an error if any of these do not exist
+ if None in [self.size_x, self.size_y, self.shape_x, self.shape_y]:
+ logger.error(f"Incomplete dimension data in {self.filepath}")
+
+ # Calculate the pixel to nano metre scaling as an average of the scale for each direction
+ pixel_to_nm_scaling_factor_x = self.size_x / self.shape_x * 1e9 if self.shape_x > 0 else 1.0
+ pixel_to_nm_scaling_factor_y = self.size_y / self.shape_y * 1e9 if self.shape_y > 0 else 1.0
+ self.px2nm = (pixel_to_nm_scaling_factor_x + pixel_to_nm_scaling_factor_y) / 2
+
+ # Establish number of curves
+ self.num_of_curves = self.shape_x * self.shape_y
+
+ def extract_global_metadata(self):
+ """Extract global metadata and populate top level metadata dictionary and segment channels list."""
+ # Load the metadata from the global properties file
+ if "header.properties" in self.list_of_all_paths:
+ with self.qi_archive.open("header.properties") as archive_meta_file:
+ props = javaproperties.load(archive_meta_file)
+
+ # Add data from the main header to the top level metadata with a prefix to avoid key clashes
+ for key, value in props.items():
+ self.top_level_meta[f"main-header.{key}"] = value
+ else:
+ logger.error(f"File {self.filepath} does not contain essential metadata and cannot be loaded")
+
+ # Load the metadata from the shared header
+ if "shared-data/header.properties" in self.list_of_all_paths:
+ with self.qi_archive.open("shared-data/header.properties") as shared_data_file:
+ shared_meta = javaproperties.load(shared_data_file)
+ channel_idx = 0
+
+ # Add all the data from the shared header to the top level metadata with a prefix to avoid key clashes
+ for key, value in shared_meta.items():
+ self.top_level_meta[f"shared-data.{key}"] = value
+
+ # Collect channel data from the shared metadata
+ while f"lcd-info.{channel_idx}.channel.name" in shared_meta:
+ channel_dict = {}
+ # Calculate and store the offset and multiplier to convert raw values into real world values
+ multiplier, offset, unit = _get_channel_scaling(shared_meta, channel_idx)
+ channel_dict["name"] = shared_meta[f"lcd-info.{channel_idx}.channel.name"]
+ channel_dict["offset"] = offset
+ channel_dict["multiplier"] = multiplier
+ channel_dict["unit"] = unit
+ # Add the channel dict to the list
+ self.segment_channels.append(channel_dict)
+ # Increment the channel index to look for the next channel
+ channel_idx += 1
+ else:
+ logger.error(f"File {self.filepath} does not contain essential channel metadata and cannot be loaded")
+
+ if len(self.segment_channels) == 0:
+ logger.error("Could not find channels for segments")
+
+ # Create a lookup for channel name to unit to be returned
+ self.channels_units = {seg_chan["name"]: seg_chan["unit"] for seg_chan in self.segment_channels}
+ # Lookup map for binary scaling
+ self.channel_scaling = {chan["name"]: chan for chan in self.segment_channels}
+
+ def close(self):
+ """Close the ZIP archive when done to free up system resources."""
+ self.qi_archive.close()
+ self.image = None
+ self.curve_data = None
+ self.curve_meta = {}
+ self.segment_meta = {}
+ self.top_level_meta = {}
+ self.failed_curves = set()
+ self.points_for_channel_segment = {}
+ self.list_of_all_paths = []
+
+
+def load_jpk_data(filepath: str | Path, channel: str, cached_data: dict, save_as_h5: bool = False) -> AFMLoad:
+ """
+ Load the JPK QI data using the JPKQILoader.
+
+ Parameters
+ ----------
+ filepath : str | Path
+ Path to the JPK QI file.
+ channel : str
+ The channel to load from the file.
+ cached_data : dict
+ Cached data to avoid reloading heavy data.
+ save_as_h5 : bool, optional
+ Whether to save the loaded data as an h5 file for faster future loading. Default is False.
+
+ Returns
+ -------
+ AFMLoad
+ The loaded JPK QI data.
+ """
+ if "jpk_qi_loader" not in cached_data:
+ cached_data["jpk_qi_loader"] = JPKQILoader(filepath=filepath, channel=channel, save_as_h5=save_as_h5)
+ return cached_data["jpk_qi_loader"].load(channel=channel, save_as_h5=save_as_h5)
+
+
+def get_jpk_data_channels(filepath: str | Path, cached_data: dict) -> list[str]:
+ """
+ Get the available channels in the JPK QI data.
+
+ Parameters
+ ----------
+ filepath : str | Path
+ Path to the JPK QI file.
+ cached_data : dict
+ Cached data to avoid reloading heavy data.
+
+ Returns
+ -------
+ list[str]
+ A list of available channels in the JPK QI data.
+ """
+ if "jpk_qi_loader" not in cached_data:
+ cached_data["jpk_qi_loader"] = JPKQILoader(filepath=filepath)
+ return cached_data["jpk_qi_loader"].get_available_channels()
diff --git a/AFMReader/logging.py b/AFMReader/logging.py
index 5d3a70f..1655df7 100644
--- a/AFMReader/logging.py
+++ b/AFMReader/logging.py
@@ -7,7 +7,7 @@
logger.remove()
# Set the format to have blue time, green file, module, function and line, and white message
logger.add(
- sys.stderr,
+ lambda msg: sys.stderr.write(msg), # pylint: disable=unnecessary-lambda
colorize=True,
format="{time:HH:mm:ss} | {level} |"
"{file}:{module}:"
diff --git a/AFMReader/raw_bin.py b/AFMReader/raw_bin.py
new file mode 100644
index 0000000..40bcd24
--- /dev/null
+++ b/AFMReader/raw_bin.py
@@ -0,0 +1,120 @@
+"""Module to decode and load .bin AFM files into Python Numpy arrays."""
+
+import math
+from pathlib import Path
+
+import numpy as np
+
+from AFMReader.data_classes import AFMLoad
+
+from .logging import logger
+
+# pylint: disable=too-many-arguments,too-many-positional-arguments,too-many-locals,fixme
+
+DTYPE_MAP = {
+ "IEEE double": np.float64,
+ "DBL": np.float64,
+ "IEEE single": np.float32,
+ "SGL": np.float32,
+ "U32": np.uint32,
+ "I32": np.int32,
+ "U16": np.uint16,
+ "I16": np.int16,
+ "U8": np.uint8,
+ "I8": np.int8,
+ "float64": np.float64,
+ "float32": np.float32,
+ "int32": np.int32,
+}
+
+
+def load_bin(
+ filepath: str | Path,
+ data_type: str,
+ offset_bytes: int,
+ size_x: float | None = None,
+ size_y: float | None = None,
+ shape_x: int | None = None,
+ shape_y: int | None = None,
+ z_scaling: float = 1.0,
+) -> AFMLoad:
+ """
+ Load image from binary file. Parameters to interpret the binary file must be provided.
+
+ Parameters
+ ----------
+ filepath : str | Path
+ Path to the binary file.
+ data_type : str
+ Data type of the binary file.
+ offset_bytes : int
+ Number of bytes to skip at the beginning of the file.
+ size_x : float, optional
+ Size of the image in the x direction (default is None).
+ size_y : float, optional
+ Size of the image in the y direction (default is None).
+ shape_x : int, optional
+ Number of pixels in the x direction (default is None).
+ shape_y : int, optional
+ Number of pixels in the y direction (default is None).
+ z_scaling : float, optional
+ Scaling factor for the z values (default is 1.0).
+
+ Returns
+ -------
+ AFMLoad
+ An AFMLoad object containing the image and its pixel to nanometre scaling value.
+ """
+ filepath = Path(filepath)
+ dt_key = str(data_type).strip()
+ shape_x = None if shape_x == 0 else shape_x
+ shape_y = None if shape_y == 0 else shape_y
+
+ if dt_key in DTYPE_MAP:
+ np_dtype = DTYPE_MAP[dt_key]
+ else:
+ logger.warning(f"Unknown data type '{dt_key}'. Defaulting to float64.")
+ np_dtype = np.float64
+ with filepath.open("rb") as f:
+ f.seek(offset_bytes)
+ flat_data = np.fromfile(f, dtype=np_dtype)
+ if None in [shape_x, shape_y]:
+ dimension = int(math.sqrt(len(flat_data)))
+ shape_x, shape_y = dimension, dimension
+ assert shape_x is not None and shape_y is not None # noqa: PT018
+ assert size_x is not None and size_y is not None # noqa: PT018
+ if shape_x * shape_y != len(flat_data):
+ logger.error(f"Loading binary file {filepath.stem} did not receive a shape and is not square")
+ image = flat_data.reshape((shape_x, shape_y))
+ image *= z_scaling
+ pixel_to_nm_scaling_factor_x = size_x / shape_x if shape_x > 0 else 1.0
+ pixel_to_nm_scaling_factor_y = size_y / shape_y if shape_y > 0 else 1.0
+ px2nm = (pixel_to_nm_scaling_factor_x + pixel_to_nm_scaling_factor_y) / 2
+ return AFMLoad(image=image, px2nm=px2nm)
+
+
+def get_bin_channels():
+ """
+ Get the list of channels available in the binary file.
+
+ Since binary files do not have a standard structure,
+ this function returns an empty list (as no standard channels are available) and the expected keyword
+ arguments for loading a binary file.
+
+ Returns
+ -------
+ list
+ Empty list.
+ dict
+ Dictionary of expected keyword arguments for loading a binary file.
+ """
+ kwarg_types = {
+ "data_type": (str, DTYPE_MAP.keys()),
+ "offset_bytes": int,
+ "size_x": float,
+ "size_y": float,
+ "shape_x": int,
+ "shape_y": int,
+ "z_scaling": float,
+ }
+ return [], kwarg_types
diff --git a/AFMReader/spm.py b/AFMReader/spm.py
index 0a803e6..8bc4355 100644
--- a/AFMReader/spm.py
+++ b/AFMReader/spm.py
@@ -5,6 +5,7 @@
import numpy as np
import pySPM
+from AFMReader.data_classes import AFMLoad
from AFMReader.logging import logger
logger.enable(__package__)
@@ -54,7 +55,7 @@ def spm_pixel_to_nm_scaling(filename: str, channel_data: pySPM.SPM.SPM_image) ->
return pixel_to_nm_scaling
-def load_spm(file_path: Path | str, channel: str) -> tuple:
+def load_spm(file_path: Path | str, channel: str) -> AFMLoad:
"""
Extract image and pixel to nm scaling from the Bruker .spm file.
@@ -67,8 +68,8 @@ def load_spm(file_path: Path | str, channel: str) -> tuple:
Returns
-------
- tuple(np.ndarray, float)
- A tuple containing the image and its pixel to nanometre scaling value.
+ AFMLoad
+ An AFMLoad object containing the image and its pixel to nanometre scaling value.
Raises
------
@@ -83,7 +84,9 @@ def load_spm(file_path: Path | str, channel: str) -> tuple:
Sensor'.
>>> from AFMReader.spm import load_spm
- >>> image, pixel_to_nm = load_spm(file_path="path/to/file.spm", channel="Height")
+ >>> afm_load = load_spm(file_path="path/to/file.spm", channel="Height")
+ >>> image = afm_load.image
+ >>> pixel_to_nm = afm_load.px2nm
```
"""
logger.info(f"Loading image from : {file_path}")
@@ -109,4 +112,32 @@ def load_spm(file_path: Path | str, channel: str) -> tuple:
raise ValueError(f"'{channel}' not in {file_path.suffix} channel list: {labels}") from e
raise e
- return (image, spm_pixel_to_nm_scaling(filename, channel_data))
+ return AFMLoad(image=image, px2nm=spm_pixel_to_nm_scaling(filename, channel_data))
+
+
+def get_spm_channels(file_path: Path | str) -> list:
+ """
+ Get the list of channels available in the .spm file.
+
+ Parameters
+ ----------
+ file_path : Path or str
+ Path to the .spm file.
+
+ Returns
+ -------
+ list
+ List of available channels.
+ """
+ labels = []
+ file_path = Path(file_path)
+ filename = file_path.stem
+ try:
+ scan = pySPM.Bruker(file_path)
+ except FileNotFoundError:
+ logger.error(f"[{filename}] File not found : {file_path}")
+ raise
+ for channel_option in [layer[b"@2:Image Data"][0] for layer in scan.layers]:
+ channel_name = channel_option.decode("latin1").split('"')[1]
+ labels.append(channel_name)
+ return labels
diff --git a/AFMReader/stp.py b/AFMReader/stp.py
index ffda53b..2eeb5c5 100644
--- a/AFMReader/stp.py
+++ b/AFMReader/stp.py
@@ -5,6 +5,7 @@
import numpy as np
+from AFMReader.data_classes import AFMLoad
from AFMReader.io import read_double
from AFMReader.logging import logger
@@ -13,9 +14,7 @@
# pylint: disable=too-many-locals
# pylint: disable=too-many-statements
-def load_stp( # noqa: C901 (ignore too complex)
- file_path: Path | str, header_encoding: str = "latin-1"
-) -> tuple[np.ndarray, float]:
+def load_stp(file_path: Path | str, header_encoding: str = "latin-1") -> AFMLoad: # noqa: C901 (ignore too complex)
"""
Load image from STP files.
@@ -28,8 +27,8 @@ def load_stp( # noqa: C901 (ignore too complex)
Returns
-------
- tuple[np.ndarray, float]
- A tuple containing the image and its pixel to nanometre scaling value.
+ AFMLoad
+ An AFMLoad object containing the image and its pixel to nanometre scaling value.
Raises
------
@@ -108,4 +107,4 @@ def load_stp( # noqa: C901 (ignore too complex)
raise e
logger.info(f"[{filename}] : Extracted image.")
- return (image, pixel_to_nm_scaling)
+ return AFMLoad(image=image, px2nm=pixel_to_nm_scaling)
diff --git a/AFMReader/top.py b/AFMReader/top.py
index 6219a68..c555e23 100644
--- a/AFMReader/top.py
+++ b/AFMReader/top.py
@@ -5,6 +5,7 @@
import numpy as np
+from AFMReader.data_classes import AFMLoad
from AFMReader.io import read_int16
from AFMReader.logging import logger
@@ -14,9 +15,7 @@
# pylint: disable=too-many-locals
# pylint: disable=too-many-statements
# pylint: disable=too-many-branches
-def load_top( # noqa: C901 (ignore too complex)
- file_path: Path | str, header_encoding: str = "latin-1"
-) -> tuple[np.ndarray, float]:
+def load_top(file_path: Path | str, header_encoding: str = "latin-1") -> AFMLoad: # noqa: C901 (ignore too complex)
"""
Load image from TOP files.
@@ -29,8 +28,8 @@ def load_top( # noqa: C901 (ignore too complex)
Returns
-------
- tuple[np.ndarray, float]
- A tuple containing the image and its pixel to nanometre scaling value.
+ AFMLoad
+ An AFMLoad object containing the image and its pixel to nanometre scaling value.
Raises
------
@@ -117,4 +116,4 @@ def load_top( # noqa: C901 (ignore too complex)
raise e
logger.info(f"[{filename}] : Extracted image.")
- return (image, pixel_to_nm_scaling)
+ return AFMLoad(image=image, px2nm=pixel_to_nm_scaling)
diff --git a/AFMReader/topostats.py b/AFMReader/topostats.py
index 9b3ac44..46ec2d2 100644
--- a/AFMReader/topostats.py
+++ b/AFMReader/topostats.py
@@ -1,30 +1,32 @@
"""For decoding and loading .topostats (HDF5 format) AFM file format into Python Nympy arrays."""
from pathlib import Path
-from typing import Any
import h5py
from packaging.version import parse as parse_version
+from AFMReader.data_classes import AFMLoad
from AFMReader.io import unpack_hdf5
from AFMReader.logging import logger
logger.enable(__package__)
-def load_topostats(file_path: Path | str) -> dict[str, Any]:
+def load_topostats(file_path: Path | str, channel: str) -> AFMLoad:
"""
Extract image and pixel to nm scaling from the .topostats (HDF5 format) file.
Parameters
----------
- file_path : Path or str
+ file_path : Path | str
Path to the .topostats file.
+ channel : str
+ The channel to load.
Returns
-------
- dict[str, Any]
- A dictionary containing the image, its pixel to nm scaling factor and nested Numpy arrays representing the
+ AFMLoad
+ An AFMLoad object containing the image, its pixel to nm scaling factor, and nested Numpy arrays representing the
analyses performed on the data.
Raises
@@ -34,7 +36,9 @@ def load_topostats(file_path: Path | str) -> dict[str, Any]:
Examples
--------
- >>> image, pixel_to_nm_scaling = load_topostats("path/to/topostats_file.topostats")
+ >>> afm_load = load_topostats("path/to/topostats_file.topostats", channel="image")
+ >>> image = afm_load.image
+ >>> pixel_to_nm_scaling = afm_load.px2nm
"""
logger.info(f"Loading image from : {file_path}")
file_path = Path(file_path)
@@ -58,4 +62,27 @@ def load_topostats(file_path: Path | str) -> dict[str, Any]:
raise e
logger.info(f"[{filename}] : Extracted .topostats dictionary.")
- return data
+ try:
+ image = data.pop(channel)
+ pixel_to_nanometre_scaling_factor = data.pop("pixel_to_nm_scaling")
+ except KeyError as exc:
+ image_keys = ["image", "image_original"]
+ topostats_keys = list(data.keys())
+ raise ValueError(
+ f"'{channel}' not in available image keys: " f"{[im for im in image_keys if im in topostats_keys]}"
+ ) from exc
+
+ # Analyses are stored to metadata - this might be a bit clunky and potentially should be stored to their own attr
+ return AFMLoad(image=image, px2nm=pixel_to_nanometre_scaling_factor, metadata=data)
+
+
+def get_topostats_channels() -> list[str]:
+ """
+ Get the available channels for a .topostats file.
+
+ Returns
+ -------
+ list[str]
+ A list of available channels in the .topostats file.
+ """
+ return ["image", "image_original"]
diff --git a/README.md b/README.md
index c574127..9d48ca7 100644
--- a/README.md
+++ b/README.md
@@ -35,10 +35,12 @@ Supported file formats
| `.ibw` | [WaveMetrics](https://www.wavemetrics.com/) |
| `.jpk-qi-image` | [Bruker](https://www.bruker.com/) |
| `.jpk` | [Bruker](https://www.bruker.com/) |
+| `.jpk-qi-data` | [Bruker](https://www.bruker.com/) |
| `.spm` | [Bruker's Format](https://www.bruker.com/) |
| `.stp` | [WSXM AFM software files](http://www.wsxm.eu) |
| `.top` | `.stp` variant |
| `.topostats` | [TopoStats](https://github.com/AFM-SPM/TopoStats) |
+| `.bin` | Unspecificied binary file format |
Support for the following additional formats is planned. Some of these are already supported in TopoStats and are
awaiting refactoring to move their functionality into AFMReader these are denoted in bold below.
@@ -116,9 +118,9 @@ from AFMReader.ibw import load_ibw
image, pixel_to_nanometre_scaling_factor = load_ibw(file_path="./my_ibw_file.ibw", channel="HeightTracee")
```
-### .jpk
+### .jpk and .jpk-qi-image
-You can open `.jpk` files using the `load_jpk` function. Just pass in the path
+You can open `.jpk` and `.jpk-qi-image` files using the `load_jpk` function. Just pass in the path
to the file and the channel name you want to use. (If in doubt, use `height_trace` or `measuredHeight_trace`).
```python
@@ -127,6 +129,19 @@ from AFMReader.jpk import load_jpk
image, pixel_to_nanometre_scaling_factor = load_jpk(file_path="./my_jpk_file.jpk", channel="height_trace")
```
+### .jpk-qi-data
+
+You can open `.jpk-qi-data` files using the `jpk_qi_loader` class. Just pass in the path to the file
+and the channel name you want to use. Then call the `my_jpk_qi_loader.load()` method. If in doubt,
+use `height_trace` or `measuredHeight_trace`.
+
+```python
+from AFMReader.jpk_qi import jpk_qi_loader
+
+my_jpk_qi_loader = jpk_qi_loader(file_path="./my_jpk_qi_data_file.jpk-qi-data", channel="height_trace")
+image, pixel_to_nanometre_scaling_factor, force_curves = my_jpk_qi_loader.load()
+```
+
### .h5-jpk
You can open `.h5-jpk` files using the `load_h5jpk` function. Just pass in the path
@@ -138,7 +153,16 @@ Note: Since `.h5-jpk` stores timeseries AFM data a dictionary of timestamps for
```python
from AFMReader.h5_jpk import load_h5jpk
-frames, pixel_to_nanometre_scaling_factor, timestamp_dict = load_h5jpk(file_path="./my_jpk_file.jpk", channel="height_trace")
+frames, pixel_to_nanometre_scaling_factor, timestamp_dict = load_h5jpk(file_path="./my_jpk_file.h5-jpk", channel="height_trace")
+```
+
+If your `.h5-jpk` file was created from a `.jpk-qi-data` file, then the curve data can be read like so. Note that reading
+force curves like this will keep the file open as the force curves are lazy loaded from your hard drive.
+
+```python
+from AFMReader.h5_jpk import load_h5jpk
+
+frames, pixel_to_nanometre_scaling_factor, timestamp_dict, force_curves = load_h5jpk(file_path="./my_jpk_file.h5-jpk", channel="height_trace")
```
### .stp
@@ -163,6 +187,28 @@ from AFMReader.top import load_top
image, pixel_to_nanometre_scaling_factor = load_top(file_path="./my_top_file.top")
```
+### .bin
+
+You can open unspecified binary files using the `load_bin` function. You must supply the path
+to the file, the data type, the byte offset where the image data begins, and the physical dimensions
+of the scan. Supported `data_type` values include `"IEEE double"`, `"IEEE single"`, `"float64"`,
+`"float32"`, `"I32"`, `"U32"`, `"I16"`, `"U16"`, `"I8"`, and `"U8"`.
+
+```python
+from AFMReader.bin import load_bin
+
+image, pixel_to_nanometre_scaling_factor = load_bin(
+ filepath="./my_binary_file.bin",
+ data_type="IEEE double",
+ offset_bytes=0,
+ size_x=1000.0, # physical width in nm
+ size_y=1000.0, # physical height in nm
+ shape_x=512, # pixels along x
+ shape_y=512, # pixels along y
+ z_scaling=1.0, # optional z-axis scaling factor
+)
+```
+
## Contributing
Bug reports and feature requests are welcome. Please search for existing issues, if none relating to your bug/feature
diff --git a/pyproject.toml b/pyproject.toml
index f4cb025..cc0051e 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -41,6 +41,7 @@ dependencies = [
"pySPM",
"tifffile",
"ruamel.yaml",
+ "javaproperties"
]
[project.optional-dependencies]
diff --git a/tests/test_asd.py b/tests/test_asd.py
index 3f411df..a85bbf0 100644
--- a/tests/test_asd.py
+++ b/tests/test_asd.py
@@ -19,19 +19,30 @@
)
def test_load_asd(file_name: str, channel: str, number_of_frames: int, pixel_to_nm_scaling: float) -> None:
"""Test the normal operation of loading a .asd file."""
- result_frames = list
- result_pixel_to_nm_scaling = float
- result_metadata = dict
-
file_path = RESOURCES / file_name
- result_frames, result_pixel_to_nm_scaling, result_metadata = asd.load_asd(file_path, channel)
+ afm_load = asd.load_asd(file_path, channel)
- assert len(result_frames) == number_of_frames # type: ignore
- assert result_pixel_to_nm_scaling == pixel_to_nm_scaling
- assert isinstance(result_metadata, dict)
+ assert len(afm_load.image) == number_of_frames # type: ignore
+ assert afm_load.px2nm == pixel_to_nm_scaling
+ assert isinstance(afm_load.metadata, dict)
def test_load_asd_file_not_found() -> None:
"""Ensure FileNotFound error is raised."""
with pytest.raises(FileNotFoundError):
asd.load_asd("nonexistant_file.asd", channel="TP")
+
+
+@pytest.mark.parametrize(
+ ("file_name", "expected_channels"),
+ [
+ pytest.param("sample_0.asd", ["TP", "PH"], id="sample_0.asd"),
+ pytest.param("sample_1.asd", ["TP", "PH"], id="sample_1.asd"),
+ pytest.param("extra_sample.asd", ["TP", "PH"], id="extra_sample.asd"),
+ ],
+)
+def test_get_asd_channels(file_name: str, expected_channels: list[str]) -> None:
+ """Test get_asd_channels."""
+ file_path = RESOURCES / file_name
+ channels = asd.get_asd_channels(file_path)
+ assert sorted(channels) == sorted(expected_channels)
diff --git a/tests/test_general_loader.py b/tests/test_general_loader.py
index b04547d..ccc0ac3 100644
--- a/tests/test_general_loader.py
+++ b/tests/test_general_loader.py
@@ -1,6 +1,8 @@
"""Test the general loader module."""
+import re
from pathlib import Path
+from typing import Any
import numpy as np
import pytest
@@ -126,23 +128,19 @@
),
],
)
-def test_load(caplog: pytest.LogCaptureFixture, filepath: Path, channel: str, error: bool, message: str) -> None:
+def test_load(capsys: pytest.CaptureFixture, filepath: Path, channel: str, error: bool, message: str) -> None:
"""Test loading of all (asd, gwy, ibw, jpk, spm, stp, top, topostats) filetypes."""
loader = general_loader.LoadFile(filepath, channel)
-
- image, px2nm = loader.load()
-
- if not error:
- # check array and px2nm returned
- assert isinstance(image, np.ndarray)
- assert isinstance(px2nm, float)
+ if error:
+ with pytest.raises(ValueError, match=re.escape(message)):
+ loader.load()
else:
- # check when channel wrong
- assert isinstance(image, ValueError)
- assert px2nm is None
-
+ afm_load = loader.load()
+ assert isinstance(afm_load.image, np.ndarray)
+ assert isinstance(afm_load.px2nm, float)
# check output logs
- assert message in caplog.text
+ captured = capsys.readouterr()
+ assert message in captured.err
@pytest.mark.parametrize(
@@ -159,5 +157,122 @@ def test_load_filenotfounderror(filepath: Path) -> None:
loader = general_loader.LoadFile(filepath, "channel")
with pytest.raises(FileNotFoundError) as execinfo: # noqa: PT012
- _, _ = loader.load()
+ loader.load()
assert "[not_a_real_file] FileNotFoundError" in execinfo.value
+
+
+@pytest.mark.parametrize(
+ ("file_name", "expected"),
+ [
+ pytest.param("sample_0.asd", ["TP", "PH"], id="asd"),
+ pytest.param(
+ "sample_0.gwy",
+ [
+ "ZSensor",
+ "Peak Force Error",
+ "Stiffness",
+ "LogStiffness",
+ "Adhesion",
+ "Deformation",
+ "Dissipation",
+ "Height",
+ ],
+ id="gwy",
+ ),
+ pytest.param(
+ "sample_0.ibw",
+ [
+ "HeightTracee",
+ "HeightRetrace",
+ "ZSensorTrace",
+ "ZSensorRetrace",
+ "UserIn0Trace",
+ "UserIn0Retrace",
+ "UserIn1Trace",
+ "UserIn1Retrace",
+ ],
+ id="ibw",
+ ),
+ pytest.param(
+ "sample_0.jpk",
+ {
+ "height_retrace": 1,
+ "measuredHeight_retrace": 2,
+ "amplitude_retrace": 3,
+ "phase_retrace": 4,
+ "error_retrace": 5,
+ "height_trace": 6,
+ "measuredHeight_trace": 7,
+ "amplitude_trace": 8,
+ "phase_trace": 9,
+ "error_trace": 10,
+ },
+ id="jpk",
+ ),
+ pytest.param(
+ "sample_0.jpk-qi-image",
+ {
+ "measuredHeight_trace": 3,
+ "vDeflection_trace": 2,
+ "adhesion_trace": 4,
+ "height_trace": 5,
+ "slope_trace": 6,
+ },
+ id="jpk-qi-image",
+ ),
+ pytest.param(
+ "sample_0.spm",
+ [
+ "Height Sensor",
+ "Peak Force Error",
+ "DMTModulus",
+ "LogDMTModulus",
+ "Adhesion",
+ "Deformation",
+ "Dissipation",
+ "Height",
+ ],
+ id="spm",
+ ),
+ pytest.param(
+ "sample_0.h5-jpk",
+ [
+ "error_trace",
+ "height_trace",
+ "phase_retrace",
+ "height_retrace",
+ "measuredheight_trace",
+ "error_retrace",
+ "amplitude_trace",
+ "amplitude_retrace",
+ "phase_trace",
+ ],
+ id="h5-jpk sample_0",
+ ),
+ pytest.param(
+ "sample_0_1.topostats",
+ ["image", "image_original"],
+ id="topostats 0.1",
+ ),
+ pytest.param(
+ "sample_0_2.topostats",
+ ["image", "image_original"],
+ id="topostats 0.2",
+ ),
+ ],
+)
+def test_get_available_channels_all_formats(file_name: str, expected: Any) -> None:
+ """Test get_available_channels for all formats."""
+ file_path = RESOURCES / file_name
+ loader = general_loader.LoadFile(file_path, channel="")
+ channels = loader.get_available_channels()
+
+ if isinstance(expected, list):
+ assert sorted(channels) == sorted(expected)
+ elif isinstance(expected, tuple) and len(expected) == 2:
+ assert isinstance(channels, tuple)
+ assert len(channels) == 2
+ assert channels[0] == expected[0]
+ assert channels[1] == expected[1]
+ else:
+ assert channels == expected
diff --git a/tests/test_gwy.py b/tests/test_gwy.py
index 9b81a33..eb8e942 100644
--- a/tests/test_gwy.py
+++ b/tests/test_gwy.py
@@ -16,12 +16,12 @@ def test_load_gwy() -> None:
"""Test the normal operation of loading a .gwy file."""
channel = "ZSensor"
file_path = RESOURCES / "sample_0.gwy"
- result_image, result_pixel_to_nm_scaling = gwy.load_gwy(file_path, channel=channel)
- assert isinstance(result_image, np.ndarray)
- assert result_image.shape == (512, 512)
- assert result_image.sum() == pytest.approx(33836850.232917726)
- assert isinstance(result_pixel_to_nm_scaling, float)
- assert result_pixel_to_nm_scaling == pytest.approx(0.8468632812499975)
+ afm_load = gwy.load_gwy(file_path, channel=channel)
+ assert isinstance(afm_load.image, np.ndarray)
+ assert afm_load.image.shape == (512, 512)
+ assert afm_load.image.sum() == pytest.approx(33836850.232917726)
+ assert isinstance(afm_load.px2nm, float)
+ assert afm_load.px2nm == pytest.approx(0.8468632812499975)
def test_gwy_read_object() -> None:
@@ -117,3 +117,30 @@ def test_load_gwy_file_not_found() -> None:
"""Ensure FileNotFound error is raised."""
with pytest.raises(FileNotFoundError):
gwy.load_gwy("nonexistant_file.gwy", channel="TP")
+
+
+@pytest.mark.parametrize(
+ ("file_name", "expected_channels"),
+ [
+ pytest.param(
+ "sample_0.gwy",
+ [
+ "ZSensor",
+ "Peak Force Error",
+ "Stiffness",
+ "LogStiffness",
+ "Adhesion",
+ "Deformation",
+ "Dissipation",
+ "Height",
+ ],
+ id="sample_0.gwy",
+ ),
+ ],
+)
+def test_get_gwy_channels(file_name: str, expected_channels: list[str]) -> None:
+ """Test get_gwy_channels."""
+ file_path = RESOURCES / file_name
+ channels = gwy.get_gwy_channels(file_path)
+ # The order might not be guaranteed, so sort before comparing
+ assert sorted(channels) == sorted(expected_channels)
diff --git a/tests/test_h5jpk.py b/tests/test_h5jpk.py
index b0336b2..a5dde28 100644
--- a/tests/test_h5jpk.py
+++ b/tests/test_h5jpk.py
@@ -1,5 +1,7 @@
"""Test the loading of .5h-jpk files."""
+# mypy: disable-error-code="arg-type,index"
+
from pathlib import Path
import numpy as np
@@ -127,24 +129,113 @@ def test_load_h5jpk(
image_sum: float,
) -> None:
"""Test the normal operation of loading a .h5-jpk file."""
- result_image, result_pixel_to_nm_scaling, results_timestamps = h5_jpk.load_h5jpk(
- RESOURCES / file_name, channel, flip_image
- )
+ afm_load = h5_jpk.load_h5jpk(RESOURCES / file_name, channel, flip_image)
- assert result_pixel_to_nm_scaling == pytest.approx(pixel_to_nm_scaling)
- assert isinstance(result_image, np.ndarray)
- assert result_image.shape == image_shape
- assert result_image.dtype == np.dtype(image_dtype)
- assert isinstance(results_timestamps, timestamps_dtype)
- assert result_image.sum() == pytest.approx(image_sum)
- assert len(results_timestamps) == result_image.shape[0]
+ assert afm_load.px2nm == pytest.approx(pixel_to_nm_scaling)
+ assert isinstance(afm_load.image, np.ndarray)
+ assert afm_load.image.shape == image_shape
+ assert afm_load.image.dtype == np.dtype(image_dtype)
+ assert isinstance(afm_load.timestamps, timestamps_dtype)
+ assert afm_load.image.sum() == pytest.approx(image_sum)
+ assert len(afm_load.timestamps) == afm_load.image.shape[0]
assert all(
- results_timestamps[f"frame {i}"] < results_timestamps[f"frame {i + 1}"]
- for i in range(len(results_timestamps) - 1)
+ afm_load.timestamps[f"frame {i}"] < afm_load.timestamps[f"frame {i + 1}"]
+ for i in range(len(afm_load.timestamps) - 1)
)
+@pytest.mark.skip(reason="Test files are too large to store in the repo; a remote storage solution is needed.")
+@pytest.mark.parametrize(
+ (
+ "file_name",
+ "channel",
+ "flip_image",
+ "curve_coords",
+ "curve_direction",
+ "curve_targets",
+ ),
+ [
+ pytest.param(
+ "sample_0_curves.h5-jpk",
+ "height_trace",
+ True,
+ (0, 0),
+ "Segment_0",
+ {
+ "height": (31, 0.00019106604),
+ "measuredHeight": (31, 0.00027384484),
+ "smoothedMeasuredHeight": (31, -38066577408.0),
+ "vDeflection": (31, 2.7409627e-07),
+ },
+ id="test curves 0",
+ ),
+ ],
+)
+def test_load_h5jpk_curves(
+ file_name: str,
+ channel: str,
+ flip_image: bool,
+ curve_coords: tuple[int, int],
+ curve_direction: str,
+ curve_targets: dict[str, tuple[int, float]],
+) -> None:
+ """
+ Test loading of curve data from a .h5-jpk file.
+
+ Parameters
+ ----------
+ file_name : str
+ The name of the .h5-jpk file to load (should be located in the test resources directory).
+ channel : str
+ The channel to load curve data for.
+ flip_image : bool
+ Whether to flip the image vertically.
+ curve_coords : tuple[int, int]
+ The coordinates of the curve to load.
+ curve_direction : str
+ The direction of the curve to load.
+ curve_targets : dict[str, tuple[int, float]]
+ A dictionary mapping curve channels to their expected size and sum, used for validating the loaded curve data.
+ """
+ afm_load = h5_jpk.load_h5jpk(RESOURCES / file_name, channel, flip_image)
+ curve_dataset = afm_load.curves_dataset
+ assert curve_dataset is not None
+ curve_at_coords = curve_dataset.get_default_volume()[curve_coords[0], curve_coords[1]]
+ for curve_channel, (expected_size, expected_sum) in curve_targets.items():
+ curve = curve_at_coords[curve_channel][curve_direction]
+ assert curve.shape == (expected_size,)
+ assert curve.sum() == pytest.approx(expected_sum)
+
+
def test_load_h5jpk_file_not_found() -> None:
"""Ensure FileNotFound error is raised."""
with pytest.raises(FileNotFoundError):
h5_jpk.load_h5jpk("nonexistant_file.h5-jpk", channel="TP")
+
+
+@pytest.mark.parametrize(
+ ("file_name", "expected_channels"),
+ [
+ pytest.param(
+ "sample_0.h5-jpk",
+ [
+ "error_trace",
+ "height_trace",
+ "phase_retrace",
+ "height_retrace",
+ "measuredheight_trace",
+ "error_retrace",
+ "amplitude_trace",
+ "amplitude_retrace",
+ "phase_trace",
+ ],
+ id="sample_0.h5-jpk",
+ ),
+ ],
+)
+def test_get_h5jpk_channels(file_name: str, expected_channels: list[str]) -> None:
+ """Test get_h5jpk_channels."""
+ file_path = RESOURCES / file_name
+ channels = h5_jpk.get_h5jpk_channels(file_path)
+ # The order might not be guaranteed, so sort before comparing
+ assert sorted(channels) == sorted(expected_channels)
diff --git a/tests/test_ibw.py b/tests/test_ibw.py
index 330bdb4..2fa745a 100644
--- a/tests/test_ibw.py
+++ b/tests/test_ibw.py
@@ -24,20 +24,44 @@ def test_load_ibw(
image_sum: float,
) -> None:
"""Test the normal operation of loading an .ibw file."""
- result_image = np.ndarray
- result_pixel_to_nm_scaling = float
-
file_path = RESOURCES / file_name
- result_image, result_pixel_to_nm_scaling = ibw.load_ibw(file_path, channel) # type: ignore
+ afm_load = ibw.load_ibw(file_path, channel)
- assert result_pixel_to_nm_scaling == pytest.approx(pixel_to_nm_scaling)
- assert isinstance(result_image, np.ndarray)
- assert result_image.shape == image_shape
- assert result_image.dtype == image_dtype
- assert result_image.sum() == pytest.approx(image_sum)
+ assert afm_load.px2nm == pytest.approx(pixel_to_nm_scaling)
+ assert isinstance(afm_load.image, np.ndarray)
+ assert afm_load.image.shape == image_shape
+ assert afm_load.image.dtype == image_dtype
+ assert afm_load.image.sum() == pytest.approx(image_sum)
def test_load_ibw_file_not_found() -> None:
"""Ensure FileNotFound error is raised."""
with pytest.raises(FileNotFoundError):
ibw.load_ibw("nonexistant_file.ibw", channel="TP")
+
+
+@pytest.mark.parametrize(
+ ("file_name", "expected_channels"),
+ [
+ pytest.param(
+ "sample_0.ibw",
+ [
+ "HeightTracee",
+ "HeightRetrace",
+ "ZSensorTrace",
+ "ZSensorRetrace",
+ "UserIn0Trace",
+ "UserIn0Retrace",
+ "UserIn1Trace",
+ "UserIn1Retrace",
+ ],
+ id="sample_0.ibw",
+ ),
+ ],
+)
+def test_get_ibw_channels(file_name: str, expected_channels: list[str]) -> None:
+ """Test get_ibw_channels."""
+ file_path = RESOURCES / file_name
+ channels = ibw.get_ibw_channels(file_path)
+ # The order might not be guaranteed, so sort before comparing
+ assert sorted(channels) == sorted(expected_channels)
diff --git a/tests/test_jpk.py b/tests/test_jpk.py
index 2e574f8..e9fea19 100644
--- a/tests/test_jpk.py
+++ b/tests/test_jpk.py
@@ -73,19 +73,56 @@ def test_load_jpk(
image_sum: float,
) -> None:
"""Test the normal operation of loading a .jpk file."""
- result_image = np.ndarray
- result_pixel_to_nm_scaling = float
file_path = RESOURCES / file_name
- result_image, result_pixel_to_nm_scaling = jpk.load_jpk(file_path, channel) # type: ignore
+ afm_load = jpk.load_jpk(file_path, channel)
- assert result_pixel_to_nm_scaling == pytest.approx(pixel_to_nm_scaling)
- assert isinstance(result_image, np.ndarray)
- assert result_image.shape == image_shape
- assert result_image.dtype == image_dtype
- assert result_image.sum() == pytest.approx(image_sum)
+ assert afm_load.px2nm == pytest.approx(pixel_to_nm_scaling)
+ assert isinstance(afm_load.image, np.ndarray)
+ assert afm_load.image.shape == image_shape
+ assert afm_load.image.dtype == image_dtype
+ assert afm_load.image.sum() == pytest.approx(image_sum)
def test_load_jpk_file_not_found() -> None:
"""Ensure FileNotFound error is raised."""
with pytest.raises(FileNotFoundError):
jpk.load_jpk("nonexistant_file.jpk", channel="TP")
+
+
+@pytest.mark.parametrize(
+ ("file_name", "expected"),
+ [
+ pytest.param(
+ "sample_0.jpk",
+ {
+ "height_retrace": 1,
+ "measuredHeight_retrace": 2,
+ "amplitude_retrace": 3,
+ "phase_retrace": 4,
+ "error_retrace": 5,
+ "height_trace": 6,
+ "measuredHeight_trace": 7,
+ "amplitude_trace": 8,
+ "phase_trace": 9,
+ "error_trace": 10,
+ },
+ id="sample_0.jpk",
+ ),
+ pytest.param(
+ "sample_0.jpk-qi-image",
+ {
+ "measuredHeight_trace": 3,
+ "vDeflection_trace": 2,
+ "adhesion_trace": 4,
+ "height_trace": 5,
+ "slope_trace": 6,
+ },
+ id="sample_0.jpk-qi-image",
+ ),
+ ],
+)
+def test_get_jpk_channels(file_name: str, expected: dict[str, int]) -> None:
+ """Test get_jpk_channels."""
+ file_path = RESOURCES / file_name
+ channels = jpk.get_jpk_channels(file_path)
+ assert channels == expected
diff --git a/tests/test_jpk_qi.py b/tests/test_jpk_qi.py
new file mode 100644
index 0000000..febbd73
--- /dev/null
+++ b/tests/test_jpk_qi.py
@@ -0,0 +1,152 @@
+"""Test the loading of jpk-qi-data files."""
+
+from pathlib import Path
+
+import numpy as np
+import pytest
+
+from AFMReader import jpk_qi
+
+BASE_DIR = Path.cwd()
+RESOURCES = BASE_DIR / "tests" / "resources"
+
+
+@pytest.mark.skip(reason="Test files are too large to store in the repo; a remote storage solution is needed.")
+@pytest.mark.parametrize(
+ (
+ "file_name",
+ "channel",
+ "pixel_to_nm_scaling",
+ "image_shape",
+ "image_dtype",
+ "image_sum",
+ "curve_coords",
+ "curve_direction",
+ "curve_targets",
+ ),
+ [
+ pytest.param(
+ "sample_0.jpk-qi-data",
+ "height_trace",
+ 390.62499999999994,
+ (256, 256),
+ float,
+ 412271271.9961158,
+ (0, 0),
+ "Segment_0",
+ {
+ "height": (31, 0.00019106601492896875),
+ "vDeflection": (31, 2.740962611337846e-07),
+ "measuredHeight": (31, 0.00027384485398464497),
+ "smoothedMeasuredHeight": (31, -38066578894.999535),
+ },
+ id="qi-data 0; height_trace",
+ ),
+ pytest.param(
+ "sample_0.jpk-qi-data",
+ "slope_trace",
+ 390.62499999999994,
+ (256, 256),
+ float,
+ 267675.3050073493,
+ (0, 0),
+ "Segment_0",
+ {
+ "height": (31, 0.00019106601492896875),
+ "vDeflection": (31, 2.740962611337846e-07),
+ "measuredHeight": (31, 0.00027384485398464497),
+ "smoothedMeasuredHeight": (31, -38066578894.999535),
+ },
+ id="qi-data 0; slope_trace",
+ ),
+ pytest.param(
+ "sample_0.jpk-qi-data",
+ "adhesion_trace",
+ 390.62499999999994,
+ (256, 256),
+ float,
+ 0.0008930453784792601,
+ (0, 0),
+ "Segment_0",
+ {
+ "height": (31, 0.00019106601492896875),
+ "vDeflection": (31, 2.740962611337846e-07),
+ "measuredHeight": (31, 0.00027384485398464497),
+ "smoothedMeasuredHeight": (31, -38066578894.999535),
+ },
+ id="qi-data 0; adhesion_trace",
+ ),
+ pytest.param(
+ "sample_0.jpk-qi-data",
+ "measuredHeight_trace",
+ 390.62499999999994,
+ (256, 256),
+ float,
+ 590908347.7454677,
+ (0, 0),
+ "Segment_0",
+ {
+ "height": (31, 0.00019106601492896875),
+ "vDeflection": (31, 2.740962611337846e-07),
+ "measuredHeight": (31, 0.00027384485398464497),
+ "smoothedMeasuredHeight": (31, -38066578894.999535),
+ },
+ id="qi-data 0; measuredHeight_trace",
+ ),
+ pytest.param(
+ "sample_0.jpk-qi-data",
+ "vDeflection_trace",
+ 390.62499999999994,
+ (256, 256),
+ float,
+ 0.0004062236060247368,
+ (0, 0),
+ "Segment_0",
+ {
+ "height": (31, 0.00019106601492896875),
+ "vDeflection": (31, 2.740962611337846e-07),
+ "measuredHeight": (31, 0.00027384485398464497),
+ "smoothedMeasuredHeight": (31, -38066578894.999535),
+ },
+ id="qi-data 0; vDeflection_trace",
+ ),
+ ],
+)
+def test_load_jpk_qi_data( # pylint: disable=too-many-arguments,too-many-positional-arguments,too-many-locals
+ file_name: str,
+ channel: str,
+ pixel_to_nm_scaling: float,
+ image_shape: tuple[int, int],
+ image_dtype: type,
+ image_sum: float,
+ curve_coords: tuple[int, int],
+ curve_direction: str,
+ curve_targets: dict[str, tuple[int, float]],
+) -> None:
+ """Test the normal operation of loading a .jpk-qi-data file."""
+ file_path = RESOURCES / file_name
+ jpk_qi_loader = jpk_qi.JPKQILoader(file_path, channel)
+ afm_load = jpk_qi_loader.load()
+
+ assert afm_load.px2nm == pytest.approx(pixel_to_nm_scaling)
+ assert isinstance(afm_load.image, np.ndarray)
+ assert afm_load.image.shape == image_shape
+ assert afm_load.image.dtype == image_dtype
+ assert afm_load.image.sum() == pytest.approx(image_sum)
+
+ # Test curve data for all targets
+ curve_dataset = afm_load.curves_dataset
+ assert curve_dataset is not None, "Curves data not found/ is None"
+ curve_at_coords = curve_dataset.get_default_volume()[curve_coords[0], curve_coords[1]]
+ for curve_channel, (expected_size, expected_sum) in curve_targets.items():
+ curve = curve_at_coords[curve_channel][curve_direction]
+ assert curve.shape == (expected_size,)
+ assert curve.sum() == pytest.approx(expected_sum)
+
+ jpk_qi_loader.close()
+
+
+def test_load_jpk_data_file_not_found() -> None:
+ """Ensure FileNotFound error is raised."""
+ with pytest.raises(FileNotFoundError):
+ jpk_qi.JPKQILoader("noexistant_file.jpk-qi-data", "TP")
diff --git a/tests/test_spm.py b/tests/test_spm.py
index f15632e..7866911 100644
--- a/tests/test_spm.py
+++ b/tests/test_spm.py
@@ -33,17 +33,14 @@ def test_load_spm(
image_sum: float,
) -> None:
"""Test the normal operation of loading a .spm file."""
- result_image = np.ndarray
- result_pixel_to_nm_scaling = float
-
file_path = RESOURCES / file_name
- result_image, result_pixel_to_nm_scaling = spm.load_spm(file_path, channel=channel)
+ afm_load = spm.load_spm(file_path, channel=channel)
- assert result_pixel_to_nm_scaling == pytest.approx(pixel_to_nm_scaling)
- assert isinstance(result_image, np.ndarray)
- assert result_image.shape == image_shape
- assert result_image.dtype == image_dtype
- assert result_image.sum() == pytest.approx(image_sum)
+ assert afm_load.px2nm == pytest.approx(pixel_to_nm_scaling)
+ assert isinstance(afm_load.image, np.ndarray)
+ assert afm_load.image.shape == image_shape
+ assert afm_load.image.dtype == image_dtype
+ assert afm_load.image.sum() == pytest.approx(image_sum)
@patch("pySPM.SPM.SPM_image")
@@ -132,7 +129,7 @@ def test_load_spm_file_not_found() -> None:
],
)
def test_load_spm_channel_not_found(
- caplog: pytest.LogCaptureFixture,
+ capsys: pytest.CaptureFixture,
channel: str,
message: str,
error: bool,
@@ -143,4 +140,31 @@ def test_load_spm_channel_not_found(
spm.load_spm(RESOURCES / "sample_0.spm", channel)
else:
spm.load_spm(RESOURCES / "sample_0.spm", channel)
- assert message in caplog.text
+ captured = capsys.readouterr()
+ assert message in captured.err
+
+
+@pytest.mark.parametrize(
+ ("file_name", "expected_channels"),
+ [
+ pytest.param(
+ "sample_0.spm",
+ [
+ "Height Sensor",
+ "Peak Force Error",
+ "DMTModulus",
+ "LogDMTModulus",
+ "Adhesion",
+ "Deformation",
+ "Dissipation",
+ "Height",
+ ],
+ id="sample_0.spm",
+ ),
+ ],
+)
+def test_get_spm_channels(file_name: str, expected_channels: list[str]) -> None:
+ """Test get_spm_channels."""
+ file_path = RESOURCES / file_name
+ channels = spm.get_spm_channels(file_path)
+ assert channels == expected_channels
diff --git a/tests/test_stp.py b/tests/test_stp.py
index 271a85a..ba928b8 100644
--- a/tests/test_stp.py
+++ b/tests/test_stp.py
@@ -33,10 +33,10 @@ def test_load_stp(
) -> None:
"""Test the normal operation of loading a .stp file."""
file_path = RESOURCES / file_name
- result_image, result_pixel_to_nm_scaling = load_stp(file_path=file_path)
+ afm_load = load_stp(file_path=file_path)
- assert result_pixel_to_nm_scaling == pytest.approx(expected_pixel_to_nm_scaling)
- assert isinstance(result_image, np.ndarray)
- assert result_image.shape == expected_image_shape
- assert result_image.dtype == expected_image_dtype
- assert result_image.sum() == pytest.approx(expected_image_sum)
+ assert afm_load.px2nm == pytest.approx(expected_pixel_to_nm_scaling)
+ assert isinstance(afm_load.image, np.ndarray)
+ assert afm_load.image.shape == expected_image_shape
+ assert afm_load.image.dtype == expected_image_dtype
+ assert afm_load.image.sum() == pytest.approx(expected_image_sum)
diff --git a/tests/test_top.py b/tests/test_top.py
index e44b5ce..ae20b36 100644
--- a/tests/test_top.py
+++ b/tests/test_top.py
@@ -33,10 +33,10 @@ def test_load_top(
) -> None:
"""Test the normal operation of loading a .top file."""
file_path = RESOURCES / file_name
- result_image, result_pixel_to_nm_scaling = load_top(file_path=file_path)
+ afm_load = load_top(file_path=file_path)
- assert result_pixel_to_nm_scaling == pytest.approx(expected_pixel_to_nm_scaling)
- assert isinstance(result_image, np.ndarray)
- assert result_image.shape == expected_image_shape
- assert result_image.dtype == expected_image_dtype
- assert result_image.sum() == pytest.approx(expected_image_sum)
+ assert afm_load.px2nm == pytest.approx(expected_pixel_to_nm_scaling)
+ assert isinstance(afm_load.image, np.ndarray)
+ assert afm_load.image.shape == expected_image_shape
+ assert afm_load.image.dtype == expected_image_dtype
+ assert afm_load.image.sum() == pytest.approx(expected_image_sum)
diff --git a/tests/test_topostats.py b/tests/test_topostats.py
index 66a7989..c2c2a18 100644
--- a/tests/test_topostats.py
+++ b/tests/test_topostats.py
@@ -89,21 +89,23 @@ def test_load_topostats(
) -> None:
"""Test the normal operation of loading a .topostats (HDF5 format) file."""
file_path = RESOURCES / file_name
- topostats_data = topostats.load_topostats(file_path)
+ afm_load = topostats.load_topostats(file_path, channel="image")
- assert set(topostats_data.keys()) == data_keys # type: ignore
+ expected_metadata_keys = data_keys - {"image", "pixel_to_nm_scaling"}
+ assert afm_load.metadata is not None
+ assert set(afm_load.metadata.keys()) == expected_metadata_keys
if version_key == "topostats_file_version":
- assert topostats_data[version_key] == float(version)
+ assert afm_load.metadata[version_key] == float(version)
else:
- assert topostats_data[version_key] == version
- assert topostats_data["pixel_to_nm_scaling"] == pytest.approx(pixel_to_nm_scaling)
- assert topostats_data["image"].shape == image_shape
- assert topostats_data["image"].sum() == pytest.approx(image_sum)
+ assert afm_load.metadata[version_key] == version
+ assert afm_load.px2nm == pytest.approx(pixel_to_nm_scaling)
+ assert afm_load.image.shape == image_shape
+ assert afm_load.image.sum() == pytest.approx(image_sum)
if version > "0.2":
- assert isinstance(topostats_data["img_path"], Path)
+ assert isinstance(afm_load.metadata["img_path"], Path)
def test_load_topostats_file_not_found() -> None:
"""Ensure FileNotFound error is raised."""
with pytest.raises(FileNotFoundError):
- topostats.load_topostats("nonexistant_file.topostats")
+ topostats.load_topostats("nonexistant_file.topostats", channel="image")