@@ -55,7 +55,7 @@ class InferenceInvocation(GenAIInvocation):
5555 context manager rather than constructing this directly.
5656 """
5757
58- def __init__ ( # pylint: disable=too-many-locals
58+ def __init__ (
5959 self ,
6060 tracer : Tracer ,
6161 metrics_recorder : InvocationMetricsRecorder ,
@@ -64,25 +64,8 @@ def __init__( # pylint: disable=too-many-locals
6464 provider : str ,
6565 * ,
6666 request_model : str | None = None ,
67- input_messages : list [InputMessage ] | None = None ,
68- output_messages : list [OutputMessage ] | None = None ,
69- system_instruction : list [MessagePart ] | None = None ,
70- response_model_name : str | None = None ,
71- response_id : str | None = None ,
72- finish_reasons : list [str ] | None = None ,
73- input_tokens : int | None = None ,
74- output_tokens : int | None = None ,
75- temperature : float | None = None ,
76- top_p : float | None = None ,
77- frequency_penalty : float | None = None ,
78- presence_penalty : float | None = None ,
79- max_tokens : int | None = None ,
80- stop_sequences : list [str ] | None = None ,
81- seed : int | None = None ,
8267 server_address : str | None = None ,
8368 server_port : int | None = None ,
84- attributes : dict [str , Any ] | None = None ,
85- metric_attributes : dict [str , Any ] | None = None ,
8669 ) -> None :
8770 """Use handler.start_inference(provider) or handler.inference(provider) instead of calling this directly."""
8871 _operation_name = GenAI .GenAiOperationNameValues .CHAT .value
@@ -96,38 +79,31 @@ def __init__( # pylint: disable=too-many-locals
9679 if request_model
9780 else _operation_name ,
9881 span_kind = SpanKind .CLIENT ,
99- attributes = attributes ,
100- metric_attributes = metric_attributes ,
10182 )
10283 self .provider = provider
10384 self .request_model = request_model
104- self .input_messages : list [InputMessage ] = (
105- [] if input_messages is None else input_messages
106- )
107- self .output_messages : list [OutputMessage ] = (
108- [] if output_messages is None else output_messages
109- )
110- self .system_instruction : list [MessagePart ] = (
111- [] if system_instruction is None else system_instruction
112- )
113- self .response_model_name = response_model_name
114- self .response_id = response_id
115- self .finish_reasons = finish_reasons
116- self .input_tokens = input_tokens
117- self .output_tokens = output_tokens
118- self .temperature = temperature
119- self .top_p = top_p
120- self .frequency_penalty = frequency_penalty
121- self .presence_penalty = presence_penalty
122- self .max_tokens = max_tokens
123- self .stop_sequences = stop_sequences
124- self .seed = seed
12585 self .server_address = server_address
12686 self .server_port = server_port
87+
88+ self .input_messages : list [InputMessage ] = []
89+ self .output_messages : list [OutputMessage ] = []
90+ self .system_instruction : list [MessagePart ] = []
91+ self .response_model_name : str | None = None
92+ self .response_id : str | None = None
93+ self .finish_reasons : list [str ] | None = None
94+ self .input_tokens : int | None = None
95+ self .output_tokens : int | None = None
96+ self .temperature : float | None = None
97+ self .top_p : float | None = None
98+ self .frequency_penalty : float | None = None
99+ self .presence_penalty : float | None = None
100+ self .max_tokens : int | None = None
101+ self .stop_sequences : list [str ] | None = None
102+ self .seed : int | None = None
127103 self .cache_creation_input_tokens : int | None = None
128104 self .cache_read_input_tokens : int | None = None
129105 self .tool_definitions : list [ToolDefinition ] | None = None
130- self ._start ()
106+ self ._start (self . _get_base_attributes () )
131107
132108 def _get_message_attributes (self , * , for_span : bool ) -> dict [str , Any ]:
133109 return get_content_attributes (
@@ -288,33 +264,34 @@ def _start_with_handler(
288264 completion_hook : CompletionHook ,
289265 ) -> None :
290266 """Create and start an InferenceInvocation from this data container. Called by handler.start_llm()."""
291- self . _inference_invocation = InferenceInvocation (
267+ inv = InferenceInvocation (
292268 tracer ,
293269 metrics_recorder ,
294270 logger ,
295271 completion_hook ,
296272 self .provider or "" ,
297273 request_model = self .request_model ,
298- input_messages = self .input_messages ,
299- output_messages = self .output_messages ,
300- system_instruction = self .system_instruction ,
301- response_model_name = self .response_model_name ,
302- response_id = self .response_id ,
303- finish_reasons = self .finish_reasons ,
304- input_tokens = self .input_tokens ,
305- output_tokens = self .output_tokens ,
306- temperature = self .temperature ,
307- top_p = self .top_p ,
308- frequency_penalty = self .frequency_penalty ,
309- presence_penalty = self .presence_penalty ,
310- max_tokens = self .max_tokens ,
311- stop_sequences = self .stop_sequences ,
312- seed = self .seed ,
313274 server_address = self .server_address ,
314275 server_port = self .server_port ,
315- attributes = self .attributes ,
316- metric_attributes = self .metric_attributes ,
317276 )
277+ inv .input_messages = self .input_messages
278+ inv .output_messages = self .output_messages
279+ inv .system_instruction = self .system_instruction
280+ inv .response_model_name = self .response_model_name
281+ inv .response_id = self .response_id
282+ inv .finish_reasons = self .finish_reasons
283+ inv .input_tokens = self .input_tokens
284+ inv .output_tokens = self .output_tokens
285+ inv .temperature = self .temperature
286+ inv .top_p = self .top_p
287+ inv .frequency_penalty = self .frequency_penalty
288+ inv .presence_penalty = self .presence_penalty
289+ inv .max_tokens = self .max_tokens
290+ inv .stop_sequences = self .stop_sequences
291+ inv .seed = self .seed
292+ inv .attributes .update (self .attributes )
293+ inv .metric_attributes .update (self .metric_attributes )
294+ self ._inference_invocation = inv
318295
319296 def _sync_to_invocation (self ) -> None :
320297 inv = self ._inference_invocation
0 commit comments