Skip to content

Commit 9cb8eda

Browse files
committed
update hilbert2 to match scipy version 1.17
1 parent 6168b98 commit 9cb8eda

3 files changed

Lines changed: 16 additions & 9 deletions

File tree

Makefile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,8 @@ tool-clean:
8585

8686
update: tool
8787
. .venv/bin/activate && python -m pip install --upgrade pip
88-
@for package in $$(./tools/taplo/taplo get -f pyproject.toml project.optional-dependencies.dev); do \
89-
. .venv/bin/activate && python -m pip install --upgrade $$package; \
88+
@./tools/taplo/taplo get -f pyproject.toml project.optional-dependencies.dev | while read -r package; do \
89+
. .venv/bin/activate && python -m pip install --upgrade "$$package"; \
9090
done
9191

9292
clean: dist-clean doc-clean test-clean tool-clean

diffsptk/modules/hilbert2.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -112,13 +112,19 @@ def _precompute(
112112
TwoDimensionalHilbertTransform._check(dim)
113113
if isinstance(fft_length, int):
114114
fft_length = (fft_length, fft_length)
115-
_, _, h1 = HilbertTransform._precompute(
116-
fft_length[0], None, device=device, dtype=torch.double
117-
)
118-
_, _, h2 = HilbertTransform._precompute(
119-
fft_length[1], None, device=device, dtype=torch.double
120-
)
121-
h = h1[0].unsqueeze(1) * h2[0].unsqueeze(0)
115+
116+
def get_weights(n: int) -> torch.Tensor:
117+
_, _, h = HilbertTransform._precompute(
118+
n, None, device=device, dtype=torch.double
119+
)
120+
h = h[0]
121+
if 2 <= n:
122+
h[(n + 1) // 2] = 0
123+
return h
124+
125+
h1 = get_weights(fft_length[0])
126+
h2 = get_weights(fft_length[1])
127+
h = h1.unsqueeze(1) * h2.unsqueeze(0)
122128
return (dim,), None, (to(h, dtype=dtype),)
123129

124130
@staticmethod

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ dev = [
4949
"pytest",
5050
"pytest-cov",
5151
"ruff",
52+
"scipy >= 1.17.0",
5253
"sphinx",
5354
"twine",
5455
]

0 commit comments

Comments
 (0)