Skip to content

Commit 77f2b5a

Browse files
Standardised infra
1 parent 2fd3ef3 commit 77f2b5a

38 files changed

Lines changed: 136 additions & 121 deletions

.github/workflows/release.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,14 @@ jobs:
1010
runs-on: ubuntu-latest
1111
steps:
1212
- name: Release
13-
uses: patrick-kidger/action_update_python_project@v6
13+
uses: patrick-kidger/action_update_python_project@v8
1414
with:
1515
python-version: "3.11"
1616
test-script: |
1717
cp -r ${{ github.workspace }}/test ./test
1818
cp ${{ github.workspace }}/pyproject.toml ./pyproject.toml
19-
python -m pip install -r ./test/requirements.txt
20-
python -m test
19+
uv sync --extra tests --no-install-project --inexact
20+
uv run --no-sync pytest
2121
pypi-token: ${{ secrets.pypi_token }}
2222
github-user: patrick-kidger
2323
github-token: ${{ github.token }}

.github/workflows/run_tests.yml

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,16 @@ jobs:
2323
- name: Install dependencies
2424
run: |
2525
python -m pip install --upgrade pip
26-
python -m pip install -r ./test/requirements.txt
27-
26+
python -m pip install '.[dev,docs,tests]'
2827
2928
- name: Checks with pre-commit
30-
uses: pre-commit/action@v3.0.1
29+
run: |
30+
pre-commit run --all-files
3131
3232
- name: Test with pytest
3333
run: |
34-
python -m pip install .
3534
python -m test
35+
36+
- name: Check that documentation can be built.
37+
run: |
38+
mkdocs build

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@ site/
88
.pymon
99
.idea/
1010
.venv/
11+
uv.lock

.pre-commit-config.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,15 @@ repos:
88
files: ^pyproject\.toml$
99
additional_dependencies: ["toml-sort==0.23.1"]
1010
- repo: https://github.com/astral-sh/ruff-pre-commit
11-
rev: v0.2.2
11+
rev: v0.13.0
1212
hooks:
1313
- id: ruff-format # formatter
14-
types_or: [ python, pyi, jupyter ]
14+
types_or: [ python, pyi, jupyter, toml ]
1515
- id: ruff # linter
16-
types_or: [ python, pyi, jupyter ]
16+
types_or: [ python, pyi, jupyter, toml ]
1717
args: [ --fix ]
1818
- repo: https://github.com/RobertCraigie/pyright-python
19-
rev: v1.1.350
19+
rev: v1.1.405
2020
hooks:
2121
- id: pyright
2222
additional_dependencies: [equinox, jax, jaxtyping, optax, optimistix, lineax, pytest, typeguard==2.13.3, typing_extensions, wadler_lindig]

CONTRIBUTING.md

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,23 +8,15 @@ Contributions (pull requests) are very welcome! Here's how to get started.
88

99
First fork the library on GitHub.
1010

11-
Then clone and install the library in development mode:
11+
Then clone and install the library:
1212

1313
```bash
1414
git clone https://github.com/your-username-here/diffrax.git
1515
cd diffrax
16-
pip install -e .
16+
pip install -e '.[dev]'
17+
pre-commit install # `pre-commit` is installed by `pip` on the previous line
1718
```
1819

19-
Then install the pre-commit hook:
20-
21-
```bash
22-
pip install pre-commit
23-
pre-commit install
24-
```
25-
26-
These hooks use ruff to lint and format the code, and pyright to type-check it.
27-
2820
---
2921

3022
**If you're making changes to the code:**
@@ -34,8 +26,8 @@ Now make your changes. Make sure to include additional tests if necessary.
3426
Next verify the tests all pass:
3527

3628
```bash
37-
pip install -r test/requirements.txt
38-
pytest
29+
pip install -e '.[tests]'
30+
pytest # `pytest` is installed by `pip` on the previous line.
3931
```
4032

4133
Then push your changes back to your fork of the repository:

