Skip to content

Commit bcdd752

Browse files
authored
Merge branch 'main' into mvick/agent-config-schema
2 parents 3d5df6b + 18602dd commit bcdd752

64 files changed

Lines changed: 424 additions & 378 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

newrelic/hooks/mlmodel_gemini.py

Lines changed: 107 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -197,15 +197,18 @@ def _record_embedding_success(*, transaction, embedding_id, linking_metadata, kw
197197
embedding_content = str(embedding_content)
198198
request_model = kwargs.get("model")
199199

200+
embedding_token_count = (
201+
settings.ai_monitoring.llm_token_count_callback(request_model, embedding_content)
202+
if settings.ai_monitoring.llm_token_count_callback
203+
else None
204+
)
205+
200206
full_embedding_response_dict = {
201207
"id": embedding_id,
202208
"span_id": span_id,
203209
"trace_id": trace_id,
204-
"token_count": (
205-
settings.ai_monitoring.llm_token_count_callback(request_model, embedding_content)
206-
if settings.ai_monitoring.llm_token_count_callback
207-
else None
208-
),
210+
# Replace values of 0 for token counts with None
211+
"response.usage.total_tokens": embedding_token_count or None,
209212
"request.model": request_model,
210213
"duration": ft.duration * 1000,
211214
"vendor": "gemini",
@@ -492,14 +495,9 @@ def _record_generation_error(*, transaction, linking_metadata, completion_id, kw
492495
"Unable to parse input message to Gemini LLM. Message content and role will be omitted from "
493496
"corresponding LlmChatCompletionMessage event. "
494497
)
498+
input_message_content, input_role = _parse_input_message(input_message)
495499

496-
generation_config = kwargs.get("config")
497-
if generation_config:
498-
request_temperature = getattr(generation_config, "temperature", None)
499-
request_max_tokens = getattr(generation_config, "max_output_tokens", None)
500-
else:
501-
request_temperature = None
502-
request_max_tokens = None
500+
request_temperature, request_max_tokens = _extract_generation_config(kwargs)
503501

504502
notice_error_attributes = {
505503
"http.statusCode": getattr(exc, "code", None),
@@ -540,17 +538,19 @@ def _record_generation_error(*, transaction, linking_metadata, completion_id, kw
540538
output_message_list = []
541539

542540
create_chat_completion_message_event(
543-
transaction,
544-
input_message,
545-
completion_id,
546-
span_id,
547-
trace_id,
541+
transaction=transaction,
542+
input_message_content=input_message_content,
543+
input_role=input_role,
544+
chat_completion_id=completion_id,
545+
span_id=span_id,
546+
trace_id=trace_id,
548547
# Passing the request model as the response model here since we do not have access to a response model
549-
request_model,
550-
request_model,
551-
llm_metadata,
552-
output_message_list,
553-
request_timestamp,
548+
response_model=request_model,
549+
llm_metadata=llm_metadata,
550+
output_message_list=output_message_list,
551+
# We do not record token counts in error cases, so set all_token_counts to True so the pipeline tokenizer does not run
552+
all_token_counts=True,
553+
request_timestamp=request_timestamp,
554554
)
555555
except Exception:
556556
_logger.warning(RECORD_EVENTS_FAILURE_LOG_MESSAGE, exc_info=True)
@@ -610,19 +610,22 @@ def _record_generation_success(
610610
request_timestamp=None,
611611
time_to_first_token=None,
612612
):
613+
settings = transaction.settings or global_settings()
613614
span_id = linking_metadata.get("span.id")
614615
trace_id = linking_metadata.get("trace.id")
615616
try:
616617
if response:
617618
response_model = response.get("model_version")
618619
# finish_reason is an enum, so grab just the stringified value from it to report
619620
finish_reason = response.get("candidates")[0].get("finish_reason").value
621+
token_usage = response.get("usage_metadata") or {}
620622
else:
621623
# Set all values to NoneTypes since we cannot access them through kwargs or another method that doesn't
622624
# require the response object
623625
response_model = None
624626
output_message_list = []
625627
finish_reason = None
628+
token_usage = {}
626629

627630
request_model = kwargs.get("model")
628631

@@ -644,13 +647,36 @@ def _record_generation_success(
644647
"corresponding LlmChatCompletionMessage event. "
645648
)
646649

647-
generation_config = kwargs.get("config")
648-
if generation_config:
649-
request_temperature = getattr(generation_config, "temperature", None)
650-
request_max_tokens = getattr(generation_config, "max_output_tokens", None)
651-
else:
652-
request_temperature = None
653-
request_max_tokens = None
650+
input_message_content, input_role = _parse_input_message(input_message)
651+
652+
# Parse output message content
653+
# This list should have a length of 1 to represent the output message
654+
# Parse the message text out to pass to any registered token counting callback
655+
output_message_content = output_message_list[0].get("parts")[0].get("text") if output_message_list else None
656+
657+
# Token counts default to those reported in the response object if available,
658+
# but the user registered callback below may override them.
659+
response_prompt_tokens = token_usage.get("prompt_token_count")
660+
response_completion_tokens = token_usage.get("candidates_token_count")
661+
response_total_tokens = token_usage.get("total_token_count")
662+
663+
# If the user has registered a callback to compute token counts it should always be preferred.
664+
token_count_callback = settings.ai_monitoring.llm_token_count_callback
665+
if token_count_callback:
666+
if input_message_content:
667+
response_prompt_tokens = token_count_callback(request_model, input_message_content)
668+
if output_message_content:
669+
response_completion_tokens = token_count_callback(response_model, output_message_content)
670+
671+
# Prefer the sum of individual counts as the total whenever both are available.
672+
# This ensures consistency in the event that the token counting callback has reported
673+
# different values for prompt or completion tokens.
674+
if response_prompt_tokens and response_completion_tokens:
675+
response_total_tokens = response_prompt_tokens + response_completion_tokens
676+
677+
all_token_counts = bool(response_prompt_tokens and response_completion_tokens and response_total_tokens)
678+
679+
request_temperature, request_max_tokens = _extract_generation_config(kwargs)
654680

655681
full_chat_completion_summary_dict = {
656682
"id": completion_id,
@@ -672,26 +698,57 @@ def _record_generation_success(
672698
"time_to_first_token": time_to_first_token,
673699
}
674700

701+
if all_token_counts:
702+
full_chat_completion_summary_dict["response.usage.prompt_tokens"] = response_prompt_tokens
703+
full_chat_completion_summary_dict["response.usage.completion_tokens"] = response_completion_tokens
704+
full_chat_completion_summary_dict["response.usage.total_tokens"] = response_total_tokens
705+
675706
llm_metadata = _get_llm_attributes(transaction)
676707
full_chat_completion_summary_dict.update(llm_metadata)
677708
transaction.record_custom_event("LlmChatCompletionSummary", full_chat_completion_summary_dict)
678709

679710
create_chat_completion_message_event(
680-
transaction,
681-
input_message,
682-
completion_id,
683-
span_id,
684-
trace_id,
685-
response_model,
686-
request_model,
687-
llm_metadata,
688-
output_message_list,
689-
request_timestamp,
711+
transaction=transaction,
712+
input_message_content=input_message_content,
713+
input_role=input_role,
714+
chat_completion_id=completion_id,
715+
span_id=span_id,
716+
trace_id=trace_id,
717+
response_model=response_model,
718+
llm_metadata=llm_metadata,
719+
output_message_list=output_message_list,
720+
all_token_counts=all_token_counts,
721+
request_timestamp=request_timestamp,
690722
)
691723
except Exception:
692724
_logger.warning(RECORD_EVENTS_FAILURE_LOG_MESSAGE, exc_info=True)
693725

694726

727+
def _parse_input_message(input_message):
728+
# The input_message will be a string if generate_content was called directly. In this case, we don't have
729+
# access to the role, so we default to user since this was an input message
730+
if isinstance(input_message, str):
731+
return input_message, "user"
732+
# The input_message will be a Google Content type if send_message was called, so we parse out the message
733+
# text and role (which should be "user")
734+
elif isinstance(input_message, google.genai.types.Content):
735+
return input_message.parts[0].text, input_message.role
736+
else:
737+
return None, None
738+
739+
740+
def _extract_generation_config(kwargs):
741+
generation_config = kwargs.get("config")
742+
if generation_config:
743+
request_temperature = getattr(generation_config, "temperature", None)
744+
request_max_tokens = getattr(generation_config, "max_output_tokens", None)
745+
else:
746+
request_temperature = None
747+
request_max_tokens = None
748+
749+
return request_temperature, request_max_tokens
750+
751+
695752
def _handle_streaming_generation_success(
696753
*, linking_metadata, completion_id, kwargs, ft, streaming_events, request_timestamp=None
697754
):
@@ -745,47 +802,29 @@ def _on_stream_chunk(self, chunk):
745802

746803

747804
def create_chat_completion_message_event(
805+
*,
748806
transaction,
749-
input_message,
807+
input_message_content,
808+
input_role,
750809
chat_completion_id,
751810
span_id,
752811
trace_id,
753812
response_model,
754-
request_model,
755813
llm_metadata,
756814
output_message_list,
815+
all_token_counts,
757816
request_timestamp=None,
758817
):
759818
try:
760819
settings = transaction.settings or global_settings()
761820

762-
if input_message:
763-
# The input_message will be a string if generate_content was called directly. In this case, we don't have
764-
# access to the role, so we default to user since this was an input message
765-
if isinstance(input_message, str):
766-
input_message_content = input_message
767-
input_role = "user"
768-
# The input_message will be a Google Content type if send_message was called, so we parse out the message
769-
# text and role (which should be "user")
770-
elif isinstance(input_message, google.genai.types.Content):
771-
input_message_content = input_message.parts[0].text
772-
input_role = input_message.role
773-
# Set input data to NoneTypes to ensure token_count callback is not called
774-
else:
775-
input_message_content = None
776-
input_role = None
777-
821+
if input_message_content:
778822
message_id = str(uuid.uuid4())
779823

780824
chat_completion_input_message_dict = {
781825
"id": message_id,
782826
"span_id": span_id,
783827
"trace_id": trace_id,
784-
"token_count": (
785-
settings.ai_monitoring.llm_token_count_callback(request_model, input_message_content)
786-
if settings.ai_monitoring.llm_token_count_callback and input_message_content
787-
else None
788-
),
789828
"role": input_role,
790829
"completion_id": chat_completion_id,
791830
# The input message will always be the first message in our request/ response sequence so this will
@@ -795,6 +834,8 @@ def create_chat_completion_message_event(
795834
"vendor": "gemini",
796835
"ingest_source": "Python",
797836
}
837+
if all_token_counts:
838+
chat_completion_input_message_dict["token_count"] = 0
798839

799840
if settings.ai_monitoring.record_content.enabled:
800841
chat_completion_input_message_dict["content"] = input_message_content
@@ -813,7 +854,7 @@ def create_chat_completion_message_event(
813854

814855
# Add one to the index to account for the single input message so our sequence value is accurate for
815856
# the output message
816-
if input_message:
857+
if input_message_content:
817858
index += 1
818859

819860
message_id = str(uuid.uuid4())
@@ -822,11 +863,6 @@ def create_chat_completion_message_event(
822863
"id": message_id,
823864
"span_id": span_id,
824865
"trace_id": trace_id,
825-
"token_count": (
826-
settings.ai_monitoring.llm_token_count_callback(response_model, message_content)
827-
if settings.ai_monitoring.llm_token_count_callback
828-
else None
829-
),
830866
"role": message.get("role"),
831867
"completion_id": chat_completion_id,
832868
"sequence": index,
@@ -836,6 +872,9 @@ def create_chat_completion_message_event(
836872
"is_response": True,
837873
}
838874

875+
if all_token_counts:
876+
chat_completion_output_message_dict["token_count"] = 0
877+
839878
if settings.ai_monitoring.record_content.enabled:
840879
chat_completion_output_message_dict["content"] = message_content
841880

tests/mlmodel_gemini/replays/test_text_generation/test_gemini_multi_text_generation/invoke-chat-standard.json

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -65,14 +65,14 @@
6565
"usageMetadata": {
6666
"promptTokenCount": 10,
6767
"candidatesTokenCount": 12,
68-
"totalTokenCount": 159,
68+
"totalTokenCount": 122,
6969
"promptTokensDetails": [
7070
{
7171
"modality": "TEXT",
7272
"tokenCount": 10
7373
}
7474
],
75-
"thoughtsTokenCount": 137
75+
"thoughtsTokenCount": 100
7676
},
7777
"modelVersion": "gemini-2.5-flash",
7878
"responseId": "EbPzafvvHoS4sOIPsYORsQw"
@@ -122,8 +122,8 @@
122122
"token_count": 10
123123
}
124124
],
125-
"thoughts_token_count": 137,
126-
"total_token_count": 159
125+
"thoughts_token_count": 100,
126+
"total_token_count": 122
127127
}
128128
}
129129
]
@@ -192,8 +192,8 @@
192192
],
193193
"usageMetadata": {
194194
"promptTokenCount": 10,
195-
"candidatesTokenCount": 11,
196-
"totalTokenCount": 121,
195+
"candidatesTokenCount": 12,
196+
"totalTokenCount": 122,
197197
"promptTokensDetails": [
198198
{
199199
"modality": "TEXT",
@@ -242,7 +242,7 @@
242242
"model_version": "gemini-2.5-flash",
243243
"response_id": "ErPzaazGNZ6OjrEPnZ_t8QE",
244244
"usage_metadata": {
245-
"candidates_token_count": 11,
245+
"candidates_token_count": 12,
246246
"prompt_token_count": 10,
247247
"prompt_tokens_details": [
248248
{
@@ -251,11 +251,11 @@
251251
}
252252
],
253253
"thoughts_token_count": 100,
254-
"total_token_count": 121
254+
"total_token_count": 122
255255
}
256256
}
257257
]
258258
}
259259
}
260260
]
261-
}
261+
}

0 commit comments

Comments
 (0)