Skip to content

Commit 3a49eef

Browse files
committed
add prometheus exporter
Signed-off-by: kerthcet <kerthcet@gmail.com>
1 parent e134e15 commit 3a49eef

12 files changed

Lines changed: 872 additions & 0 deletions

File tree

.env.example

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,8 @@ ALPHATRION_ARTIFACT_INSECURE=false
1616
ALPHATRION_ENABLE_TRACING=true
1717
ALPHATRION_CLICKHOUSE_INIT_TABLES=true
1818
ALPHATRION_CLICKHOUSE_ENABLE_BATCH=true
19+
20+
# Prometheus push gateway configurations
21+
ALPHATRION_ENABLE_PROMETHEUS=false
22+
ALPHATRION_PROMETHEUS_PUSHGATEWAY_URL=localhost:9091
23+
ALPHATRION_PROMETHEUS_JOB_NAME=alphatrion

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ Save the generated user ID — you'll need it to track experiments.
6565
**Optional Tools:**
6666
- pgAdmin: `http://localhost:8081` (alphatrion@inftyai.com / alphatr1on)
6767
- Registry UI: `http://localhost:80`
68+
- Grafana: `http://localhost:3000` (admin / admin) - LLM metrics dashboard
69+
- Prometheus: `http://localhost:9090` - Metrics explorer
6870

6971
### 3. Track Your First Experiment
7072

alphatrion/envs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@
1616
CLICKHOUSE_PASSWORD = "ALPHATRION_CLICKHOUSE_PASSWORD"
1717
CLICKHOUSE_ENABLE_BATCH = "ALPHATRION_CLICKHOUSE_ENABLE_BATCH"
1818

19+
# Prometheus push gateway related envs
20+
ENABLE_PROMETHEUS = "ALPHATRION_ENABLE_PROMETHEUS"
21+
PROMETHEUS_PUSHGATEWAY_URL = "ALPHATRION_PROMETHEUS_PUSHGATEWAY_URL"
22+
PROMETHEUS_JOB_NAME = "ALPHATRION_PROMETHEUS_JOB_NAME"
23+
1924
# Dashboard only related envs
2025
DASHBOARD_USER_ID = "ALPHATRION_DASHBOARD_USER_ID"
2126

alphatrion/storage/runtime.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from alphatrion.storage.sqlstore import SQLStore
1111
from alphatrion.storage.tracestore import TraceStore
1212
from alphatrion.tracing.clickhouse_exporter import ClickHouseSpanExporter
13+
from alphatrion.tracing.prometheus_span_processor import PrometheusSpanProcessor
1314
from alphatrion.tracing.span_processor import ContextAttributesSpanProcessor
1415

1516
__STORAGE_RUNTIME__ = None
@@ -55,6 +56,19 @@ def __init__(self):
5556
tracer_provider = trace.get_tracer_provider()
5657
tracer_provider.add_span_processor(ContextAttributesSpanProcessor())
5758

