2222import base64
2323import hashlib
2424import datetime
25- import requests
2625import os
27- from typing import Optional
26+ import logging
27+ import aiohttp
28+ from typing import Optional , Any
29+ from open_webui .env import AIOHTTP_CLIENT_TIMEOUT , SRC_LOG_LEVELS
30+ from cryptography .fernet import Fernet , InvalidToken
2831import tiktoken
29- from pydantic import BaseModel , Field
32+ from pydantic import BaseModel , Field , GetCoreSchemaHandler
33+ from pydantic_core import core_schema
3034
3135# Global variables to track start time and token counts
3236global start_time , request_token_count , response_token_count
37+ start_time = 0
38+ request_token_count = 0
39+ response_token_count = 0
40+
41+ # Simplified encryption implementation with automatic handling
42+ class EncryptedStr (str ):
43+ """A string type that automatically handles encryption/decryption"""
44+
45+ @classmethod
46+ def _get_encryption_key (cls ) -> Optional [bytes ]:
47+ """
48+ Generate encryption key from WEBUI_SECRET_KEY if available
49+ Returns None if no key is configured
50+ """
51+ secret = os .getenv ("WEBUI_SECRET_KEY" )
52+ if not secret :
53+ return None
54+
55+ hashed_key = hashlib .sha256 (secret .encode ()).digest ()
56+ return base64 .urlsafe_b64encode (hashed_key )
57+
58+ @classmethod
59+ def encrypt (cls , value : str ) -> str :
60+ """
61+ Encrypt a string value if a key is available
62+ Returns the original value if no key is available
63+ """
64+ if not value or value .startswith ("encrypted:" ):
65+ return value
66+
67+ key = cls ._get_encryption_key ()
68+ if not key : # No encryption if no key
69+ return value
70+
71+ f = Fernet (key )
72+ encrypted = f .encrypt (value .encode ())
73+ return f"encrypted:{ encrypted .decode ()} "
74+
75+ @classmethod
76+ def decrypt (cls , value : str ) -> str :
77+ """
78+ Decrypt an encrypted string value if a key is available
79+ Returns the original value if no key is available or decryption fails
80+ """
81+ if not value or not value .startswith ("encrypted:" ):
82+ return value
83+
84+ key = cls ._get_encryption_key ()
85+ if not key : # No decryption if no key
86+ return value [len ("encrypted:" ) :] # Return without prefix
3387
88+ try :
89+ encrypted_part = value [len ("encrypted:" ) :]
90+ f = Fernet (key )
91+ decrypted = f .decrypt (encrypted_part .encode ())
92+ return decrypted .decode ()
93+ except (InvalidToken , Exception ):
94+ return value
95+
96+ # Pydantic integration
97+ @classmethod
98+ def __get_pydantic_core_schema__ (
99+ cls , _source_type : Any , _handler : GetCoreSchemaHandler
100+ ) -> core_schema .CoreSchema :
101+ return core_schema .union_schema (
102+ [
103+ core_schema .is_instance_schema (cls ),
104+ core_schema .chain_schema (
105+ [
106+ core_schema .str_schema (),
107+ core_schema .no_info_plain_validator_function (
108+ lambda value : cls (cls .encrypt (value ) if value else value )
109+ ),
110+ ]
111+ ),
112+ ],
113+ serialization = core_schema .plain_serializer_function_ser_schema (
114+ lambda instance : str (instance )
115+ ),
116+ )
117+
118+ def get_decrypted (self ) -> str :
119+ """Get the decrypted value"""
120+ return self .decrypt (self )
121+
122+ # Helper functions
123+ async def cleanup_response (
124+ response : Optional [aiohttp .ClientResponse ],
125+ session : Optional [aiohttp .ClientSession ],
126+ ) -> None :
127+ """
128+ Clean up the response and session objects.
129+
130+ Args:
131+ response: The ClientResponse object to close
132+ session: The ClientSession object to close
133+ """
134+ if response :
135+ response .close ()
136+ if session :
137+ await session .close ()
34138
35139class Filter :
36140 class Valves (BaseModel ):
@@ -54,9 +158,15 @@ class Valves(BaseModel):
54158 SHOW_TOKENS_PER_SECOND : bool = Field (
55159 default = True , description = "Show tokens per second for the response."
56160 )
57- SEND_TO_LOG_ANALYTICS : bool = os .getenv ("SEND_TO_LOG_ANALYTICS" , "False" )
58- LOG_ANALYTICS_WORKSPACE_ID : str = os .getenv ("LOG_ANALYTICS_WORKSPACE_ID" , "" )
59- LOG_ANALYTICS_SHARED_KEY : str = os .getenv ("LOG_ANALYTICS_SHARED_KEY" , "" )
161+ SEND_TO_LOG_ANALYTICS : bool = Field (
162+ default = bool (os .getenv ("SEND_TO_LOG_ANALYTICS" , False )), description = "Send logs to Azure Log Analytics workspace"
163+ )
164+ LOG_ANALYTICS_WORKSPACE_ID : str = Field (
165+ default = os .getenv ("LOG_ANALYTICS_WORKSPACE_ID" , "" ), description = "Azure Log Analytics Workspace ID"
166+ )
167+ LOG_ANALYTICS_SHARED_KEY : EncryptedStr = Field (
168+ default = os .getenv ("LOG_ANALYTICS_SHARED_KEY" , "" ), description = "Azure Log Analytics Workspace Shared Key"
169+ )
60170 LOG_ANALYTICS_LOG_TYPE : str = Field (
61171 default = "OpenWebuiMetrics" , description = "Log Analytics log type name."
62172 )
@@ -80,30 +190,33 @@ def _build_signature(self, date, content_length, method, content_type, resource)
80190 + resource
81191 )
82192 bytes_to_hash = string_to_hash .encode ("utf-8" )
83- decoded_key = base64 .b64decode (self .valves .LOG_ANALYTICS_SHARED_KEY )
193+ decoded_key = base64 .b64decode (self .valves .LOG_ANALYTICS_SHARED_KEY . get_decrypted ())
84194 encoded_hash = base64 .b64encode (
85195 hmac .new (decoded_key , bytes_to_hash , digestmod = hashlib .sha256 ).digest ()
86196 ).decode ("utf-8" )
87197 authorization = (
88198 f"SharedKey { self .valves .LOG_ANALYTICS_WORKSPACE_ID } :{ encoded_hash } "
89199 )
90200 return authorization
91-
92- def _send_to_log_analytics (self , data ):
93- """Send data to Azure Log Analytics."""
201+
202+ async def _send_to_log_analytics_async (self , data ):
203+ """Send data to Azure Log Analytics asynchronously using aiohttp ."""
94204 if (
95205 not self .valves .SEND_TO_LOG_ANALYTICS
96206 or not self .valves .LOG_ANALYTICS_WORKSPACE_ID
97207 or not self .valves .LOG_ANALYTICS_SHARED_KEY
98208 ):
99209 return False
210+
211+ log = logging .getLogger ("time_token_tracker._send_to_log_analytics_async" )
212+ log .setLevel (SRC_LOG_LEVELS ["OPENAI" ])
100213
101214 method = "POST"
102215 content_type = "application/json"
103216 resource = "/api/logs"
104- rfc1123date = datetime .datetime .utcnow ( ).strftime ("%a, %d %b %Y %H:%M:%S GMT" )
217+ rfc1123date = datetime .datetime .now ( datetime . timezone . utc ).strftime ("%a, %d %b %Y %H:%M:%S GMT" )
105218 content_length = len (json .dumps (data ))
106-
219+
107220 signature = self ._build_signature (
108221 rfc1123date , content_length , method , content_type , resource
109222 )
@@ -118,18 +231,36 @@ def _send_to_log_analytics(self, data):
118231 "time-generated-field" : "timestamp" ,
119232 }
120233
234+ session = None
235+ response = None
236+
121237 try :
122- response = requests .post (uri , json = data , headers = headers )
123- if response .status_code == 200 :
238+ session = aiohttp .ClientSession (
239+ trust_env = True ,
240+ timeout = aiohttp .ClientTimeout (total = AIOHTTP_CLIENT_TIMEOUT ),
241+ )
242+
243+ response = await session .request (
244+ method = "POST" ,
245+ url = uri ,
246+ json = data ,
247+ headers = headers ,
248+ )
249+
250+ if response .status == 200 :
124251 return True
125252 else :
126- print (
127- f"Error sending to Log Analytics: { response .status_code } - { response .text } "
253+ response_text = await response .text ()
254+ log .error (
255+ f"Error sending to Log Analytics: { response .status } - { response_text } "
128256 )
129257 return False
258+
130259 except Exception as e :
131- print (f"Exception when sending to Log Analytics: { str (e )} " )
260+ log . error (f"Exception when sending to Log Analytics asynchronously : { str (e )} " )
132261 return False
262+ finally :
263+ await cleanup_response (response , session )
133264
134265 async def inlet (
135266 self , body : dict , __user__ : Optional [dict ] = None , __event_emitter__ = None
@@ -172,16 +303,61 @@ async def inlet(
172303 request_messages = [last_user_system ] if last_user_system else []
173304
174305 request_token_count = sum (
175- len (encoding .encode (m . get ( "content" , "" )))
306+ len (encoding .encode (self . _get_message_content ( m )))
176307 for m in request_messages
177- if m and isinstance ( m . get ( "content" ), str )
308+ if m
178309 )
179310
180311 return body
181312
313+ def _get_message_content (self , message ):
314+ """Extract content from a message, handling different formats."""
315+ content = message .get ("content" , "" )
316+
317+ # Handle None content
318+ if content is None :
319+ content = ""
320+
321+ # Handle string content
322+ if isinstance (content , str ):
323+ return content
324+
325+ # Handle list content (e.g., for messages with multiple content parts)
326+ if isinstance (content , list ):
327+ text_parts = []
328+ for part in content :
329+ if isinstance (part , dict ):
330+ if part .get ("type" ) == "text" :
331+ text_parts .append (part .get ("text" , "" ))
332+ else :
333+ # Try to convert other types to string
334+ try :
335+ text_parts .append (str (part ))
336+ except :
337+ pass
338+ return " " .join (text_parts )
339+
340+ # Handle function_call in message
341+ if message .get ("function_call" ):
342+ try :
343+ func_call = message ["function_call" ]
344+ func_str = f"function: { func_call .get ('name' , '' )} , arguments: { func_call .get ('arguments' , '' )} "
345+ return func_str
346+ except :
347+ return ""
348+
349+ # If nothing else works, try converting to string or return empty
350+ try :
351+ return str (content )
352+ except :
353+ return ""
354+
182355 async def outlet (
183356 self , body : dict , __user__ : Optional [dict ] = None , __event_emitter__ = None
184357 ) -> dict :
358+ log = logging .getLogger ("time_token_tracker.outlet" )
359+ log .setLevel (SRC_LOG_LEVELS ["OPENAI" ])
360+
185361 global start_time , request_token_count , response_token_count
186362 end_time = time .time ()
187363 response_time = end_time - start_time
@@ -194,9 +370,7 @@ async def outlet(
194370 except KeyError :
195371 encoding = tiktoken .get_encoding ("cl100k_base" )
196372
197- reversed_messages = list (reversed (all_messages ))
198-
199- # If CALCULATE_ALL_MESSAGES is true, use all "assistant" messages
373+ reversed_messages = list (reversed (all_messages )) # If CALCULATE_ALL_MESSAGES is true, use all "assistant" messages
200374 if self .valves .CALCULATE_ALL_MESSAGES :
201375 assistant_messages = [
202376 m for m in all_messages if m .get ("role" ) == "assistant"
@@ -209,22 +383,18 @@ async def outlet(
209383 assistant_messages = [last_assistant ] if last_assistant else []
210384
211385 response_token_count = sum (
212- len (encoding .encode (m . get ( "content" , "" )))
386+ len (encoding .encode (self . _get_message_content ( m )))
213387 for m in assistant_messages
214- if m and isinstance (m .get ("content" ), str )
215- )
216-
217- # Calculate tokens per second (only for the last assistant response)
388+ if m
389+ ) # Calculate tokens per second (only for the last assistant response)
218390 resp_tokens_per_sec = 0
219391 if self .valves .SHOW_TOKENS_PER_SECOND :
220392 last_assistant_msg = next (
221393 (m for m in reversed_messages if m .get ("role" ) == "assistant" ), None
222394 )
223395 last_assistant_tokens = (
224- len (encoding .encode (last_assistant_msg .get ("content" , "" )))
225- if last_assistant_msg
226- and isinstance (last_assistant_msg .get ("content" ), str )
227- else 0
396+ len (encoding .encode (self ._get_message_content (last_assistant_msg )))
397+ if last_assistant_msg else 0
228398 )
229399 resp_tokens_per_sec = (
230400 0 if response_time == 0 else last_assistant_tokens / response_time
@@ -292,18 +462,17 @@ async def outlet(
292462 # Add averages if calculated
293463 if self .valves .SHOW_AVERAGE_TOKENS and self .valves .CALCULATE_ALL_MESSAGES :
294464 log_data [0 ]["avgRequestTokens" ] = avg_request_tokens
295- log_data [0 ]["avgResponseTokens" ] = avg_response_tokens
296-
465+ log_data [0 ]["avgResponseTokens" ] = avg_response_tokens
466+
297467 # Send to Log Analytics asynchronously (non-blocking)
298- # For true async, you might want to use asyncio or threading
299468 try :
300- import threading
301-
302- threading . Thread (
303- target = self . _send_to_log_analytics , args = ( log_data ,)
304- ). start ( )
305- except :
306- # Fallback to synchronous if threading fails
307- self . _send_to_log_analytics ( log_data )
308-
469+ result = await self . _send_to_log_analytics_async ( log_data )
470+ if result :
471+ log . info ( f"Log Analytics data sent successfully" )
472+ else :
473+ log . warning ( f"Failed to send data to Log Analytics" )
474+ except Exception as e :
475+ # Handle exceptions during sending to Log Analytics
476+ log . error ( f"Error sending to Log Analytics: { e } " )
477+
309478 return body
0 commit comments