Skip to content

Commit 2f83d75

Browse files
committed
refactor: unify shared color conversions
1 parent 8f614bf commit 2f83d75

File tree

7 files changed

+1934
-964
lines changed

7 files changed

+1934
-964
lines changed

invokeai/app/invocations/composition-nodes.py

Lines changed: 54 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -14,22 +14,27 @@
1414
from torchvision.transforms.functional import to_pil_image as pil_image_from_tensor
1515

1616
from invokeai.app.invocations.primitives import ImageOutput
17-
from invokeai.backend.image_util.composition import (
18-
CIELAB_TO_UPLAB_ICC_PATH,
19-
MAX_FLOAT,
20-
equivalent_achromatic_lightness,
21-
gamut_clip_tensor,
17+
from invokeai.backend.image_util.color_conversion import (
2218
hsl_from_srgb,
2319
linear_srgb_from_oklab,
20+
linear_srgb_from_oklch,
2421
linear_srgb_from_srgb,
2522
okhsl_from_srgb,
2623
okhsv_from_srgb,
2724
oklab_from_linear_srgb,
28-
remove_nans,
25+
oklab_from_oklch,
26+
oklch_from_oklab,
2927
srgb_from_hsl,
30-
srgb_from_linear_srgb,
3128
srgb_from_okhsl,
3229
srgb_from_okhsv,
30+
)
31+
from invokeai.backend.image_util.composition import (
32+
CIELAB_TO_UPLAB_ICC_PATH,
33+
MAX_FLOAT,
34+
equivalent_achromatic_lightness,
35+
gamut_clip_tensor,
36+
remove_nans,
37+
srgb_from_linear_srgb,
3338
tensor_from_pil_image,
3439
)
3540
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
@@ -136,20 +141,20 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
136141

137142
if space == "hsv":
138143
hsv_tensor = image_resized_to_grid_as_tensor(image_in.convert("HSV"), normalize=False, multiple_of=1)
139-
hsv_tensor[0, :, :] = torch.remainder(torch.add(hsv_tensor[0, :, :], torch.div(self.degrees, 360.0)), 1.0)
144+
hsv_tensor[0, :, :] = torch.remainder(torch.add(hsv_tensor[0, :, :] * 360.0, self.degrees), 360.0) / 360.0
140145
image_out = pil_image_from_tensor(hsv_tensor, mode="HSV").convert("RGB")
141146

142147
elif space == "okhsl":
143148
rgb_tensor = image_resized_to_grid_as_tensor(image_in.convert("RGB"), normalize=False, multiple_of=1)
144149
hsl_tensor = okhsl_from_srgb(rgb_tensor, steps=(3 if self.ok_high_precision else 1))
145-
hsl_tensor[0, :, :] = torch.remainder(torch.add(hsl_tensor[0, :, :], torch.div(self.degrees, 360.0)), 1.0)
150+
hsl_tensor[0, :, :] = torch.remainder(torch.add(hsl_tensor[0, :, :], self.degrees), 360.0)
146151
rgb_tensor = srgb_from_okhsl(hsl_tensor, alpha=0.0)
147152
image_out = pil_image_from_tensor(rgb_tensor, mode="RGB")
148153

149154
elif space == "okhsv":
150155
rgb_tensor = image_resized_to_grid_as_tensor(image_in.convert("RGB"), normalize=False, multiple_of=1)
151156
hsv_tensor = okhsv_from_srgb(rgb_tensor, steps=(3 if self.ok_high_precision else 1))
152-
hsv_tensor[0, :, :] = torch.remainder(torch.add(hsv_tensor[0, :, :], torch.div(self.degrees, 360.0)), 1.0)
157+
hsv_tensor[0, :, :] = torch.remainder(torch.add(hsv_tensor[0, :, :], self.degrees), 360.0)
153158
rgb_tensor = srgb_from_okhsv(hsv_tensor, alpha=0.0)
154159
image_out = pil_image_from_tensor(rgb_tensor, mode="RGB")
155160

@@ -197,24 +202,9 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
197202
rgb_tensor = image_resized_to_grid_as_tensor(image_in.convert("RGB"), normalize=False, multiple_of=1)
198203

199204
linear_srgb_tensor = linear_srgb_from_srgb(rgb_tensor)
200-
201-
lab_tensor = oklab_from_linear_srgb(linear_srgb_tensor)
202-
203-
# L*a*b* to L*C*h
204-
c_tensor = torch.sqrt(torch.add(torch.pow(lab_tensor[1, :, :], 2.0), torch.pow(lab_tensor[2, :, :], 2.0)))
205-
h_tensor = torch.atan2(lab_tensor[2, :, :], lab_tensor[1, :, :])
206-
207-
# Rotate h
208-
rot_rads = (self.degrees / 180.0) * PI
209-
210-
h_rot = torch.add(h_tensor, rot_rads)
211-
h_rot = torch.remainder(torch.add(h_rot, 2 * PI), 2 * PI)
212-
213-
# L*C*h to L*a*b*
214-
lab_tensor[1, :, :] = torch.mul(c_tensor, torch.cos(h_rot))
215-
lab_tensor[2, :, :] = torch.mul(c_tensor, torch.sin(h_rot))
216-
217-
linear_srgb_tensor = linear_srgb_from_oklab(lab_tensor)
205+
oklch_tensor = oklch_from_oklab(oklab_from_linear_srgb(linear_srgb_tensor))
206+
oklch_tensor[2, :, :] = torch.remainder(torch.add(oklch_tensor[2, :, :], self.degrees), 360.0)
207+
linear_srgb_tensor = linear_srgb_from_oklch(oklch_tensor)
218208

219209
rgb_tensor = srgb_from_linear_srgb(
220210
linear_srgb_tensor, alpha=self.ok_adaptive_gamut, steps=(3 if self.ok_high_precision else 1)
@@ -602,14 +592,14 @@ def prepare_tensors_from_images(
602592
image_hsv_upper, image_hsv_lower = image_upper.convert("HSV"), image_lower.convert("HSV")
603593
upper_hsv_tensor = torch.stack(
604594
[
605-
tensor_from_pil_image(image_hsv_upper.getchannel("H"), normalize=False)[0, :, :],
595+
tensor_from_pil_image(image_hsv_upper.getchannel("H"), normalize=False)[0, :, :] * 360.0,
606596
tensor_from_pil_image(image_hsv_upper.getchannel("S"), normalize=False)[0, :, :],
607597
tensor_from_pil_image(image_hsv_upper.getchannel("V"), normalize=False)[0, :, :],
608598
]
609599
)
610600
lower_hsv_tensor = torch.stack(
611601
[
612-
tensor_from_pil_image(image_hsv_lower.getchannel("H"), normalize=False)[0, :, :],
602+
tensor_from_pil_image(image_hsv_lower.getchannel("H"), normalize=False)[0, :, :] * 360.0,
613603
tensor_from_pil_image(image_hsv_lower.getchannel("S"), normalize=False)[0, :, :],
614604
tensor_from_pil_image(image_hsv_lower.getchannel("V"), normalize=False)[0, :, :],
615605
]
@@ -655,29 +645,8 @@ def prepare_tensors_from_images(
655645
if "oklch" in required:
656646
upper_oklab_tensor = oklab_from_linear_srgb(upper_rgb_l_tensor)
657647
lower_oklab_tensor = oklab_from_linear_srgb(lower_rgb_l_tensor)
658-
659-
upper_oklch_tensor = torch.stack(
660-
[
661-
upper_oklab_tensor[0, :, :],
662-
torch.sqrt(
663-
torch.add(
664-
torch.pow(upper_oklab_tensor[1, :, :], 2.0), torch.pow(upper_oklab_tensor[2, :, :], 2.0)
665-
)
666-
),
667-
torch.atan2(upper_oklab_tensor[2, :, :], upper_oklab_tensor[1, :, :]),
668-
]
669-
)
670-
lower_oklch_tensor = torch.stack(
671-
[
672-
lower_oklab_tensor[0, :, :],
673-
torch.sqrt(
674-
torch.add(
675-
torch.pow(lower_oklab_tensor[1, :, :], 2.0), torch.pow(lower_oklab_tensor[2, :, :], 2.0)
676-
)
677-
),
678-
torch.atan2(lower_oklab_tensor[2, :, :], lower_oklab_tensor[1, :, :]),
679-
]
680-
)
648+
upper_oklch_tensor = oklch_from_oklab(upper_oklab_tensor)
649+
lower_oklch_tensor = oklch_from_oklab(lower_oklab_tensor)
681650

682651
return (
683652
upper_rgb_l_tensor,
@@ -736,7 +705,17 @@ def adaptive_clipped(rgb_tensor: torch.Tensor, clamp: bool = True, replace_with:
736705
"HSL": lambda t: linear_srgb_from_srgb(srgb_from_hsl(t)),
737706
"HSV": lambda t: linear_srgb_from_srgb(
738707
tensor_from_pil_image(
739-
pil_image_from_tensor(t.clamp(0.0, 1.0), mode="HSV").convert("RGB"), normalize=False
708+
pil_image_from_tensor(
709+
torch.stack(
710+
[
711+
torch.remainder(t[0, :, :], 360.0) / 360.0,
712+
t[1, :, :].clamp(0.0, 1.0),
713+
t[2, :, :].clamp(0.0, 1.0),
714+
]
715+
),
716+
mode="HSV",
717+
).convert("RGB"),
718+
normalize=False,
740719
)
741720
),
742721
"Okhsl": lambda t: linear_srgb_from_srgb(
@@ -745,15 +724,7 @@ def adaptive_clipped(rgb_tensor: torch.Tensor, clamp: bool = True, replace_with:
745724
"Okhsv": lambda t: linear_srgb_from_srgb(
746725
srgb_from_okhsv(t, alpha=self.adaptive_gamut, steps=(3 if self.high_precision else 1))
747726
),
748-
"Oklch": lambda t: linear_srgb_from_oklab(
749-
torch.stack(
750-
[
751-
t[0, :, :],
752-
torch.mul(t[1, :, :], torch.cos(t[2, :, :])),
753-
torch.mul(t[1, :, :], torch.sin(t[2, :, :])),
754-
]
755-
)
756-
),
727+
"Oklch": lambda t: linear_srgb_from_oklab(oklab_from_oklch(t)),
757728
"LCh": lambda t: linear_srgb_from_srgb(
758729
tensor_from_pil_image(
759730
self.image_convert_with_xform(
@@ -784,9 +755,9 @@ def adaptive_clipped(rgb_tensor: torch.Tensor, clamp: bool = True, replace_with:
784755
alpha_upper_tensor,
785756
alpha_lower_tensor,
786757
mask_tensor,
787-
upper_hsv_tensor, # h_rgb, s_hsv, v_hsv
758+
upper_hsv_tensor, # h_hsv_degrees, s_hsv, v_hsv
788759
lower_hsv_tensor,
789-
upper_hsl_tensor, # , s_hsl, l_hsl
760+
upper_hsl_tensor, # h_hsl_degrees, s_hsl, l_hsl
790761
lower_hsl_tensor,
791762
upper_lab_tensor, # l_lab, a_lab, b_lab
792763
lower_lab_tensor,
@@ -796,11 +767,11 @@ def adaptive_clipped(rgb_tensor: torch.Tensor, clamp: bool = True, replace_with:
796767
lower_l_eal_tensor,
797768
upper_oklab_tensor, # l_oklab, a_oklab, b_oklab
798769
lower_oklab_tensor,
799-
upper_oklch_tensor, # , c_oklab, h_oklab
770+
upper_oklch_tensor, # l_oklab, c_oklab, h_oklab_degrees
800771
lower_oklch_tensor,
801-
upper_okhsv_tensor, # h_okhsv, s_okhsv, v_okhsv
772+
upper_okhsv_tensor, # h_okhsv_degrees, s_okhsv, v_okhsv
802773
lower_okhsv_tensor,
803-
upper_okhsl_tensor, # h_okhsl, s_okhsl, l_r_oklab
774+
upper_okhsl_tensor, # h_okhsl_degrees, s_okhsl, l_r_oklab
804775
lower_okhsl_tensor,
805776
) = image_tensors
806777

@@ -850,6 +821,17 @@ def adaptive_clipped(rgb_tensor: torch.Tensor, clamp: bool = True, replace_with:
850821
"LCh": 2,
851822
}[color_space]
852823

824+
hue_period = {
825+
"RGB": None,
826+
"Linear": None,
827+
"HSL": 360.0,
828+
"HSV": 360.0,
829+
"Okhsl": 360.0,
830+
"Okhsv": 360.0,
831+
"Oklch": 360.0,
832+
"LCh": 2.0 * PI,
833+
}[color_space]
834+
853835
if blend_mode == "Normal":
854836
upper_rgb_l_tensor = reassembly_function(upper_space_tensor)
855837

@@ -982,19 +964,19 @@ def adaptive_clipped(rgb_tensor: torch.Tensor, clamp: bool = True, replace_with:
982964
elif blend_mode == "Linear Dodge (Add)":
983965
lower_space_tensor = torch.add(lower_space_tensor, upper_space_tensor)
984966
if hue_index is not None:
985-
lower_space_tensor[hue_index, :, :] = torch.remainder(lower_space_tensor[hue_index, :, :], 1.0)
967+
lower_space_tensor[hue_index, :, :] = torch.remainder(lower_space_tensor[hue_index, :, :], hue_period)
986968
upper_rgb_l_tensor = adaptive_clipped(reassembly_function(lower_space_tensor))
987969

988970
elif blend_mode == "Color Dodge":
989971
lower_space_tensor = torch.div(lower_space_tensor, torch.add(torch.mul(upper_space_tensor, -1.0), 1.0))
990972
if hue_index is not None:
991-
lower_space_tensor[hue_index, :, :] = torch.remainder(lower_space_tensor[hue_index, :, :], 1.0)
973+
lower_space_tensor[hue_index, :, :] = torch.remainder(lower_space_tensor[hue_index, :, :], hue_period)
992974
upper_rgb_l_tensor = adaptive_clipped(reassembly_function(lower_space_tensor))
993975

994976
elif blend_mode == "Divide":
995977
lower_space_tensor = torch.div(lower_space_tensor, upper_space_tensor)
996978
if hue_index is not None:
997-
lower_space_tensor[hue_index, :, :] = torch.remainder(lower_space_tensor[hue_index, :, :], 1.0)
979+
lower_space_tensor[hue_index, :, :] = torch.remainder(lower_space_tensor[hue_index, :, :], hue_period)
998980
upper_rgb_l_tensor = adaptive_clipped(reassembly_function(lower_space_tensor))
999981

1000982
elif blend_mode == "Linear Burn":
@@ -1088,7 +1070,7 @@ def adaptive_clipped(rgb_tensor: torch.Tensor, clamp: bool = True, replace_with:
10881070
elif blend_mode == "Subtract":
10891071
lower_space_tensor = torch.sub(lower_space_tensor, upper_space_tensor)
10901072
if hue_index is not None:
1091-
lower_space_tensor[hue_index, :, :] = torch.remainder(lower_space_tensor[hue_index, :, :], 1.0)
1073+
lower_space_tensor[hue_index, :, :] = torch.remainder(lower_space_tensor[hue_index, :, :], hue_period)
10921074
upper_rgb_l_tensor = adaptive_clipped(reassembly_function(lower_space_tensor))
10931075

10941076
elif blend_mode == "Difference":

invokeai/app/invocations/image.py

Lines changed: 45 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,29 @@
3838
from invokeai.backend.image_util.safety_checker import SafetyChecker
3939

4040

41+
def _extract_alpha_channel(image: Image.Image) -> Image.Image | None:
42+
if image.mode in ("RGBA", "LA", "PA"):
43+
return image.getchannel("A")
44+
return None
45+
46+
47+
def _restore_original_mode(image: Image.Image, mode: str, alpha_channel: Image.Image | None) -> Image.Image:
48+
if alpha_channel is None:
49+
return image.convert(mode)
50+
51+
if mode == "RGBA":
52+
image = image.convert("RGB")
53+
elif mode == "LA":
54+
image = image.convert("L")
55+
elif mode == "PA":
56+
image = image.convert("P")
57+
else:
58+
return image.convert(mode)
59+
60+
image.putalpha(alpha_channel)
61+
return image
62+
63+
4164
@invocation("show_image", title="Show Image", tags=["image"], category="image", version="1.0.1")
4265
class ShowImageInvocation(BaseInvocation):
4366
"""Displays a provided image using the OS image viewer, and passes it forward in the pipeline."""
@@ -382,7 +405,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
382405
image = context.images.get_pil(self.image.image_name)
383406
mode = image.mode
384407

385-
alpha_channel = image.getchannel("A") if mode == "RGBA" else None
408+
alpha_channel = _extract_alpha_channel(image)
386409
image = image.convert("RGB")
387410
image_blurred = self.array_from_pil(image.filter(ImageFilter.GaussianBlur(radius=self.radius)))
388411

@@ -431,7 +454,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
431454
image = context.images.get_pil(self.image.image_name)
432455
mode = image.mode
433456

434-
alpha_channel = image.getchannel("A") if mode == "RGBA" else None
457+
alpha_channel = _extract_alpha_channel(image)
435458
image = image.convert("RGB")
436459

437460
image_blurred = self.tensor_from_pil(image.filter(ImageFilter.GaussianBlur(radius=self.radius)))
@@ -443,10 +466,11 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
443466
image_oklab[0, ...] += (image_oklab[0, ...] - image_blurred_oklab[0, ...]) * (self.strength / 100.0)
444467
image_oklab = torch.clamp(image_oklab, -1.0, 1.0)
445468

446-
image = self.pil_from_tensor(srgb_from_linear_srgb(linear_srgb_from_oklab(image_oklab))).convert(mode)
447-
448-
if alpha_channel is not None:
449-
image.putalpha(alpha_channel)
469+
image = _restore_original_mode(
470+
self.pil_from_tensor(srgb_from_linear_srgb(linear_srgb_from_oklab(image_oklab))),
471+
mode,
472+
alpha_channel,
473+
)
450474

451475
image_dto = context.images.save(image=image)
452476
return ImageOutput.build(image_dto)
@@ -854,25 +878,26 @@ class OklchImageHueAdjustmentInvocation(BaseInvocation, WithMetadata, WithBoard)
854878
def invoke(self, context: InvocationContext) -> ImageOutput:
855879
image = context.images.get_pil(self.image.image_name)
856880
mode = image.mode
857-
alpha_channel = image.getchannel("A") if mode == "RGBA" else None
881+
alpha_channel = _extract_alpha_channel(image)
858882

859883
rgb = torch.from_numpy(numpy.asarray(image.convert("RGB"), dtype=numpy.float32) / 255.0).permute(2, 0, 1)
860884
oklch = oklch_from_oklab(oklab_from_linear_srgb(linear_srgb_from_srgb(rgb)))
861885
oklch[2, ...] = (oklch[2, ...] + self.hue) % 360.0
862886

863-
image = Image.fromarray(
864-
(
865-
torch.clamp(srgb_from_linear_srgb(linear_srgb_from_oklch(oklch)), 0.0, 1.0)
866-
.permute(1, 2, 0)
867-
.cpu()
868-
.numpy()
869-
* 255.0
870-
).astype(numpy.uint8),
871-
mode="RGB",
872-
).convert(mode)
873-
874-
if alpha_channel is not None:
875-
image.putalpha(alpha_channel)
887+
image = _restore_original_mode(
888+
Image.fromarray(
889+
(
890+
torch.clamp(srgb_from_linear_srgb(linear_srgb_from_oklch(oklch)), 0.0, 1.0)
891+
.permute(1, 2, 0)
892+
.cpu()
893+
.numpy()
894+
* 255.0
895+
).astype(numpy.uint8),
896+
mode="RGB",
897+
),
898+
mode,
899+
alpha_channel,
900+
)
876901

877902
image_dto = context.images.save(image=image)
878903
return ImageOutput.build(image_dto)

0 commit comments

Comments
 (0)