Skip to content

Commit 1c8622c

Browse files
authored
feature(openai): support async functions for openai versions >1.0.0 (#181)
1 parent 2bacb80 commit 1c8622c

2 files changed

Lines changed: 175 additions & 35 deletions

File tree

langfuse/openai.py

Lines changed: 101 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33
import types
44
from typing import Optional
55

6+
from packaging.version import Version
7+
68

79
from langfuse import Langfuse
810
from langfuse.client import InitialGeneration, CreateTrace, StatefulGenerationClient
911

10-
from distutils.version import StrictVersion
1112
import openai
1213
from wrapt import wrap_function_wrapper
1314

@@ -19,12 +20,14 @@ class OpenAiDefinition:
1920
object: str
2021
method: str
2122
type: str
23+
sync: bool
2224

23-
def __init__(self, module: str, object: str, method: str, type: str):
25+
def __init__(self, module: str, object: str, method: str, type: str, sync: bool):
2426
self.module = module
2527
self.object = object
2628
self.method = method
2729
self.type = type
30+
self.sync = sync
2831

2932

3033
OPENAI_METHODS_V0 = [
@@ -33,28 +36,34 @@ def __init__(self, module: str, object: str, method: str, type: str):
3336
object="ChatCompletion",
3437
method="create",
3538
type="chat",
39+
sync=True,
3640
),
3741
OpenAiDefinition(
3842
module="openai",
3943
object="Completion",
4044
method="create",
4145
type="completion",
46+
sync=True,
4247
),
4348
]
4449

4550

4651
OPENAI_METHODS_V1 = [
52+
OpenAiDefinition(module="openai.resources.chat.completions", object="Completions", method="create", type="chat", sync=True),
53+
OpenAiDefinition(module="openai.resources.completions", object="Completions", method="create", type="completion", sync=True),
4754
OpenAiDefinition(
4855
module="openai.resources.chat.completions",
49-
object="Completions",
56+
object="AsyncCompletions",
5057
method="create",
5158
type="chat",
59+
sync=False,
5260
),
5361
OpenAiDefinition(
5462
module="openai.resources.completions",
55-
object="Completions",
63+
object="AsyncCompletions",
5664
method="create",
5765
type="completion",
66+
sync=False,
5867
),
5968
]
6069

@@ -75,9 +84,9 @@ def get_openai_args(self):
7584

7685

7786
def _langfuse_wrapper(func):
78-
def _with_langfuse(open_ai_definitions, langfuse, initialize):
87+
def _with_langfuse(open_ai_definitions, initialize):
7988
def wrapper(wrapped, instance, args, kwargs):
80-
return func(open_ai_definitions, langfuse, initialize, wrapped, instance, args, kwargs)
89+
return func(open_ai_definitions, initialize, wrapped, args, kwargs)
8190

8291
return wrapper
8392

