diff --git a/src/openlayer/lib/__init__.py b/src/openlayer/lib/__init__.py index c46e72c1..6bf3ec9a 100644 --- a/src/openlayer/lib/__init__.py +++ b/src/openlayer/lib/__init__.py @@ -39,6 +39,18 @@ def trace_openai(client): return openai_tracer.trace_openai(client) +def trace_async_openai(client): + """Trace OpenAI chat completions.""" + # pylint: disable=import-outside-toplevel + import openai + + from .integrations import async_openai_tracer + + if not isinstance(client, (openai.AsyncOpenAI, openai.AsyncAzureOpenAI)): + raise ValueError("Invalid client. Please provide an OpenAI client.") + return async_openai_tracer.trace_async_openai(client) + + def trace_openai_assistant_thread_run(client, run): """Trace OpenAI Assistant thread run.""" # pylint: disable=import-outside-toplevel diff --git a/src/openlayer/lib/integrations/async_openai_tracer.py b/src/openlayer/lib/integrations/async_openai_tracer.py new file mode 100644 index 00000000..4e65f45a --- /dev/null +++ b/src/openlayer/lib/integrations/async_openai_tracer.py @@ -0,0 +1,264 @@ +"""Module with methods used to trace async OpenAI / Azure OpenAI LLMs.""" + +import json +import logging +import time +from functools import wraps +from typing import Any, Dict, Iterator, Optional, Union + +import openai + +from .openai_tracer import ( + get_model_parameters, + create_trace_args, + add_to_trace, + parse_non_streaming_output_data, +) + +logger = logging.getLogger(__name__) + + +def trace_async_openai( + client: Union[openai.AsyncOpenAI, openai.AsyncAzureOpenAI], +) -> Union[openai.AsyncOpenAI, openai.AsyncAzureOpenAI]: + """Patch the AsyncOpenAI or AsyncAzureOpenAI client to trace chat completions. + + The following information is collected for each chat completion: + - start_time: The time when the completion was requested. + - end_time: The time when the completion was received. + - latency: The time it took to generate the completion. + - tokens: The total number of tokens used to generate the completion. + - prompt_tokens: The number of tokens in the prompt. + - completion_tokens: The number of tokens in the completion. + - model: The model used to generate the completion. + - model_parameters: The parameters used to configure the model. + - raw_output: The raw output of the model. + - inputs: The inputs used to generate the completion. + - metadata: Additional metadata about the completion. For example, the time it + took to generate the first token, when streaming. + + Parameters + ---------- + client : Union[openai.AsyncOpenAI, openai.AsyncAzureOpenAI] + The AsyncOpenAI client to patch. + + Returns + ------- + Union[openai.AsyncOpenAI, openai.AsyncAzureOpenAI] + The patched AsyncOpenAI client. + """ + is_azure_openai = isinstance(client, openai.AsyncAzureOpenAI) + create_func = client.chat.completions.create + + @wraps(create_func) + async def traced_create_func(*args, **kwargs): + inference_id = kwargs.pop("inference_id", None) + stream = kwargs.get("stream", False) + + if stream: + return await handle_async_streaming_create( + *args, + **kwargs, + create_func=create_func, + inference_id=inference_id, + is_azure_openai=is_azure_openai, + ) + return await handle_async_non_streaming_create( + *args, + **kwargs, + create_func=create_func, + inference_id=inference_id, + is_azure_openai=is_azure_openai, + ) + + client.chat.completions.create = traced_create_func + return client + + +async def handle_async_streaming_create( + create_func: callable, + *args, + is_azure_openai: bool = False, + inference_id: Optional[str] = None, + **kwargs, +) -> Iterator[Any]: + """Handles the create method when streaming is enabled. + + Parameters + ---------- + create_func : callable + The create method to handle. + is_azure_openai : bool, optional + Whether the client is an Azure OpenAI client, by default False + inference_id : Optional[str], optional + A user-generated inference id, by default None + + Returns + ------- + Iterator[Any] + A generator that yields the chunks of the completion. + """ + chunks = await create_func(*args, **kwargs) + return await stream_async_chunks( + chunks=chunks, + kwargs=kwargs, + inference_id=inference_id, + is_azure_openai=is_azure_openai, + ) + + +async def stream_async_chunks( + chunks: Iterator[Any], + kwargs: Dict[str, any], + is_azure_openai: bool = False, + inference_id: Optional[str] = None, +): + """Streams the chunks of the completion and traces the completion.""" + collected_output_data = [] + collected_function_call = { + "name": "", + "arguments": "", + } + raw_outputs = [] + start_time = time.time() + end_time = None + first_token_time = None + num_of_completion_tokens = None + latency = None + try: + i = 0 + async for chunk in chunks: + raw_outputs.append(chunk.model_dump()) + if i == 0: + first_token_time = time.time() + if i > 0: + num_of_completion_tokens = i + 1 + i += 1 + + delta = chunk.choices[0].delta + + if delta.content: + collected_output_data.append(delta.content) + elif delta.function_call: + if delta.function_call.name: + collected_function_call["name"] += delta.function_call.name + if delta.function_call.arguments: + collected_function_call["arguments"] += ( + delta.function_call.arguments + ) + elif delta.tool_calls: + if delta.tool_calls[0].function.name: + collected_function_call["name"] += delta.tool_calls[0].function.name + if delta.tool_calls[0].function.arguments: + collected_function_call["arguments"] += delta.tool_calls[ + 0 + ].function.arguments + + yield chunk + end_time = time.time() + latency = (end_time - start_time) * 1000 + # pylint: disable=broad-except + except Exception as e: + logger.error("Failed yield chunk. %s", e) + finally: + # Try to add step to the trace + try: + collected_output_data = [ + message for message in collected_output_data if message is not None + ] + if collected_output_data: + output_data = "".join(collected_output_data) + else: + collected_function_call["arguments"] = json.loads( + collected_function_call["arguments"] + ) + output_data = collected_function_call + + trace_args = create_trace_args( + end_time=end_time, + inputs={"prompt": kwargs["messages"]}, + output=output_data, + latency=latency, + tokens=num_of_completion_tokens, + prompt_tokens=0, + completion_tokens=num_of_completion_tokens, + model=kwargs.get("model"), + model_parameters=get_model_parameters(kwargs), + raw_output=raw_outputs, + id=inference_id, + metadata={ + "timeToFirstToken": ( + (first_token_time - start_time) * 1000 + if first_token_time + else None + ) + }, + ) + add_to_trace( + **trace_args, + is_azure_openai=is_azure_openai, + ) + + # pylint: disable=broad-except + except Exception as e: + logger.error( + "Failed to trace the create chat completion request with Openlayer. %s", + e, + ) + + +async def handle_async_non_streaming_create( + create_func: callable, + *args, + is_azure_openai: bool = False, + inference_id: Optional[str] = None, + **kwargs, +) -> "openai.types.chat.chat_completion.ChatCompletion": + """Handles the create method when streaming is disabled. + + Parameters + ---------- + create_func : callable + The create method to handle. + is_azure_openai : bool, optional + Whether the client is an Azure OpenAI client, by default False + inference_id : Optional[str], optional + A user-generated inference id, by default None + + Returns + ------- + openai.types.chat.chat_completion.ChatCompletion + The chat completion response. + """ + start_time = time.time() + response = await create_func(*args, **kwargs) + end_time = time.time() + + # Try to add step to the trace + try: + output_data = parse_non_streaming_output_data(response) + trace_args = create_trace_args( + end_time=end_time, + inputs={"prompt": kwargs["messages"]}, + output=output_data, + latency=(end_time - start_time) * 1000, + tokens=response.usage.total_tokens, + prompt_tokens=response.usage.prompt_tokens, + completion_tokens=response.usage.completion_tokens, + model=response.model, + model_parameters=get_model_parameters(kwargs), + raw_output=response.model_dump(), + id=inference_id, + ) + + add_to_trace( + is_azure_openai=is_azure_openai, + **trace_args, + ) + # pylint: disable=broad-except + except Exception as e: + logger.error( + "Failed to trace the create chat completion request with Openlayer. %s", e + ) + + return response diff --git a/src/openlayer/lib/integrations/openai_tracer.py b/src/openlayer/lib/integrations/openai_tracer.py index 064c35a9..e3faab0d 100644 --- a/src/openlayer/lib/integrations/openai_tracer.py +++ b/src/openlayer/lib/integrations/openai_tracer.py @@ -137,12 +137,16 @@ def stream_chunks( if delta.function_call.name: collected_function_call["name"] += delta.function_call.name if delta.function_call.arguments: - collected_function_call["arguments"] += delta.function_call.arguments + collected_function_call["arguments"] += ( + delta.function_call.arguments + ) elif delta.tool_calls: if delta.tool_calls[0].function.name: collected_function_call["name"] += delta.tool_calls[0].function.name if delta.tool_calls[0].function.arguments: - collected_function_call["arguments"] += delta.tool_calls[0].function.arguments + collected_function_call["arguments"] += delta.tool_calls[ + 0 + ].function.arguments yield chunk end_time = time.time() @@ -153,11 +157,15 @@ def stream_chunks( finally: # Try to add step to the trace try: - collected_output_data = [message for message in collected_output_data if message is not None] + collected_output_data = [ + message for message in collected_output_data if message is not None + ] if collected_output_data: output_data = "".join(collected_output_data) else: - collected_function_call["arguments"] = json.loads(collected_function_call["arguments"]) + collected_function_call["arguments"] = json.loads( + collected_function_call["arguments"] + ) output_data = collected_function_call trace_args = create_trace_args( @@ -172,7 +180,13 @@ def stream_chunks( model_parameters=get_model_parameters(kwargs), raw_output=raw_outputs, id=inference_id, - metadata={"timeToFirstToken": ((first_token_time - start_time) * 1000 if first_token_time else None)}, + metadata={ + "timeToFirstToken": ( + (first_token_time - start_time) * 1000 + if first_token_time + else None + ) + }, ) add_to_trace( **trace_args, @@ -240,8 +254,12 @@ def create_trace_args( def add_to_trace(is_azure_openai: bool = False, **kwargs) -> None: """Add a chat completion step to the trace.""" if is_azure_openai: - tracer.add_chat_completion_step_to_trace(**kwargs, name="Azure OpenAI Chat Completion", provider="Azure") - tracer.add_chat_completion_step_to_trace(**kwargs, name="OpenAI Chat Completion", provider="OpenAI") + tracer.add_chat_completion_step_to_trace( + **kwargs, name="Azure OpenAI Chat Completion", provider="Azure" + ) + tracer.add_chat_completion_step_to_trace( + **kwargs, name="OpenAI Chat Completion", provider="OpenAI" + ) def handle_non_streaming_create( @@ -294,7 +312,9 @@ def handle_non_streaming_create( ) # pylint: disable=broad-except except Exception as e: - logger.error("Failed to trace the create chat completion request with Openlayer. %s", e) + logger.error( + "Failed to trace the create chat completion request with Openlayer. %s", e + ) return response @@ -336,7 +356,9 @@ def parse_non_streaming_output_data( # --------------------------- OpenAI Assistants API -------------------------- # -def trace_openai_assistant_thread_run(client: openai.OpenAI, run: "openai.types.beta.threads.run.Run") -> None: +def trace_openai_assistant_thread_run( + client: openai.OpenAI, run: "openai.types.beta.threads.run.Run" +) -> None: """Trace a run from an OpenAI assistant. Once the run is completed, the thread data is published to Openlayer, @@ -353,7 +375,9 @@ def trace_openai_assistant_thread_run(client: openai.OpenAI, run: "openai.types. metadata = _extract_run_metadata(run) # Convert thread to prompt - messages = client.beta.threads.messages.list(thread_id=run.thread_id, order="asc") + messages = client.beta.threads.messages.list( + thread_id=run.thread_id, order="asc" + ) prompt = _thread_messages_to_prompt(messages) # Add step to the trace