Skip to content

Commit 33c7a87

Browse files
committed
refactor: unify oklab color conversions
1 parent e678764 commit 33c7a87

5 files changed

Lines changed: 258 additions & 177 deletions

File tree

invokeai/app/invocations/image.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import cv2
66
import numpy
7+
import torch
78
from PIL import Image, ImageChops, ImageFilter, ImageOps
89

910
from invokeai.app.invocations.baseinvocation import (
@@ -419,11 +420,12 @@ class OklabUnsharpMaskInvocation(BaseInvocation, WithMetadata, WithBoard):
419420
radius: float = InputField(gt=0, description="Unsharp mask radius", default=2)
420421
strength: float = InputField(ge=0, description="Unsharp mask strength", default=50)
421422

422-
def pil_from_array(self, arr: numpy.ndarray) -> Image.Image:
423-
return Image.fromarray((numpy.clip(arr, 0.0, 1.0) * 255).astype("uint8"))
423+
def pil_from_tensor(self, tensor: torch.Tensor) -> Image.Image:
424+
array = torch.clamp(tensor, 0.0, 1.0).permute(1, 2, 0).cpu().numpy()
425+
return Image.fromarray((array * 255).astype("uint8"))
424426

425-
def array_from_pil(self, img: Image.Image) -> numpy.ndarray:
426-
return numpy.array(img, dtype=numpy.float32) / 255.0
427+
def tensor_from_pil(self, img: Image.Image) -> torch.Tensor:
428+
return torch.from_numpy(numpy.array(img, dtype=numpy.float32) / 255.0).permute(2, 0, 1)
427429

428430
def invoke(self, context: InvocationContext) -> ImageOutput:
429431
image = context.images.get_pil(self.image.image_name)
@@ -432,16 +434,16 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
432434
alpha_channel = image.getchannel("A") if mode == "RGBA" else None
433435
image = image.convert("RGB")
434436

435-
image_blurred = self.array_from_pil(image.filter(ImageFilter.GaussianBlur(radius=self.radius)))
436-
image_arr = self.array_from_pil(image)
437+
image_blurred = self.tensor_from_pil(image.filter(ImageFilter.GaussianBlur(radius=self.radius)))
438+
image_tensor = self.tensor_from_pil(image)
437439

438-
image_oklab = oklab_from_linear_srgb(linear_srgb_from_srgb(image_arr))
440+
image_oklab = oklab_from_linear_srgb(linear_srgb_from_srgb(image_tensor))
439441
image_blurred_oklab = oklab_from_linear_srgb(linear_srgb_from_srgb(image_blurred))
440442

441-
image_oklab += (image_oklab - image_blurred_oklab) * (self.strength / 100.0)
442-
image_oklab = numpy.clip(image_oklab, -1.0, 1.0)
443+
image_oklab[0, ...] += (image_oklab[0, ...] - image_blurred_oklab[0, ...]) * (self.strength / 100.0)
444+
image_oklab = torch.clamp(image_oklab, -1.0, 1.0)
443445

444-
image = self.pil_from_array(srgb_from_linear_srgb(linear_srgb_from_oklab(image_oklab))).convert(mode)
446+
image = self.pil_from_tensor(srgb_from_linear_srgb(linear_srgb_from_oklab(image_oklab))).convert(mode)
445447

446448
if alpha_channel is not None:
447449
image.putalpha(alpha_channel)
@@ -854,12 +856,18 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
854856
mode = image.mode
855857
alpha_channel = image.getchannel("A") if mode == "RGBA" else None
856858

857-
rgb = numpy.asarray(image.convert("RGB"), dtype=numpy.float32) / 255.0
859+
rgb = torch.from_numpy(numpy.asarray(image.convert("RGB"), dtype=numpy.float32) / 255.0).permute(2, 0, 1)
858860
oklch = oklch_from_oklab(oklab_from_linear_srgb(linear_srgb_from_srgb(rgb)))
859-
oklch[..., 2] = (oklch[..., 2] + self.hue) % 360.0
861+
oklch[2, ...] = (oklch[2, ...] + self.hue) % 360.0
860862

861863
image = Image.fromarray(
862-
numpy.clip(srgb_from_linear_srgb(linear_srgb_from_oklch(oklch)) * 255.0, 0.0, 255.0).astype(numpy.uint8),
864+
(
865+
torch.clamp(srgb_from_linear_srgb(linear_srgb_from_oklch(oklch)), 0.0, 1.0)
866+
.permute(1, 2, 0)
867+
.cpu()
868+
.numpy()
869+
* 255.0
870+
).astype(numpy.uint8),
863871
mode="RGB",
864872
).convert(mode)
865873

Lines changed: 66 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,70 +1,98 @@
1-
import numpy
1+
import torch
22

33

4-
def linear_srgb_from_srgb(srgb: numpy.ndarray) -> numpy.ndarray:
5-
return numpy.where(srgb <= 0.0404482362771082, srgb / 12.92, ((srgb + 0.055) / 1.055) ** 2.4)
4+
def srgb_from_linear_srgb(linear_srgb_tensor: torch.Tensor) -> torch.Tensor:
5+
"""Convert a 3xHxW linear-light sRGB tensor in [0, 1] to gamma-corrected sRGB."""
66

7+
linear_srgb_tensor = linear_srgb_tensor.clamp(0.0, 1.0)
8+
return torch.where(
9+
linear_srgb_tensor <= 0.0031308,
10+
linear_srgb_tensor * 12.92,
11+
1.055 * torch.pow(linear_srgb_tensor, 1.0 / 2.4) - 0.055,
12+
)
13+
14+
15+
def linear_srgb_from_srgb(srgb_tensor: torch.Tensor) -> torch.Tensor:
16+
"""Convert a 3xHxW gamma-corrected sRGB tensor in [0, 1] to linear-light sRGB."""
717

8-
def srgb_from_linear_srgb(linear_srgb: numpy.ndarray) -> numpy.ndarray:
9-
linear_srgb = numpy.clip(linear_srgb, 0.0, 1.0)
10-
return numpy.where(
11-
linear_srgb <= 0.0031308,
12-
linear_srgb * 12.92,
13-
1.055 * numpy.power(linear_srgb, 1.0 / 2.4) - 0.055,
18+
return torch.where(
19+
srgb_tensor <= 0.0404482362771082,
20+
srgb_tensor / 12.92,
21+
torch.pow((srgb_tensor + 0.055) / 1.055, 2.4),
1422
)
1523

1624

17-
def oklab_from_linear_srgb(linear_srgb: numpy.ndarray) -> numpy.ndarray:
18-
lms_l = 0.4122214708 * linear_srgb[..., 0] + 0.5363325363 * linear_srgb[..., 1] + 0.0514459929 * linear_srgb[..., 2]
19-
lms_m = 0.2119034982 * linear_srgb[..., 0] + 0.6806995451 * linear_srgb[..., 1] + 0.1073969566 * linear_srgb[..., 2]
20-
lms_s = 0.0883024619 * linear_srgb[..., 0] + 0.2817188376 * linear_srgb[..., 1] + 0.6299787005 * linear_srgb[..., 2]
25+
def oklab_from_linear_srgb(linear_srgb_tensor: torch.Tensor) -> torch.Tensor:
26+
"""Convert a 3xHxW linear-light sRGB tensor to Oklab."""
2127

22-
lms_l_cbrt = numpy.cbrt(lms_l)
23-
lms_m_cbrt = numpy.cbrt(lms_m)
24-
lms_s_cbrt = numpy.cbrt(lms_s)
28+
lms_l = (
29+
0.4122214708 * linear_srgb_tensor[0, ...]
30+
+ 0.5363325363 * linear_srgb_tensor[1, ...]
31+
+ 0.0514459929 * linear_srgb_tensor[2, ...]
32+
)
33+
lms_m = (
34+
0.2119034982 * linear_srgb_tensor[0, ...]
35+
+ 0.6806995451 * linear_srgb_tensor[1, ...]
36+
+ 0.1073969566 * linear_srgb_tensor[2, ...]
37+
)
38+
lms_s = (
39+
0.0883024619 * linear_srgb_tensor[0, ...]
40+
+ 0.2817188376 * linear_srgb_tensor[1, ...]
41+
+ 0.6299787005 * linear_srgb_tensor[2, ...]
42+
)
2543

26-
return numpy.stack(
44+
lms_l_cbrt = torch.sign(lms_l) * torch.pow(torch.abs(lms_l), 1.0 / 3.0)
45+
lms_m_cbrt = torch.sign(lms_m) * torch.pow(torch.abs(lms_m), 1.0 / 3.0)
46+
lms_s_cbrt = torch.sign(lms_s) * torch.pow(torch.abs(lms_s), 1.0 / 3.0)
47+
48+
return torch.stack(
2749
[
2850
0.2104542553 * lms_l_cbrt + 0.7936177850 * lms_m_cbrt - 0.0040720468 * lms_s_cbrt,
2951
1.9779984951 * lms_l_cbrt - 2.4285922050 * lms_m_cbrt + 0.4505937099 * lms_s_cbrt,
3052
0.0259040371 * lms_l_cbrt + 0.7827717662 * lms_m_cbrt - 0.8086757660 * lms_s_cbrt,
31-
],
32-
axis=-1,
53+
]
3354
)
3455

3556

36-
def linear_srgb_from_oklab(oklab: numpy.ndarray) -> numpy.ndarray:
37-
lms_l_cbrt = oklab[..., 0] + 0.3963377774 * oklab[..., 1] + 0.2158037573 * oklab[..., 2]
38-
lms_m_cbrt = oklab[..., 0] - 0.1055613458 * oklab[..., 1] - 0.0638541728 * oklab[..., 2]
39-
lms_s_cbrt = oklab[..., 0] - 0.0894841775 * oklab[..., 1] - 1.2914855480 * oklab[..., 2]
57+
def linear_srgb_from_oklab(oklab_tensor: torch.Tensor) -> torch.Tensor:
58+
"""Convert a 3xHxW Oklab tensor to linear-light sRGB."""
59+
60+
lms_l_cbrt = oklab_tensor[0, ...] + 0.3963377774 * oklab_tensor[1, ...] + 0.2158037573 * oklab_tensor[2, ...]
61+
lms_m_cbrt = oklab_tensor[0, ...] - 0.1055613458 * oklab_tensor[1, ...] - 0.0638541728 * oklab_tensor[2, ...]
62+
lms_s_cbrt = oklab_tensor[0, ...] - 0.0894841775 * oklab_tensor[1, ...] - 1.2914855480 * oklab_tensor[2, ...]
4063

4164
lms_l = lms_l_cbrt**3
4265
lms_m = lms_m_cbrt**3
4366
lms_s = lms_s_cbrt**3
4467

45-
return numpy.stack(
68+
return torch.stack(
4669
[
4770
4.0767416621 * lms_l - 3.3077115913 * lms_m + 0.2309699292 * lms_s,
4871
-1.2684380046 * lms_l + 2.6097574011 * lms_m - 0.3413193965 * lms_s,
4972
-0.0041960863 * lms_l - 0.7034186147 * lms_m + 1.7076147010 * lms_s,
50-
],
51-
axis=-1,
73+
]
5274
)
5375

5476

55-
def oklch_from_oklab(oklab: numpy.ndarray) -> numpy.ndarray:
56-
lightness = oklab[..., 0]
57-
chroma = numpy.sqrt(oklab[..., 1] ** 2 + oklab[..., 2] ** 2)
58-
hue = numpy.degrees(numpy.arctan2(oklab[..., 2], oklab[..., 1])) % 360.0
59-
return numpy.stack([lightness, chroma, hue], axis=-1)
77+
def oklch_from_oklab(oklab_tensor: torch.Tensor) -> torch.Tensor:
78+
"""Convert a 3xHxW Oklab tensor to Oklch, with hue in degrees."""
79+
80+
lightness = oklab_tensor[0, ...]
81+
chroma = torch.sqrt(oklab_tensor[1, ...] ** 2 + oklab_tensor[2, ...] ** 2)
82+
hue = torch.remainder(torch.rad2deg(torch.atan2(oklab_tensor[2, ...], oklab_tensor[1, ...])), 360.0)
83+
return torch.stack([lightness, chroma, hue])
84+
85+
86+
def oklab_from_oklch(oklch_tensor: torch.Tensor) -> torch.Tensor:
87+
"""Convert a 3xHxW Oklch tensor, with hue in degrees, to Oklab."""
6088

89+
hue_radians = torch.deg2rad(oklch_tensor[2, ...])
90+
a_channel = oklch_tensor[1, ...] * torch.cos(hue_radians)
91+
b_channel = oklch_tensor[1, ...] * torch.sin(hue_radians)
92+
return torch.stack([oklch_tensor[0, ...], a_channel, b_channel])
6193

62-
def oklab_from_oklch(oklch: numpy.ndarray) -> numpy.ndarray:
63-
hue_radians = numpy.radians(oklch[..., 2])
64-
a_channel = oklch[..., 1] * numpy.cos(hue_radians)
65-
b_channel = oklch[..., 1] * numpy.sin(hue_radians)
66-
return numpy.stack([oklch[..., 0], a_channel, b_channel], axis=-1)
6794

95+
def linear_srgb_from_oklch(oklch_tensor: torch.Tensor) -> torch.Tensor:
96+
"""Convert a 3xHxW Oklch tensor directly to linear-light sRGB."""
6897

69-
def linear_srgb_from_oklch(oklch: numpy.ndarray) -> numpy.ndarray:
70-
return linear_srgb_from_oklab(oklab_from_oklch(oklch))
98+
return linear_srgb_from_oklab(oklab_from_oklch(oklch_tensor))

invokeai/backend/image_util/composition.py

Lines changed: 9 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,14 @@
1414
import torch
1515
from PIL import Image
1616

17+
from invokeai.backend.image_util.color_conversion import (
18+
linear_srgb_from_oklab,
19+
linear_srgb_from_srgb,
20+
oklab_from_linear_srgb,
21+
)
22+
from invokeai.backend.image_util.color_conversion import (
23+
srgb_from_linear_srgb as shared_srgb_from_linear_srgb,
24+
)
1725
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
1826

1927
MAX_FLOAT = torch.finfo(torch.tensor(1.0).dtype).max
@@ -60,23 +68,7 @@ def srgb_from_linear_srgb(linear_srgb_tensor: torch.Tensor, alpha: float = 0.0,
6068

6169
if 0.0 < alpha:
6270
linear_srgb_tensor = gamut_clip_tensor(linear_srgb_tensor, alpha=alpha, steps=steps)
63-
linear_srgb_tensor = linear_srgb_tensor.clamp(0.0, 1.0)
64-
mask = torch.lt(linear_srgb_tensor, 0.0404482362771082 / 12.92)
65-
rgb_tensor = torch.sub(torch.mul(torch.pow(linear_srgb_tensor, (1 / 2.4)), 1.055), 0.055)
66-
rgb_tensor[mask] = torch.mul(linear_srgb_tensor[mask], 12.92)
67-
68-
return rgb_tensor
69-
70-
71-
def linear_srgb_from_srgb(srgb_tensor: torch.Tensor):
72-
"""Get linear-light sRGB from a standard gamma-corrected sRGB image tensor"""
73-
74-
linear_srgb_tensor = torch.pow(torch.div(torch.add(srgb_tensor, 0.055), 1.055), 2.4)
75-
linear_srgb_tensor_1 = torch.div(srgb_tensor, 12.92)
76-
mask = torch.le(srgb_tensor, 0.0404482362771082)
77-
linear_srgb_tensor[mask] = linear_srgb_tensor_1[mask]
78-
79-
return linear_srgb_tensor
71+
return shared_srgb_from_linear_srgb(linear_srgb_tensor)
8072

8173

8274
def max_srgb_saturation_tensor(units_ab_tensor: torch.Tensor, steps: int = 1):
@@ -175,63 +167,6 @@ def max_srgb_saturation_tensor(units_ab_tensor: torch.Tensor, steps: int = 1):
175167
return s_tensor
176168

177169

178-
def linear_srgb_from_oklab(oklab_tensor: torch.Tensor):
179-
"""Get linear-light sRGB from an Oklab image tensor"""
180-
181-
# L*a*b* to LMS
182-
lms_matrix_1 = torch.tensor(
183-
[[1.0, 0.3963377774, 0.2158037573], [1.0, -0.1055613458, -0.0638541728], [1.0, -0.0894841775, -1.2914855480]]
184-
)
185-
186-
lms_tensor_1 = torch.einsum("lwh, kl -> kwh", oklab_tensor, lms_matrix_1)
187-
lms_tensor = torch.pow(lms_tensor_1, 3.0)
188-
189-
# LMS to linear RGB
190-
rgb_matrix = torch.tensor(
191-
[
192-
[4.0767416621, -3.3077115913, 0.2309699292],
193-
[-1.2684380046, 2.6097574011, -0.3413193965],
194-
[-0.0041960863, -0.7034186147, 1.7076147010],
195-
]
196-
)
197-
198-
linear_srgb_tensor = torch.einsum("kwh, sk -> swh", lms_tensor, rgb_matrix)
199-
200-
return linear_srgb_tensor
201-
202-
203-
def oklab_from_linear_srgb(linear_srgb_tensor: torch.Tensor):
204-
"""Get an Oklab image tensor from a tensor of linear-light sRGB"""
205-
# linear RGB to LMS
206-
lms_matrix = torch.tensor(
207-
[
208-
[0.4122214708, 0.5363325363, 0.0514459929],
209-
[0.2119034982, 0.6806995451, 0.1073969566],
210-
[0.0883024619, 0.2817188376, 0.6299787005],
211-
]
212-
)
213-
214-
lms_tensor = torch.einsum("cwh, kc -> kwh", linear_srgb_tensor, lms_matrix)
215-
216-
# LMS to L*a*b*
217-
lms_tensor_neg_mask = torch.lt(lms_tensor, 0.0)
218-
lms_tensor[lms_tensor_neg_mask] = torch.mul(lms_tensor[lms_tensor_neg_mask], -1.0)
219-
lms_tensor_1 = torch.pow(lms_tensor, 1.0 / 3.0)
220-
lms_tensor[lms_tensor_neg_mask] = torch.mul(lms_tensor[lms_tensor_neg_mask], -1.0)
221-
lms_tensor_1[lms_tensor_neg_mask] = torch.mul(lms_tensor_1[lms_tensor_neg_mask], -1.0)
222-
lab_matrix = torch.tensor(
223-
[
224-
[0.2104542553, 0.7936177850, -0.0040720468],
225-
[1.9779984951, -2.4285922050, 0.4505937099],
226-
[0.0259040371, 0.7827717662, -0.8086757660],
227-
]
228-
)
229-
230-
lab_tensor = torch.einsum("kwh, lk -> lwh", lms_tensor_1, lab_matrix)
231-
232-
return lab_tensor
233-
234-
235170
def find_cusp_tensor(units_ab_tensor: torch.Tensor, steps: int = 1):
236171
"""Compute maximum sRGB lightness and chroma from a tensor of Oklab ab unit vectors"""
237172

0 commit comments

Comments
 (0)