Skip to content

Commit acaea58

Browse files
committed
push
1 parent 3a7efae commit acaea58

File tree

5 files changed

+130
-85
lines changed

5 files changed

+130
-85
lines changed

langfuse/_client/client.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1304,22 +1304,20 @@ def create_score(
13041304
score_id = score_id or self._create_observation_id()
13051305

13061306
try:
1307-
score_event = {
1308-
"id": score_id,
1309-
"session_id": session_id,
1310-
"dataset_run_id": dataset_run_id,
1311-
"trace_id": trace_id,
1312-
"observation_id": observation_id,
1313-
"name": name,
1314-
"value": value,
1315-
"data_type": data_type,
1316-
"comment": comment,
1317-
"config_id": config_id,
1318-
"environment": self._environment,
1319-
"metadata": metadata,
1320-
}
1321-
1322-
new_body = ScoreBody(**score_event)
1307+
new_body = ScoreBody(
1308+
id=score_id,
1309+
session_id=session_id,
1310+
dataset_run_id=dataset_run_id,
1311+
trace_id=trace_id,
1312+
observation_id=observation_id,
1313+
name=name,
1314+
value=value,
1315+
data_type=data_type,
1316+
comment=comment,
1317+
config_id=config_id,
1318+
environment=self._environment,
1319+
metadata=metadata,
1320+
)
13231321

13241322
event = {
13251323
"id": self.create_trace_id(),
@@ -1960,7 +1958,7 @@ def get_prompt(
19601958
f"Returning fallback prompt for '{cache_key}' due to fetch error: {e}"
19611959
)
19621960

1963-
fallback_client_args = {
1961+
fallback_client_args: Dict[str, Any] = {
19641962
"name": name,
19651963
"prompt": fallback,
19661964
"type": type,

langfuse/_client/resource_manager.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import os
1919
import threading
2020
from queue import Full, Queue
21-
from typing import Dict, Optional, cast
21+
from typing import Dict, Optional, cast, Any
2222

2323
import httpx
2424
from opentelemetry import trace as otel_trace_api
@@ -148,7 +148,7 @@ def _initialize_instance(
148148
)
149149
tracer_provider.add_span_processor(langfuse_processor)
150150

151-
tracer_provider = otel_trace_api.get_tracer_provider()
151+
tracer_provider = cast(TracerProvider, otel_trace_api.get_tracer_provider())
152152
self._otel_tracer = tracer_provider.get_tracer(
153153
LANGFUSE_TRACER_NAME,
154154
langfuse_version,
@@ -195,7 +195,7 @@ def _initialize_instance(
195195
LANGFUSE_MEDIA_UPLOAD_ENABLED, "True"
196196
).lower() not in ("false", "0")
197197

198-
self._media_upload_queue = Queue(100_000)
198+
self._media_upload_queue: Queue[Any] = Queue(100_000)
199199
self._media_manager = MediaManager(
200200
api_client=self.api,
201201
media_upload_queue=self._media_upload_queue,
@@ -220,7 +220,7 @@ def _initialize_instance(
220220
self.prompt_cache = PromptCache()
221221

222222
# Score ingestion
223-
self._score_ingestion_queue = Queue(100_000)
223+
self._score_ingestion_queue: Queue[Any] = Queue(100_000)
224224
self._ingestion_consumers = []
225225

226226
ingestion_consumer = ScoreIngestionConsumer(

langfuse/_utils/serializer.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,16 @@
2121
try:
2222
from langchain.load.serializable import Serializable
2323
except ImportError:
24-
# If Serializable is not available, set it to NoneType
25-
Serializable = type(None)
24+
# If Serializable is not available, set it to a placeholder type
25+
class Serializable: # type: ignore
26+
pass
27+
2628

2729
# Attempt to import numpy
2830
try:
2931
import numpy as np
3032
except ImportError:
31-
np = None
33+
np = None # type: ignore
3234

3335
logger = getLogger(__name__)
3436

langfuse/langchain/CallbackHandler.py

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ def __init__(self, *, public_key: Optional[str] = None) -> None:
6060
self.client = get_client(public_key=public_key)
6161

6262
self.runs: Dict[UUID, Union[LangfuseSpan, LangfuseGeneration]] = {}
63-
self.prompt_to_parent_run_map = {}
64-
self.updated_completion_start_time_memo = set()
63+
self.prompt_to_parent_run_map: Dict[UUID, Any] = {}
64+
self.updated_completion_start_time_memo: Set[UUID] = set()
6565

6666
def on_llm_new_token(
6767
self,
@@ -166,19 +166,26 @@ def on_chain_start(
166166
run_id=run_id, parent_run_id=parent_run_id, metadata=metadata
167167
)
168168

169-
content = {
170-
"name": self.get_langchain_run_name(serialized, **kwargs),
171-
"metadata": self.__join_tags_and_metadata(tags, metadata),
172-
"input": inputs,
173-
"level": "DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None,
174-
}
169+
span_name = self.get_langchain_run_name(serialized, **kwargs)
170+
span_metadata = self.__join_tags_and_metadata(tags, metadata)
171+
span_level = "DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None
175172

176173
if parent_run_id is None:
177-
self.runs[run_id] = self.client.start_span(**content)
174+
self.runs[run_id] = self.client.start_span(
175+
name=span_name,
176+
metadata=span_metadata,
177+
input=inputs,
178+
level=span_level,
179+
)
178180
else:
179181
self.runs[run_id] = cast(
180182
LangfuseSpan, self.runs[parent_run_id]
181-
).start_span(**content)
183+
).start_span(
184+
name=span_name,
185+
metadata=span_metadata,
186+
input=inputs,
187+
level=span_level,
188+
)
182189

183190
except Exception as e:
184191
langfuse_logger.exception(e)
@@ -431,23 +438,25 @@ def on_retriever_start(
431438
self._log_debug_event(
432439
"on_retriever_start", run_id, parent_run_id, query=query
433440
)
441+
span_name = self.get_langchain_run_name(serialized, **kwargs)
442+
span_metadata = self.__join_tags_and_metadata(tags, metadata)
443+
span_level = "DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None
444+
434445
if parent_run_id is None:
435-
content = {
436-
"name": self.get_langchain_run_name(serialized, **kwargs),
437-
"metadata": self.__join_tags_and_metadata(tags, metadata),
438-
"input": query,
439-
"level": "DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None,
440-
}
441-
442-
self.runs[run_id] = self.client.start_span(**content)
446+
self.runs[run_id] = self.client.start_span(
447+
name=span_name,
448+
metadata=span_metadata,
449+
input=query,
450+
level=span_level,
451+
)
443452
else:
444453
self.runs[run_id] = cast(
445454
LangfuseSpan, self.runs[parent_run_id]
446455
).start_span(
447-
name=self.get_langchain_run_name(serialized, **kwargs),
456+
name=span_name,
448457
input=query,
449-
metadata=self.__join_tags_and_metadata(tags, metadata),
450-
level="DEBUG" if tags and LANGSMITH_TAG_HIDDEN in tags else None,
458+
metadata=span_metadata,
459+
level=span_level,
451460
)
452461

453462
except Exception as e:

langfuse/langchain/utils.py

Lines changed: 76 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""@private"""
22

33
import re
4-
from typing import Any, Dict, List, Literal, Optional
4+
from typing import Any, Dict, List, Literal, Optional, cast
55

66
# NOTE ON DEPENDENCIES:
77
# - since Jan 2024, there is https://pypi.org/project/langchain-openai/ which is a separate package and imports openai models.
@@ -12,7 +12,7 @@
1212
def _extract_model_name(
1313
serialized: Optional[Dict[str, Any]],
1414
**kwargs: Any,
15-
):
15+
) -> Optional[str]:
1616
"""Extracts the model name from the serialized or kwargs object. This is used to get the model names for Langfuse."""
1717
# In this function we return on the first match, so the order of operations is important
1818

@@ -39,39 +39,54 @@ def _extract_model_name(
3939

4040
for model_name, keys, select_from in models_by_id:
4141
model = _extract_model_by_path_for_id(
42-
model_name, serialized, kwargs, keys, select_from
42+
model_name,
43+
serialized,
44+
kwargs,
45+
keys,
46+
cast(Literal["serialized", "kwargs"], select_from),
4347
)
4448
if model:
4549
return model
4650

4751
# Second, we match AzureOpenAI as we need to extract the model name, fdeployment version and deployment name
48-
if serialized.get("id")[-1] == "AzureOpenAI":
49-
if kwargs.get("invocation_params").get("model"):
50-
return kwargs.get("invocation_params").get("model")
51-
52-
if kwargs.get("invocation_params").get("model_name"):
53-
return kwargs.get("invocation_params").get("model_name")
54-
55-
deployment_name = None
56-
deployment_version = None
57-
58-
if serialized.get("kwargs").get("openai_api_version"):
59-
deployment_version = serialized.get("kwargs").get("deployment_version")
60-
61-
if serialized.get("kwargs").get("deployment_name"):
62-
deployment_name = serialized.get("kwargs").get("deployment_name")
63-
64-
if not isinstance(deployment_name, str):
65-
return None
66-
67-
if not isinstance(deployment_version, str):
68-
return deployment_name
69-
70-
return (
71-
deployment_name + "-" + deployment_version
72-
if deployment_version not in deployment_name
73-
else deployment_name
74-
)
52+
if serialized:
53+
serialized_id = serialized.get("id")
54+
if (
55+
serialized_id
56+
and isinstance(serialized_id, list)
57+
and len(serialized_id) > 0
58+
and serialized_id[-1] == "AzureOpenAI"
59+
):
60+
invocation_params = kwargs.get("invocation_params")
61+
if invocation_params and isinstance(invocation_params, dict):
62+
if invocation_params.get("model"):
63+
return str(invocation_params.get("model"))
64+
65+
if invocation_params.get("model_name"):
66+
return str(invocation_params.get("model_name"))
67+
68+
deployment_name = None
69+
deployment_version = None
70+
71+
serialized_kwargs = serialized.get("kwargs")
72+
if serialized_kwargs and isinstance(serialized_kwargs, dict):
73+
if serialized_kwargs.get("openai_api_version"):
74+
deployment_version = serialized_kwargs.get("deployment_version")
75+
76+
if serialized_kwargs.get("deployment_name"):
77+
deployment_name = serialized_kwargs.get("deployment_name")
78+
79+
if not isinstance(deployment_name, str):
80+
return None
81+
82+
if not isinstance(deployment_version, str):
83+
return deployment_name
84+
85+
return (
86+
deployment_name + "-" + deployment_version
87+
if deployment_version not in deployment_name
88+
else deployment_name
89+
)
7590

7691
# Third, for some models, we are unable to extract the model by a path in an object. Langfuse provides us with a string representation of the model pbjects
7792
# We use regex to extract the model from the repr string
@@ -111,7 +126,9 @@ def _extract_model_name(
111126
]
112127
for select in ["kwargs", "serialized"]:
113128
for path in random_paths:
114-
model = _extract_model_by_path(serialized, kwargs, path, select)
129+
model = _extract_model_by_path(
130+
serialized, kwargs, path, cast(Literal["serialized", "kwargs"], select)
131+
)
115132
if model:
116133
return model
117134

@@ -123,13 +140,20 @@ def _extract_model_from_repr_by_pattern(
123140
serialized: Optional[Dict[str, Any]],
124141
pattern: str,
125142
default: Optional[str] = None,
126-
):
143+
) -> Optional[str]:
127144
if serialized is None:
128145
return None
129146

130-
if serialized.get("id")[-1] == id:
131-
if serialized.get("repr"):
132-
extracted = _extract_model_with_regex(pattern, serialized.get("repr"))
147+
serialized_id = serialized.get("id")
148+
if (
149+
serialized_id
150+
and isinstance(serialized_id, list)
151+
and len(serialized_id) > 0
152+
and serialized_id[-1] == id
153+
):
154+
repr_str = serialized.get("repr")
155+
if repr_str and isinstance(repr_str, str):
156+
extracted = _extract_model_with_regex(pattern, repr_str)
133157
return extracted if extracted else default if default else None
134158

135159
return None
@@ -145,15 +169,24 @@ def _extract_model_with_regex(pattern: str, text: str):
145169
def _extract_model_by_path_for_id(
146170
id: str,
147171
serialized: Optional[Dict[str, Any]],
148-
kwargs: dict,
172+
kwargs: Dict[str, Any],
149173
keys: List[str],
150174
select_from: Literal["serialized", "kwargs"],
151-
):
175+
) -> Optional[str]:
152176
if serialized is None and select_from == "serialized":
153177
return None
154178

155-
if serialized.get("id")[-1] == id:
156-
return _extract_model_by_path(serialized, kwargs, keys, select_from)
179+
if serialized:
180+
serialized_id = serialized.get("id")
181+
if (
182+
serialized_id
183+
and isinstance(serialized_id, list)
184+
and len(serialized_id) > 0
185+
and serialized_id[-1] == id
186+
):
187+
return _extract_model_by_path(serialized, kwargs, keys, select_from)
188+
189+
return None
157190

158191

159192
def _extract_model_by_path(
@@ -168,7 +201,10 @@ def _extract_model_by_path(
168201
current_obj = kwargs if select_from == "kwargs" else serialized
169202

170203
for key in keys:
171-
current_obj = current_obj.get(key)
204+
if current_obj and isinstance(current_obj, dict):
205+
current_obj = current_obj.get(key)
206+
else:
207+
return None
172208
if not current_obj:
173209
return None
174210

0 commit comments

Comments
 (0)