|
1 | | -from typing import Sequence, Literal |
| 1 | +# Standard libraries |
| 2 | +from typing import Literal, Sequence |
| 3 | + |
| 4 | +# Non-standard libraries |
2 | 5 | import jax |
3 | | -import numpy as np |
4 | 6 | import jax.numpy as jnp |
| 7 | +import numpy as np |
5 | 8 |
|
6 | 9 |
|
7 | 10 | class Color: |
8 | | - |
9 | 11 | def __init__(self, values: np.ndarray): |
10 | 12 | self._values = values |
11 | 13 | return |
12 | 14 |
|
13 | 15 | def rgb( |
14 | | - self, |
15 | | - unit: Literal['int', 'percent', 'fraction'] = 'int', |
16 | | - as_str: bool = False, |
17 | | - with_comma: bool = True |
| 16 | + self, |
| 17 | + unit: Literal["int", "percent", "fraction"] = "int", |
| 18 | + as_str: bool = False, |
| 19 | + with_comma: bool = True, |
18 | 20 | ) -> np.ndarray: |
19 | 21 | rgb_colors = self._values |
20 | | - sep = ', ' if with_comma else ' ' |
21 | | - if unit == 'int': |
| 22 | + sep = ", " if with_comma else " " |
| 23 | + if unit == "int": |
22 | 24 | if not as_str: |
23 | 25 | return rgb_colors |
24 | | - return np.array([f'rgb({sep.join(color)})' for color in rgb_colors.astype(str)]) |
| 26 | + return np.array([f"rgb({sep.join(color)})" for color in rgb_colors.astype(str)]) |
25 | 27 | rgb_colors_norm = rgb_colors / 255 |
26 | | - if unit == 'percent': |
| 28 | + if unit == "percent": |
27 | 29 | rgb_colors_norm *= 100 |
28 | 30 | if not as_str: |
29 | 31 | return rgb_colors_norm |
30 | 32 | return np.array( |
31 | 33 | [ |
32 | | - f'rgb({sep.join(color)})' |
33 | | - for color in np.char.mod("%.1f%%" if unit == 'percent' else "%.1f", rgb_colors_norm) |
| 34 | + f"rgb({sep.join(color)})" |
| 35 | + for color in np.char.mod("%.1f%%" if unit == "percent" else "%.1f", rgb_colors_norm) |
34 | 36 | ] |
35 | 37 | ) |
36 | 38 |
|
37 | 39 | def hex( |
38 | | - self, |
39 | | - with_hash: bool = False, |
| 40 | + self, |
| 41 | + with_hash: bool = False, |
40 | 42 | ): |
41 | | - sep = '#' if with_hash else '' |
42 | | - return np.array([f'{sep}{rgb[0]:02X}{rgb[1]:02X}{rgb[2]:02X}' for rgb in self._values]) |
| 43 | + sep = "#" if with_hash else "" |
| 44 | + return np.array([f"{sep}{rgb[0]:02X}{rgb[1]:02X}{rgb[2]:02X}" for rgb in self._values]) |
43 | 45 |
|
44 | 46 | @property |
45 | 47 | def hsl(self): |
46 | 48 | return _rgb_to_hsl(rgb=self._values) |
47 | 49 |
|
48 | 50 |
|
49 | | - |
50 | 51 | def rgb(values: tuple[int, int, int] | Sequence[tuple[int, int, int]]): |
51 | 52 | colors = np.asarray(values) |
52 | 53 | if not np.issubdtype(colors.dtype, np.integer): |
53 | | - raise ValueError(f"`values` must be a sequence of integers, but found elements with type {colors.dtype}") |
| 54 | + raise ValueError( |
| 55 | + f"`values` must be a sequence of integers, but found elements with type {colors.dtype}" |
| 56 | + ) |
54 | 57 | if np.any(np.logical_or(colors < 0, colors > 255)): |
55 | 58 | raise ValueError("`values` must be in the range [0, 255].") |
56 | 59 | colors = colors.astype(np.ubyte) |
57 | 60 | if colors.ndim > 2 or colors.shape[-1] != 3: |
58 | | - raise ValueError(f"`values` must either have a shape of (3, ) or (n, 3). The input shape was {colors.shape}") |
| 61 | + raise ValueError( |
| 62 | + f"`values` must either have a shape of (3, ) or (n, 3). The input shape was {colors.shape}" |
| 63 | + ) |
59 | 64 | colors = colors[np.newaxis] if colors.ndim == 1 else colors |
60 | 65 | return Color(values=colors) |
61 | 66 |
|
62 | 67 |
|
63 | 68 | def hexa(values: str | Sequence[str]) -> Color: |
64 | | - |
65 | 69 | def process_single_hex(val: str) -> tuple[int, int, int]: |
66 | 70 | if len(val) == 3: |
67 | | - val = ''.join([d * 2 for d in val]) |
| 71 | + val = "".join([d * 2 for d in val]) |
68 | 72 | elif len(val) != 6: |
69 | 73 | raise ValueError(f"Hex color '{val}' not recognized.") |
70 | | - return tuple(int(val[i:i + 2], 16) for i in range(0, 5, 2)) |
| 74 | + return tuple(int(val[i : i + 2], 16) for i in range(0, 5, 2)) |
71 | 75 |
|
72 | 76 | colors = np.asarray(values) |
73 | 77 | if not np.issubdtype(colors.dtype, np.str_): |
74 | | - raise ValueError(f"`values` must be a sequence of strings, but found elements with type {colors.dtype}") |
| 78 | + raise ValueError( |
| 79 | + f"`values` must be a sequence of strings, but found elements with type {colors.dtype}" |
| 80 | + ) |
75 | 81 | if colors.ndim == 0: |
76 | 82 | colors = colors[np.newaxis] |
77 | 83 | elif colors.ndim > 1: |
78 | 84 | raise ValueError( |
79 | 85 | f"`values` must either be a string, or a 1-dimensional sequence. The input dimension was {colors.ndim}" |
80 | 86 | ) |
81 | | - colors = np.char.lstrip(colors, '#') |
| 87 | + colors = np.char.lstrip(colors, "#") |
82 | 88 | colors_rgb = np.empty(shape=(colors.size, 3), dtype=np.ubyte) |
83 | 89 | for i, color in enumerate(colors): |
84 | 90 | colors_rgb[i] = process_single_hex(color) |
@@ -112,19 +118,18 @@ def _rgb_to_hsl(rgb: jax.Array): |
112 | 118 | # 'H' and 'S' values are non-zero and must be calculated. |
113 | 119 | # Calculate all 'S' values. |
114 | 120 | # Here conditioning on `minmax_dist == 0` is not necessary, since the numerator is `minmax_dist` |
115 | | - hsl_s = jnp.where( |
116 | | - hsl_l > 0.5, |
117 | | - minmax_dist / (2 - minmax_sum), |
118 | | - minmax_dist / minmax_sum |
119 | | - ) |
| 121 | + hsl_s = jnp.where(hsl_l > 0.5, minmax_dist / (2 - minmax_sum), minmax_dist / minmax_sum) |
120 | 122 | # Calculate all 'H' values. |
121 | | - hsl_h = jnp.where( |
122 | | - minmax_dist == 0, |
123 | | - 0, |
| 123 | + hsl_h = ( |
124 | 124 | jnp.where( |
125 | | - max_val == rs, |
126 | | - jnp.where(gs < bs, 6, 0) + (gs - bs) / minmax_dist, |
127 | | - jnp.where(max_val == gs, (bs - rs) / minmax_dist + 2, (rs - gs) / minmax_dist + 4) |
| 125 | + minmax_dist == 0, |
| 126 | + 0, |
| 127 | + jnp.where( |
| 128 | + max_val == rs, |
| 129 | + jnp.where(gs < bs, 6, 0) + (gs - bs) / minmax_dist, |
| 130 | + jnp.where(max_val == gs, (bs - rs) / minmax_dist + 2, (rs - gs) / minmax_dist + 4), |
| 131 | + ), |
128 | 132 | ) |
129 | | - ) / 6 |
| 133 | + / 6 |
| 134 | + ) |
130 | 135 | return jnp.stack([hsl_h, hsl_s, hsl_l], axis=-1) * 100 |
0 commit comments