59+
# Add Prometheus span processor if enabled
60+
if os.getenv(envs.ENABLE_PROMETHEUS, "false").lower() == "true":
61+
pushgateway_url = os.getenv(
62+
envs.PROMETHEUS_PUSHGATEWAY_URL, "localhost:9091"
63+
)
64+
job_name = os.getenv(envs.PROMETHEUS_JOB_NAME, "alphatrion")
65+
66+
prometheus_processor = PrometheusSpanProcessor(
67+
pushgateway_url=pushgateway_url,
68+
job_name=job_name,
69+
)
70+
tracer_provider.add_span_processor(prometheus_processor)
71+
5872
artifact_insecure = os.getenv(envs.ARTIFACT_INSECURE, "false").lower() == "true"
5973
if artifact_storage_enabled():
6074
self._artifact = Artifact(insecure=artifact_insecure)
Lines changed: 313 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,313 @@
1+
# ruff: noqa: PLR0911
2+
3+
import logging
4+
import socket
5+
import uuid
6+
7+
from opentelemetry.context import Context
8+
from opentelemetry.sdk.trace import ReadableSpan
9+
from opentelemetry.sdk.trace.export import SpanProcessor
10+
from prometheus_client import (
11+
CollectorRegistry,
12+
Counter,
13+
Histogram,
14+
pushadd_to_gateway,
15+
)
16+
17+
logger = logging.getLogger(__name__)
18+
19+
20+
class PrometheusSpanProcessor(SpanProcessor):
21+
"""
22+
Span processor that exports OpenTelemetry span metrics to Prometheus push gateway.
23+
24+
This processor extracts metrics from LLM spans (tokens, latency, etc.) and pushes
25+
them to a Prometheus push gateway, making them available for scraping by Prometheus.
26+
"""
27+
28+
def __init__(
29+
self,
30+
pushgateway_url: str,
31+
job_name: str = "alphatrion",
32+
grouping_key: dict[str, str] | None = None,
33+
):
34+
"""
35+
Initialize the Prometheus span processor.
36+
37+
Args:
38+
pushgateway_url: URL of the Prometheus push gateway (e.g., "localhost:9091")
39+
job_name: Job name for the metrics in Prometheus
40+
grouping_key: Additional grouping labels (e.g., {"instance": "app-1"})
41+
"""
42+
self.pushgateway_url = pushgateway_url
43+
self.job_name = job_name
44+
45+
# Generate unique instance identifier to prevent metrics from being overwritten
46+
if grouping_key is None:
47+
self.grouping_key = {
48+
"instance": socket.gethostname(),
49+
"uid": uuid.uuid4().hex,
50+
}
51+
52+
# Create a separate registry for push gateway metrics
53+
self.registry = CollectorRegistry()
54+
55+
# Define metrics
56+
self._init_metrics()
57+
58+
logger.info(
59+
f"PrometheusSpanProcessor initialized: pushgateway={pushgateway_url}, "
60+
f"job={job_name}"
61+
)
62+
63+
def _init_metrics(self):
64+
"""Initialize Prometheus metrics."""
65+
66+
# LLM Token usage metrics
67+
self.llm_tokens_total = Counter(
68+
"alphatrion_llm_tokens_total",
69+
"Total LLM tokens consumed",
70+
["team_id", "experiment_id", "model", "token_type"],
71+
registry=self.registry,
72+
)
73+
74+
self.llm_input_tokens_total = Counter(
75+
"alphatrion_llm_input_tokens_total",
76+
"Total LLM input tokens consumed",
77+
["team_id", "experiment_id", "model"],
78+
registry=self.registry,
79+
)
80+
81+
self.llm_output_tokens_total = Counter(
82+
"alphatrion_llm_output_tokens_total",
83+
"Total LLM output tokens consumed",
84+
["team_id", "experiment_id", "model"],
85+
registry=self.registry,
86+
)
87+
88+
# LLM Request metrics
89+
self.llm_requests_total = Counter(
90+
"alphatrion_llm_requests_total",
91+
"Total number of LLM requests",
92+
["team_id", "experiment_id", "model", "status"],
93+
registry=self.registry,
94+
)
95+
96+
# LLM Latency metrics
97+
self.llm_duration_seconds = Histogram(
98+
"alphatrion_llm_duration_seconds",
99+
"LLM request duration in seconds",
100+
["team_id", "experiment_id", "model"],
101+
buckets=[0.1, 0.5, 1.0, 2.0, 5.0, 10.0, 30.0, 60.0],
102+
registry=self.registry,
103+
)
104+
105+
# Error tracking
106+
self.llm_errors_total = Counter(
107+
"alphatrion_llm_errors_total",
108+
"Total LLM errors by type",
109+
["error_type"],
110+
registry=self.registry,
111+
)
112+
113+
def on_start(self, span: ReadableSpan, parent_context: Context | None = None):
114+
"""Called when a span is started. No-op for this processor."""
115+
pass
116+
117+
def on_end(self, span: ReadableSpan):
118+
"""
119+
Called when a span ends. Extract metrics and push to Prometheus.
120+
121+
Args:
122+
span: The completed span
123+
"""
124+
try:
125+
# Only process spans with traceloop attributes
126+
# (same filter as ClickHouse exporter)
127+
if not span.attributes or "traceloop.workflow.name" not in span.attributes:
128+
return
129+
130+
# Extract common attributes
131+
attributes = {k: str(v) for k, v in span.attributes.items()}
132+
team_id = attributes.get("team_id", "unknown")
133+
experiment_id = attributes.get("experiment_id", "unknown")
134+
135+
# Only process LLM spans
136+
if "llm.usage.total_tokens" not in attributes:
137+
return
138+
139+
# Calculate duration in seconds
140+
duration = (span.end_time - span.start_time) / 1_000_000_000
141+
142+
# Status
143+
status_map = {0: "UNSET", 1: "OK", 2: "ERROR"}
144+
status = status_map.get(span.status.status_code.value, "UNSET")
145+
146+
# Track errors
147+
if status == "ERROR":
148+
error_type = self._classify_error(span, attributes)
149+
self.llm_errors_total.labels(
150+
error_type=error_type,
151+
).inc()
152+
153+
# Process LLM-specific metrics
154+
self._process_llm_span(
155+
span, attributes, team_id, experiment_id, duration, status
156+
)
157+
158+
# Push to gateway
159+
self._push_metrics()
160+
161+
except Exception as e:
162+
logger.error(f"Failed to process span metrics: {e}", exc_info=True)
163+
164+
def _classify_error(self, span: ReadableSpan, attributes: dict[str, str]) -> str:
165+
"""
166+
Classify error type from span.
167+
168+
Args:
169+
span: The span with error
170+
attributes: Span attributes
171+
172+
Returns:
173+
Error type string
174+
"""
175+
# Check status message for common error patterns
176+
status_msg = span.status.description or ""
177+
status_msg_lower = status_msg.lower()
178+
179+
# Common error patterns
180+
if "timeout" in status_msg_lower or "timed out" in status_msg_lower:
181+
return "timeout"
182+
elif "rate limit" in status_msg_lower or "429" in status_msg_lower:
183+
return "rate_limit"
184+
elif (
185+
"auth" in status_msg_lower
186+
or "401" in status_msg_lower
187+
or "403" in status_msg_lower
188+
):
189+
return "auth_error"
190+
elif "not found" in status_msg_lower or "404" in status_msg_lower:
191+
return "not_found"
192+
elif "invalid" in status_msg_lower or "400" in status_msg_lower:
193+
return "invalid_request"
194+
elif "connection" in status_msg_lower or "network" in status_msg_lower:
195+
return "connection_error"
196+
elif (
197+
"500" in status_msg_lower
198+
or "502" in status_msg_lower
199+
or "503" in status_msg_lower
200+
):
201+
return "server_error"
202+
else:
203+
return "unknown"
204+
205+
def _process_llm_span(
206+
self,
207+
span: ReadableSpan,
208+
attributes: dict[str, str],
209+
team_id: str,
210+
experiment_id: str,
211+
duration: float,
212+
status: str,
213+
):
214+
"""Process LLM-specific metrics from a span."""
215+
# Extract model name
216+
model = attributes.get(
217+
"gen_ai.request.model", attributes.get("gen_ai.response.model", "unknown")
218+
)
219+
220+
# Token metrics
221+
total_tokens = int(attributes.get("llm.usage.total_tokens", 0))
222+
input_tokens = int(attributes.get("gen_ai.usage.input_tokens", 0))
223+
output_tokens = int(attributes.get("gen_ai.usage.output_tokens", 0))
224+
225+
if total_tokens > 0:
226+
self.llm_tokens_total.labels(
227+
team_id=team_id,
228+
experiment_id=experiment_id,
229+
model=model,
230+
token_type="total",
231+
).inc(total_tokens)
232+
233+
if input_tokens > 0:
234+
self.llm_input_tokens_total.labels(
235+
team_id=team_id,
236+
experiment_id=experiment_id,
237+
model=model,
238+
).inc(input_tokens)
239+
240+
self.llm_tokens_total.labels(
241+
team_id=team_id,
242+
experiment_id=experiment_id,
243+
model=model,
244+
token_type="input",
245+
).inc(input_tokens)
246+
247+
if output_tokens > 0:
248+
self.llm_output_tokens_total.labels(
249+
team_id=team_id,
250+
experiment_id=experiment_id,
251+
model=model,
252+
).inc(output_tokens)
253+
254+
self.llm_tokens_total.labels(
255+
team_id=team_id,
256+
experiment_id=experiment_id,
257+
model=model,
258+
token_type="output",
259+
).inc(output_tokens)
260+
261+
# Request count
262+
self.llm_requests_total.labels(
263+
team_id=team_id,
264+
experiment_id=experiment_id,
265+
model=model,
266+
status=status,
267+
).inc()
268+
269+
# Duration
270+
self.llm_duration_seconds.labels(
271+
team_id=team_id,
272+
experiment_id=experiment_id,
273+
model=model,
274+
).observe(duration)
275+
276+
def _push_metrics(self):
277+
"""Push metrics to Prometheus push gateway."""
278+
try:
279+
# Use pushadd_to_gateway to accumulate counters instead of replacing them
280+
pushadd_to_gateway(
281+
self.pushgateway_url,
282+
job=self.job_name,
283+
registry=self.registry,
284+
grouping_key=self.grouping_key,
285+
)
286+
logger.debug("Successfully pushed metrics to Prometheus push gateway")
287+
except Exception as e:
288+
logger.warning(f"Failed to push metrics to Prometheus: {e}")
289+
290+
def shutdown(self):
291+
"""Shutdown the processor and perform final push."""
292+
try:
293+
self._push_metrics()
294+
logger.info("PrometheusSpanProcessor shut down successfully")
295+
except Exception as e:
296+
logger.error(f"Error during PrometheusSpanProcessor shutdown: {e}")
297+
298+
def force_flush(self, timeout_millis: int = 30000) -> bool:
299+
"""
300+
Force flush metrics to push gateway.
301+
302+
Args:
303+
timeout_millis: Timeout in milliseconds (not used)
304+
305+
Returns:
306+
True if successful, False otherwise
307+
"""
308+
try:
309+
self._push_metrics()
310+
return True
311+
except Exception as e:
312+
logger.error(f"Failed to force flush metrics: {e}")
313+
return False

0 commit comments

Comments
 (0)