Skip to content

Commit 8d68e3f

Browse files
committed
add prometheus exporter
Signed-off-by: kerthcet <kerthcet@gmail.com>
1 parent e134e15 commit 8d68e3f

12 files changed

Lines changed: 881 additions & 0 deletions

File tree

.env.example

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,8 @@ ALPHATRION_ARTIFACT_INSECURE=false
1616
ALPHATRION_ENABLE_TRACING=true
1717
ALPHATRION_CLICKHOUSE_INIT_TABLES=true
1818
ALPHATRION_CLICKHOUSE_ENABLE_BATCH=true
19+
20+
# Prometheus push gateway configurations
21+
ALPHATRION_ENABLE_PROMETHEUS=false
22+
ALPHATRION_PROMETHEUS_PUSHGATEWAY_URL=localhost:9091
23+
ALPHATRION_PROMETHEUS_JOB_NAME=alphatrion

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ Save the generated user ID — you'll need it to track experiments.
6565
**Optional Tools:**
6666
- pgAdmin: `http://localhost:8081` (alphatrion@inftyai.com / alphatr1on)
6767
- Registry UI: `http://localhost:80`
68+
- Grafana: `http://localhost:3000` (admin / admin) - LLM metrics dashboard
69+
- Prometheus: `http://localhost:9090` - Metrics explorer
6870

6971
### 3. Track Your First Experiment
7072

alphatrion/envs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@
1616
CLICKHOUSE_PASSWORD = "ALPHATRION_CLICKHOUSE_PASSWORD"
1717
CLICKHOUSE_ENABLE_BATCH = "ALPHATRION_CLICKHOUSE_ENABLE_BATCH"
1818

19+
# Prometheus push gateway related envs
20+
ENABLE_PROMETHEUS = "ALPHATRION_ENABLE_PROMETHEUS"
21+
PROMETHEUS_PUSHGATEWAY_URL = "ALPHATRION_PROMETHEUS_PUSHGATEWAY_URL"
22+
PROMETHEUS_JOB_NAME = "ALPHATRION_PROMETHEUS_JOB_NAME"
23+
1924
# Dashboard only related envs
2025
DASHBOARD_USER_ID = "ALPHATRION_DASHBOARD_USER_ID"
2126

alphatrion/storage/runtime.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from alphatrion.storage.sqlstore import SQLStore
1111
from alphatrion.storage.tracestore import TraceStore
1212
from alphatrion.tracing.clickhouse_exporter import ClickHouseSpanExporter
13+
from alphatrion.tracing.prometheus_span_processor import PrometheusSpanProcessor
1314
from alphatrion.tracing.span_processor import ContextAttributesSpanProcessor
1415

1516
__STORAGE_RUNTIME__ = None
@@ -55,6 +56,19 @@ def __init__(self):
5556
tracer_provider = trace.get_tracer_provider()
5657
tracer_provider.add_span_processor(ContextAttributesSpanProcessor())
5758

