Skip to content

Commit bb387f9

Browse files
authored
feat: Batch processing images (#91)
- Added batch_extract_colors() for parallel processing of multiple images - CLI now supports processing multiple image sources in a single command - CLI has an enhanced progress display with recently extracted palettes and palette previews - Added metadata-field to Palette-class. It contains information about the image source, extraction parameters and extraction performance. This is a preparatory step for better palette exports. - removed --filename and --image-url parameters - mode parameter to color extraction methods now accepts a ExtractionMethod enum member instead of a string.
1 parent 4a360a8 commit bb387f9

11 files changed

Lines changed: 619 additions & 231 deletions

File tree

Pylette/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from Pylette.src.color import Color
2-
from Pylette.src.color_extraction import extract_colors
2+
from Pylette.src.color_extraction import batch_extract_colors, extract_colors
33
from Pylette.src.palette import Palette
44
from Pylette.src.types import ImageInput
55

6-
__all__ = ["extract_colors", "Palette", "Color", "ImageInput"]
6+
__all__ = ["extract_colors", "batch_extract_colors", "Palette", "Color", "ImageInput"]

Pylette/cmd.py

Lines changed: 64 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
import pathlib
22
from enum import Enum
3+
from typing import Annotated, List
34

45
import typer
56

6-
from Pylette import extract_colors
7-
8-
9-
class ExtractionMode(str, Enum):
10-
KM = "KM"
11-
MC = "MC"
7+
from Pylette.src.cli_utils import PyletteProgress
8+
from Pylette.src.color_extraction import batch_extract_colors
9+
from Pylette.src.types import BatchResult, ExtractionMethod
1210

1311

1412
class SortBy(str, Enum):
@@ -25,12 +23,12 @@ class ColorSpace(str, Enum):
2523
pylette_app = typer.Typer()
2624

2725

28-
@pylette_app.command()
26+
@pylette_app.command(no_args_is_help=True)
2927
def main(
30-
ctx: typer.Context,
31-
filename: pathlib.Path | None = None,
32-
image_url: str | None = None,
33-
mode: ExtractionMode = ExtractionMode.KM,
28+
image_sources: Annotated[
29+
List[str], typer.Argument(help="A list of paths / directories / URLs pointing to images.")
30+
], # These can be paths or URLs
31+
mode: ExtractionMethod = ExtractionMethod.KM,
3432
n: int = 5,
3533
sort_by: SortBy = SortBy.luminance,
3634
stdout: bool = True,
@@ -43,41 +41,68 @@ def main(
4341
max=255,
4442
help="Alpha threshold for transparent image masking (0-255). Pixels with alpha below this value are excluded.",
4543
),
44+
num_threads: int | None = typer.Option(
45+
None, min=1, help="Number of threads used for batch extraction of color palettes"
46+
),
4647
):
47-
if filename is None and image_url is None:
48-
typer.echo(ctx.get_help())
49-
raise typer.Exit(code=0)
50-
51-
if filename is not None and image_url is not None:
52-
typer.echo("Please provide either a filename or an image-url, but not both.")
53-
raise typer.Exit(code=1)
54-
55-
if filename is not None:
56-
image = filename # Path
57-
else:
58-
assert image_url is not None
59-
image = image_url # str (URL)
60-
6148
output_file_path = str(out_filename) if out_filename is not None else None
62-
try:
63-
palette = extract_colors(
64-
image=image,
49+
50+
# Set up progress bar for CLI
51+
with PyletteProgress(palette_size=n) as progress:
52+
task_id = progress.add_task("Extracting colors...", total=len(image_sources))
53+
54+
def progress_callback(task_number: int, result: BatchResult):
55+
if result.success and result.palette:
56+
progress.mark_task_complete(
57+
task_number=task_number,
58+
task_id=task_id,
59+
completed_task_name=result.palette.metadata["image_source"] if result.palette.metadata else "",
60+
palette_colors=result.palette.colors,
61+
)
62+
else:
63+
progress.update(task_id, advance=1)
64+
65+
results = batch_extract_colors(
66+
images=image_sources,
6567
palette_size=n,
6668
sort_mode=sort_by.value,
67-
mode=mode.value,
69+
mode=mode,
6870
alpha_mask_threshold=alpha_mask_threshold,
71+
max_workers=num_threads,
72+
progress_callback=progress_callback,
6973
)
70-
except ValueError as e:
71-
typer.echo(str(e))
72-
raise typer.Exit(code=1)
73-
74-
palette.to_csv(filename=output_file_path, frequency=True, stdout=stdout, colorspace=colorspace.value)
75-
if display_colors:
76-
palette.display()
77-
7874

79-
def docs():
80-
typer.launch("https://qtiptip.github.io/Pylette/")
75+
successful = [r for r in results if r.success]
76+
failed = [r for r in results if not r.success]
77+
for success in successful:
78+
if success.palette is not None:
79+
success.palette.to_csv(
80+
filename=output_file_path, frequency=True, stdout=stdout, colorspace=colorspace.value
81+
)
82+
if display_colors:
83+
success.palette.display()
84+
85+
if failed:
86+
print_extraction_summary(successful, failed)
87+
88+
# If we have no successful extractions, return with code 1
89+
if not successful:
90+
raise typer.Exit(1)
91+
# Otherwise, if we have some failures, return with code 2
92+
elif failed:
93+
raise typer.Exit(2)
94+
95+
96+
def print_extraction_summary(successful: list[BatchResult], failed: list[BatchResult]):
97+
total = len(successful) + len(failed)
98+
99+
if successful:
100+
typer.secho(f"✓ Processed {len(successful)}/{total} images successfully", fg=typer.colors.GREEN)
101+
if failed:
102+
typer.secho(f"✗ {len(failed)} images failed:", fg=typer.colors.RED)
103+
for result in failed:
104+
error_msg = str(result.error)
105+
typer.secho(f"{result.source}: {error_msg}", err=True)
81106

82107

83108
def main_typer() -> None:

Pylette/src/cli_utils.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
from collections import deque
2+
from typing import Iterator
3+
4+
from rich.console import Console, RenderableType
5+
from rich.progress import BarColumn, Progress, TaskID, TextColumn, TimeElapsedColumn
6+
from rich.table import Table
7+
from rich.text import Text
8+
9+
from Pylette.src.color import Color
10+
11+
12+
class RecentlyCompletedDisplay:
13+
def __init__(self, num_items: int = 5, max_text_width: int = 20, max_num_preview_colors: int = 6) -> None:
14+
self.max_items = num_items
15+
self.max_name_width = max_text_width
16+
self.items = deque(maxlen=self.max_items)
17+
self.num_preview_colors = max_num_preview_colors
18+
19+
super().__init__()
20+
21+
def add_completed_task(self, task_number: int, task_name: str, colors: list[Color]) -> None:
22+
# Truncate name if too long
23+
display_name = task_name
24+
if len(display_name) > self.max_name_width:
25+
display_name = task_name[:-3] + "..."
26+
27+
self.items.append(
28+
{
29+
"number": task_number,
30+
"name": display_name,
31+
"colors": colors[: self.num_preview_colors],
32+
}
33+
)
34+
35+
def render(self) -> RenderableType:
36+
table = Table.grid("Number", "Image Source", "Preview", padding=(0, 1))
37+
table.add_column() # Task number
38+
table.add_column(width=self.max_name_width) # Task name
39+
table.add_column() # Palette dots
40+
41+
# Add each recent task as a row
42+
table.add_row("Number", "Image Source", "Preview")
43+
for task_info in list(self.items):
44+
task_name = task_info["name"]
45+
task_number = str(task_info["number"])
46+
colors = task_info["colors"]
47+
48+
# Create colored dots for palette
49+
dots_text = Text()
50+
for c in colors:
51+
r, g, b = c.rgb
52+
color = f"rgb({r},{g},{b})"
53+
dots_text.append("●", style=color)
54+
55+
table.add_row(task_number, task_name, dots_text)
56+
57+
return table
58+
59+
60+
class PyletteProgress(Progress):
61+
"""Custom Progress class for Pylette color extraction with palette preview"""
62+
63+
def __init__(
64+
self,
65+
palette_size: int = 5,
66+
max_recent_items: int = 5,
67+
max_name_width: int = 40,
68+
console_width: int = 140,
69+
*args, # pyright: ignore[reportMissingParameterType]
70+
**kwargs, # pyright: ignore[reportMissingParameterType]
71+
):
72+
# Create the recently completed display
73+
self.recently_completed = RecentlyCompletedDisplay(
74+
num_items=max_recent_items, max_text_width=max_name_width, max_num_preview_colors=palette_size
75+
)
76+
77+
# Set up the console with appropriate width
78+
console = Console(width=console_width)
79+
80+
# Initialize Progress with custom columns
81+
super().__init__(
82+
TextColumn("[bold blue]{task.description}"),
83+
BarColumn(bar_width=40),
84+
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
85+
TextColumn("({task.completed}/{task.total})"),
86+
TimeElapsedColumn(),
87+
console=console,
88+
*args,
89+
**kwargs,
90+
)
91+
92+
def get_renderables(self) -> Iterator[RenderableType]:
93+
for renderable in super().get_renderables():
94+
yield renderable
95+
96+
yield ""
97+
yield self.recently_completed.render()
98+
99+
def mark_task_complete(
100+
self, task_id: TaskID, task_number: int, completed_task_name: str, palette_colors: list[Color]
101+
) -> None:
102+
self.update(task_id, advance=1)
103+
self.recently_completed.add_completed_task(task_number, completed_task_name, palette_colors)

0 commit comments

Comments
 (0)