Skip to content

Commit 3345d3d

Browse files
authored
Implement KMeans and MedianCut extractors in terms of the extractor base class (#55)
* Extract methods to separate file * Add a protocol and base class for color extractors * Implement median cut with a colorextractor class * Implement KMeans using the Extractor Base Class * Fix misc typing in utils * Bump patch version, and update Changelog
1 parent 34d11ff commit 3345d3d

7 files changed

Lines changed: 141 additions & 66 deletions

File tree

CHANGELOG.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1010

1111
# Released
1212

13+
## 4.0.1 27/01/2025
14+
15+
### Added
16+
17+
- A `ColorExtractor`-protocol that defines an interface for color extractors.
18+
- Create a `ColorExtractorBase` abstract class that extractors can inherit from to implement the interface.
19+
20+
### Changed
21+
- The implementation of `median_cut_extraction` and `k_means_extraction` is now
22+
implemented as in terms of subclasses of the `ColorExtractorBase`
23+
1324
## [4.0.0] 08/10/2024
1425

1526
### Changed

Pylette/src/color_extraction.py

Lines changed: 2 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,10 @@
88
import requests # type: ignore
99
from numpy.typing import NDArray
1010
from PIL import Image
11-
from sklearn.cluster import KMeans
1211

13-
from Pylette.src.color import Color
12+
from Pylette.src.extractors.k_means import k_means_extraction
13+
from Pylette.src.extractors.median_cut import median_cut_extraction
1414
from Pylette.src.palette import Palette
15-
from Pylette.src.utils import ColorBox
1615

1716
ImageType_T: TypeAlias = Union["os.PathLike[Any]", bytes, NDArray[float], str]
1817

@@ -25,35 +24,6 @@ class ImageType(str, Enum):
2524
NONE = "none"
2625

2726

28-
def median_cut_extraction(arr: np.ndarray, height: int, width: int, palette_size: int) -> list[Color]:
29-
"""
30-
Extracts a color palette using the median cut algorithm.
31-
32-
Parameters:
33-
arr (np.ndarray): The input array.
34-
height (int): The height of the image.
35-
width (int): The width of the image.
36-
palette_size (int): The number of colors to extract from the image.
37-
38-
Returns:
39-
list[Color]: A list of colors extracted from the image.
40-
"""
41-
42-
arr = arr.reshape((width * height, -1))
43-
c = [ColorBox(arr)]
44-
45-
# Each iteration, find the largest box, split it, remove original box from list of boxes, and add the two new boxes.
46-
while len(c) < palette_size:
47-
largest_c_idx = np.argmax(c)
48-
# add the two new boxes to the list, while removing the split box.
49-
c = c[:largest_c_idx] + c[largest_c_idx].split() + c[largest_c_idx + 1 :]
50-
51-
total_pixels = width * height
52-
colors = [Color(tuple(map(int, box.average)), box.pixel_count / total_pixels) for box in c]
53-
54-
return colors
55-
56-
5727
def _parse_image_type(image: ImageType_T) -> ImageType:
5828
"""
5929
Determines the type of the input image.
@@ -169,28 +139,3 @@ def request_image(image_url: str) -> Image.Image:
169139
return img
170140
else:
171141
raise ValueError("The URL did not point to a valid image.")
172-
173-
174-
def k_means_extraction(arr: NDArray[float], height: int, width: int, palette_size: int) -> list[Color]:
175-
"""
176-
Extracts a color palette using KMeans.
177-
178-
Parameters:
179-
arr (NDArray[float]): The input array.
180-
height (int): The height of the image.
181-
width (int): The width of the image.
182-
palette_size (int): The number of colors to extract from the image.
183-
184-
Returns:
185-
list[Color]: A palette of colors sorted by frequency.
186-
"""
187-
arr = np.reshape(arr, (width * height, -1))
188-
model = KMeans(n_clusters=palette_size, n_init="auto", init="k-means++", random_state=2024)
189-
labels = model.fit_predict(arr)
190-
palette = np.array(model.cluster_centers_, dtype=int)
191-
color_count = np.bincount(labels)
192-
color_frequency = color_count / float(np.sum(color_count))
193-
colors = []
194-
for color, freq in zip(palette, color_frequency):
195-
colors.append(Color(color, freq))
196-
return colors

Pylette/src/extractors/k_means.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import numpy as np
2+
from numpy.typing import NDArray
3+
from sklearn.cluster import KMeans
4+
5+
from Pylette.src.color import Color
6+
from Pylette.src.extractors.protocol import NP_T, ColorExtractorBase
7+
8+
9+
class KMeansExtractor(ColorExtractorBase):
10+
def extract(self, arr: NDArray[NP_T], height: int, width: int, palette_size: int) -> list[Color]:
11+
"""
12+
Extracts a color palette using KMeans.
13+
14+
Parameters:
15+
arr (NDArray[float]): The input array.
16+
height (int): The height of the image.
17+
width (int): The width of the image.
18+
palette_size (int): The number of colors to extract from the image.
19+
20+
Returns:
21+
list[Color]: A palette of colors sorted by frequency.
22+
"""
23+
arr = np.reshape(arr, (width * height, -1))
24+
model = KMeans(n_clusters=palette_size, n_init="auto", init="k-means++", random_state=2024)
25+
labels = model.fit_predict(arr)
26+
palette = np.array(model.cluster_centers_, dtype=int)
27+
color_count = np.bincount(labels)
28+
color_frequency = color_count / float(np.sum(color_count))
29+
colors = []
30+
for color, freq in zip(palette, color_frequency):
31+
colors.append(Color(color, freq))
32+
return colors
33+
34+
35+
def k_means_extraction(arr: NDArray[NP_T], height: int, width: int, palette_size: int) -> list[Color]:
36+
"""
37+
Extracts a color palette using KMeans.
38+
39+
Parameters:
40+
arr (NDArray[float]): The input array.
41+
height (int): The height of the image.
42+
width (int): The width of the image.
43+
palette_size (int): The number of colors to extract from the image.
44+
45+
Returns:
46+
list[Color]: A palette of colors sorted by frequency.
47+
"""
48+
return KMeansExtractor().extract(arr=arr, height=height, width=width, palette_size=palette_size)
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import numpy as np
2+
from numpy.typing import NDArray
3+
4+
from Pylette.src.color import Color
5+
from Pylette.src.extractors.protocol import NP_T, ColorExtractorBase
6+
from Pylette.src.utils import ColorBox
7+
8+
9+
class MedianCutExtractor(ColorExtractorBase):
10+
def extract(self, arr: NDArray[NP_T], height: int, width: int, palette_size: int) -> list[Color]:
11+
"""
12+
Extracts a color palette using the median cut algorithm.
13+
14+
Parameters:
15+
arr (np.ndarray): The input array.
16+
height (int): The height of the image.
17+
width (int): The width of the image.
18+
palette_size (int): The number of colors to extract from the image.
19+
20+
Returns:
21+
list[Color]: A list of colors extracted from the image.
22+
"""
23+
24+
arr = self._reshape_array(arr=arr, height=height, width=width)
25+
26+
boxes = [ColorBox(arr)]
27+
while len(boxes) < palette_size:
28+
largest_box_idx = np.argmax(boxes) # type: ignore
29+
boxes = boxes[:largest_box_idx] + boxes[largest_box_idx].split() + boxes[largest_box_idx + 1 :]
30+
31+
total_pixels = width * height
32+
return [Color(tuple(map(int, box.average)), box.pixel_count / total_pixels) for box in boxes]
33+
34+
35+
def median_cut_extraction(arr: np.ndarray, height: int, width: int, palette_size: int) -> list[Color]:
36+
"""
37+
Extracts a color palette using the median cut algorithm.
38+
39+
Parameters:
40+
arr (np.ndarray): The input array.
41+
height (int): The height of the image.
42+
width (int): The width of the image.
43+
palette_size (int): The number of colors to extract from the image.
44+
45+
Returns:
46+
list[Color]: A list of colors extracted from the image.
47+
"""
48+
49+
return MedianCutExtractor().extract(arr, height, width, palette_size)

Pylette/src/extractors/protocol.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Protocol, TypeVar
3+
4+
import numpy as np
5+
from numpy.typing import NDArray
6+
7+
from Pylette.src.color import Color
8+
9+
NP_T = TypeVar("NP_T", bound=np.generic, covariant=True)
10+
11+
12+
class ColorExtractor(Protocol):
13+
def extract(self, arr: NDArray[NP_T], height: int, width: int, palette_size: int) -> list[Color]: ...
14+
15+
16+
class ColorExtractorBase(ABC):
17+
@abstractmethod
18+
def extract(self, arr: NDArray[NP_T], height: int, width: int, palette_size: int) -> list[Color]:
19+
pass
20+
21+
def _reshape_array(self, arr: NDArray[NP_T], height: int, width: int) -> NDArray[NP_T]:
22+
return arr.reshape((height * width, -1))

Pylette/src/utils.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ def _get_min_max(self) -> None:
2323
"""
2424
Calculates the minimum and maximum values for each color channel in the ColorBox.
2525
"""
26-
self.min_channel: NDArray[np.uint8, (3,)] = np.min(self.colors, axis=0)
27-
self.max_channel: NDArray[np.uint8, (3,)] = np.max(self.colors, axis=0)
26+
self.min_channel: NDArray[np.uint8] = np.min(self.colors, axis=0)
27+
self.max_channel: NDArray[np.uint8] = np.max(self.colors, axis=0)
2828

2929
def __lt__(self, other: "ColorBox") -> bool:
3030
"""
@@ -36,10 +36,10 @@ def __lt__(self, other: "ColorBox") -> bool:
3636
Returns:
3737
bool: True if the volume of this ColorBox is less than the volume of the other ColorBox, False otherwise.
3838
"""
39-
return self.size < other.size
39+
return bool(self.size < other.size)
4040

4141
@property
42-
def size(self) -> np.uint64:
42+
def size(self) -> int:
4343
"""
4444
Returns the volume of the ColorBox.
4545
@@ -55,12 +55,12 @@ def _get_dominant_channel(self) -> int:
5555
Returns:
5656
int: The index of the dominant color channel.
5757
"""
58-
diff: NDArray[np.uint8, (3,)] = self.max_channel - self.min_channel
58+
diff: NDArray[np.uint8] = self.max_channel - self.min_channel
5959
dominant_channel = np.argmax(diff)
60-
return dominant_channel
60+
return int(dominant_channel)
6161

6262
@property
63-
def average(self) -> np.ndarray:
63+
def average(self) -> NDArray[np.uint8]:
6464
"""
6565
Calculates the average color contained in the ColorBox.
6666
@@ -80,7 +80,7 @@ def volume(self) -> int:
8080
Returns:
8181
int: The volume of the ColorBox.
8282
"""
83-
diff: NDArray[np.uint8, (3,)] = self.max_channel - self.min_channel
83+
diff: NDArray[np.uint8] = self.max_channel - self.min_channel
8484
return np.prod(diff).item()
8585

8686
def split(self) -> list["ColorBox"]:

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "pylette"
3-
version = "4.0.0"
3+
version = "4.0.1"
44
description = "A Python library for extracting color palettes from images."
55
authors = ["Ivar Stangeby"]
66
license = "MIT"

0 commit comments

Comments
 (0)