Skip to content

Commit b1773d9

Browse files
prady0tlucascolley
andauthored
ENH: add angle (#718)
* Adding angle Signed-off-by: Pradyot Ranjan <99216956+pradyotRanjan@users.noreply.github.com> * test correction + pre commit Signed-off-by: Pradyot Ranjan <99216956+pradyotRanjan@users.noreply.github.com> * adding suggestions Signed-off-by: Pradyot Ranjan <99216956+pradyotRanjan@users.noreply.github.com> * fix test Signed-off-by: Pradyot Ranjan <99216956+pradyotRanjan@users.noreply.github.com> * make deg keyword only and z position only Signed-off-by: Pradyot Ranjan <99216956+pradyotRanjan@users.noreply.github.com> * Update tests/test_funcs.py Co-authored-by: Lucas Colley <lucas.colley8@gmail.com> * Update src/array_api_extra/_lib/_funcs.py Co-authored-by: Lucas Colley <lucas.colley8@gmail.com> * Update tests/test_funcs.py Co-authored-by: Lucas Colley <lucas.colley8@gmail.com> * Update tests/test_funcs.py Co-authored-by: Lucas Colley <lucas.colley8@gmail.com> * adding suggestions Signed-off-by: Pradyot Ranjan <99216956+pradyotRanjan@users.noreply.github.com> * adding suggestions Signed-off-by: Pradyot Ranjan <99216956+pradyotRanjan@users.noreply.github.com> * adding a note for real input Signed-off-by: Pradyot Ranjan <99216956+pradyotRanjan@users.noreply.github.com> * tidy note --------- Signed-off-by: Pradyot Ranjan <99216956+pradyotRanjan@users.noreply.github.com> Co-authored-by: Pradyot Ranjan <99216956+pradyotRanjan@users.noreply.github.com> Co-authored-by: Lucas Colley <lucas.colley8@gmail.com>
1 parent 25754cd commit b1773d9

4 files changed

Lines changed: 128 additions & 0 deletions

File tree

docs/api-reference.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
:nosignatures:
77
:toctree: generated
88
9+
angle
910
apply_where
1011
argpartition
1112
at

src/array_api_extra/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
)
2121
from ._lib._at import at
2222
from ._lib._funcs import (
23+
angle,
2324
apply_where,
2425
default_dtype,
2526
kron,
@@ -32,6 +33,7 @@
3233
# pylint: disable=duplicate-code
3334
__all__ = [
3435
"__version__",
36+
"angle",
3537
"apply_where",
3638
"argpartition",
3739
"at",

src/array_api_extra/_lib/_funcs.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from ._utils._typing import Array, Device, DType
2424

2525
__all__ = [
26+
"angle",
2627
"apply_where",
2728
"atleast_nd",
2829
"broadcast_shapes",
@@ -782,3 +783,51 @@ def union1d(a: Array, b: Array, /, *, xp: ModuleType) -> Array:
782783
b = xp.reshape(b, (-1,))
783784
# XXX: `sparse` returns NumPy arrays from `unique_values`
784785
return xp.asarray(xp.unique_values(xp.concat([a, b])))
786+
787+
788+
def angle(z: Array, /, *, deg: bool = False, xp: ModuleType | None = None) -> Array:
789+
"""
790+
Return the angle of the complex argument.
791+
792+
Parameters
793+
----------
794+
z : Array
795+
Input array.
796+
deg : bool, optional
797+
Return angle in degrees if True, radians if False (default).
798+
xp : array_namespace, optional
799+
The standard-compatible namespace for `z`. Default: infer.
800+
801+
Returns
802+
-------
803+
array
804+
The counterclockwise angle from the positive real axis on the complex
805+
plane in the range ``(-pi, pi]``.
806+
807+
Notes
808+
-----
809+
Real input ``x`` is interpreted as ``x + 0j``.
810+
811+
Examples
812+
--------
813+
>>> import array_api_strict as xp
814+
>>> import array_api_extra as xpx
815+
>>> xpx.angle(xp.asarray([1.0, 1.0j, 1 + 1j]), xp=xp)
816+
Array([0. , 1.57079633, 0.78539816], dtype=array_api_strict.float64)
817+
>>> xpx.angle(xp.asarray([1.0, 1.0j, 1 + 1j]), deg=True, xp=xp)
818+
Array([ 0., 90., 45.], dtype=array_api_strict.float64)
819+
"""
820+
if xp is None:
821+
xp = array_namespace(z)
822+
if xp.isdtype(z.dtype, "complex floating"):
823+
zimag = xp.imag(z)
824+
zreal = xp.real(z)
825+
else:
826+
if not xp.isdtype(z.dtype, "real floating"):
827+
z = xp.astype(z, default_dtype(xp, device=_compat.device(z)))
828+
zimag = xp.zeros_like(z)
829+
zreal = z
830+
a = xp.atan2(zimag, zreal)
831+
if deg:
832+
a = a * 180 / xp.pi
833+
return a

tests/test_funcs.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from typing_extensions import override
1414

1515
from array_api_extra import (
16+
angle,
1617
apply_where,
1718
argpartition,
1819
at,
@@ -1906,3 +1907,78 @@ def test_device(self, xp: ModuleType, device: Device):
19061907
a = xp.asarray([-1, 1, 0], device=device)
19071908
b = xp.asarray([2, -2, 0], device=device)
19081909
assert get_device(union1d(a, b)) == device
1910+
1911+
1912+
class TestAngle:
1913+
def test_simple(self, xp: ModuleType):
1914+
a = xp.asarray([1, 0])
1915+
res = angle(a)
1916+
expected = xp.asarray([0.0, 0.0], dtype=res.dtype)
1917+
xp_assert_equal(res, expected)
1918+
1919+
def test_basic(self, xp: ModuleType):
1920+
x = xp.asarray(
1921+
[
1922+
1 + 3j,
1923+
np.sqrt(2) / 2.0 + 1j * np.sqrt(2) / 2,
1924+
1,
1925+
1j,
1926+
-1,
1927+
-1j,
1928+
1 - 3j,
1929+
-1 + 3j,
1930+
],
1931+
dtype=xp.complex128,
1932+
)
1933+
expected = xp.asarray(
1934+
[
1935+
np.arctan(3.0 / 1.0),
1936+
np.arctan(1.0),
1937+
0,
1938+
np.pi / 2,
1939+
np.pi,
1940+
-np.pi / 2.0,
1941+
-np.arctan(3.0 / 1.0),
1942+
np.pi - np.arctan(3.0 / 1.0),
1943+
],
1944+
dtype=xp.float64,
1945+
)
1946+
xp_assert_close(angle(x, xp=xp), expected, rtol=0, atol=1e-11)
1947+
xp_assert_close(
1948+
angle(x, deg=True, xp=xp),
1949+
expected * 180 / xp.pi,
1950+
rtol=0,
1951+
atol=1e-11,
1952+
)
1953+
1954+
def test_real(self, xp: ModuleType):
1955+
x = xp.asarray([0.0, -0.0, 1.0, -1.0])
1956+
expected = xp.asarray([0.0, xp.pi, 0.0, xp.pi], dtype=x.dtype)
1957+
xp_assert_close(angle(x, xp=xp), expected)
1958+
1959+
def test_complex(self, xp: ModuleType):
1960+
a = xp.asarray([1 + 1j, 1 - 1j, -1 + 1j, -1 - 1j])
1961+
expected = xp.asarray([xp.pi / 4, -xp.pi / 4, 3 * xp.pi / 4, -3 * xp.pi / 4])
1962+
res = angle(a, xp=xp)
1963+
xp_assert_equal(res, expected)
1964+
1965+
def test_integral(self, xp: ModuleType):
1966+
x = xp.asarray([0, -1, 1], dtype=xp.int32)
1967+
actual = angle(x, xp=xp)
1968+
expected = xp.asarray(
1969+
[0.0, xp.pi, 0.0], dtype=default_dtype(xp, device=get_device(x))
1970+
)
1971+
xp_assert_close(actual, expected)
1972+
1973+
def test_2d(self, xp: ModuleType):
1974+
a = xp.asarray([[1 + 1j, 1 - 1j], [-1 + 1j, -1 - 1j]])
1975+
expected = xp.asarray(
1976+
[[xp.pi / 4, -xp.pi / 4], [3 * xp.pi / 4, -3 * xp.pi / 4]]
1977+
)
1978+
res = angle(a, xp=xp)
1979+
xp_assert_equal(res, expected)
1980+
1981+
@pytest.mark.skip_xp_backend(Backend.TORCH, reason="materialize 'meta' device")
1982+
def test_device(self, xp: ModuleType, device: Device):
1983+
a = xp.asarray([1 + 1j], device=device)
1984+
assert get_device(angle(a)) == device

0 commit comments

Comments
 (0)