forked from GoogleCloudPlatform/BigQuery-Agent-Analytics-SDK
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdispatch.py
More file actions
209 lines (176 loc) · 6.49 KB
/
Copy pathdispatch.py
File metadata and controls
209 lines (176 loc) · 6.49 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Core dispatch logic for the BigQuery Remote Function.
This module contains the business logic (dispatch routing, evaluator
factories, filter construction) that is independent of Flask /
functions-framework so it can be tested without those dependencies.
"""
from __future__ import annotations
import json
from typing import Any
from bigquery_agent_analytics import Client
from bigquery_agent_analytics import SystemEvaluator
from bigquery_agent_analytics import LLMAsJudge
from bigquery_agent_analytics import serialize
from bigquery_agent_analytics import TraceFilter
from bigquery_agent_analytics._deploy_runtime import resolve_client_options
def build_client_from_context(
user_defined_context: dict[str, Any],
) -> Client:
"""Build a Client from userDefinedContext + env vars."""
return Client(**resolve_client_options(user_defined_context))
def process_calls(
client: Client,
calls: list[list[Any]],
) -> list[dict[str, Any]]:
"""Process a batch of Remote Function calls.
Args:
client: An initialized SDK Client.
calls: List of [operation, params_json] pairs.
Returns:
List of JSON-safe dicts, one per call (partial failure safe).
The caller (main.py) serializes the whole ``{"replies": [...]}``
response once via ``jsonify``, so replies must be dicts — not
pre-serialized JSON strings — to avoid double encoding.
"""
replies: list[dict[str, Any]] = []
for call in calls:
try:
operation, params_json = call[0], call[1]
params = (
json.loads(params_json)
if isinstance(params_json, str)
else params_json
)
result = dispatch(client, operation, params)
result["_version"] = "1.0"
replies.append(result)
except Exception as e:
replies.append(
{
"_error": {
"code": type(e).__name__,
"message": str(e),
},
"_version": "1.0",
}
)
return replies
def dispatch(client, operation, params):
"""Route operation to SDK method, return JSON-safe dict."""
if operation == "analyze":
trace = client.get_session_trace(params["session_id"])
return serialize(trace)
if operation == "evaluate":
evaluator = build_evaluator(params)
filters = build_filters(params)
report = client.evaluate(evaluator=evaluator, filters=filters)
return serialize(report)
if operation == "judge":
judge = build_judge(params)
filters = build_filters(params)
report = client.evaluate(evaluator=judge, filters=filters)
return serialize(report)
if operation == "insights":
filters = build_filters(params)
report = client.insights(filters=filters)
return serialize(report)
if operation == "drift":
golden_dataset = params.get("golden_dataset")
if not golden_dataset:
raise ValueError("drift operation requires 'golden_dataset' param")
filters = build_filters(params)
report = client.drift_detection(
golden_dataset=golden_dataset,
filters=filters,
)
return serialize(report)
raise ValueError(f"Unknown operation: {operation!r}")
def _bool_param(value: Any) -> bool:
"""Parse boolean-ish JSON or string values from remote function params."""
if isinstance(value, bool):
return value
if isinstance(value, str):
return value.strip().lower() in ("1", "true", "yes", "on")
return bool(value)
def build_filters(params):
"""Build TraceFilter from params dict."""
return TraceFilter.from_cli_args(
session_id=params.get("session_id"),
agent_id=params.get("agent_filter"),
last=params.get("last"),
limit=params.get("limit", 100),
)
def build_evaluator(params):
"""Build SystemEvaluator from params dict."""
metric = params.get("metric", "latency")
threshold = params.get("threshold")
fail_on_missing_telemetry = _bool_param(
params.get("fail_on_missing_telemetry", False)
)
factories_with_t = {
"latency": lambda t: SystemEvaluator.latency(threshold_ms=t),
"error_rate": lambda t: SystemEvaluator.error_rate(
max_error_rate=t,
),
"turn_count": lambda t: SystemEvaluator.turn_count(
max_turns=int(t),
),
"token_efficiency": lambda t: SystemEvaluator.token_efficiency(
max_tokens=int(t),
),
"ttft": lambda t: SystemEvaluator.ttft(threshold_ms=t),
"cost": lambda t: SystemEvaluator.cost_per_session(
max_cost_usd=t,
),
}
factories_default = {
"latency": SystemEvaluator.latency,
"error_rate": SystemEvaluator.error_rate,
"turn_count": SystemEvaluator.turn_count,
"token_efficiency": SystemEvaluator.token_efficiency,
"ttft": SystemEvaluator.ttft,
"cost": SystemEvaluator.cost_per_session,
}
if metric == "context_cache_hit_rate":
kwargs = {"fail_on_missing_telemetry": fail_on_missing_telemetry}
if threshold is not None:
kwargs["min_hit_rate"] = threshold
return SystemEvaluator.context_cache_hit_rate(**kwargs)
if metric not in factories_with_t:
raise ValueError(f"Unknown metric: {metric!r}")
if threshold is not None:
return factories_with_t[metric](threshold)
return factories_default[metric]()
def build_judge(params):
"""Build LLMAsJudge from params dict."""
criterion = params.get("criterion", "correctness")
threshold = params.get("threshold")
factories_with_t = {
"correctness": lambda t: LLMAsJudge.correctness(threshold=t),
"hallucination": lambda t: LLMAsJudge.hallucination(
threshold=t,
),
"sentiment": lambda t: LLMAsJudge.sentiment(threshold=t),
}
factories_default = {
"correctness": LLMAsJudge.correctness,
"hallucination": LLMAsJudge.hallucination,
"sentiment": LLMAsJudge.sentiment,
}
if criterion not in factories_with_t:
raise ValueError(f"Unknown criterion: {criterion!r}")
if threshold is not None:
return factories_with_t[criterion](threshold)
return factories_default[criterion]()