Skip to content

Commit 48f1b30

Browse files
google-genai-botcopybara-github
authored andcommitted
feat: Simplify data retrieved handling of ask_data_agent tool and ask_data_insights tool
PiperOrigin-RevId: 917326615
1 parent 59f7347 commit 48f1b30

7 files changed

Lines changed: 460 additions & 924 deletions

File tree

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from __future__ import annotations
15+
16+
import json
17+
from typing import Any
18+
19+
import requests
20+
21+
22+
def get_stream(
23+
url: str,
24+
ca_payload: dict[str, Any],
25+
headers: dict[str, str],
26+
max_query_result_rows: int,
27+
) -> list[dict[str, Any]]:
28+
"""Sends a JSON request to a streaming API and returns a list of messages."""
29+
with requests.Session() as s:
30+
accumulator = ""
31+
messages = []
32+
data_msg_idx = -1
33+
34+
with s.post(url, json=ca_payload, headers=headers, stream=True) as resp:
35+
resp.raise_for_status()
36+
for line in resp.iter_lines():
37+
if not line:
38+
continue
39+
40+
decoded_line = line.decode("utf-8")
41+
42+
if decoded_line == "[{":
43+
accumulator = "{"
44+
elif decoded_line == "}]":
45+
accumulator += "}"
46+
elif decoded_line == ",":
47+
continue
48+
else:
49+
accumulator += decoded_line
50+
51+
try:
52+
data_json = json.loads(accumulator)
53+
except ValueError:
54+
continue
55+
56+
accumulator = ""
57+
58+
if not isinstance(data_json, dict):
59+
messages.append(data_json)
60+
continue
61+
62+
processed_msg = None
63+
data_result = _extract_data_result(data_json)
64+
if data_result is not None:
65+
processed_msg = _format_data_retrieved(
66+
data_result, max_query_result_rows
67+
)
68+
if data_msg_idx >= 0:
69+
messages[data_msg_idx] = {
70+
"Data Retrieved": "Intermediate result omitted"
71+
}
72+
data_msg_idx = len(messages)
73+
elif isinstance(data_json.get("systemMessage"), dict):
74+
processed_msg = data_json["systemMessage"]
75+
else:
76+
processed_msg = data_json
77+
78+
if processed_msg is not None:
79+
messages.append(processed_msg)
80+
81+
return messages
82+
83+
84+
def _extract_data_result(msg: dict[str, Any]) -> dict[str, Any] | None:
85+
"""Attempts to find the result.data deep inside the generic dict."""
86+
sm = msg.get("systemMessage")
87+
if not isinstance(sm, dict):
88+
return None
89+
data = sm.get("data")
90+
if not isinstance(data, dict):
91+
return None
92+
result = data.get("result")
93+
if not isinstance(result, dict):
94+
return None
95+
if "data" in result and isinstance(result["data"], list):
96+
return result
97+
return None
98+
99+
100+
def _format_data_retrieved(
101+
result: dict[str, Any], max_rows: int
102+
) -> dict[str, Any]:
103+
"""Transforms the raw result dict into the simplified Toolbox format."""
104+
raw_data = result.get("data", [])
105+
106+
fields = []
107+
schema = result.get("schema")
108+
if isinstance(schema, dict):
109+
schema_fields = schema.get("fields")
110+
if isinstance(schema_fields, list):
111+
fields = schema_fields
112+
113+
headers = []
114+
for f in fields:
115+
if isinstance(f, dict):
116+
name = f.get("name")
117+
if isinstance(name, str):
118+
headers.append(name)
119+
120+
if not headers and raw_data:
121+
first_row = raw_data[0]
122+
if isinstance(first_row, dict):
123+
headers = list(first_row.keys())
124+
125+
total_rows = len(raw_data)
126+
num_to_display = min(total_rows, max_rows)
127+
128+
rows = []
129+
for r in raw_data[:num_to_display]:
130+
if isinstance(r, dict):
131+
row = [r.get(h) for h in headers]
132+
rows.append(row)
133+
134+
summary = f"Showing all {total_rows} rows."
135+
if total_rows > max_rows:
136+
summary = f"Showing the first {num_to_display} of {total_rows} total rows."
137+
138+
return {
139+
"Data Retrieved": {
140+
"headers": headers,
141+
"rows": rows,
142+
"summary": summary,
143+
}
144+
}

0 commit comments

Comments
 (0)