88import requests # type: ignore
99from numpy .typing import NDArray
1010from PIL import Image
11- from sklearn .cluster import KMeans
1211
13- from Pylette .src .color import Color
12+ from Pylette .src .extractors .k_means import k_means_extraction
13+ from Pylette .src .extractors .median_cut import median_cut_extraction
1414from Pylette .src .palette import Palette
15- from Pylette .src .utils import ColorBox
1615
1716ImageType_T : TypeAlias = Union ["os.PathLike[Any]" , bytes , NDArray [float ], str ]
1817
@@ -25,35 +24,6 @@ class ImageType(str, Enum):
2524 NONE = "none"
2625
2726
28- def median_cut_extraction (arr : np .ndarray , height : int , width : int , palette_size : int ) -> list [Color ]:
29- """
30- Extracts a color palette using the median cut algorithm.
31-
32- Parameters:
33- arr (np.ndarray): The input array.
34- height (int): The height of the image.
35- width (int): The width of the image.
36- palette_size (int): The number of colors to extract from the image.
37-
38- Returns:
39- list[Color]: A list of colors extracted from the image.
40- """
41-
42- arr = arr .reshape ((width * height , - 1 ))
43- c = [ColorBox (arr )]
44-
45- # Each iteration, find the largest box, split it, remove original box from list of boxes, and add the two new boxes.
46- while len (c ) < palette_size :
47- largest_c_idx = np .argmax (c )
48- # add the two new boxes to the list, while removing the split box.
49- c = c [:largest_c_idx ] + c [largest_c_idx ].split () + c [largest_c_idx + 1 :]
50-
51- total_pixels = width * height
52- colors = [Color (tuple (map (int , box .average )), box .pixel_count / total_pixels ) for box in c ]
53-
54- return colors
55-
56-
5727def _parse_image_type (image : ImageType_T ) -> ImageType :
5828 """
5929 Determines the type of the input image.
@@ -169,28 +139,3 @@ def request_image(image_url: str) -> Image.Image:
169139 return img
170140 else :
171141 raise ValueError ("The URL did not point to a valid image." )
172-
173-
174- def k_means_extraction (arr : NDArray [float ], height : int , width : int , palette_size : int ) -> list [Color ]:
175- """
176- Extracts a color palette using KMeans.
177-
178- Parameters:
179- arr (NDArray[float]): The input array.
180- height (int): The height of the image.
181- width (int): The width of the image.
182- palette_size (int): The number of colors to extract from the image.
183-
184- Returns:
185- list[Color]: A palette of colors sorted by frequency.
186- """
187- arr = np .reshape (arr , (width * height , - 1 ))
188- model = KMeans (n_clusters = palette_size , n_init = "auto" , init = "k-means++" , random_state = 2024 )
189- labels = model .fit_predict (arr )
190- palette = np .array (model .cluster_centers_ , dtype = int )
191- color_count = np .bincount (labels )
192- color_frequency = color_count / float (np .sum (color_count ))
193- colors = []
194- for color , freq in zip (palette , color_frequency ):
195- colors .append (Color (color , freq ))
196- return colors
0 commit comments