Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions .github/workflows/build_docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,18 @@ jobs:
- name: Checkout code
uses: actions/checkout@v2

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
- name: Install the latest version of uv
uses: astral-sh/setup-uv@v7
with:
python-version: ${{ matrix.python-version }}

- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install '.[docs]'
uv run echo done

- name: Build docs
run: |
mkdocs build
uv run mkdocs build

- name: Upload docs
uses: actions/upload-artifact@v4
Expand Down
4 changes: 3 additions & 1 deletion .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@ jobs:
uses: patrick-kidger/action_update_python_project@v8
with:
python-version: "3.11"
# Uninstall and reinstall pytest to work around the fact that it doesn't get put into `bin` otherwise.
test-script: |
cp -r ${{ github.workspace }}/test ./test
cp ${{ github.workspace }}/pyproject.toml ./pyproject.toml
uv sync --extra tests --no-install-project --inexact
uv pip uninstall pytest
uv sync --no-install-project --inexact
uv run --no-sync python -m test
pypi-token: ${{ secrets.pypi_token }}
github-user: patrick-kidger
Expand Down
13 changes: 6 additions & 7 deletions .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,23 @@ jobs:
- name: Checkout code
uses: actions/checkout@v2

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
- name: Install the latest version of uv
uses: astral-sh/setup-uv@v7
with:
python-version: ${{ matrix.python-version }}

- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install '.[dev,docs,tests]'
uv run echo done

- name: Checks with pre-commit
run: |
pre-commit run --all-files
uv run prek run --all-files

- name: Test with pytest
run: |
python -m test
uv run python -m test

- name: Check that documentation can be built.
run: |
mkdocs build
uv run mkdocs build
52 changes: 34 additions & 18 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,22 +1,38 @@
fail_fast: true
repos:
- repo: local
- repo: meta
hooks:
- id: sort_pyproject
name: sort_pyproject
entry: toml-sort -i --sort-table-keys --sort-inline-tables
language: python
files: ^pyproject\.toml$
additional_dependencies: ["toml-sort==0.23.1"]
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.13.0
- id: check-hooks-apply
- id: check-useless-excludes
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
- id: ruff-format # formatter
types_or: [ python, pyi, jupyter, toml ]
- id: ruff # linter
types_or: [ python, pyi, jupyter, toml ]
args: [ --fix ]
- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.405
- id: trailing-whitespace
exclude: \.md$
- id: check-toml
- id: mixed-line-ending
- repo: local
hooks:
- id: pyright
additional_dependencies: [equinox, jax, jaxtyping, optax, optimistix, lineax, pytest, typeguard==2.13.3, typing_extensions, wadler_lindig]
- id: sort-pyproject
name: sort pyproject
files: ^pyproject\.toml$
language: system
entry: uv run -- toml-sort -i --sort-table-keys --sort-inline-tables
- id: ruff-format
name: ruff format
types_or: [python, pyi, jupyter, toml]
language: system
entry: uv run -- ruff format --
require_serial: true
- id: ruff-lint
name: ruff lint
types_or: [python, pyi, jupyter, toml]
language: system
entry: uv run -- ruff check --fix --
require_serial: true
- id: pyright
name: pyright
types_or: [python]
language: system
entry: uv run -- pyright
require_serial: true
23 changes: 8 additions & 15 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,24 @@ Contributions (pull requests) are very welcome! Here's how to get started.

---

**Getting started**
### Getting started

First fork the library on GitHub.

