diff --git a/changelog.md b/changelog.md index 55bb6d22..3af3cab0 100644 --- a/changelog.md +++ b/changelog.md @@ -12,6 +12,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Internal +- Fix CI (apply latest `black`, use latest `pytest` and `pytest-benchmark`) + ([PR](https://github.com/f-dangel/backpack/pull/348)) - Improve efficiency of Hessian-vector product ([PR](https://github.com/f-dangel/backpack/pull/341)) diff --git a/pyproject.toml b/pyproject.toml index bb920a11..f7215c9d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,8 +47,8 @@ requires-python = ">=3.9" test = [ "scipy", "numpy<2", - "pytest>=4.5.0,<5.0.0", - "pytest-benchmark>=3.2.2,<4.0.0", + "pytest", + "pytest-benchmark", "pytest-optional-tests>=0.1.1", "pytest-cov", "coveralls", diff --git a/test/core/derivatives/activation_settings.py b/test/core/derivatives/activation_settings.py index a8dacbad..d3646be2 100644 --- a/test/core/derivatives/activation_settings.py +++ b/test/core/derivatives/activation_settings.py @@ -6,7 +6,7 @@ "input_fn" (callable): Used for specifying input function Optional entries: - "target_fn" (callable): Fetches the groundtruth/target classes + "target_fn" (callable): Fetches the groundtruth/target classes of regression/classification task "loss_function_fn" (callable): Loss function used in the model "device" [list(torch.device)]: List of devices to run the test on. diff --git a/test/core/derivatives/pooling_settings.py b/test/core/derivatives/pooling_settings.py index 0c0627fa..07d7ef66 100644 --- a/test/core/derivatives/pooling_settings.py +++ b/test/core/derivatives/pooling_settings.py @@ -5,7 +5,7 @@ "input_fn" (callable): Used for specifying input function Optional entries: - "target_fn" (callable): Fetches the groundtruth/target classes + "target_fn" (callable): Fetches the groundtruth/target classes of regression/classification task "loss_function_fn" (callable): Loss function used in the model "device" [list(torch.device)]: List of devices to run the test on. diff --git a/test/extensions/firstorder/batch_l2_grad/batchl2grad_settings.py b/test/extensions/firstorder/batch_l2_grad/batchl2grad_settings.py index 16e02930..cc1f7cee 100644 --- a/test/extensions/firstorder/batch_l2_grad/batchl2grad_settings.py +++ b/test/extensions/firstorder/batch_l2_grad/batchl2grad_settings.py @@ -1,6 +1,6 @@ """Test configurations to test batch_l2_grad -The tests are taken from `test.extensions.firstorder.firstorder_settings`, +The tests are taken from `test.extensions.firstorder.firstorder_settings`, but additional custom tests can be defined here by appending it to the list. """ diff --git a/test/extensions/firstorder/sum_grad_squared/sumgradsquared_settings.py b/test/extensions/firstorder/sum_grad_squared/sumgradsquared_settings.py index 3b3d465d..2565135d 100644 --- a/test/extensions/firstorder/sum_grad_squared/sumgradsquared_settings.py +++ b/test/extensions/firstorder/sum_grad_squared/sumgradsquared_settings.py @@ -1,6 +1,6 @@ """Test configurations to test sum_grad_square -The tests are taken from `test.extensions.firstorder.firstorder_settings`, +The tests are taken from `test.extensions.firstorder.firstorder_settings`, but additional custom tests can be defined here by appending it to the list. """ diff --git a/test/utils/evaluation_mode.py b/test/utils/evaluation_mode.py index f41ba2f5..a8b31589 100644 --- a/test/utils/evaluation_mode.py +++ b/test/utils/evaluation_mode.py @@ -24,7 +24,7 @@ def initialize_training_false_recursive(module: Module) -> Module: def initialize_batch_norm_eval( - module: Union[BatchNorm1d, BatchNorm2d, BatchNorm3d] + module: Union[BatchNorm1d, BatchNorm2d, BatchNorm3d], ) -> Union[BatchNorm1d, BatchNorm2d, BatchNorm3d]: """Initializes a BatchNorm module in evaluation mode. diff --git a/test/utils/test_conv_settings.py b/test/utils/test_conv_settings.py index bf073726..33d2392e 100644 --- a/test/utils/test_conv_settings.py +++ b/test/utils/test_conv_settings.py @@ -7,7 +7,7 @@ Optional entries: - "target_fn" (callable): Fetches the groundtruth/target classes + "target_fn" (callable): Fetches the groundtruth/target classes of regression/classification task "loss_function_fn" (callable): Loss function used in the model "device" [list(torch.device)]: List of devices to run the test on. diff --git a/test/utils/test_conv_transpose_settings.py b/test/utils/test_conv_transpose_settings.py index 34986e42..1f69e7be 100644 --- a/test/utils/test_conv_transpose_settings.py +++ b/test/utils/test_conv_transpose_settings.py @@ -4,7 +4,7 @@ "input_fn" (callable): Used for specifying input function Optional entries: - "target_fn" (callable): Fetches the groundtruth/target classes + "target_fn" (callable): Fetches the groundtruth/target classes of regression/classification task "loss_function_fn" (callable): Loss function used in the model "device" [list(torch.device)]: List of devices to run the test on.