Skip to content

Commit 95d6d4d

Browse files
Fix HFFT tests to use complex input tensors
PyTorch's fft.hfft/hfft2/hfftn require complex input tensors. The tests were passing real-valued tensors (Float32/Float64), which causes 'NYI' errors in newer PyTorch versions. Fixed by using complex64/complex128 input types and ihfft for inverse verification. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 5782540 commit 95d6d4d

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

test/TorchSharpTest/TestTorchTensor.cs

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6942,25 +6942,25 @@ public void Float64FFT()
69426942
[TestOf(nameof(fft.hfft))]
69436943
public void Float32HFFT()
69446944
{
6945-
var input = torch.arange(4);
6945+
var input = torch.arange(4, complex64);
69466946
var output = fft.hfft(input);
69476947
Assert.Equal(6, output.shape[0]);
69486948
Assert.Equal(ScalarType.Float32, output.dtype);
69496949

6950-
var inverted = fft.ifft(output);
6950+
var inverted = fft.ihfft(output);
69516951
Assert.Equal(ScalarType.ComplexFloat32, inverted.dtype);
69526952
}
69536953

69546954
[Fact]
69556955
[TestOf(nameof(fft.hfft))]
69566956
public void Float64HFFT()
69576957
{
6958-
var input = torch.arange(4, float64);
6958+
var input = torch.arange(4, complex128);
69596959
var output = fft.hfft(input);
69606960
Assert.Equal(6, output.shape[0]);
69616961
Assert.Equal(ScalarType.Float64, output.dtype);
69626962

6963-
var inverted = fft.ifft(output);
6963+
var inverted = fft.ihfft(output);
69646964
Assert.Equal(ScalarType.ComplexFloat64, inverted.dtype);
69656965
}
69666966

@@ -7203,10 +7203,10 @@ public void Float64RFFTN()
72037203
[TestOf(nameof(fft.hfft2))]
72047204
public void Float32HFFT2()
72057205
{
7206-
var input = torch.rand(new long[] { 5, 5, 5, 5 });
7206+
var input = torch.rand(new long[] { 5, 5, 5, 5 }, complex64);
72077207
var output = fft.hfft2(input);
72087208
Assert.Equal(new long[] { 5, 5, 5, 8 }, output.shape);
7209-
Assert.Equal(input.dtype, output.dtype);
7209+
Assert.Equal(ScalarType.Float32, output.dtype);
72107210

72117211
var inverted = fft.ihfft2(output);
72127212
Assert.Equal(new long[] { 5, 5, 5, 5 }, inverted.shape);
@@ -7217,10 +7217,10 @@ public void Float32HFFT2()
72177217
[TestOf(nameof(fft.hfft2))]
72187218
public void Float64HFFT2()
72197219
{
7220-
var input = torch.rand(new long[] { 5, 5, 5, 5 }, float64);
7220+
var input = torch.rand(new long[] { 5, 5, 5, 5 }, complex128);
72217221
var output = fft.hfft2(input);
72227222
Assert.Equal(new long[] { 5, 5, 5, 8 }, output.shape);
7223-
Assert.Equal(input.dtype, output.dtype);
7223+
Assert.Equal(ScalarType.Float64, output.dtype);
72247224

72257225
var inverted = fft.ihfft2(output);
72267226
Assert.Equal(new long[] { 5, 5, 5, 5 }, inverted.shape);
@@ -7231,10 +7231,10 @@ public void Float64HFFT2()
72317231
[TestOf(nameof(fft.hfft2))]
72327232
public void Float32HFFTN()
72337233
{
7234-
var input = torch.rand(new long[] { 5, 5, 5, 5 });
7234+
var input = torch.rand(new long[] { 5, 5, 5, 5 }, complex64);
72357235
var output = fft.hfft2(input);
72367236
Assert.Equal(new long[] { 5, 5, 5, 8 }, output.shape);
7237-
Assert.Equal(input.dtype, output.dtype);
7237+
Assert.Equal(ScalarType.Float32, output.dtype);
72387238

72397239
var inverted = fft.ihfft2(output);
72407240
Assert.Equal(new long[] { 5, 5, 5, 5 }, inverted.shape);
@@ -7249,10 +7249,10 @@ public void Float64HFFTN()
72497249

72507250
// TODO: Something in this test makes if fail on Windows / Release and MacOS / Release
72517251

7252-
var input = torch.rand(new long[] { 5, 5, 5, 5 }, float64);
7252+
var input = torch.rand(new long[] { 5, 5, 5, 5 }, complex128);
72537253
var output = fft.hfftn(input);
72547254
Assert.Equal(new long[] { 5, 5, 5, 8 }, output.shape);
7255-
Assert.Equal(input.dtype, output.dtype);
7255+
Assert.Equal(ScalarType.Float64, output.dtype);
72567256

72577257
var inverted = fft.ihfftn(output);
72587258
Assert.Equal(new long[] { 5, 5, 5, 5 }, inverted.shape);

0 commit comments

Comments
 (0)