diffrax/_adjoint.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -362,8 +362,7 @@ def loop(
362362
if is_unsafe_sde(terms):
363363
kind = "lax"
364364
msg = (
365-
"Cannot reverse-mode autodifferentiate when using "
366-
"`UnsafeBrownianPath`."
365+
"Cannot reverse-mode autodifferentiate when using `UnsafeBrownianPath`."
367366
)
368367
elif max_steps is None:
369368
kind = "lax"

diffrax/_brownian/path.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,9 @@ class UnsafeBrownianPath(AbstractBrownianPath):
6262
"""
6363

6464
shape: PyTree[jax.ShapeDtypeStruct] = eqx.field(static=True)
65-
levy_area: type[
66-
BrownianIncrement | SpaceTimeLevyArea | SpaceTimeTimeLevyArea
67-
] = eqx.field(static=True)
65+
levy_area: type[BrownianIncrement | SpaceTimeLevyArea | SpaceTimeTimeLevyArea] = (
66+
eqx.field(static=True)
67+
)
6868
key: PRNGKeyArray
6969

7070
def __init__(

diffrax/_brownian/tree.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -235,9 +235,9 @@ class VirtualBrownianTree(AbstractBrownianPath):
235235
t1: RealScalarLike
236236
tol: RealScalarLike
237237
shape: PyTree[jax.ShapeDtypeStruct] = eqx.field(static=True)
238-
levy_area: type[
239-
BrownianIncrement | SpaceTimeLevyArea | SpaceTimeTimeLevyArea
240-
] = eqx.field(static=True)
238+
levy_area: type[BrownianIncrement | SpaceTimeLevyArea | SpaceTimeTimeLevyArea] = (
239+
eqx.field(static=True)
240+
)
241241
key: PyTree[PRNGKeyArray]
242242
_spline: _Spline = eqx.field(static=True)
243243

diffrax/_integrate.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from collections.abc import Callable
44
from typing import ( # noqa: UP035
55
Any,
6+
cast,
67
get_args,
78
get_origin,
89
Tuple,
@@ -1164,7 +1165,10 @@ def _wrap(term):
11641165
def _get_tols(x):
11651166
outs = []
11661167
for attr in ("rtol", "atol", "norm"):
1167-
if getattr(solver.root_finder, attr) is use_stepsize_tol:
1168+
if (
1169+
getattr(cast(AbstractImplicitSolver, solver).root_finder, attr)
1170+
is use_stepsize_tol
1171+
):
11681172
outs.append(getattr(x, attr))
11691173
return tuple(outs)
11701174

diffrax/_solution.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from ._path import AbstractPath
1111

1212

13-
class RESULTS(optx.RESULTS): # pyright: ignore
13+
class RESULTS(optx.RESULTS): # pyright: ignore[reportGeneralTypeIssues]
1414
successful = ""
1515
max_steps_reached = (
1616
"The maximum number of solver steps was reached. Try increasing `max_steps`."
@@ -121,8 +121,8 @@ class Solution(AbstractPath):
121121
# the structure of `subs`.
122122
# SaveAt(fn=...) means that `ys` will then follow with arbitrary sub-dependent
123123
# PyTree structures.
124-
ts: PyTree[Real[Array, " ?times"], " S"] | None
125-
ys: PyTree[Shaped[Array, "?times ?*shape"], "S ..."] | None
124+
ts: PyTree[Real[Array, " ?times"], " S"] | None # pyright: ignore[reportUndefinedVariable]
125+
ys: PyTree[Shaped[Array, "?times ?*shape"], "S ..."] | None # pyright: ignore
126126
interpolation: DenseInterpolation | None
127127
stats: dict[str, Any]
128128
result: RESULTS
@@ -133,7 +133,7 @@ class Solution(AbstractPath):
133133

134134
def evaluate(
135135
self, t0: RealScalarLike, t1: RealScalarLike | None = None, left: bool = True
136-
) -> PyTree[Shaped[Array, "?*shape"], " Y"]:
136+
) -> PyTree[Shaped[Array, "?*shape"], " Y"]: # pyright: ignore[reportUndefinedVariable]
137137
"""If dense output was saved, then evaluate the solution at any point in the
138138
region of integration `self.t0` to `self.t1`.
139139
@@ -153,7 +153,7 @@ def evaluate(
153153

154154
def derivative(
155155
self, t: RealScalarLike, left: bool = True
156-
) -> PyTree[Shaped[Array, "?*shape"], " Y"]:
156+
) -> PyTree[Shaped[Array, "?*shape"], " Y"]: # pyright: ignore[reportUndefinedVariable]
157157
r"""If dense output was saved, then calculate an **approximation** to the
158158
derivative of the solution at any point in the region of integration `self.t0`
159159
to `self.t1`.

0 commit comments

Comments
 (0)