Skip to content

Commit 2d677bf

Browse files
authored
Encryption added and other improvements
1 parent c323025 commit 2d677bf

1 file changed

Lines changed: 212 additions & 43 deletions

File tree

filters/time_token_tracker.py

Lines changed: 212 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,119 @@
2222
import base64
2323
import hashlib
2424
import datetime
25-
import requests
2625
import os
27-
from typing import Optional
26+
import logging
27+
import aiohttp
28+
from typing import Optional, Any
29+
from open_webui.env import AIOHTTP_CLIENT_TIMEOUT, SRC_LOG_LEVELS
30+
from cryptography.fernet import Fernet, InvalidToken
2831
import tiktoken
29-
from pydantic import BaseModel, Field
32+
from pydantic import BaseModel, Field, GetCoreSchemaHandler
33+
from pydantic_core import core_schema
3034

3135
# Global variables to track start time and token counts
3236
global start_time, request_token_count, response_token_count
37+
start_time = 0
38+
request_token_count = 0
39+
response_token_count = 0
40+
41+
# Simplified encryption implementation with automatic handling
42+
class EncryptedStr(str):
43+
"""A string type that automatically handles encryption/decryption"""
44+
45+
@classmethod
46+
def _get_encryption_key(cls) -> Optional[bytes]:
47+
"""
48+
Generate encryption key from WEBUI_SECRET_KEY if available
49+
Returns None if no key is configured
50+
"""
51+
secret = os.getenv("WEBUI_SECRET_KEY")
52+
if not secret:
53+
return None
54+
55+
hashed_key = hashlib.sha256(secret.encode()).digest()
56+
return base64.urlsafe_b64encode(hashed_key)
57+
58+
@classmethod
59+
def encrypt(cls, value: str) -> str:
60+
"""
61+
Encrypt a string value if a key is available
62+
Returns the original value if no key is available
63+
"""
64+
if not value or value.startswith("encrypted:"):
65+
return value
66+
67+
key = cls._get_encryption_key()
68+
if not key: # No encryption if no key
69+
return value
70+
71+
f = Fernet(key)
72+
encrypted = f.encrypt(value.encode())
73+
return f"encrypted:{encrypted.decode()}"
74+
75+
@classmethod
76+
def decrypt(cls, value: str) -> str:
77+
"""
78+
Decrypt an encrypted string value if a key is available
79+
Returns the original value if no key is available or decryption fails
80+
"""
81+
if not value or not value.startswith("encrypted:"):
82+
return value
83+
84+
key = cls._get_encryption_key()
85+
if not key: # No decryption if no key
86+
return value[len("encrypted:") :] # Return without prefix
3387

88+
try:
89+
encrypted_part = value[len("encrypted:") :]
90+
f = Fernet(key)
91+
decrypted = f.decrypt(encrypted_part.encode())
92+
return decrypted.decode()
93+
except (InvalidToken, Exception):
94+
return value
95+
96+
# Pydantic integration
97+
@classmethod
98+
def __get_pydantic_core_schema__(
99+
cls, _source_type: Any, _handler: GetCoreSchemaHandler
100+
) -> core_schema.CoreSchema:
101+
return core_schema.union_schema(
102+
[
103+
core_schema.is_instance_schema(cls),
104+
core_schema.chain_schema(
105+
[
106+
core_schema.str_schema(),
107+
core_schema.no_info_plain_validator_function(
108+
lambda value: cls(cls.encrypt(value) if value else value)
109+
),
110+
]
111+
),
112+
],
113+
serialization=core_schema.plain_serializer_function_ser_schema(
114+
lambda instance: str(instance)
115+
),
116+
)
117+
118+
def get_decrypted(self) -> str:
119+
"""Get the decrypted value"""
120+
return self.decrypt(self)
121+
122+
# Helper functions
123+
async def cleanup_response(
124+
response: Optional[aiohttp.ClientResponse],
125+
session: Optional[aiohttp.ClientSession],
126+
) -> None:
127+
"""
128+
Clean up the response and session objects.
129+
130+
Args:
131+
response: The ClientResponse object to close
132+
session: The ClientSession object to close
133+
"""
134+
if response:
135+
response.close()
136+
if session:
137+
await session.close()
34138

