Skip to content

Commit 786e8ea

Browse files
authored
feat(openai): add support for tags (#400)
1 parent edd633c commit 786e8ea

2 files changed

Lines changed: 22 additions & 7 deletions

File tree

langfuse/openai.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def __init__(
8989
trace_id=None,
9090
session_id=None,
9191
user_id=None,
92+
tags=None,
9293
parent_observation_id=None,
9394
**kwargs,
9495
):
@@ -98,6 +99,7 @@ def __init__(
9899
self.args["trace_id"] = trace_id
99100
self.args["session_id"] = session_id
100101
self.args["user_id"] = user_id
102+
self.args["tags"] = tags
101103
self.args["parent_observation_id"] = parent_observation_id
102104
self.kwargs = kwargs
103105

@@ -187,17 +189,24 @@ def _get_langfuse_data_from_kwargs(
187189
if user_id is not None and not isinstance(user_id, str):
188190
raise TypeError("user_id must be a string")
189191

192+
tags = kwargs.get("tags", None)
193+
if tags is not None and (
194+
not isinstance(tags, list) or not all(isinstance(tag, str) for tag in tags)
195+
):
196+
raise TypeError("tags must be a list of strings")
197+
190198
parent_observation_id = kwargs.get("parent_observation_id", None)
191199
if parent_observation_id is not None and not isinstance(parent_observation_id, str):
192200
raise TypeError("parent_observation_id must be a string")
193201
if parent_observation_id is not None and trace_id is None:
194202
raise ValueError("parent_observation_id requires trace_id to be set")
195203

196204
if trace_id:
197-
langfuse.trace(id=trace_id, session_id=session_id, user_id=user_id)
198-
elif session_id:
199-
# If a session_id is provided but no trace_id, we should create a trace using the SDK and then use its trace_id
200-
trace_id = langfuse.trace(session_id=session_id, user_id=user_id).id
205+
langfuse.trace(id=trace_id, session_id=session_id, user_id=user_id, tags=tags)
206+
else:
207+
trace_id = langfuse.trace(
208+
session_id=session_id, user_id=user_id, tags=tags, name=name
209+
).id
201210

202211
metadata = kwargs.get("metadata", {})
203212

tests/test_openai.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -255,9 +255,11 @@ def test_openai_chat_completion_fail():
255255
openai.api_key = os.environ["OPENAI_API_KEY"]
256256

257257

258-
def test_openai_chat_completion_with_user_id():
258+
def test_openai_chat_completion_with_additional_params():
259259
api = get_api()
260260
user_id = create_uuid()
261+
session_id = create_uuid()
262+
tags = ["tag1", "tag2"]
261263
trace_id = create_uuid()
262264
completion = chat_func(
263265
name="user-creation",
@@ -267,14 +269,18 @@ def test_openai_chat_completion_with_user_id():
267269
metadata={"someKey": "someResponse"},
268270
user_id=user_id,
269271
trace_id=trace_id,
272+
session_id=session_id,
273+
tags=tags,
270274
)
271275

272276
openai.flush_langfuse()
273277

274278
assert len(completion.choices) != 0
275-
traces = api.trace.get(trace_id)
279+
trace = api.trace.get(trace_id)
276280

277-
assert traces.user_id == user_id
281+
assert trace.user_id == user_id
282+
assert trace.session_id == session_id
283+
assert trace.tags == tags
278284

279285

280286
def test_openai_chat_completion_without_extra_param():

0 commit comments

Comments
 (0)