Skip to content

Commit a7a1efa

Browse files
authored
feat(orchestrator): refactor DownloadOrchestrator into specialized components (#63)
2 parents 1873f58 + 914693a commit a7a1efa

6 files changed

Lines changed: 773 additions & 372 deletions

File tree

.gitignore

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,3 +205,9 @@ cython_debug/
205205
marimo/_static/
206206
marimo/_lsp/
207207
__marimo__/
208+
209+
# backups
210+
*.py.backup
211+
212+
# Road Map Plan files
213+
*PLAN.md
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
"""
2+
Concurrency manager for handling semaphore-controlled async operations.
3+
"""
4+
5+
import asyncio
6+
from typing import List, Optional, Tuple, Callable, Any, TypeVar
7+
from dataclasses import dataclass, field
8+
from datetime import datetime
9+
10+
from forklet.infrastructure.logger import logger
11+
12+
T = TypeVar("T")
13+
14+
15+
@dataclass
16+
class ConcurrencyStats:
17+
"""Statistics for concurrency operations."""
18+
19+
started_tasks: int = 0
20+
completed_tasks: int = 0
21+
failed_tasks: int = 0
22+
cancelled_tasks: int = 0
23+
start_time: Optional[datetime] = field(default_factory=datetime.now)
24+
end_time: Optional[datetime] = None
25+
26+
@property
27+
def duration_seconds(self) -> float:
28+
"""Get total duration in seconds."""
29+
if self.start_time and self.end_time:
30+
return (self.end_time - self.start_time).total_seconds()
31+
return 0.0
32+
33+
34+
class ConcurrencyManager:
35+
"""
36+
Manages concurrent async operations with semaphore control.
37+
38+
Handles task creation, execution with semaphore limits,
39+
and provides statistics and control mechanisms.
40+
"""
41+
42+
def __init__(self, max_concurrent: int = 10):
43+
self.max_concurrent = max_concurrent
44+
self._semaphore = asyncio.Semaphore(max_concurrent)
45+
self._stats = ConcurrencyStats()
46+
self._active_tasks: List[asyncio.Task] = []
47+
self._is_cancelled = False
48+
self._cancellation_event = asyncio.Event()
49+
50+
async def execute_with_concurrency(
51+
self,
52+
items: List[Any],
53+
processor: Callable[[Any], Any],
54+
*,
55+
return_exceptions: bool = True,
56+
) -> Tuple[List[Any], List[Exception]]:
57+
"""
58+
Process items concurrently with semaphore control.
59+
60+
Args:
61+
items: List of items to process
62+
processor: Async function to process each item
63+
return_exceptions: Whether to return exceptions or raise them
64+
65+
Returns:
66+
Tuple of (results, exceptions) where results contains successful outputs
67+
and exceptions contains failed operations (if return_exceptions=True)
68+
"""
69+
if self._is_cancelled:
70+
raise RuntimeError("Concurrency manager has been cancelled")
71+
72+
self._stats.start_time = datetime.now()
73+
self._stats.started_tasks = len(items)
74+
self._active_tasks.clear()
75+
76+
# Create tasks with semaphore control
77+
tasks = [self._process_with_semaphore(item, processor) for item in items]
78+
79+
# Store active tasks for potential cancellation
80+
self._active_tasks = tasks
81+
82+
try:
83+
# Execute all tasks concurrently
84+
results = await asyncio.gather(*tasks, return_exceptions=return_exceptions)
85+
86+
# Separate successful results from exceptions
87+
successful_results = []
88+
exceptions = []
89+
90+
for item, result in zip(items, results):
91+
if isinstance(result, Exception):
92+
self._stats.failed_tasks += 1
93+
exceptions.append(result)
94+
else:
95+
self._stats.completed_tasks += 1
96+
successful_results.append(result)
97+
98+
return successful_results, exceptions
99+
100+
except asyncio.CancelledError:
101+
logger.info("Concurrent operation was cancelled")
102+
# Ensure all tasks are properly cancelled
103+
for task in self._active_tasks:
104+
if not task.done():
105+
task.cancel()
106+
raise
107+
finally:
108+
# Clear active tasks and update stats
109+
self._active_tasks.clear()
110+
self._stats.end_time = datetime.now()
111+
112+
async def _process_with_semaphore(
113+
self, item: Any, processor: Callable[[Any], Any]
114+
) -> Any:
115+
"""Process a single item with semaphore control."""
116+
async with self._semaphore:
117+
# Check for cancellation before processing
118+
if self._cancellation_event.is_set():
119+
return None
120+
121+
try:
122+
return await processor(item)
123+
except Exception:
124+
# Re-raise to be handled by gather
125+
raise
126+
127+
def cancel(self) -> None:
128+
"""Cancel all pending operations."""
129+
if not self._active_tasks:
130+
return
131+
132+
self._is_cancelled = True
133+
self._cancellation_event.set()
134+
135+
# Cancel all active tasks
136+
for task in self._active_tasks:
137+
if not task.done():
138+
task.cancel()
139+
140+
logger.info("Concurrency manager cancelled")
141+
142+
def get_stats(self) -> ConcurrencyStats:
143+
"""Get current concurrency statistics."""
144+
return self._stats
145+
146+
def is_busy(self) -> bool:
147+
"""Check if there are active tasks."""
148+
return len(self._active_tasks) > 0

0 commit comments

Comments
 (0)