Skip to content

Commit f85f2f3

Browse files
committed
neglected to reshape sigma
1 parent c29aedf commit f85f2f3

3 files changed

Lines changed: 3 additions & 3 deletions

File tree

pynumdiff/tests/test_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def test_estimate_integration_constant():
2626
(np.ones(5)*10, np.ones(5)*5 + 0.01*np.random.randn(5), 5), # with some noise
2727
(np.array([0]), np.array([1]), -1), # singleton case
2828
(np.vstack([np.arange(5)]*5), np.vstack([np.arange(5) + c for c in range(5)]), -np.arange(5).reshape(-1,1)), # multidimensional case
29-
(np.ones((5,5)), np.vstack([np.arange(5) + c for c in range(5)]), -np.arange(1,6).reshape(-1,1))]:
29+
(np.ones((7,5)), np.vstack([np.arange(5) + c for c in range(7)]), -np.arange(1,8).reshape(-1,1))]: # nonsquare case
3030
x0 = utility.estimate_integration_constant(x, x_hat, axis=-1)
3131
assert np.allclose(x0, c, rtol=1e-3)
3232

pynumdiff/utils/utility.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def estimate_integration_constant(x, x_hat, M=6, axis=0):
5858
:math:`\\mathbf{\\hat{x}}` with :math:`\\mathbf{x}`
5959
"""
6060
s = list(x_hat.shape); s[axis] = 1; s = tuple(s) # proper shape for multidimensional integration constants
61-
sigma = median_abs_deviation(x - x_hat, axis=axis, scale='normal') # M is in units of this robust scatter metric
61+
sigma = median_abs_deviation(x - x_hat, axis=axis, scale='normal').reshape(s) # M is in units of this robust scatter metric
6262
if M == float('inf') or np.all(sigma < 1e-3): # If no scatter, then no outliers, so use L2
6363
return np.mean(x - x_hat, axis=axis).reshape(s) # Solves the l2 distance minimization, argmin_c ||x_hat + c - x||_2^2
6464
elif M < 1e-3: # small M looks like l1 loss, and Huber gets too flat to work well

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ name = "pynumdiff"
77
dynamic = ["version"]
88
description = "pynumdiff: numerical derivatives in python"
99
readme = "README.md"
10-
license = {text = "MIT"}
10+
license = "MIT"
1111
maintainers = [
1212
{name = "Floris van Breugel", email = "fvanbreugel@unr.edu"},
1313
{name = "Pavel Komarov", email = "pvlkmrv@uw.edu"},

0 commit comments

Comments
 (0)