|
13 | 13 | from typing_extensions import override |
14 | 14 |
|
15 | 15 | from array_api_extra import ( |
| 16 | + angle, |
16 | 17 | apply_where, |
17 | 18 | argpartition, |
18 | 19 | at, |
@@ -1906,3 +1907,78 @@ def test_device(self, xp: ModuleType, device: Device): |
1906 | 1907 | a = xp.asarray([-1, 1, 0], device=device) |
1907 | 1908 | b = xp.asarray([2, -2, 0], device=device) |
1908 | 1909 | assert get_device(union1d(a, b)) == device |
| 1910 | + |
| 1911 | + |
| 1912 | +class TestAngle: |
| 1913 | + def test_simple(self, xp: ModuleType): |
| 1914 | + a = xp.asarray([1, 0]) |
| 1915 | + res = angle(a) |
| 1916 | + expected = xp.asarray([0.0, 0.0], dtype=res.dtype) |
| 1917 | + xp_assert_equal(res, expected) |
| 1918 | + |
| 1919 | + def test_basic(self, xp: ModuleType): |
| 1920 | + x = xp.asarray( |
| 1921 | + [ |
| 1922 | + 1 + 3j, |
| 1923 | + np.sqrt(2) / 2.0 + 1j * np.sqrt(2) / 2, |
| 1924 | + 1, |
| 1925 | + 1j, |
| 1926 | + -1, |
| 1927 | + -1j, |
| 1928 | + 1 - 3j, |
| 1929 | + -1 + 3j, |
| 1930 | + ], |
| 1931 | + dtype=xp.complex128, |
| 1932 | + ) |
| 1933 | + expected = xp.asarray( |
| 1934 | + [ |
| 1935 | + np.arctan(3.0 / 1.0), |
| 1936 | + np.arctan(1.0), |
| 1937 | + 0, |
| 1938 | + np.pi / 2, |
| 1939 | + np.pi, |
| 1940 | + -np.pi / 2.0, |
| 1941 | + -np.arctan(3.0 / 1.0), |
| 1942 | + np.pi - np.arctan(3.0 / 1.0), |
| 1943 | + ], |
| 1944 | + dtype=xp.float64, |
| 1945 | + ) |
| 1946 | + xp_assert_close(angle(x, xp=xp), expected, rtol=0, atol=1e-11) |
| 1947 | + xp_assert_close( |
| 1948 | + angle(x, deg=True, xp=xp), |
| 1949 | + expected * 180 / xp.pi, |
| 1950 | + rtol=0, |
| 1951 | + atol=1e-11, |
| 1952 | + ) |
| 1953 | + |
| 1954 | + def test_real(self, xp: ModuleType): |
| 1955 | + x = xp.asarray([0.0, -0.0, 1.0, -1.0]) |
| 1956 | + expected = xp.asarray([0.0, xp.pi, 0.0, xp.pi], dtype=x.dtype) |
| 1957 | + xp_assert_close(angle(x, xp=xp), expected) |
| 1958 | + |
| 1959 | + def test_complex(self, xp: ModuleType): |
| 1960 | + a = xp.asarray([1 + 1j, 1 - 1j, -1 + 1j, -1 - 1j]) |
| 1961 | + expected = xp.asarray([xp.pi / 4, -xp.pi / 4, 3 * xp.pi / 4, -3 * xp.pi / 4]) |
| 1962 | + res = angle(a, xp=xp) |
| 1963 | + xp_assert_equal(res, expected) |
| 1964 | + |
| 1965 | + def test_integral(self, xp: ModuleType): |
| 1966 | + x = xp.asarray([0, -1, 1], dtype=xp.int32) |
| 1967 | + actual = angle(x, xp=xp) |
| 1968 | + expected = xp.asarray( |
| 1969 | + [0.0, xp.pi, 0.0], dtype=default_dtype(xp, device=get_device(x)) |
| 1970 | + ) |
| 1971 | + xp_assert_close(actual, expected) |
| 1972 | + |
| 1973 | + def test_2d(self, xp: ModuleType): |
| 1974 | + a = xp.asarray([[1 + 1j, 1 - 1j], [-1 + 1j, -1 - 1j]]) |
| 1975 | + expected = xp.asarray( |
| 1976 | + [[xp.pi / 4, -xp.pi / 4], [3 * xp.pi / 4, -3 * xp.pi / 4]] |
| 1977 | + ) |
| 1978 | + res = angle(a, xp=xp) |
| 1979 | + xp_assert_equal(res, expected) |
| 1980 | + |
| 1981 | + @pytest.mark.skip_xp_backend(Backend.TORCH, reason="materialize 'meta' device") |
| 1982 | + def test_device(self, xp: ModuleType, device: Device): |
| 1983 | + a = xp.asarray([1 + 1j], device=device) |
| 1984 | + assert get_device(angle(a)) == device |
0 commit comments