59+
# Add Prometheus span processor if enabled
60+
if os.getenv(envs.ENABLE_PROMETHEUS, "false").lower() == "true":
61+
pushgateway_url = os.getenv(
62+
envs.PROMETHEUS_PUSHGATEWAY_URL, "localhost:9091"
63+
)
64+
job_name = os.getenv(envs.PROMETHEUS_JOB_NAME, "alphatrion")
65+
66+
prometheus_processor = PrometheusSpanProcessor(
67+
pushgateway_url=pushgateway_url,
68+
job_name=job_name,
69+
)
70+
tracer_provider.add_span_processor(prometheus_processor)
71+
5872
artifact_insecure = os.getenv(envs.ARTIFACT_INSECURE, "false").lower() == "true"
5973
if artifact_storage_enabled():
6074
self._artifact = Artifact(insecure=artifact_insecure)
Lines changed: 322 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,322 @@
1+
# ruff: noqa: PLR0911
2+
3+
import logging
4+
import socket
5+
import uuid
6+
7+
from opentelemetry.context import Context
8+
from opentelemetry.sdk.trace import ReadableSpan
9+
from opentelemetry.sdk.trace.export import SpanProcessor
10+
from prometheus_client import (
11+
CollectorRegistry,
12+
Counter,
13+
Histogram,
14+
pushadd_to_gateway,
15+
)
16+
17+
logger = logging.getLogger(__name__)
18+
19+
20+
class PrometheusSpanProcessor(SpanProcessor):
21+
"""
22+
Span processor that exports OpenTelemetry span metrics to Prometheus push gateway.
23+
24+
This processor extracts metrics from LLM spans (tokens, latency, etc.) and pushes
25+
them to a Prometheus push gateway, making them available for scraping by Prometheus.
26+
"""
27+
28+
def __init__(
29+
self,
30+
pushgateway_url: str,
31+
job_name: str = "alphatrion",
32+
grouping_key: dict[str, str] | None = None,
33+
):
34+
"""
35+
Initialize the Prometheus span processor.
36+
37+
Args:
38+
pushgateway_url: URL of the Prometheus push gateway (e.g., "localhost:9091")
39+
job_name: Job name for the metrics in Prometheus
40+
grouping_key: Additional grouping labels (e.g., {"instance": "app-1"})
41+
"""
42+
self.pushgateway_url = pushgateway_url
43+
self.job_name = job_name
44+
45+
# Generate unique instance identifier to prevent metrics from being overwritten
46+
# Combines hostname (for traceability) with UUID (for uniqueness)
47+
if grouping_key is None:
48+
try:
49+
hostname = socket.gethostname()
50+
if hostname:
51+
instance_id = f"{hostname}-{uuid.uuid4().hex}"
52+
else:
53+
instance_id = uuid.uuid4().hex
54+
except Exception:
55+
instance_id = uuid.uuid4().hex
56+
57+
self.grouping_key = {"instance": instance_id}
58+
else:
59+
self.grouping_key = grouping_key
60+
61+
# Create a separate registry for push gateway metrics
62+
self.registry = CollectorRegistry()
63+
64+
# Define metrics
65+
self._init_metrics()
66+
67+
logger.info(
68+
f"PrometheusSpanProcessor initialized: pushgateway={pushgateway_url}, "
69+
f"job={job_name}"
70+
)
71+
72+
def _init_metrics(self):
73+
"""Initialize Prometheus metrics."""
74+
75+
# LLM Token usage metrics
76+
self.llm_tokens_total = Counter(
77+
"alphatrion_llm_tokens_total",
78+
"Total LLM tokens consumed",
79+
["team_id", "experiment_id", "model", "token_type"],
80+
registry=self.registry,
81+
)
82+
83+
self.llm_input_tokens_total = Counter(
84+
"alphatrion_llm_input_tokens_total",
85+
"Total LLM input tokens consumed",
86+
["team_id", "experiment_id", "model"],
87+
registry=self.registry,
88+
)
89+
90+
self.llm_output_tokens_total = Counter(
91+
"alphatrion_llm_output_tokens_total",
92+
"Total LLM output tokens consumed",
93+
["team_id", "experiment_id", "model"],
94+
registry=self.registry,
95+
)
96+
97+
# LLM Request metrics
98+
self.llm_requests_total = Counter(
99+
"alphatrion_llm_requests_total",
100+
"Total number of LLM requests",
101+
["team_id", "experiment_id", "model", "status"],
102+
registry=self.registry,
103+
)
104+
105+
# LLM Latency metrics
106+
self.llm_duration_seconds = Histogram(
107+
"alphatrion_llm_duration_seconds",
108+
"LLM request duration in seconds",
109+
["team_id", "experiment_id", "model"],
110+
buckets=[0.1, 0.5, 1.0, 2.0, 5.0, 10.0, 30.0, 60.0],
111+
registry=self.registry,
112+
)
113+
114+
# Error tracking
115+
self.llm_errors_total = Counter(
116+
"alphatrion_llm_errors_total",
117+
"Total LLM errors by type",
118+
["error_type"],
119+
registry=self.registry,
120+
)
121+
122+
def on_start(self, span: ReadableSpan, parent_context: Context | None = None):
123+
"""Called when a span is started. No-op for this processor."""
124+
pass
125+
126+
def on_end(self, span: ReadableSpan):
127+
"""
128+
Called when a span ends. Extract metrics and push to Prometheus.
129+
130+
Args:
131+
span: The completed span
132+
"""
133+
try:
134+
# Only process spans with traceloop attributes
135+
# (same filter as ClickHouse exporter)
136+
if not span.attributes or "traceloop.workflow.name" not in span.attributes:
137+
return
138+
139+
# Extract common attributes
140+
attributes = {k: str(v) for k, v in span.attributes.items()}
141+
team_id = attributes.get("team_id", "unknown")
142+
experiment_id = attributes.get("experiment_id", "unknown")
143+
144+
# Only process LLM spans
145+
if "llm.usage.total_tokens" not in attributes:
146+
return
147+
148+
# Calculate duration in seconds
149+
duration = (span.end_time - span.start_time) / 1_000_000_000
150+
151+
# Status
152+
status_map = {0: "UNSET", 1: "OK", 2: "ERROR"}
153+
status = status_map.get(span.status.status_code.value, "UNSET")
154+
155+
# Track errors
156+
if status == "ERROR":
157+
error_type = self._classify_error(span, attributes)
158+
self.llm_errors_total.labels(
159+
error_type=error_type,
160+
).inc()
161+
162+
# Process LLM-specific metrics
163+
self._process_llm_span(
164+
span, attributes, team_id, experiment_id, duration, status
165+
)
166+
167+
# Push to gateway
168+
self._push_metrics()
169+
170+
except Exception as e:
171+
logger.error(f"Failed to process span metrics: {e}", exc_info=True)
172+
173+
def _classify_error(self, span: ReadableSpan, attributes: dict[str, str]) -> str:
174+
"""
175+
Classify error type from span.
176+
177+
Args:
178+
span: The span with error
179+
attributes: Span attributes
180+
181+
Returns:
182+
Error type string
183+
"""
184+
# Check status message for common error patterns
185+
status_msg = span.status.description or ""
186+
status_msg_lower = status_msg.lower()
187+
188+
# Common error patterns
189+
if "timeout" in status_msg_lower or "timed out" in status_msg_lower:
190+
return "timeout"
191+
elif "rate limit" in status_msg_lower or "429" in status_msg_lower:
192+
return "rate_limit"
193+
elif (
194+
"auth" in status_msg_lower
195+
or "401" in status_msg_lower
196+
or "403" in status_msg_lower
197+
):
198+
return "auth_error"
199+
elif "not found" in status_msg_lower or "404" in status_msg_lower:
200+
return "not_found"
201+
elif "invalid" in status_msg_lower or "400" in status_msg_lower:
202+
return "invalid_request"
203+
elif "connection" in status_msg_lower or "network" in status_msg_lower:
204+
return "connection_error"
205+
elif (
206+
"500" in status_msg_lower
207+
or "502" in status_msg_lower
208+
or "503" in status_msg_lower
209+
):
210+
return "server_error"
211+
else:
212+
return "unknown"
213+
214+
def _process_llm_span(
215+
self,
216+
span: ReadableSpan,
217+
attributes: dict[str, str],
218+
team_id: str,
219+
experiment_id: str,
220+
duration: float,
221+
status: str,
222+
):
223+
"""Process LLM-specific metrics from a span."""
224+
# Extract model name
225+
model = attributes.get(
226+
"gen_ai.request.model", attributes.get("gen_ai.response.model", "unknown")
227+
)
228+
229+
# Token metrics
230+
total_tokens = int(attributes.get("llm.usage.total_tokens", 0))
231+
input_tokens = int(attributes.get("gen_ai.usage.input_tokens", 0))
232+
output_tokens = int(attributes.get("gen_ai.usage.output_tokens", 0))
233+
234+
if total_tokens > 0:
235+
self.llm_tokens_total.labels(
236+
team_id=team_id,
237+
experiment_id=experiment_id,
238+
model=model,
239+
token_type="total",
240+
).inc(total_tokens)
241+
242+
if input_tokens > 0:
243+
self.llm_input_tokens_total.labels(
244+
team_id=team_id,
245+
experiment_id=experiment_id,
246+
model=model,
247+
).inc(input_tokens)
248+
249+
self.llm_tokens_total.labels(
250+
team_id=team_id,
251+
experiment_id=experiment_id,
252+
model=model,
253+
token_type="input",
254+
).inc(input_tokens)
255+
256+
if output_tokens > 0:
257+
self.llm_output_tokens_total.labels(
258+
team_id=team_id,
259+
experiment_id=experiment_id,
260+
model=model,
261+
).inc(output_tokens)
262+
263+
self.llm_tokens_total.labels(
264+
team_id=team_id,
265+
experiment_id=experiment_id,
266+
model=model,
267+
token_type="output",
268+
).inc(output_tokens)
269+
270+
# Request count
271+
self.llm_requests_total.labels(
272+
team_id=team_id,
273+
experiment_id=experiment_id,
274+
model=model,
275+
status=status,
276+
).inc()
277+
278+
# Duration
279+
self.llm_duration_seconds.labels(
280+
team_id=team_id,
281+
experiment_id=experiment_id,
282+
model=model,
283+
).observe(duration)
284+
285+
def _push_metrics(self):
286+
"""Push metrics to Prometheus push gateway."""
287+
try:
288+
# Use pushadd_to_gateway to accumulate counters instead of replacing them
289+
pushadd_to_gateway(
290+
self.pushgateway_url,
291+
job=self.job_name,
292+
registry=self.registry,
293+
grouping_key=self.grouping_key,
294+
)
295+
logger.debug("Successfully pushed metrics to Prometheus push gateway")
296+
except Exception as e:
297+
logger.warning(f"Failed to push metrics to Prometheus: {e}")
298+
299+
def shutdown(self):
300+
"""Shutdown the processor and perform final push."""
301+
try:
302+
self._push_metrics()
303+
logger.info("PrometheusSpanProcessor shut down successfully")
304+
except Exception as e:
305+
logger.error(f"Error during PrometheusSpanProcessor shutdown: {e}")
306+
307+
def force_flush(self, timeout_millis: int = 30000) -> bool:
308+
"""
309+
Force flush metrics to push gateway.
310+
311+
Args:
312+
timeout_millis: Timeout in milliseconds (not used)
313+
314+
Returns:
315+
True if successful, False otherwise
316+
"""
317+
try:
318+
self._push_metrics()
319+
return True
320+
except Exception as e:
321+
logger.error(f"Failed to force flush metrics: {e}")
322+
return False

0 commit comments

Comments
 (0)