|
1 | 1 | from typing import Any |
2 | 2 |
|
3 | | -import array_api_compat.torch as torch |
| 3 | +import numba |
| 4 | +import numpy as np |
4 | 5 | from array_api_compat import array_namespace |
| 6 | +from cm_time import timer |
| 7 | +from numba import prange |
| 8 | +from numba.extending import overload |
5 | 9 |
|
6 | 10 | from array_api_jit.main import jit |
7 | 11 |
|
8 | 12 |
|
| 13 | +@overload(np.stack) |
| 14 | +def stack(arrays, axis=0): |
| 15 | + def inner(arrays, axis=0): |
| 16 | + if axis == 0: |
| 17 | + shape = (len(arrays),) + arrays[0].shape # noqa: RUF005 |
| 18 | + stacked_array = np.empty(shape, dtype=arrays[0].dtype) |
| 19 | + for j in prange(len(arrays)): |
| 20 | + stacked_array[j] = arrays[j] |
| 21 | + elif axis == -1: |
| 22 | + shape = arrays[0].shape + (len(arrays),) # noqa: RUF005 |
| 23 | + stacked_array = np.empty(shape, dtype=arrays[0].dtype) |
| 24 | + for j in prange(len(arrays)): |
| 25 | + stacked_array[..., j] = arrays[j] |
| 26 | + return stacked_array |
| 27 | + |
| 28 | + return inner |
| 29 | + |
| 30 | + |
| 31 | +def legendre(x: Any, n_end: int) -> Any: |
| 32 | + """ |
| 33 | + Legendre polynomial of order 0 to n_end-1. |
| 34 | +
|
| 35 | + Parameters |
| 36 | + ---------- |
| 37 | + x : Any |
| 38 | + The points at which to evaluate the polynomial. |
| 39 | + n_end : int |
| 40 | + The number of orders to compute, starting from 0. |
| 41 | +
|
| 42 | + Returns |
| 43 | + ------- |
| 44 | + Any |
| 45 | + Array of shape (*x.shape, n_end), |
| 46 | + where [..., i] is the value of the i-th order polynomial at x. |
| 47 | +
|
| 48 | + """ |
| 49 | + xp = array_namespace(x) |
| 50 | + prevprev = xp.zeros_like(x) |
| 51 | + prev = xp.ones_like(x) |
| 52 | + result = [prevprev, prev] |
| 53 | + for n in range(2, n_end): |
| 54 | + prevprev, prev = prev, (((2 * n - 1) * x * prev - (n - 1) * prevprev) / n) |
| 55 | + result.append(prev) |
| 56 | + return xp.stack(result[:n_end], axis=-1) |
| 57 | + |
| 58 | + |
| 59 | +def legendre_assign(x: Any, n_end: int) -> Any: |
| 60 | + """ |
| 61 | + Legendre polynomial of order 0 to n_end-1. |
| 62 | +
|
| 63 | + Parameters |
| 64 | + ---------- |
| 65 | + x : Any |
| 66 | + The points at which to evaluate the polynomial. |
| 67 | + n_end : int |
| 68 | + The number of orders to compute, starting from 0. |
| 69 | +
|
| 70 | + Returns |
| 71 | + ------- |
| 72 | + Any |
| 73 | + Array of shape (*x.shape, n_end), |
| 74 | + where [..., i] is the value of the i-th order polynomial at x. |
| 75 | +
|
| 76 | + """ |
| 77 | + xp = array_namespace(x) |
| 78 | + prevprev = xp.zeros_like(x) |
| 79 | + prev = xp.ones_like(x) |
| 80 | + result = xp.empty((*x.shape, n_end), dtype=x.dtype) |
| 81 | + if n_end > 0: |
| 82 | + result[..., 0] = prevprev |
| 83 | + if n_end > 1: |
| 84 | + result[..., 1] = prev |
| 85 | + for n in range(2, n_end): |
| 86 | + prevprev, prev = prev, (((2 * n - 1) * x * prev - (n - 1) * prevprev) / n) |
| 87 | + result[..., n] = prev |
| 88 | + return result |
| 89 | + |
| 90 | + |
| 91 | +legendre_jit = jit( |
| 92 | + {"numpy": numba.jit(nogil=True, parallel=True)}, |
| 93 | + decorator_kwargs={"jax": {"static_argnames": ["n_end"]}}, |
| 94 | +)(legendre) |
| 95 | +legendre_assign_jit = jit({"numpy": numba.jit(nogil=True, parallel=True)})(legendre_assign) |
| 96 | + |
| 97 | + |
9 | 98 | def test_jit(xp: Any) -> None: |
10 | | - @jit() # type: ignore |
11 | | - def spherical_coordinates( |
12 | | - r: torch.Tensor, theta: torch.Tensor, phi: torch.Tensor |
13 | | - ) -> torch.Tensor: |
14 | | - xp = array_namespace(r, theta, phi) |
15 | | - rsin = r * xp.sin(theta) |
16 | | - x = rsin * xp.cos(phi) |
17 | | - y = rsin * xp.sin(phi) |
18 | | - z = r * xp.cos(theta) |
19 | | - return xp.stack((x, y, z), axis=-1) |
20 | | - |
21 | | - r = xp.arange(100) |
22 | | - theta = xp.linspace(0, xp.pi, 100) |
23 | | - phi = xp.linspace(0, 2 * xp.pi, 100) |
24 | | - x = spherical_coordinates(r, theta, phi) # type: ignore |
25 | | - assert x.shape == (100, 3) |
| 99 | + t = {} |
| 100 | + for name, func in [("nojit", legendre), ("jit", legendre_jit)] + ( |
| 101 | + [ |
| 102 | + ("assign-nojit", legendre_assign), |
| 103 | + ("assign-jit", legendre_assign_jit), |
| 104 | + ] |
| 105 | + if "jax" not in xp.__name__ |
| 106 | + else [] |
| 107 | + ): |
| 108 | + print(xp.__name__, name) |
| 109 | + for i in range(20): |
| 110 | + with timer() as timer_: |
| 111 | + x = xp.arange(1000000, dtype=xp.float32) |
| 112 | + p = func(x, 10) # type: ignore |
| 113 | + assert p.shape == (1000000, 10) |
| 114 | + t[name] = timer_.elapsed |
| 115 | + if i == 0: |
| 116 | + print(f"First call: {timer_.elapsed:g}s") |
| 117 | + print(f"Last call: {timer_.elapsed:g}s") |
| 118 | + if "numpy" not in xp.__name__: |
| 119 | + assert t["jit"] < t["nojit"], ( |
| 120 | + f"JIT time {t['jit']} should be less than non-JIT time {t['nojit']}" |
| 121 | + ) |
| 122 | + assert t["assign-jit"] < t["assign-nojit"], ( |
| 123 | + f"JIT assign time {t['assign-jit']} should be " |
| 124 | + f"less than non-JIT assign time {t['assign-nojit']}" |
| 125 | + ) |
0 commit comments