|
1 | | -import numpy |
| 1 | +import torch |
2 | 2 |
|
3 | 3 |
|
4 | | -def linear_srgb_from_srgb(srgb: numpy.ndarray) -> numpy.ndarray: |
5 | | - return numpy.where(srgb <= 0.0404482362771082, srgb / 12.92, ((srgb + 0.055) / 1.055) ** 2.4) |
| 4 | +def srgb_from_linear_srgb(linear_srgb_tensor: torch.Tensor) -> torch.Tensor: |
| 5 | + """Convert a 3xHxW linear-light sRGB tensor in [0, 1] to gamma-corrected sRGB.""" |
6 | 6 |
|
| 7 | + linear_srgb_tensor = linear_srgb_tensor.clamp(0.0, 1.0) |
| 8 | + return torch.where( |
| 9 | + linear_srgb_tensor <= 0.0031308, |
| 10 | + linear_srgb_tensor * 12.92, |
| 11 | + 1.055 * torch.pow(linear_srgb_tensor, 1.0 / 2.4) - 0.055, |
| 12 | + ) |
| 13 | + |
| 14 | + |
| 15 | +def linear_srgb_from_srgb(srgb_tensor: torch.Tensor) -> torch.Tensor: |
| 16 | + """Convert a 3xHxW gamma-corrected sRGB tensor in [0, 1] to linear-light sRGB.""" |
7 | 17 |
|
8 | | -def srgb_from_linear_srgb(linear_srgb: numpy.ndarray) -> numpy.ndarray: |
9 | | - linear_srgb = numpy.clip(linear_srgb, 0.0, 1.0) |
10 | | - return numpy.where( |
11 | | - linear_srgb <= 0.0031308, |
12 | | - linear_srgb * 12.92, |
13 | | - 1.055 * numpy.power(linear_srgb, 1.0 / 2.4) - 0.055, |
| 18 | + return torch.where( |
| 19 | + srgb_tensor <= 0.0404482362771082, |
| 20 | + srgb_tensor / 12.92, |
| 21 | + torch.pow((srgb_tensor + 0.055) / 1.055, 2.4), |
14 | 22 | ) |
15 | 23 |
|
16 | 24 |
|
17 | | -def oklab_from_linear_srgb(linear_srgb: numpy.ndarray) -> numpy.ndarray: |
18 | | - lms_l = 0.4122214708 * linear_srgb[..., 0] + 0.5363325363 * linear_srgb[..., 1] + 0.0514459929 * linear_srgb[..., 2] |
19 | | - lms_m = 0.2119034982 * linear_srgb[..., 0] + 0.6806995451 * linear_srgb[..., 1] + 0.1073969566 * linear_srgb[..., 2] |
20 | | - lms_s = 0.0883024619 * linear_srgb[..., 0] + 0.2817188376 * linear_srgb[..., 1] + 0.6299787005 * linear_srgb[..., 2] |
| 25 | +def oklab_from_linear_srgb(linear_srgb_tensor: torch.Tensor) -> torch.Tensor: |
| 26 | + """Convert a 3xHxW linear-light sRGB tensor to Oklab.""" |
21 | 27 |
|
22 | | - lms_l_cbrt = numpy.cbrt(lms_l) |
23 | | - lms_m_cbrt = numpy.cbrt(lms_m) |
24 | | - lms_s_cbrt = numpy.cbrt(lms_s) |
| 28 | + lms_l = ( |
| 29 | + 0.4122214708 * linear_srgb_tensor[0, ...] |
| 30 | + + 0.5363325363 * linear_srgb_tensor[1, ...] |
| 31 | + + 0.0514459929 * linear_srgb_tensor[2, ...] |
| 32 | + ) |
| 33 | + lms_m = ( |
| 34 | + 0.2119034982 * linear_srgb_tensor[0, ...] |
| 35 | + + 0.6806995451 * linear_srgb_tensor[1, ...] |
| 36 | + + 0.1073969566 * linear_srgb_tensor[2, ...] |
| 37 | + ) |
| 38 | + lms_s = ( |
| 39 | + 0.0883024619 * linear_srgb_tensor[0, ...] |
| 40 | + + 0.2817188376 * linear_srgb_tensor[1, ...] |
| 41 | + + 0.6299787005 * linear_srgb_tensor[2, ...] |
| 42 | + ) |
25 | 43 |
|
26 | | - return numpy.stack( |
| 44 | + lms_l_cbrt = torch.sign(lms_l) * torch.pow(torch.abs(lms_l), 1.0 / 3.0) |
| 45 | + lms_m_cbrt = torch.sign(lms_m) * torch.pow(torch.abs(lms_m), 1.0 / 3.0) |
| 46 | + lms_s_cbrt = torch.sign(lms_s) * torch.pow(torch.abs(lms_s), 1.0 / 3.0) |
| 47 | + |
| 48 | + return torch.stack( |
27 | 49 | [ |
28 | 50 | 0.2104542553 * lms_l_cbrt + 0.7936177850 * lms_m_cbrt - 0.0040720468 * lms_s_cbrt, |
29 | 51 | 1.9779984951 * lms_l_cbrt - 2.4285922050 * lms_m_cbrt + 0.4505937099 * lms_s_cbrt, |
30 | 52 | 0.0259040371 * lms_l_cbrt + 0.7827717662 * lms_m_cbrt - 0.8086757660 * lms_s_cbrt, |
31 | | - ], |
32 | | - axis=-1, |
| 53 | + ] |
33 | 54 | ) |
34 | 55 |
|
35 | 56 |
|
36 | | -def linear_srgb_from_oklab(oklab: numpy.ndarray) -> numpy.ndarray: |
37 | | - lms_l_cbrt = oklab[..., 0] + 0.3963377774 * oklab[..., 1] + 0.2158037573 * oklab[..., 2] |
38 | | - lms_m_cbrt = oklab[..., 0] - 0.1055613458 * oklab[..., 1] - 0.0638541728 * oklab[..., 2] |
39 | | - lms_s_cbrt = oklab[..., 0] - 0.0894841775 * oklab[..., 1] - 1.2914855480 * oklab[..., 2] |
| 57 | +def linear_srgb_from_oklab(oklab_tensor: torch.Tensor) -> torch.Tensor: |
| 58 | + """Convert a 3xHxW Oklab tensor to linear-light sRGB.""" |
| 59 | + |
| 60 | + lms_l_cbrt = oklab_tensor[0, ...] + 0.3963377774 * oklab_tensor[1, ...] + 0.2158037573 * oklab_tensor[2, ...] |
| 61 | + lms_m_cbrt = oklab_tensor[0, ...] - 0.1055613458 * oklab_tensor[1, ...] - 0.0638541728 * oklab_tensor[2, ...] |
| 62 | + lms_s_cbrt = oklab_tensor[0, ...] - 0.0894841775 * oklab_tensor[1, ...] - 1.2914855480 * oklab_tensor[2, ...] |
40 | 63 |
|
41 | 64 | lms_l = lms_l_cbrt**3 |
42 | 65 | lms_m = lms_m_cbrt**3 |
43 | 66 | lms_s = lms_s_cbrt**3 |
44 | 67 |
|
45 | | - return numpy.stack( |
| 68 | + return torch.stack( |
46 | 69 | [ |
47 | 70 | 4.0767416621 * lms_l - 3.3077115913 * lms_m + 0.2309699292 * lms_s, |
48 | 71 | -1.2684380046 * lms_l + 2.6097574011 * lms_m - 0.3413193965 * lms_s, |
49 | 72 | -0.0041960863 * lms_l - 0.7034186147 * lms_m + 1.7076147010 * lms_s, |
50 | | - ], |
51 | | - axis=-1, |
| 73 | + ] |
52 | 74 | ) |
53 | 75 |
|
54 | 76 |
|
55 | | -def oklch_from_oklab(oklab: numpy.ndarray) -> numpy.ndarray: |
56 | | - lightness = oklab[..., 0] |
57 | | - chroma = numpy.sqrt(oklab[..., 1] ** 2 + oklab[..., 2] ** 2) |
58 | | - hue = numpy.degrees(numpy.arctan2(oklab[..., 2], oklab[..., 1])) % 360.0 |
59 | | - return numpy.stack([lightness, chroma, hue], axis=-1) |
| 77 | +def oklch_from_oklab(oklab_tensor: torch.Tensor) -> torch.Tensor: |
| 78 | + """Convert a 3xHxW Oklab tensor to Oklch, with hue in degrees.""" |
| 79 | + |
| 80 | + lightness = oklab_tensor[0, ...] |
| 81 | + chroma = torch.sqrt(oklab_tensor[1, ...] ** 2 + oklab_tensor[2, ...] ** 2) |
| 82 | + hue = torch.remainder(torch.rad2deg(torch.atan2(oklab_tensor[2, ...], oklab_tensor[1, ...])), 360.0) |
| 83 | + return torch.stack([lightness, chroma, hue]) |
| 84 | + |
| 85 | + |
| 86 | +def oklab_from_oklch(oklch_tensor: torch.Tensor) -> torch.Tensor: |
| 87 | + """Convert a 3xHxW Oklch tensor, with hue in degrees, to Oklab.""" |
60 | 88 |
|
| 89 | + hue_radians = torch.deg2rad(oklch_tensor[2, ...]) |
| 90 | + a_channel = oklch_tensor[1, ...] * torch.cos(hue_radians) |
| 91 | + b_channel = oklch_tensor[1, ...] * torch.sin(hue_radians) |
| 92 | + return torch.stack([oklch_tensor[0, ...], a_channel, b_channel]) |
61 | 93 |
|
62 | | -def oklab_from_oklch(oklch: numpy.ndarray) -> numpy.ndarray: |
63 | | - hue_radians = numpy.radians(oklch[..., 2]) |
64 | | - a_channel = oklch[..., 1] * numpy.cos(hue_radians) |
65 | | - b_channel = oklch[..., 1] * numpy.sin(hue_radians) |
66 | | - return numpy.stack([oklch[..., 0], a_channel, b_channel], axis=-1) |
67 | 94 |
|
| 95 | +def linear_srgb_from_oklch(oklch_tensor: torch.Tensor) -> torch.Tensor: |
| 96 | + """Convert a 3xHxW Oklch tensor directly to linear-light sRGB.""" |
68 | 97 |
|
69 | | -def linear_srgb_from_oklch(oklch: numpy.ndarray) -> numpy.ndarray: |
70 | | - return linear_srgb_from_oklab(oklab_from_oklch(oklch)) |
| 98 | + return linear_srgb_from_oklab(oklab_from_oklch(oklch_tensor)) |
0 commit comments