Skip to content

Commit 68a950d

Browse files
committed
Add OKLab and Oklch image utilities and nodes
1 parent dd5758b commit 68a950d

5 files changed

Lines changed: 337 additions & 0 deletions

File tree

invokeai/app/invocations/image.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,14 @@
2525
from invokeai.app.services.image_records.image_records_common import ImageCategory
2626
from invokeai.app.services.shared.invocation_context import InvocationContext
2727
from invokeai.app.util.misc import SEED_MAX
28+
from invokeai.backend.image_util.color_conversion import (
29+
linear_srgb_from_oklab,
30+
linear_srgb_from_oklch,
31+
linear_srgb_from_srgb,
32+
oklab_from_linear_srgb,
33+
oklch_from_oklab,
34+
srgb_from_linear_srgb,
35+
)
2836
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
2937
from invokeai.backend.image_util.safety_checker import SafetyChecker
3038

@@ -397,6 +405,51 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
397405
)
398406

399407

408+
@invocation(
409+
"unsharp_mask_oklab",
410+
title="Unsharp Mask (Oklab)",
411+
tags=["image", "unsharp_mask", "oklab"],
412+
category="image",
413+
version="1.0.0",
414+
)
415+
class OklabUnsharpMaskInvocation(BaseInvocation, WithMetadata, WithBoard):
416+
"""Applies an unsharp mask filter to an image in the Oklab color space"""
417+
418+
image: ImageField = InputField(description="The image to use")
419+
radius: float = InputField(gt=0, description="Unsharp mask radius", default=2)
420+
strength: float = InputField(ge=0, description="Unsharp mask strength", default=50)
421+
422+
def pil_from_array(self, arr: numpy.ndarray) -> Image.Image:
423+
return Image.fromarray((numpy.clip(arr, 0.0, 1.0) * 255).astype("uint8"))
424+
425+
def array_from_pil(self, img: Image.Image) -> numpy.ndarray:
426+
return numpy.array(img, dtype=numpy.float32) / 255.0
427+
428+
def invoke(self, context: InvocationContext) -> ImageOutput:
429+
image = context.images.get_pil(self.image.image_name)
430+
mode = image.mode
431+
432+
alpha_channel = image.getchannel("A") if mode == "RGBA" else None
433+
image = image.convert("RGB")
434+
435+
image_blurred = self.array_from_pil(image.filter(ImageFilter.GaussianBlur(radius=self.radius)))
436+
image_arr = self.array_from_pil(image)
437+
438+
image_oklab = oklab_from_linear_srgb(linear_srgb_from_srgb(image_arr))
439+
image_blurred_oklab = oklab_from_linear_srgb(linear_srgb_from_srgb(image_blurred))
440+
441+
image_oklab += (image_oklab - image_blurred_oklab) * (self.strength / 100.0)
442+
image_oklab = numpy.clip(image_oklab, -1.0, 1.0)
443+
444+
image = self.pil_from_array(srgb_from_linear_srgb(linear_srgb_from_oklab(image_oklab))).convert(mode)
445+
446+
if alpha_channel is not None:
447+
image.putalpha(alpha_channel)
448+
449+
image_dto = context.images.save(image=image)
450+
return ImageOutput.build(image_dto)
451+
452+
400453
PIL_RESAMPLING_MODES = Literal[
401454
"nearest",
402455
"box",
@@ -802,6 +855,40 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
802855
return ImageOutput.build(image_dto)
803856

804857

858+
@invocation(
859+
"img_hue_adjust_oklch",
860+
title="Adjust Image Hue (Oklch)",
861+
tags=["image", "hue", "oklch"],
862+
category="image",
863+
version="1.0.0",
864+
)
865+
class OklchImageHueAdjustmentInvocation(BaseInvocation, WithMetadata, WithBoard):
866+
"""Adjusts the hue of an image in Oklch space."""
867+
868+
image: ImageField = InputField(description="The image to adjust")
869+
hue: int = InputField(default=0, description="The degrees by which to rotate the hue, 0-360")
870+
871+
def invoke(self, context: InvocationContext) -> ImageOutput:
872+
image = context.images.get_pil(self.image.image_name)
873+
mode = image.mode
874+
alpha_channel = image.getchannel("A") if mode == "RGBA" else None
875+
876+
rgb = numpy.asarray(image.convert("RGB"), dtype=numpy.float32) / 255.0
877+
oklch = oklch_from_oklab(oklab_from_linear_srgb(linear_srgb_from_srgb(rgb)))
878+
oklch[..., 2] = (oklch[..., 2] + self.hue) % 360.0
879+
880+
image = Image.fromarray(
881+
numpy.clip(srgb_from_linear_srgb(linear_srgb_from_oklch(oklch)) * 255.0, 0.0, 255.0).astype(numpy.uint8),
882+
mode="RGB",
883+
).convert(mode)
884+
885+
if alpha_channel is not None:
886+
image.putalpha(alpha_channel)
887+
888+
image_dto = context.images.save(image=image)
889+
return ImageOutput.build(image_dto)
890+
891+
805892
COLOR_CHANNELS = Literal[
806893
"Red (RGBA)",
807894
"Green (RGBA)",
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import numpy
2+
3+
4+
def linear_srgb_from_srgb(srgb: numpy.ndarray) -> numpy.ndarray:
5+
return numpy.where(srgb <= 0.0404482362771082, srgb / 12.92, ((srgb + 0.055) / 1.055) ** 2.4)
6+
7+
8+
def srgb_from_linear_srgb(linear_srgb: numpy.ndarray) -> numpy.ndarray:
9+
linear_srgb = numpy.clip(linear_srgb, 0.0, 1.0)
10+
return numpy.where(
11+
linear_srgb <= 0.0031308,
12+
linear_srgb * 12.92,
13+
1.055 * numpy.power(linear_srgb, 1.0 / 2.4) - 0.055,
14+
)
15+
16+
17+
def oklab_from_linear_srgb(linear_srgb: numpy.ndarray) -> numpy.ndarray:
18+
lms_l = 0.4122214708 * linear_srgb[..., 0] + 0.5363325363 * linear_srgb[..., 1] + 0.0514459929 * linear_srgb[..., 2]
19+
lms_m = 0.2119034982 * linear_srgb[..., 0] + 0.6806995451 * linear_srgb[..., 1] + 0.1073969566 * linear_srgb[..., 2]
20+
lms_s = 0.0883024619 * linear_srgb[..., 0] + 0.2817188376 * linear_srgb[..., 1] + 0.6299787005 * linear_srgb[..., 2]
21+
22+
lms_l_cbrt = numpy.cbrt(lms_l)
23+
lms_m_cbrt = numpy.cbrt(lms_m)
24+
lms_s_cbrt = numpy.cbrt(lms_s)
25+
26+
return numpy.stack(
27+
[
28+
0.2104542553 * lms_l_cbrt + 0.7936177850 * lms_m_cbrt - 0.0040720468 * lms_s_cbrt,
29+
1.9779984951 * lms_l_cbrt - 2.4285922050 * lms_m_cbrt + 0.4505937099 * lms_s_cbrt,
30+
0.0259040371 * lms_l_cbrt + 0.7827717662 * lms_m_cbrt - 0.8086757660 * lms_s_cbrt,
31+
],
32+
axis=-1,
33+
)
34+
35+
36+
def linear_srgb_from_oklab(oklab: numpy.ndarray) -> numpy.ndarray:
37+
lms_l_cbrt = oklab[..., 0] + 0.3963377774 * oklab[..., 1] + 0.2158037573 * oklab[..., 2]
38+
lms_m_cbrt = oklab[..., 0] - 0.1055613458 * oklab[..., 1] - 0.0638541728 * oklab[..., 2]
39+
lms_s_cbrt = oklab[..., 0] - 0.0894841775 * oklab[..., 1] - 1.2914855480 * oklab[..., 2]
40+
41+
lms_l = lms_l_cbrt**3
42+
lms_m = lms_m_cbrt**3
43+
lms_s = lms_s_cbrt**3
44+
45+
return numpy.stack(
46+
[
47+
4.0767416621 * lms_l - 3.3077115913 * lms_m + 0.2309699292 * lms_s,
48+
-1.2684380046 * lms_l + 2.6097574011 * lms_m - 0.3413193965 * lms_s,
49+
-0.0041960863 * lms_l - 0.7034186147 * lms_m + 1.7076147010 * lms_s,
50+
],
51+
axis=-1,
52+
)
53+
54+
55+
def oklch_from_oklab(oklab: numpy.ndarray) -> numpy.ndarray:
56+
lightness = oklab[..., 0]
57+
chroma = numpy.sqrt(oklab[..., 1] ** 2 + oklab[..., 2] ** 2)
58+
hue = numpy.degrees(numpy.arctan2(oklab[..., 2], oklab[..., 1])) % 360.0
59+
return numpy.stack([lightness, chroma, hue], axis=-1)
60+
61+
62+
def oklab_from_oklch(oklch: numpy.ndarray) -> numpy.ndarray:
63+
hue_radians = numpy.radians(oklch[..., 2])
64+
a_channel = oklch[..., 1] * numpy.cos(hue_radians)
65+
b_channel = oklch[..., 1] * numpy.sin(hue_radians)
66+
return numpy.stack([oklch[..., 0], a_channel, b_channel], axis=-1)
67+
68+
69+
def linear_srgb_from_oklch(oklch: numpy.ndarray) -> numpy.ndarray:
70+
return linear_srgb_from_oklab(oklab_from_oklch(oklch))

invokeai/invocation_api/__init__.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,15 @@
8080
from invokeai.app.services.shared.invocation_context import InvocationContext
8181
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
8282
from invokeai.app.util.misc import SEED_MAX, get_random_seed
83+
from invokeai.backend.image_util.color_conversion import (
84+
linear_srgb_from_oklab,
85+
linear_srgb_from_oklch,
86+
linear_srgb_from_srgb,
87+
oklab_from_linear_srgb,
88+
oklab_from_oklch,
89+
oklch_from_oklab,
90+
srgb_from_linear_srgb,
91+
)
8392
from invokeai.backend.model_manager.configs.factory import AnyModelConfig, ModelConfigFactory
8493
from invokeai.backend.model_manager.load.load_base import LoadedModel
8594
from invokeai.backend.model_manager.taxonomy import (
@@ -207,6 +216,14 @@
207216
# invokeai.app.util.misc
208217
"SEED_MAX",
209218
"get_random_seed",
219+
# invokeai.backend.image_util.color_conversion
220+
"linear_srgb_from_srgb",
221+
"srgb_from_linear_srgb",
222+
"oklab_from_linear_srgb",
223+
"linear_srgb_from_oklab",
224+
"oklch_from_oklab",
225+
"oklab_from_oklch",
226+
"linear_srgb_from_oklch",
210227
# invokeai.backend.model_manager.taxonomy
211228
"BaseModelType",
212229
"ModelType",
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
from types import SimpleNamespace
2+
from unittest.mock import MagicMock
3+
4+
import numpy
5+
from PIL import Image, ImageFilter
6+
7+
from invokeai.app.invocations.image import ImageField, OklabUnsharpMaskInvocation, OklchImageHueAdjustmentInvocation
8+
from invokeai.backend.image_util.color_conversion import (
9+
linear_srgb_from_oklab,
10+
linear_srgb_from_oklch,
11+
linear_srgb_from_srgb,
12+
oklab_from_linear_srgb,
13+
oklch_from_oklab,
14+
srgb_from_linear_srgb,
15+
)
16+
17+
18+
def test_oklab_unsharp_mask_invocation_preserves_alpha_and_sharpens_in_oklab() -> None:
19+
input_image = Image.new("RGBA", (3, 1))
20+
input_image.putdata(
21+
[
22+
(255, 0, 0, 32),
23+
(0, 255, 0, 128),
24+
(0, 0, 255, 224),
25+
]
26+
)
27+
28+
context = MagicMock()
29+
context.images.get_pil.return_value = input_image
30+
context.images.save.side_effect = lambda image: SimpleNamespace(
31+
image_name="out", width=image.width, height=image.height
32+
)
33+
34+
invocation = OklabUnsharpMaskInvocation(image=ImageField(image_name="in"), radius=1.0, strength=50.0)
35+
output = invocation.invoke(context)
36+
saved_image = context.images.save.call_args.kwargs["image"]
37+
38+
assert output.image.image_name == "out"
39+
assert output.width == 3
40+
assert output.height == 1
41+
assert numpy.asarray(saved_image.getchannel("A")).reshape(-1).tolist() == [32, 128, 224]
42+
43+
rgb = numpy.asarray(input_image.convert("RGB"), dtype=numpy.float32) / 255.0
44+
blurred_rgb = (
45+
numpy.asarray(input_image.convert("RGB").filter(ImageFilter.GaussianBlur(radius=1.0)), dtype=numpy.float32)
46+
/ 255.0
47+
)
48+
49+
rgb_unsharp = numpy.clip(rgb + (rgb - blurred_rgb) * 0.5, 0.0, 1.0)
50+
oklab_unsharp = srgb_from_linear_srgb(
51+
linear_srgb_from_oklab(
52+
numpy.clip(
53+
oklab_from_linear_srgb(linear_srgb_from_srgb(rgb))
54+
+ (
55+
oklab_from_linear_srgb(linear_srgb_from_srgb(rgb))
56+
- oklab_from_linear_srgb(linear_srgb_from_srgb(blurred_rgb))
57+
)
58+
* 0.5,
59+
-1.0,
60+
1.0,
61+
)
62+
)
63+
)
64+
65+
assert not numpy.allclose(oklab_unsharp, rgb_unsharp, atol=1e-3)
66+
assert numpy.allclose(
67+
numpy.asarray(saved_image.convert("RGB"), dtype=numpy.float32) / 255.0, oklab_unsharp, atol=1 / 255.0
68+
)
69+
70+
71+
def test_oklch_hue_adjustment_invocation_preserves_alpha_and_rotates_hue_in_oklch() -> None:
72+
input_image = Image.new("RGBA", (2, 1))
73+
input_image.putdata(
74+
[
75+
(210, 80, 30, 64),
76+
(40, 160, 220, 192),
77+
]
78+
)
79+
80+
context = MagicMock()
81+
context.images.get_pil.return_value = input_image
82+
context.images.save.side_effect = lambda image: SimpleNamespace(
83+
image_name="out", width=image.width, height=image.height
84+
)
85+
86+
invocation = OklchImageHueAdjustmentInvocation(image=ImageField(image_name="in"), hue=180)
87+
output = invocation.invoke(context)
88+
saved_image = context.images.save.call_args.kwargs["image"]
89+
90+
rgb = numpy.asarray(input_image.convert("RGB"), dtype=numpy.float32) / 255.0
91+
oklch = oklch_from_oklab(oklab_from_linear_srgb(linear_srgb_from_srgb(rgb)))
92+
rotated_oklch = oklch.copy()
93+
rotated_oklch[..., 2] = (rotated_oklch[..., 2] + 180.0) % 360.0
94+
expected_rgb = srgb_from_linear_srgb(linear_srgb_from_oklch(rotated_oklch))
95+
96+
assert output.image.image_name == "out"
97+
assert output.width == 2
98+
assert output.height == 1
99+
assert numpy.asarray(saved_image.getchannel("A")).reshape(-1).tolist() == [64, 192]
100+
assert numpy.allclose(
101+
numpy.asarray(saved_image.convert("RGB"), dtype=numpy.float32) / 255.0, expected_rgb, atol=1 / 255.0
102+
)
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import numpy
2+
3+
from invokeai.backend.image_util.color_conversion import (
4+
linear_srgb_from_oklab,
5+
linear_srgb_from_oklch,
6+
linear_srgb_from_srgb,
7+
oklab_from_linear_srgb,
8+
oklab_from_oklch,
9+
oklch_from_oklab,
10+
srgb_from_linear_srgb,
11+
)
12+
13+
14+
def test_srgb_oklab_round_trip() -> None:
15+
srgb = numpy.array(
16+
[
17+
[[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]],
18+
[[1.0, 0.0, 0.0], [0.1, 0.6, 0.9]],
19+
],
20+
dtype=numpy.float32,
21+
)
22+
23+
round_tripped = srgb_from_linear_srgb(linear_srgb_from_oklab(oklab_from_linear_srgb(linear_srgb_from_srgb(srgb))))
24+
25+
assert numpy.allclose(round_tripped, srgb, atol=1e-5)
26+
27+
28+
def test_oklab_from_pure_srgb_red_matches_reference_value() -> None:
29+
srgb_red = numpy.array([[[1.0, 0.0, 0.0]]], dtype=numpy.float32)
30+
31+
oklab_red = oklab_from_linear_srgb(linear_srgb_from_srgb(srgb_red))
32+
33+
assert numpy.allclose(oklab_red[0, 0], [0.62795536, 0.22486306, 0.1258463], atol=1e-6)
34+
35+
36+
def test_oklab_oklch_round_trip() -> None:
37+
oklab = numpy.array(
38+
[
39+
[[0.6, 0.2, 0.1], [0.4, -0.1, 0.05]],
40+
],
41+
dtype=numpy.float32,
42+
)
43+
44+
round_tripped = oklab_from_oklch(oklch_from_oklab(oklab))
45+
46+
assert numpy.allclose(round_tripped, oklab, atol=1e-6)
47+
48+
49+
def test_srgb_oklch_round_trip() -> None:
50+
srgb = numpy.array(
51+
[
52+
[[0.2, 0.4, 0.8], [0.9, 0.3, 0.1]],
53+
],
54+
dtype=numpy.float32,
55+
)
56+
57+
round_tripped = srgb_from_linear_srgb(
58+
linear_srgb_from_oklch(oklch_from_oklab(oklab_from_linear_srgb(linear_srgb_from_srgb(srgb))))
59+
)
60+
61+
assert numpy.allclose(round_tripped, srgb, atol=1e-5)

0 commit comments

Comments
 (0)