Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Internal

- Fix CI (apply latest `black`, use latest `pytest` and `pytest-benchmark`)
([PR](https://github.com/f-dangel/backpack/pull/348))
- Improve efficiency of Hessian-vector product
([PR](https://github.com/f-dangel/backpack/pull/341))

Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ requires-python = ">=3.9"
test = [
"scipy",
"numpy<2",
"pytest>=4.5.0,<5.0.0",
"pytest-benchmark>=3.2.2,<4.0.0",
"pytest",
"pytest-benchmark",
"pytest-optional-tests>=0.1.1",
"pytest-cov",
"coveralls",
Expand Down
2 changes: 1 addition & 1 deletion test/core/derivatives/activation_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"input_fn" (callable): Used for specifying input function

Optional entries:
"target_fn" (callable): Fetches the groundtruth/target classes
"target_fn" (callable): Fetches the groundtruth/target classes
of regression/classification task
"loss_function_fn" (callable): Loss function used in the model
"device" [list(torch.device)]: List of devices to run the test on.
Expand Down
2 changes: 1 addition & 1 deletion test/core/derivatives/pooling_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"input_fn" (callable): Used for specifying input function

Optional entries:
"target_fn" (callable): Fetches the groundtruth/target classes
"target_fn" (callable): Fetches the groundtruth/target classes
of regression/classification task
"loss_function_fn" (callable): Loss function used in the model
"device" [list(torch.device)]: List of devices to run the test on.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Test configurations to test batch_l2_grad

The tests are taken from `test.extensions.firstorder.firstorder_settings`,
The tests are taken from `test.extensions.firstorder.firstorder_settings`,
but additional custom tests can be defined here by appending it to the list.
"""

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Test configurations to test sum_grad_square

The tests are taken from `test.extensions.firstorder.firstorder_settings`,
The tests are taken from `test.extensions.firstorder.firstorder_settings`,
but additional custom tests can be defined here by appending it to the list.
"""

Expand Down
2 changes: 1 addition & 1 deletion test/utils/evaluation_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def initialize_training_false_recursive(module: Module) -> Module:


def initialize_batch_norm_eval(
module: Union[BatchNorm1d, BatchNorm2d, BatchNorm3d]
module: Union[BatchNorm1d, BatchNorm2d, BatchNorm3d],
) -> Union[BatchNorm1d, BatchNorm2d, BatchNorm3d]:
"""Initializes a BatchNorm module in evaluation mode.

Expand Down
2 changes: 1 addition & 1 deletion test/utils/test_conv_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


Optional entries:
"target_fn" (callable): Fetches the groundtruth/target classes
"target_fn" (callable): Fetches the groundtruth/target classes
of regression/classification task
"loss_function_fn" (callable): Loss function used in the model
"device" [list(torch.device)]: List of devices to run the test on.
Expand Down
2 changes: 1 addition & 1 deletion test/utils/test_conv_transpose_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"input_fn" (callable): Used for specifying input function

Optional entries:
"target_fn" (callable): Fetches the groundtruth/target classes
"target_fn" (callable): Fetches the groundtruth/target classes
of regression/classification task
"loss_function_fn" (callable): Loss function used in the model
"device" [list(torch.device)]: List of devices to run the test on.
Expand Down