-
Notifications
You must be signed in to change notification settings - Fork 742
Expand file tree
/
Copy pathtask_queue.py
More file actions
179 lines (154 loc) · 7.57 KB
/
task_queue.py
File metadata and controls
179 lines (154 loc) · 7.57 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
"""
Redis Queue implementation for SchedulerMessageItem objects.
This module provides a Redis-based queue implementation that can replace
the local memos_message_queue functionality in BaseScheduler.
"""
from memos.context.context import get_current_trace_id
from memos.log import get_logger
from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
from memos.mem_scheduler.task_schedule_modules.local_queue import SchedulerLocalQueue
from memos.mem_scheduler.task_schedule_modules.orchestrator import SchedulerOrchestrator
from memos.mem_scheduler.task_schedule_modules.redis_queue import SchedulerRedisQueue
from memos.mem_scheduler.utils.db_utils import get_utc_now
from memos.mem_scheduler.utils.misc_utils import group_messages_by_user_and_mem_cube
from memos.mem_scheduler.utils.monitor_event_utils import emit_monitor_event, to_iso
from memos.mem_scheduler.utils.status_tracker import TaskStatusTracker
logger = get_logger(__name__)
class ScheduleTaskQueue:
def __init__(
self,
use_redis_queue: bool,
maxsize: int,
disabled_handlers: list | None = None,
orchestrator: SchedulerOrchestrator | None = None,
status_tracker: TaskStatusTracker | None = None,
):
self.use_redis_queue = use_redis_queue
self.maxsize = maxsize
self.orchestrator = SchedulerOrchestrator() if orchestrator is None else orchestrator
self.status_tracker = status_tracker
if self.use_redis_queue:
if maxsize is None or not isinstance(maxsize, int) or maxsize <= 0:
maxsize = None
self.memos_message_queue = SchedulerRedisQueue(
max_len=maxsize,
consumer_group="scheduler_group",
consumer_name="scheduler_consumer",
orchestrator=self.orchestrator,
status_tracker=self.status_tracker, # Propagate status_tracker
)
else:
self.memos_message_queue = SchedulerLocalQueue(maxsize=self.maxsize)
self.disabled_handlers = disabled_handlers
def set_status_tracker(self, status_tracker: TaskStatusTracker) -> None:
"""
Set the status tracker for this queue and propagate it to the underlying queue implementation.
This allows the tracker to be injected after initialization (e.g., when Redis connection becomes available).
"""
self.status_tracker = status_tracker
if self.memos_message_queue and hasattr(self.memos_message_queue, "status_tracker"):
# SchedulerRedisQueue has status_tracker attribute (from our previous fix)
# SchedulerLocalQueue can also accept it dynamically if it doesn't use __slots__
self.memos_message_queue.status_tracker = status_tracker
logger.info("Propagated status_tracker to underlying message queue")
def ack_message(
self,
user_id: str,
mem_cube_id: str,
task_label: str,
redis_message_id,
message: ScheduleMessageItem | None,
) -> None:
if not isinstance(self.memos_message_queue, SchedulerRedisQueue):
logger.warning("ack_message is only supported for Redis queues")
return
self.memos_message_queue.ack_message(
user_id=user_id,
mem_cube_id=mem_cube_id,
task_label=task_label,
redis_message_id=redis_message_id,
message=message,
)
def get_stream_keys(self) -> list[str]:
if isinstance(self.memos_message_queue, SchedulerRedisQueue):
stream_keys = self.memos_message_queue.get_stream_keys()
else:
stream_keys = list(self.memos_message_queue.queue_streams.keys())
return stream_keys
def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageItem]):
"""Submit messages to the message queue (either local queue or Redis)."""
if isinstance(messages, ScheduleMessageItem):
messages = [messages]
if len(messages) < 1:
logger.error("submit_messages called with empty payload")
return
current_trace_id = get_current_trace_id()
for msg in messages:
if current_trace_id:
# Prefer current request trace_id so logs can be correlated
msg.trace_id = current_trace_id
msg.stream_key = self.memos_message_queue.get_stream_key(
user_id=msg.user_id, mem_cube_id=msg.mem_cube_id, task_label=msg.label
)
if len(messages) == 1:
if getattr(messages[0], "timestamp", None) is None:
messages[0].timestamp = get_utc_now()
if self.disabled_handlers and messages[0].label in self.disabled_handlers:
logger.debug(
"Skip disabled handler. label=%s item_id=%s user_id=%s mem_cube_id=%s",
messages[0].label,
messages[0].item_id,
messages[0].user_id,
messages[0].mem_cube_id,
)
else:
enqueue_ts = to_iso(getattr(messages[0], "timestamp", None))
emit_monitor_event(
"enqueue",
messages[0],
{"enqueue_ts": enqueue_ts, "event_duration_ms": 0, "total_duration_ms": 0},
)
self.memos_message_queue.put(messages[0])
else:
user_cube_groups = group_messages_by_user_and_mem_cube(messages)
# Process each user and mem_cube combination
for _user_id, cube_groups in user_cube_groups.items():
for _mem_cube_id, user_cube_msgs in cube_groups.items():
for message in user_cube_msgs:
if not isinstance(message, ScheduleMessageItem):
error_msg = f"Invalid message type: {type(message)}, expected ScheduleMessageItem"
logger.error(error_msg)
raise TypeError(error_msg)
if getattr(message, "timestamp", None) is None:
message.timestamp = get_utc_now()
if self.disabled_handlers and message.label in self.disabled_handlers:
logger.debug(
"Skip disabled handler. label=%s item_id=%s user_id=%s mem_cube_id=%s",
message.label,
message.item_id,
message.user_id,
message.mem_cube_id,
)
continue
enqueue_ts = to_iso(getattr(message, "timestamp", None))
emit_monitor_event(
"enqueue",
message,
{
"enqueue_ts": enqueue_ts,
"event_duration_ms": 0,
"total_duration_ms": 0,
},
)
self.memos_message_queue.put(message)
logger.info(
"Queue submit completed. backend=%s total=%s",
"redis_queue" if self.use_redis_queue else "local_queue",
len(messages),
)
def get_messages(self, batch_size: int) -> list[ScheduleMessageItem]:
return self.memos_message_queue.get_messages(batch_size=batch_size)
def clear(self):
self.memos_message_queue.clear()
def qsize(self):
return self.memos_message_queue.qsize()