4747 wrap_function_wrapper , # type: ignore[reportUnknownVariableType]
4848)
4949
50- from opentelemetry ._logs import get_logger
51- from opentelemetry .instrumentation ._semconv import (
52- _OpenTelemetrySemanticConventionStability ,
53- _OpenTelemetryStabilitySignalType ,
54- _StabilityMode ,
55- )
5650from opentelemetry .instrumentation .instrumentor import BaseInstrumentor
5751from opentelemetry .instrumentation .utils import unwrap
5852from opentelemetry .instrumentation .vertexai .package import _instruments
59- from opentelemetry .instrumentation .vertexai .patch import MethodWrappers
53+ from opentelemetry .instrumentation .vertexai .patch import (
54+ agenerate_content ,
55+ generate_content ,
56+ )
6057from opentelemetry .instrumentation .vertexai .utils import is_content_enabled
61- from opentelemetry .semconv .schemas import Schemas
62- from opentelemetry .trace import get_tracer
63- from opentelemetry .util .genai .completion_hook import load_completion_hook
64-
65-
66- def _methods_to_wrap (
67- method_wrappers : MethodWrappers ,
68- ):
69- # This import is very slow, do it lazily in case instrument() is not called
70- # pylint: disable=import-outside-toplevel
71- from google .cloud .aiplatform_v1 .services .prediction_service import ( # noqa: PLC0415
72- async_client ,
73- client ,
74- )
75- from google .cloud .aiplatform_v1beta1 .services .prediction_service import ( # noqa: PLC0415
76- async_client as async_client_v1beta1 ,
77- )
78- from google .cloud .aiplatform_v1beta1 .services .prediction_service import ( # noqa: PLC0415
79- client as client_v1beta1 ,
80- )
81-
82- for client_class in (
83- client .PredictionServiceClient ,
84- client_v1beta1 .PredictionServiceClient ,
85- ):
86- yield (
87- client_class ,
88- client_class .generate_content .__name__ , # type: ignore[reportUnknownMemberType]
89- method_wrappers .generate_content ,
90- )
91-
92- for client_class in (
93- async_client .PredictionServiceAsyncClient ,
94- async_client_v1beta1 .PredictionServiceAsyncClient ,
95- ):
96- yield (
97- client_class ,
98- client_class .generate_content .__name__ , # type: ignore[reportUnknownMemberType]
99- method_wrappers .agenerate_content ,
100- )
58+ from opentelemetry .util .genai .handler import get_telemetry_handler
10159
10260
10361class VertexAIInstrumentor (BaseInstrumentor ):
@@ -110,61 +68,55 @@ def instrumentation_dependencies(self) -> Collection[str]:
11068
11169 def _instrument (self , ** kwargs : Any ):
11270 """Enable VertexAI instrumentation."""
113- completion_hook = (
114- kwargs .get ("completion_hook" ) or load_completion_hook ()
115- )
116- sem_conv_opt_in_mode = _OpenTelemetrySemanticConventionStability ._get_opentelemetry_stability_opt_in_mode (
117- _OpenTelemetryStabilitySignalType .GEN_AI ,
118- )
11971 tracer_provider = kwargs .get ("tracer_provider" )
120- schema = (
121- Schemas .V1_28_0 .value
122- if sem_conv_opt_in_mode == _StabilityMode .DEFAULT
123- else Schemas .V1_36_0 .value
124- )
125- tracer = get_tracer (
126- __name__ ,
127- "" ,
128- tracer_provider ,
129- schema_url = schema ,
130- )
13172 logger_provider = kwargs .get ("logger_provider" )
132- logger = get_logger (
133- __name__ ,
134- "" ,
73+ meter_provider = kwargs .get ("meter_provider" )
74+
75+ handler = get_telemetry_handler (
76+ tracer_provider = tracer_provider ,
77+ meter_provider = meter_provider ,
13578 logger_provider = logger_provider ,
136- schema_url = schema ,
13779 )
138- sem_conv_opt_in_mode = _OpenTelemetrySemanticConventionStability ._get_opentelemetry_stability_opt_in_mode (
139- _OpenTelemetryStabilitySignalType .GEN_AI ,
80+
81+ capture_content = is_content_enabled ()
82+
83+ # This import is very slow, do it lazily in case instrument() is not called
84+ # pylint: disable=import-outside-toplevel
85+ from google .cloud .aiplatform_v1 .services .prediction_service import ( # noqa: PLC0415
86+ async_client ,
87+ client ,
14088 )
141- if sem_conv_opt_in_mode == _StabilityMode .DEFAULT :
142- # Type checker now knows sem_conv_opt_in_mode is a Literal[_StabilityMode.DEFAULT]
143- method_wrappers = MethodWrappers (
144- tracer ,
145- logger ,
146- is_content_enabled (sem_conv_opt_in_mode ),
147- sem_conv_opt_in_mode ,
148- completion_hook ,
149- )
150- elif sem_conv_opt_in_mode == _StabilityMode .GEN_AI_LATEST_EXPERIMENTAL :
151- # Type checker now knows it's the other literal
152- method_wrappers = MethodWrappers (
153- tracer ,
154- logger ,
155- is_content_enabled (sem_conv_opt_in_mode ),
156- sem_conv_opt_in_mode ,
157- completion_hook ,
89+ from google .cloud .aiplatform_v1beta1 .services .prediction_service import ( # noqa: PLC0415
90+ async_client as async_client_v1beta1 ,
91+ )
92+ from google .cloud .aiplatform_v1beta1 .services .prediction_service import ( # noqa: PLC0415
93+ client as client_v1beta1 ,
94+ )
95+
96+ sync_wrapper = generate_content (capture_content , handler )
97+ async_wrapper = agenerate_content (capture_content , handler )
98+
99+ for client_class in (
100+ client .PredictionServiceClient ,
101+ client_v1beta1 .PredictionServiceClient ,
102+ ):
103+ method_name = client_class .generate_content .__name__ # type: ignore[reportUnknownMemberType]
104+ wrap_function_wrapper (
105+ client_class ,
106+ name = method_name ,
107+ wrapper = sync_wrapper ,
158108 )
159- else :
160- raise RuntimeError (f"{ sem_conv_opt_in_mode } mode not supported" )
161- for client_class , method_name , wrapper in _methods_to_wrap (
162- method_wrappers
109+ self ._methods_to_unwrap .append ((client_class , method_name ))
110+
111+ for client_class in (
112+ async_client .PredictionServiceAsyncClient ,
113+ async_client_v1beta1 .PredictionServiceAsyncClient ,
163114 ):
115+ method_name = client_class .generate_content .__name__ # type: ignore[reportUnknownMemberType]
164116 wrap_function_wrapper (
165117 client_class ,
166118 name = method_name ,
167- wrapper = wrapper ,
119+ wrapper = async_wrapper ,
168120 )
169121 self ._methods_to_unwrap .append ((client_class , method_name ))
170122
0 commit comments