1414from torchvision .transforms .functional import to_pil_image as pil_image_from_tensor
1515
1616from invokeai .app .invocations .primitives import ImageOutput
17- from invokeai .backend .image_util .composition import (
18- CIELAB_TO_UPLAB_ICC_PATH ,
19- MAX_FLOAT ,
20- equivalent_achromatic_lightness ,
21- gamut_clip_tensor ,
17+ from invokeai .backend .image_util .color_conversion import (
2218 hsl_from_srgb ,
2319 linear_srgb_from_oklab ,
20+ linear_srgb_from_oklch ,
2421 linear_srgb_from_srgb ,
2522 okhsl_from_srgb ,
2623 okhsv_from_srgb ,
2724 oklab_from_linear_srgb ,
28- remove_nans ,
25+ oklab_from_oklch ,
26+ oklch_from_oklab ,
2927 srgb_from_hsl ,
30- srgb_from_linear_srgb ,
3128 srgb_from_okhsl ,
3229 srgb_from_okhsv ,
30+ )
31+ from invokeai .backend .image_util .composition import (
32+ CIELAB_TO_UPLAB_ICC_PATH ,
33+ MAX_FLOAT ,
34+ equivalent_achromatic_lightness ,
35+ gamut_clip_tensor ,
36+ remove_nans ,
37+ srgb_from_linear_srgb ,
3338 tensor_from_pil_image ,
3439)
3540from invokeai .backend .stable_diffusion .diffusers_pipeline import image_resized_to_grid_as_tensor
@@ -136,20 +141,20 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
136141
137142 if space == "hsv" :
138143 hsv_tensor = image_resized_to_grid_as_tensor (image_in .convert ("HSV" ), normalize = False , multiple_of = 1 )
139- hsv_tensor [0 , :, :] = torch .remainder (torch .add (hsv_tensor [0 , :, :], torch . div ( self .degrees , 360.0 )), 1.0 )
144+ hsv_tensor [0 , :, :] = torch .remainder (torch .add (hsv_tensor [0 , :, :] * 360.0 , self .degrees ) , 360.0 ) / 360.0
140145 image_out = pil_image_from_tensor (hsv_tensor , mode = "HSV" ).convert ("RGB" )
141146
142147 elif space == "okhsl" :
143148 rgb_tensor = image_resized_to_grid_as_tensor (image_in .convert ("RGB" ), normalize = False , multiple_of = 1 )
144149 hsl_tensor = okhsl_from_srgb (rgb_tensor , steps = (3 if self .ok_high_precision else 1 ))
145- hsl_tensor [0 , :, :] = torch .remainder (torch .add (hsl_tensor [0 , :, :], torch . div ( self .degrees , 360.0 )), 1 .0 )
150+ hsl_tensor [0 , :, :] = torch .remainder (torch .add (hsl_tensor [0 , :, :], self .degrees ), 360 .0 )
146151 rgb_tensor = srgb_from_okhsl (hsl_tensor , alpha = 0.0 )
147152 image_out = pil_image_from_tensor (rgb_tensor , mode = "RGB" )
148153
149154 elif space == "okhsv" :
150155 rgb_tensor = image_resized_to_grid_as_tensor (image_in .convert ("RGB" ), normalize = False , multiple_of = 1 )
151156 hsv_tensor = okhsv_from_srgb (rgb_tensor , steps = (3 if self .ok_high_precision else 1 ))
152- hsv_tensor [0 , :, :] = torch .remainder (torch .add (hsv_tensor [0 , :, :], torch . div ( self .degrees , 360.0 )), 1 .0 )
157+ hsv_tensor [0 , :, :] = torch .remainder (torch .add (hsv_tensor [0 , :, :], self .degrees ), 360 .0 )
153158 rgb_tensor = srgb_from_okhsv (hsv_tensor , alpha = 0.0 )
154159 image_out = pil_image_from_tensor (rgb_tensor , mode = "RGB" )
155160
@@ -197,24 +202,9 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
197202 rgb_tensor = image_resized_to_grid_as_tensor (image_in .convert ("RGB" ), normalize = False , multiple_of = 1 )
198203
199204 linear_srgb_tensor = linear_srgb_from_srgb (rgb_tensor )
200-
201- lab_tensor = oklab_from_linear_srgb (linear_srgb_tensor )
202-
203- # L*a*b* to L*C*h
204- c_tensor = torch .sqrt (torch .add (torch .pow (lab_tensor [1 , :, :], 2.0 ), torch .pow (lab_tensor [2 , :, :], 2.0 )))
205- h_tensor = torch .atan2 (lab_tensor [2 , :, :], lab_tensor [1 , :, :])
206-
207- # Rotate h
208- rot_rads = (self .degrees / 180.0 ) * PI
209-
210- h_rot = torch .add (h_tensor , rot_rads )
211- h_rot = torch .remainder (torch .add (h_rot , 2 * PI ), 2 * PI )
212-
213- # L*C*h to L*a*b*
214- lab_tensor [1 , :, :] = torch .mul (c_tensor , torch .cos (h_rot ))
215- lab_tensor [2 , :, :] = torch .mul (c_tensor , torch .sin (h_rot ))
216-
217- linear_srgb_tensor = linear_srgb_from_oklab (lab_tensor )
205+ oklch_tensor = oklch_from_oklab (oklab_from_linear_srgb (linear_srgb_tensor ))
206+ oklch_tensor [2 , :, :] = torch .remainder (torch .add (oklch_tensor [2 , :, :], self .degrees ), 360.0 )
207+ linear_srgb_tensor = linear_srgb_from_oklch (oklch_tensor )
218208
219209 rgb_tensor = srgb_from_linear_srgb (
220210 linear_srgb_tensor , alpha = self .ok_adaptive_gamut , steps = (3 if self .ok_high_precision else 1 )
@@ -602,14 +592,14 @@ def prepare_tensors_from_images(
602592 image_hsv_upper , image_hsv_lower = image_upper .convert ("HSV" ), image_lower .convert ("HSV" )
603593 upper_hsv_tensor = torch .stack (
604594 [
605- tensor_from_pil_image (image_hsv_upper .getchannel ("H" ), normalize = False )[0 , :, :],
595+ tensor_from_pil_image (image_hsv_upper .getchannel ("H" ), normalize = False )[0 , :, :] * 360.0 ,
606596 tensor_from_pil_image (image_hsv_upper .getchannel ("S" ), normalize = False )[0 , :, :],
607597 tensor_from_pil_image (image_hsv_upper .getchannel ("V" ), normalize = False )[0 , :, :],
608598 ]
609599 )
610600 lower_hsv_tensor = torch .stack (
611601 [
612- tensor_from_pil_image (image_hsv_lower .getchannel ("H" ), normalize = False )[0 , :, :],
602+ tensor_from_pil_image (image_hsv_lower .getchannel ("H" ), normalize = False )[0 , :, :] * 360.0 ,
613603 tensor_from_pil_image (image_hsv_lower .getchannel ("S" ), normalize = False )[0 , :, :],
614604 tensor_from_pil_image (image_hsv_lower .getchannel ("V" ), normalize = False )[0 , :, :],
615605 ]
@@ -655,29 +645,8 @@ def prepare_tensors_from_images(
655645 if "oklch" in required :
656646 upper_oklab_tensor = oklab_from_linear_srgb (upper_rgb_l_tensor )
657647 lower_oklab_tensor = oklab_from_linear_srgb (lower_rgb_l_tensor )
658-
659- upper_oklch_tensor = torch .stack (
660- [
661- upper_oklab_tensor [0 , :, :],
662- torch .sqrt (
663- torch .add (
664- torch .pow (upper_oklab_tensor [1 , :, :], 2.0 ), torch .pow (upper_oklab_tensor [2 , :, :], 2.0 )
665- )
666- ),
667- torch .atan2 (upper_oklab_tensor [2 , :, :], upper_oklab_tensor [1 , :, :]),
668- ]
669- )
670- lower_oklch_tensor = torch .stack (
671- [
672- lower_oklab_tensor [0 , :, :],
673- torch .sqrt (
674- torch .add (
675- torch .pow (lower_oklab_tensor [1 , :, :], 2.0 ), torch .pow (lower_oklab_tensor [2 , :, :], 2.0 )
676- )
677- ),
678- torch .atan2 (lower_oklab_tensor [2 , :, :], lower_oklab_tensor [1 , :, :]),
679- ]
680- )
648+ upper_oklch_tensor = oklch_from_oklab (upper_oklab_tensor )
649+ lower_oklch_tensor = oklch_from_oklab (lower_oklab_tensor )
681650
682651 return (
683652 upper_rgb_l_tensor ,
@@ -736,7 +705,17 @@ def adaptive_clipped(rgb_tensor: torch.Tensor, clamp: bool = True, replace_with:
736705 "HSL" : lambda t : linear_srgb_from_srgb (srgb_from_hsl (t )),
737706 "HSV" : lambda t : linear_srgb_from_srgb (
738707 tensor_from_pil_image (
739- pil_image_from_tensor (t .clamp (0.0 , 1.0 ), mode = "HSV" ).convert ("RGB" ), normalize = False
708+ pil_image_from_tensor (
709+ torch .stack (
710+ [
711+ torch .remainder (t [0 , :, :], 360.0 ) / 360.0 ,
712+ t [1 , :, :].clamp (0.0 , 1.0 ),
713+ t [2 , :, :].clamp (0.0 , 1.0 ),
714+ ]
715+ ),
716+ mode = "HSV" ,
717+ ).convert ("RGB" ),
718+ normalize = False ,
740719 )
741720 ),
742721 "Okhsl" : lambda t : linear_srgb_from_srgb (
@@ -745,15 +724,7 @@ def adaptive_clipped(rgb_tensor: torch.Tensor, clamp: bool = True, replace_with:
745724 "Okhsv" : lambda t : linear_srgb_from_srgb (
746725 srgb_from_okhsv (t , alpha = self .adaptive_gamut , steps = (3 if self .high_precision else 1 ))
747726 ),
748- "Oklch" : lambda t : linear_srgb_from_oklab (
749- torch .stack (
750- [
751- t [0 , :, :],
752- torch .mul (t [1 , :, :], torch .cos (t [2 , :, :])),
753- torch .mul (t [1 , :, :], torch .sin (t [2 , :, :])),
754- ]
755- )
756- ),
727+ "Oklch" : lambda t : linear_srgb_from_oklab (oklab_from_oklch (t )),
757728 "LCh" : lambda t : linear_srgb_from_srgb (
758729 tensor_from_pil_image (
759730 self .image_convert_with_xform (
@@ -784,9 +755,9 @@ def adaptive_clipped(rgb_tensor: torch.Tensor, clamp: bool = True, replace_with:
784755 alpha_upper_tensor ,
785756 alpha_lower_tensor ,
786757 mask_tensor ,
787- upper_hsv_tensor , # h_rgb, s_hsv, v_hsv
758+ upper_hsv_tensor , # h_hsv_degrees, s_hsv, v_hsv
788759 lower_hsv_tensor ,
789- upper_hsl_tensor , # , s_hsl, l_hsl
760+ upper_hsl_tensor , # h_hsl_degrees, s_hsl, l_hsl
790761 lower_hsl_tensor ,
791762 upper_lab_tensor , # l_lab, a_lab, b_lab
792763 lower_lab_tensor ,
@@ -796,11 +767,11 @@ def adaptive_clipped(rgb_tensor: torch.Tensor, clamp: bool = True, replace_with:
796767 lower_l_eal_tensor ,
797768 upper_oklab_tensor , # l_oklab, a_oklab, b_oklab
798769 lower_oklab_tensor ,
799- upper_oklch_tensor , # , c_oklab, h_oklab
770+ upper_oklch_tensor , # l_oklab , c_oklab, h_oklab_degrees
800771 lower_oklch_tensor ,
801- upper_okhsv_tensor , # h_okhsv , s_okhsv, v_okhsv
772+ upper_okhsv_tensor , # h_okhsv_degrees , s_okhsv, v_okhsv
802773 lower_okhsv_tensor ,
803- upper_okhsl_tensor , # h_okhsl , s_okhsl, l_r_oklab
774+ upper_okhsl_tensor , # h_okhsl_degrees , s_okhsl, l_r_oklab
804775 lower_okhsl_tensor ,
805776 ) = image_tensors
806777
@@ -850,6 +821,17 @@ def adaptive_clipped(rgb_tensor: torch.Tensor, clamp: bool = True, replace_with:
850821 "LCh" : 2 ,
851822 }[color_space ]
852823
824+ hue_period = {
825+ "RGB" : None ,
826+ "Linear" : None ,
827+ "HSL" : 360.0 ,
828+ "HSV" : 360.0 ,
829+ "Okhsl" : 360.0 ,
830+ "Okhsv" : 360.0 ,
831+ "Oklch" : 360.0 ,
832+ "LCh" : 2.0 * PI ,
833+ }[color_space ]
834+
853835 if blend_mode == "Normal" :
854836 upper_rgb_l_tensor = reassembly_function (upper_space_tensor )
855837
@@ -982,19 +964,19 @@ def adaptive_clipped(rgb_tensor: torch.Tensor, clamp: bool = True, replace_with:
982964 elif blend_mode == "Linear Dodge (Add)" :
983965 lower_space_tensor = torch .add (lower_space_tensor , upper_space_tensor )
984966 if hue_index is not None :
985- lower_space_tensor [hue_index , :, :] = torch .remainder (lower_space_tensor [hue_index , :, :], 1.0 )
967+ lower_space_tensor [hue_index , :, :] = torch .remainder (lower_space_tensor [hue_index , :, :], hue_period )
986968 upper_rgb_l_tensor = adaptive_clipped (reassembly_function (lower_space_tensor ))
987969
988970 elif blend_mode == "Color Dodge" :
989971 lower_space_tensor = torch .div (lower_space_tensor , torch .add (torch .mul (upper_space_tensor , - 1.0 ), 1.0 ))
990972 if hue_index is not None :
991- lower_space_tensor [hue_index , :, :] = torch .remainder (lower_space_tensor [hue_index , :, :], 1.0 )
973+ lower_space_tensor [hue_index , :, :] = torch .remainder (lower_space_tensor [hue_index , :, :], hue_period )
992974 upper_rgb_l_tensor = adaptive_clipped (reassembly_function (lower_space_tensor ))
993975
994976 elif blend_mode == "Divide" :
995977 lower_space_tensor = torch .div (lower_space_tensor , upper_space_tensor )
996978 if hue_index is not None :
997- lower_space_tensor [hue_index , :, :] = torch .remainder (lower_space_tensor [hue_index , :, :], 1.0 )
979+ lower_space_tensor [hue_index , :, :] = torch .remainder (lower_space_tensor [hue_index , :, :], hue_period )
998980 upper_rgb_l_tensor = adaptive_clipped (reassembly_function (lower_space_tensor ))
999981
1000982 elif blend_mode == "Linear Burn" :
@@ -1088,7 +1070,7 @@ def adaptive_clipped(rgb_tensor: torch.Tensor, clamp: bool = True, replace_with:
10881070 elif blend_mode == "Subtract" :
10891071 lower_space_tensor = torch .sub (lower_space_tensor , upper_space_tensor )
10901072 if hue_index is not None :
1091- lower_space_tensor [hue_index , :, :] = torch .remainder (lower_space_tensor [hue_index , :, :], 1.0 )
1073+ lower_space_tensor [hue_index , :, :] = torch .remainder (lower_space_tensor [hue_index , :, :], hue_period )
10921074 upper_rgb_l_tensor = adaptive_clipped (reassembly_function (lower_space_tensor ))
10931075
10941076 elif blend_mode == "Difference" :
0 commit comments