-
Notifications
You must be signed in to change notification settings - Fork 15
refactor: Improve PSD typing #522
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 6 commits
Commits
Show all changes
37 commits
Select commit
Hold shift + click to select a range
9fec107
refactor(linalg): Add `PSDQuadraticForm` and `GeneralizedMatrix`.
PierreQuinton acf8e58
Merge branch 'main' into add-generalized-matrix-psd-matrix
PierreQuinton a744fa2
Sort items of `__all__` of `_linalg.__init__`
PierreQuinton 23de54d
one line
PierreQuinton 2bd603e
fix `is_psd_quadratic_form`
PierreQuinton d6f8375
remove outdated comment
PierreQuinton 24d24bb
Add `assert_psd_quadratic_form` and TODOs for where to test it. I als…
PierreQuinton 242cb55
fix is_psd_quadratic_form
PierreQuinton 72a9a5f
Rename `PSDQuadraticForm` to `PSDGeneralizedMatrix`
PierreQuinton 09df593
fix type of weighting in Flattening
PierreQuinton 0a1d45c
Add parametrization of zero matrix for test_gramian_is_psd
PierreQuinton 6f63182
Add test of the PSD property for functions in aggregation/_utils/gramian
PierreQuinton 5a42ecd
rename test of equivariance accordingly
PierreQuinton 0497f3a
Rename functions in `autogram/_gramian_utils` so that they don't incl…
PierreQuinton 48df0a8
Test the PSD property on outputs of functions in `autogram/_gramian_u…
PierreQuinton 40977f3
Remove internal checks of shapes of matrices
PierreQuinton 92b975b
Remove uninformative shadowing of assertion error in assert_psd_*
PierreQuinton bda0a5f
Factorize `compute_gramian` from `forward_backward` by making the one…
PierreQuinton 97bcf42
Revert "Factorize `compute_gramian` from `forward_backward` by making…
PierreQuinton 03aebae
Generalizes `compute_gramian` to take a `GeneralizedMatrix` instead.
PierreQuinton ee54c09
Move `aggregation/_utils/gramian.py` to `_linalg/gramian.py`
PierreQuinton 2b94d78
Merge branch 'main' into add-generalized-matrix-psd-matrix
ValerianRey e347075
Apply suggestions from code review
PierreQuinton 3d9742c
Remove outdated comments
PierreQuinton f2d0d1b
Improve style
PierreQuinton 5eafa74
Improve typing of `forward_backward.compute_gramian`
PierreQuinton d60e9fa
improve asserts
PierreQuinton f4d611b
Merge branch 'main' into add-generalized-matrix-psd-matrix
ValerianRey 57af9f1
Merge branch 'main' into add-generalized-matrix-psd-matrix
ValerianRey a793693
Can parametrize number of dimensions to contract in `compute_gramian`
PierreQuinton 7da352c
Remove GeneralizedMatrix
ValerianRey 994932b
Rename PSDGeneralizedMatrix to PSDTensor
ValerianRey ab809c6
Add comment about using classes
ValerianRey 47bf743
Remove useless overload of compute_gramian
ValerianRey 55bc6f8
Rename matrix to t in compute_gramian
ValerianRey a80f3f6
Add overload for compute_gramian when t is matrix and contracted_dims…
ValerianRey 09393cc
Stop expecting coverage for overload functions
ValerianRey File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,23 @@ | ||
| from .gramian import compute_gramian | ||
| from .matrix import Matrix, PSDMatrix | ||
| from ._gramian import compute_gramian | ||
| from ._matrix import ( | ||
| GeneralizedMatrix, | ||
| Matrix, | ||
| PSDMatrix, | ||
| PSDQuadraticForm, | ||
| is_generalized_matrix, | ||
| is_matrix, | ||
| is_psd_matrix, | ||
| is_psd_quadratic_form, | ||
| ) | ||
|
|
||
| __all__ = ["compute_gramian", "Matrix", "PSDMatrix"] | ||
| __all__ = [ | ||
| "compute_gramian", | ||
| "GeneralizedMatrix", | ||
| "Matrix", | ||
| "PSDMatrix", | ||
| "PSDQuadraticForm", | ||
| "is_generalized_matrix", | ||
| "is_matrix", | ||
| "is_psd_matrix", | ||
| "is_psd_quadratic_form", | ||
| ] |
6 changes: 4 additions & 2 deletions
6
src/torchjd/_linalg/gramian.py → src/torchjd/_linalg/_gramian.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,9 +1,11 @@ | ||
| from .matrix import Matrix, PSDMatrix | ||
| from ._matrix import Matrix, PSDMatrix, is_psd_matrix | ||
|
|
||
|
|
||
| def compute_gramian(matrix: Matrix) -> PSDMatrix: | ||
| """ | ||
| Computes the `Gramian matrix <https://en.wikipedia.org/wiki/Gram_matrix>`_ of a given matrix. | ||
| """ | ||
|
|
||
| return matrix @ matrix.T | ||
| gramian = matrix @ matrix.T | ||
| assert is_psd_matrix(gramian) | ||
| return gramian |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,42 @@ | ||
| from typing import TypeGuard | ||
|
|
||
| from torch import Tensor | ||
|
|
||
|
|
||
| class GeneralizedMatrix(Tensor): | ||
| pass | ||
|
|
||
|
|
||
| class Matrix(GeneralizedMatrix): | ||
| pass | ||
|
PierreQuinton marked this conversation as resolved.
Outdated
|
||
|
|
||
|
|
||
| class PSDQuadraticForm(Tensor): | ||
|
PierreQuinton marked this conversation as resolved.
Outdated
|
||
| pass | ||
|
PierreQuinton marked this conversation as resolved.
Outdated
|
||
|
|
||
|
|
||
| class PSDMatrix(PSDQuadraticForm, Matrix): | ||
| pass | ||
|
PierreQuinton marked this conversation as resolved.
Outdated
|
||
|
|
||
|
|
||
| def is_generalized_matrix(t: Tensor) -> TypeGuard[GeneralizedMatrix]: | ||
| return t.ndim >= 1 | ||
|
|
||
|
|
||
| def is_matrix(t: Tensor) -> TypeGuard[Matrix]: | ||
| return t.ndim == 2 | ||
|
|
||
|
|
||
| def is_psd_quadratic_form(t: Tensor) -> TypeGuard[PSDQuadraticForm]: | ||
| half_dim = t.ndim // 2 | ||
| return not t.ndim % 2 != 0 and t.shape[:half_dim] == t.shape[: half_dim - 1 : -1] | ||
| # We do not check that t is PSD as it is expensive, but this must be checked in the tests of | ||
| # every function that use this TypeGuard. | ||
| # TODO: Say with what assert we check that | ||
|
|
||
|
|
||
| def is_psd_matrix(t: Tensor) -> TypeGuard[PSDMatrix]: | ||
| return t.ndim == 2 and t.shape[0] == t.shape[1] | ||
| # We do not check that t is PSD as it is expensive, but this must be checked in the tests of | ||
| # every function that use this TypeGuard. | ||
| # TODO: Say with what assert we check that | ||
|
ValerianRey marked this conversation as resolved.
Outdated
|
||
This file was deleted.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.