|
| 1 | +""" |
| 2 | +TAU2-style evaluation JSON converters. |
| 3 | +
|
| 4 | +This module contains utilities to convert TAU2-ish evaluation outputs (as seen in |
| 5 | +`data/tau2/*.json`) into StringSight's expected tidy dataframe format. |
| 6 | +""" |
| 7 | + |
| 8 | +from __future__ import annotations |
| 9 | + |
| 10 | +import json |
| 11 | +from typing import Any, Dict, List, Mapping, Sequence, TypedDict |
| 12 | + |
| 13 | +import pandas as pd |
| 14 | + |
| 15 | + |
| 16 | +class OAIMessage(TypedDict, total=False): |
| 17 | + """ |
| 18 | + A minimal OpenAI Chat Completions message structure. |
| 19 | +
|
| 20 | + Keys: |
| 21 | + - role: "system" | "user" | "assistant" | "tool" |
| 22 | + - content: string content (may be omitted for tool-calling assistant turns) |
| 23 | + - tool_calls: OpenAI tool call list for assistant messages |
| 24 | + - tool_call_id: tool call id for tool messages |
| 25 | + - name: tool name for tool messages (optional) |
| 26 | + """ |
| 27 | + |
| 28 | + role: str |
| 29 | + content: str |
| 30 | + tool_calls: List[Dict[str, Any]] |
| 31 | + tool_call_id: str |
| 32 | + name: str |
| 33 | + |
| 34 | + |
| 35 | +def tau2_json_to_stringsight_df( |
| 36 | + data: Mapping[str, Any], |
| 37 | + *, |
| 38 | + prompt_fields: Sequence[str] = ("known_info",), |
| 39 | + include_system_message: bool = True, |
| 40 | + include_scenario_message: bool = True, |
| 41 | + scenario_fields: Sequence[str] = ("domain", "reason_for_call", "known_info", "task_instructions"), |
| 42 | +) -> pd.DataFrame: |
| 43 | + """ |
| 44 | + Convert a TAU2-style evaluation JSON into a StringSight tidy dataframe. |
| 45 | +
|
| 46 | + Expected input schema (minimum required): |
| 47 | + - data["info"]["agent_info"]["llm"]: str |
| 48 | + - data["info"]["environment_info"]["policy"]: str |
| 49 | + - data["info"]["user_info"]["global_simulation_guidelines"]: str |
| 50 | + - data["tasks"]: list of tasks, each containing: |
| 51 | + - task["id"]: str |
| 52 | + - task["user_scenario"]["instructions"]: dict containing at least: |
| 53 | + - "known_info": str (used as the output dataframe's `prompt`) |
| 54 | + - plus any fields listed in `scenario_fields` (if include_scenario_message=True) |
| 55 | + - data["simulations"]: list of simulations, each containing: |
| 56 | + - sim["task_id"]: str (matches a task id) |
| 57 | + - sim["messages"]: list of message dicts in TAU2 format |
| 58 | + - sim["reward_info"]["reward"]: float |
| 59 | +
|
| 60 | + Output dataframe columns: |
| 61 | + - prompt: str |
| 62 | + - By default: the scenario "known_info" field. |
| 63 | + - If prompt_fields contains multiple keys: a labeled, concatenated string containing those fields. |
| 64 | + - model: str |
| 65 | + - data["info"]["agent_info"]["llm"] |
| 66 | + - model_response: list[dict] |
| 67 | + - OpenAI-conversation-style messages: |
| 68 | + - optionally prepends a system message (policy + guidelines) |
| 69 | + - optionally prepends a synthetic user message containing scenario fields |
| 70 | + - followed by the recorded simulation trace converted to OAI format |
| 71 | + - reward: float |
| 72 | + - sim["reward_info"]["reward"] |
| 73 | + """ |
| 74 | + |
| 75 | + tasks_by_id: Dict[str, Mapping[str, Any]] = {t["id"]: t for t in data["tasks"]} |
| 76 | + |
| 77 | + model_name: str = data["info"]["agent_info"]["llm"] |
| 78 | + policy: str = data["info"]["environment_info"]["policy"] |
| 79 | + guidelines: str = data["info"]["user_info"]["global_simulation_guidelines"] |
| 80 | + |
| 81 | + rows: List[Dict[str, Any]] = [] |
| 82 | + for sim in data["simulations"]: |
| 83 | + task_id: str = sim["task_id"] |
| 84 | + task = tasks_by_id[task_id] |
| 85 | + instr: Mapping[str, Any] = task["user_scenario"]["instructions"] |
| 86 | + |
| 87 | + prompt: str = _format_prompt(instr, prompt_fields=prompt_fields) |
| 88 | + reward: float = sim["reward_info"]["reward"] |
| 89 | + |
| 90 | + oai_messages: List[OAIMessage] = [] |
| 91 | + |
| 92 | + if include_system_message: |
| 93 | + oai_messages.append( |
| 94 | + { |
| 95 | + "role": "system", |
| 96 | + "content": f"{policy}\n\n{guidelines}", |
| 97 | + } |
| 98 | + ) |
| 99 | + |
| 100 | + if include_scenario_message: |
| 101 | + scenario_text = _format_scenario_message(instr, scenario_fields=scenario_fields) |
| 102 | + oai_messages.append({"role": "user", "content": scenario_text}) |
| 103 | + |
| 104 | + oai_messages.extend(_tau2_messages_to_oai(sim["messages"])) |
| 105 | + |
| 106 | + rows.append( |
| 107 | + { |
| 108 | + "prompt": prompt, |
| 109 | + "model": model_name, |
| 110 | + "model_response": oai_messages, |
| 111 | + "reward": reward, |
| 112 | + } |
| 113 | + ) |
| 114 | + |
| 115 | + return pd.DataFrame(rows, columns=["prompt", "model", "model_response", "reward"]) |
| 116 | + |
| 117 | + |
| 118 | +def _format_scenario_message(instr: Mapping[str, Any], *, scenario_fields: Sequence[str]) -> str: |
| 119 | + """ |
| 120 | + Format a synthetic user message from TAU2 scenario instructions. |
| 121 | +
|
| 122 | + Args: |
| 123 | + instr: Task instructions mapping (e.g., task["user_scenario"]["instructions"]). |
| 124 | + Expected to contain keys listed in scenario_fields. |
| 125 | + scenario_fields: Ordered keys to include in the rendered message. |
| 126 | +
|
| 127 | + Returns: |
| 128 | + A single string suitable for an OpenAI-format user message. |
| 129 | + """ |
| 130 | + |
| 131 | + return _format_keyed_sections(instr, fields=scenario_fields, title="TAU2 scenario info:") |
| 132 | + |
| 133 | + |
| 134 | +def _format_prompt(instr: Mapping[str, Any], *, prompt_fields: Sequence[str]) -> str: |
| 135 | + """ |
| 136 | + Format the output dataframe's `prompt` field from TAU2 scenario instructions. |
| 137 | +
|
| 138 | + Args: |
| 139 | + instr: Task instructions mapping (e.g., task["user_scenario"]["instructions"]). |
| 140 | + prompt_fields: Ordered keys to include in the prompt. |
| 141 | + - If exactly one key is provided, the prompt is that raw string (no labels), |
| 142 | + preserving the common "prompt == known_info" behavior. |
| 143 | + - If multiple keys are provided, the prompt is a labeled, concatenated string. |
| 144 | +
|
| 145 | + Returns: |
| 146 | + Prompt string for the output dataframe. |
| 147 | + """ |
| 148 | + |
| 149 | + if len(prompt_fields) == 1: |
| 150 | + only_key = prompt_fields[0] |
| 151 | + return str(instr[only_key]).strip() |
| 152 | + |
| 153 | + return _format_keyed_sections(instr, fields=prompt_fields, title=None) |
| 154 | + |
| 155 | + |
| 156 | +def _format_keyed_sections( |
| 157 | + instr: Mapping[str, Any], |
| 158 | + *, |
| 159 | + fields: Sequence[str], |
| 160 | + title: str | None, |
| 161 | +) -> str: |
| 162 | + """ |
| 163 | + Render selected instruction fields into a single text block. |
| 164 | +
|
| 165 | + Args: |
| 166 | + instr: Task instructions mapping (e.g., task["user_scenario"]["instructions"]). |
| 167 | + fields: Ordered keys to include. |
| 168 | + title: Optional title line prepended to the block. |
| 169 | +
|
| 170 | + Returns: |
| 171 | + String with optional title and [key] sections. |
| 172 | + """ |
| 173 | + |
| 174 | + lines: List[str] = [] |
| 175 | + if title is not None: |
| 176 | + lines.append(title) |
| 177 | + for key in fields: |
| 178 | + value = instr[key] |
| 179 | + lines.append(f"\n[{key}]\n{value}") |
| 180 | + return "\n".join(lines).strip() |
| 181 | + |
| 182 | + |
| 183 | +def _tau2_messages_to_oai(messages: List[Mapping[str, Any]]) -> List[OAIMessage]: |
| 184 | + """ |
| 185 | + Convert TAU2 recorded messages into a minimal OAI chat format. |
| 186 | +
|
| 187 | + TAU2 message shape (observed): |
| 188 | + - role: str |
| 189 | + - content: str |
| 190 | + - tool_calls: optional list of {"id": str, "name": str, "arguments": dict, ...} for assistant turns |
| 191 | + - id: str (for tool messages; corresponds to the tool call id) |
| 192 | +
|
| 193 | + Returns: |
| 194 | + List of OAI-format messages. |
| 195 | + """ |
| 196 | + |
| 197 | + out: List[OAIMessage] = [] |
| 198 | + for m in messages: |
| 199 | + role = m["role"] |
| 200 | + |
| 201 | + if role == "tool": |
| 202 | + tool_call_id = m["id"] |
| 203 | + content = m.get("content", "") |
| 204 | + out.append({"role": "tool", "tool_call_id": tool_call_id, "content": content}) |
| 205 | + continue |
| 206 | + |
| 207 | + msg: OAIMessage = {"role": role, "content": m.get("content", "")} |
| 208 | + |
| 209 | + tool_calls = m.get("tool_calls") |
| 210 | + if role == "assistant" and tool_calls: |
| 211 | + msg["tool_calls"] = [ |
| 212 | + { |
| 213 | + "id": tc["id"], |
| 214 | + "type": "function", |
| 215 | + "function": { |
| 216 | + "name": tc["name"], |
| 217 | + "arguments": json.dumps(tc["arguments"], ensure_ascii=False), |
| 218 | + }, |
| 219 | + } |
| 220 | + for tc in tool_calls |
| 221 | + ] |
| 222 | + # For tool-calling messages, OpenAI often uses `content: None`. StringSight accepts either. |
| 223 | + if msg.get("content", "") == "": |
| 224 | + msg.pop("content", None) |
| 225 | + |
| 226 | + out.append(msg) |
| 227 | + |
| 228 | + return out |
0 commit comments