Skip to content

Commit cbe701d

Browse files
committed
fix (tasks): handling encryption in prod db
1 parent 8d840bc commit cbe701d

5 files changed

Lines changed: 253 additions & 201 deletions

File tree

src/server/main/db.py

Lines changed: 10 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -12,50 +12,15 @@
1212

1313
# Import config from the current 'main' directory
1414
from main.config import MONGO_URI, MONGO_DB_NAME, ENVIRONMENT
15-
from main.auth.utils import aes_encrypt, aes_decrypt
15+
from workers.utils.crypto import encrypt_doc, decrypt_doc, encrypt_field, decrypt_field
1616

1717
DB_ENCRYPTION_ENABLED = ENVIRONMENT == 'stag'
1818

19-
def _datetime_serializer(obj):
20-
"""JSON serializer for objects not serializable by default json code, like datetime."""
21-
if isinstance(obj, datetime.datetime):
22-
return obj.isoformat()
23-
raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable")
24-
25-
def _encrypt_field(data: Any) -> Any:
26-
if not DB_ENCRYPTION_ENABLED or data is None:
27-
return data
28-
data_str = json.dumps(data, default=_datetime_serializer)
29-
return aes_encrypt(data_str)
30-
31-
def _decrypt_field(data: Any) -> Any:
32-
if not DB_ENCRYPTION_ENABLED or data is None or not isinstance(data, str):
33-
return data
34-
try:
35-
decrypted_str = aes_decrypt(data)
36-
return json.loads(decrypted_str)
37-
except Exception:
38-
return data
39-
40-
def _encrypt_doc(doc: Dict, fields: List[str]):
41-
if not DB_ENCRYPTION_ENABLED or not doc:
42-
return
43-
for field in fields:
44-
if field in doc and doc[field] is not None:
45-
doc[field] = _encrypt_field(doc[field])
46-
47-
def _decrypt_doc(doc: Optional[Dict], fields: List[str]):
48-
if not DB_ENCRYPTION_ENABLED or not doc:
49-
return
50-
for field in fields:
51-
if field in doc and doc[field] is not None:
52-
doc[field] = _decrypt_field(doc[field])
53-
5419
def _decrypt_docs(docs: List[Dict], fields: List[str]):
5520
if not DB_ENCRYPTION_ENABLED or not docs:
5621
return
5722
for doc in docs:
58-
_decrypt_doc(doc, fields)
23+
decrypt_doc(doc, fields)
5924

6025
USER_PROFILES_COLLECTION = "user_profiles"
6126
NOTIFICATIONS_COLLECTION = "notifications"
@@ -144,7 +109,7 @@ async def get_user_profile(self, user_id: str) -> Optional[Dict]:
144109
SENSITIVE_USER_DATA_FIELDS = ["onboardingAnswers", "personalInfo", "pwa_subscription", "privacyFilters"]
145110
for field in SENSITIVE_USER_DATA_FIELDS:
146111
if field in user_data and user_data[field] is not None:
147-
user_data[field] = _decrypt_field(user_data[field])
112+
user_data[field] = decrypt_field(user_data[field])
148113
return doc
149114

150115
async def update_user_profile(self, user_id: str, profile_data: Dict) -> bool:
@@ -156,7 +121,7 @@ async def update_user_profile(self, user_id: str, profile_data: Dict) -> bool:
156121
data_to_update = profile_data.copy()
157122
for key, value in profile_data.items():
158123
if key.startswith("userData.") and key.split('.')[1] in SENSITIVE_USER_DATA_FIELDS:
159-
data_to_update[key] = _encrypt_field(value)
124+
data_to_update[key] = encrypt_field(value)
160125
profile_data = data_to_update
161126

162127
update_operations = {"$set": {}, "$setOnInsert": {}}
@@ -241,7 +206,7 @@ async def get_notifications(self, user_id: str) -> List[Dict]:
241206
for notification in notifications_list:
242207
for field in SENSITIVE_NOTIFICATION_FIELDS:
243208
if field in notification and notification[field] is not None:
244-
notification[field] = _decrypt_field(notification[field])
209+
notification[field] = decrypt_field(notification[field])
245210

246211
# Serialize datetime objects before returning, as they are not JSON-serializable by default.
247212
for notification in notifications_list:
@@ -260,7 +225,7 @@ async def add_notification(self, user_id: str, notification_data: Dict) -> Optio
260225
SENSITIVE_NOTIFICATION_FIELDS = ["message", "suggestion_payload"]
261226
for field in SENSITIVE_NOTIFICATION_FIELDS:
262227
if field in notification_data and notification_data[field] is not None:
263-
notification_data[field] = _encrypt_field(notification_data[field])
228+
notification_data[field] = encrypt_field(notification_data[field])
264229

