Skip to content

Commit dc89964

Browse files
committed
make PriorityQueue and TriggeredService generic
1 parent 2e669e2 commit dc89964

7 files changed

Lines changed: 41 additions & 39 deletions

File tree

cms/io/priorityqueue.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,13 @@
4141
from datetime import datetime
4242
from functools import total_ordering
4343
from typing import TypedDict
44+
import typing
4445

4546
from gevent.event import Event
4647

4748
from cmscommon.datetime import make_datetime, make_timestamp
4849

4950

50-
# TODO: make PriorityQueue generic over the exact type of the QueueItem.
51-
# Allows more exact typing in other classes (i.e. TriggeredService and its children).
5251
class QueueItem:
5352

5453
"""Payload of an item in the queue.
@@ -61,15 +60,16 @@ def to_dict(self) -> dict:
6160
"""Return a dict() representation of the object."""
6261
return self.__dict__
6362

63+
QueueItemT = typing.TypeVar('QueueItemT', bound=QueueItem)
6464

6565
@total_ordering
66-
class QueueEntry:
66+
class QueueEntry(typing.Generic[QueueItemT]):
6767

6868
"""Type of the actual objects in the queue.
6969
7070
"""
7171

72-
def __init__(self, item: QueueItem, priority: int, timestamp: datetime, index: int):
72+
def __init__(self, item: QueueItemT, priority: int, timestamp: datetime, index: int):
7373
"""Create a QueueEntry object.
7474
7575
item: the payload.
@@ -102,7 +102,7 @@ class QueueEntryDict(TypedDict):
102102
priority: int
103103
timestamp: float
104104

105-
class PriorityQueue:
105+
class PriorityQueue(typing.Generic[QueueItemT]):
106106

107107
"""A priority queue.
108108
@@ -125,11 +125,11 @@ def __init__(self):
125125
"""Create a priority queue."""
126126
# The queue: a min-heap whose elements are of the form
127127
# (priority, timestamp, item), where item is the actual data.
128-
self._queue: list[QueueEntry] = []
128+
self._queue: list[QueueEntry[QueueItemT]] = []
129129

130130
# Reverse lookup for the items in the queue: a dictionary
131131
# associating the index in the queue to each item.
132-
self._reverse = {}
132+
self._reverse: dict[QueueItemT, int] = {}
133133

134134
# Event to signal that there are items in the queue.
135135
self._event = Event()
@@ -159,7 +159,7 @@ def _verify(self) -> bool:
159159
return False
160160
return True
161161

162-
def __contains__(self, item: QueueItem) -> bool:
162+
def __contains__(self, item: QueueItemT) -> bool:
163163
"""Implement the 'in' operator for an item in the queue.
164164
165165
item: an item to search.
@@ -234,7 +234,7 @@ def _updown_heap(self, idx: int) -> int:
234234
idx = self._up_heap(idx)
235235
return self._down_heap(idx)
236236