35139
class Filter:
36140
class Valves(BaseModel):
@@ -54,9 +158,15 @@ class Valves(BaseModel):
54158
SHOW_TOKENS_PER_SECOND: bool = Field(
55159
default=True, description="Show tokens per second for the response."
56160
)
57-
SEND_TO_LOG_ANALYTICS: bool = os.getenv("SEND_TO_LOG_ANALYTICS", "False")
58-
LOG_ANALYTICS_WORKSPACE_ID: str = os.getenv("LOG_ANALYTICS_WORKSPACE_ID", "")
59-
LOG_ANALYTICS_SHARED_KEY: str = os.getenv("LOG_ANALYTICS_SHARED_KEY", "")
161+
SEND_TO_LOG_ANALYTICS: bool = Field(
162+
default=bool(os.getenv("SEND_TO_LOG_ANALYTICS", False)), description="Send logs to Azure Log Analytics workspace"
163+
)
164+
LOG_ANALYTICS_WORKSPACE_ID: str = Field(
165+
default=os.getenv("LOG_ANALYTICS_WORKSPACE_ID", ""), description="Azure Log Analytics Workspace ID"
166+
)
167+
LOG_ANALYTICS_SHARED_KEY: EncryptedStr = Field(
168+
default=os.getenv("LOG_ANALYTICS_SHARED_KEY", ""), description="Azure Log Analytics Workspace Shared Key"
169+
)
60170
LOG_ANALYTICS_LOG_TYPE: str = Field(
61171
default="OpenWebuiMetrics", description="Log Analytics log type name."
62172
)
@@ -80,30 +190,33 @@ def _build_signature(self, date, content_length, method, content_type, resource)
80190
+ resource
81191
)
82192
bytes_to_hash = string_to_hash.encode("utf-8")
83-
decoded_key = base64.b64decode(self.valves.LOG_ANALYTICS_SHARED_KEY)
193+
decoded_key = base64.b64decode(self.valves.LOG_ANALYTICS_SHARED_KEY.get_decrypted())
84194
encoded_hash = base64.b64encode(
85195
hmac.new(decoded_key, bytes_to_hash, digestmod=hashlib.sha256).digest()
86196
).decode("utf-8")
87197
authorization = (
88198
f"SharedKey {self.valves.LOG_ANALYTICS_WORKSPACE_ID}:{encoded_hash}"
89199
)
90200
return authorization
91-
92-
def _send_to_log_analytics(self, data):
93-
"""Send data to Azure Log Analytics."""
201+
202+
async def _send_to_log_analytics_async(self, data):
203+
"""Send data to Azure Log Analytics asynchronously using aiohttp."""
94204
if (
95205
not self.valves.SEND_TO_LOG_ANALYTICS
96206
or not self.valves.LOG_ANALYTICS_WORKSPACE_ID
97207
or not self.valves.LOG_ANALYTICS_SHARED_KEY
98208
):
99209
return False
210+
211+
log = logging.getLogger("time_token_tracker._send_to_log_analytics_async")
212+
log.setLevel(SRC_LOG_LEVELS["OPENAI"])
100213

101214
method = "POST"
102215
content_type = "application/json"
103216
resource = "/api/logs"
104-
rfc1123date = datetime.datetime.utcnow().strftime("%a, %d %b %Y %H:%M:%S GMT")
217+
rfc1123date = datetime.datetime.now(datetime.timezone.utc).strftime("%a, %d %b %Y %H:%M:%S GMT")
105218
content_length = len(json.dumps(data))
106-
219+
107220
signature = self._build_signature(
108221
rfc1123date, content_length, method, content_type, resource
109222
)
@@ -118,18 +231,36 @@ def _send_to_log_analytics(self, data):
118231
"time-generated-field": "timestamp",
119232
}
120233

