Skip to content

Commit a36f712

Browse files
committed
fixed side by side
1 parent d08b248 commit a36f712

4 files changed

Lines changed: 247 additions & 5 deletions

File tree

stringsight/__init__.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,16 @@
55
property extraction, clustering, and metrics computation.
66
"""
77

8-
from .public import explain, explain_side_by_side, explain_single_model, explain_with_custom_pipeline, compute_metrics_only, label, extract_properties_only
8+
from .public import (
9+
explain,
10+
explain_side_by_side,
11+
explain_single_model,
12+
explain_with_custom_pipeline,
13+
compute_metrics_only,
14+
label,
15+
extract_properties_only,
16+
)
17+
from .utils.tau2 import tau2_json_to_stringsight_df
918

1019

1120
__version__ = "0.3.6"
@@ -17,4 +26,5 @@
1726
"compute_metrics_only",
1827
"label",
1928
"extract_properties_only",
29+
"tau2_json_to_stringsight_df",
2030
]

stringsight/core/data_objects.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ class Property:
135135
"""An extracted behavioral property from a model response."""
136136
id: str # unique id for the property
137137
question_id: str
138-
model: str
138+
model: str | list[str]
139139
# Parsed fields (filled by LLMJsonParser)
140140
property_description: str | None = None
141141
category: str | None = None

stringsight/extractors/openai.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,13 @@ def _build_prompt(idx: int, conv):
141141
continue
142142

143143
# We don't yet know which model(s) the individual properties will
144-
# belong to; parser will figure it out. Use a placeholder model
145-
# name so that validation passes.
146-
model_name = conv.model if isinstance(conv.model, str) else conv.model[0] if isinstance(conv.model, list) and conv.model else "unknown"
144+
# belong to; the parser will figure it out from the model label in
145+
# each extracted property JSON.
146+
#
147+
# Important for side-by-side: preserve the model pair on the
148+
# placeholder property so `LLMJsonParser` can map "Model A"/"Model B"
149+
# (or equivalent) onto the correct concrete model name.
150+
model_name = conv.model
147151
prop = Property(
148152
id=str(uuid.uuid4()),
149153
question_id=conv.question_id,

stringsight/utils/tau2.py

Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
"""
2+
TAU2-style evaluation JSON converters.
3+
4+
This module contains utilities to convert TAU2-ish evaluation outputs (as seen in
5+
`data/tau2/*.json`) into StringSight's expected tidy dataframe format.
6+
"""
7+
8+
from __future__ import annotations
9+
10+
import json
11+
from typing import Any, Dict, List, Mapping, Sequence, TypedDict
12+
13+
import pandas as pd
14+
15+
16+
class OAIMessage(TypedDict, total=False):
17+
"""
18+
A minimal OpenAI Chat Completions message structure.
19+
20+
Keys:
21+
- role: "system" | "user" | "assistant" | "tool"
22+
- content: string content (may be omitted for tool-calling assistant turns)
23+
- tool_calls: OpenAI tool call list for assistant messages
24+
- tool_call_id: tool call id for tool messages
25+
- name: tool name for tool messages (optional)
26+
"""
27+
28+
role: str
29+
content: str
30+
tool_calls: List[Dict[str, Any]]
31+
tool_call_id: str
32+
name: str
33+
34+
35+
def tau2_json_to_stringsight_df(
36+
data: Mapping[str, Any],
37+
*,
38+
prompt_fields: Sequence[str] = ("known_info",),
39+
include_system_message: bool = True,
40+
include_scenario_message: bool = True,
41+
scenario_fields: Sequence[str] = ("domain", "reason_for_call", "known_info", "task_instructions"),
42+
) -> pd.DataFrame:
43+
"""
44+
Convert a TAU2-style evaluation JSON into a StringSight tidy dataframe.
45+
46+
Expected input schema (minimum required):
47+
- data["info"]["agent_info"]["llm"]: str
48+
- data["info"]["environment_info"]["policy"]: str
49+
- data["info"]["user_info"]["global_simulation_guidelines"]: str
50+
- data["tasks"]: list of tasks, each containing:
51+
- task["id"]: str
52+
- task["user_scenario"]["instructions"]: dict containing at least:
53+
- "known_info": str (used as the output dataframe's `prompt`)
54+
- plus any fields listed in `scenario_fields` (if include_scenario_message=True)
55+
- data["simulations"]: list of simulations, each containing:
56+
- sim["task_id"]: str (matches a task id)
57+
- sim["messages"]: list of message dicts in TAU2 format
58+
- sim["reward_info"]["reward"]: float
59+
60+
Output dataframe columns:
61+
- prompt: str
62+
- By default: the scenario "known_info" field.
63+
- If prompt_fields contains multiple keys: a labeled, concatenated string containing those fields.
64+
- model: str
65+
- data["info"]["agent_info"]["llm"]
66+
- model_response: list[dict]
67+
- OpenAI-conversation-style messages:
68+
- optionally prepends a system message (policy + guidelines)
69+
- optionally prepends a synthetic user message containing scenario fields
70+
- followed by the recorded simulation trace converted to OAI format
71+
- reward: float
72+
- sim["reward_info"]["reward"]
73+
"""
74+
75+
tasks_by_id: Dict[str, Mapping[str, Any]] = {t["id"]: t for t in data["tasks"]}
76+
77+
model_name: str = data["info"]["agent_info"]["llm"]
78+
policy: str = data["info"]["environment_info"]["policy"]
79+
guidelines: str = data["info"]["user_info"]["global_simulation_guidelines"]
80+
81+
rows: List[Dict[str, Any]] = []
82+
for sim in data["simulations"]:
83+
task_id: str = sim["task_id"]
84+
task = tasks_by_id[task_id]
85+
instr: Mapping[str, Any] = task["user_scenario"]["instructions"]
86+
87+
prompt: str = _format_prompt(instr, prompt_fields=prompt_fields)
88+
reward: float = sim["reward_info"]["reward"]
89+
90+
oai_messages: List[OAIMessage] = []
91+
92+
if include_system_message:
93+
oai_messages.append(
94+
{
95+
"role": "system",
96+
"content": f"{policy}\n\n{guidelines}",
97+
}
98+
)
99+
100+
if include_scenario_message:
101+
scenario_text = _format_scenario_message(instr, scenario_fields=scenario_fields)
102+
oai_messages.append({"role": "user", "content": scenario_text})
103+
104+
oai_messages.extend(_tau2_messages_to_oai(sim["messages"]))
105+
106+
rows.append(
107+
{
108+
"prompt": prompt,
109+
"model": model_name,
110+
"model_response": oai_messages,
111+
"reward": reward,
112+
}
113+
)
114+
115+
return pd.DataFrame(rows, columns=["prompt", "model", "model_response", "reward"])
116+
117+
118+
def _format_scenario_message(instr: Mapping[str, Any], *, scenario_fields: Sequence[str]) -> str:
119+
"""
120+
Format a synthetic user message from TAU2 scenario instructions.
121+
122+
Args:
123+
instr: Task instructions mapping (e.g., task["user_scenario"]["instructions"]).
124+
Expected to contain keys listed in scenario_fields.
125+
scenario_fields: Ordered keys to include in the rendered message.
126+
127+
Returns:
128+
A single string suitable for an OpenAI-format user message.
129+
"""
130+
131+
return _format_keyed_sections(instr, fields=scenario_fields, title="TAU2 scenario info:")
132+
133+
134+
def _format_prompt(instr: Mapping[str, Any], *, prompt_fields: Sequence[str]) -> str:
135+
"""
136+
Format the output dataframe's `prompt` field from TAU2 scenario instructions.
137+
138+
Args:
139+
instr: Task instructions mapping (e.g., task["user_scenario"]["instructions"]).
140+
prompt_fields: Ordered keys to include in the prompt.
141+
- If exactly one key is provided, the prompt is that raw string (no labels),
142+
preserving the common "prompt == known_info" behavior.
143+
- If multiple keys are provided, the prompt is a labeled, concatenated string.
144+
145+
Returns:
146+
Prompt string for the output dataframe.
147+
"""
148+
149+
if len(prompt_fields) == 1:
150+
only_key = prompt_fields[0]
151+
return str(instr[only_key]).strip()
152+
153+
return _format_keyed_sections(instr, fields=prompt_fields, title=None)
154+
155+
156+
def _format_keyed_sections(
157+
instr: Mapping[str, Any],
158+
*,
159+
fields: Sequence[str],
160+
title: str | None,
161+
) -> str:
162+
"""
163+
Render selected instruction fields into a single text block.
164+
165+
Args:
166+
instr: Task instructions mapping (e.g., task["user_scenario"]["instructions"]).
167+
fields: Ordered keys to include.
168+
title: Optional title line prepended to the block.
169+
170+
Returns:
171+
String with optional title and [key] sections.
172+
"""
173+
174+
lines: List[str] = []
175+
if title is not None:
176+
lines.append(title)
177+
for key in fields:
178+
value = instr[key]
179+
lines.append(f"\n[{key}]\n{value}")
180+
return "\n".join(lines).strip()
181+
182+
183+
def _tau2_messages_to_oai(messages: List[Mapping[str, Any]]) -> List[OAIMessage]:
184+
"""
185+
Convert TAU2 recorded messages into a minimal OAI chat format.
186+
187+
TAU2 message shape (observed):
188+
- role: str
189+
- content: str
190+
- tool_calls: optional list of {"id": str, "name": str, "arguments": dict, ...} for assistant turns
191+
- id: str (for tool messages; corresponds to the tool call id)
192+
193+
Returns:
194+
List of OAI-format messages.
195+
"""
196+
197+
out: List[OAIMessage] = []
198+
for m in messages:
199+
role = m["role"]
200+
201+
if role == "tool":
202+
tool_call_id = m["id"]
203+
content = m.get("content", "")
204+
out.append({"role": "tool", "tool_call_id": tool_call_id, "content": content})
205+
continue
206+
207+
msg: OAIMessage = {"role": role, "content": m.get("content", "")}
208+
209+
tool_calls = m.get("tool_calls")
210+
if role == "assistant" and tool_calls:
211+
msg["tool_calls"] = [
212+
{
213+
"id": tc["id"],
214+
"type": "function",
215+
"function": {
216+
"name": tc["name"],
217+
"arguments": json.dumps(tc["arguments"], ensure_ascii=False),
218+
},
219+
}
220+
for tc in tool_calls
221+
]
222+
# For tool-calling messages, OpenAI often uses `content: None`. StringSight accepts either.
223+
if msg.get("content", "") == "":
224+
msg.pop("content", None)
225+
226+
out.append(msg)
227+
228+
return out

0 commit comments

Comments
 (0)