237-
def push(self, item: QueueItem, priority: int | None = None,
237+
def push(self, item: QueueItemT, priority: int | None = None,
238238
timestamp: datetime | None = None) -> bool:
239239
"""Push an item in the queue. If timestamp is not specified,
240240
uses the current time.
@@ -270,7 +270,7 @@ def push(self, item: QueueItem, priority: int | None = None,
270270

271271
return True
272272

273-
def top(self, wait: bool = False) -> QueueEntry:
273+
def top(self, wait: bool = False) -> QueueEntry[QueueItemT]:
274274
"""Return the first element in the queue without extracting it.
275275
276276
wait: if True, block until an element is present.
@@ -292,7 +292,7 @@ def top(self, wait: bool = False) -> QueueEntry:
292292
continue
293293
return self._queue[0]
294294

295-
def pop(self, wait: bool = False) -> QueueEntry:
295+
def pop(self, wait: bool = False) -> QueueEntry[QueueItemT]:
296296
"""Extract (and return) the first element in the queue.
297297
298298
wait: if True, block until an element is present.
@@ -317,7 +317,7 @@ def pop(self, wait: bool = False) -> QueueEntry:
317317
self._event.clear()
318318
return top
319319

320-
def remove(self, item: QueueItem) -> QueueEntry:
320+
def remove(self, item: QueueItemT) -> QueueEntry[QueueItemT]:
321321
"""Remove an item from the queue. Raise a KeyError if not present.
322322
323323
item: the item to remove.
@@ -343,7 +343,7 @@ def remove(self, item: QueueItem) -> QueueEntry:
343343

344344
return entry
345345

346-
def set_priority(self, item: QueueItem, priority: int):
346+
def set_priority(self, item: QueueItemT, priority: int):
347347
"""Change the priority of an item inside the queue. Raises an
348348
exception if the item is not in the queue.
349349

cms/io/triggeredservice.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,13 @@
3131
from gevent.event import Event
3232

3333
from cms.io import PriorityQueue, Service, rpc_method
34-
from cms.io.priorityqueue import QueueEntry, QueueEntryDict, QueueItem
34+
from cms.io.priorityqueue import QueueEntry, QueueEntryDict, QueueItemT
3535

3636

3737
logger = logging.getLogger(__name__)
3838

3939

40-
class Executor(metaclass=ABCMeta):
40+
class Executor(typing.Generic[QueueItemT], metaclass=ABCMeta):
4141

4242
"""A class taking care of executing operations.
4343
@@ -61,9 +61,9 @@ def __init__(self, batch_executions: bool = False):
6161
super().__init__()
6262

6363
self._batch_executions = batch_executions
64-
self._operation_queue = PriorityQueue()
64+
self._operation_queue: PriorityQueue[QueueItemT] = PriorityQueue()
6565

66-
def __contains__(self, item: QueueItem) -> bool:
66+
def __contains__(self, item: QueueItemT) -> bool:
6767
"""Return whether the item is in the queue.
6868
6969
item: the item to look for.
@@ -84,7 +84,7 @@ def get_status(self) -> list[QueueEntryDict]:
8484
"""
8585
return self._operation_queue.get_status()
8686

87-
def enqueue(self, item: QueueItem, priority: int | None = None, timestamp: datetime | None = None) -> bool:
87+
def enqueue(self, item: QueueItemT, priority: int | None = None, timestamp: datetime | None = None) -> bool:
8888
"""Add an item to the queue.
8989
9090
item: the item to add.
@@ -97,7 +97,7 @@ def enqueue(self, item: QueueItem, priority: int | None = None, timestamp: datet
9797
"""
9898
return self._operation_queue.push(item, priority, timestamp)
9999

100-
def dequeue(self, item: QueueItem) -> QueueEntry:
100+
def dequeue(self, item: QueueItemT) -> QueueEntry[QueueItemT]:
101101
"""Remove an item from the queue.
102102
103103
item: the item to remove.
@@ -107,7 +107,7 @@ def dequeue(self, item: QueueItem) -> QueueEntry:
107107
"""
108108
return self._operation_queue.remove(item)
109109

110-
def _pop(self, wait: bool = False) -> QueueEntry:
110+
def _pop(self, wait: bool = False) -> QueueEntry[QueueItemT]:
111111
"""Extract (and return) the first element in the queue.
112112
113113
wait: if True, block until an element is present.
@@ -180,7 +180,7 @@ def max_operations_per_batch(self) -> int:
180180
return 0
181181

182182
@abstractmethod
183-
def execute(self, entry: QueueEntry | list[QueueEntry]):
183+
def execute(self, entry: QueueEntry[QueueItemT] | list[QueueEntry[QueueItemT]]):
184184
"""Perform a single operation.
185185
186186
Must be implemented if batch_execution is false.
@@ -196,8 +196,10 @@ def execute(self, entry: QueueEntry | list[QueueEntry]):
196196
pass
197197

198198

199+
# The correct bound here is Executor[QueueItemT], but expressing that would
200+
# require higher-kinded types, which python's typechecking does not support.
199201
ExecutorT = typing.TypeVar('ExecutorT', bound=Executor)
200-
class TriggeredService(Service, typing.Generic[ExecutorT]):
202+
class TriggeredService(Service, typing.Generic[QueueItemT, ExecutorT]):
201203

202204
"""A service receiving notifications to perform an operation.
203205
@@ -263,7 +265,7 @@ def get_executor(self) -> ExecutorT:
263265
"""
264266
return self._executors[0]
265267

266-
def enqueue(self, operation: QueueItem, priority: int | None = None,
268+
def enqueue(self, operation: QueueItemT, priority: int | None = None,
267269
timestamp: datetime | None = None) -> int:
268270
"""Add an operation to the queue of each executor.
269271
@@ -282,7 +284,7 @@ def enqueue(self, operation: QueueItem, priority: int | None = None,
282284
ret += 1
283285
return ret
284286

285-
def dequeue(self, operation: QueueItem):
287+
def dequeue(self, operation: QueueItemT):
286288
"""Remove an operation from the queue of each executor.
287289
288290
operation: the operation to dequeue.

cms/service/EvaluationService.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959
logger = logging.getLogger(__name__)
6060

6161

62-
class EvaluationExecutor(Executor):
62+
class EvaluationExecutor(Executor[ESOperation]):
6363

6464
# Real maximum number of operations to be sent to a worker.
6565
MAX_OPERATIONS_PER_BATCH = 25
@@ -122,7 +122,7 @@ def max_operations_per_batch(self) -> int:
122122
ratio, ret)
123123
return ret
124124

125-
def execute(self, entries: list[QueueEntry]):
125+
def execute(self, entries: list[QueueEntry[ESOperation]]):
126126
"""Execute a batch of operations in the queue.
127127
128128
The operations might not be executed immediately because of
@@ -190,7 +190,7 @@ def _pop(self, wait=False):
190190
self._remove_from_cumulative_status(queue_entry)
191191
return queue_entry
192192

193-
def _remove_from_cumulative_status(self, queue_entry: QueueEntry):
193+
def _remove_from_cumulative_status(self, queue_entry: QueueEntry[ESOperation]):
194194
# Remove the item from the cumulative status dictionary.
195195
key = queue_entry.item.short_key() + (queue_entry.priority,)
196196
self.queue_status_cumulative[key]["item"]["multiplicity"] -= 1
@@ -223,7 +223,7 @@ def __init__(self, job: Job, job_success: bool):
223223
self.job_success = job_success
224224

225225

226-
class EvaluationService(TriggeredService[EvaluationExecutor]):
226+
class EvaluationService(TriggeredService[ESOperation, EvaluationExecutor]):
227227
"""Evaluation service.
228228
229229
"""

cms/service/PrintingService.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def to_dict(self):
6161
return {"printjob_id": self.printjob_id}
6262

6363

64-
class PrintingExecutor(Executor):
64+
class PrintingExecutor(Executor[PrintingOperation]):
6565
def __init__(self, file_cacher):
6666
super().__init__()
6767

@@ -72,7 +72,7 @@ def __init__(self, file_cacher):
7272
self.jinja2_env.filters["escape_tex_normal"] = escape_tex_normal
7373
self.jinja2_env.filters["escape_tex_tt"] = escape_tex_tt
7474

75-
def execute(self, entry: QueueEntry):
75+
def execute(self, entry: QueueEntry[PrintingOperation]):
7676
"""Print a print job.
7777
7878
This is the core of PrintingService.
@@ -208,12 +208,12 @@ def execute(self, entry: QueueEntry):
208208
rmtree(directory)
209209

210210

211-
class PrintingService(TriggeredService):
211+
class PrintingService(TriggeredService[PrintingOperation, PrintingExecutor]):
212212
"""A service that prepares print jobs and sends them to a printer.
213213
214214
"""
215215

216-
def __init__(self, shard):
216+
def __init__(self, shard: int):
217217
"""Initialize the PrintingService.
218218
219219
"""

cms/service/ProxyService.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def to_dict(self):
128128
"data": self.data}
129129

