diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml new file mode 100644 index 0000000..adf79b0 --- /dev/null +++ b/.github/workflows/ci.yaml @@ -0,0 +1,132 @@ +name: CI + +on: + push: + branches: + - main + - test + tags: + - "*" + pull_request: + branches: + - main + +permissions: + contents: read + actions: read + +env: + PYTHONDONTWRITEBYTECODE: 1 + FORCE_COLOR: 1 + PYTHON_VERSION: "3.12" + +jobs: + test: + name: Run tests πŸ§ͺ + runs-on: ubuntu-latest + permissions: + contents: write + + steps: + - name: Checkout repo πŸ›ŽοΈ + uses: actions/checkout@v6 + with: + persist-credentials: false + + - name: Set up Python 🐍 + uses: actions/setup-python@v6 + with: + python-version: ${{ env.PYTHON_VERSION }} + cache: pip + + - name: Restore testing cache πŸ“₯ + uses: actions/cache@v4 + with: + path: | + testing/datasets + testing/extractions + testing/protocol + testing/share + key: testing-${{ runner.os }}-${{ env.PYTHON_VERSION }}-${{ hashFiles('isimip_utils/tests/constants.py') }} + restore-keys: | + testing-${{ runner.os }}-${{ env.PYTHON_VERSION }}- + + - name: Install cdo 🌍 + run: | + sudo apt-get update + sudo apt-get install -y cdo netcdf-bin --no-install-recommends + + - name: Install package πŸ“¦ + run: pip install -e .[all] + + - name: Run download script 🌐 + run: python testing/download.py + + - name: Run setup script πŸ”§ + run: python testing/setup.py + + - name: Save testing cache πŸ“€ + if: always() + uses: actions/cache/save@v4 + with: + path: | + testing/datasets + testing/extractions + testing/protocol + testing/share + key: testing-${{ runner.os }}-${{ env.PYTHON_VERSION }}-${{ hashFiles('isimip_utils/tests/constants.py') }} + + - name: Run pytest πŸ§ͺ + run: pytest --cov=isimip_utils --cov-fail-under=90 --cov-report=term-missing + + build: + name: Build distribution πŸ‘· + needs: test + runs-on: ubuntu-latest + + steps: + - name: Checkout repo πŸ›ŽοΈ + uses: actions/checkout@v6 + with: + persist-credentials: false + + - name: Set up Python 🐍 + uses: actions/setup-python@v6 + with: + python-version: ${{ env.PYTHON_VERSION }} + cache: pip + + - name: Install build 🧱 + run: python3 -m pip install build --user + + - name: Build a binary wheel and a source tarball πŸ› οΈ + run: python3 -m build + + - name: Store the distribution packages πŸ“€ + uses: actions/upload-artifact@v5 + with: + name: python-package-distributions + path: dist/ + + pypi: + name: Publish distribution to PyPI πŸ“¦ + if: startsWith(github.ref, 'refs/tags/') + needs: build + runs-on: ubuntu-latest + + environment: + name: pypi + url: https://pypi.org/p/isimip-utils + + permissions: + id-token: write + + steps: + - name: Download the distribution packages πŸ“₯ + uses: actions/download-artifact@v6 + with: + name: python-package-distributions + path: dist/ + + - name: Publish to PyPI πŸš€ + uses: pypa/gh-action-pypi-publish@release/v1 diff --git a/.gitignore b/.gitignore index 3e122cd..26c23b3 100644 --- a/.gitignore +++ b/.gitignore @@ -11,4 +11,19 @@ __pycache__/ /build /dist /*.egg-info + +/.aider* /.pytest_cache +/.ruff_cache + +/.coverage +/htmlcov + +/site + +/testing/datasets +/testing/extractions +/testing/output +/testing/plots +/testing/protocol +/testing/share diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 419443e..a37beb0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,15 +3,21 @@ repos: hooks: - id: check-hooks-apply - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v6.0.0 hooks: - id: check-ast - id: check-yaml - id: end-of-file-fixer - id: trailing-whitespace - id: debug-statements + - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: v0.0.284 + rev: v0.14.6 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] + + - repo: https://github.com/crate-ci/typos + rev: v1.39.2 + hooks: + - id: typos diff --git a/CITATION.cff b/CITATION.cff new file mode 100644 index 0000000..4725156 --- /dev/null +++ b/CITATION.cff @@ -0,0 +1,19 @@ +cff-version: 1.2.0 +message: If you use this software in your research, please cite it using the provided Digital Object Identifier (DOI). + +title: ISIMIP utils +abstract: Common functionality for different ISIMIP tools. + +authors: +- family-names: Klar + given-names: Jochen + orcid: https://orcid.org/0000-0002-5883-4273 +- family-names: BΓΌchner + given-names: Matthias + orcid: https://orcid.org/0000-0002-1382-7424 +- family-names: Inga + given-names: Sauer + orcid: https://orcid.org/0000-0002-9302-2131 + +license: MIT +repository-code: https://github.com/ISI-MIP/isimip-utils diff --git a/LICENSE b/LICENSE index fe4013f..9d51b46 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2022 Potsdam Institute for Climate Impact Research +Copyright (c) 2022-2026 Potsdam Institute for Climate Impact Research Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/README.md b/README.md index 3fb074c..ac1a29d 100644 --- a/README.md +++ b/README.md @@ -1,48 +1,85 @@ ISIMIP utils ============ -[![Latest release](https://shields.io/github/v/release/ISI-MIP/isimip-utils)](https://github.com/ISI-MIP/isimip-utils/releases) -[![PyPI Release](https://img.shields.io/pypi/v/isimip-utils)](https://pypi.org/project/isimip-utils/) -[![Python Version](https://img.shields.io/badge/python->=3.8-blue)](https://www.python.org/) -[![License](https://img.shields.io/badge/license-MIT-green)](https://github.com/ISI-MIP/django-datacite/blob/master/LICENSE) +[![Python Version](https://img.shields.io/badge/python->=3.11-blue)](https://www.python.org/) +[![License](https://img.shields.io/github/license/ISI-MIP/isimip-utils?style=flat)](https://github.com/rdmorganiser/isimip-utils/blob/main/LICENSE) +[![CI status](https://github.com/ISI-MIP/isimip-utils/actions/workflows/ci.yml/badge.svg)](https://github.com/ISI-MIP/isimip-utils/actions/workflows/ci.yml) +[![Coverage status](https://coveralls.io/repos/ISI-MIP/isimip-utils/badge.svg?branch=main&service=github)](https://coveralls.io/github/ISI-MIP/isimip-utils?branch=main) +[![Latest release](https://img.shields.io/pypi/v/isimip-utils.svg?style=flat)](https://pypi.python.org/pypi/isimip-utils/) -This package contains common functionality for different ISIMIP tools, namely: -* https://github.com/ISI-MIP/isimip-publisher -* https://github.com/ISI-MIP/isimip-qa -* https://github.com/ISI-MIP/isimip-qc +[ISIMIP](https://isimip.org) offers a framework for consistently projecting the impacts +of climate change across affected sectors and spatial scales. An international network +of climate-impact modellers contribute to a comprehensive and consistent picture of the +world under different climate-change scenarios. -It comprises of: +This package contains various utility methods for use in custom scripts as well +as in different ISIMIP tools: -* `isimip_utils.checksum`: Functions to compute the SHA-512 checksum of a file. -* `isimip_utils.config`: A settings class to combine input from `argparse`, the environment (via `python-dotenv`) and config files. -* `isimip_utils.exceptions`: Custom exceptions for ISIMIP tools. -* `isimip_utils.fetch`: Functions to fetch files from the machine-actionable ISIMIP protocols. -* `isimip_utils.netcdf`: Functions to open and read NetCDF files. -* `isimip_utils.patterns`: Functions to match the file names and extract the ISIMIP specifiers. -* `isimip_utils.utils`: Additional utility functions. +* [ISIMIP quality control](https://github.com/ISI-MIP/isimip-qc) +* [ISIMIP quality assurance](https://github.com/ISI-MIP/isimip-qa) +* [ISIMIP publisher](https://github.com/ISI-MIP/isimip-publisher) + + +The different methods are described are [documented here](docs/index.md). Setup -===== +----- -Working on the package requires a running Python3 on your system. Installing those prerequisites is covered [here](https://github.com/ISI-MIP/isimip-utils/blob/master/docs/releases.md). +Using the package requires a running Python 3 on your system. The installation for different systems is covered +[here](docs/releases.md). -The package itself can be installed via pip: +Unless you already use an environment manager (e.g. `conda` or `uv`), it is highly recommended to use a +[virtual environment](https://docs.python.org/3/library/venv.html), which can be created using: -``` -pip install isimip-utils +```bash +python3 -m venv env +source env/bin/activate # needs to be invoked in every new terminal session ``` -The package can also be installed directly from GitHub: +The package itself can be installed via `pip`: -``` -pip install git+https://github.com/ISI-MIP/isimip-utils +```bash +pip install isimip-utils ``` For a development setup, the repo should be cloned and installed in *editable* mode: -``` +```bash git clone git@github.com:ISI-MIP/isimip-utils pip install -e isimip-utils ``` + + +Usage +----- + +Once installed, the modules can be used like any other Python library, e.g. in order to create a ISIMIP +compliant NetCDF file, you can use: + +```python +from isimip_utils.xarray import init_dataset, write_dataset + +time = np.arrange(0, 365, dtype=np.float64) +var = np.ones((365, 360, 720), dtype=np.float32) + +attrs={ + 'global': { + 'contact': 'mail@example.com' + }, + 'var': { + 'standard_name': 'var', + 'long_name': 'Variable', + 'units': '1', + } +} + +# create an xarray.Dataset +ds = init_dataset(time=time, var=var, attrs=attrs) + +# write the dataset as NetCDF file +write_dataset(ds, 'output.nc') +``` + +Please also note our [examples page](examples.md) and the [API reference](api.md). diff --git a/docs/api/checksum.md b/docs/api/checksum.md new file mode 100644 index 0000000..53f5adf --- /dev/null +++ b/docs/api/checksum.md @@ -0,0 +1,3 @@ +# isimip_utils.checksum + +::: isimip_utils.checksum diff --git a/docs/api/cli.md b/docs/api/cli.md new file mode 100644 index 0000000..9066d91 --- /dev/null +++ b/docs/api/cli.md @@ -0,0 +1,3 @@ +# isimip_utils.cli + +::: isimip_utils.cli diff --git a/docs/api/config.md b/docs/api/config.md new file mode 100644 index 0000000..e6e2d51 --- /dev/null +++ b/docs/api/config.md @@ -0,0 +1,3 @@ +# isimip_utils.config + +::: isimip_utils.config diff --git a/docs/api/exceptions.md b/docs/api/exceptions.md new file mode 100644 index 0000000..3362831 --- /dev/null +++ b/docs/api/exceptions.md @@ -0,0 +1,3 @@ +# isimip_utils.exceptions + +::: isimip_utils.exceptions diff --git a/docs/api/extractions.md b/docs/api/extractions.md new file mode 100644 index 0000000..bc0be87 --- /dev/null +++ b/docs/api/extractions.md @@ -0,0 +1,3 @@ +# isimip_utils.extractions + +::: isimip_utils.extractions diff --git a/docs/api/fetch.md b/docs/api/fetch.md new file mode 100644 index 0000000..91317da --- /dev/null +++ b/docs/api/fetch.md @@ -0,0 +1,3 @@ +# isimip_utils.fetch + +::: isimip_utils.fetch diff --git a/docs/api/files.md b/docs/api/files.md new file mode 100644 index 0000000..e8a29c0 --- /dev/null +++ b/docs/api/files.md @@ -0,0 +1,3 @@ +# isimip_utils.files + +::: isimip_utils.files diff --git a/docs/api/netcdf.md b/docs/api/netcdf.md new file mode 100644 index 0000000..f47af95 --- /dev/null +++ b/docs/api/netcdf.md @@ -0,0 +1,3 @@ +# isimip_utils.netcdf + +::: isimip_utils.netcdf diff --git a/docs/api/pandas.md b/docs/api/pandas.md new file mode 100644 index 0000000..37731a9 --- /dev/null +++ b/docs/api/pandas.md @@ -0,0 +1,3 @@ +# isimip_utils.pandas + +::: isimip_utils.pandas diff --git a/docs/api/parameters.md b/docs/api/parameters.md new file mode 100644 index 0000000..ac1a072 --- /dev/null +++ b/docs/api/parameters.md @@ -0,0 +1,3 @@ +# isimip_utils.parameters + +::: isimip_utils.parameters diff --git a/docs/api/patterns.md b/docs/api/patterns.md new file mode 100644 index 0000000..e6e8bcf --- /dev/null +++ b/docs/api/patterns.md @@ -0,0 +1,3 @@ +# isimip_utils.patterns + +::: isimip_utils.patterns diff --git a/docs/api/plot.md b/docs/api/plot.md new file mode 100644 index 0000000..5f4aa30 --- /dev/null +++ b/docs/api/plot.md @@ -0,0 +1,3 @@ +# isimip_utils.plot + +::: isimip_utils.plot diff --git a/docs/api/protocol.md b/docs/api/protocol.md new file mode 100644 index 0000000..7f986e0 --- /dev/null +++ b/docs/api/protocol.md @@ -0,0 +1,3 @@ +# isimip_utils.protocol + +::: isimip_utils.protocol diff --git a/docs/api/utils.md b/docs/api/utils.md new file mode 100644 index 0000000..adb212f --- /dev/null +++ b/docs/api/utils.md @@ -0,0 +1,3 @@ +# isimip_utils.utils + +::: isimip_utils.utils diff --git a/docs/api/xarray.md b/docs/api/xarray.md new file mode 100644 index 0000000..243f615 --- /dev/null +++ b/docs/api/xarray.md @@ -0,0 +1,3 @@ +# isimip_utils.xarray + +::: isimip_utils.xarray diff --git a/docs/examples.md b/docs/examples.md new file mode 100644 index 0000000..bac945d --- /dev/null +++ b/docs/examples.md @@ -0,0 +1,2 @@ +Examples +======== diff --git a/docs/index.md b/docs/index.md new file mode 100644 index 0000000..60c38a2 --- /dev/null +++ b/docs/index.md @@ -0,0 +1,107 @@ +ISIMIP utils +============ + +[ISIMIP](https://isimip.org) offers a framework for consistently projecting the impacts +of climate change across affected sectors and spatial scales. An international network +of climate-impact modellers contribute to a comprehensive and consistent picture of the +world under different climate-change scenarios. + +Overview +-------- + +This package contains various utility methods for use in custom scripts as well +as in different ISIMIP tools: + +* [ISIMIP quality control](https://github.com/ISI-MIP/isimip-qc) +* [ISIMIP quality assurance](https://github.com/ISI-MIP/isimip-qa) +* [ISIMIP publisher](https://github.com/ISI-MIP/isimip-publisher) + +The following modules contain high-level method to extract data (e.g. aggregated time series of points, areas, shapes) +from global ISIMIP data sets and create gridded plots visualizing the data: + +* [`isimip_utils.extractions`](api/extractions.md): Create extractions using [Xarray](https://docs.xarray.dev). +* [`isimip_utils.plot`](api/plot.md): Plotting utilities using [Vega-Altair](https://altair-viz.github.io). + +Lower-level functions are provided to interact with the data sets and customize `xarray`, `pandas`, and `netcdf` +for ISIMIP conventions. + +* [`isimip_utils.xarray`](api/xarray.md): Functions for working with `xarray` datasets. +* [`isimip_utils.netcdf`](api/netcdf.md): Functions to open and read NetCDF files using netCDF4. +* [`isimip_utils.pandas`](api/pandas.md): Pandas utilities for ISIMIP data processing. + +Two modules focus on the interface to the [machine-readable ISIMIP protocol](https://protocol.isimip.org): + +* [`isimip_utils.patterns`](api/patterns.md): Functions to fetch information from machine-actionable ISIMIP protocols. +* [`isimip_utils.protocol`](api/patterns.md): Functions to match file names and extract ISIMIP specifiers. + +The remaining modules contain utility functions which are used by the other modules or by the ISIMIP tools mentioned above: + +* [`isimip_utils.checksum`](api/checksum.md): Checksum computation utilities for file integrity verification. +* [`isimip_utils.cli`](api/cli.md): Command-line interface utilities for argument parsing and configuration. +* [`isimip_utils.config`](api/config.md): A `Settings` class for command-line interface utilities. +* [`isimip_utils.exceptions`](api/exceptions.md): Custom exceptions for ISIMIP tools. +* [`isimip_utils.fetch`](api/fetch.md): Functions to fetch files from urls or local paths. +* [`isimip_utils.files`](api/files.md): File search utilities with regex pattern matching. +* [`isimip_utils.parameters`](api/parameters.md): Utility functions for the work with parameters and placeholders. +* [`isimip_utils.utils`](api/utils.md): Additional utility functions. + + +Setup +----- + +Using the package requires a running Python 3 on your system. The installation for different systems is covered +[here](https://github.com/ISI-MIP/isimip-utils/blob/master/docs/releases.md). + +Unless you already use an environment manager (e.g. `conda` or `uv`), it is highly recommended to use a +[virtual environment](https://docs.python.org/3/library/venv.html), which can be created using: + +```bash +python3 -m venv env +source env/bin/activate # needs to be invoked in every new terminal session +``` + +The package itself can be installed via `pip`: + +```bash +pip install isimip-utils +``` + +For a development setup, the repo should be cloned and installed in *editable* mode: + +```bash +git clone git@github.com:ISI-MIP/isimip-utils +pip install -e isimip-utils +``` + + +Usage +----- + +Once installed, the modules can be used like any other Python library, e.g. in order to create a ISIMIP +compliant NetCDF file, you can use: + +```python +from isimip_utils.xarray import init_dataset, write_dataset + +time = np.arange(0, 365, dtype=np.float64) +var = np.ones((365, 360, 720), dtype=np.float32) + +attrs={ + 'global': { + 'contact': 'mail@example.com' + }, + 'var': { + 'standard_name': 'var', + 'long_name': 'Variable', + 'units': '1', + } +} + +# create an xarray.Dataset +ds = init_dataset(time=time, var=var, attrs=attrs) + +# write the dataset as NetCDF file +write_dataset(ds, 'output.nc') +``` + +Please also note our page with additional [examples](examples.md) and the [API reference](api.md). diff --git a/docs/prerequisites.md b/docs/prerequisites.md index a739038..fc5a6e7 100644 --- a/docs/prerequisites.md +++ b/docs/prerequisites.md @@ -1,9 +1,11 @@ -Prerequisites -------------- +Python installation +------------------- -The installation of Python (and its developing packages) differs from operating system to operating system. Optional Git is needed if a package is installed directly from GitHub. +Using the package requires a running Python 3 on your system. The installation of Python (and its developing +packages) differs from operating system to operating system. Optional Git is needed if a package is installed +directly from GitHub. -### Linux +## Linux On Linux, Python3 is probably already installed, but the development packages are usually not. Optionally, Git can be installed as well. You should be able to install all prerequisites using: @@ -21,7 +23,7 @@ sudo zypper install python3 python3-devel sudo zypper install git ``` -### macOS +## macOS While we reccoment using [Homebrew](https://brew.sh) to install Python3 on a Mac, other means of obtaining Python like [Anaconda](https://www.anaconda.com/products/individual), [MacPorts](https://www.macports.org/), or [Fink](https://www.finkproject.org/) should work just as fine: @@ -30,9 +32,9 @@ brew install python brew install git ``` -### Windows +## Windows -#### Regular installation +### Regular installation The software prerequisites need to be downloaded and installed from their particular web sites. @@ -47,6 +49,6 @@ For git: All further steps need to be performed using the windows shell `cmd.exe`. You can open it from the Start-Menu. -#### Using the Windows Subsystem for Linux (WSL) +### Using the Windows Subsystem for Linux (WSL) -As an alternative for advanced users, you can use the Windows Subsystem for Linux (WSL) to install a Linux distribution whithin Windows 10. The installation is explained in the [Microsoft documentation](https://docs.microsoft.com/en-us/windows/wsl/install-win10). When using WSL, please install Python3 as explained in the Linux section. +As an alternative for advanced users, you can use the Windows Subsystem for Linux (WSL) to install a Linux distribution within Windows 10. The installation is explained in the [Microsoft documentation](https://docs.microsoft.com/en-us/windows/wsl/install-win10). When using WSL, please install Python3 as explained in the Linux section. diff --git a/docs/releases.md b/docs/releases.md deleted file mode 100644 index de98902..0000000 --- a/docs/releases.md +++ /dev/null @@ -1,77 +0,0 @@ -Releases -======== - -Requirements ------------- - -Install `build` and `twine` - -``` -pip install build twine -``` - -Create `~/.pypirc` - -``` -[pypi] -username: ... -password: ... - -[testpypi] -repository: https://test.pypi.org/legacy/ -username: ... -password: ... -``` - -Prepare repo ------------- - -1) Ensure tests are passing. - -2) Update version in `isimip_utils/__init__.py`. - -3) Build `sdist` and `bdist_wheel`: - - ``` - python -m build - ``` - -4) Check: - - ``` - twine check dist/* - ``` - - -Release on Test PyPI --------------------- - -1) Upload with `twine` to Test PyPI: - - ``` - twine upload -r testpypi dist/* - ``` - -2) Check at https://test.pypi.org/project/isimip-utils/. - - -Release on PyPI ---------------- - -1) Upload with `twine` to PyPI: - - ``` - twine upload dist/* - ``` - -2) Check at https://pypi.org/project/isimip-utils/. - - -Create release on GitHub ------------------------- - -1) Commit local changes. - -2) Push changes. - -3) Create release on https://github.com/ISI-MIP/isimip-utils/releases). diff --git a/isimip_utils/__init__.py b/isimip_utils/__init__.py index a080f04..787bb6e 100644 --- a/isimip_utils/__init__.py +++ b/isimip_utils/__init__.py @@ -1 +1,7 @@ -VERSION = __version__ = '1.3.2' +from importlib.metadata import PackageNotFoundError +from importlib.metadata import version as _version + +try: + VERSION = __version__ = _version(__package__) +except PackageNotFoundError: + VERSION = __version__ = "0.0.0+unknown" diff --git a/isimip_utils/checksum.py b/isimip_utils/checksum.py index 8155f41..80c03d7 100644 --- a/isimip_utils/checksum.py +++ b/isimip_utils/checksum.py @@ -1,5 +1,7 @@ +"""Checksum computation utilities for file integrity verification.""" import hashlib import logging +from pathlib import Path logger = logging.getLogger(__name__) @@ -7,7 +9,16 @@ CHECKSUM_TYPE = 'sha512' -def get_checksum(abspath, checksum_type=CHECKSUM_TYPE): +def get_checksum(abspath: str | Path, checksum_type: str = CHECKSUM_TYPE) -> str: + """Compute the checksum of a file. + + Args: + abspath (str | Path): Absolute path to the file to checksum. + checksum_type (str): Type of checksum algorithm to use (default: 'sha512'). + + Returns: + The hexadecimal digest string of the file's checksum. + """ m = hashlib.new(checksum_type) with open(abspath, 'rb') as f: # read and update in blocks of 64K @@ -16,9 +27,19 @@ def get_checksum(abspath, checksum_type=CHECKSUM_TYPE): return m.hexdigest() -def get_checksum_type(): +def get_checksum_type() -> str: + """Get the default checksum type. + + Returns: + The default checksum algorithm name (e.g., 'sha512'). + """ return CHECKSUM_TYPE -def get_checksum_suffix(): +def get_checksum_suffix() -> str: + """Get the file suffix for checksum files. + + Returns: + The checksum file extension (e.g., '.sha512'). + """ return '.' + CHECKSUM_TYPE diff --git a/isimip_utils/cli.py b/isimip_utils/cli.py new file mode 100644 index 0000000..6c517df --- /dev/null +++ b/isimip_utils/cli.py @@ -0,0 +1,298 @@ +"""Command-line interface utilities for ISIMIP tools.""" +import argparse +import logging +import os +import tomllib +from datetime import datetime +from pathlib import Path +from urllib.parse import urlparse + +from dotenv import load_dotenv +from rich.logging import RichHandler + +from .exceptions import ConfigError + + +def setup_env() -> None: + """Load environment variables from .env file in current working directory.""" + load_dotenv(Path().cwd() / '.env') + + +def setup_logs(log_level: str = 'WARNING', log_file: str | None = None, + log_console: bool = True, log_rich: bool = True, + show_time: bool = False, show_path: bool = False) -> None: + """Configure logging with console and/or file handlers. + + Args: + log_level (str): Logging level (default: 'WARNING'). + log_file (str | None): Path to log file, or None for no file logging (default: None). + log_console (bool): Whether to log to console (default: True). + log_rich (bool): Whether to use RichHandler for console logging (default: True). + show_time (bool): Whether to show the time in the console logs (default: False). + show_path (bool): Whether to show the path in the console logs (default: False). + """ + log_level = log_level.upper() + + root_logger = logging.getLogger() + root_logger.setLevel(log_level) + + if log_console: + if log_rich: + console_handler = RichHandler(show_time=show_time, show_path=show_path) + else: + fmt = '' + if show_time: + fmt += '[%(asctime)s] ' + fmt += '%(levelname)s - ' + if show_path: + fmt += '%(filename)s:%(lineno)d - ' + fmt += '%(message)s' + + console_handler = logging.StreamHandler() + console_handler.setFormatter(logging.Formatter(fmt)) + + console_handler.setLevel(log_level) + root_logger.addHandler(console_handler) + + if log_file is not None: + Path(log_file).parent.mkdir(exist_ok=True, parents=True) + + file_handler = logging.FileHandler(log_file) + file_handler.setLevel(log_level) + file_handler.setFormatter(logging.Formatter('[%(asctime)s] %(levelname)s: %(message)s')) + + root_logger.addHandler(file_handler) + + +def parse_dict(string: str) -> dict[str, list[str]] | None: + """Parse a string in format 'key=value1,value2' into a dictionary. + + Args: + string (str): String to parse in format 'key=value1,value2,value3'. + + Returns: + Dictionary with single key mapping to list of values. + """ + if string: + key, values = string.split('=') + return { + key.strip(): [value.strip() for value in values.split(',')] + } + + +def parse_list(string: str) -> list[str]: + """Parse a comma-separated string into a list. + + Args: + string (str): Comma-separated string to parse. + + Returns: + List of stripped values. + """ + if string: + return [value.strip() for value in string.split(',')] + else: + return [] + + +def parse_version(value: str) -> str: + """Parse a version string in YYYYMMDD format. + + Args: + value (str): Version string in YYYYMMDD format. + + Returns: + Version string in YYYYMMDD format. + + Raises: + argparse.ArgumentTypeError: If format is incorrect. + """ + try: + datetime.strptime(value, '%Y%m%d') + return value + except ValueError as e: + raise argparse.ArgumentTypeError('incorrect format, should be YYYYMMDD') from e + + +def parse_path(value: str) -> Path: + """Parse and expand a path string. + + Args: + value (str): Path string to parse. + + Returns: + Expanded Path object. + """ + return Path(value).expanduser() + + +def parse_locations(value: str | list) -> list[str | Path]: + """Parse and expand a location string as list of URL or Path objects. + + Args: + value (str): Location string to parse. + + Returns: + List of URL or Path objects. + """ + if value: + return [ + string.rstrip('/') if urlparse(string).scheme else Path(string).expanduser() + for string in (value.split() if isinstance(value, str) else value) + ] + else: + return [] + + +def parse_filelist(filelist_file: str | Path | None) -> set[str]: + """Parse a filelist file into a set of file paths. + + Args: + filelist_file (str | Path | None): Path to file containing list of paths (one per line). + Lines starting with '#' are treated as comments. + + Returns: + List of file paths. + """ + if filelist_file: + with open(filelist_file) as f: + return list({line for line in f.read().splitlines() if (line and not line.startswith('#'))}) + else: + return [] + + +def parse_parameters(value: str) -> Path: + """Parse and expand a parameters string (a=b). + + Args: + value (str): Parameter string to parse. + + Returns: + Dict of the form {key: values} + """ + if value: + key, values_str = value.split('=') + values = values_str.split(',') + return {key: values} + else: + return {} + + +class ArgumentParser(argparse.ArgumentParser): + """Extended ArgumentParser that reads defaults from config files and environment. + + Supports reading configuration from TOML files in the following order: + + - `./isimip.toml` + - `~/.isimip.toml` + - `/etc/isimip.toml` + + Environment variables (uppercase) override config file values. + """ + + config_files = [ + 'isimip.toml', + '~/.isimip.toml', + '/etc/isimip.toml', + ] + + env_prefix = 'ISIMIP_' + + def parse_args(self, *args, config_path=None) -> argparse.Namespace: + return super().parse_args(*args, namespace=self.build_default_args(config_path)) + + def get_defaults(self) -> dict: + defaults = {} + for action in self._actions: + if not action.required and action.dest != 'help': + defaults[action.dest] = action.default + + defaults.update(vars(self.build_default_args())) + return defaults + + def read_global_config(self) -> dict: + for config_file in self.config_files: + config_path = Path(config_file).expanduser() + if config_path.is_file(): + with open(config_path, 'rb') as fp: + data = tomllib.load(fp) + if self.prog in data: + return data[self.prog] + return {} + + def read_local_config(self, config_path) -> dict: + if config_path and config_path.is_file(): + with open(config_path, 'rb') as fp: + return tomllib.load(fp) + return {} + + def build_default_args(self, config_path=None) -> argparse.Namespace: + # read config file(s) + config = self.read_global_config() + config.update(self.read_local_config(config_path)) + + # init the default namespace + default_args = argparse.Namespace() + + for action in self._actions: + if action.dest not in ['config', 'help']: + key = action.dest + key_upper = key.upper() + key_env = self.env_prefix + key_upper + + value = None + + if os.getenv(key_env): + # if the attribute is in the environment, take the value + value = os.getenv(key_env) + if value.lower() == 'true': + value = True + elif value.lower() == 'false': + value = False + elif value.lower() == 'none': + value = None + + # apply action type + if value and action.type is not None and value not in [True, False]: + try: + value = action.type(value) + except argparse.ArgumentTypeError as e: + raise ConfigError(f'argument "{key}": {e}') from e + + elif config and key in config: + # if the attribute is in the config file, take it from there + value = config.get(key) + + # apply certain action types + if value and action.type in [parse_filelist, parse_locations, parse_path, parse_version]: + try: + value = action.type(value) + except argparse.ArgumentTypeError as e: + raise ConfigError(f'argument "{key}": {e}') from e + + if value is not None: + # check action.action + if action.const: + if value is True: + value = action.const + elif value is False: + value = None + + # check action.choices + if action.choices and value not in action.choices: + raise ConfigError(f'argument "{key}": invalid choice "{value}" (choose from {action.choices})') + + # check list + if action.type in (list, parse_list, parse_locations): + if not isinstance(value, list): + raise ConfigError(f'argument "{key}": needs to be a list') + + # check dict + if action.type in (dict, parse_dict): + if not isinstance(value, list): + raise ConfigError(f'argument "{key}": needs to be a dict') + + # add the key and value to the default_args + setattr(default_args, key, value) + + return default_args diff --git a/isimip_utils/config.py b/isimip_utils/config.py index cf58ec0..835feed 100644 --- a/isimip_utils/config.py +++ b/isimip_utils/config.py @@ -1,47 +1,84 @@ +"""Configuration management for ISIMIP tools.""" import logging +import tomllib from pathlib import Path +from typing import Any -from colorlog import ColoredFormatter, StreamHandler +from .utils import Singleton +logger = logging.getLogger(__name__) -class Settings: - _shared_state = {} +class Settings(Singleton): + """Singleton settings class for managing application configuration. - def __init__(self): - self.__dict__ = self._shared_state + This class provides a centralized settings store that combines input from + argparse, environment variables, and config files. Settings are stored as + uppercase keys and can be accessed as attributes. + """ + _settings: dict[str, Any] = {} - def __str__(self): - return str(self.args) + ignore_keys = ('config', ) - def setup(self, args): - # reset the shared state - self.__dict__ = self._shared_state = {} + def __repr__(self) -> str: + return str(self._settings) - # assign args to settings object - self.args = {key.upper(): value for key, value in args.items()} + def __getattr__(self, name: str) -> Any: + if name in self._settings.keys(): + return self._settings[name] + else: + raise AttributeError(f"{self.__class__.__name__} object has no attribute '{name}'") - # setup logs - try: - self.LOG_LEVEL = self.LOG_LEVEL.upper() - self.LOG_FILE = Path(self.LOG_FILE).expanduser() if self.LOG_FILE else None + def __setattr__(self, name: str, value: Any) -> None: + if name.startswith('_'): + # allow normal attribute for internal data + super().__setattr__(name, value) + else: + self._settings[name] = value - if self.LOG_FILE: - logging.basicConfig(level=self.LOG_LEVEL, filename=self.LOG_FILE, - format='[%(asctime)s] %(levelname)s %(name)s: %(message)s') + def to_dict(self) -> dict[str, Any]: + """Return the settings as a dictionary. + + Returns: + Dictionary of all settings. + """ + return self._settings + + @classmethod + def from_dict(cls, values: dict[str, Any]) -> 'Settings': + """Create a Settings instance from a dictionary. + + Args: + values (dict[str, Any]): Dictionary of setting key-value pairs. + + Returns: + A Settings instance populated with the provided values. + All keys are converted to uppercase. + """ + instance = cls() + instance._settings = {key.upper(): value for key, value in values.items() if key not in cls.ignore_keys} + logger.debug('settings = %s', instance) + return instance + + @classmethod + def from_toml(cls, path: Path, section: str | None = None) -> 'Settings': + """Create a Settings instance from a TOML file. + + Args: + path (Path): Path to the TOML file. + section (str): Section to use. + + Returns: + A Settings instance populated with the content of the TOML file. + All keys are converted to uppercase. + """ + values = {} + with open(path, 'rb') as fp: + data = tomllib.load(fp) + if section: + if section in data: + values = data[section] else: - formatter = ColoredFormatter('%(log_color)s[%(asctime)s] %(levelname)s %(name)s: %(message)s') - handler = StreamHandler() - handler.setFormatter(formatter) - logging.basicConfig(level=self.LOG_LEVEL, handlers=[handler]) - - except AttributeError: - pass - - def __getattr__(self, name): - # this function catches all properties and returns the values in the self.args dict, e.g. - # settings.FOO -> settings.args['FOO'] - try: - return self.args[name] - except KeyError as e: - raise AttributeError from e + values = data + + return cls.from_dict(values) diff --git a/isimip_utils/decorators.py b/isimip_utils/decorators.py deleted file mode 100644 index c3b165f..0000000 --- a/isimip_utils/decorators.py +++ /dev/null @@ -1,24 +0,0 @@ -""" -Simplified version of Django'd cached_property -https://github.com/django/django/blob/main/django/utils/functional.py -""" - - -class cached_property: - - name = None - - def __init__(self, func): - self.func = func - - def __set_name__(self, owner, name): - if self.name is None: - self.name = name - else: - raise TypeError("Cannot assign the same cached_property to two different names") - - def __get__(self, instance, cls=None): - if instance is None: - return self - value = instance.__dict__[self.name] = self.func(instance) - return value diff --git a/isimip_utils/exceptions.py b/isimip_utils/exceptions.py index 6019e6f..7d67629 100644 --- a/isimip_utils/exceptions.py +++ b/isimip_utils/exceptions.py @@ -1,5 +1,26 @@ +"""Custom exceptions for ISIMIP tools.""" + + +class ExtractionError(RuntimeError): + """Raised when data extraction operations fail.""" + pass + + +class ValidationError(RuntimeError): + """Raised when data validation fails.""" + pass + + class DidNotMatch(RuntimeError): + """Raised when a pattern does not match the expected format.""" pass + class NotFound(RuntimeError): + """Raised when a required resource or file is not found.""" + pass + + +class ConfigError(RuntimeError): + """Raised when there is an error in configuration.""" pass diff --git a/isimip_utils/extractions.py b/isimip_utils/extractions.py new file mode 100644 index 0000000..12d9829 --- /dev/null +++ b/isimip_utils/extractions.py @@ -0,0 +1,347 @@ +"""Data extraction and manipulation utilities for xarray datasets.""" +import logging +from collections.abc import Iterable +from datetime import datetime +from typing import Literal + +import numpy as np +import xarray as xr + +from .exceptions import ExtractionError +from .utils import validate_lat, validate_lon +from .xarray import compute_offset, compute_time, get_attrs, set_attrs, set_fill_value_to_nan + +logger = logging.getLogger(__name__) + + +def select_time(ds: xr.Dataset, timestamp: datetime) -> xr.Dataset | None: + """Select a single time point from a dataset. + + Args: + ds (xr.Dataset): Dataset with time dimension. + timestamp (datetime): Timestamp to select. + + Returns: + Dataset at the selected time, or None if timestamp is outside range. + """ + logger.info(f'select time time={timestamp}') + if ds.time.encoding.get('units'): + time = np.datetime64(timestamp) + else: + time = compute_time(ds, timestamp) + + if time < ds['time'].min() or time > ds['time'].max(): + logger.warning(f'Selected time={time} is outside the dataset.') + return None + + return ds.sel(time=time, method='nearest') + + +def select_period(ds: xr.Dataset, start: datetime | None, end: datetime | None) -> xr.Dataset: + """Select a time period from a dataset. + + Args: + ds (xr.Dataset): Dataset with time dimension. + start (datetime | None): Start of period, or None for beginning. + end (datetime | None): End of period, or None for end. + + Returns: + Dataset with time dimension sliced to the period. + + Raises: + ExtractionError: If no time axis remains after selection. + """ + logger.info(f'select period start={start} end={end}') + if ds.time.encoding.get('units'): + start_time, end_time = np.datetime64(start), np.datetime64(end) + else: + start_time, end_time = compute_time(ds, start), compute_time(ds, end) + + ds = ds.sel(time=slice(start_time, end_time)) + + if 'time' not in ds.sizes: + raise ExtractionError('No time axis remains after selecting period.') + + return ds + + +def select_point(ds: xr.Dataset, lat: float, lon: float) -> xr.Dataset: + """Select a single geographic point from a dataset. + + Args: + ds (xr.Dataset): Dataset with lat/lon dimensions. + lat (float): Latitude (-90 to 90). + lon (float): Longitude (-180 to 180). + + Returns: + Dataset at the nearest grid point. + + Raises: + ValidationError: If lat/lon are out of valid range. + """ + logger.info(f'select point lat={lat} lon={lon}') + validate_lat(lat) + validate_lon(lon) + return ds.sel(lat=lat, lon=lon, method='nearest') + + +def select_bbox(ds: xr.Dataset, west: float, east: float, south: float, north: float) -> xr.Dataset: + """Select a bounding box region from a dataset. + + Args: + ds (xr.Dataset): Dataset with lat/lon dimensions. + west (float): Western longitude boundary (-180 to 180). + east (float): Eastern longitude boundary (-180 to 180). + south (float): Southern latitude boundary (-90 to 90). + north (float): Northern latitude boundary (-90 to 90). + + Returns: + Dataset with lat/lon dimensions sliced to the bounding box. + + Raises: + ValidationError: If coordinates are out of valid range. + ExtractionError: If no lat or lon axis remains after selection. + """ + logger.info(f'select bbox west={west} east={east} south={south} north={north}') + + validate_lat(south) + validate_lat(north) + validate_lon(west) + validate_lon(east) + + lat_slice = slice(north, south) if ds.lon.values[1] > ds.lon.values[0] else slice(south, north) + lon_slice = slice(west, east) + + ds = ds.sel(lat=lat_slice, lon=lon_slice) + + if 'lat' not in ds.sizes: + raise ExtractionError('No lat axis remains after selecting bbox.') + elif 'lon' not in ds.sizes: + raise ExtractionError('No lon axis remains after selecting bbox.') + + return ds + + +def mask_bbox(ds: xr.Dataset, west: float, east: float, south: float, north: float) -> xr.Dataset: + """Mask a dataset to a bounding box, setting values outside to NaN. + + Args: + ds (xr.Dataset): Dataset with lat/lon dimensions. + west (float): Western longitude boundary (-180 to 180). + east (float): Eastern longitude boundary (-180 to 180). + south (float): Southern latitude boundary (-90 to 90). + north (float): Northern latitude boundary (-90 to 90). + + Returns: + Dataset with values outside bounding box masked as NaN. + + Raises: + ValidationError: If coordinates are out of valid range. + """ + logger.info(f'mask bbox west={west} east={east} south={south} north={north}') + + validate_lat(south) + validate_lat(north) + validate_lon(west) + validate_lon(east) + + lat = ds['lat'] + lon = ds['lon'] + + if west > east: + lon_mask = (lon >= west) | (lon <= east) + else: + lon_mask = (lon >= west) & (lon <= east) + + lat_mask = (lat >= south) & (lat <= north) + + mask = lat_mask & lon_mask + + ds = ds.where(mask) + + return ds + + +def mask_mask(ds: xr.Dataset, mask_ds: xr.Dataset, mask_var: str = 'mask', + inverse: bool = False) -> xr.Dataset: + """Apply a mask dataset to another dataset. + + Args: + ds (xr.Dataset): Dataset to mask. + mask_ds (xr.Dataset): Dataset containing mask variable. + mask_var (str): Name of mask variable (default: 'mask'). + inverse (bool): Whether to invert the mask (default: False). + + Returns: + Masked dataset with values where mask is 1 (or 0 if inverse=True). + """ + logger.info(f'mask {mask_var}') + return ds.where(np.isclose(mask_ds[mask_var], 0 if inverse else 1)) + + +def compute_aggregation(ds: xr.Dataset, type: Literal['mean', 'min', 'max', 'sum', 'std'], + dim: str | Iterable | None = None, weights: xr.DataArray | None = None) -> xr.Dataset: + """Compute aggregated values over selected dimensions and add dummy dimensions like CDO. + + Args: + ds (xr.Dataset): Dataset to process. + type (str): Type of aggregation. + dim (str|Iterable): Dimensions along which apply mean [default: ('lat', 'lon')] + weights (xr.DataArray | None): Weights for averaging over lat/lon. If None, uses latitude-dependent weights. + + Returns: + Dataset with aggregated values over selected dimensions. + """ + dim = dim or ('lat', 'lon') + dim_expand = {d: [0] for d in ([dim] if isinstance(dim, str) else dim)} + dim_transpose = list(ds.dims) + + logger.info('compute %s %s', type, dim) + + attrs = get_attrs(ds) + + ds = set_fill_value_to_nan(ds) + + if type in ('mean', 'std', 'sum') and dim == ('lat', 'lon'): + if weights is None: + logger.warning('no weights provided, using latitude-dependent weights') + weights = np.sin(np.deg2rad(ds.lat + 0.25)) - np.sin(np.deg2rad(ds.lat - 0.25)) + + ds = ds.weighted(weights) + + if type == 'mean': + ds = ds.mean(dim=dim, skipna=True) + elif type == 'std': + ds = ds.std(dim=dim, skipna=True) + elif type == 'sum': + ds = ds.sum(dim=dim, skipna=True) + elif type == 'min': + ds = ds.min(dim=dim, skipna=True) + elif type == 'max': + ds = ds.max(dim=dim, skipna=True) + else: + raise RuntimeError(f'unknown type "{type}" in compute_aggregation') + + ds = ds.expand_dims(**dim_expand).transpose(*dim_transpose).astype(np.float32) + ds = set_attrs(ds, attrs) + + return ds + + +def compute_mean(ds: xr.Dataset, dim: str | Iterable | None = None, weights: xr.DataArray | None = None) -> xr.Dataset: + """ + Compute mean values over selected dimensions and add dummy dimensions like CDO. Wrapper for compute_aggregation. + + Args: + ds (xr.Dataset): Dataset to process. + dim (str|Iterable): Dimensions along which apply mean [default: ('lat', 'lon')] + weights (xr.DataArray | None): Weights for averaging over lat/lon. If None, uses latitude-dependent weights. + + Returns: + Dataset with mean values over selected dimensions. + """ + return compute_aggregation(ds, 'mean', dim, weights) + + +def compute_std(ds: xr.Dataset, dim: str | Iterable | None = None, weights: xr.DataArray | None = None) -> xr.Dataset: + """ + Compute the standard deviation over selected dimensions and add dummy dimensions like CDO. + Wrapper for compute_aggregation. + + Args: + ds (xr.Dataset): Dataset to process. + dim (str|Iterable): Dimensions along which apply mean [default: ('lat', 'lon')] + weights (xr.DataArray | None): Weights for averaging over lat/lon. If None, uses latitude-dependent weights. + + Returns: + Dataset with the standard deviation over selected dimensions. + """ + return compute_aggregation(ds, 'std', dim, weights) + + +def compute_sum(ds: xr.Dataset, dim: str | Iterable | None = None, weights: xr.DataArray | None = None) -> xr.Dataset: + """ + Compute the sum over selected dimensions and add dummy dimensions like CDO. Wrapper for compute_aggregation. + + Args: + ds (xr.Dataset): Dataset to process. + dim (str|Iterable): Dimensions along which apply mean [default: ('lat', 'lon')] + weights (xr.DataArray | None): Weights for averaging over lat/lon. If None, uses latitude-dependent weights. + + Returns: + Dataset with the sum over selected dimensions. + """ + return compute_aggregation(ds, 'sum', dim, weights) + + +def compute_min(ds: xr.Dataset, dim: str | Iterable | None = None) -> xr.Dataset: + """ + Compute minimum values over selected dimensions and add dummy dimensions like CDO. Wrapper for compute_aggregation. + + Args: + ds (xr.Dataset): Dataset to process. + dim (str|Iterable): Dimensions along which apply mean [default: ('lat', 'lon')] + weights (xr.DataArray | None): Weights for averaging over lat/lon. If None, uses latitude-dependent weights. + + Returns: + Dataset with minimum values over selected dimensions. + """ + return compute_aggregation(ds, 'min', dim) + + +def compute_max(ds: xr.Dataset, dim: str | Iterable | None = None) -> xr.Dataset: + """ + Compute maximum values over selected dimensions and add dummy dimensions like CDO. Wrapper for compute_aggregation. + + Args: + ds (xr.Dataset): Dataset to process. + dim (str|Iterable): Dimensions along which apply mean [default: ('lat', 'lon')] + weights (xr.DataArray | None): Weights for averaging over lat/lon. If None, uses latitude-dependent weights. + + Returns: + Dataset with maximum values over selected dimensions. + """ + return compute_aggregation(ds, 'max', dim) + + +def count_values(ds: xr.Dataset, dim: str | Iterable | None = None) -> xr.Dataset: + """Count non-NaN values over lat/lon dimensions. + + Args: + ds (xr.Dataset): Dataset with lat/lon dimensions. + dim (str|Iterable): Dimensions along which to count [default: ('lat', 'lon')] + + Returns: + Dataset with count of non-NaN values per time step. + """ + dim = dim or ('lat', 'lon') + logger.info('count values over %s', dim) + + ds = set_fill_value_to_nan(ds) + ds = ds.count(dim=dim).astype(np.float32) + + return ds + + +def concat_extraction(ds1: xr.Dataset | None, ds2: xr.Dataset) -> xr.Dataset: + """Concatenate two datasets along time dimension with offset correction. + + Args: + ds1 (xr.Dataset | None): First dataset, or None. + ds2 (xr.Dataset): Second dataset to concatenate. + + Returns: + Concatenated dataset, or copy of ds2 if ds1 is None. + """ + if ds1 is None: + return ds2.copy() + elif not ds2.sizes.get('time'): + return ds1 + else: + if not ds1.time.encoding or not ds1.time.encoding.get('units'): + # apply offset when time units or calendar diverges, but only if times where not decoded + offset = compute_offset(ds1, ds2) + if offset is not None: + ds2 = ds2.assign_coords(time=ds2['time'] + offset) + + return xr.concat([ds1, ds2], 'time') diff --git a/isimip_utils/fetch.py b/isimip_utils/fetch.py index d56c1d5..9392d21 100644 --- a/isimip_utils/fetch.py +++ b/isimip_utils/fetch.py @@ -1,166 +1,95 @@ +"""Functions to fetch files from urls or local paths.""" import json import logging -import os -import re +import shutil from pathlib import Path -from urllib.parse import urlparse +from typing import Any import requests -from isimip_utils.exceptions import NotFound - logger = logging.getLogger(__name__) -def fetch_definitions(bases, path): - path_components = Path(path).parts - for i in range(len(path_components), 0, -1): - definitions_path = Path('definitions').joinpath(os.sep.join(path_components[:i+1])).with_suffix('.json') - definitions_json = fetch_json(bases, definitions_path, extend_base='output') - - if definitions_json: - logger.debug('definitions_path = %s', definitions_path) - logger.debug('definitions_json = %s', definitions_json) - - definitions = {} - for definition_name, definition in definitions_json.items(): - # convert the definitions to dicts if they are lists - if isinstance(definition, list): - definitions[definition_name] = { - row['specifier']: row for row in definition - } - else: - definitions[definition_name] = definition - - logger.debug('definitions = %s', definitions) - return definitions - - raise NotFound(f'no definitions found for {path}') - - -def fetch_pattern(bases, path): - path_components = Path(path).parts - for i in range(len(path_components), 0, -1): - pattern_path = Path('pattern').joinpath(os.sep.join(path_components[:i+1]) + '.json') - pattern_json = fetch_json(bases, pattern_path, extend_base='output') - - if pattern_json: - logger.debug('pattern_path = %s', pattern_path) - logger.debug('pattern_json = %s', pattern_json) - - if not all([ - isinstance(pattern_json['path'], str), - isinstance(pattern_json['file'], str), - isinstance(pattern_json['dataset'], str), - isinstance(pattern_json['suffix'], list) - ]): - break - - pattern = { - 'path': re.compile(pattern_json['path']), - 'file': re.compile(pattern_json['file']), - 'dataset': re.compile(pattern_json['dataset']), - 'suffix': pattern_json['suffix'], - 'specifiers': pattern_json.get('specifiers', []), - 'specifiers_map': pattern_json.get('specifiers_map', {}) - } +def fetch_json(url: str) -> Any | None: + """Fetch JSON content from a URL. - logger.debug('pattern = %s', pattern) + Args: + location (str | Path): URL to fetch JSON from. - return pattern + Returns: + Parsed JSON object, or None if request fails. + """ + logger.debug('url = %s', url) - raise NotFound(f'no pattern found for {path}') + try: + response = requests.get(url) + except requests.exceptions.ConnectionError: + return None + if response.status_code == 200: + return response.json() -def fetch_schema(bases, path): - path_components = Path(path).parts - for i in range(len(path_components), 0, -1): - schema_path = Path('schema').joinpath(os.sep.join(path_components[:i+1])).with_suffix('.json') - schema_json = fetch_json(bases, schema_path, extend_base='output') - if schema_json: - logger.debug('schema_path = %s', schema_path) - logger.debug('schema_json = %s', schema_json) - return schema_json +def fetch_file(url: str, target: None | str | Path = None) -> bool: + """Download file from a URL. - raise NotFound(f'no schema found for {path}') + Args: + location (str | Path): URL to download file from. + target (str | Path): Target path, or None if the content should be returned. + Returns: + Target path if it was provided, the content otherwise, or None if the request fails. + """ + logger.debug('url = %s', url) -def fetch_tree(bases, path): - path_components = Path(path).parts - for i in range(len(path_components), 0, -1): - tree_path = Path('tree').joinpath(os.sep.join(path_components[:i+1])).with_suffix('.json') - tree_json = fetch_json(bases, tree_path, extend_base='output') + try: + response = requests.get(url) + except requests.exceptions.ConnectionError: + return None - if tree_json: - logger.debug('tree_path = %s', tree_path) - logger.debug('tree_json = %s', tree_json) - return tree_json + if target is None: + return response.content.decode() + else: + target.parent.mkdir(exist_ok=True, parents=True) + if response.status_code == 200: + with open(target, "wb") as fp: + fp.write(response.content) + return target - raise NotFound(f'no tree found for {path}') +def load_json(path: str | Path) -> Any | None: + """Load JSON content from a local path. -def fetch_resource(location): - return fetch_json([location]) + Args: + location (str | Path): URL to fetch JSON from. + Returns: + Parsed JSON object, or None if request fails. + """ + logger.debug('path = %s', path) -def fetch_json(bases, path=None, extend_base=None): - for base in bases: - if urlparse(base).scheme: - if path is not None: - json_url = base.rstrip('/') + '/' + path.as_posix() - else: - json_url = base.rstrip('/') + path = Path(path) + if path.exists(): + return json.loads(open(path).read()) - logger.debug('json_url = %s', json_url) - try: - response = requests.get(json_url) - except requests.exceptions.ConnectionError: - return None +def load_file(path: str | Path, target: None | str | Path = None) -> bool: + """Copy a file from a local path. - if response.status_code == 200: - return response.json() + Args: + location (str | Path): URL to download file from. + target (str | Path): Target path, or None if the content should be returned. - else: - json_path = Path(base).expanduser() - if extend_base is not None: - json_path /= extend_base - if path is not None: - json_path /= path - - logger.debug('json_path = %s', json_path) - - if json_path.exists(): - return json.loads(open(json_path).read()) - - -def fetch_file(bases, path=None, extend_base=None): - for base in bases: - if urlparse(base).scheme: - if path is not None: - file_url = base.rstrip('/') + '/' + path.as_posix() - else: - file_url = base.rstrip('/') - - logger.debug('file_url = %s', file_url) - - try: - response = requests.get(file_url) - except requests.exceptions.ConnectionError: - return None - - if response.status_code == 200: - return response.content + Returns: + Target path if it was provided, the content otherwise, or None if the request fails. + """ + logger.debug('path = %s', path) + path = Path(path) + if path.is_file(): + if target is None: + return path.read_text() else: - file_path = Path(base).expanduser() - if extend_base is not None: - file_path /= extend_base - if path is not None: - file_path /= path - - logger.debug('file_path = %s', file_path) - - if file_path.exists(): - return file_path.read() + target.parent.mkdir(exist_ok=True, parents=True) + shutil.copy(path, target) + return target diff --git a/isimip_utils/files.py b/isimip_utils/files.py new file mode 100644 index 0000000..f974471 --- /dev/null +++ b/isimip_utils/files.py @@ -0,0 +1,39 @@ +"""Functions to find files for specific datasets.""" +import logging +import re +from collections.abc import Iterable +from pathlib import Path + +logger = logging.getLogger(__name__) + + +def find_files(file_iter: Iterable[Path], + pattern: str = r'_(?P\d{4})_(?P\d{4})?\.nc\d*$') -> tuple[list[tuple], int, int]: + """Find files for a given (dataset) path, matching a regex pattern for start and end year. + + Args: + file_iter (Iterable[Path]): Iterator over file paths to search through. + pattern (str): Regular expression for start and end year matching. + + Returns: + Tuple containing (a) the List of tuples containing the path and the start and end years for each file, + (b) the lowest start year, and (c) the highest end year. + """ + files = [] + + for file_path in sorted(file_iter): + match = re.search(pattern, str(file_path), re.IGNORECASE) + if match: + try: + start_year = int(match.group('start_year')) + except TypeError: + start_year = None + + try: + end_year = int(match.group('end_year')) + except TypeError: + end_year = None + + files.append((file_path, start_year, end_year)) + + return files diff --git a/isimip_utils/netcdf.py b/isimip_utils/netcdf.py index 8072d94..f8e2430 100644 --- a/isimip_utils/netcdf.py +++ b/isimip_utils/netcdf.py @@ -1,4 +1,7 @@ +"""Functions to open and read NetCDF files using netCDF4.""" from datetime import datetime +from pathlib import Path +from typing import Any import numpy as np from netCDF4 import Dataset @@ -9,28 +12,78 @@ INT_TYPES = [np.int8, np.uint8, np.int16, np.uint16, np.int32, np.uint32, np.int64, np.uint64] -def open_dataset_read(file_path): - return Dataset(file_path, 'r') +def open_dataset(file_path: str | Path, mode: str = 'r') -> Dataset: + """Open a NetCDF dataset (just a wrapper for netcdf.Dataset). + Args: + file_path (str | Path): Path to the NetCDF file. + mode (str): + Returns: + NetCDF4 Dataset object opened in the selected mode. + """ + return Dataset(file_path, mode) -def open_dataset_write(file_path): + +def open_dataset_read(file_path: str | Path) -> Dataset: + """Open a NetCDF dataset in read-only mode. + + Args: + file_path (str | Path): Path to the NetCDF file. + + Returns: + NetCDF4 Dataset object opened in read mode. + """ + return open_dataset(file_path) + + +def open_dataset_write(file_path: str | Path) -> Dataset: + """Open a NetCDF dataset in read/write mode. + + Args: + file_path (str | Path): Path to the NetCDF file. + + Returns: + NetCDF4 Dataset object opened in read/write mode. + """ return Dataset(file_path, 'r+') -def init_dataset(file_path, diskless=False, lon=720, lat=360, time=True, - time_unit='days since 1601-1-1 00:00:00', - time_calendar='proleptic_gregorian', **variables): +def init_dataset(file_path: str | Path, diskless: bool = False, overwrite: bool = False, lon: int = 720, lat: int = 360, + time: None | np.ndarray = None, time_unit: str = 'days since 1601-1-1 00:00:00', + time_calendar: str = 'proleptic_gregorian', attrs: None | dict = None, **variables: Any) -> Dataset: + """Initialize a new NetCDF4 dataset with standard dimensions and variables. + + Args: + file_path (str | Path): Path where the NetCDF file will be created. + diskless (bool): If True, create dataset in memory (default: False). + overwrite (bool): If True, overwrite existing dataset (default: False). + lon (int): Number of longitude points (default: 720). + lat (int): Number of latitude points (default: 360). + time (np.ndarray): Time dimension configuration (default: None). + time_unit (str): Units for the time dimension (default: 'days since 1601-1-1 00:00:00'). + time_calendar (str): Calendar type for time dimension (default: 'proleptic_gregorian'). + attrs (dict): Dictionary of attributes for variables and global attributes. + **variables (Any): Data variables to create in the dataset. + + Returns: + Initialized NetCDF4 Dataset object. + """ + # overwrite existing file + if overwrite and file_path.exists(): + file_path.unlink() + + # create NetCDF dataset ds = Dataset(file_path, 'w', format='NETCDF4_CLASSIC', diskless=diskless) + # create time dimension if time is set if time is not None and time is not False: ds.createDimension('time', None) - d_lon = 360.0 / lon - d_lat = 180.0 / lat - + # create lon and lat dimensions ds.createDimension('lon', lon) ds.createDimension('lat', lat) + # create time variable if time is set if time is not None: time_variable = ds.createVariable('time', 'f8', ('time',), fill_value=FILL_VALUE) time_variable.missing_value = FILL_VALUE @@ -42,46 +95,73 @@ def init_dataset(file_path, diskless=False, lon=720, lat=360, time=True, if isinstance(time, np.ndarray): time_variable[:] = time + # create lon variable + lon_delta = 360.0 / lon lon_variable = ds.createVariable('lon', 'f8', ('lon',), fill_value=FILL_VALUE) lon_variable.missing_value = FILL_VALUE lon_variable.standard_name = 'longitude' lon_variable.long_name = 'Longitude' lon_variable.units = 'degrees_east' lon_variable.axis = 'X' - lon_variable[:] = np.arange(-180 + 0.5 * d_lon, 180, d_lon) + lon_variable[:] = np.arange(-180 + 0.5 * lon_delta, 180, lon_delta) + # create lat variable + lat_delta = 180.0 / lat lat_variable = ds.createVariable('lat', 'f8', ('lat',), fill_value=FILL_VALUE) lat_variable.missing_value = FILL_VALUE lat_variable.standard_name = 'latitude' lat_variable.long_name = 'Latitude' lat_variable.units = 'degrees_north' lat_variable.axis = 'Y' - lat_variable[:] = np.arange(90 - 0.5 * d_lat, -90, -d_lat) - - for variable_name, variable_dict in variables.items(): - long_name = variable_dict.get('long_name') - dtype = variable_dict.get('dtype', 'f8') - dimensions = variable_dict.get('dimensions', ('time', 'lat', 'lon')) - units = variable_dict.get('units') - - if variable_name: - variable = ds.createVariable(variable_name, dtype, dimensions, - fill_value=FILL_VALUE, compression='zlib') - variable.missing_value = FILL_VALUE - variable.standard_name = variable_name - if long_name: - variable.long_name = long_name - if units: - variable.units = units + lat_variable[:] = np.arange(90 - 0.5 * lat_delta, -90, -lat_delta) + + # create a data variable for each provided variable + for variable_name, variable in variables.items(): + + dimensions = ('time', 'lat', 'lon') if time is not None else ('lat', 'lon') + var = ds.createVariable(variable_name, variable.dtype, dimensions, + fill_value=FILL_VALUE, compression='zlib') + + # set variable attributes + if attrs: + for key, value in attrs.get(variable_name, {}).items(): + setattr(var, key, value) + + # set missing value + var.missing_value = np.float32(FILL_VALUE) + + # set variable data + var[:] = variable + + # set global attributes + if attrs: + for key, value in attrs.get('global', {}).items(): + setattr(ds, key, value) return ds -def get_data_model(dataset): +def get_data_model(dataset: Dataset) -> str: + """Get the data model of a NetCDF dataset. + + Args: + dataset (Dataset): NetCDF4 Dataset object. + + Returns: + String representing the data model (e.g., 'NETCDF4', 'NETCDF4_CLASSIC'). + """ return dataset.data_model -def get_dimensions(dataset): +def get_dimensions(dataset: Dataset) -> dict[str, int]: + """Get dimensions from a NetCDF dataset. + + Args: + dataset (Dataset): NetCDF4 Dataset object. + + Returns: + Dictionary mapping dimension names to their sizes. + """ dimensions = {} for dimension_name, dimension in dataset.dimensions.items(): dimensions[dimension_name] = dimension.size @@ -89,7 +169,16 @@ def get_dimensions(dataset): return dimensions -def get_variables(dataset, convert=False): +def get_variables(dataset: Dataset, convert: bool = False) -> dict[str, Any]: + """Get variables and their attributes from a NetCDF dataset. + + Args: + dataset (Dataset): NetCDF4 Dataset object. + convert (bool): If True, convert numpy types to Python types (default: False). + + Returns: + Dictionary mapping variable names to their attributes and dimensions. + """ variables = {} for variable_name, variable in dataset.variables.items(): @@ -105,7 +194,16 @@ def get_variables(dataset, convert=False): return variables -def get_global_attributes(dataset, convert=False): +def get_global_attributes(dataset: Dataset, convert: bool = False) -> dict[str, Any]: + """Get global attributes from a NetCDF dataset. + + Args: + dataset (Dataset): NetCDF4 Dataset object. + convert (bool): If True, convert numpy types to Python types (default: False). + + Returns: + Dictionary of global attributes. + """ if convert: global_attributes = {} for key, value in dataset.__dict__.items(): @@ -116,7 +214,25 @@ def get_global_attributes(dataset, convert=False): return global_attributes -def convert_attribute(value): +def get_index(dataset: Dataset, lat: float, lon: float) -> tuple[int, int]: + dx = dataset.variables['lon'][1] - dataset.variables['lon'][0] + dy = dataset.variables['lat'][1] - dataset.variables['lat'][0] + + ix = round(float((lon - dataset.variables['lon'][0]) / dx)) + iy = round(float((lat - dataset.variables['lat'][0]) / dy)) + + return ix, iy + + +def convert_attribute(value: Any) -> Any: + """Convert numpy types to Python native types. + + Args: + value (Any): Value to convert (may be numpy array, float, int, or other type). + + Returns: + Converted value with Python native types. + """ if type(value) in LIST_TYPES: value = [convert_attribute(v) for v in value] elif type(value) in FLOAT_TYPES: @@ -126,17 +242,35 @@ def convert_attribute(value): return value -def update_global_attributes(dataset, set_attributes={}, delete_attributes=[]): - for attr in dataset.__dict__: - if attr in delete_attributes: - dataset.delncattr(attr) +def update_global_attributes(dataset: Dataset, set_attributes: dict | None = None, + delete_attributes: list | None = None) -> None: + """Update global attributes of a NetCDF dataset. + + Args: + dataset (Dataset): NetCDF4 Dataset object. + set_attributes (dict): Dictionary of attributes to set or update. + delete_attributes (list): List of attribute names to delete. + """ + if delete_attributes is not None: + for attr in dataset.__dict__: + if attr in delete_attributes: + dataset.delncattr(attr) + + if set_attributes is not None: + for attr, value in set_attributes.items(): + dataset.setncattr(attr, value2string(value)) + - for attr, value in set_attributes.items(): - dataset.setncattr(attr, value2string(value)) +def value2string(value: Any) -> str: + """Convert a value to string representation. + Args: + value (Any): Value to convert. Datetime objects get ISO format with 'Z' suffix. -def value2string(value): + Returns: + String representation of the value. + """ if isinstance(value, datetime): - return value.isoformat() + 'Z', + return value.isoformat() + 'Z' else: return str(value) diff --git a/isimip_utils/pandas.py b/isimip_utils/pandas.py new file mode 100644 index 0000000..de3f286 --- /dev/null +++ b/isimip_utils/pandas.py @@ -0,0 +1,267 @@ +"""Pandas DataFrame utilities for ISIMIP data.""" +from typing import Literal + +import pandas as pd + + +def get_coords(df: pd.DataFrame) -> tuple: + """Get the coordinate names from DataFrame attributes. + + Args: + df (pd.DataFrame): DataFrame with 'coords' in attrs. + + Returns: + Name of the coordinates. + """ + return tuple(df.attrs['coords']) + + +def get_first_coord(df: pd.DataFrame) -> str: + """Get the first coordinate name from DataFrame attributes. + + Args: + df (pd.DataFrame): DataFrame with 'coords' in attrs. + + Returns: + Name of the first coordinate. + """ + return next(iter(get_coords(df))) + + +def get_coord_labels(df: pd.DataFrame) -> tuple: + """Get a formatted labels for the coordinates with units. + + Args: + df (pd.DataFrame): DataFrame with 'coords' in attrs. + + Returns: + Formatted string like "Coordinate Name [units]" or just the name if no units. + """ + labels = [] + for coord in get_coords(df): + name = df.attrs['coords'][coord].get('long_name', coord) + units = df.attrs['coords'][coord].get('units') + labels.append(f'{name} [{units}]' if units else name) + return tuple(labels) + + +def get_first_coord_label(df: pd.DataFrame) -> str | None: + """Get a formatted label for the coordinate with units. + + Args: + df (pd.DataFrame): DataFrame with 'coords' in attrs. + + Returns: + Formatted string like "Coordinate Name [units]" or just the name if no units. + """ + return next(iter(get_coord_labels(df))) + + +def get_coord_axes(df: pd.DataFrame) -> tuple: + """Get the axis attribute for all coordinates. + + Args: + df (pd.DataFrame): DataFrame with 'coords' in attrs. + + Returns: + Axis attribute (e.g., 'T', 'X', 'Y'). + """ + axes = [] + for coord in get_coords(df): + axes.append(df.attrs['coords'][coord].get('axis')) + return tuple(axes) + + +def get_first_coord_axis(df: pd.DataFrame) -> str | None: + """Get the axis attribute for the first coordinate. + + Args: + df (pd.DataFrame): DataFrame with 'coords' in attrs. + + Returns: + Axis attribute (e.g., 'T', 'X', 'Y'), or None if not set. + """ + return next(iter(get_coord_axes(df))) + + +def get_data_vars(df: pd.DataFrame) -> tuple: + """Get the data variable names from DataFrame attributes. + + Args: + df (pd.DataFrame): DataFrame with 'data_vars' in attrs. + + Returns: + Names of the data variables. + """ + return tuple(df.attrs['data_vars']) + + +def get_first_data_var(df: pd.DataFrame) -> str: + """Get the first data variable name from DataFrame attributes. + + Args: + df (pd.DataFrame): DataFrame with 'data_vars' in attrs. + + Returns: + Name of the first data variable. + """ + return next(iter(get_data_vars(df))) + + +def get_data_var_labels(df: pd.DataFrame) -> str: + """Get a formatted label for the data variable with units. + + Args: + df (pd.DataFrame): DataFrame with 'data_vars' in attrs. + + Returns: + Formatted string like "Variable Name [units]" or just the name if no units. + """ + labels = [] + for data_var in get_data_vars(df): + data_var_name = df.attrs['data_vars'][data_var].get('name', data_var) + data_var_units = df.attrs['data_vars'][data_var].get('units') + labels.append(f'{data_var_name} [{data_var_units}]' if data_var_units else data_var_name) + return tuple(labels) + + +def get_first_data_var_label(df: pd.DataFrame) -> str: + """Get a formatted label for the data variable with units. + + Args: + df (pd.DataFrame): DataFrame with 'data_vars' in attrs. + + Returns: + Formatted string like "Variable Name [units]" or just the name if no units. + """ + return next(iter(get_data_var_labels(df))) + + +def compute_average(df: pd.DataFrame, data_var: None | str = None, area: bool = True, + type: Literal['annual', 'monthly'] = 'annual') -> pd.DataFrame: + """Compute yearly or monthly average with optional standard deviation bounds. + + Args: + df (pd.DataFrame): DataFrame with time column and data variable. + data_var (str): Name of the data variable (default: first data var). + area (bool): Whether to include lower/upper bounds using std (default: True). + type ('annual' | 'monthly'): Compute annual or monthly averages + Returns: + DataFrame with yearly aggregated data. + """ + data_var = data_var or get_first_data_var(df) + data_var_long_name = df.attrs['data_vars'][data_var].get('long_name') + data_var_units = df.attrs['data_vars'][data_var].get('units') + + attrs = df.attrs + + if type == 'annual': + column_name = 'year' + df[column_name] = df['time'].dt.year + elif type == 'monthly': + column_name = 'month' + df[column_name] = df['time'].values.astype('datetime64[M]') + else: + raise RuntimeError(f'unknown type "{type}" must be "annual" or "monthly"') + + kwargs = {'mean': (data_var, 'mean')} + if area: + kwargs['lower'] = (data_var, lambda y: y.mean() - y.std()) + kwargs['upper'] = (data_var, lambda y: y.mean() + y.std()) + + df = df.groupby(column_name).agg(**kwargs).reset_index() + + # cast to double + df['mean'] = df['mean'].astype('float64') + if area: + df['lower'] = df['lower'].astype('float64') + df['upper'] = df['upper'].astype('float64') + + # update attrs + df.attrs = attrs + df.attrs['coords'] = {column_name: {'long_name': column_name.capitalize(), 'axis': 'T'}} + df.attrs['data_vars'] = { + 'mean': { + 'name': f'avg {type} {data_var}' + } + } + if data_var_long_name: + df.attrs['data_vars']['mean']['long_name'] = f'Average {type} {data_var_long_name.lower()}' + if data_var_units: + df.attrs['data_vars']['mean']['units'] = data_var_units + + return df + + +def group_by_day(df: pd.DataFrame, data_var: None | str = None) -> pd.DataFrame: + """Group data by day of year and compute mean. + + Args: + df (pd.DataFrame): DataFrame with time column and data variable. + data_var (str): Name of the data variable (default: first data var). + + Returns: + DataFrame grouped by day of year (1-365/366). + """ + data_var = data_var or get_first_data_var(df) + + df['day'] = df['time'].dt.dayofyear + df = df.groupby('day')[data_var].mean().reset_index() + df.attrs['coords'] = {'day': { 'long_name': 'Day of the year'}} + + return df + + +def group_by_month(df: pd.DataFrame, data_var: None | str = None) -> pd.DataFrame: + """Group data by month and compute mean. + + Args: + df (pd.DataFrame): DataFrame with time column and data variable. + data_var (str): Name of the data variable (default: first data var). + + Returns: + DataFrame grouped by month (1-12). + """ + data_var = data_var or get_first_data_var(df) + + df['month'] = df['time'].dt.month + df = df.groupby('month')[data_var].mean().reset_index() + df.attrs['coords'] = {'month': {'long_name': 'Month of the year'}} + + return df + + +def normalize(df: pd.DataFrame, data_var: None | str = None) -> pd.DataFrame: + """Normalize data variable using z-score normalization. + + Args: + df (pd.DataFrame): DataFrame with data variable to normalize. + data_var (str): Name of the data variable (default: first data var). + + Returns: + DataFrame with normalized data variable (mean=0, std=1). + """ + data_var = data_var or get_first_data_var(df) + data_var_long_name = df.attrs['data_vars'][data_var].get('long_name') + + mean, std = df[data_var].mean(), df[data_var].std() + df[data_var] = (df[data_var] - mean) / (std if std > 0 else 1.0) + if data_var_long_name: + df.attrs['data_vars'][data_var]['long_name'] = f'Normalized {data_var_long_name.lower()}' + del df.attrs['data_vars'][data_var]['units'] + + return df + + +def create_label(df: pd.DataFrame, labels: list[str]) -> pd.DataFrame: + """Add a label column to DataFrame by joining label strings. + + Args: + df (pd.DataFrame): DataFrame to add label to. + labels (list[str]): List of label strings to join with spaces. + + Returns: + DataFrame with added 'label' column. + """ + df['label'] = ' '.join(labels) + return df diff --git a/isimip_utils/parameters.py b/isimip_utils/parameters.py new file mode 100644 index 0000000..e898813 --- /dev/null +++ b/isimip_utils/parameters.py @@ -0,0 +1,85 @@ +"""Utility functions for the work with parameters and placeholders.""" +from itertools import product +from pathlib import Path +from typing import Any + + +def get_permutations(parameters: dict[str, list]) -> tuple[tuple]: + """Generate all permutations from parameter value lists. + + Args: + parameters (dict[str, list]): Dictionary mapping parameter names to lists of values. + + Returns: + Tuple of tuples representing all possible combinations of parameter values. + """ + return tuple(product(*parameters.values())) + + +def get_placeholders(parameters: dict[str, list], permutation: tuple) -> dict: + """Convert a permutation tuple into a dictionary of placeholders. + + Args: + parameters (dict[str, list]): Dictionary mapping parameter names to lists of values. + permutation (tuple): Tuple of values representing one permutation. + + Returns: + Dictionary mapping parameter names to their values in this permutation. + """ + return dict(zip(parameters.keys(), permutation, strict=True)) + + +def join_parameters(parameters: dict[str, list[str]], max_count: int = 5, + max_label: str = 'various') -> dict[str, str]: + """Join parameter values into strings, with fallback for large value sets. + + Args: + parameters (dict[str, list[str]]): Dictionary mapping parameter names to lists of values. + max_count (int): Maximum number of values to join (default: 5). + max_label (str): Label to use when value count exceeds max_count (default: 'various'). + + Returns: + Dictionary mapping parameter names to joined strings or max_label. + """ + return { + key: (max_label if len(values) > max_count else '+'.join(values)) + for key, values in parameters.items() + } + + +def copy_placeholders(*placeholder_args: dict, **kwargs: Any) -> dict: + """Merge multiple placeholder dictionaries and additional kwargs. + + Args: + *placeholder_args (dict): Variable number of placeholder dictionaries to merge. + **kwargs (Any): Additional key-value pairs to add to the result. + + Returns: + Dictionary containing all merged placeholders. + """ + placeholders = { + key: value + for placeholder_arg in placeholder_args + for key, value in placeholder_arg.items() + } + placeholders.update(**kwargs) + return placeholders + + +def apply_placeholders(path_template: str | Path, placeholders: dict) -> Path: + """Apply placeholders to a string or path, ensuring that the name of the path is lower case + + Args: + path_template (str | Path): Path template as string or path. + placeholders (dict): Placeholder dictionary. + + Returns: + Path with the applied placeholders. + """ + try: + path = str(path_template).format(**placeholders) + except KeyError as e: + raise RuntimeError(f'Some of the placeholders are missing ({e}).') from e + + path = Path(path) + return path.with_stem(path.stem.lower()) diff --git a/isimip_utils/parser.py b/isimip_utils/parser.py deleted file mode 100644 index 07b36f7..0000000 --- a/isimip_utils/parser.py +++ /dev/null @@ -1,62 +0,0 @@ -import argparse -import configparser -import os -from pathlib import Path - -from dotenv import load_dotenv - - -class ArgumentParser(argparse.ArgumentParser): - - config_file = None - default_config_files = [ - 'isimip.conf', - '~/.isimip.conf', - '/etc/isimip.conf' - ] - - def parse_args(self, *args): - # parse the command line arguments with the default namespace - # obtained from the config file and the environment - return super().parse_args(*args, namespace=self.build_default_args()) - - def get_defaults(self): - defaults = {} - for action in self._actions: - if not action.required and action.dest != 'help': - defaults[action.dest] = action.default - - defaults.update(vars(self.build_default_args())) - return defaults - - def read_config(self): - config_files = [self.config_file] if self.config_file else self.default_config_files - for config_file in config_files: - config_path = Path(config_file).expanduser() - config = configparser.ConfigParser() - config.read(config_path) - if self.prog in config: - return config[self.prog] - - def build_default_args(self): - # setup env from .env file - load_dotenv(Path().cwd() / '.env') - - # read config file - config = self.read_config() - - # init the default namespace - default_args = argparse.Namespace() - - for action in self._actions: - if not action.required and action.dest != 'help': - key = action.dest - key_upper = key.upper() - if os.getenv(key_upper): - # if the attribute is in the environment, take the value - setattr(default_args, key, os.getenv(key_upper)) - elif config and key in config: - # if the attribute is in the config file, take it from there - setattr(default_args, key, config.get(key)) - - return default_args diff --git a/isimip_utils/patterns.py b/isimip_utils/patterns.py index 7cd302f..a66eb86 100644 --- a/isimip_utils/patterns.py +++ b/isimip_utils/patterns.py @@ -1,3 +1,4 @@ +"""Functions to match file names and extract ISIMIP specifiers.""" import logging import re from pathlib import Path @@ -9,20 +10,59 @@ year_pattern = re.compile(r'^\d{4}$') -def match_dataset_path(pattern, dataset_path): +def match_dataset_path(pattern: dict, dataset_path: Path) -> tuple[Path, dict]: + """Match a dataset path against a pattern. + + Args: + pattern (dict): Pattern dictionary containing regex patterns. + dataset_path (Path): Path to the dataset to match. + + Returns: + Tuple of (matched_path, specifiers_dict). + + Raises: + DidNotMatch: If the path doesn't match the pattern. + """ return match_path(pattern, dataset_path, filename_pattern_key='dataset') -def match_file_path(pattern, file_path): +def match_file_path(pattern: dict, file_path: Path) -> tuple[Path, dict]: + """Match a file path against a pattern. + + Args: + pattern (dict): Pattern dictionary containing regex patterns. + file_path (Path): Path to the file to match. + + Returns: + Tuple of (matched_path, specifiers_dict). + + Raises: + DidNotMatch: If the path doesn't match the pattern. + """ return match_path(pattern, file_path) -def match_path(pattern, path, dirname_pattern_key='path', filename_pattern_key='file'): +def match_path(pattern: dict, path: Path, dirname_pattern_key: str = 'path', + filename_pattern_key: str = 'file') -> tuple[Path, dict]: + """Match both directory and filename components of a path against patterns. + + Args: + pattern (dict): Pattern dictionary containing regex patterns and specifiers. + path (Path): Path object to match. + dirname_pattern_key (str): Key in pattern dict for directory pattern (default: 'path'). + filename_pattern_key (str): Key in pattern dict for filename pattern (default: 'file'). + + Returns: + Tuple of (matched_path, specifiers_dict) containing extracted specifiers. + + Raises: + DidNotMatch: If dirname and filename specifiers conflict. + """ dirname_pattern = pattern[dirname_pattern_key] filename_pattern = pattern[filename_pattern_key] # match the dirname and the filename - dirname_path, dirname_specifiers = match_string(dirname_pattern, path.parent.as_posix()) + dirname_path, dirname_specifiers = match_string(dirname_pattern, str(path.parent)) filename_path, filename_specifiers = match_string(filename_pattern, path.name) path = dirname_path / filename_path @@ -30,7 +70,7 @@ def match_path(pattern, path, dirname_pattern_key='path', filename_pattern_key=' # assert that any value in dirname_specifiers at least starts with # its corresponding value (same key) in filename_specifiers # e.g. 'ewe' and 'ewe_north-sea' - for key, value in filename_specifiers.items(): + for key, _ in filename_specifiers.items(): if key in dirname_specifiers: f, d = filename_specifiers[key], dirname_specifiers[key] @@ -52,17 +92,54 @@ def match_path(pattern, path, dirname_pattern_key='path', filename_pattern_key=' return path, specifiers -def match_dataset(pattern, path): +def match_dataset(pattern: dict, path: Path) -> tuple[Path, dict]: + """Match a dataset name against a pattern. + + Args: + pattern (dict): Pattern dictionary containing regex patterns. + path (Path): Path object with dataset name. + + Returns: + Tuple of (matched_path, specifiers_dict). + + Raises: + DidNotMatch: If the dataset name doesn't match the pattern. + """ return match_string(pattern['dataset'], path.name) -def match_file(pattern, path): +def match_file(pattern: dict, path: Path) -> tuple[Path, dict]: + """Match a file name against a pattern. + + Args: + pattern (dict): Pattern dictionary containing regex patterns. + path (Path): Path object with file name. + + Returns: + Tuple of (matched_path, specifiers_dict). + + Raises: + DidNotMatch: If the file name doesn't match the pattern. + """ return match_string(pattern['file'], path.name) -def match_string(pattern, string): - logger.debug(pattern.pattern) - logger.debug(string) +def match_string(pattern: re.Pattern, string: str) -> tuple[Path, dict]: + """Match a string against a regex pattern and extract specifiers. + + Args: + pattern (re.Pattern): Compiled regex pattern with named groups. + string (str): String to match against the pattern. + + Returns: + Tuple of (Path of matched portion, specifiers_dict). + Year values (4-digit numbers) are converted to integers. + + Raises: + DidNotMatch: If the string doesn't match the pattern. + """ + logger.debug('pattern = "%s"', pattern.pattern) + logger.debug('string = "%s"', string) # try to match the string match = pattern.search(string) @@ -77,4 +154,17 @@ def match_string(pattern, string): return Path(match.group(0)), specifiers else: + # try to find a matching fragment + split_pattern = pattern.pattern.split('_') + for i in range(len(split_pattern), 0, -1): + try: + sub_pattern = re.compile('_'.join(split_pattern[:i])) + sub_match = sub_pattern.search(string) + if sub_match: + fragment = sub_match.group(0) + raise DidNotMatch(f'No match for "{string}", last matching fragment was "{fragment}"') + except re.error: + pass + + # just raise the exception if no fragment was found raise DidNotMatch(f'No match for {string} ("{pattern.pattern}")') diff --git a/isimip_utils/plot.py b/isimip_utils/plot.py new file mode 100644 index 0000000..7c84158 --- /dev/null +++ b/isimip_utils/plot.py @@ -0,0 +1,362 @@ +"""Plotting utilities using Altair for ISIMIP data visualization.""" +import json +import logging +from pathlib import Path +from typing import Any + +import altair as alt +import numpy as np +import pandas as pd + +from .pandas import ( + get_first_coord, + get_first_coord_axis, + get_first_coord_label, + get_first_data_var, + get_first_data_var_label, +) + +logger = logging.getLogger(__name__) + +alt.data_transformers.enable('vegafusion') + +@alt.theme.register('isimip_utils', enable=True) +def custom_theme(): + return alt.theme.ThemeConfig({ + "config": { + "mark": { + "color": "steelblue" + } + } + }) + + +def save_plot(chart: alt.Chart, path: str | Path, *args: Any, **kwargs: Any) -> None: + """Save an Altair chart to a file. + + Args: + chart (alt.Chart): Altair chart to save. + path (str | Path): Output file path. + *args (Any): Additional positional arguments for chart.save(). + **kwargs (Any): Additional keyword arguments for chart.save(). + """ + path = Path(path) + + logger.info(f'save {path.absolute()}') + path.parent.mkdir(exist_ok=True, parents=True) + chart.save(path, *args, **kwargs) + + +def save_index(index_path: Path) -> None: + """Save an HTML index file for browsing plot images. + + Creates an interactive HTML page for viewing SVG/PNG files in a directory. + + Args: + index_path (Path): Path where the index.html file will be saved. + """ + index_json = json.dumps([ + str(p.name) for p in sorted(index_path.parent.iterdir()) if p.suffix in ['.svg', '.png'] + ], indent=2).replace('\n', '\n ') + + logger.info(f'save {index_path.absolute()}') + index_path.with_suffix('.html').write_text(r''' + + + + + + + + +
+ + + + +
+
+ +
+ + +'''.replace(r'{{ index_json }}', index_json).strip()) + + +def format_title(permutation: tuple) -> dict: + """Create a plot title from a permutation tuple. + + Args: + permutation (tuple): Tuple of strings to join as title. + + Returns: + Dictionary with Altair title configuration. + """ + return { + "text": ' Β· '.join(permutation), + "fontSize": 16, + "dy": -10 + } + + +def plot_line(df: pd.DataFrame, x_field: str | None = None, x_label: str | None = None, + x_type: str | None = None, y_field: str | None = None, y_label: str | None = None, + y_type: str | None = None, y_format: str | None = None, color_field: str | None = None, + color_type: str | None = None, color_domain: list | None = None, color_range: list | None = None, + color_scheme: str | None = None, color_title: str | None = 'Legend', legend: bool = True, + empty: bool = False, **mark_kwargs: Any) -> alt.Chart: + """Create a line plot from a DataFrame. + + Args: + df (pd.DataFrame): DataFrame to plot. + x_field (str | None): Column name for x-axis (default: auto-detect from attrs). + x_label (str | None): Label for x-axis (default: auto-detect from attrs). + x_type (str | None): Altair type for x-axis (default: 'T' for time, 'Q' for quantitative). + y_field (str | None): Column name for y-axis (default: auto-detect from attrs). + y_label (str | None): Label for y-axis (default: auto-detect from attrs). + y_type (str | None): Altair type for y-axis (default: 'Q'). + y_format (str | None): Format string for y-axis values. + color_field (str | None): Column name for color encoding (default: 'label'). + color_type (str | None): Altair type for color (default: 'N'). + color_domain (list | None): Custom color domain. + color_range (list | None): Custom color range for scale. + color_scheme (str | None): Custom color scheme for scale. + color_title (str | None): Title for color (default: 'Legend'). + legend (bool): Whether to show legend (default: True). + empty (bool): Whether to create an empty plot with NaN values (default: False). + **mark_kwargs (Any): Additional keyword arguments for mark_line(). + + Returns: + Altair Chart object with line plot (and optional area for lower/upper bounds). + """ + + x_field = get_first_coord(df) if x_field is None else x_field + x_label = get_first_coord_label(df) if x_label is None else x_label + x_type = ('T' if get_first_coord_axis(df) == 'T' else 'Q') if x_type is None else x_type + x = alt.X( + f'{x_field}:{x_type}', + title=x_label + ) + + y_field = get_first_data_var(df) if y_field is None else y_field + y_label = get_first_data_var_label(df) if y_label is None else y_label + y_type = 'Q' if y_type is None else y_type + y = alt.Y( + f'{y_field}:{y_type}', + title=y_label, + axis=alt.Axis(format=y_format) if y_format else alt.Axis(), + scale=alt.Scale(zero=False, nice=False) + ) + + color_field = 'label' if color_field is None else color_field + if empty or color_field not in df: + color = alt.Color() + else: + color_type = 'N' if color_type is None else color_type + color_scale_args = {} + if color_domain: + color_scale_args['domain'] = color_domain + if color_range: + color_scale_args['range'] = color_range + if color_scheme: + color_scale_args['scheme'] = color_scheme + + color_legend_args = {} + if color_title: + color_legend_args['title'] = color_title + + color = alt.Color( + f'{color_field}:{color_type}', + scale=alt.Scale(**color_scale_args), + legend=alt.Legend(padding=10, **color_legend_args) if legend else None + ) + + if empty: + df = pd.DataFrame({ + x_field: df[x_field], + y_field: np.full_like(df[y_field], np.nan, dtype=float) + }) + + # the base chart contains only the x axis + base = alt.Chart(df).mark_line(**mark_kwargs).encode(x=x) + + chart = base.mark_line(**mark_kwargs).encode(y=y, color=color) + + if 'lower' in df and 'upper' in df: + chart += base.mark_area(**mark_kwargs, opacity=0.5).encode( + y='lower:Q', + y2='upper:Q', + color=color + ) + + return chart + + +def plot_map(df: pd.DataFrame, color_field: str | None = None, color_type: str | None = None, + color_domain: list | None = None, color_range: list | None = None, color_scheme: str | None = None, + color_label: str | None = None, color_format: str | None = None, bin_size: int = 1, legend: bool = True, + empty: bool = False) -> alt.Chart: + """Create a geographic map plot from a DataFrame with lat/lon coordinates. + + Args: + df (pd.DataFrame): DataFrame with 'lat' and 'lon' columns. + color_field (str | None): Column name for color encoding (default: auto-detect from attrs). + color_type (str | None): Altair type for color (default: 'Q'). + color_domain (list | None): Custom color domain. + color_range (list | None): Custom color range for scale. + color_scheme (str | None): Custom color scheme for scale. + color_label (str | None): Label for color legend (default: auto-detect from attrs). + color_format (str | None): Format string for color legend values. + bin_size (int): Bin size for aggregating grid cells (default: 1). + legend (bool): Whether to show legend (default: True). + empty (bool): Whether to create an empty plot (default: False). + + Returns: + Altair Chart object with rectangular heatmap. + """ + lon = np.sort(df['lon'].unique()) + lon_size = len(lon) + lon_bin = float(abs(lon[1] - lon[0])) * bin_size + lon_domain = (lon.min() - 0.5 * lon_bin, lon.max() + 0.5 * lon_bin) + lon_ticks = np.linspace(lon_domain[0], lon_domain[1], num=7) + + x = alt.X( + 'lon:Q', + title='lon', + bin=alt.Bin(step=lon_bin), + axis=alt.Axis(values=lon_ticks), + scale=alt.Scale(domain=lon_domain, padding=0, round=True) + ) + + lat = np.sort(df['lat'].unique()) + lat_size = len(lat) + lat_bin = float(abs(lat[1] - lat[0])) * bin_size + lat_domain = (lat.min() - 0.5 * lat_bin, lat.max() + 0.5 * lat_bin) + lat_ticks = np.linspace(lat_domain[0], lat_domain[1], num=5) + + y = alt.Y( + 'lat:Q', + title='lat', + bin=alt.Bin(step=lat_bin), + axis=alt.Axis(values=lat_ticks), + scale=alt.Scale(domain=lat_domain, padding=0, round=True) + ) + + if empty: + color = alt.Color() + else: + color_field = get_first_data_var(df) if color_field is None else color_field + color_type = 'Q' if color_type is None else color_type + color_label = get_first_data_var_label(df) if color_label is None else color_label + + color_scale_args = {} + if color_domain: + color_scale_args['domain'] = color_domain + if color_range: + color_scale_args['range'] = color_range + if color_scheme: + color_scale_args['scheme'] = color_scheme + + color_legend_args = {} + if color_format: + color_legend_args['format'] = color_format + + color = alt.Color( + f'{color_field}:{color_type}', + title=color_label, + scale=alt.Scale(**color_scale_args), + legend=alt.Legend(padding=10, **color_legend_args) if legend else None + ) + + if empty: + df = pd.DataFrame({ + 'lon': [], + 'lat': [] + }) + + return alt.Chart(df).mark_rect().encode(x=x, y=y, color=color).properties( + width=lon_size, + height=lat_size + ) + + +def plot_grid(grid_permutations: list[tuple], plot_permutations: list[tuple], plots: dict, empty_plot: alt.Chart, + x: str = 'shared', y: str = 'shared', color: str = 'shared') -> alt.Chart: + """Create a grid of plots organized by parameter permutations. + + Args: + grid_permutations (list): List the permutations (with tuples of parameters) which span the grid. + plot_permutations (list): List the permutations (with tuples of parameters) for each plot. + plots (dict): Dictionary mapping permutation tuples to Chart objects. + empty_plot (alt.Chart): Chart to use when a permutation has no data. + x (str): Scale resolution for x-axis ('shared', 'independent', default: 'shared'). + y (str): Scale resolution for y-axis ('shared', 'independent', default: 'shared'). + color (str): Scale resolution for color ('shared', 'independent', default: 'shared'). + + Returns: + Altair Chart object with grid layout. + """ + rows = [] + prev = None + + for grid_permutation in grid_permutations: + row_title = grid_permutation[0] if len(grid_permutation) > 0 else '' + column_title = grid_permutation[1] if len(grid_permutation) > 1 else '' + + if prev is None or (len(grid_permutation) > 0 and grid_permutation[0] != prev[0]): + # start a new row + column = [] + row = [(column_title, column)] + rows.append((row_title, row)) + elif prev is None or (len(grid_permutation) > 1 and grid_permutation[1] != prev[1]): + # start a new column + column = [] + row.append((column_title, column)) + + for plot_permutation in plot_permutations: + plot = plots.get(grid_permutation + plot_permutation) + if plot: + column.append(plot) + + prev = grid_permutation + + chart = alt.vconcat(*[ + alt.hconcat(*[ + alt.layer(*column, title=column_title) if column else empty_plot + for column_title, column in row + ], title=row_title).resolve_scale(x=x, y=y) + for row_title, row in rows + ]).resolve_scale(x=x, y=y) + + return chart diff --git a/isimip_utils/protocol.py b/isimip_utils/protocol.py new file mode 100644 index 0000000..3326397 --- /dev/null +++ b/isimip_utils/protocol.py @@ -0,0 +1,172 @@ +"""Functions to fetch information from machine-actionable ISIMIP protocols.""" +import logging +import os +import re +from collections.abc import Generator +from pathlib import Path +from typing import Any +from urllib.parse import urlparse + +from .exceptions import NotFound +from .fetch import fetch_json, load_json + +logger = logging.getLogger(__name__) + +PROTOCOL_LOCATIONS = [ + 'https://protocol.isimip.org', + 'https://protocol2.isimip.org', +] + + +def fetch_definitions(path: str | Path, protocol_locations: str | list[str] = PROTOCOL_LOCATIONS) -> dict[str, Any]: + """Fetch definitions from ISIMIP protocol locations. + + Args: + path (str | Path): Path to search for definitions. + protocol_locations (str | list[str]): List of protocol locations to search (default: https://protocol.isimip.org). + + Returns: + Dictionary of definitions with specifiers as keys. + + Raises: + NotFound: If no definitions are found for the given path. + """ + if isinstance(protocol_locations, str): + protocol_locations = [protocol_locations] + + for protocol_location in protocol_locations: + definitions_json = find_json(protocol_location, 'definitions', path) + if definitions_json: + definitions = {} + for definition_name, definition in definitions_json.items(): + # convert the definitions to dicts if they are lists + if isinstance(definition, list): + definitions[definition_name] = { + row['specifier']: row for row in definition + } + else: + definitions[definition_name] = definition + + logger.debug('definitions = %s', definitions) + return definitions + + raise NotFound(f'No definitions found for {path}.') + + +def fetch_pattern(path: str | Path, protocol_locations: str | list[str] = PROTOCOL_LOCATIONS) -> dict[str, Any]: + """Fetch pattern definitions from ISIMIP protocol locations. + + Args: + path (str | Path): Path to search for patterns. + protocol_locations (str | list[str]): List of protocol locations to search (default: https://protocol.isimip.org). + + Returns: + Dictionary containing compiled regex patterns for 'path', 'file', 'dataset', + and lists of 'suffix', 'specifiers', and 'specifiers_map'. + + Raises: + NotFound: If no pattern is found for the given path. + """ + if isinstance(protocol_locations, str): + protocol_locations = [protocol_locations] + + for protocol_location in protocol_locations: + pattern_json = find_json(protocol_location, 'pattern', path) + if pattern_json: + if not all([ + isinstance(pattern_json['path'], str), + isinstance(pattern_json['file'], str), + isinstance(pattern_json['dataset'], str), + isinstance(pattern_json['suffix'], list) + ]): + break + + pattern = { + 'path': re.compile(pattern_json['path']), + 'file': re.compile(pattern_json['file']), + 'dataset': re.compile(pattern_json['dataset']), + 'suffix': pattern_json['suffix'], + 'specifiers': pattern_json.get('specifiers', []), + 'specifiers_map': pattern_json.get('specifiers_map', {}) + } + + logger.debug('pattern = %s', pattern) + + return pattern + + raise NotFound(f'No pattern found for {path}.') + + +def fetch_schema(path: str | Path, protocol_locations: str | list[str] = PROTOCOL_LOCATIONS) -> Any: + """Fetch schema from ISIMIP protocol locations. + + Args: + path (str | Path): Path to search for schema. + protocol_locations (str | list[str]): List of protocol locations to search (default: https://protocol.isimip.org). + + Returns: + Schema JSON object. + + Raises: + NotFound: If no schema is found for the given path. + """ + if isinstance(protocol_locations, str): + protocol_locations = [protocol_locations] + + for protocol_location in protocol_locations: + schema_json = find_json(protocol_location, 'schema', path) + if schema_json: + return schema_json + + raise NotFound(f'No schema found for {path}.') + + +def fetch_tree(path: str | Path, protocol_locations: str | list[str] = PROTOCOL_LOCATIONS) -> Any: + """Fetch tree structure from ISIMIP protocol locations. + + Args: + path (str | Path): Path to search for tree structure. + protocol_locations (str | list[str]): List of protocol locations to search (default: https://protocol.isimip.org). + + Returns: + Tree JSON object. + + Raises: + NotFound: If no tree is found for the given path. + """ + if isinstance(protocol_locations, str): + protocol_locations = [protocol_locations] + + for protocol_location in protocol_locations: + tree_json = find_json(protocol_location, 'tree', path) + if tree_json: + return tree_json + + raise NotFound(f'No tree found for {path}.') + + +def find_json(protocol_location: str, sub_location: str, path: str | Path) -> Generator[tuple[Path, Any], None, None]: + """Find JSON files in protocol locations by traversing path components. + + Args: + protocol_location (str): Base protocol location URL or path. + sub_location (str): Subdirectory within protocol location (e.g., 'definitions', 'pattern'). + path (str | Path): Path to search for JSON files. + + Returns: + The JSON response from the first matching path. + """ + path_components = Path(path).parts + for i in range(len(path_components), 0, -1): + current_path = Path(os.sep.join(path_components[:i+1])).with_suffix('.json') + + if not isinstance(protocol_location, Path) and urlparse(protocol_location).scheme: + data = fetch_json(f'{protocol_location}/{sub_location}/{current_path}') + else: + data = load_json(Path(protocol_location) / 'output' / sub_location / current_path) + + logger.debug('path = %s', current_path) + logger.debug('data = %s', data) + + if data is not None: + return data diff --git a/isimip_utils/tests/__init__.py b/isimip_utils/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/isimip_utils/tests/constants.py b/isimip_utils/tests/constants.py new file mode 100644 index 0000000..a3b977f --- /dev/null +++ b/isimip_utils/tests/constants.py @@ -0,0 +1,48 @@ +from datetime import date +from pathlib import Path + +DATASETS_PATH = Path("testing/datasets") +EXTRACTIONS_PATH = Path("testing/extractions") +PLOTS_PATH = Path("testing/plots") +OUTPUT_PATH = Path("testing/output") + +PROTOCOL_PATH = Path("testing/protocol/output") +SHARE_PATH = Path("testing/share") + +LANDSEAMASK_PATH = "ISIMIP3a/InputData/geo_conditions/landseamask/landseamask.nc" + +TAS_PATH = "ISIMIP3a/InputData/climate/atmosphere/obsclim/global/daily/" \ + "historical/20CRv3-ERA5/20crv3-era5_obsclim_tas_global_daily_2021_2021.nc" + +TAS_DATE_SPECIFIERS = '2021_2021' + +TAS_SPLIT_PERIOD = ( + (date(2021, 1, 1), date(2021, 4, 30)), + (date(2021, 5, 1), date(2021, 8, 31)), + (date(2021, 9, 1), date(2021, 12, 31)) +) +TAS_SPLIT_PATHS = [ + TAS_PATH.replace(TAS_DATE_SPECIFIERS, f"{start_date.strftime('%Y%m%d')}_{end_date.strftime('%Y%m%d')}") + for start_date, end_date in TAS_SPLIT_PERIOD +] + +YIELD_PATH = "ISIMIP3a/OutputData/agriculture/LPJmL/gswp3-w5e5/historical/" \ + "lpjml_gswp3-w5e5_obsclim_2015soc_default_yield-mai-noirr_global_annual-gs_1901_2016.nc" + +PROTOCOL_PATHS = [ + "definitions/ISIMIP3a/OutputData/agriculture.json", + "pattern/ISIMIP3a/OutputData/agriculture.json", + "schema/ISIMIP3a/OutputData/agriculture.json", + "tree/ISIMIP3a/OutputData/agriculture.json" +] + +PROTOCOL_LOCATIONS = ['testing/protocol'] +PATTERN_PATH = 'ISIMIP3a/OutputData/agriculture.json' + +DATE = date(2021, 1, 1) +PERIOD = date(2021, 4, 1), date(2021, 9, 30) + +BBOX = (70, 80, -5, 5) + +POINT = (52.395833, 13.061389) +POINT_INDEX = (386, 75) diff --git a/isimip_utils/tests/helper.py b/isimip_utils/tests/helper.py new file mode 100644 index 0000000..7248492 --- /dev/null +++ b/isimip_utils/tests/helper.py @@ -0,0 +1,56 @@ +import json +import re +import subprocess +from io import BytesIO +from pathlib import Path +from unittest.mock import MagicMock + + +def call(cmd): + print(cmd) + return subprocess.check_output(cmd, shell=True).decode() + + +def normalize_whitespace(string): + return re.sub(r'\s+', ' ', string).strip() + + +def assert_multiline_strings_equal(a, b): + for a_line, b_line in zip(a.strip().splitlines(), b.strip().splitlines(), strict=True): + assert normalize_whitespace(a_line) == normalize_whitespace(b_line), (a_line, b_line) + + +def mock_json(url, *args, **kwargs): + mock_response = MagicMock() + mock_path = Path(url.replace('https://protocol.isimip.org', 'testing/protocol/output')) + + if mock_path.exists(): + with mock_path.open() as fp: + mock_response.status_code = 200 + mock_response.json.return_value = json.load(fp) + else: + mock_response.status_code = 404 + mock_response.json.return_value = None + + return mock_response + + +def mock_content(url, *args, **kwargs): + mock_response = MagicMock() + mock_path = Path(url.replace('https://protocol.isimip.org', 'testing/protocol/output')) + + if mock_path.exists(): + data = mock_path.read_bytes() + + mock_response.status_code = 200 + mock_response.raw = BytesIO(data) + mock_response.content = data + mock_response.iter_content.return_value = [data] + + else: + mock_response.status_code = 404 + mock_response.raw = BytesIO() + mock_response.content = b"" + mock_response.iter_content.return_value = [] + + return mock_response diff --git a/isimip_utils/tests/test_checksum.py b/isimip_utils/tests/test_checksum.py new file mode 100644 index 0000000..ab9f239 --- /dev/null +++ b/isimip_utils/tests/test_checksum.py @@ -0,0 +1,16 @@ +from isimip_utils.checksum import get_checksum, get_checksum_suffix, get_checksum_type +from isimip_utils.tests import constants + + +def test_get_checksum(): + file_path = constants.DATASETS_PATH / constants.LANDSEAMASK_PATH + checksum = get_checksum(file_path) + assert checksum == '30f34d0720b8a6b670d0c093d488a3cd564e232a94d7ebafef99c1d7c18cec5d127fbc663f6378b4b99f9434fa10f71e8413b533c5cc5314d149ab9e2f7cca98' # noqa: E501 + + +def test_get_checksum_type(): + assert get_checksum_type() == 'sha512' + + +def test_get_checksum_suffix(): + assert get_checksum_suffix() == '.sha512' diff --git a/isimip_utils/tests/test_cli.py b/isimip_utils/tests/test_cli.py new file mode 100644 index 0000000..b6617e2 --- /dev/null +++ b/isimip_utils/tests/test_cli.py @@ -0,0 +1,122 @@ +import argparse +import os +import tempfile +from pathlib import Path + +import pytest + +from isimip_utils.cli import ( + ArgumentParser, + parse_dict, + parse_filelist, + parse_list, + parse_locations, + parse_parameters, + parse_path, + parse_version, +) + + +def test_parse_dict(): + result = parse_dict("key=value1,value2") + assert result == {"key": ["value1", "value2"]} + + +def test_parse_list(): + result = parse_list("a,b,c") + assert result == ["a", "b", "c"] + + +def test_parse_version(): + result = parse_version("20230101") + assert result == "20230101" + + +def test_parse_version_invalid(): + with pytest.raises(argparse.ArgumentTypeError): + parse_version("invalid") + + +def test_parse_path(): + result = parse_path("~/test") + assert isinstance(result, Path) + + +def test_parse_locations(): + result = parse_locations('https://example.com /opt/test ~/test') + assert result == ['https://example.com', Path('/opt/test'), Path('~/test').expanduser()] + + +def test_parse_locations_none(): + result = parse_locations('') + assert result == [] + + +def test_parse_filelist(): + with tempfile.NamedTemporaryFile(mode="w", delete=False) as f: + f.write("/path/to/file1\n") + f.write("#comment\n") + f.write("/path/to/file2\n") + temp_file = f.name + + try: + result = parse_filelist(temp_file) + assert "/path/to/file1" in result + assert "/path/to/file2" in result + assert "#comment" not in result + finally: + os.unlink(temp_file) + + +def test_parse_filelist_none(): + result = parse_filelist(None) + assert result == [] + + +def test_parse_parameters(): + result = parse_parameters('egg=spam,foo,bar') + assert result == {'egg': ['spam', 'foo', 'bar']} + + +def test_parse_parameters_none(): + result = parse_parameters('') + assert result == {} + + +def test_argument_parser(): + parser = ArgumentParser() + parser.add_argument("--test", default="default") + + args = parser.parse_args([]) + assert args.test == "default" + + +def test_argument_parser_with_config(tmp_path): + config_file = tmp_path / "isimip.toml" + config_file.write_text("[test]\ntest = \"config_value\"\n") + + # Temporarily change the config files list to use our test config + original_config_files = ArgumentParser.config_files + ArgumentParser.config_files = [str(config_file)] + + try: + parser = ArgumentParser(prog="test") + parser.add_argument("--test", default="default") + + args = parser.parse_args([]) + assert args.test == "config_value" + finally: + ArgumentParser.config_files = original_config_files + + +def test_argument_parser_with_env(): + os.environ["ISIMIP_TEST"] = "env_value" + + try: + parser = ArgumentParser() + parser.add_argument("--test", default="default") + + args = parser.parse_args([]) + assert args.test == "env_value" + finally: + del os.environ["ISIMIP_TEST"] diff --git a/isimip_utils/tests/test_extractions.py b/isimip_utils/tests/test_extractions.py new file mode 100644 index 0000000..cfc5f60 --- /dev/null +++ b/isimip_utils/tests/test_extractions.py @@ -0,0 +1,324 @@ +import pytest + +import numpy as np + +from isimip_utils.extractions import ( + compute_aggregation, + compute_max, + compute_mean, + compute_min, + compute_std, + compute_sum, + concat_extraction, + count_values, + mask_bbox, + mask_mask, + select_bbox, + select_period, + select_point, + select_time, +) +from isimip_utils.tests import constants, helper +from isimip_utils.xarray import open_dataset, write_dataset + + +@pytest.mark.parametrize('decode_cf', (True, False)) +def test_select_time(decode_cf): + date = constants.DATE + date_specifiers = date.strftime('%Y%m%d') + + dataset_path = constants.DATASETS_PATH / constants.TAS_PATH + extraction_path = ( + constants.EXTRACTIONS_PATH / constants.TAS_PATH.replace('_global_', '_select-time_') + .replace(constants.TAS_DATE_SPECIFIERS, date_specifiers) + ) + extraction_path.unlink(missing_ok=True) + + with open_dataset(dataset_path, decode_cf=decode_cf) as file_ds: + ds = select_time(file_ds, date) + write_dataset(ds, extraction_path) + + cdo_path = ( + constants.EXTRACTIONS_PATH / constants.TAS_PATH.replace('_global_', '_select-time-cdo_') + .replace(constants.TAS_DATE_SPECIFIERS, date_specifiers) + ) + helper.call(f'cdo diff {extraction_path} {cdo_path}') + + +@pytest.mark.parametrize('decode_cf', (True, False)) +def test_select_period(decode_cf): + start_date, end_date = constants.PERIOD + date_specifiers = f"{start_date.strftime('%Y%m%d')}_{end_date.strftime('%Y%m%d')}" + + dataset_path = constants.DATASETS_PATH / constants.TAS_PATH + extraction_path = ( + constants.EXTRACTIONS_PATH / constants.TAS_PATH.replace('_global_', '_select-period_') + .replace(constants.TAS_DATE_SPECIFIERS, date_specifiers) + ) + extraction_path.unlink(missing_ok=True) + + with open_dataset(dataset_path, decode_cf=decode_cf) as file_ds: + ds = select_period(file_ds, start_date, end_date) + write_dataset(ds, extraction_path) + + cdo_path = ( + constants.EXTRACTIONS_PATH / constants.TAS_PATH.replace('_global_', '_select-period-cdo_') + .replace(constants.TAS_DATE_SPECIFIERS, date_specifiers) + ) + helper.call(f'cdo diff {extraction_path} {cdo_path}') + + +@pytest.mark.parametrize('decode_cf', (True, False)) +def test_select_point(decode_cf): + lat, lon = constants.POINT + + dataset_path = constants.DATASETS_PATH / constants.TAS_PATH + extraction_path = constants.EXTRACTIONS_PATH / constants.TAS_PATH.replace('_global_', '_select-point_') + extraction_path.unlink(missing_ok=True) + + with open_dataset(dataset_path, decode_cf=decode_cf) as file_ds: + ds = select_point(file_ds, lat, lon) + write_dataset(ds, extraction_path) + + cdo_path = constants.EXTRACTIONS_PATH / constants.TAS_PATH.replace('_global_', '_select-point-cdo_') + helper.call(f'cdo diff {extraction_path} {cdo_path}') + + +@pytest.mark.parametrize('decode_cf', (True, False)) +def test_select_point_concat(decode_cf): + lat, lon = constants.POINT + + extraction_ds = None + for path in constants.TAS_SPLIT_PATHS: + dataset_path = constants.DATASETS_PATH / path + + with open_dataset(dataset_path, decode_cf=decode_cf) as file_ds: + ds = select_point(file_ds, lat, lon) + extraction_ds = concat_extraction(extraction_ds, ds) + + extraction_path = constants.EXTRACTIONS_PATH / constants.TAS_PATH.replace('_global_', '_select-point_') + extraction_path.unlink(missing_ok=True) + + write_dataset(extraction_ds, extraction_path) + + cdo_path = constants.EXTRACTIONS_PATH / constants.TAS_PATH.replace('_global_', '_select-point-cdo_') + helper.call(f'cdo diff {extraction_path} {cdo_path}') + + +@pytest.mark.parametrize('decode_cf', (True, False)) +def test_select_bbox(decode_cf): + west, east, south, north = constants.BBOX + + dataset_path = constants.DATASETS_PATH / constants.TAS_PATH + extraction_path = constants.EXTRACTIONS_PATH / constants.TAS_PATH.replace('_global_', '_select-bbox_') + extraction_path.unlink(missing_ok=True) + + with open_dataset(dataset_path, decode_cf=decode_cf) as file_ds: + ds = select_bbox(file_ds, west, east, south, north) + write_dataset(ds, extraction_path) + + cdo_path = constants.EXTRACTIONS_PATH / constants.TAS_PATH.replace('_global_', '_select-bbox-cdo_') + helper.call(f'cdo diff {extraction_path} {cdo_path}') + + +@pytest.mark.parametrize('decode_cf', (True, False)) +def test_select_bbox_concat(decode_cf): + west, east, south, north = constants.BBOX + + extraction_ds = None + for path in constants.TAS_SPLIT_PATHS: + dataset_path = constants.DATASETS_PATH / path + + with open_dataset(dataset_path, decode_cf=decode_cf) as file_ds: + ds = select_bbox(file_ds, west, east, south, north) + extraction_ds = concat_extraction(extraction_ds, ds) + + extraction_path = constants.EXTRACTIONS_PATH / constants.TAS_PATH.replace('_global_', '_select-bbox_') + extraction_path.unlink(missing_ok=True) + + write_dataset(extraction_ds, extraction_path) + + cdo_path = constants.EXTRACTIONS_PATH / constants.TAS_PATH.replace('_global_', '_select-bbox-cdo_') + helper.call(f'cdo diff {extraction_path} {cdo_path}') + + +@pytest.mark.parametrize('decode_cf', (True, False)) +def test_mask_bbox(decode_cf): + west, east, south, north = constants.BBOX + + dataset_path = constants.DATASETS_PATH / constants.TAS_PATH + extraction_path = constants.EXTRACTIONS_PATH / constants.TAS_PATH.replace('_global_', '_mask-bbox_') + extraction_path.unlink(missing_ok=True) + + with open_dataset(dataset_path, decode_cf=decode_cf) as file_ds: + ds = mask_bbox(file_ds, west, east, south, north) + write_dataset(ds, extraction_path) + + cdo_path = constants.EXTRACTIONS_PATH / constants.TAS_PATH.replace('_global_', '_mask-bbox-cdo_') + helper.call(f'cdo diff {extraction_path} {cdo_path}') + + +@pytest.mark.parametrize('decode_cf', (True, False)) +def test_mask_bbox_concat(decode_cf): + west, east, south, north = constants.BBOX + + extraction_ds = None + for path in constants.TAS_SPLIT_PATHS: + dataset_path = constants.DATASETS_PATH / path + + with open_dataset(dataset_path, decode_cf=decode_cf) as file_ds: + ds = mask_bbox(file_ds, west, east, south, north) + extraction_ds = concat_extraction(extraction_ds, ds) + + extraction_path = constants.EXTRACTIONS_PATH / constants.TAS_PATH.replace('_global_', '_mask-bbox_') + extraction_path.unlink(missing_ok=True) + + write_dataset(extraction_ds, extraction_path) + + cdo_path = constants.EXTRACTIONS_PATH / constants.TAS_PATH.replace('_global_', '_mask-bbox-cdo_') + helper.call(f'cdo diff {extraction_path} {cdo_path}') + + +@pytest.mark.parametrize('decode_cf', (True, False)) +def test_mask_mask(decode_cf): + mask_path = constants.DATASETS_PATH / constants.LANDSEAMASK_PATH + mask_ds = open_dataset(mask_path) + + dataset_path = constants.DATASETS_PATH / constants.TAS_PATH + extraction_path = constants.EXTRACTIONS_PATH / constants.TAS_PATH.replace('_global_', '_mask-mask_') + extraction_path.unlink(missing_ok=True) + + with open_dataset(dataset_path, decode_cf=decode_cf) as file_ds: + ds = mask_mask(file_ds, mask_ds) + write_dataset(ds, extraction_path) + + cdo_path = constants.EXTRACTIONS_PATH / constants.TAS_PATH.replace('_global_', '_mask-mask-cdo_') + helper.call(f'cdo diff {extraction_path} {cdo_path}') + + +@pytest.mark.parametrize('decode_cf', (True, False)) +def test_mask_mask_concat(decode_cf): + mask_path = constants.DATASETS_PATH / constants.LANDSEAMASK_PATH + mask_ds = open_dataset(mask_path) + + extraction_ds = None + for path in constants.TAS_SPLIT_PATHS: + dataset_path = constants.DATASETS_PATH / path + + with open_dataset(dataset_path, decode_cf=decode_cf) as file_ds: + ds = mask_mask(file_ds, mask_ds) + extraction_ds = concat_extraction(extraction_ds, ds) + + extraction_path = constants.EXTRACTIONS_PATH / constants.TAS_PATH.replace('_global_', '_mask-mask_') + extraction_path.unlink(missing_ok=True) + + write_dataset(extraction_ds, extraction_path) + + cdo_path = constants.EXTRACTIONS_PATH / constants.TAS_PATH.replace('_global_', '_mask-mask-cdo_') + helper.call(f'cdo diff {extraction_path} {cdo_path}') + + +@pytest.mark.parametrize('type', ('mean', 'min', 'max', 'sum', 'std')) +@pytest.mark.parametrize('decode_cf', (True, False)) +def test_compute_aggregation(type, decode_cf): + gridarea_path = constants.SHARE_PATH / 'gridarea.nc' + gridarea_ds = open_dataset(gridarea_path) + + west, east, south, north = constants.BBOX + + dataset_path = constants.DATASETS_PATH / constants.TAS_PATH + extraction_path = constants.EXTRACTIONS_PATH / constants.TAS_PATH.replace('_global_', f'_select-bbox-{type}_') + extraction_path.unlink(missing_ok=True) + + with open_dataset(dataset_path, decode_cf=decode_cf) as file_ds: + ds = select_bbox(file_ds, west, east, south, north) + + if type == 'mean': + ds = compute_mean(ds, weights=gridarea_ds['cell_area']) + elif type == 'std': + ds = compute_std(ds, weights=gridarea_ds['cell_area']) + elif type == 'sum': + ds = compute_sum(ds, weights=gridarea_ds['cell_area']) + elif type == 'max': + ds = compute_max(ds) + elif type == 'min': + ds = compute_min(ds) + + write_dataset(ds, extraction_path) + + # allow for a small relative difference, translated into an absolute difference + if type == 'sum': + abslim = 3.36e+07 + elif type == 'std': + abslim = 1e-7 + else: + abslim = 0.0 + + cdo_path = constants.EXTRACTIONS_PATH / constants.TAS_PATH.replace('_global_', f'_select-bbox-{type}-cdo_') + helper.call(f'cdo diff,abslim={abslim} {extraction_path} {cdo_path}') + + +@pytest.mark.parametrize('decode_cf', (True, False)) +def test_compute_aggregation_nan(decode_cf): + dataset_path = constants.DATASETS_PATH / constants.YIELD_PATH + + with open_dataset(dataset_path, decode_cf=decode_cf) as file_ds: + ds = compute_max(file_ds) + + # check that the max is not FILL_VALUE + assert (ds['yield-mai-noirr'] < 20).all() + + +@pytest.mark.parametrize('decode_cf', (True, False)) +def test_compute_mean_time(decode_cf): + west, east, south, north = constants.BBOX + + dataset_path = constants.DATASETS_PATH / constants.TAS_PATH + extraction_path = constants.EXTRACTIONS_PATH / constants.TAS_PATH.replace('_global_', '_select-bbox-map_') + extraction_path.unlink(missing_ok=True) + + with open_dataset(dataset_path, decode_cf=decode_cf) as file_ds: + ds = select_bbox(file_ds, west, east, south, north) + ds = compute_aggregation(ds, 'mean', dim='time') + write_dataset(ds, extraction_path) + + # allow for a small relative difference, translated into an absolute difference + abslim = 2.45e-4 + + cdo_path = constants.EXTRACTIONS_PATH / constants.TAS_PATH.replace('_global_', '_select-bbox-map-cdo_') + helper.call(f'cdo diff,abslim={abslim} {extraction_path} {cdo_path}') + + +@pytest.mark.parametrize('decode_cf', (True, False)) +def test_count_values(decode_cf): + dataset_path = constants.DATASETS_PATH / constants.TAS_PATH + + with open_dataset(dataset_path, decode_cf=decode_cf) as file_ds: + ds = count_values(file_ds) + assert (ds['tas'] == 720*360).all() + + +@pytest.mark.parametrize('decode_cf', (True, False)) +def test_count_values_nan(decode_cf): + dataset_path = constants.DATASETS_PATH / constants.YIELD_PATH + + cdo_counts = np.array([ + int(line.split()[5]) - int(line.split()[6]) + for line in helper.call(f'cdo info {dataset_path}').splitlines()[1:-1] + ]) + + with open_dataset(dataset_path, decode_cf=decode_cf) as file_ds: + ds = count_values(file_ds) + assert (ds['yield-mai-noirr'].values == cdo_counts).all() + + +@pytest.mark.parametrize('decode_cf', (True, False)) +def test_count_values_mask(decode_cf): + west, east, south, north = constants.BBOX + + dataset_path = constants.DATASETS_PATH / constants.TAS_PATH + + with open_dataset(dataset_path, decode_cf=decode_cf) as file_ds: + ds = mask_bbox(file_ds, west, east, south, north) + ds = count_values(ds) + assert (ds['tas'] == 400).all() diff --git a/isimip_utils/tests/test_fetch.py b/isimip_utils/tests/test_fetch.py new file mode 100644 index 0000000..ea7131f --- /dev/null +++ b/isimip_utils/tests/test_fetch.py @@ -0,0 +1,51 @@ +from unittest.mock import patch + +from isimip_utils.fetch import fetch_file, fetch_json, load_file, load_json +from isimip_utils.tests import constants, helper + +paths = [ + 'ISIMIP3a/OutputData/agriculture/ACEA/gswp3-w5e5.json', + 'ISIMIP3a/OutputData/agriculture/ACEA.json', + 'ISIMIP3a/OutputData/agriculture.json' +] + + +def test_fetch_json(): + with patch('isimip_utils.fetch.requests.get', side_effect=helper.mock_json): + data = fetch_json("https://protocol.isimip.org/definitions/ISIMIP3a/OutputData/agriculture.json") + assert data is not None + + +def test_fetch_json_not_found(): + with patch('isimip_utils.fetch.requests.get', side_effect=helper.mock_json): + data = fetch_json("https://protocol.isimip.org/definitions/ISIMIP3a/OutputData/agriculture/ACEA.json") + assert data is None + + +def test_fetch_file(): + with patch('isimip_utils.fetch.requests.get', side_effect=helper.mock_content): + output_path = constants.OUTPUT_PATH / 'test.json' + output_path.parent.mkdir(exist_ok=True, parents=True) + output_path.unlink(missing_ok=True) + + fetch_file("https://protocol.isimip.org/definitions/ISIMIP3a/OutputData/agriculture.json", output_path) + assert output_path.is_file() + + +def test_load_json(): + data = load_json('testing/protocol/output/definitions/ISIMIP3a/OutputData/agriculture.json') + assert data is not None + + +def test_load_json_not_found(): + data = load_json('testing/protocol/output/definitions/ISIMIP3a/OutputData/agriculture/ACEA.json') + assert data is None + + +def test_load_file(): + output_path = constants.OUTPUT_PATH / 'test.json' + output_path.parent.mkdir(exist_ok=True, parents=True) + output_path.unlink(missing_ok=True) + + load_file('testing/protocol/output/definitions/ISIMIP3a/OutputData/agriculture.json', output_path) + assert output_path.is_file() diff --git a/isimip_utils/tests/test_files.py b/isimip_utils/tests/test_files.py new file mode 100644 index 0000000..cbe04eb --- /dev/null +++ b/isimip_utils/tests/test_files.py @@ -0,0 +1,40 @@ +from pathlib import Path + +from isimip_utils.files import find_files +from isimip_utils.tests import constants + + +def test_find_files(): + file_path = Path(constants.YIELD_PATH) + fake_path = file_path.with_stem(file_path.stem + '_a') + files = [ + file_path.name, + fake_path.name + ] + + result = find_files(files) + assert len(result) + assert result == [ + (file_path.name, 1901, 2016) + ] + + +def test_find_files_with_pattern(): + file_path = Path(constants.YIELD_PATH) + fake_path = file_path.with_stem(file_path.stem + '_a') + none_path = file_path.with_stem(file_path.stem.replace('_1901_2016', '')) + files = [ + file_path.name, + fake_path.name, + none_path.name, + ] + + pattern = r'(_(?P\d{4}))?(_(?P\d{4}))?(_\w+)?\.nc\d*$' + + result = find_files(files, pattern=pattern) + assert len(result) + assert result == [ + (none_path.name, None, None), # result is sorted + (file_path.name, 1901, 2016), + (fake_path.name, 1901, 2016), + ] diff --git a/isimip_utils/tests/test_netcdf.py b/isimip_utils/tests/test_netcdf.py new file mode 100644 index 0000000..e1d2c68 --- /dev/null +++ b/isimip_utils/tests/test_netcdf.py @@ -0,0 +1,171 @@ +from datetime import datetime +from pathlib import Path + +import pytest + +import numpy as np +from netCDF4 import Dataset + +from isimip_utils.netcdf import ( + convert_attribute, + get_data_model, + get_dimensions, + get_global_attributes, + get_index, + get_variables, + init_dataset, + open_dataset, + open_dataset_read, + open_dataset_write, + update_global_attributes, + value2string, +) +from isimip_utils.tests import constants + + +def test_open_dataset(): + dataset = open_dataset(constants.DATASETS_PATH / constants.LANDSEAMASK_PATH) + assert isinstance(dataset, Dataset) + + +def test_open_dataset_read(): + dataset = open_dataset_read(constants.DATASETS_PATH / constants.LANDSEAMASK_PATH) + assert isinstance(dataset, Dataset) + + +def test_open_dataset_write(): + test_path = Path('testing/output') / 'test.nc' + test_path.parent.mkdir(exist_ok=True) + test_path.unlink(missing_ok=True) + + dataset = open_dataset_write(test_path) + assert isinstance(dataset, Dataset) + + +def test_init_dataset(): + test_path = Path('testing/output') / 'test.nc' + test_path.parent.mkdir(exist_ok=True) + test_path.unlink(missing_ok=True) + + dataset = init_dataset( + test_path, + time=np.arange(0, 10, dtype=np.float64), + var=np.random.rand(10, 360, 720).astype(np.float64), + attrs={'var': {'long_name': 'Variable'}} + ) + assert isinstance(dataset, Dataset) + + +@pytest.mark.parametrize('point,result', [ + ((89.75, -179.75), (0, 0)), + ((89.75, -179.25), (1, 0)), + ((89.25, -179.75), (0, 1)), + ((52.395833, 13.061389), (386, 75)) +]) +def test_get_index(point, result): + test_path = Path('testing/output') / 'test.nc' + test_path.parent.mkdir(exist_ok=True) + test_path.unlink(missing_ok=True) + + lat, lon = point + dataset = init_dataset(test_path, overwrite=True) + assert get_index(dataset, lat, lon) == result + + +def test_get_data_model(): + dataset = Dataset(constants.DATASETS_PATH / constants.LANDSEAMASK_PATH) + data_model = get_data_model(dataset) + assert data_model == 'NETCDF4_CLASSIC' + + +def test_get_dimensions(): + test_path = Path('testing/output') / 'test.nc' + test_path.parent.mkdir(exist_ok=True) + test_path.unlink(missing_ok=True) + + dataset = init_dataset(test_path, overwrite=True) + dimensions = get_dimensions(dataset) + assert list(dimensions.items()) == [ + ('lon', 720), + ('lat', 360) + ] + + +def test_get_variables(): + test_path = Path('testing/output') / 'test.nc' + test_path.parent.mkdir(exist_ok=True) + test_path.unlink(missing_ok=True) + + dataset = init_dataset(test_path, overwrite=True) + variables = get_variables(dataset) + assert [(variable_name, variable['standard_name']) for variable_name, variable in variables.items()] == [ + ('lon', 'longitude'), + ('lat', 'latitude') + ] + + +def test_get_global_attributes(): + test_path = Path('testing/output') / 'test.nc' + test_path.parent.mkdir(exist_ok=True) + test_path.unlink(missing_ok=True) + + dataset = init_dataset(test_path, overwrite=True, attrs={ + 'global': { + 'egg': 'spam', + 'x': np.float32(3.0) + } + }) + global_attrs = get_global_attributes(dataset) + + assert global_attrs['egg'] == 'spam' + assert global_attrs['x'] == np.float32(3.0) + + +@pytest.mark.parametrize('value,return_value', [ + (np.float32(3.0), 3.0), + (np.int32(42), 42), + ([1, 2, 3], [1, 2, 3]), + (np.array([1, 2, 3]), [1, 2, 3]), + ([np.float32(1.0), np.int32(2)], [1.0, 2]) +]) +def test_convert_attribute(value, return_value): + assert convert_attribute(value) == return_value + + +def test_update_global_attributes_set(): + test_path = Path('testing/output') / 'test.nc' + test_path.parent.mkdir(exist_ok=True) + test_path.unlink(missing_ok=True) + + dataset = init_dataset(test_path, overwrite=True) + update_global_attributes(dataset, set_attributes={ + 'egg': 'spam' + }) + + assert dataset.egg == 'spam' + + +def test_update_global_attributes_delete(): + test_path = Path('testing/output') / 'test.nc' + test_path.parent.mkdir(exist_ok=True) + test_path.unlink(missing_ok=True) + + dataset = init_dataset(test_path, overwrite=True, attrs={ + 'global': { + 'egg': 'spam' + } + }) + update_global_attributes(dataset, delete_attributes=['egg']) + + with pytest.raises(AttributeError): + assert dataset.egg + + +@pytest.mark.parametrize('value,string', [ + (datetime(2023, 1, 1, 12, 0, 0), '2023-01-01T12:00:00Z'), + (123, '123'), + ('test', 'test'), + (None, 'None') +]) +def test_value2string(value, string): + assert value2string(value) == string diff --git a/isimip_utils/tests/test_pandas.py b/isimip_utils/tests/test_pandas.py new file mode 100644 index 0000000..0c2906e --- /dev/null +++ b/isimip_utils/tests/test_pandas.py @@ -0,0 +1,174 @@ +import pytest + +from isimip_utils.pandas import ( + compute_average, + create_label, + get_coord_axes, + get_coord_labels, + get_coords, + get_data_var_labels, + get_data_vars, + get_first_coord, + get_first_coord_axis, + get_first_coord_label, + get_first_data_var, + get_first_data_var_label, + group_by_day, + group_by_month, + normalize, +) +from isimip_utils.tests import constants +from isimip_utils.xarray import open_dataset, to_dataframe + +extractions = { + 'bbox': constants.TAS_PATH.replace('_global_', '_select-bbox-cdo_'), + 'point': constants.TAS_PATH.replace('_global_', '_select-point-cdo_') +} + +@pytest.mark.parametrize('extraction,result', [ + ('bbox', ('lon', 'lat', 'time')), + ('point', ('time', )) +]) +def test_get_coords(extraction, result): + with open_dataset(constants.EXTRACTIONS_PATH / extractions[extraction]) as ds: + df = to_dataframe(ds) + assert get_coords(df) == result + + +@pytest.mark.parametrize('extraction,result', [ + ('point', 'time') +]) +def test_get_first_coord(extraction, result): + with open_dataset(constants.EXTRACTIONS_PATH / extractions[extraction]) as ds: + df = to_dataframe(ds) + assert get_first_coord(df) == result + + +@pytest.mark.parametrize('extraction,result', [ + ('bbox', ('Longitude [degrees_east]', 'Latitude [degrees_north]', 'Time')), + ('point', ('Time', )) +]) +def test_get_coord_labels(extraction, result): + with open_dataset(constants.EXTRACTIONS_PATH / extractions[extraction]) as ds: + df = to_dataframe(ds) + assert get_coord_labels(df) == result + + +@pytest.mark.parametrize('extraction,result', [ + ('point', 'Time') +]) +def test_get_first_coord_label(extraction, result): + with open_dataset(constants.EXTRACTIONS_PATH / extractions[extraction]) as ds: + df = to_dataframe(ds) + assert get_first_coord_label(df) == result + + +@pytest.mark.parametrize('extraction,result', [ + ('bbox', ('X', 'Y', 'T')), + ('point', ('T', )) +]) +def test_get_coord_axes(extraction, result): + with open_dataset(constants.EXTRACTIONS_PATH / extractions[extraction]) as ds: + df = to_dataframe(ds) + assert get_coord_axes(df) == result + + +@pytest.mark.parametrize('extraction,result', [ + ('point', 'T') +]) +def test_get_first_coord_axis(extraction, result): + with open_dataset(constants.EXTRACTIONS_PATH / extractions[extraction]) as ds: + df = to_dataframe(ds) + assert get_first_coord_axis(df) == result + + +@pytest.mark.parametrize('extraction,result', [ + ('bbox', ('tas', )), + ('point', ('tas', )) +]) +def test_get_data_vars(extraction, result): + with open_dataset(constants.EXTRACTIONS_PATH / extractions[extraction]) as ds: + df = to_dataframe(ds) + assert get_data_vars(df) == result + + +@pytest.mark.parametrize('extraction,result', [ + ('point', 'tas') +]) +def test_get_first_data_var(extraction, result): + with open_dataset(constants.EXTRACTIONS_PATH / extractions[extraction]) as ds: + df = to_dataframe(ds) + assert get_first_data_var(df) == result + + +@pytest.mark.parametrize('extraction,result', [ + ('bbox', ('tas [K]', )), + ('point', ('tas [K]', )) +]) +def test_get_data_var_labels(extraction, result): + with open_dataset(constants.EXTRACTIONS_PATH / extractions[extraction]) as ds: + df = to_dataframe(ds) + assert get_data_var_labels(df) == result + + +@pytest.mark.parametrize('extraction,result', [ + ('point', 'tas [K]') +]) +def test_get_first_data_var_label(extraction, result): + with open_dataset(constants.EXTRACTIONS_PATH / extractions[extraction]) as ds: + df = to_dataframe(ds) + assert get_first_data_var_label(df) == result + + +def test_compute_average(): + with open_dataset(constants.EXTRACTIONS_PATH / extractions['point']) as ds: + df = to_dataframe(ds) + df = compute_average(df, 'tas') + + assert df['lower'].between(270, 280).all() + assert df['mean'].between(280, 290).all() + assert df['upper'].between(290, 300).all() + + +def test_compute_average_monthly(): + with open_dataset(constants.EXTRACTIONS_PATH / extractions['point']) as ds: + df = to_dataframe(ds) + df = compute_average(df, 'tas', type='monthly') + + assert df['lower'].between(260, 300).all() + assert df['mean'].between(270, 300).all() + assert df['upper'].between(270, 305).all() + + +def test_group_by_day(): + with open_dataset(constants.EXTRACTIONS_PATH / extractions['point']) as ds: + df = to_dataframe(ds) + df = group_by_day(df, 'tas') + + assert len(df) == 365 + assert df['tas'].between(260, 305).all() + + +def test_group_by_month(): + with open_dataset(constants.EXTRACTIONS_PATH / extractions['point']) as ds: + df = to_dataframe(ds) + df = group_by_month(df, 'tas') + + assert len(df) == 12 + assert df['tas'].between(260, 300).all() + + +def test_normalize(): + with open_dataset(constants.EXTRACTIONS_PATH / extractions['point']) as ds: + df = to_dataframe(ds) + df = normalize(df, 'tas') + + assert df['tas'].between(-4, 4).all() + + +def test_create_label(): + with open_dataset(constants.EXTRACTIONS_PATH / extractions['point']) as ds: + df = to_dataframe(ds) + df = create_label(df, ['x', 'y', 'z']) + + assert (df['label'] == 'x y z').all() diff --git a/isimip_utils/tests/test_parameters.py b/isimip_utils/tests/test_parameters.py new file mode 100644 index 0000000..cf35aa5 --- /dev/null +++ b/isimip_utils/tests/test_parameters.py @@ -0,0 +1,71 @@ +from pathlib import Path + +import pytest + +from isimip_utils.parameters import ( + apply_placeholders, + copy_placeholders, + get_permutations, + get_placeholders, + join_parameters, +) + +parameters = { + 'model': ['model_a', 'model_b'], + 'variable': ['x', 'y', 'z'] +} + + +def test_get_permutations(): + assert get_permutations(parameters) == ( + ('model_a', 'x'), + ('model_a', 'y'), + ('model_a', 'z'), + ('model_b', 'x'), + ('model_b', 'y'), + ('model_b', 'z') + ) + + +def test_get_placeholders(): + assert get_placeholders(parameters, ('model_a', 'x')) == { + 'model': 'model_a', + 'variable': 'x' + } + + +def test_join_parameters(): + assert join_parameters(parameters) == { + 'model': 'model_a+model_b', + 'variable': 'x+y+z' + } + + +def test_join_parameters_max_count(): + assert join_parameters(parameters, 2) == { + 'model': 'model_a+model_b', + 'variable': 'various' + } + + +def test_join_parameters_max_count_label(): + assert join_parameters(parameters, 2, 'label') == { + 'model': 'model_a+model_b', + 'variable': 'label' + } + + +def test_copy_placeholders(): + assert copy_placeholders({'foo': 'bar'}, {'egg': 'spam'}) == { + 'foo': 'bar', + 'egg': 'spam' + } + + +def test_apply_placeholders(): + assert apply_placeholders('{foo}_{egg}', {'foo': 'bar', 'egg': 'spam'}) == Path('bar_spam') + + +def test_apply_placeholders_error(): + with pytest.raises(RuntimeError): + apply_placeholders('{foo}_{egg}', {'foo': 'bar'}) diff --git a/isimip_utils/tests/test_patterns.py b/isimip_utils/tests/test_patterns.py new file mode 100644 index 0000000..38d9fcd --- /dev/null +++ b/isimip_utils/tests/test_patterns.py @@ -0,0 +1,98 @@ +from pathlib import Path + +from isimip_utils.patterns import match_dataset, match_dataset_path, match_file, match_file_path, match_path +from isimip_utils.protocol import fetch_pattern +from isimip_utils.tests import constants + +protocol_locations = ['testing/protocol'] + +pattern_path = 'ISIMIP3a/OutputData/agriculture.json' + +path_specifiers = { + 'simulation_round': 'ISIMIP3a', + 'product': 'OutputData', + 'sector': 'agriculture', + 'period': 'historical' +} + +dataset_specifiers = { + 'model': 'lpjml', + 'climate_forcing': 'gswp3-w5e5', + 'climate_scenario': 'obsclim', + 'soc_scenario': '2015soc', + 'sens_scenario': 'default', + 'variable': 'yield', + 'crop': 'mai', + 'irrigation': 'noirr', + 'region': 'global', + 'time_step': 'annual-gs' +} + +file_specifiers = { + **dataset_specifiers, + 'start_year': 1901, + 'end_year': 2016, +} + + +def test_match_dataset_path(): + dataset_path = Path(constants.YIELD_PATH.replace('_1901_2016.nc', '')) + + pattern = fetch_pattern(pattern_path, protocol_locations) + path, specifiers = match_dataset_path(pattern, constants.DATASETS_PATH / dataset_path) + + assert str(path) == str(dataset_path) + assert specifiers == {**path_specifiers, **dataset_specifiers} + + +def test_match_file_path(): + file_path = Path(constants.YIELD_PATH) + + pattern = fetch_pattern(pattern_path, protocol_locations) + path, specifiers = match_file_path(pattern, constants.DATASETS_PATH / file_path) + + assert str(path) == str(file_path) + assert specifiers == {**path_specifiers, **file_specifiers} + + +def test_match_dataset(): + dataset_path = Path(constants.YIELD_PATH.replace('_1901_2016.nc', '')) + + pattern = fetch_pattern(pattern_path, protocol_locations) + path, specifiers = match_dataset(pattern, constants.DATASETS_PATH / dataset_path) + + assert str(path) == dataset_path.name + assert specifiers == dataset_specifiers + + +def test_match_file(): + file_path = Path(constants.YIELD_PATH) + + pattern = fetch_pattern(pattern_path, protocol_locations) + path, specifiers = match_file(pattern, constants.DATASETS_PATH / file_path) + + assert str(path) == file_path.name + assert specifiers == file_specifiers + + +def test_match_path(): + file_path = Path(constants.YIELD_PATH) + + pattern = fetch_pattern(pattern_path, protocol_locations) + path, specifiers = match_path(pattern, constants.DATASETS_PATH / constants.YIELD_PATH) + + assert str(path) == str(file_path) + assert specifiers == {**path_specifiers, **file_specifiers} + + +def test_match_path_specifiers_map(): + file_path = Path(constants.YIELD_PATH) + + pattern = fetch_pattern(pattern_path, protocol_locations) + pattern['specifiers_map'] = { + 'global': 'spam' + } + path, specifiers = match_path(pattern, constants.DATASETS_PATH / file_path) + + assert str(path) == str(file_path) + assert specifiers == {**path_specifiers, **file_specifiers, 'region': 'spam'} diff --git a/isimip_utils/tests/test_plot.py b/isimip_utils/tests/test_plot.py new file mode 100644 index 0000000..a732821 --- /dev/null +++ b/isimip_utils/tests/test_plot.py @@ -0,0 +1,275 @@ + +import numpy as np +import pandas as pd + +from isimip_utils.pandas import compute_average, create_label +from isimip_utils.plot import format_title, plot_grid, plot_line, plot_map, save_index, save_plot +from isimip_utils.tests import constants +from isimip_utils.xarray import open_dataset, to_dataframe + + +def test_plot_line(): + extraction_path = constants.EXTRACTIONS_PATH / constants.TAS_PATH.replace('_global_', '_select-point-cdo_') + + plot_path = constants.PLOTS_PATH / 'plot_line.png' + plot_path.unlink(missing_ok=True) + + with open_dataset(extraction_path) as ds: + df = to_dataframe(ds) + + chart = plot_line(df) + + assert chart.data.equals(df) + assert chart.encoding.x.shorthand == 'time:T' + assert chart.encoding.y.shorthand == 'tas:Q' + + save_plot(chart, plot_path) + + assert plot_path.is_file + + +def test_plot_line_nocf(): + extraction_path = constants.EXTRACTIONS_PATH / constants.TAS_PATH.replace('_global_', '_select-point-cdo_') + + plot_path = constants.PLOTS_PATH / 'plot_line_nocf.png' + plot_path.unlink(missing_ok=True) + + with open_dataset(extraction_path, decode_cf=True) as ds: + df = to_dataframe(ds) + + chart = plot_line(df, x_type='Q') + + assert chart.data.equals(df) + assert chart.encoding.x.shorthand == 'time:Q' + assert chart.encoding.y.shorthand == 'tas:Q' + + save_plot(chart, plot_path) + + assert plot_path.is_file + + +def test_plot_line_empty(): + extraction_path = constants.EXTRACTIONS_PATH / constants.TAS_PATH.replace('_global_', '_select-point-cdo_') + + plot_path = constants.PLOTS_PATH / 'plot_line_empty.png' + plot_path.unlink(missing_ok=True) + + with open_dataset(extraction_path) as ds: + df = to_dataframe(ds) + df_empty = pd.DataFrame({ 'time': df['time'], 'tas': np.nan }) + + chart = plot_line(df, empty=True) + + assert chart.data.equals(df_empty) + assert chart.encoding.x.shorthand == 'time:T' + assert chart.encoding.y.shorthand == 'tas:Q' + + save_plot(chart, plot_path) + + assert plot_path.is_file + + +def test_plot_line_area(): + extraction_path = constants.EXTRACTIONS_PATH / constants.TAS_PATH.replace('_global_', '_select-point-cdo_') + + plot_path = constants.PLOTS_PATH / 'plot_line_area.png' + plot_path.unlink(missing_ok=True) + + with open_dataset(extraction_path) as ds: + df = to_dataframe(ds) + df = compute_average(df, 'tas', type='monthly') + + chart = plot_line(df) + + assert chart.data.equals(df) + + mean, area = chart.layer + + assert mean.encoding.x.shorthand == 'month:T' + assert mean.encoding.y.shorthand == 'mean:Q' + + assert area.encoding.x.shorthand == 'month:T' + assert area.encoding.y.shorthand == 'lower:Q' + assert area.encoding.y2.shorthand == 'upper:Q' + + save_plot(chart, plot_path) + + assert plot_path.is_file + + +def test_plot_line_color(): + extraction_path = constants.EXTRACTIONS_PATH / constants.TAS_PATH.replace('_global_', '_select-point-cdo_') + + plot_path = constants.PLOTS_PATH / 'plot_line_color.png' + plot_path.unlink(missing_ok=True) + + with open_dataset(extraction_path) as ds: + df = to_dataframe(ds) + df = compute_average(df, 'tas', type='monthly') + df = create_label(df, ('a', 'b', 'c')) + + chart = plot_line(df, color_scheme='viridis') + + assert chart.data.equals(df) + + mean, area = chart.layer + + assert mean.encoding.x.shorthand == 'month:T' + assert mean.encoding.y.shorthand == 'mean:Q' + + assert area.encoding.x.shorthand == 'month:T' + assert area.encoding.y.shorthand == 'lower:Q' + assert area.encoding.y2.shorthand == 'upper:Q' + + save_plot(chart, plot_path) + + assert plot_path.is_file + + +def test_plot_map(): + date = constants.DATE + date_specifiers = date.strftime('%Y%m%d') + extraction_path = ( + constants.EXTRACTIONS_PATH / constants.TAS_PATH.replace('_global_', '_select-time-cdo_') + .replace(constants.TAS_DATE_SPECIFIERS, date_specifiers) + ) + + plot_path = constants.PLOTS_PATH / 'plot_map.png' + plot_path.unlink(missing_ok=True) + + with open_dataset(extraction_path) as ds: + df = to_dataframe(ds) + chart = plot_map(df) + + assert chart.data.equals(df) + assert chart.encoding.x.shorthand == 'lon:Q' + assert chart.encoding.y.shorthand == 'lat:Q' + assert chart.encoding.color.shorthand == 'tas:Q' + + save_plot(chart, plot_path) + + assert plot_path.is_file + + +def test_plot_map_nocf(): + date = constants.DATE + date_specifiers = date.strftime('%Y%m%d') + extraction_path = ( + constants.EXTRACTIONS_PATH / constants.TAS_PATH.replace('_global_', '_select-time-cdo_') + .replace(constants.TAS_DATE_SPECIFIERS, date_specifiers) + ) + + plot_path = constants.PLOTS_PATH / 'plot_map_nocf.png' + plot_path.unlink(missing_ok=True) + + with open_dataset(extraction_path) as ds: + df = to_dataframe(ds) + chart = plot_map(df) + + assert chart.data.equals(df) + assert chart.encoding.x.shorthand == 'lon:Q' + assert chart.encoding.y.shorthand == 'lat:Q' + assert chart.encoding.color.shorthand == 'tas:Q' + + save_plot(chart, plot_path) + + assert plot_path.is_file + + +def test_plot_map_empty(): + date = constants.DATE + date_specifiers = date.strftime('%Y%m%d') + extraction_path = ( + constants.EXTRACTIONS_PATH / constants.TAS_PATH.replace('_global_', '_select-time-cdo_') + .replace(constants.TAS_DATE_SPECIFIERS, date_specifiers) + ) + + plot_path = constants.PLOTS_PATH / 'plot_map_empty.png' + plot_path.unlink(missing_ok=True) + + with open_dataset(extraction_path) as ds: + df = to_dataframe(ds) + df_empty = pd.DataFrame({ + 'lon': [], + 'lat': [] + }) + + chart = plot_map(df, empty=True) + + assert chart.data.equals(df_empty) + assert chart.encoding.x.shorthand == 'lon:Q' + assert chart.encoding.y.shorthand == 'lat:Q' + + save_plot(chart, plot_path) + + assert plot_path.is_file + + +def test_plot_grid(): + extraction_paths = [ + constants.EXTRACTIONS_PATH / dataset_path.replace('_global_', '_select-point-cdo_') + for dataset_path in constants.TAS_SPLIT_PATHS + ] + + plot_path = constants.PLOTS_PATH / 'plot_grid.png' + plot_path.unlink(missing_ok=True) + + dataframes = [] + for extraction_path in extraction_paths: + with open_dataset(extraction_path) as ds: + dataframes.append(to_dataframe(ds)) + + df_empty = pd.DataFrame({ 'time': dataframes[2]['time'], 'tas': np.nan }) + + grid_permutations = [ + ('a', 'x'), + ('a', 'y'), + ('b', 'x'), + ] + plot_permutations = [()] + + plots = {} + for permutation, df in zip(grid_permutations, dataframes, strict=True): + plots[permutation] = plot_line(df) + + empty_plot = plot_line(df, empty=True) + + grid_permutations.append(('b', 'y')) + + chart = plot_grid(grid_permutations, plot_permutations, plots, x='independent', empty_plot=empty_plot) + + top, bottom = chart.vconcat + top_left, top_right = top.hconcat + bottom_left, bottom_right = bottom.hconcat + + assert top_left.data.equals(dataframes[0]) + assert top_right.data.equals(dataframes[1]) + assert bottom_left.data.equals(dataframes[2]) + assert bottom_right.data.equals(df_empty) + + for compound_chart in [chart, top, bottom]: + assert compound_chart.resolve.scale.x == 'independent' + assert compound_chart.resolve.scale.y == 'shared' + + save_plot(chart, plot_path) + + assert plot_path.is_file + + +def test_save_index(): + index_path = constants.PLOTS_PATH / 'index.html' + index_path.unlink(missing_ok=True) + + save_index(index_path) + + assert index_path.is_file + + +def test_format_title(): + permutation = ('a', 'b', 'c') + + assert format_title(permutation) == { + "text": 'a Β· b Β· c', + "fontSize": 16, + "dy": -10 + } diff --git a/isimip_utils/tests/test_protocol.py b/isimip_utils/tests/test_protocol.py new file mode 100644 index 0000000..a34fbb1 --- /dev/null +++ b/isimip_utils/tests/test_protocol.py @@ -0,0 +1,76 @@ +from unittest.mock import patch + +import pytest + +from isimip_utils.protocol import ( + fetch_definitions, + fetch_pattern, + fetch_schema, + fetch_tree, + find_json, +) +from isimip_utils.tests import helper + +paths = [ + 'ISIMIP3a/OutputData/agriculture/ACEA/gswp3-w5e5.json', + 'ISIMIP3a/OutputData/agriculture/ACEA.json', + 'ISIMIP3a/OutputData/agriculture.json' +] + + +@pytest.mark.parametrize('path', paths) +def test_fetch_definitions_local(path): + data = fetch_definitions(path, 'testing/protocol') + assert data and isinstance(data, dict) + + +@pytest.mark.parametrize('path', paths) +def test_fetch_pattern(path): + with patch('isimip_utils.fetch.requests.get', side_effect=helper.mock_json): + data = fetch_pattern(path) + assert data and isinstance(data, dict) + + +@pytest.mark.parametrize('path', paths) +def test_fetch_pattern_local(path): + data = fetch_pattern(path, 'testing/protocol') + assert data and isinstance(data, dict) + + +@pytest.mark.parametrize('path', paths) +def test_fetch_schema(path): + with patch('isimip_utils.fetch.requests.get', side_effect=helper.mock_json): + data = fetch_schema(path) + assert data and isinstance(data, dict) + + +@pytest.mark.parametrize('path', paths) +def test_fetch_schema_local(path): + data = fetch_schema(path, 'testing/protocol') + assert data and isinstance(data, dict) + + +@pytest.mark.parametrize('path', paths) +def test_fetch_tree(path): + with patch('isimip_utils.fetch.requests.get', side_effect=helper.mock_json): + data = fetch_tree(path) + assert data and isinstance(data, dict) + + +@pytest.mark.parametrize('path', paths) +def test_fetch_tree_local(path): + data = fetch_tree(path, 'testing/protocol') + assert data and isinstance(data, dict) + + +@pytest.mark.parametrize('path', paths) +def test_find_json_fetch(path): + with patch('isimip_utils.fetch.requests.get', side_effect=helper.mock_json): + data = find_json('https://protocol.isimip.org', 'definitions', path) + assert data is not None + + +@pytest.mark.parametrize('path', paths) +def test_find_json_load(path): + data = find_json('testing/protocol', 'definitions', path) + assert data is not None diff --git a/isimip_utils/tests/test_utils.py b/isimip_utils/tests/test_utils.py new file mode 100644 index 0000000..6efbd03 --- /dev/null +++ b/isimip_utils/tests/test_utils.py @@ -0,0 +1,114 @@ +import pytest + +from isimip_utils.exceptions import ValidationError +from isimip_utils.utils import ( + Singleton, + cached_property, + exclude_path, + get_max_value, + get_min_value, + include_path, + validate_lat, + validate_lon, +) + + +def test_singleton(): + a = Singleton() + a.egg = 'spam' + + b = Singleton() + assert b.egg == 'spam' + + +def test_cached_property(): + + class Test: + + def __init__(self): + self.counter = 0 + + @cached_property + def egg(self): + self.counter += 1 + return 'spam' + + t = Test() + assert t.egg == 'spam' + assert t.egg == 'spam' + assert t.counter == 1 + + +@pytest.mark.parametrize('lat', (-90.0, -45.5, 0, 45, 90)) +def test_validate_lat(lat): + validate_lat(lat) + + +@pytest.mark.parametrize('lat', (-91, 91, None, '', 'none')) +def test_validate_lat_error(lat): + with pytest.raises(ValidationError): + validate_lat(lat) + + +@pytest.mark.parametrize('lon', (-180.0, -45.5, 0, 45, 180)) +def test_validate_lon(lon): + validate_lon(lon) + + +@pytest.mark.parametrize('lon', (-181, 181, None, '', 'none')) +def test_validate_lon_error(lon): + with pytest.raises(ValidationError): + validate_lon(lon) + + +@pytest.mark.parametrize('exclude,path,match,result', ( + ([], 'a/b/c', 'any', False), + (['a/b/c', 'a/b/d', 'a/b/e'], 'a/b/c', 'any', True), + (['a/b/c', 'a/b/d', 'a/b/e'], 'a/b/cc', 'any', True), + (['a/b/c', 'a/b/d', 'a/b/e'], 'a/b/f', 'any', False), + (['a_b', 'c_d'], 'a_b_c_d', 'any', True), + (['a_b', 'c_d'], 'a_b_c_d', 'all', True), + (['a_b', 'c_e'], 'a_b_c_d', 'any', True), + (['a_b', 'c_e'], 'a_b_c_d', 'all', False), + (['a_e', 'c_d'], 'a_b_c_d', 'any', True), + (['a_e', 'c_d'], 'a_b_c_e', 'all', False), +)) +def test_exclude_path(exclude, path, match, result): + assert exclude_path(exclude, path, match) is result + + + +@pytest.mark.parametrize('include,path,match,result', ( + ([], 'a/b/c', 'any', True), + (['a/b/c', 'a/b/d', 'a/b/e'], 'a/b/c', 'any', True), + (['a/b/c', 'a/b/d', 'a/b/e'], 'a/b/cc', 'any', True), + (['a/b/c', 'a/b/d', 'a/b/e'], 'a/b/f', 'any', False), + (['a_b', 'c_d'], 'a_b_c_d', 'any', True), + (['a_b', 'c_d'], 'a_b_c_d', 'all', True), + (['a_b', 'c_e'], 'a_b_c_d', 'any', True), + (['a_b', 'c_e'], 'a_b_c_d', 'all', False), + (['a_e', 'c_d'], 'a_b_c_d', 'any', True), + (['a_e', 'c_d'], 'a_b_c_e', 'all', False), +)) +def test_include_path(include, path, match, result): + assert include_path(include, path, match) is result + + +@pytest.mark.parametrize('values,result', [ + ([1, 2, 3], 1), + ([None, 2, 3], 2), + ([None, None, None], None), + ([], None) +]) +def test_get_min_value(values, result): + assert get_min_value(values) == result + + +@pytest.mark.parametrize('values,result', [ + ([1, 2, 3], 3), + ([1, 2, None], 2), + ([None, None, None], None), + ([], None) +]) +def test_get_max_value(values, result): + assert get_max_value(values) == result diff --git a/isimip_utils/tests/test_xarray.py b/isimip_utils/tests/test_xarray.py new file mode 100644 index 0000000..57933e2 --- /dev/null +++ b/isimip_utils/tests/test_xarray.py @@ -0,0 +1,555 @@ +from datetime import timedelta + +import cftime +import geopandas as gpd +import numpy as np +import pandas as pd +import xarray as xr +from shapely.geometry import box + +from isimip_utils.netcdf import open_dataset_read +from isimip_utils.tests import constants, helper +from isimip_utils.xarray import ( + add_compression_to_data_vars, + add_fill_value_to_data_vars, + convert_time, + create_mask, + get_attrs, + init_dataset, + load_dataset, + open_dataset, + order_variables, + remove_fill_value_from_coords, + set_attrs, + to_dataframe, + write_dataset, +) + + +def test_init_dataset(): + ds = init_dataset() + + assert isinstance(ds, xr.Dataset) + assert ds.sizes['lon'] == 720 + assert ds.sizes['lat'] == 360 + + test_path = constants.OUTPUT_PATH / 'test.nc' + test_path.unlink(missing_ok=True) + + write_dataset(ds, test_path) + + output = helper.call(f'ncdump -h {test_path}') + + helper.assert_multiline_strings_equal(output, ''' +netcdf test { +dimensions: + lon = 720 ; + lat = 360 ; +variables: + double lon(lon) ; + lon:standard_name = "longitude" ; + lon:long_name = "Longitude" ; + lon:units = "degrees_east" ; + lon:axis = "X" ; + double lat(lat) ; + lat:standard_name = "latitude" ; + lat:long_name = "Latitude" ; + lat:units = "degrees_north" ; + lat:axis = "Y" ; +} +''') + + +def test_init_dataset_float(): + lon_size, lat_size = 18, 9 + + var = np.random.rand(lat_size, lon_size).astype(np.float32) + + attrs = { + 'var': { + 'long_name': 'Variable' + } + } + + ds = init_dataset(lon=lon_size, lat=lat_size, var=var, attrs=attrs) + + assert isinstance(ds, xr.Dataset) + assert ds.sizes['lon'] == lon_size + assert ds.sizes['lat'] == lat_size + + test_path = constants.OUTPUT_PATH / 'test.nc' + test_path.unlink(missing_ok=True) + + write_dataset(ds, test_path) + + output = helper.call(f'ncdump -h {test_path}') + + helper.assert_multiline_strings_equal(output, ''' +netcdf test { +dimensions: + lon = 18 ; + lat = 9 ; +variables: + double lon(lon) ; + lon:standard_name = "longitude" ; + lon:long_name = "Longitude" ; + lon:units = "degrees_east" ; + lon:axis = "X" ; + double lat(lat) ; + lat:standard_name = "latitude" ; + lat:long_name = "Latitude" ; + lat:units = "degrees_north" ; + lat:axis = "Y" ; + float var(lat, lon) ; + var:_FillValue = 1.e+20f ; + var:long_name = "Variable" ; + var:missing_value = 1.e+20f ; +} +''') + +def test_init_dataset_double(): + lon_size, lat_size = 18, 9 + + var = np.random.rand(lat_size, lon_size).astype(np.float64) + + attrs = { + 'var': { + 'long_name': 'Variable' + } + } + + ds = init_dataset(lon=lon_size, lat=lat_size, var=var, attrs=attrs) + + assert isinstance(ds, xr.Dataset) + assert ds.sizes['lon'] == lon_size + assert ds.sizes['lat'] == lat_size + + test_path = constants.OUTPUT_PATH / 'test.nc' + test_path.unlink(missing_ok=True) + + write_dataset(ds, test_path) + + output = helper.call(f'ncdump -h {test_path}') + + helper.assert_multiline_strings_equal(output, ''' +netcdf test { +dimensions: + lon = 18 ; + lat = 9 ; +variables: + double lon(lon) ; + lon:standard_name = "longitude" ; + lon:long_name = "Longitude" ; + lon:units = "degrees_east" ; + lon:axis = "X" ; + double lat(lat) ; + lat:standard_name = "latitude" ; + lat:long_name = "Latitude" ; + lat:units = "degrees_north" ; + lat:axis = "Y" ; + double var(lat, lon) ; + var:_FillValue = 1.e+20 ; + var:long_name = "Variable" ; + var:missing_value = 1.e+20 ; +} +''') + +def test_init_dataset_args(): + lon_size, lat_size, time_size = 180, 90, 10 + + time = np.arange(time_size, dtype=np.float64) + var = np.random.rand(time_size, lat_size, lon_size).astype(np.float32) + + attrs = { + 'var': { + 'long_name': 'Variable' + }, + 'time': { + 'calendar': '365_day', + 'units': 'days since 2000-01-01 00:00:00' + } + } + + ds = init_dataset(lon=lon_size, lat=lat_size, time=time, attrs=attrs, var=var) + + assert isinstance(ds, xr.Dataset) + assert ds.sizes['lon'] == lon_size + assert ds.sizes['lat'] == lat_size + + assert ds['time'].units == attrs['time']['units'] + assert ds['time'].calendar == attrs['time']['calendar'] + + assert np.array_equal(ds['var'].values, var) + assert ds['var'].long_name == attrs['var']['long_name'] + + test_path = constants.OUTPUT_PATH / 'test.nc' + test_path.unlink(missing_ok=True) + + write_dataset(ds, test_path) + + output = helper.call(f'ncdump -h {test_path}') + + helper.assert_multiline_strings_equal(output, ''' +netcdf test { +dimensions: + time = UNLIMITED ; // (10 currently) + lon = 180 ; + lat = 90 ; +variables: + double lon(lon) ; + lon:standard_name = "longitude" ; + lon:long_name = "Longitude" ; + lon:units = "degrees_east" ; + lon:axis = "X" ; + double lat(lat) ; + lat:standard_name = "latitude" ; + lat:long_name = "Latitude" ; + lat:units = "degrees_north" ; + lat:axis = "Y" ; + double time(time) ; + time:standard_name = "time" ; + time:long_name = "Time" ; + time:calendar = "365_day" ; + time:units = "days since 2000-01-01 00:00:00" ; + time:axis = "T" ; + float var(time, lat, lon) ; + var:_FillValue = 1.e+20f ; + var:long_name = "Variable" ; + var:missing_value = 1.e+20f ; +} +''') + + +def test_init_dataset_latlon(): + var = np.random.rand(10, 1, 1).astype(np.float32) + + attrs = { + 'var': { + 'long_name': 'Variable' + } + } + + ds = init_dataset( + lon=np.array([10], dtype=np.float64), + lat=np.array([20], dtype=np.float64), + time=10, attrs=attrs, var=var + ) + + assert isinstance(ds, xr.Dataset) + assert ds.sizes['lon'] == 1 + assert ds.sizes['lat'] == 1 + + assert ds['time'].units == 'days since 1601-01-01 00:00:00' + assert ds['time'].calendar == 'proleptic_gregorian' + + assert np.array_equal(ds['var'].values, var) + assert ds['var'].long_name == attrs['var']['long_name'] + + test_path = constants.OUTPUT_PATH / 'test.nc' + test_path.unlink(missing_ok=True) + + write_dataset(ds, test_path) + + output = helper.call(f'ncdump -h {test_path}') + + helper.assert_multiline_strings_equal(output, ''' +netcdf test { +dimensions: + time = UNLIMITED ; // (10 currently) + lon = 1 ; + lat = 1 ; +variables: + double lon(lon) ; + lon:standard_name = "longitude" ; + lon:long_name = "Longitude" ; + lon:units = "degrees_east" ; + lon:axis = "X" ; + double lat(lat) ; + lat:standard_name = "latitude" ; + lat:long_name = "Latitude" ; + lat:units = "degrees_north" ; + lat:axis = "Y" ; + double time(time) ; + time:standard_name = "time" ; + time:long_name = "Time" ; + time:calendar = "proleptic_gregorian" ; + time:units = "days since 1601-01-01 00:00:00" ; + time:axis = "T" ; + float var(time, lat, lon) ; + var:_FillValue = 1.e+20f ; + var:long_name = "Variable" ; + var:missing_value = 1.e+20f ; +} +''') + + +def test_init_dataset_dims(): + a = np.arange(0, 2, dtype=np.float64) + b = np.arange(0, 3, dtype=np.float64) + var = np.random.rand(b.size, a.size, 360, 720).astype(np.float32) + + attrs = { + 'var': { + 'long_name': 'Variable' + }, + 'a': { + 'long_name': 'A Axis', + 'axis': 'A' + }, + 'b': { + 'long_name': 'B Axis', + 'axis': 'B' + } + } + + ds = init_dataset(dims=('b', 'a', 'lat', 'lon'), attrs=attrs, a=a, b=b, var=var) + + assert isinstance(ds, xr.Dataset) + + assert ds['a'].long_name == attrs['a']['long_name'] + assert ds['b'].long_name == attrs['b']['long_name'] + + assert np.array_equal(ds['var'].values, var) + assert ds['var'].long_name == attrs['var']['long_name'] + + test_path = constants.OUTPUT_PATH / 'test.nc' + test_path.unlink(missing_ok=True) + + write_dataset(ds, test_path) + + output = helper.call(f'ncdump -h {test_path}') + + helper.assert_multiline_strings_equal(output, ''' +netcdf test { +dimensions: + lon = 720 ; + lat = 360 ; + b = 3 ; + a = 2 ; +variables: + double lon(lon) ; + lon:standard_name = "longitude" ; + lon:long_name = "Longitude" ; + lon:units = "degrees_east" ; + lon:axis = "X" ; + double lat(lat) ; + lat:standard_name = "latitude" ; + lat:long_name = "Latitude" ; + lat:units = "degrees_north" ; + lat:axis = "Y" ; + double b(b) ; + b:long_name = "B Axis" ; + b:axis = "B" ; + double a(a) ; + a:long_name = "A Axis" ; + a:axis = "A" ; + float var(b, a, lat, lon) ; + var:_FillValue = 1.e+20f ; + var:long_name = "Variable" ; + var:missing_value = 1.e+20f ; +} +''') + + +def test_open_dataset(): + with open_dataset(constants.DATASETS_PATH / constants.TAS_PATH) as ds: + assert isinstance(ds, xr.Dataset) + assert ds['time'].dtype.type == np.datetime64 + + +def test_open_dataset_decode_cf_false(): + with open_dataset(constants.DATASETS_PATH / constants.TAS_PATH, decode_cf=False) as ds: + assert isinstance(ds, xr.Dataset) + assert ds['time'].dtype.type == np.float64 + + +def test_open_dataset_growing_seasons(): + with open_dataset(constants.DATASETS_PATH / constants.YIELD_PATH) as ds: + assert isinstance(ds, xr.Dataset) + assert isinstance(ds['time'].dtype, object) + assert ds['time'].values[0].isoformat() == '1901-01-01T00:00:00' + + +def test_load_dataset(): + with load_dataset(constants.DATASETS_PATH / constants.LANDSEAMASK_PATH) as ds: + assert isinstance(ds, xr.Dataset) + + +def test_order_variables(): + test_path = constants.OUTPUT_PATH / 'test.nc' + test_path.unlink(missing_ok=True) + + ds = init_dataset( + var=np.random.rand(360, 720).astype(np.float64) + ) + ds = ds[[*ds.data_vars, *ds.coords]] + ds.to_netcdf(test_path) + + dataset = open_dataset_read(test_path) + assert tuple(dataset.variables) == ('var', 'lon', 'lat') + + test_path.unlink(missing_ok=True) + + ds = order_variables(ds) + ds.to_netcdf(test_path) + + dataset = open_dataset_read(test_path) + assert tuple(dataset.variables) == ('lon', 'lat', 'var') + + +def test_get_attrs(): + with open_dataset(constants.DATASETS_PATH / constants.TAS_PATH) as ds: + attrs = get_attrs(ds) + assert attrs['lon']['long_name'] == 'Longitude' + assert attrs['lat']['long_name'] == 'Latitude' + assert attrs['tas']['long_name'] == 'Near-Surface Air Temperature' + + +def test_set_attrs(): + with open_dataset(constants.DATASETS_PATH / constants.TAS_PATH) as ds: + attrs = get_attrs(ds) + attrs['tas']['egg'] = 'spam' + set_attrs(ds, attrs) + assert attrs['tas']['egg'] == 'spam' + + +def test_remove_fill_value_from_coords(): + ds = xr.Dataset( + coords={ + 'time': np.arange(10, dtype=np.float64) + }, + data_vars={ + 'var': (['time'], np.ones(10)) + } + ) + remove_fill_value_from_coords(ds) + assert '_FillValue' not in ds['time'] + + +def test_add_fill_value_to_data_vars(): + ds = xr.Dataset( + coords={ + 'time': np.arange(10, dtype=np.float64) + }, + data_vars={ + 'var': (['time'], np.ones(10)) + } + ) + + assert not ds['var'].encoding + + add_fill_value_to_data_vars(ds) + + assert ds['var'].encoding.get('_FillValue') == 1e20 + assert ds['var'].encoding.get('missing_value') == 1e20 + + +def test_add_compression_to_data_vars(): + ds = xr.Dataset( + coords={ + 'time': np.arange(10, dtype=np.float64) + }, + data_vars={ + 'var': (['time'], np.ones(10)) + } + ) + + assert not ds['var'].encoding + + add_compression_to_data_vars(ds, 9) + + assert ds['var'].encoding.get('zlib') is True + assert ds['var'].encoding.get('complevel') == 9 + + +def test_create_mask(): + ds = init_dataset( + var=np.ones((360, 720)) + ) + + geometry = box(-10, -5, 10, 5) + + df = gpd.GeoDataFrame( + [{'geometry': geometry}], + crs='EPSG:4326' # WGS84 coordinate system + ) + + mask_ds = create_mask(ds, df, layer=0) + + assert mask_ds['lon'].shape == (720, ) + assert mask_ds['lat'].shape == (360, ) + + assert mask_ds['mask'].dims == ('lat', 'lon') + assert mask_ds['mask'].shape == (360, 720) + + inside_region = mask_ds.sel(lat=slice(5, -5), lon=slice(-10, 10)) + assert np.all(inside_region['mask'].values == 1.0) + + outside_regions = [ + mask_ds.sel(lon=slice(90, 5)), + mask_ds.sel(lon=slice(-5, -90)), + mask_ds.sel(lon=slice(10, 180)), + mask_ds.sel(lon=slice(-180, -10)) + ] + for outside_region in outside_regions: + assert np.all(np.isnan(outside_region['mask'].values)) + + +def test_convert_time_datetime(): + calendar = 'proleptic_gregorian' + units = 'days since 2000-01-01 00:00:00' + + start_day = cftime.datetime(2000, 1, 1, calendar=calendar) + end_day = cftime.datetime(2000, 12, 31, calendar=calendar) + + time = np.array([start_day + timedelta(days=i) for i in range((end_day - start_day).days + 1)], dtype=object) + time_converted = convert_time(time, calendar=calendar, units=units) + + start = 0 + assert np.array_equal(time_converted, np.arange(start, start + 366, dtype=np.float64)) + + +def test_convert_time_datetime64(): + time = np.array(pd.date_range(start='2000-01-01', end='2000-12-31', freq='D')) + time_converted = convert_time(time) + + start = 145731 + assert np.array_equal(time_converted, np.arange(start, start + 366, dtype=np.float64)) + + +def test_convert_time_datetime64_index(): + time = pd.date_range(start='2000-01-01', end='2000-12-31', freq='D') + time_converted = convert_time(time) + + start = 145731 + assert np.array_equal(time_converted, np.arange(start, start + 366, dtype=np.float64)) + + +def test_convert_time_datetime64_series(): + time = pd.Series(pd.date_range(start='2000-01-01', end='2000-12-31', freq='D')) + time_converted = convert_time(time) + + start = 145731 + assert np.array_equal(time_converted, np.arange(start, start + 366, dtype=np.float64)) + + +def test_convert_time_datetime_str(): + time = pd.date_range(start='2000-01-01', end='2000-12-31', freq='D').astype(str) + time_converted = convert_time(time) + + start = 145731 + assert np.array_equal(time_converted, np.arange(start, start + 366, dtype=np.float64)) + + +def test_to_dataframe(): + ds = xr.Dataset( + coords={ + 'time': np.arange(10, dtype=np.float64) + }, + data_vars={ + 'var': (['time'], np.ones(10)) + } + ) + df = to_dataframe(ds) + assert np.array_equal(df['time'], ds['time']) + assert np.array_equal(df['var'], ds['var']) diff --git a/isimip_utils/utils.py b/isimip_utils/utils.py index fcb07b4..f455e74 100644 --- a/isimip_utils/utils.py +++ b/isimip_utils/utils.py @@ -1,26 +1,154 @@ -def parse_filelist(filelist_file): - if filelist_file: - with open(filelist_file) as f: - filelist = {line for line in f.read().splitlines() if (line and not line.startswith('#'))} - else: - filelist = None +"""Additional utility functions for ISIMIP tools.""" +from collections.abc import Callable +from pathlib import Path +from typing import Any, Literal + +from .exceptions import ValidationError + + +class Singleton: + """Base class for implementing the singleton pattern. + + Ensures only one instance of a class exists. Subclasses will share + a single instance with a 'data' attribute initialized as an empty dict. + """ + _instance: Any = None + + def __new__(cls) -> 'Singleton': + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance.data = {} + return cls._instance + + +class cached_property: + """Decorator that converts a method into a cached property. + + The property value is computed once and then cached as an instance attribute. + Subsequent accesses return the cached value without re-computing. - return filelist + Simplified version of + [Django's cached_property](https://github.com/django/django/blob/main/django/utils/functional.py). + """ + name: str | None = None -def exclude_path(exclude, path): + def __init__(self, func: Callable) -> None: + self.func = func + + def __set_name__(self, owner: type, name: str) -> None: + if self.name is None: + self.name = name + else: + raise TypeError("Cannot assign the same cached_property to two different names") + + def __get__(self, instance: Any, cls: type | None = None) -> Any: + if instance is None: + return self + value = instance.__dict__[self.name] = self.func(instance) + return value + + +def exclude_path(exclude: list[str] | None, path: Path | str, match: Literal['any', 'all'] = 'any') -> bool: + """Check if a path should be excluded based on exclude patterns. + + Args: + exclude (list[str] | None): List of include patterns (strings). Path is excluded if it + contains any or all patterns, depending on the match argument or if include list is None/empty. + path (Path | str): Path to check for exclusion. + match ('any', 'all'): Match all or any of the lines in exclude. + + Returns: + True if path should be excluded, False otherwise. + """ if exclude: - for exclude_string in exclude: - if str(path).startswith(exclude_string): - return True + if match == 'any': + return any(string in str(path) for string in exclude) + elif match == 'all': + return all(string in str(path) for string in exclude) + else: + raise ValidationError(f'match={match} needs to be "any" or "all"') return False -def include_path(include, path): +def include_path(include: list[str] | None, path: Path | str, match: Literal['any', 'all'] = 'any') -> bool: + """Check if a path should be included based on include patterns. + + Args: + include (list[str] | None): List of include patterns (strings). Path is included if it + contains any or all patterns, depending on the match argument or if include list is None/empty. + path (Path | str): Path to check for inclusion. + match ('any', 'all'): Match all or any of the lines in exclude. + + Returns: + True if path should be included, False otherwise. + """ if include: - for include_string in include: - if str(path).startswith(include_string): - return True - return False + if match == 'any': + return any(string in str(path) for string in include) + elif match == 'all': + return all(string in str(path) for string in include) + else: + raise ValidationError(f'match={match} needs to be "any" or "all"') else: return True + + +def validate_lat(lat: float) -> None: + """Validate latitude value is within valid range. + + Args: + lat (float): Latitude value to validate. + + Raises: + ValidationError: If latitude is outside -90 to 90 range. + """ + try: + if lat < -90: + raise ValidationError(f'lat={lat} must be > -90') + elif lat > 90: + raise ValidationError(f'lat={lat} must be < 90') + except TypeError as e: + raise ValidationError(f'lat={lat} is a valid number') from e + + +def validate_lon(lon: float) -> None: + """Validate longitude value is within valid range. + + Args: + lon (float): Longitude value to validate. + + Raises: + ValidationError: If longitude is outside -180 to 180 range. + """ + try: + if lon < -180: + raise ValidationError(f'lon={lon} must be > -180') + elif lon > 180: + raise ValidationError(f'lon={lon} must be < 180') + except TypeError as e: + raise ValidationError(f'lon={lon} is a valid number') from e + + +def get_min_value(values): + """Get the minimal value of the input values, excluding None and using None as default. + + Args: + values (list): Input values. + + Returns: + Minimal value + """ + return min([v for v in values if v is not None], default=None) + + +def get_max_value(values): + """Get the maximum value of the input values, excluding None and using None as default. + + Args: + values (list): Input values. + + Returns: + Maximum value + """ + return max([v for v in values if v is not None], default=None) diff --git a/isimip_utils/xarray.py b/isimip_utils/xarray.py new file mode 100644 index 0000000..66759f9 --- /dev/null +++ b/isimip_utils/xarray.py @@ -0,0 +1,490 @@ +"""Functions for working with xarray datasets for ISIMIP data.""" +import logging +import warnings +from datetime import date, datetime +from pathlib import Path + +import cftime +import numpy as np +import pandas as pd +import xarray as xr + +logger = logging.getLogger(__name__) + +DEFAULT_ATTRS = { + 'lon': { + 'standard_name': 'longitude', + 'long_name': 'Longitude', + 'units': 'degrees_east', + 'axis': 'X' + }, + 'lat': { + 'standard_name': 'latitude', + 'long_name': 'Latitude', + 'units': 'degrees_north', + 'axis': 'Y' + }, + 'time': { + 'standard_name': 'time', + 'long_name': 'Time', + 'calendar': 'proleptic_gregorian', + 'units': 'days since 1601-01-01 00:00:00', + 'axis': 'T' + } +} + +FILL_VALUE = 1e20 + +def init_dataset(lon: None | int | np.ndarray = 720, + lat: None | int | np.ndarray = 360, + time: None | int | np.ndarray = None, + dims: None | list = None, + attrs: None | dict = None, + **variables: np.ndarray) -> xr.Dataset: + """Initialize a new xarray dataset with standard ISIMIP dimensions. + + Args: + lon (int | np.ndarray): Number of longitude points, or longitude array, or None to omit (default: 720). + lat (int | np.ndarray): Number of latitude points, or latitude array, or None to omit (default: 360). + time (int | np.ndarray): Number of time steps, or time array, or None to omit time dimension (default: None). + attrs (dict): Dictionary of attributes for variables and global attributes. + dims (list): List of dimensions (default time, lat, lon). + **variables (np.ndarray): Data variables to include in the dataset. + + Returns: + Initialized xarray Dataset with coordinates and data variables. + """ + + # create dimensions + if dims is None: + dims = [] + if time is not None: + dims.append('time') + if lat is not None: + dims.append('lat') + if lon is not None: + dims.append('lon') + + # create coordinates + coords = {} + if isinstance(lon, int): + lon_delta = 360.0 / lon + coords['lon'] = np.arange(-180 + 0.5 * lon_delta, 180, lon_delta) + elif isinstance(lon, np.ndarray): + coords['lon'] = lon + + if isinstance(lat, int): + lat_delta = 180.0 / lat + coords['lat'] = np.arange(90 - 0.5 * lat_delta, -90, -lat_delta) + elif isinstance(lat, np.ndarray): + coords['lat'] = lat + + if isinstance(time, int): + coords['time'] = np.arange(time, dtype=np.float64) + elif isinstance(time, np.ndarray): + coords['time'] = time + + for dim in dims: + if dim not in ['lon', 'lat', 'time']: + coords[dim] = variables[dim] + + # create data variables + data_vars = { + var_name: (dims, var) + for var_name, var in variables.items() + if var_name not in dims + } + + # create dataset + ds = xr.Dataset(coords=coords, data_vars=data_vars) + + # combine attrs + attrs = { + key: {**DEFAULT_ATTRS.get(key, {}), **(attrs or {}).get(key, {})} + for key in {*DEFAULT_ATTRS.keys(), *(attrs or {}).keys()} + } + + # set attributes + for coord in ds.coords: + if coord in attrs: + ds.coords[coord].attrs.update(attrs[coord]) + + for data_var in ds.data_vars: + if attrs: + if data_var in attrs: + ds.data_vars[data_var].attrs.update(attrs[data_var]) + + # set global attributes + ds.attrs = attrs.get('global', {}) + + return ds + + +def open_dataset(path: str | Path, decode_cf: bool = True, load: bool = False) -> xr.Dataset: + """Open a NetCDF dataset using xarray. + + Args: + path (str | Path): Path to the NetCDF file. + decode_cf (bool): Whether to decode CF conventions (default: True). + load (bool): Whether to load data into memory immediately (default: False). + + Returns: + Xarray Dataset object. + + Note: + Handles non-standard time units like 'growing seasons' by converting + them to 'common_years' with a 365_day calendar. + """ + path = Path(path) + + logger.info(f'load {path.absolute()}' if load else f'open {path.absolute()}') + + try: + ds = xr.open_dataset(path, decode_cf=decode_cf) + except ValueError as e: + # workaround for non standard times (e.g. growing seasons) + ds = xr.open_dataset(path, decode_cf=decode_cf, decode_times=False) + + units = ds['time'].units + calendar = ds['time'].calendar + + if units.startswith('months'): + ds['time'] = cftime.num2date(ds['time'].values, units=units, calendar='360_day') + elif units.startswith('years'): + units = units.replace('years', 'common_years') + ds['time'] = cftime.num2date(ds['time'].values, units=units, calendar='365_day') + elif units.startswith('growing seasons'): + units = units.replace('growing seasons', 'common_years') + ds['time'] = cftime.num2date(ds['time'].values, units=units, calendar='365_day') + else: + raise ValueError(f'unable to decode time units "{units}" with calendar "{calendar}"') from e + + if load: + ds.load() + + return ds + + +def load_dataset(path: str | Path, decode_cf: bool = True) -> xr.Dataset: + """Open a NetCDF dataset using xarray and load data into memory immediately. + + Args: + path (str | Path): Path to the NetCDF file. + decode_cf (bool): Whether to decode CF conventions (default: True). + + Returns: + Xarray Dataset object. + + Note: + Handles non-standard time units like 'growing seasons' by converting + them to 'common_years' with a 365_day calendar. + + This is a shortcut for `open_dataset(path, decode_cf, load=True)`. + """ + return open_dataset(path, decode_cf, load=True) + + +def write_dataset(ds: xr.Dataset, path: str | Path): + """Write an xarray dataset to a NetCDF file. + + Args: + ds (xr.Dataset): Xarray Dataset to write. + path (str | Path): Path where the NetCDF file will be written. + + Note: + Automatically adds fill values, converts NaN to fill values, + orders variables, and sets time as unlimited dimension. + """ + path = Path(path) + path.parent.mkdir(exist_ok=True, parents=True) + + logger.info(f'write {path.absolute()}') + + ds = remove_fill_value_from_coords(ds) + ds = add_fill_value_to_data_vars(ds) + ds = add_compression_to_data_vars(ds) + ds = order_variables(ds) + + # time should be an unlimited dimension + unlimited_dims = ['time'] if 'time' in ds.dims else [] + + # write dataset as netcdf + ds.to_netcdf(path, format='NETCDF4_CLASSIC', unlimited_dims=unlimited_dims) + + +def order_variables(ds: xr.Dataset) -> xr.Dataset: + """Reorder dataset variables with coordinates first, then data variables. + + Args: + ds (xr.Dataset): Xarray Dataset to reorder. + + Returns: + Dataset with reordered variables. + """ + preferred_coords = ['lon', 'lat', 'time'] + + ordered_coords = [coord for coord in preferred_coords if coord in ds.coords] + remaining_coords = [coord for coord in ds.coords if coord not in preferred_coords] + + return ds[[*ordered_coords, *remaining_coords, *ds.data_vars]] + + +def get_attrs(ds: xr.Dataset) -> dict: + """Get all attributes from coordinates and data variables. + + Args: + ds (xr.Dataset): Xarray Dataset. + + Returns: + Dictionary mapping variable names to their attributes. + """ + attrs = {} + for coord in ds.coords: + attrs[coord] = ds[coord].attrs + for data_var in ds.data_vars: + attrs[data_var] = ds[data_var].attrs + return attrs + + +def set_attrs(ds: xr.Dataset, attrs: dict) -> xr.Dataset: + """Set attributes on coordinates and data variables. + + Args: + ds (xr.Dataset): Xarray Dataset to modify. + attrs (dict): Dictionary mapping variable names to their attributes. + + Returns: + Modified dataset with updated attributes. + """ + for coord in ds.coords: + if coord in attrs: + ds[coord].attrs = attrs[coord] + for data_var in ds.data_vars: + if data_var in attrs: + ds[data_var].attrs = attrs[data_var] + return ds + + +def set_fill_value_to_nan(ds: xr.Dataset) -> xr.Dataset: + """Replace fill values with NaN in data variables. This is only needed for datasets + which are read with decode_cf=False and _FillValue is not in encoding. + + Args: + ds (xr.Dataset): Xarray Dataset to modify. + + Returns: + Dataset with fill values replaced by NaN. + """ + for data_var in ds.data_vars: + if '_FillValue' not in ds[data_var].encoding: + ds[data_var] = ds[data_var].where(ds[data_var] != FILL_VALUE) + return ds + + +def set_nan_to_fill_value(ds: xr.Dataset) -> xr.Dataset: + """Replace NaN values with fill values in data variables. This is only needed for datasets + which are read with decode_cf=False and _FillValue is not in encoding. + + Args: + ds (xr.Dataset): Xarray Dataset to modify. + + Returns: + Dataset with NaN values replaced by fill values. + """ + for data_var in ds.data_vars: + if '_FillValue' not in ds[data_var].encoding: + ds[data_var] = ds[data_var].fillna(FILL_VALUE) + return ds + + +def remove_fill_value_from_coords(ds: xr.Dataset) -> xr.Dataset: + """Remove _FillValue and missing_value attributes from the coords. + + Args: + ds (xr.Dataset): Xarray Dataset to modify. + + Returns: + Dataset with fill value removed for the coords. + """ + for coord in ds.coords: + if '_FillValue' not in ds[coord].encoding: + ds[coord].encoding['_FillValue'] = None + return ds + + +def add_fill_value_to_data_vars(ds: xr.Dataset) -> xr.Dataset: + """Add _FillValue and missing_value to data_vars if no encoding is present. This + is the case for a newly created Dataset. + + Args: + ds (xr.Dataset): Xarray Dataset to modify. + + Returns: + Dataset with encoding added for the data_vars. + """ + for data_var in ds.data_vars: + encoding = ds[data_var].encoding + if not encoding: + ds[data_var].attrs.pop('_FillValue', None) + ds[data_var].attrs.pop('missing_value', None) + ds[data_var].encoding.update({ + '_FillValue': FILL_VALUE, + 'missing_value': ds[data_var].dtype.type(FILL_VALUE) + }) + + return ds + + +def add_compression_to_data_vars(ds, complevel=5) -> xr.Dataset: + """Add compression to data variables. + + Args: + ds (xr.Dataset): Xarray Dataset to reorder. + complevel (int): Compression level + + Returns: + Dataset with updated encoding. + """ + for data_var in ds.data_vars: + ds[data_var].encoding.update({ + 'zlib': True, + 'complevel': complevel + }) + return ds + + +def compute_time(ds: xr.Dataset, timestamp: datetime | None) -> float | None: + """Convert a datetime to numeric time value for dataset. + + Args: + ds (xr.Dataset): Dataset with time coordinate containing units and calendar. + timestamp (datetime | date | None): Timestamp to convert, or None. + + Returns: + Numeric time value in dataset's units, or None if timestamp is None. + """ + if type(timestamp) is date: + timestamp = datetime.combine(timestamp, datetime.min.time()) + + units = ds.time.encoding.get('units') or ds.coords['time'].attrs.get('units') + calendar = ds.time.encoding.get('calendar') or ds.coords['time'].attrs.get('calendar') + return cftime.date2num(timestamp, units=units, calendar=calendar) if timestamp else None + + +def compute_offset(ds1: xr.Dataset, ds2: xr.Dataset) -> xr.DataArray | None: + """Compute time offset between two datasets with different time units. + + Args: + ds1 (xr.Dataset): First dataset with time coordinate. + ds2 (xr.Dataset): Second dataset with time coordinate. + + Returns: + Time offset to apply to ds2, or None if units/calendars match. + """ + + units1 = ds1.time.encoding.get('units') or ds1.coords['time'].attrs.get('units') + calendar1 = ds1.time.encoding.get('calendar') or ds1.coords['time'].attrs.get('calendar') + units2 = ds2.time.encoding.get('units') or ds2.coords['time'].attrs.get('units') + calendar2 = ds2.time.encoding.get('calendar') or ds2.coords['time'].attrs.get('calendar') + if units1 != units2 or calendar1 != calendar2: + start_time = ds2['time'][0] + start_date = cftime.num2date(start_time, units=units2, calendar=calendar2) + offset = cftime.date2num(start_date, units=units1, calendar=calendar1) - start_time + logger.debug(f'time axis diverges "{units1}"/"{units2}" "{calendar1}"/"{calendar2}" offset={offset.values}') + return offset + + +def create_mask(ds: xr.Dataset, df: pd.DataFrame, layer: int) -> xr.Dataset: + """Create a spatial mask from a geometry layer. + + Args: + ds (xr.Dataset): Xarray Dataset with lat/lon coordinates. + df (pd.DataFrame): GeoDataFrame with geometry column. + layer (int): Index of the layer to use from the GeoDataFrame. + + Returns: + Xarray dataset with a 'mask' variable clipped to the geometry. + + Note: + Requires geopandas and rioxarray to be installed. + """ + import shapely.geometry + logger.info('create mask') + + df_row = df.iloc[layer] + geometry = shapely.geometry.mapping(df_row['geometry']) + + ds_lat = ds.coords['lat'] + ds_lon = ds.coords['lon'] + mask_ds = xr.Dataset( + data_vars={ + 'mask': (('lat', 'lon'), np.ones((ds_lat.size, ds_lon.size), dtype=np.float32)) + }, + coords={'lat': ds_lat, 'lon': ds_lon} + ) + mask_ds.rio.write_crs(df.crs, inplace=True) + mask_ds = mask_ds.rio.clip([geometry], drop=False) + mask_ds = mask_ds.drop_vars('spatial_ref') + return mask_ds + + +def convert_time(time: np.ndarray, units='days since 1601-1-1 00:00:00', calendar='proleptic_gregorian') -> np.ndarray: + """Convert an time coordinate array to np.float64 using cftime.date2num. + + Args: + time (np.ndarray): Time coordinate array. + units (str): Units for the time coordinate (default: 'days since 1601-1-1 00:00:00'). + calendar (str): Calendar type for time coordinate (default: 'proleptic_gregorian'). + + Returns: + time (np.ndarray): Time coordinate array as np.float64. + """ + if isinstance(time.dtype, pd.StringDtype): + time = np.array([datetime.fromisoformat(t) for t in time], dtype=object) + elif np.issubdtype(time.dtype, np.datetime64): + with warnings.catch_warnings(): + warnings.simplefilter("ignore", FutureWarning) + + if isinstance(time, pd.DatetimeIndex): + time = time.to_pydatetime() + elif isinstance(time, pd.Series): + time = time.dt.to_pydatetime() + else: + time = pd.to_datetime(time).to_pydatetime() + + return cftime.date2num( + time, calendar=calendar, units=units + ).astype(np.float64) + + +def to_dataframe(ds: xr.Dataset) -> pd.DataFrame: + """Convert an xarray Dataset to a pandas DataFrame. + + Args: + ds (xr.Dataset): Xarray Dataset to convert. + + Returns: + Pandas DataFrame with coordinates as columns and data variables as columns. + Attributes are preserved in df.attrs['coords'] and df.attrs['data_vars']. + + Note: + Time coordinates are converted to datetime64[ns] format. + Data variables are converted to float64. + """ + if 'time' in ds.coords: + ds.coords['time'] = ds.coords['time'].astype('datetime64[ns]') + + ds = ds.assign({ + data_var: ds[data_var].astype('float64') + for data_var in ds.data_vars + }) + + df = ds.to_dataframe().reset_index() + df.attrs['coords'] = { + coord: ds[coord].attrs for coord in ds.coords if (ds[coord].size > 1) + } + df.attrs['data_vars'] = { + data_var: ds[data_var].attrs for data_var in ds.data_vars if (ds[data_var].size > 1) + } + + return df diff --git a/mkdocs.yml b/mkdocs.yml new file mode 100644 index 0000000..f4dd447 --- /dev/null +++ b/mkdocs.yml @@ -0,0 +1,33 @@ +site_name: ISIMIP utils + +theme: + name: material + +plugins: +- mkdocstrings: + handlers: + python: + options: + show_source: false + show_bases: false + members_order: source + +nav: + - Getting started: index.md + - Examples: examples.md + - API reference: + - isimip_utils.checksum: api/checksum.md + - isimip_utils.cli: api/cli.md + - isimip_utils.config: api/config.md + - isimip_utils.exceptions: api/exceptions.md + - isimip_utils.extractions: api/extractions.md + - isimip_utils.fetch: api/fetch.md + - isimip_utils.files: api/files.md + - isimip_utils.netcdf: api/netcdf.md + - isimip_utils.pandas: api/pandas.md + - isimip_utils.patterns: api/patterns.md + - isimip_utils.parameters: api/parameters.md + - isimip_utils.plot: api/plot.md + - isimip_utils.utils: api/utils.md + - isimip_utils.xarray: api/xarray.md + - Python installation: prerequisites.md diff --git a/pyproject.toml b/pyproject.toml index 165fc61..dce2760 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,7 @@ +[build-system] +requires = ["hatchling", "hatch-vcs"] +build-backend = "hatchling.build" + [project] name = "isimip-utils" authors = [ @@ -8,21 +12,20 @@ maintainers = [ ] description = "This package contains common functionality for different ISIMIP tools." readme = "README.md" -requires-python = ">=3.8" -license = { file = "LICENSE" } +requires-python = ">=3.11" +license = "MIT" +license-files = ["LICENSE"] classifiers = [ - 'Operating System :: OS Independent', - 'License :: OSI Approved :: MIT License', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3.10', - 'Programming Language :: Python :: 3.11', + "Operating System :: OS Independent", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", ] dependencies = [ - "colorlog", - "netCDF4", "python-dotenv", - "requests" + "requests", + "rich", ] dynamic = ["version"] @@ -30,22 +33,55 @@ dynamic = ["version"] Repository = "https://github.com/ISI-MIP/isimip-utils" [project.optional-dependencies] +all = [ + "isimip-utils[netcdf,altair,geopandas,xarray,dev,pytest,docs]" +] +recommended = [ + "isimip-utils[netcdf,altair,geopandas,xarray]" +] +netcdf = [ + "netCDF4~=1.7" +] +altair = [ + "altair[all]~=6.0", + "palettable~=3.3", +] +geopandas = [ + "geopandas~=1.1", + "rioxarray>=0.19", +] +xarray = [ + "cftime~=1.6", + "xarray>=2025.11" +] +pytest = [ + "pytest~=9.0", + "pytest-cov~=7.0" +] dev = [ "build", "pre-commit", "ruff", "twine", ] +docs = [ + "mkdocs", + "mkdocs-material", + "mkdocstrings-python", +] -[tool.setuptools] -packages = ["isimip_utils"] +[tool.hatch.version] +source = "vcs" -[tool.setuptools.dynamic] -version = { attr = "isimip_utils.__version__" } +[tool.hatch.build.targets.wheel] +packages = ["isimip_utils"] +exclude = ["isimip_utils/tests"] [tool.ruff] -target-version = "py38" +target-version = "py311" line-length = 120 + +[tool.ruff.lint] select = [ "B", # flake8-bugbear "C4", # flake8-comprehensions @@ -59,20 +95,35 @@ select = [ "YTT", # flake8-2020 ] ignore = [ - "B006", # mutable-argument-default - "B007", # unused-loop-control-variable - "B018", # useless-expression "RUF012", # mutable-class-default ] -[tool.ruff.isort] +[tool.ruff.lint.isort] known-first-party = [ "isimip_utils" ] section-order = [ "future", "standard-library", + "pytest", "third-party", "first-party", "local-folder" ] + +[tool.ruff.lint.isort.sections] +pytest = ["pytest"] + +[tool.typos.default.extend-words] +iy = "iy" +arange = "arange" + +[tool.coverage.report] +exclude_lines = [ + "pragma: no cover", + "except PackageNotFoundError:", + "except requests.exceptions.ConnectionError", + "raise AssertionError", + "raise NotImplementedError", + "raise RuntimeError" +] diff --git a/testing/download.py b/testing/download.py new file mode 100755 index 0000000..cb1eb15 --- /dev/null +++ b/testing/download.py @@ -0,0 +1,35 @@ +#!/usr/bin/env python3 +from isimip_utils.tests import constants, helper + + +def main(): + download_datasets() + download_protocol() + + +def download_datasets(): + constants.DATASETS_PATH.mkdir(parents=True, exist_ok=True) + + for path in [constants.LANDSEAMASK_PATH, constants.TAS_PATH, constants.YIELD_PATH]: + file_path = constants.DATASETS_PATH / path + file_path.parent.mkdir(parents=True, exist_ok=True) + + url = f"https://files.isimip.org/{path}" + + helper.call(f'wget -c {url} -O {file_path}') + + +def download_protocol(): + constants.PROTOCOL_PATH.mkdir(parents=True, exist_ok=True) + + for path in constants.PROTOCOL_PATHS: + file_path = constants.PROTOCOL_PATH / path + file_path.parent.mkdir(parents=True, exist_ok=True) + + url = f"https://protocol.isimip.org/{path}" + + helper.call(f'wget -c {url} -O {file_path}') + + +if __name__ == "__main__": + main() diff --git a/testing/setup.py b/testing/setup.py new file mode 100755 index 0000000..208e19a --- /dev/null +++ b/testing/setup.py @@ -0,0 +1,193 @@ +#!/usr/bin/env python3 +from isimip_utils.tests import constants, helper + + +def main(): + run_gridfile() + run_seldate() + run_select_time() + run_select_period() + run_select_point() + run_select_bbox() + run_select_bbox_aggregations() + run_select_bbox_map() + run_mask_bbox() + run_mask_mask() + + +def download_datasets(): + constants.DATASETS_PATH.mkdir(parents=True, exist_ok=True) + + for path in [constants.LANDSEAMASK_PATH, constants.TAS_PATH, constants.YIELD_PATH]: + file_path = constants.DATASETS_PATH / path + file_path.parent.mkdir(parents=True, exist_ok=True) + + url = f"https://files.isimip.org/{path}" + + helper.call(f'wget -c {url} -O {file_path}') + + +def download_protocol(): + constants.PROTOCOL_PATH.mkdir(parents=True, exist_ok=True) + + for path in constants.PROTOCOL_PATHS: + file_path = constants.PROTOCOL_PATH / path + file_path.parent.mkdir(parents=True, exist_ok=True) + + url = f"https://protocol.isimip.org/{path}" + + helper.call(f'wget -c {url} -O {file_path}') + + +def run_gridfile(): + input_path = constants.DATASETS_PATH / constants.TAS_PATH + output_path = constants.SHARE_PATH / 'gridarea.nc' + output_path.parent.mkdir(parents=True, exist_ok=True) + + if not output_path.exists(): + helper.call(f'cdo gridarea {input_path} {output_path}') + + +def run_seldate(): + input_path = constants.DATASETS_PATH / constants.TAS_PATH + + for period, path in zip(constants.TAS_SPLIT_PERIOD, constants.TAS_SPLIT_PATHS, strict=True): + start_date, end_date = period + start_date, end_date = start_date.strftime('%Y-%m-%d'), end_date.strftime('%Y-%m-%d') + output_path = constants.DATASETS_PATH / path + + if not output_path.exists(): + helper.call(f'cdo -f nc4c -z zip_5 -L seldate,{start_date},{end_date} {input_path} {output_path}') + + +def run_select_time(): + date = constants.DATE + date_specifiers = date.strftime('%Y%m%d') + + path = constants.TAS_PATH + + input_path = constants.DATASETS_PATH / path + + output_path = ( + constants.EXTRACTIONS_PATH / path.replace('_global_', '_select-time-cdo_') + .replace(constants.TAS_DATE_SPECIFIERS, date_specifiers) + ) + output_path.parent.mkdir(parents=True, exist_ok=True) + + if not output_path.exists(): + helper.call(f"cdo -f nc4c -z zip_5 -L seldate,{date.strftime('%Y-%m-%d')} {input_path} {output_path}") + + +def run_select_period(): + start_date, end_date = constants.PERIOD + date_specifiers = f"{start_date.strftime('%Y%m%d')}_{end_date.strftime('%Y%m%d')}" + + path = constants.TAS_PATH + + input_path = constants.DATASETS_PATH / path + + output_path = ( + constants.EXTRACTIONS_PATH / path.replace('_global_', '_select-period-cdo_') \ + .replace(constants.TAS_DATE_SPECIFIERS, date_specifiers) + ) + output_path.parent.mkdir(parents=True, exist_ok=True) + + if not output_path.exists(): + helper.call(f'cdo -f nc4c -z zip_5 -L seldate,{start_date},{end_date} {input_path} {output_path}') + + +def run_select_point(): + ix, iy = constants.POINT_INDEX + + # add one since cdo is counting from 1! + ix, iy = ix + 1, iy + 1 + + for path in [constants.TAS_PATH, *constants.TAS_SPLIT_PATHS]: + input_path = constants.DATASETS_PATH / path + + output_path = constants.EXTRACTIONS_PATH / path.replace('_global_', '_select-point-cdo_') + output_path.parent.mkdir(parents=True, exist_ok=True) + + if not output_path.exists(): + helper.call(f'cdo -f nc4c -z zip_5 -L -selindexbox,{ix},{ix},{iy},{iy} {input_path} {output_path}') + + +def run_select_bbox(): + west, east, south, north = constants.BBOX + + for path in [constants.TAS_PATH, *constants.TAS_SPLIT_PATHS]: + input_path = constants.DATASETS_PATH / path + + output_path = constants.EXTRACTIONS_PATH / path.replace('_global_', '_select-bbox-cdo_') + output_path.parent.mkdir(parents=True, exist_ok=True) + + if not output_path.exists(): + helper.call(f'cdo -f nc4c -z zip_5 -L ' \ + f'-sellonlatbox,{west},{east},{south},{north} {input_path} {output_path}') + + +def run_select_bbox_aggregations(): + west, east, south, north = constants.BBOX + + for path in [constants.TAS_PATH, *constants.TAS_SPLIT_PATHS]: + input_path = constants.DATASETS_PATH / path + + for aggregation in ['mean', 'min', 'max', 'sum', 'std']: + output_path = constants.EXTRACTIONS_PATH / path.replace('_global_', f'_select-bbox-{aggregation}-cdo_') + output_path.parent.mkdir(parents=True, exist_ok=True) + + gridarea_path = constants.SHARE_PATH / 'gridarea.nc' + + if not output_path.exists(): + if aggregation == 'sum': + helper.call(f'cdo -f nc4c -z zip_5 -L -fld{aggregation} ' + f'-sellonlatbox,{west},{east},{south},{north} ' + f'-mul {input_path} {gridarea_path} {output_path}') + else: + helper.call(f'cdo -f nc4c -z zip_5 -L -fld{aggregation} ' + f'-sellonlatbox,{west},{east},{south},{north} {input_path} {output_path}') + + +def run_select_bbox_map(): + west, east, south, north = constants.BBOX + + for path in [constants.TAS_PATH, *constants.TAS_SPLIT_PATHS]: + input_path = constants.DATASETS_PATH / path + + output_path = constants.EXTRACTIONS_PATH / path.replace('_global_', '_select-bbox-map-cdo_') + output_path.parent.mkdir(parents=True, exist_ok=True) + + if not output_path.exists(): + helper.call('cdo -f nc4c -z zip_5 -L timmean ' + f'-sellonlatbox,{west},{east},{south},{north} {input_path} {output_path}') + + +def run_mask_bbox(): + west, east, south, north = constants.BBOX + + for path in [constants.TAS_PATH, *constants.TAS_SPLIT_PATHS]: + input_path = constants.DATASETS_PATH / path + + output_path = constants.EXTRACTIONS_PATH / path.replace('_global_', '_mask-bbox-cdo_') + output_path.parent.mkdir(parents=True, exist_ok=True) + + if not output_path.exists(): + helper.call(f'cdo -f nc4c -z zip_5 -L ' + f'-masklonlatbox,{west},{east},{south},{north} {input_path} {output_path}') + + +def run_mask_mask(): + mask_path = constants.DATASETS_PATH / constants.LANDSEAMASK_PATH + + for path in [constants.TAS_PATH, *constants.TAS_SPLIT_PATHS]: + input_path = constants.DATASETS_PATH / path + + output_path = constants.EXTRACTIONS_PATH / path.replace('_global_', '_mask-mask-cdo_') + output_path.parent.mkdir(parents=True, exist_ok=True) + + if not output_path.exists(): + helper.call(f'cdo -f nc4c -z zip_5 -L -ifthen -selname,mask {mask_path} {input_path} {output_path}') + + +if __name__ == "__main__": + main()