|
1 | 1 | """Test embedding instrumentation for LlamaIndex.""" |
2 | 2 |
|
3 | | -import pytest |
4 | | - |
5 | | -import os |
6 | | - |
7 | 3 | from llama_index.core import Settings |
8 | 4 | from llama_index.core.callbacks import CallbackManager |
9 | | -from llama_index.embeddings.openai import OpenAIEmbedding |
10 | | - |
11 | | -from opentelemetry import metrics, trace |
12 | | -from opentelemetry.instrumentation.llamaindex import LlamaindexInstrumentor |
13 | | -from opentelemetry.sdk.metrics import MeterProvider |
14 | | -from opentelemetry.sdk.metrics.export import InMemoryMetricReader |
15 | | -from opentelemetry.sdk.trace import TracerProvider |
16 | | -from opentelemetry.sdk.trace.export import ( |
17 | | - ConsoleSpanExporter, |
18 | | - SimpleSpanProcessor, |
19 | | -) |
20 | | - |
21 | | -pytestmark = pytest.mark.skip( |
22 | | - reason="Requires live OpenAI API key; needs VCR cassettes" |
23 | | -) |
24 | | - |
25 | | -# Global setup - shared across tests |
26 | | -metric_reader = None |
27 | | -instrumentor = None |
28 | | - |
29 | | - |
30 | | -def setup_telemetry(): |
31 | | - """Setup OpenTelemetry with span and metric exporters (once).""" |
32 | | - global metric_reader, instrumentor |
33 | | - |
34 | | - if metric_reader is not None: |
35 | | - return metric_reader |
36 | | - |
37 | | - # Enable metrics |
38 | | - os.environ["OTEL_INSTRUMENTATION_GENAI_EMITTERS"] = "span_metric" |
39 | | - |
40 | | - # Setup tracing |
41 | | - trace.set_tracer_provider(TracerProvider()) |
42 | | - trace.get_tracer_provider().add_span_processor( |
43 | | - SimpleSpanProcessor(ConsoleSpanExporter()) |
44 | | - ) |
45 | | - |
46 | | - # Setup metrics with InMemoryMetricReader |
47 | | - metric_reader = InMemoryMetricReader() |
48 | | - meter_provider = MeterProvider(metric_readers=[metric_reader]) |
49 | | - metrics.set_meter_provider(meter_provider) |
50 | | - |
51 | | - # Enable instrumentation once |
52 | | - instrumentor = LlamaindexInstrumentor() |
53 | | - instrumentor.instrument( |
54 | | - tracer_provider=trace.get_tracer_provider(), |
55 | | - meter_provider=metrics.get_meter_provider(), |
56 | | - ) |
57 | | - |
58 | | - return metric_reader |
| 5 | +from llama_index.core.embeddings import MockEmbedding |
59 | 6 |
|
60 | 7 |
|
61 | | -def test_embedding_single_text(): |
62 | | - """Test single text embedding instrumentation.""" |
63 | | - print("\nTest: Single Text Embedding") |
64 | | - print("=" * 60) |
65 | | - |
66 | | - metric_reader = setup_telemetry() |
67 | | - |
68 | | - # Configure embedding model |
69 | | - embed_model = OpenAIEmbedding( |
70 | | - model="text-embedding-3-small", |
71 | | - api_key=os.environ.get("OPENAI_API_KEY"), |
72 | | - ) |
| 8 | +def test_embedding_single_text(span_exporter, instrument): |
| 9 | + """Test single text embedding produces spans.""" |
| 10 | + embed_model = MockEmbedding(embed_dim=8) |
73 | 11 | Settings.embed_model = embed_model |
74 | | - |
75 | | - # Make sure callback manager is initialized |
76 | 12 | if Settings.callback_manager is None: |
77 | 13 | Settings.callback_manager = CallbackManager() |
78 | 14 |
|
79 | | - # Generate single embedding |
80 | | - text = "LlamaIndex is a data framework for LLM applications" |
81 | | - embedding = embed_model.get_text_embedding(text) |
82 | | - |
83 | | - print(f"\nText: {text}") |
84 | | - print(f"Embedding dimension: {len(embedding)}") |
85 | | - print(f"First 5 values: {embedding[:5]}") |
86 | | - |
87 | | - # Validate metrics |
88 | | - print("\nMetrics:") |
89 | | - metrics_data = metric_reader.get_metrics_data() |
90 | | - for resource_metric in metrics_data.resource_metrics: |
91 | | - for scope_metric in resource_metric.scope_metrics: |
92 | | - for metric in scope_metric.metrics: |
93 | | - print(f"\nMetric: {metric.name}") |
94 | | - for data_point in metric.data.data_points: |
95 | | - if hasattr(data_point, "bucket_counts"): |
96 | | - # Histogram |
97 | | - print(f" Count: {sum(data_point.bucket_counts)}") |
98 | | - else: |
99 | | - # Counter |
100 | | - print(f" Value: {data_point.value}") |
101 | | - |
102 | | - print("\nTest completed successfully") |
| 15 | + embedding = embed_model.get_text_embedding( |
| 16 | + "LlamaIndex is a data framework for LLM applications" |
| 17 | + ) |
103 | 18 |
|
| 19 | + assert len(embedding) == 8 |
104 | 20 |
|
105 | | -def test_embedding_batch(): |
106 | | - """Test batch embedding instrumentation.""" |
107 | | - print("\nTest: Batch Embeddings") |
108 | | - print("=" * 60) |
| 21 | + spans = span_exporter.get_finished_spans() |
| 22 | + assert len(spans) >= 1 |
109 | 23 |
|
110 | | - metric_reader = setup_telemetry() |
111 | 24 |
|
112 | | - # Configure embedding model |
113 | | - embed_model = OpenAIEmbedding( |
114 | | - model="text-embedding-3-small", |
115 | | - api_key=os.environ.get("OPENAI_API_KEY"), |
116 | | - ) |
| 25 | +def test_embedding_batch(span_exporter, instrument): |
| 26 | + """Test batch embedding produces spans.""" |
| 27 | + embed_model = MockEmbedding(embed_dim=8) |
117 | 28 | Settings.embed_model = embed_model |
118 | | - |
119 | | - # Make sure callback manager is initialized |
120 | 29 | if Settings.callback_manager is None: |
121 | 30 | Settings.callback_manager = CallbackManager() |
122 | 31 |
|
123 | | - # Generate batch embeddings |
124 | 32 | texts = [ |
125 | 33 | "Paris is the capital of France", |
126 | 34 | "Berlin is the capital of Germany", |
127 | 35 | "Rome is the capital of Italy", |
128 | 36 | ] |
129 | 37 | embeddings = embed_model.get_text_embedding_batch(texts) |
130 | 38 |
|
131 | | - print(f"\nEmbedded {len(embeddings)} texts") |
132 | | - print(f"Dimension: {len(embeddings[0])}") |
133 | | - |
134 | | - # Validate metrics |
135 | | - print("\nMetrics:") |
136 | | - metrics_data = metric_reader.get_metrics_data() |
137 | | - for resource_metric in metrics_data.resource_metrics: |
138 | | - for scope_metric in resource_metric.scope_metrics: |
139 | | - for metric in scope_metric.metrics: |
140 | | - print(f"\nMetric: {metric.name}") |
141 | | - for data_point in metric.data.data_points: |
142 | | - if hasattr(data_point, "bucket_counts"): |
143 | | - # Histogram |
144 | | - print(f" Count: {sum(data_point.bucket_counts)}") |
145 | | - else: |
146 | | - # Counter |
147 | | - print(f" Value: {data_point.value}") |
148 | | - |
149 | | - print("\nTest completed successfully") |
150 | | - |
151 | | - |
152 | | -if __name__ == "__main__": |
153 | | - test_embedding_single_text() |
154 | | - print("\n" + "=" * 60 + "\n") |
155 | | - test_embedding_batch() |
| 39 | + assert len(embeddings) == 3 |
| 40 | + assert all(len(e) == 8 for e in embeddings) |
156 | 41 |
|
157 | | - # Cleanup |
158 | | - if instrumentor: |
159 | | - instrumentor.uninstrument() |
| 42 | + spans = span_exporter.get_finished_spans() |
| 43 | + assert len(spans) >= 1 |
0 commit comments