Skip to content

Commit 0ec328b

Browse files
committed
Adding angle
Signed-off-by: Pradyot Ranjan <99216956+pradyotRanjan@users.noreply.github.com>
1 parent 5045afa commit 0ec328b

3 files changed

Lines changed: 60 additions & 0 deletions

File tree

src/array_api_extra/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
default_dtype,
2525
kron,
2626
nunique,
27+
angle,
2728
)
2829
from ._lib._lazy import lazy_apply
2930

@@ -54,4 +55,5 @@
5455
"setdiff1d",
5556
"sinc",
5657
"union1d",
58+
"angle",
5759
]

src/array_api_extra/_lib/_funcs.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -818,3 +818,36 @@ def union1d(a: Array, b: Array, /, *, xp: ModuleType) -> Array:
818818
b = xp.reshape(b, (-1,))
819819
# XXX: `sparse` returns NumPy arrays from `unique_values`
820820
return xp.asarray(xp.unique_values(xp.concat([a, b])))
821+
822+
823+
def angle(z: Array, deg: bool = False, /, *, xp: ModuleType | None = None) -> Array:
824+
"""
825+
Return the angle of the complex argument.
826+
827+
Parameters
828+
----------
829+
z : Array
830+
Input array.
831+
deg : bool, optional
832+
Return angle in degrees if True, radians if False (default).
833+
xp : array_namespace, optional
834+
The standard-compatible namespace for `z`. Default: infer.
835+
836+
Returns
837+
-------
838+
angle : ndarray or scalar
839+
The counterclockwise angle from the positive real axis on the complex
840+
plane in the range ``(-pi, pi]``, with dtype as float64.
841+
"""
842+
if xp is None:
843+
xp = array_namespace(z)
844+
if xp.isdtype(z.dtype, "complex floating"):
845+
zimage = xp.imag(z)
846+
zreal = xp.real(z)
847+
else:
848+
zimage = xp.zeros_like(z, dtype=xp.float64)
849+
zreal = xp.astype(z, xp.float64)
850+
a = xp.atan2(zimage, zreal)
851+
if deg:
852+
a = a * 180 / xp.pi
853+
return a

tests/test_funcs.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
setdiff1d,
3434
sinc,
3535
union1d,
36+
angle,
3637
)
3738
from array_api_extra import (
3839
searchsorted as xpx_searchsorted,
@@ -1881,3 +1882,27 @@ def test_device(self, xp: ModuleType, device: Device):
18811882
a = xp.asarray([-1, 1, 0], device=device)
18821883
b = xp.asarray([2, -2, 0], device=device)
18831884
assert get_device(union1d(a, b)) == device
1885+
1886+
class TestAngle:
1887+
def test_simple(self, xp: ModuleType):
1888+
a = xp.asarray([1, 0])
1889+
expected = xp.asarray([0., 0.])
1890+
res = angle(a)
1891+
xp_assert_equal(res, expected)
1892+
1893+
def test_complex(self, xp: ModuleType):
1894+
a = xp.asarray([1 + 1j, 1 - 1j, -1 + 1j, -1 - 1j])
1895+
expected = xp.asarray([np.pi / 4, -np.pi / 4, 3 * np.pi / 4, -3 * np.pi / 4])
1896+
res = angle(a)
1897+
xp_assert_equal(res, expected)
1898+
1899+
def test_2d(self, xp: ModuleType):
1900+
a = xp.asarray([[1 + 1j, 1 - 1j], [-1 + 1j, -1 - 1j]])
1901+
expected = xp.asarray([[np.pi / 4, -np.pi / 4], [3 * np.pi / 4, -3 * np.pi / 4]])
1902+
res = angle(a)
1903+
xp_assert_equal(res, expected)
1904+
1905+
@pytest.mark.skip_xp_backend(Backend.TORCH, reason="materialize 'meta' device")
1906+
def test_device(self, xp: ModuleType, device: Device):
1907+
a = xp.asarray([1 + 1j], device=device)
1908+
assert get_device(angle(a)) == device

0 commit comments

Comments
 (0)