Skip to content

Commit 8903d80

Browse files
committed
add sampling
1 parent de21c5f commit 8903d80

3 files changed

Lines changed: 55 additions & 20 deletions

File tree

langfuse/_task_manager/media_upload_consumer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66

77
class MediaUploadConsumer(threading.Thread):
8-
_log = logging.getLogger(__name__)
8+
_log = logging.getLogger("langfuse")
99
_identifier: int
1010
_max_retries: int
1111
_media_manager: MediaManager

langfuse/otel/__init__.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def __init__(
4848
environment: Optional[str] = None,
4949
release: Optional[str] = None,
5050
media_upload_thread_count: Optional[int] = None,
51-
# sample_rate: Optional[float] = None, # TODO: Implement sampling
51+
sample_rate: Optional[float] = None,
5252
# mask: Optional[MaskFunction] = None, # TODO: implement masking
5353
# sdk_integration: Optional[str] = "default", -> TO BE DEPRECATED
5454
# threads: Optional[int] = None, -> TO BE DEPRECATED
@@ -101,6 +101,7 @@ def __init__(
101101
flush_interval=flush_interval,
102102
httpx_client=httpx_client,
103103
media_upload_thread_count=media_upload_thread_count,
104+
sample_rate=sample_rate,
104105
)
105106

106107
self.tracer = (
@@ -184,13 +185,15 @@ def start_span(
184185

185186
span = self.tracer.start_span(name=name, attributes=attributes)
186187

187-
self._process_media_span_attributes(
188-
span=span,
189-
as_type=as_type,
190-
input=input,
191-
output=output,
192-
metadata=metadata,
193-
)
188+
# Process media only if span is sampled
189+
if span.is_recording:
190+
self._process_media_span_attributes(
191+
span=span,
192+
as_type=as_type,
193+
input=input,
194+
output=output,
195+
metadata=metadata,
196+
)
194197

195198
return span
196199

@@ -308,13 +311,15 @@ def _start_as_current_span_with_processed_media(
308311
with self.tracer.start_as_current_span(
309312
name=name, attributes=attributes
310313
) as span:
311-
self._process_media_span_attributes(
312-
span=span,
313-
as_type=as_type,
314-
input=input,
315-
output=output,
316-
metadata=metadata,
317-
)
314+
# Process media only if span is sampled
315+
if span.is_recording():
316+
self._process_media_span_attributes(
317+
span=span,
318+
as_type=as_type,
319+
input=input,
320+
output=output,
321+
metadata=metadata,
322+
)
318323

319324
yield span
320325

@@ -666,7 +671,6 @@ def create_score(
666671
"environment": self.environment,
667672
}
668673

669-
langfuse_logger.debug(f"Creating score {score_event}...")
670674
new_body = ScoreBody(**score_event)
671675

672676
event = {

langfuse/otel/_tracer.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from opentelemetry import trace as otel_trace_api
99
from opentelemetry.sdk.resources import Resource
1010
from opentelemetry.sdk.trace import TracerProvider
11+
from opentelemetry.sdk.trace.sampling import Decision, TraceIdRatioBased
1112

1213
from langfuse._task_manager.media_manager import MediaManager
1314
from langfuse._task_manager.media_upload_consumer import MediaUploadConsumer
@@ -47,6 +48,7 @@ def __new__(
4748
flush_interval: Optional[float] = None,
4849
httpx_client: Optional[httpx.Client] = None,
4950
media_upload_thread_count: Optional[int] = None,
51+
sample_rate: Optional[float] = None,
5052
) -> "LangfuseTracer":
5153
if public_key in cls._instances:
5254
return cls._instances[public_key]
@@ -66,6 +68,7 @@ def __new__(
6668
flush_interval=flush_interval,
6769
httpx_client=httpx_client,
6870
media_upload_thread_count=media_upload_thread_count,
71+
sample_rate=sample_rate,
6972
)
7073

7174
cls._instances[public_key] = instance
@@ -85,10 +88,11 @@ def _initialize_instance(
8588
flush_interval: Optional[float] = None,
8689
media_upload_thread_count: Optional[int] = None,
8790
httpx_client: Optional[httpx.Client] = None,
91+
sample_rate: Optional[float] = None,
8892
):
8993
# OTEL Tracer
9094
tracer_provider = _init_tracer_provider(
91-
environment=environment, release=release
95+
environment=environment, release=release, sample_rate=sample_rate
9296
)
9397

9498
langfuse_processor = LangfuseSpanProcessor(
@@ -191,6 +195,14 @@ def _initialize_instance(
191195
# Register shutdown handler
192196
atexit.register(self.shutdown)
193197

198+
langfuse_logger.info(
199+
f"Initialized Langfuse tracer with "
200+
f"public_key={public_key}, "
201+
f"host={host}, "
202+
f"environment={environment}, "
203+
f"sample_rate={sample_rate}"
204+
)
205+
194206
def _fetch_project_id_background(self):
195207
try:
196208
projects = self.api.projects.get(
@@ -231,7 +243,22 @@ def project_id(self):
231243

232244
def add_score_task(self, event: dict):
233245
try:
234-
self._score_ingestion_queue.put(event, block=False)
246+
# Sample scores with the same sampler that is used for tracing
247+
tracer_provider = cast(TracerProvider, otel_trace_api.get_tracer_provider())
248+
should_sample = (
249+
tracer_provider.sampler.should_sample(
250+
parent_context=None,
251+
trace_id=int(event["body"].trace_id, 16),
252+
name="score",
253+
).decision
254+
== Decision.RECORD_AND_SAMPLE
255+
if hasattr(event["body"], "trace_id")
256+
else True
257+
)
258+
259+
if should_sample:
260+
langfuse_logger.debug(f"Enqueuing score event: {event}")
261+
self._score_ingestion_queue.put(event, block=False)
235262

236263
except Full:
237264
langfuse_logger.warning("Score ingestion queue is full")
@@ -313,6 +340,7 @@ def _init_tracer_provider(
313340
*,
314341
environment: Optional[str] = None,
315342
release: Optional[str] = None,
343+
sample_rate: Optional[float] = None,
316344
) -> TracerProvider:
317345
environment = environment or os.environ.get(LANGFUSE_TRACING_ENVIRONMENT)
318346
release = release or os.environ.get(LANGFUSE_RELEASE) or get_common_release_envs()
@@ -330,7 +358,10 @@ def _init_tracer_provider(
330358
default_provider = cast(TracerProvider, otel_trace_api.get_tracer_provider())
331359

332360
if isinstance(default_provider, otel_trace_api.ProxyTracerProvider):
333-
provider = TracerProvider(resource=resource)
361+
provider = TracerProvider(
362+
resource=resource,
363+
sampler=TraceIdRatioBased(sample_rate) if sample_rate else None,
364+
)
334365
otel_trace_api.set_tracer_provider(provider)
335366

336367
else:

0 commit comments

Comments
 (0)