-
Notifications
You must be signed in to change notification settings - Fork 261
Expand file tree
/
Copy pathcosmosdb_service.py
More file actions
202 lines (177 loc) · 7.08 KB
/
cosmosdb_service.py
File metadata and controls
202 lines (177 loc) · 7.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import uuid
from datetime import datetime
from azure.cosmos import exceptions
from azure.cosmos.aio import CosmosClient
class CosmosConversationClient:
def __init__(
self,
cosmosdb_endpoint: str,
credential: any,
database_name: str,
container_name: str,
enable_message_feedback: bool = False,
):
self.cosmosdb_endpoint = cosmosdb_endpoint
self.credential = credential
self.database_name = database_name
self.container_name = container_name
self.enable_message_feedback = enable_message_feedback
try:
self.cosmosdb_client = CosmosClient(
self.cosmosdb_endpoint, credential=credential
)
except exceptions.CosmosHttpResponseError as e:
if e.status_code == 401:
raise ValueError("Invalid credentials") from e
else:
raise ValueError("Invalid CosmosDB endpoint") from e
try:
self.database_client = self.cosmosdb_client.get_database_client(
database_name
)
except exceptions.CosmosResourceNotFoundError:
raise ValueError("Invalid CosmosDB database name")
try:
self.container_client = self.database_client.get_container_client(
container_name
)
except exceptions.CosmosResourceNotFoundError:
raise ValueError("Invalid CosmosDB container name")
async def ensure(self):
if (
not self.cosmosdb_client
or not self.database_client
or not self.container_client
):
return False, "CosmosDB client not initialized correctly"
try:
await self.database_client.read()
except Exception:
return (
False,
f"CosmosDB database {self.database_name} on account {self.cosmosdb_endpoint} not found",
)
try:
await self.container_client.read()
except Exception:
return False, f"CosmosDB container {self.container_name} not found"
return True, "CosmosDB client initialized successfully"
async def create_conversation(
self, user_id, conversation_id=str(uuid.uuid4()), title=""
):
conversation = {
"id": conversation_id,
"type": "conversation",
"createdAt": datetime.utcnow().isoformat(),
"updatedAt": datetime.utcnow().isoformat(),
"userId": user_id,
"title": title,
"conversation_id": conversation_id,
}
# TODO: add some error handling based on the output of the upsert_item call
resp = await self.container_client.upsert_item(conversation)
if resp:
return resp
else:
return False
async def upsert_conversation(self, conversation):
resp = await self.container_client.upsert_item(conversation)
if resp:
return resp
else:
return False
async def delete_conversation(self, user_id, conversation_id):
conversation = await self.container_client.read_item(
item=conversation_id, partition_key=user_id
)
if conversation:
resp = await self.container_client.delete_item(
item=conversation_id, partition_key=user_id
)
return resp
else:
return True
async def delete_messages(self, conversation_id, user_id):
# get a list of all the messages in the conversation
messages = await self.get_messages(user_id, conversation_id)
response_list = []
if messages:
for message in messages:
resp = await self.container_client.delete_item(
item=message["id"], partition_key=user_id
)
response_list.append(resp)
return response_list
async def get_conversations(self, user_id, limit, sort_order="DESC", offset=0):
parameters = [{"name": "@userId", "value": user_id}]
query = f"SELECT * FROM c where c.userId = @userId and c.type='conversation' order by c.updatedAt {sort_order}"
if limit is not None:
query += f" offset {offset} limit {limit}"
conversations = []
async for item in self.container_client.query_items(
query=query, parameters=parameters
):
conversations.append(item)
return conversations
async def get_conversation(self, user_id, conversation_id):
parameters = [
{"name": "@conversationId", "value": conversation_id},
{"name": "@userId", "value": user_id},
]
query = "SELECT * FROM c where c.id = @conversationId and c.type='conversation' and c.userId = @userId"
conversations = []
async for item in self.container_client.query_items(
query=query, parameters=parameters
):
conversations.append(item)
# if no conversations are found, return None
if len(conversations) == 0:
return None
else:
return conversations[0]
async def create_message(self, uuid, conversation_id, user_id, input_message: dict):
message = {
"id": uuid,
"type": "message",
"userId": user_id,
"createdAt": datetime.utcnow().isoformat(),
"updatedAt": datetime.utcnow().isoformat(),
"conversationId": conversation_id,
"role": input_message["role"],
"content": input_message,
}
if self.enable_message_feedback:
message["feedback"] = ""
resp = await self.container_client.upsert_item(message)
if resp:
# update the parent conversations's updatedAt field with the current message's createdAt datetime value
conversation = await self.get_conversation(user_id, conversation_id)
if not conversation:
return "Conversation not found"
conversation["updatedAt"] = message["createdAt"]
await self.upsert_conversation(conversation)
return resp
else:
return False
async def update_message_feedback(self, user_id, message_id, feedback):
message = await self.container_client.read_item(
item=message_id, partition_key=user_id
)
if message:
message["feedback"] = feedback
resp = await self.container_client.upsert_item(message)
return resp
else:
return False
async def get_messages(self, user_id, conversation_id):
parameters = [
{"name": "@conversationId", "value": conversation_id},
{"name": "@userId", "value": user_id},
]
query = "SELECT * FROM c WHERE c.conversationId = @conversationId AND c.type='message' AND c.userId = @userId ORDER BY c.timestamp ASC"
messages = []
async for item in self.container_client.query_items(
query=query, parameters=parameters
):
messages.append(item)
return messages