Skip to content

Commit 87eb152

Browse files
authored
Openai v1 (#169)
1 parent 909712d commit 87eb152

7 files changed

Lines changed: 551 additions & 364 deletions

File tree

.github/workflows/ci.yml

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -100,12 +100,6 @@ jobs:
100100
- name: Install the project dependencies
101101
run: poetry install
102102

103-
# - name: Lint
104-
# run: |
105-
# poetry run black --check .
106-
# poetry run isort --check .
107-
# poetry run flake8 .
108-
109103
- name: Run the automated tests
110104
run: |
111105
python --version

.tox/.package.lock

Whitespace-only changes.

langfuse/openai.py

Lines changed: 227 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,61 @@
11
import threading
2-
import functools
32
from datetime import datetime
43

5-
import openai
6-
from openai.api_resources import ChatCompletion, Completion
74

85
from langfuse import Langfuse
96
from langfuse.client import InitialGeneration, CreateTrace
10-
from langfuse.api.resources.commons.types.llm_usage import LlmUsage
7+
8+
from distutils.version import StrictVersion
9+
import openai
10+
from wrapt import wrap_function_wrapper
11+
12+
13+
class OpenAiDefinition:
14+
module: str
15+
object: str
16+
method: str
17+
type: str
18+
19+
def __init__(self, module: str, object: str, method: str, type: str):
20+
self.module = module
21+
self.object = object
22+
self.method = method
23+
self.type = type
1124

1225

13-
class CreateArgsExtractor:
26+
OPENAI_METHODS_V0 = [
27+
OpenAiDefinition(
28+
module="openai",
29+
object="ChatCompletion",
30+
method="create",
31+
type="chat",
32+
),
33+
OpenAiDefinition(
34+
module="openai",
35+
object="Completion",
36+
method="create",
37+
type="completion",
38+
),
39+
]
40+
41+
42+
OPENAI_METHODS_V1 = [
43+
OpenAiDefinition(
44+
module="openai.resources.chat.completions",
45+
object="Completions",
46+
method="create",
47+
type="chat",
48+
),
49+
OpenAiDefinition(
50+
module="openai.resources.completions",
51+
object="Completions",
52+
method="create",
53+
type="completion",
54+
),
55+
]
56+
57+
58+
class OpenAiArgsExtractor:
1459
def __init__(self, name=None, metadata=None, trace_id=None, **kwargs):
1560
self.args = {}
1661
self.args["name"] = name
@@ -25,6 +70,105 @@ def get_openai_args(self):
2570
return self.kwargs
2671

2772

73+
def _with_tracer_wrapper(func):
74+
"""Helper for providing tracer for wrapper functions."""
75+
76+
def _with_tracer(open_ai_definitions, langfuse):
77+
def wrapper(wrapped, instance, args, kwargs):
78+
return func(open_ai_definitions, langfuse, wrapped, instance, args, kwargs)
79+
80+
return wrapper
81+
82+
return _with_tracer
83+
84+
85+
def _get_langfuse_data_from_kwargs(resource: OpenAiDefinition, langfuse: Langfuse, start_time, kwargs):
86+
name = kwargs.get("name", "OpenAI-generation")
87+
88+
if name is not None and not isinstance(name, str):
89+
raise TypeError("name must be a string")
90+
91+
trace_id = kwargs.get("trace_id", "OpenAI-generation")
92+
if trace_id is not None and not isinstance(trace_id, str):
93+
raise TypeError("trace_id must be a string")
94+
95+
if trace_id:
96+
langfuse.trace(CreateTrace(id=trace_id))
97+
98+
metadata = kwargs.get("metadata", {})
99+
100+
if metadata is not None and not isinstance(metadata, dict):
101+
raise TypeError("metadata must be a dictionary")
102+
103+
prompt = None
104+
if resource.type == "completion":
105+
prompt = kwargs.get("prompt", None)
106+
elif resource.type == "chat":
107+
prompt = (
108+
{
109+
"messages": kwargs.get("messages", [{}]),
110+
"functions": kwargs.get("functions", [{}]),
111+
"function_call": kwargs.get("function_call", {}),
112+
}
113+
if kwargs.get("functions", None) is not None
114+
else kwargs.get("messages", [{}])
115+
)
116+
117+
modelParameters = {
118+
"temperature": kwargs.get("temperature", 1),
119+
"maxTokens": kwargs.get("max_tokens", float("inf")),
120+
"top_p": kwargs.get("top_p", 1),
121+
"frequency_penalty": kwargs.get("frequency_penalty", 0),
122+
"presence_penalty": kwargs.get("presence_penalty", 0),
123+
}
124+
125+
return InitialGeneration(name=name, metadata=metadata, trace_id=trace_id, start_time=start_time, prompt=prompt, modelParameters=modelParameters)
126+
127+
128+
def _get_langfuse_data_from_response(resource: OpenAiDefinition, response):
129+
model = response.get("model", None)
130+
131+
completion = None
132+
if resource.type == "completion":
133+
choices = response.get("choices", [])
134+
if len(choices) > 0:
135+
choice = choices[-1]
136+
137+
completion = choice.text if _is_openai_v1() else choice.get("text", None)
138+
elif resource.type == "chat":
139+
choices = response.get("choices", [])
140+
if len(choices) > 0:
141+
choice = choices[-1]
142+
completion = choice.message.content if _is_openai_v1() else choice.get("message", None).get("content", None)
143+
144+
usage = response.get("usage", None)
145+
146+
return model, completion, usage
147+
148+
149+
def _is_openai_v1():
150+
return StrictVersion(openai.__version__) >= StrictVersion("1.0.0")
151+
152+
153+
@_with_tracer_wrapper
154+
def _wrap(open_ai_resource: OpenAiDefinition, langfuse: Langfuse, wrapped, instance, args, kwargs):
155+
start_time = datetime.now()
156+
arg_extractor = OpenAiArgsExtractor(*args, **kwargs)
157+
158+
generation = _get_langfuse_data_from_kwargs(open_ai_resource, langfuse, start_time, arg_extractor.get_langfuse_args())
159+
updated_generation = generation
160+
try:
161+
result = wrapped(**arg_extractor.get_openai_args())
162+
model, completion, usage = _get_langfuse_data_from_response(open_ai_resource, result.__dict__ if _is_openai_v1() else result)
163+
updated_generation = generation.copy(update={"model": model, "completion": completion, "end_time": datetime.now(), "usage": usage})
164+
langfuse.generation(updated_generation)
165+
return result
166+
except Exception as ex:
167+
model = kwargs.get("model", None)
168+
langfuse.generation(updated_generation.copy(update={"end_time": datetime.now(), "status_message": str(ex), "level": "ERROR", "model": model}))
169+
raise ex
170+
171+
28172
class OpenAILangfuse:
29173
_instance = None
30174
_lock = threading.Lock()
@@ -44,109 +188,95 @@ def initialize(self):
44188
def flush(cls):
45189
cls._instance.langfuse.flush()
46190

47-
def _get_call_details(self, result, api_resource_class, **kwargs):
48-
name = kwargs.get("name", "OpenAI-generation")
191+
# def legacy(self, result, api_resource_class, **kwargs):
192+
# completion = None
49193

50-
if name is not None and not isinstance(name, str):
51-
raise TypeError("name must be a string")
194+
# if api_resource_class == chat_completions:
195+
# prompt = (
196+
# {
197+
# "messages": kwargs.get("messages", [{}]),
198+
# "functions": kwargs.get("functions", [{}]),
199+
# "function_call": kwargs.get("function_call", {}),
200+
# }
201+
# if kwargs.get("functions", None) is not None
202+
# else kwargs.get("messages", [{}])
203+
# )
204+
# if not isinstance(result, Exception):
205+
# completion = result.choices[-1].message.content
206+
# if completion is None:
207+
# completion = result.choices[-1].message.function_call
52208

53-
trace_id = kwargs.get("trace_id", "OpenAI-generation")
54-
if trace_id is not None and not isinstance(trace_id, str):
55-
raise TypeError("trace_id must be a string")
209+
# elif api_resource_class == completions:
210+
# prompt = kwargs.get("prompt", "")
211+
# if not isinstance(result, Exception):
212+
# completion = result.choices[-1].text
213+
# else:
214+
# completion = None
56215

57-
metadata = kwargs.get("metadata", {})
216+
# model = kwargs.get("model", None) if isinstance(result, Exception) else result.model
58217

59-
if metadata is not None and not isinstance(metadata, dict):
60-
raise TypeError("metadata must be a dictionary")
218+
# usage = None if isinstance(result, Exception) or result.usage is None else LlmUsage(**result.usage.dict())
219+
# endTime = datetime.now()
220+
# modelParameters = {
221+
# "temperature": kwargs.get("temperature", 1),
222+
# "maxTokens": kwargs.get("max_tokens", float("inf")),
223+
# "top_p": kwargs.get("top_p", 1),
224+
# "frequency_penalty": kwargs.get("frequency_penalty", 0),
225+
# "presence_penalty": kwargs.get("presence_penalty", 0),
226+
# }
227+
# all_details = {
228+
# "status_message": str(result) if isinstance(result, Exception) else None,
229+
# "name": name,
230+
# "prompt": prompt,
231+
# "completion": completion,
232+
# "endTime": endTime,
233+
# "model": model,
234+
# "modelParameters": modelParameters,
235+
# "usage": usage,
236+
# "metadata": metadata,
237+
# "level": "ERROR" if isinstance(result, Exception) else "DEFAULT",
238+
# "trace_id": trace_id,
239+
# }
240+
# return all_details
61241

62-
completion = None
242+
# def _log_result(self, call_details):
243+
# generation = InitialGeneration(**call_details)
244+
# if call_details["trace_id"] is not None:
245+
# self.langfuse.trace(CreateTrace(id=call_details["trace_id"]))
246+
# self.langfuse.generation(generation)
63247

64-
if api_resource_class == ChatCompletion:
65-
prompt = (
66-
{
67-
"messages": kwargs.get("messages", [{}]),
68-
"functions": kwargs.get("functions", [{}]),
69-
"function_call": kwargs.get("function_call", {}),
70-
}
71-
if kwargs.get("functions", None) is not None
72-
else kwargs.get("messages", [{}])
73-
)
74-
if not isinstance(result, Exception):
75-
completion = result.choices[-1].message.content
76-
if completion is None:
77-
completion = result.choices[-1].message.function_call
78-
79-
elif api_resource_class == Completion:
80-
prompt = kwargs.get("prompt", "")
81-
if not isinstance(result, Exception):
82-
completion = result.choices[-1].text
83-
else:
84-
completion = None
85-
86-
model = kwargs.get("model", None) if isinstance(result, Exception) else result.model
87-
88-
usage = None if isinstance(result, Exception) or result.usage is None else LlmUsage(**result.usage)
89-
endTime = datetime.now()
90-
modelParameters = {
91-
"temperature": kwargs.get("temperature", 1),
92-
"maxTokens": kwargs.get("max_tokens", float("inf")),
93-
"top_p": kwargs.get("top_p", 1),
94-
"frequency_penalty": kwargs.get("frequency_penalty", 0),
95-
"presence_penalty": kwargs.get("presence_penalty", 0),
96-
}
97-
all_details = {
98-
"status_message": str(result) if isinstance(result, Exception) else None,
99-
"name": name,
100-
"prompt": prompt,
101-
"completion": completion,
102-
"endTime": endTime,
103-
"model": model,
104-
"modelParameters": modelParameters,
105-
"usage": usage,
106-
"metadata": metadata,
107-
"level": "ERROR" if isinstance(result, Exception) else "DEFAULT",
108-
"trace_id": trace_id,
109-
}
110-
return all_details
111-
112-
def _log_result(self, call_details):
113-
generation = InitialGeneration(**call_details)
114-
if call_details["trace_id"] is not None:
115-
self.langfuse.trace(CreateTrace(id=call_details["trace_id"]))
116-
self.langfuse.generation(generation)
117-
118-
def langfuse_modified(self, func, api_resource_class):
119-
@functools.wraps(func)
120-
def wrapper(*args, **kwargs):
121-
try:
122-
startTime = datetime.now()
123-
arg_extractor = CreateArgsExtractor(*args, **kwargs)
124-
result = func(**arg_extractor.get_openai_args())
125-
call_details = self._get_call_details(result, api_resource_class, **arg_extractor.get_langfuse_args())
126-
call_details["startTime"] = startTime
127-
self._log_result(call_details)
128-
except Exception as ex:
129-
call_details = self._get_call_details(ex, api_resource_class, **arg_extractor.get_langfuse_args())
130-
call_details["startTime"] = startTime
131-
self._log_result(call_details)
132-
raise ex
133-
134-
return result
248+
# def langfuse_modified(self, func, api_resource_class):
249+
# @functools.wraps(func)
250+
# def wrapper(*args, **kwargs):
251+
# try:
252+
# startTime = datetime.now()
253+
# arg_extractor = OpenAiArgsExtractor(*args, **kwargs)
254+
# result = func(**arg_extractor.get_openai_args())
255+
# call_details = self._get_langfuse_data_from_kwargs(result, api_resource_class, **arg_extractor.get_langfuse_args())
256+
# call_details["startTime"] = startTime
257+
# self._log_result(call_details)
258+
# except Exception as ex:
259+
# call_details = self._get_langfuse_data_from_kwargs(ex, api_resource_class, **arg_extractor.get_langfuse_args())
260+
# call_details["startTime"] = startTime
261+
# self._log_result(call_details)
262+
# raise ex
135263

136-
return wrapper
264+
# return result
137265

138-
def replace_openai_funcs(self):
139-
api_resources_classes = [
140-
(ChatCompletion, "create"),
141-
(Completion, "create"),
142-
]
266+
# return wrapper
143267

144-
for api_resource_class, method in api_resources_classes:
145-
create_method = getattr(api_resource_class, method)
146-
setattr(api_resource_class, method, self.langfuse_modified(create_method, api_resource_class))
268+
def register_tracing(self):
269+
resources = OPENAI_METHODS_V1 if _is_openai_v1() else OPENAI_METHODS_V0
270+
271+
for resource in resources:
272+
wrap_function_wrapper(
273+
resource.module,
274+
f"{resource.object}.{resource.method}",
275+
_wrap(resource, self.langfuse),
276+
)
147277

148278
setattr(openai, "flush_langfuse", self.flush)
149279

150280

151281
modifier = OpenAILangfuse()
152-
modifier.replace_openai_funcs()
282+
modifier.register_tracing()

langfuse/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "1.1.15"
1+
__version__ = "1.1.16a2"

0 commit comments

Comments
 (0)