265230
result = await self.notifications_collection.update_one(
266231
{"user_id": user_id},
@@ -393,7 +358,7 @@ async def add_task(self, user_id: str, task_data: dict) -> str:
393358
"swarm_details": task_data.get("swarm_details") # Will be None for single tasks
394359
}
395360
SENSITIVE_TASK_FIELDS = ["name", "description", "plan", "runs", "original_context", "chat_history", "error", "clarifying_questions", "result", "swarm_details"]
396-
_encrypt_doc(task_doc, SENSITIVE_TASK_FIELDS)
361+
encrypt_doc(task_doc, SENSITIVE_TASK_FIELDS)
397362

398363
await self.task_collection.insert_one(task_doc)
399364
logger.info(f"Created new task {task_id} (type: {task_doc['task_type']}) for user {user_id} with status 'planning'.")
@@ -403,7 +368,7 @@ async def get_task(self, task_id: str, user_id: str) -> Optional[Dict]:
403368
"""Fetches a single task by its ID, ensuring it belongs to the user."""
404369
doc = await self.task_collection.find_one({"task_id": task_id, "user_id": user_id})
405370
SENSITIVE_TASK_FIELDS = ["name", "description", "plan", "runs", "original_context", "chat_history", "error", "clarifying_questions", "result", "swarm_details"]
406-
_decrypt_doc(doc, SENSITIVE_TASK_FIELDS)
371+
decrypt_doc(doc, SENSITIVE_TASK_FIELDS)
407372
return doc
408373

409374
async def get_all_tasks_for_user(self, user_id: str) -> List[Dict]:
@@ -418,7 +383,7 @@ async def update_task(self, task_id: str, updates: Dict) -> bool:
418383
"""Updates an existing task document."""
419384
updates["updated_at"] = datetime.datetime.now(datetime.timezone.utc)
420385
SENSITIVE_TASK_FIELDS = ["name", "description", "plan", "runs", "original_context", "chat_history", "error", "clarifying_questions", "result", "swarm_details"]
421-
_encrypt_doc(updates, SENSITIVE_TASK_FIELDS)
386+
encrypt_doc(updates, SENSITIVE_TASK_FIELDS)
422387
result = await self.task_collection.update_one(
423388
{"task_id": task_id},
424389
{"$set": updates}
@@ -552,7 +517,7 @@ async def add_message(self, user_id: str, role: str, content: str, message_id: O
552517
message_doc["tool_results"] = tool_results
553518

554519
SENSITIVE_MESSAGE_FIELDS = ["content", "thoughts", "tool_calls", "tool_results"]
555-
_encrypt_doc(message_doc, SENSITIVE_MESSAGE_FIELDS)
520+
encrypt_doc(message_doc, SENSITIVE_MESSAGE_FIELDS)
556521

557522
await self.messages_collection.insert_one(message_doc)
558523
logger.info(f"Added message for user {user_id} with role {role}")

src/server/main/tasks/routes.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -358,23 +358,28 @@ async def task_chat(
358358
if not task:
359359
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Task not found.")
360360

361-
# Append user message to top-level chat history
361+
# Decryption is handled by get_task, so we can safely modify the list
362362
new_message = {
363363
"role": "user",
364364
"content": request.message,
365365
"timestamp": datetime.now(timezone.utc)
366366
}
367367

368-
# Put the task back into planning state for re-evaluation.
369-
# The planner will see the chat history and the original context.
370-
# It will NOT create a new run. It will overwrite the top-level plan.
371-
await mongo_manager.task_collection.update_one(
372-
{"task_id": task_id},
373-
{
374-
"$set": {"status": "planning"},
375-
"$push": {"chat_history": new_message}
376-
}
377-
)
368+
chat_history = task.get("chat_history", [])
369+
if not isinstance(chat_history, list):
370+
# This can happen if the field was somehow corrupted or is an old format
371+
logger.warning(f"Task {task_id} chat_history is not a list. Resetting it.")
372+
chat_history = []
373+
chat_history.append(new_message)
374+
375+
# Prepare the payload for a full field update
376+
update_payload = {
377+
"status": "planning", # Revert to planning for re-evaluation
378+
"chat_history": chat_history
379+
}
380+
381+
# The update_task method will handle re-encrypting the entire chat_history field
382+
await mongo_manager.update_task(task_id, update_payload)
378383

379384
# Re-trigger the planner for the same task
380385
generate_plan_from_context.delay(task_id, user_id)

0 commit comments

Comments
 (0)