-
Notifications
You must be signed in to change notification settings - Fork 791
Expand file tree
/
Copy pathlocal_queue.py
More file actions
261 lines (212 loc) · 10.2 KB
/
local_queue.py
File metadata and controls
261 lines (212 loc) · 10.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
"""
Local Queue implementation for SchedulerMessageItem objects.
This module provides a local-based queue implementation that can replace
the local memos_message_queue functionality in BaseScheduler.
"""
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from collections.abc import Callable
from memos.log import get_logger
from memos.mem_scheduler.general_modules.misc import AutoDroppingQueue as Queue
from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
from memos.mem_scheduler.schemas.task_schemas import DEFAULT_STREAM_KEY_PREFIX
from memos.mem_scheduler.task_schedule_modules.orchestrator import SchedulerOrchestrator
from memos.mem_scheduler.utils.status_tracker import TaskStatusTracker
from memos.mem_scheduler.webservice_modules.redis_service import RedisSchedulerModule
logger = get_logger(__name__)
class SchedulerLocalQueue(RedisSchedulerModule):
def __init__(
self,
maxsize: int = 0,
stream_key_prefix: str = DEFAULT_STREAM_KEY_PREFIX,
orchestrator: SchedulerOrchestrator | None = None,
status_tracker: TaskStatusTracker | None = None,
):
"""
Initialize the SchedulerLocalQueue with a maximum queue size limit.
Arguments match SchedulerRedisQueue for compatibility.
Args:
maxsize (int): Maximum number of messages allowed in each individual queue.
stream_key_prefix (str): Prefix for stream keys (simulated).
orchestrator: SchedulerOrchestrator instance (ignored).
status_tracker: TaskStatusTracker instance (ignored).
"""
super().__init__()
self.stream_key_prefix = stream_key_prefix or "local_queue"
self.max_internal_message_queue_size = maxsize
# Dictionary to hold per-stream queues: key = stream_key, value = Queue[ScheduleMessageItem]
self.queue_streams: dict[str, Queue[ScheduleMessageItem]] = {}
self.orchestrator = orchestrator
self.status_tracker = status_tracker
self._is_listening = False
self._message_handler: Callable[[ScheduleMessageItem], None] | None = None
logger.info(
f"SchedulerLocalQueue initialized with max_internal_message_queue_size={self.max_internal_message_queue_size}"
)
def get_stream_key(self, user_id: str, mem_cube_id: str, task_label: str) -> str:
stream_key = f"{self.stream_key_prefix}:{user_id}:{mem_cube_id}:{task_label}"
return stream_key
def put(
self, message: ScheduleMessageItem, block: bool = True, timeout: float | None = None
) -> None:
"""
Put a message into the appropriate internal queue based on user_id and mem_cube_id.
If the corresponding queue does not exist, it is created automatically.
This method uses a local in-memory queue (not Redis) for buffering messages.
Args:
message (ScheduleMessageItem): The message to enqueue.
block (bool): If True, block if the queue is full; if False, raise Full immediately.
timeout (float | None): Maximum time to wait for the queue to become available.
If None, block indefinitely. Ignored if block=False.
Raises:
queue.Full: If the queue is full and block=False or timeout expires.
Exception: Any underlying error during queue.put() operation.
"""
stream_key = self.get_stream_key(
user_id=message.user_id, mem_cube_id=message.mem_cube_id, task_label=message.label
)
message.stream_key = stream_key
# Create the queue if it doesn't exist yet
if stream_key not in self.queue_streams:
logger.info(f"Creating new internal queue for stream: {stream_key}")
self.queue_streams[stream_key] = Queue(maxsize=self.max_internal_message_queue_size)
try:
self.queue_streams[stream_key].put(item=message, block=block, timeout=timeout)
logger.debug(
"Local queue enqueued. stream=%s size=%s label=%s item_id=%s",
stream_key,
self.queue_streams[stream_key].qsize(),
message.label,
message.item_id,
)
except Exception as e:
logger.error(f"Failed to put message into queue '{stream_key}': {e}", exc_info=True)
raise # Re-raise to maintain caller expectations
def get(
self,
stream_key: str,
block: bool = True,
timeout: float | None = None,
batch_size: int | None = 1,
) -> list[ScheduleMessageItem]:
if batch_size is not None and batch_size <= 0:
logger.warning(
f"get() called with invalid batch_size: {batch_size}. Returning empty list."
)
return []
# Return empty list if queue does not exist
if stream_key not in self.queue_streams:
logger.debug("Stream %s does not exist when trying to get messages", stream_key)
return []
# Ensure we always request a batch so we get a list back
effective_batch_size = batch_size if batch_size is not None else 1
# Note: Assumes custom Queue implementation supports batch_size parameter
res = self.queue_streams[stream_key].get(
block=block, timeout=timeout, batch_size=effective_batch_size
)
logger.debug(
f"Retrieved {len(res)} messages from queue '{stream_key}'. Current size: {self.queue_streams[stream_key].qsize()}"
)
return res
def get_nowait(self, stream_key: str, batch_size: int | None = 1) -> list[ScheduleMessageItem]:
"""
Non-blocking version of get(). Equivalent to get(stream_key, block=False, batch_size=batch_size).
Returns immediately with available messages or an empty list if queue is empty.
Args:
stream_key (str): The stream/queue identifier.
batch_size (int | None): Number of messages to retrieve in a batch.
If None, retrieves one message.
Returns:
List[ScheduleMessageItem]: Retrieved messages or empty list if queue is empty.
"""
logger.debug(f"get_nowait() called for {stream_key} with batch_size: {batch_size}")
return self.get(stream_key=stream_key, block=False, batch_size=batch_size)
def get_messages(self, batch_size: int) -> list[ScheduleMessageItem]:
"""
Get messages from all streams in round-robin or sequential fashion.
Equivalent to SchedulerRedisQueue.get_messages.
"""
messages = []
# Snapshot keys to avoid runtime modification issues
stream_keys = list(self.queue_streams.keys())
# Simple strategy: try to get up to batch_size messages across all streams
# We can just iterate and collect.
# Calculate how many to get per stream to be fair?
# Or just greedy? Redis implementation uses a complex logic.
# For local, let's keep it simple: just iterate and take what's available (non-blocking)
for stream_key in stream_keys:
if len(messages) >= batch_size:
break
needed = batch_size - len(messages)
# Use get_nowait to avoid blocking
fetched = self.get_nowait(stream_key=stream_key, batch_size=needed)
messages.extend(fetched)
if messages and len(messages) >= batch_size:
logger.debug(
"Local queue dequeued batch. batch_size=%s requested_batch_size=%s active_streams=%s",
len(messages),
batch_size,
len(stream_keys),
)
return messages
def qsize(self) -> dict:
"""
Return the current size of all internal queues as a dictionary.
Each key is the stream name, and each value is the number of messages in that queue.
Also includes 'total_size'.
Returns:
Dict[str, int]: Mapping from stream name to current queue size.
"""
sizes = {stream: queue.qsize() for stream, queue in self.queue_streams.items()}
total_size = sum(sizes.values())
sizes["total_size"] = total_size
logger.debug(f"Current queue sizes: {sizes}")
return sizes
def clear(self, stream_key: str | None = None) -> None:
if stream_key:
if stream_key in self.queue_streams:
self.queue_streams[stream_key].clear()
logger.info("Cleared local queue stream: %s", stream_key)
else:
for queue in self.queue_streams.values():
queue.clear()
logger.info("Cleared all local queue streams. stream_count=%s", len(self.queue_streams))
@property
def unfinished_tasks(self) -> int:
"""
Calculate the total number of unprocessed messages across all queues.
This is a convenience property for monitoring overall system load.
Returns:
int: Sum of all message counts in all internal queues.
"""
# qsize() now includes "total_size", so we need to be careful not to double count if we use qsize() values
# But qsize() implementation above sums values from queue_streams, then adds total_size.
# So sum(self.queue_streams.values().qsize()) is safer.
total = sum(queue.qsize() for queue in self.queue_streams.values())
logger.debug(f"Total unfinished tasks across all queues: {total}")
return total
def get_stream_keys(self, stream_key_prefix: str | None = None) -> list[str]:
"""
Return list of active stream keys.
"""
prefix = stream_key_prefix or self.stream_key_prefix
return [k for k in self.queue_streams if k.startswith(prefix)]
def size(self) -> int:
"""
Total size of all queues.
"""
return sum(q.qsize() for q in self.queue_streams.values())
def empty(self) -> bool:
"""
Check if all queues are empty.
"""
return self.size() == 0
def full(self) -> bool:
"""
Check if any queue is full (approximate).
"""
if self.max_internal_message_queue_size <= 0:
return False
return any(
q.qsize() >= self.max_internal_message_queue_size for q in self.queue_streams.values()
)