@@ -130,12 +139,41 @@ def _get_langfuse_data_from_kwargs(resource: OpenAiDefinition, langfuse: Langfus
130139
return InitialGeneration(name=name, metadata=metadata, trace_id=trace_id, start_time=start_time, prompt=prompt, modelParameters=modelParameters, model=model)
131140

132141

133-
def _get_lagnfuse_data_from_streaming_response(resource: OpenAiDefinition, response, generation: StatefulGenerationClient, langfuse: Langfuse):
134-
final_response = [] if resource.type == "chat" else ""
142+
def _get_lagnfuse_data_from_sync_streaming_response(resource: OpenAiDefinition, response, generation: StatefulGenerationClient, langfuse: Langfuse):
143+
responses = []
144+
for i in response:
145+
responses.append(i)
146+
yield i
147+
148+
model, completion_start_time, completion = _extract_data(resource, responses)
149+
150+
_create_langfuse_update(completion, generation, completion_start_time, model=model)
151+
152+
153+
async def _get_lagnfuse_data_from_async_streaming_response(resource: OpenAiDefinition, response, generation: StatefulGenerationClient, langfuse: Langfuse):
154+
responses = []
155+
async for i in response:
156+
responses.append(i)
157+
yield i
158+
159+
model, completion_start_time, completion = _extract_data(resource, responses)
160+
161+
_create_langfuse_update(completion, generation, completion_start_time, model=model)
162+
163+
164+
def _create_langfuse_update(completion, generation: StatefulGenerationClient, completion_start_time, model=None):
165+
update = UpdateGeneration(end_time=datetime.now(), completion=completion, completion_start_time=completion_start_time)
166+
if model is not None:
167+
update = update.copy(update={"model": model})
168+
generation.update(update)
169+
170+
171+
def _extract_data(resource, responses):
172+
completion = [] if resource.type == "chat" else ""
135173
model = None
136174
completion_start_time = None
137-
for index, i in enumerate(response):
138-
print(index)
175+
176+
for index, i in enumerate(responses):
139177
if index == 0:
140178
completion_start_time = datetime.now()
141179

@@ -156,36 +194,31 @@ def _get_lagnfuse_data_from_streaming_response(resource: OpenAiDefinition, respo
156194
delta = delta.__dict__
157195

158196
if delta.get("role", None) is not None:
159-
final_response.append({"role": delta.get("role", None), "function_call": None, "tool_calls": None, "content": None})
197+
completion.append({"role": delta.get("role", None), "function_call": None, "tool_calls": None, "content": None})
160198

161199
elif delta.get("content", None) is not None:
162-
final_response[-1]["content"] = delta.get("content", None) if final_response[-1]["content"] is None else final_response[-1]["content"] + delta.get("content", None)
200+
completion[-1]["content"] = delta.get("content", None) if completion[-1]["content"] is None else completion[-1]["content"] + delta.get("content", None)
163201

164202
elif delta.get("function_call", None) is not None:
165-
final_response[-1]["function_call"] = (
166-
delta.get("function_call", None) if final_response[-1]["function_call"] is None else final_response[-1]["function_call"] + delta.get("function_call", None)
203+
completion[-1]["function_call"] = (
204+
delta.get("function_call", None) if completion[-1]["function_call"] is None else completion[-1]["function_call"] + delta.get("function_call", None)
167205
)
168206
elif delta.get("tools_call", None) is not None:
169-
final_response[-1]["tool_calls"] = delta.get("tools_call", None) if final_response[-1]["tool_calls"] is None else final_response[-1]["tool_calls"] + delta.get("tools_call", None)
207+
completion[-1]["tool_calls"] = delta.get("tools_call", None) if completion[-1]["tool_calls"] is None else completion[-1]["tool_calls"] + delta.get("tools_call", None)
170208
if resource.type == "completion":
171-
final_response += choice.get("text", None)
172-
173-
yield i
209+
completion += choice.get("text", None)
174210

175211
def get_response_for_chat():
176-
if len(final_response) > 0:
177-
if final_response[-1].get("content", None) is not None:
178-
return final_response[-1]["content"]
179-
elif final_response[-1].get("function_call", None) is not None:
180-
return final_response[-1]["function_call"]
181-
elif final_response[-1].get("tool_calls", None) is not None:
182-
return final_response[-1]["tool_calls"]
212+
if len(completion) > 0:
213+
if completion[-1].get("content", None) is not None:
214+
return completion[-1]["content"]
215+
elif completion[-1].get("function_call", None) is not None:
216+
return completion[-1]["function_call"]
217+
elif completion[-1].get("tool_calls", None) is not None:
218+
return completion[-1]["tool_calls"]
183219
return None
184220

185-
update = UpdateGeneration(end_time=datetime.now(), completion=get_response_for_chat() if resource.type == "chat" else final_response, completion_start_time=completion_start_time)
186-
if model is not None:
187-
update = update.copy(update={"model": model})
188-
generation.update(update)
221+
return model, completion_start_time, get_response_for_chat() if resource.type == "chat" else completion
189222

190223

191224
def _get_langfuse_data_from_default_response(resource: OpenAiDefinition, response):
@@ -210,15 +243,15 @@ def _get_langfuse_data_from_default_response(resource: OpenAiDefinition, respons
210243

211244

212245
def _is_openai_v1():
213-
return StrictVersion(openai.__version__) >= StrictVersion("1.0.0")
246+
return Version(openai.__version__) >= Version("1.0.0")
214247

215248

216249
def _is_streaming_response(response):
217-
return isinstance(response, types.GeneratorType) or (_is_openai_v1() and isinstance(response, openai.Stream))
250+
return isinstance(response, types.GeneratorType) or (_is_openai_v1() and isinstance(response, openai.Stream)) or (_is_openai_v1() and isinstance(response, openai.AsyncStream))
218251

219252

220253
@_langfuse_wrapper
221-
def _wrap(open_ai_resource: OpenAiDefinition, langfuse: Langfuse, initialize, wrapped, instance, args, kwargs):
254+
def _wrap(open_ai_resource: OpenAiDefinition, initialize, wrapped, args, kwargs):
222255
new_langfuse = initialize()
223256

224257
start_time = datetime.now()
@@ -230,7 +263,31 @@ def _wrap(open_ai_resource: OpenAiDefinition, langfuse: Langfuse, initialize, wr
230263
openai_response = wrapped(**arg_extractor.get_openai_args())
231264

232265
if _is_streaming_response(openai_response):
233-
return _get_lagnfuse_data_from_streaming_response(open_ai_resource, openai_response, generation, new_langfuse)
266+
return _get_lagnfuse_data_from_sync_streaming_response(open_ai_resource, openai_response, generation, new_langfuse)
267+
268+
else:
269+
model, completion, usage = _get_langfuse_data_from_default_response(open_ai_resource, openai_response.__dict__ if _is_openai_v1() else openai_response)
270+
generation.update(UpdateGeneration(model=model, completion=completion, end_time=datetime.now(), usage=usage))
271+
return openai_response
272+
except Exception as ex:
273+
model = kwargs.get("model", None)
274+
generation.update(UpdateGeneration(endTime=datetime.now(), statusMessage=str(ex), level="ERROR", model=model))
275+
raise ex
276+
277+
278+
@_langfuse_wrapper
279+
async def _wrap_async(open_ai_resource: OpenAiDefinition, initialize, wrapped, args, kwargs):
280+
new_langfuse = initialize()
281+
start_time = datetime.now()
282+
arg_extractor = OpenAiArgsExtractor(*args, **kwargs)
283+
284+
generation = _get_langfuse_data_from_kwargs(open_ai_resource, new_langfuse, start_time, arg_extractor.get_langfuse_args())
285+
generation = new_langfuse.generation(generation)
286+
try:
287+
openai_response = await wrapped(**arg_extractor.get_openai_args())
288+
289+
if _is_streaming_response(openai_response):
290+
return _get_lagnfuse_data_from_async_streaming_response(open_ai_resource, openai_response, generation, new_langfuse)
234291

235292
else:
236293
model, completion, usage = _get_langfuse_data_from_default_response(open_ai_resource, openai_response.__dict__ if _is_openai_v1() else openai_response)
@@ -271,15 +328,24 @@ def register_tracing(self):
271328
wrap_function_wrapper(
272329
resource.module,
273330
f"{resource.object}.{resource.method}",
274-
_wrap(resource, self._langfuse, self.initialize),
331+
_wrap(resource, self.initialize) if resource.sync else _wrap_async(resource, self.initialize),
275332
)
276333

277334
setattr(openai, "langfuse_public_key", None)
278335
setattr(openai, "langfuse_secret_key", None)
279336
setattr(openai, "langfuse_host", None)
280-
281337
setattr(openai, "flush_langfuse", self.flush)
282338

339+
setattr(openai.AsyncOpenAI, "langfuse_public_key", None)
340+
setattr(openai.AsyncOpenAI, "langfuse_secret_key", None)
341+
setattr(openai.AsyncOpenAI, "langfuse_host", None)
342+
setattr(openai.AsyncOpenAI, "flush_langfuse", self.flush)
343+
344+
setattr(openai.OpenAI, "langfuse_public_key", None)
345+
setattr(openai.OpenAI, "langfuse_secret_key", None)
346+
setattr(openai.OpenAI, "langfuse_host", None)
347+
setattr(openai.OpenAI, "flush_langfuse", self.flush)
348+
283349

284350
modifier = OpenAILangfuse()
285351
modifier.register_tracing()

tests/test_openai.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from langfuse.openai import _is_openai_v1, _is_streaming_response, openai
66

77
from tests.utils import create_uuid, get_api
8+
from openai import AsyncOpenAI
89

910

1011
chat_func = openai.chat.completions.create if _is_openai_v1() else openai.ChatCompletion.create
@@ -460,3 +461,76 @@ def test_fails_wrong_trace_id():
460461
prompt="1 + 1 = ",
461462
temperature=0,
462463
)
464+
465+
466+
@pytest.mark.asyncio
467+
async def test_async_chat():
468+
api = get_api()
469+
client = AsyncOpenAI()
470+
generation_name = create_uuid()
471+
472+
completion = await client.chat.completions.create(messages=[{"role": "user", "content": "1 + 1 = "}], model="gpt-3.5-turbo", name=generation_name)
473+
474+
client.flush_langfuse()
475+
print(completion)
476+
477+
generation = api.observations.get_many(name=generation_name, type="GENERATION")
478+
479+
assert len(generation.data) != 0
480+
assert generation.data[0].name == generation_name
481+
assert len(completion.choices) != 0
482+
assert completion.choices[0].message.content == generation.data[0].output
483+
assert generation.data[0].input == [{"content": "1 + 1 = ", "role": "user"}]
484+
assert generation.data[0].type == "GENERATION"
485+
assert generation.data[0].model == "gpt-3.5-turbo-0613"
486+
assert generation.data[0].start_time is not None
487+
assert generation.data[0].end_time is not None
488+
assert generation.data[0].start_time < generation.data[0].end_time
489+
assert generation.data[0].model_parameters == {
490+
"temperature": 1,
491+
"top_p": 1,
492+
"frequency_penalty": 0,
493+
"maxTokens": "inf",
494+
"presence_penalty": 0,
495+
}
496+
assert generation.data[0].prompt_tokens is not None
497+
assert generation.data[0].completion_tokens is not None
498+
assert generation.data[0].total_tokens is not None
499+
assert generation.data[0].output == "2"
500+
501+
502+
@pytest.mark.asyncio
503+
async def test_async_chat_stream():
504+
api = get_api()
505+
client = AsyncOpenAI()
506+
generation_name = create_uuid()
507+
508+
completion = await client.chat.completions.create(messages=[{"role": "user", "content": "1 + 1 = "}], model="gpt-3.5-turbo", name=generation_name, stream=True)
509+
510+
async for c in completion:
511+
print(c)
512+
513+
client.flush_langfuse()
514+
print(completion)
515+
516+
generation = api.observations.get_many(name=generation_name, type="GENERATION")
517+
518+
assert len(generation.data) != 0
519+
assert generation.data[0].name == generation_name
520+
assert generation.data[0].input == [{"content": "1 + 1 = ", "role": "user"}]
521+
assert generation.data[0].type == "GENERATION"
522+
assert generation.data[0].model == "gpt-3.5-turbo-0613"
523+
assert generation.data[0].start_time is not None
524+
assert generation.data[0].end_time is not None
525+
assert generation.data[0].start_time < generation.data[0].end_time
526+
assert generation.data[0].model_parameters == {
527+
"temperature": 1,
528+
"top_p": 1,
529+
"frequency_penalty": 0,
530+
"maxTokens": "inf",
531+
"presence_penalty": 0,
532+
}
533+
assert generation.data[0].prompt_tokens is not None
534+
assert generation.data[0].completion_tokens is not None
535+
assert generation.data[0].total_tokens is not None
536+
assert generation.data[0].output == "2"

0 commit comments

Comments
 (0)