234+
session = None
235+
response = None
236+
121237
try:
122-
response = requests.post(uri, json=data, headers=headers)
123-
if response.status_code == 200:
238+
session = aiohttp.ClientSession(
239+
trust_env=True,
240+
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT),
241+
)
242+
243+
response = await session.request(
244+
method="POST",
245+
url=uri,
246+
json=data,
247+
headers=headers,
248+
)
249+
250+
if response.status == 200:
124251
return True
125252
else:
126-
print(
127-
f"Error sending to Log Analytics: {response.status_code} - {response.text}"
253+
response_text = await response.text()
254+
log.error(
255+
f"Error sending to Log Analytics: {response.status} - {response_text}"
128256
)
129257
return False
258+
130259
except Exception as e:
131-
print(f"Exception when sending to Log Analytics: {str(e)}")
260+
log.error(f"Exception when sending to Log Analytics asynchronously: {str(e)}")
132261
return False
262+
finally:
263+
await cleanup_response(response, session)
133264

134265
async def inlet(
135266
self, body: dict, __user__: Optional[dict] = None, __event_emitter__=None
@@ -172,16 +303,61 @@ async def inlet(
172303
request_messages = [last_user_system] if last_user_system else []
173304

174305
request_token_count = sum(
175-
len(encoding.encode(m.get("content", "")))
306+
len(encoding.encode(self._get_message_content(m)))
176307
for m in request_messages
177-
if m and isinstance(m.get("content"), str)
308+
if m
178309
)
179310

180311
return body
181312

313+
def _get_message_content(self, message):
314+
"""Extract content from a message, handling different formats."""
315+
content = message.get("content", "")
316+
317+
# Handle None content
318+
if content is None:
319+
content = ""
320+
321+
# Handle string content
322+
if isinstance(content, str):
323+
return content
324+
325+
# Handle list content (e.g., for messages with multiple content parts)
326+
if isinstance(content, list):
327+
text_parts = []
328+
for part in content:
329+
if isinstance(part, dict):
330+
if part.get("type") == "text":
331+
text_parts.append(part.get("text", ""))
332+
else:
333+
# Try to convert other types to string
334+
try:
335+
text_parts.append(str(part))
336+
except:
337+
pass
338+
return " ".join(text_parts)
339+
340+
# Handle function_call in message
341+
if message.get("function_call"):
342+
try:
343+
func_call = message["function_call"]
344+
func_str = f"function: {func_call.get('name', '')}, arguments: {func_call.get('arguments', '')}"
345+
return func_str
346+
except:
347+
return ""
348+
349+
# If nothing else works, try converting to string or return empty
350+
try:
351+
return str(content)
352+
except:
353+
return ""
354+
182355
async def outlet(
183356
self, body: dict, __user__: Optional[dict] = None, __event_emitter__=None
184357
) -> dict:
358+
log = logging.getLogger("time_token_tracker.outlet")
359+
log.setLevel(SRC_LOG_LEVELS["OPENAI"])
360+
185361
global start_time, request_token_count, response_token_count
186362
end_time = time.time()
187363
response_time = end_time - start_time
@@ -194,9 +370,7 @@ async def outlet(
194370
except KeyError:
195371
encoding = tiktoken.get_encoding("cl100k_base")
196372

197-
reversed_messages = list(reversed(all_messages))
198-
199-
# If CALCULATE_ALL_MESSAGES is true, use all "assistant" messages
373+
reversed_messages = list(reversed(all_messages)) # If CALCULATE_ALL_MESSAGES is true, use all "assistant" messages
200374
if self.valves.CALCULATE_ALL_MESSAGES:
201375
assistant_messages = [
202376
m for m in all_messages if m.get("role") == "assistant"
@@ -209,22 +383,18 @@ async def outlet(
209383
assistant_messages = [last_assistant] if last_assistant else []
210384

211385
response_token_count = sum(
212-
len(encoding.encode(m.get("content", "")))
386+
len(encoding.encode(self._get_message_content(m)))
213387
for m in assistant_messages
214-
if m and isinstance(m.get("content"), str)
215-
)
216-
217-
# Calculate tokens per second (only for the last assistant response)
388+
if m
389+
) # Calculate tokens per second (only for the last assistant response)
218390
resp_tokens_per_sec = 0
219391
if self.valves.SHOW_TOKENS_PER_SECOND:
220392
last_assistant_msg = next(
221393
(m for m in reversed_messages if m.get("role") == "assistant"), None
222394
)
223395
last_assistant_tokens = (
224-
len(encoding.encode(last_assistant_msg.get("content", "")))
225-
if last_assistant_msg
226-
and isinstance(last_assistant_msg.get("content"), str)
227-
else 0
396+
len(encoding.encode(self._get_message_content(last_assistant_msg)))
397+
if last_assistant_msg else 0
228398
)
229399
resp_tokens_per_sec = (
230400
0 if response_time == 0 else last_assistant_tokens / response_time
@@ -292,18 +462,17 @@ async def outlet(
292462
# Add averages if calculated
293463
if self.valves.SHOW_AVERAGE_TOKENS and self.valves.CALCULATE_ALL_MESSAGES:
294464
log_data[0]["avgRequestTokens"] = avg_request_tokens
295-
log_data[0]["avgResponseTokens"] = avg_response_tokens
296-
465+
log_data[0]["avgResponseTokens"] = avg_response_tokens
466+
297467
# Send to Log Analytics asynchronously (non-blocking)
298-
# For true async, you might want to use asyncio or threading
299468
try:
300-
import threading
301-
302-
threading.Thread(
303-
target=self._send_to_log_analytics, args=(log_data,)
304-
).start()
305-
except:
306-
# Fallback to synchronous if threading fails
307-
self._send_to_log_analytics(log_data)
308-
469+
result = await self._send_to_log_analytics_async(log_data)
470+
if result:
471+
log.info(f"Log Analytics data sent successfully")
472+
else:
473+
log.warning(f"Failed to send data to Log Analytics")
474+
except Exception as e:
475+
# Handle exceptions during sending to Log Analytics
476+
log.error(f"Error sending to Log Analytics: {e}")
477+
309478
return body

0 commit comments

Comments
 (0)