130130

131-
class ProxyExecutor(Executor):
131+
class ProxyExecutor(Executor[ProxyOperation]):
132132
"""A thread that sends data to one ranking.
133133
134134
The object is used as a thread-local storage and its run method is
@@ -182,7 +182,7 @@ def __init__(self, ranking: str):
182182
self._ranking = ranking
183183
self._visible_ranking = safe_url(ranking)
184184

185-
def execute(self, entries: list[QueueEntry]):
185+
def execute(self, entries: list[QueueEntry[ProxyOperation]]):
186186
"""Consume (i.e. send) the data put in the queue, forever.
187187
188188
Pick all operations found in the queue (if there aren't any,
@@ -230,7 +230,7 @@ def execute(self, entries: list[QueueEntry]):
230230
gevent.sleep(self.FAILURE_WAIT)
231231

232232

233-
class ProxyService(TriggeredService):
233+
class ProxyService(TriggeredService[ProxyOperation, ProxyExecutor]):
234234
"""Maintain the information held by rankings up-to-date.
235235
236236
Discover (by receiving notifications and by periodically sweeping

cms/service/ScoringService.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,12 @@
3838
logger = logging.getLogger(__name__)
3939

4040

41-
class ScoringExecutor(Executor):
41+
class ScoringExecutor(Executor[ScoringOperation]):
4242
def __init__(self, proxy_service):
4343
super().__init__()
4444
self.proxy_service = proxy_service
4545

46-
def execute(self, entry: QueueEntry):
46+
def execute(self, entry: QueueEntry[ScoringOperation]):
4747
"""Assign a score to a submission result.
4848
4949
This is the core of ScoringService: here we retrieve the result
@@ -113,7 +113,7 @@ def execute(self, entry: QueueEntry):
113113
submission_id=submission.id)
114114

115115

116-
class ScoringService(TriggeredService):
116+
class ScoringService(TriggeredService[ScoringOperation, ScoringExecutor]):
117117
"""A service that assigns a score to submission results.
118118
119119
A submission result is ready to be scored when its compilation is

cmstestsuite/unit_tests/io/triggeredservice_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def get_notifications(self):
4040
return self._notifications
4141

4242

43-
class FakeExecutor(Executor):
43+
class FakeExecutor(Executor[FakeQueueItem]):
4444
def __init__(self, notifier, batch_executions=False):
4545
super().__init__(batch_executions)
4646
self._notifier = notifier

0 commit comments

Comments
 (0)