Commit 2bf9824
authored
0.1.1 release (#220)
This is mostly a bug fix, documentation and under-the-hood performance improvement release with one new feature—the `lx.invert` [transformation](https://docs.kidger.site/lineax/api/linear_solve/#invert) which produces an operator representing the inverse of a matrix. Use of coloring rules should make using implicit solvers in [diffrax](https://docs.kidger.site/diffrax/) for tridiagonal `Jacobian/FunctionLinearOperator`s at least an order of magnitude faster.
## Breaking Changes
* Extraction of diagonal/tridiagonals of now leverages the promise of a matrix being tagged as diagonal/tridiagonal more heavily. If you have previously used the tag for an operator that you just wanted lineax to TREAT as diagonal/tridiagonal you may now get incorrect results. In most cases the right fix will probably be to first manually extract (tri)diagonal and construct the `(Tri)DiagonalOperator` explicitly, please raise an [issue](https://github.com/patrick-kidger/lineax/issues/new) if you need any further assistance.
* `lineax.linear_solve` now stop-gradient's automatically (#213), it is unlikely this will break any existing use-cases but may make manual stop-gradienting unecessary
* Removed AuxLinearOperator (#203)
## Features
* Add invert helper function to wrap `lineax.linear_solve` in `FunctionLinearOperator`. Materialising an inverse is now as simple as `lx.invert(op).as_matrix()`. (#206)
## Compatibility
* lineax v0.1.1 now requires JAX >= 0.10.0 which provides a lowering to LAPACK/cuSolver's` ormqr for more efficient QR solve adopted in #219.
## Bugfixes
* Fix derived tag check rules for composite operators (e.g. `Composed/Neg/Mul/AddLinearOperator`) (#192)
* Linearisation of functions `custom_vjp`'s are now supported by `lineax.linearize(JacobianLinearOperator(f, x, jac="bwd"))` by using `jax.linear_transpose` under the hood. (#191)
* Complex positive/negative semi-definite matrices no longer register as symmetric (#200)
* `lineax.LSMR` no longer fails when initial residual is exactly zero. (HUGE thanks to @f0uriest for spotting this tricky and hard-to-spot bug #202)
* Differentiating through `linear_solve`'s no longer differentiates through `solver.init` this means using solver's with no or incorrect jvp rule is now possible (#212)
## Performance
* Coloring rules now used to _massively_ speed up diagonal/tridiagonal extraction of tagged `Jacobian/FunctionLinearOperators` (#164, #165)
* Normal and iterative solvers now apply `lineax.linearise` under the hood to avoid multiple sequential AD passes (#198)
* Furthermore, `lineax.Normal(lineax.Cholesky())` now materialises the inner operator before constructing the Gram matrix (#207)
* `ComposedLinearOperator.as_matrix` no longer materialises each matrix first but instead batches `mv` of the first operator over the second matrix (#196)
* JAX's [ormqr](https://docs.jax.dev/en/latest/_autosummary/jax.lax.linalg.ormqr.html) now used for more efficient QR solves (#219)
## Documentation
* The `lineax.LSMR` iterative least square solver is now properly documented (#204) after @f0uriest's #202 bug-fixes make it more robust.
Other repo infra PR's not affecting Python package include #214, #216 and #218.
## New Contributors
* @patrick-kidger-bot 🤖 made their first contribution in #216
**Full Changelog**: v0.1.0...v0.1.11 parent 53edb88 commit 2bf9824
1 file changed
Lines changed: 1 addition & 1 deletion
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
55 | 55 | | |
56 | 56 | | |
57 | 57 | | |
58 | | - | |
| 58 | + | |
59 | 59 | | |
60 | 60 | | |
61 | 61 | | |
| |||
0 commit comments