Then clone and install the library:
[We assume that you have `uv` installed.](https://docs.astral.sh/uv/) Now fork the library on GitHub. Then clone and install the library:

```bash
git clone https://github.com/your-username-here/diffrax.git
cd diffrax
pip install -e '.[dev]'
pre-commit install # `pre-commit` is installed by `pip` on the previous line
uv run prek install # Creates a local venv + installs dependencies + installs pre-commit hooks.
```

---

**If you're making changes to the code:**

Now make your changes. Make sure to include additional tests if necessary.
### If you're making changes to the code

Next verify the tests all pass:
Now make your changes. Make sure to include additional tests if necessary. Next verify the tests all pass:

```bash
pip install -e '.[tests]'
pytest # `pytest` is installed by `pip` on the previous line.
uv run pytest
```

Then push your changes back to your fork of the repository:
Expand All @@ -40,13 +34,12 @@ Finally, open a pull request on GitHub!

---

**If you're making changes to the documentation:**
### If you're making changes to the documentation

Make your changes. You can then build the documentation by doing

```bash
pip install -e '.[docs]'
mkdocs serve
uv run mkdocs serve
```

You can then see your local copy of the documentation by navigating to `localhost:8000` in a web browser.
6 changes: 3 additions & 3 deletions diffrax/_autocitation.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def _backsolve_adjoint(adjoint, terms=None):
r"""
% You are backpropagating through an SDE using optimise-then-discretise
% (`adjoint=BacksolveAdjoint(...)`)
% This technique was introduced in
% This technique was introduced in
"""
+ vbt_ref
+ r"""
Expand Down Expand Up @@ -273,10 +273,10 @@ def _discrete_adjoint(adjoint):
% If using forward-mode autodifferentiation, then this was studied in:
@inproceedings{ma2021comparison,
title={A Comparison of Automatic Differentiation and Continuous Sensitivity Analysis
for Derivatives of Differential Equation Solutions},
for Derivatives of Differential Equation Solutions},
author={Ma, Yingbo and Dixit, Vaibhav and Innes, Michael J and Guo, Xingjian and
Rackauckas, Chris},
booktitle={2021 IEEE High Performance Extreme Computing Conference (HPEC)},
booktitle={2021 IEEE High Performance Extreme Computing Conference (HPEC)},
year={2021},
pages={1-9},
doi={10.1109/HPEC49654.2021.9622796}
Expand Down
10 changes: 5 additions & 5 deletions diffrax/_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __init__(
the exact time of the event. If the triggered condition function returns a real
number, then the final time will be the time at which that real number equals zero.
(If the triggered condition function returns a boolean, then the returned time will
just be the end of the step on which it becomes `True`.)
just be the end of the step on which it becomes `True`.)
[`optimistix.Newton`](https://docs.kidger.site/optimistix/api/root_find/#optimistix.Newton)
would be a typical choice here.

Expand All @@ -74,12 +74,12 @@ def __init__(

!!! Example

Consider a bouncing ball dropped from some intial height $x_0$. We can model
Consider a bouncing ball dropped from some intial height $x_0$. We can model
the ball by a 2-dimensional ODE

$\\frac{dx_t}{dt} = v_t, \\quad \\frac{dv_t}{dt} = -g,$

where $x_t$ represents the height of the ball, $v_t$ its velocity,
where $x_t$ represents the height of the ball, $v_t$ its velocity,
and $g$ is the gravitational constant. With $g=8$, this corresponds to the
vector field:

Expand All @@ -89,8 +89,8 @@ def vector_field(t, y, args):
return jnp.array([v, -8.0])
```

Figuring out exactly when the ball hits the ground amounts to
solving the ODE until the event $x_t=0$ is triggered. This can be done by using
Figuring out exactly when the ball hits the ground amounts to
solving the ODE until the event $x_t=0$ is triggered. This can be done by using
the real-valued condition function:

```python
Expand Down
4 changes: 2 additions & 2 deletions diffrax/_progress_meter.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,8 @@ def __check_init__(self):
)

@staticmethod
def _init_bar() -> "tqdm.tqdm": # pyright: ignore # noqa: F821
import tqdm # pyright: ignore
def _init_bar() -> "tqdm.tqdm": # pyright: ignore[reportUndefinedVariable] # noqa: F821
import tqdm

bar_format = (
"{percentage:.2f}%|{bar}| [{elapsed}<{remaining}, {rate_fmt}{postfix}]"
Expand Down
2 changes: 1 addition & 1 deletion diffrax/_root_finder/_verychord.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def postprocess(

- `rtol`: Relative tolerance for terminating the solve.
- `atol`: Absolute tolerance for terminating the solve.
- `norm`: The norm used to determine the difference between two iterates in the
- `norm`: The norm used to determine the difference between two iterates in the
convergence criteria. Should be any function `PyTree -> Scalar`, for example
`optimistix.max_norm`.
- `kappa`: A tolerance for the early convergence check.
Expand Down
4 changes: 2 additions & 2 deletions diffrax/_saveat.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __check_init__(self):
- `t0`: If `True`, save the initial input `y0`.
- `t1`: If `True`, save the output at `t1`.
- `ts`: Some array of times at which to save the output.
- `steps`: If `n>0`, save the output at every `n`th step of the numerical solver.
- `steps`: If `n>0`, save the output at every `n`th step of the numerical solver.
`0` means no saving.
- `fn`: A function `fn(t, y, args)` which specifies what to save into `sol.ys` when
using `t0`, `t1`, `ts` or `steps`. Defaults to `fn(t, y, args) -> y`, so that the
Expand Down Expand Up @@ -110,7 +110,7 @@ def __init__(
- `t0`: If `True`, save the initial input `y0`.
- `t1`: If `True`, save the output at `t1`.
- `ts`: Some array of times at which to save the output.
- `steps`: If `n>0`, save the output at every `n`th step of the numerical solver.
- `steps`: If `n>0`, save the output at every `n`th step of the numerical solver.
`0` means no saving.
- `dense`: If `True`, save dense output, that can later be evaluated at any part of
the interval $[t_0, t_1]$ via `sol = diffeqsolve(...); sol.evaluate(...)`.
Expand Down
2 changes: 1 addition & 1 deletion docs/_static/mathjax.js
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ window.MathJax = {
}
};

document$.subscribe(() => {
document$.subscribe(() => {
MathJax.typesetPromise()
})
2 changes: 1 addition & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ edit_uri: ""

strict: true # Don't allow warnings during the build process

extra_javascript:
extra_javascript:
# The below two make MathJax work, see https://squidfunk.github.io/mkdocs-material/reference/mathjax/
- _static/mathjax.js
- https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js
Expand Down
54 changes: 31 additions & 23 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,34 @@
build-backend = "hatchling.build"
requires = ["hatchling"]

[dependency-groups]
dev = [
"prek==0.3.9",
"pyright==1.1.405",
"ruff==0.13.0",
"toml-sort==0.23.1"
]
docs = [
"hippogriffe==0.2.2",
"griffe==1.7.3",
"mkdocs==1.6.1",
"mkdocs-include-exclude-files==0.1.0",
"mkdocs-ipynb==0.1.1",
"mkdocs-material==9.6.7",
"mkdocstrings==0.28.3",
"mkdocstrings-python==1.16.8",
"pygments==2.20.0",
"pymdown-extensions==10.21.2"
]
tests = [
"beartype",
"jaxlib",
"optax",
"pytest",
"scipy",
"tqdm"
]

[project]
authors = [
{email = "contact@kidger.site", name = "Patrick Kidger"}
Expand Down Expand Up @@ -29,29 +57,6 @@ requires-python = ">=3.11"
urls = {repository = "https://github.com/patrick-kidger/diffrax"}
version = "0.7.2"

[project.optional-dependencies]
dev = ["pre-commit"]
docs = [
"hippogriffe==0.2.2",
"griffe==1.7.3",
"mkdocs==1.6.1",
"mkdocs-include-exclude-files==0.1.0",
"mkdocs-ipynb==0.1.1",
"mkdocs-material==9.6.7",
"mkdocstrings==0.28.3",
"mkdocstrings-python==1.16.8",
"pygments==2.20.0",
"pymdown-extensions==10.21.2"
]
tests = [
"beartype",
"jaxlib",
"optax",
"pytest",
"scipy",
"tqdm"
]

[tool.hatch.build]
include = ["diffrax/*"]

Expand Down Expand Up @@ -84,3 +89,6 @@ combine-as-imports = true
extra-standard-library = ["typing_extensions"]
lines-after-imports = 2
order-by-type = false

[tool.uv]
default-groups = ["dev", "docs", "tests"]
6 changes: 3 additions & 3 deletions test/test_saveat_solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def test_saveat_solution():
assert sol.t1 == _t1
assert sol.ts.shape == (4096,) # pyright: ignore
assert sol.ys.shape == (4096, 1) # pyright: ignore
_ts = jnp.where(sol.ts == jnp.inf, jnp.nan, sol.ts)
_ts = jnp.where(sol.ts == jnp.inf, jnp.nan, sol.ts) # pyright: ignore[reportArgumentType]
_ts = cast(jax.Array, _ts)
with jax.numpy_rank_promotion("allow"):
_ys = _y0 * jnp.exp(-0.5 * (_ts - _t0))[:, None]
Expand All @@ -141,7 +141,7 @@ def test_saveat_solution():
n = (4096 - 1) // 2 + 1
assert sol.ts.shape == (n,) # pyright: ignore
assert sol.ys.shape == (n, 1) # pyright: ignore
_ts = jnp.where(sol.ts == jnp.inf, jnp.nan, sol.ts)
_ts = jnp.where(sol.ts == jnp.inf, jnp.nan, sol.ts) # pyright: ignore[reportArgumentType]
_ts = cast(jax.Array, _ts)
with jax.numpy_rank_promotion("allow"):
_ys = _y0 * jnp.exp(-0.5 * (_ts - _t0))[:, None]
Expand All @@ -163,7 +163,7 @@ def test_saveat_solution():
n = (4096 - 1) // 2 + 1
assert sol.ts.shape == (n,) # pyright: ignore
assert sol.ys.shape == (n, 1) # pyright: ignore
_ts = jnp.where(sol.ts == jnp.inf, jnp.nan, sol.ts)
_ts = jnp.where(sol.ts == jnp.inf, jnp.nan, sol.ts) # pyright: ignore[reportArgumentType]
_ts = cast(jax.Array, _ts)
with jax.numpy_rank_promotion("allow"):
_ys = _y0 * jnp.exp(-0.5 * (_ts - _t0))[:, None]
Expand Down
Loading