|
| 1 | +if Code.ensure_loaded?(Nx) do |
| 2 | + defmodule Image.Color do |
| 3 | + @moduledoc """ |
| 4 | + Vectorised colour-space conversions for `Nx` tensors of pixel |
| 5 | + rows. |
| 6 | +
|
| 7 | + Per-pixel conversion via `Color.convert/2` is correct but slow |
| 8 | + when there are tens of thousands of pixels in flight (e.g. |
| 9 | + palette extraction from an image). The helpers in this module |
| 10 | + do the same conversions as a single tensor op so the cost is |
| 11 | + `O(matmul + element-wise)` rather than `O(n × Elixir-call)`. |
| 12 | +
|
| 13 | + All functions in this module require [`Nx`](https://hex.pm/packages/nx). |
| 14 | + They are not compiled if `Nx` is not loaded. |
| 15 | +
|
| 16 | + """ |
| 17 | + |
| 18 | + import Nx, warn: false |
| 19 | + |
| 20 | + # ---- sRGB → XYZ ────────────────────────────────────────── |
| 21 | + # |
| 22 | + # Bruce Lindbloom's published sRGB → CIE XYZ matrix relative |
| 23 | + # to D65 (https://www.brucelindbloom.com/), in row-major form. |
| 24 | + @srgb_to_xyz_d65 [ |
| 25 | + [0.4124564, 0.3575761, 0.1804375], |
| 26 | + [0.2126729, 0.7151522, 0.0721750], |
| 27 | + [0.0193339, 0.1191920, 0.9503041] |
| 28 | + ] |
| 29 | + |
| 30 | + # ---- XYZ (D65) → LMS ──────────────────────────────────── |
| 31 | + # |
| 32 | + # Ottosson's M1 matrix (`Color.Conversion.Oklab` mirrors |
| 33 | + # this — kept here as a literal so the inner loop avoids a |
| 34 | + # module attribute lookup per call). |
| 35 | + @m1 [ |
| 36 | + [0.8189330101, 0.3618667424, -0.1288597137], |
| 37 | + [0.0329845436, 0.9293118715, 0.0361456387], |
| 38 | + [0.0482003018, 0.2643662691, 0.6338517070] |
| 39 | + ] |
| 40 | + |
| 41 | + # ---- LMS' → Oklab ─────────────────────────────────────── |
| 42 | + @m2 [ |
| 43 | + [0.2104542553, 0.7936177850, -0.0040720468], |
| 44 | + [1.9779984951, -2.4285922050, 0.4505937099], |
| 45 | + [0.0259040371, 0.7827717662, -0.8086757660] |
| 46 | + ] |
| 47 | + |
| 48 | + @doc """ |
| 49 | + Converts a tensor of sRGB pixel rows to an Oklab tensor. |
| 50 | +
|
| 51 | + Input values are interpreted as 8-bit sRGB (0–255) when the |
| 52 | + tensor type is integer; as unit-range linear sRGB (0.0–1.0) |
| 53 | + is **not** assumed — even float tensors are treated as |
| 54 | + gamma-encoded sRGB on the [0, 1] scale, matching what |
| 55 | + `Image.to_nx/2` returns when the source image is in the |
| 56 | + sRGB colourspace. |
| 57 | +
|
| 58 | + The pipeline is the standard one: |
| 59 | +
|
| 60 | + sRGB → linear-sRGB → XYZ (D65) → LMS → ∛ → LMS' → Oklab |
| 61 | +
|
| 62 | + All of it is expressed as Nx tensor ops so a 90 000-row |
| 63 | + input is one matmul-heavy pass rather than 90 000 Elixir |
| 64 | + function calls. |
| 65 | +
|
| 66 | + ### Arguments |
| 67 | +
|
| 68 | + * `tensor` is an `Nx.Tensor.t/0` of shape `{n, 3}` (alpha |
| 69 | + bands must be stripped before calling — alpha is a property |
| 70 | + of the *source* image, not of the colour conversion). |
| 71 | +
|
| 72 | + ### Returns |
| 73 | +
|
| 74 | + * An `Nx.Tensor.t/0` of shape `{n, 3}`, type `f32`, where |
| 75 | + column `0` is `L`, column `1` is `a`, and column `2` is |
| 76 | + `b`. |
| 77 | +
|
| 78 | + ### Examples |
| 79 | +
|
| 80 | + iex> rgb = Nx.tensor([[255, 0, 0], [0, 255, 0], [0, 0, 255]], type: :u8) |
| 81 | + iex> oklab = Image.Color.srgb_tensor_to_oklab(rgb) |
| 82 | + iex> Nx.shape(oklab) |
| 83 | + {3, 3} |
| 84 | +
|
| 85 | + """ |
| 86 | + @spec srgb_tensor_to_oklab(Nx.Tensor.t()) :: Nx.Tensor.t() |
| 87 | + def srgb_tensor_to_oklab(tensor) do |
| 88 | + tensor |
| 89 | + |> normalise_to_unit() |
| 90 | + |> srgb_to_linear() |
| 91 | + |> matmul_rows(@srgb_to_xyz_d65) |
| 92 | + |> matmul_rows(@m1) |
| 93 | + |> Nx.cbrt() |
| 94 | + |> matmul_rows(@m2) |
| 95 | + end |
| 96 | + |
| 97 | + # Dispatch on tensor type. Integer u8 inputs are scaled by |
| 98 | + # 1/255; float inputs are assumed to already live in [0, 1]. |
| 99 | + defp normalise_to_unit(tensor) do |
| 100 | + case Nx.type(tensor) do |
| 101 | + {:u, _} -> Nx.divide(tensor, 255.0) |
| 102 | + {:s, _} -> Nx.divide(tensor, 255.0) |
| 103 | + {:f, _} -> tensor |
| 104 | + {:bf, _} -> tensor |
| 105 | + end |
| 106 | + end |
| 107 | + |
| 108 | + # IEC 61966-2-1 inverse-gamma. Element-wise, branchless via |
| 109 | + # Nx.select so the whole tensor processes in one pass. |
| 110 | + defp srgb_to_linear(tensor) do |
| 111 | + threshold = 0.04045 |
| 112 | + lin_lo = Nx.divide(tensor, 12.92) |
| 113 | + lin_hi = Nx.pow(Nx.divide(Nx.add(tensor, 0.055), 1.055), 2.4) |
| 114 | + |
| 115 | + Nx.select(Nx.less_equal(tensor, threshold), lin_lo, lin_hi) |
| 116 | + end |
| 117 | + |
| 118 | + # Right-multiply each row of `tensor` (shape `{n, k}`) by the |
| 119 | + # transpose of `matrix` (a literal `k×k` row-major list). The |
| 120 | + # net effect is that each output row is `matrix · row`, the |
| 121 | + # standard convention used in `Color.Conversion.Lindbloom`. |
| 122 | + defp matmul_rows(tensor, matrix) do |
| 123 | + m = Nx.tensor(matrix, type: :f32) |
| 124 | + Nx.dot(tensor, Nx.transpose(m)) |
| 125 | + end |
| 126 | + end |
| 127 | +end |
0 commit comments