Skip to content

Commit 891e52f

Browse files
lhoestqcfahlgren1
andauthored
Parse agent traces (#8113)
* parse agents traces * fix typing * add file_name * support session_id for claude * fix agent traces parquet row group size * use full repo-relative file path instead of filename for agent traces (#8133) * use full repo-relative file path instead of basename for agent traces file_name * set base_path to hf://... instead of https://huggingface.co/... * use relpath --------- Co-authored-by: Quentin Lhoest <lhoest.q@gmail.com> * tag openclaw sessions via cwd in pi agent traces (#8143) * update tests * style --------- Co-authored-by: Caleb Fahlgren <cfahlgren1@gmail.com>
1 parent 2724a65 commit 891e52f

6 files changed

Lines changed: 151 additions & 16 deletions

File tree

src/datasets/arrow_writer.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -652,15 +652,15 @@ def write_examples_on_file(self):
652652
row[0][col].to_pylist()[0] if isinstance(row[0][col], (pa.Array, pa.ChunkedArray)) else row[0][col]
653653
for row in self.current_examples
654654
]
655-
self.write_batch(batch_examples=batch_examples)
655+
self._write_batch(batch_examples=batch_examples)
656656
self.current_examples = []
657657

658658
def write_rows_on_file(self):
659659
"""Write stored rows from the write-pool of rows. It concatenates the single-row tables and it writes the resulting table."""
660660
if not self.current_rows:
661661
return
662662
table = pa.concat_tables(self.current_rows)
663-
self.write_table(table)
663+
self._write_table(table)
664664
self.current_rows = []
665665

666666
def write(
@@ -709,6 +709,15 @@ def write_batch(
709709
batch_examples: the batch of examples to add.
710710
try_original_type: use `try_type` when instantiating OptimizedTypedSequence if `True`, otherwise `try_type = None`.
711711
"""
712+
self.write_examples_on_file() # in case there are buffered examples to write first
713+
self._write_batch(batch_examples, writer_batch_size=writer_batch_size, try_original_type=try_original_type)
714+
715+
def _write_batch(
716+
self,
717+
batch_examples: dict[str, list],
718+
writer_batch_size: Optional[int] = None,
719+
try_original_type: Optional[bool] = True,
720+
):
712721
if batch_examples and len(next(iter(batch_examples.values()))) == 0:
713722
return
714723
features = None if self.pa_writer is None and self.update_features else self._features
@@ -752,6 +761,10 @@ def write_table(self, pa_table: pa.Table, writer_batch_size: Optional[int] = Non
752761
Args:
753762
example: the Table to add.
754763
"""
764+
self.write_rows_on_file() # in case there are buffered rows to write first
765+
self._write_table(pa_table, writer_batch_size=writer_batch_size)
766+
767+
def _write_table(self, pa_table: pa.Table, writer_batch_size: Optional[int] = None):
755768
if writer_batch_size is None:
756769
writer_batch_size = self.writer_batch_size
757770
if self.pa_writer is None:

src/datasets/builder.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1794,7 +1794,10 @@ def _prepare_split_single(
17941794
embed_local_files=embed_local_files,
17951795
)
17961796
try:
1797-
writer.write_table(table)
1797+
if len(table) == 1:
1798+
writer.write_row(table)
1799+
else:
1800+
writer.write_table(table)
17981801
except CastError as cast_error:
17991802
raise DatasetGenerationCastError.from_cast_error(
18001803
cast_error=cast_error,

src/datasets/load.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -669,7 +669,7 @@ def get_module(self) -> DatasetModule:
669669
]
670670
default_config_name = None
671671
builder_kwargs = {
672-
"base_path": hf_dataset_url(self.name, "", revision=self.commit_hash).rstrip("/"),
672+
"base_path": base_path,
673673
"repo_id": self.name,
674674
"dataset_name": camelcase_to_snakecase(Path(self.name).name),
675675
}

src/datasets/packaged_modules/json/json.py

Lines changed: 120 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import io
2+
import os
23
from dataclasses import dataclass
34
from typing import Literal, Optional
45

@@ -50,6 +51,7 @@ class JsonConfig(datasets.BuilderConfig):
5051
chunksize: int = 10 << 20 # 10MB
5152
newlines_in_values: Optional[bool] = None
5253
on_mixed_types: Optional[Literal["use_json"]] = "use_json"
54+
parse_agent_traces: bool = True
5355

5456
def __post_init__(self):
5557
super().__post_init__()
@@ -83,13 +85,19 @@ def _split_generators(self, dl_manager):
8385
splits.append(
8486
datasets.SplitGenerator(
8587
name=split_name,
86-
gen_kwargs={"files_iterables": files_iterables, "base_files": base_data_files[split_name]},
88+
gen_kwargs={
89+
"files_iterables": files_iterables,
90+
"base_files": base_data_files[split_name],
91+
"original_files": self.config.data_files[split_name],
92+
},
8793
)
8894
)
8995
if self.info.features is None:
9096
try:
9197
pa_table = next(iter(self._generate_tables(**splits[0].gen_kwargs, allow_full_read=False)))[1]
9298
self.info.features = datasets.Features.from_arrow_schema(pa_table.schema)
99+
if self.config.parse_agent_traces and has_agent_traces_markers(self.info.features):
100+
self.info.features = AGENT_TRACES_FEATURES
93101
except FullReadDisallowed:
94102
pass
95103
return splits
@@ -124,14 +132,18 @@ def _cast_table(self, pa_table: pa.Table, json_field_paths=()) -> pa.Table:
124132
pa_table = table_cast(pa_table, features.arrow_schema)
125133
return pa_table
126134

127-
def _generate_shards(self, base_files, files_iterables):
135+
def _generate_shards(self, base_files, files_iterables, original_files):
128136
yield from base_files
129137

130-
def _generate_tables(self, base_files, files_iterables, allow_full_read=True):
138+
def _generate_tables(self, base_files, files_iterables, original_files, allow_full_read=True):
131139
json_field_paths = []
140+
is_agent_traces = False
132141

133142
if self.info.features is not None:
134-
json_field_paths = get_json_field_paths_from_feature(self.info.features)
143+
if self.info.features == AGENT_TRACES_FEATURES:
144+
is_agent_traces = True
145+
else:
146+
json_field_paths = get_json_field_paths_from_feature(self.info.features)
135147

136148
for shard_idx, files_iterable in enumerate(files_iterables):
137149
for file in files_iterable:
@@ -149,6 +161,24 @@ def _generate_tables(self, base_files, files_iterables, allow_full_read=True):
149161
pa_table = pa.Table.from_pandas(df, preserve_index=False)
150162
yield Key(shard_idx, 0), self._cast_table(pa_table)
151163

164+
# If the files are agent traces (one row = one file)
165+
elif is_agent_traces:
166+
with open(file, "r", encoding="utf-8") as f:
167+
traces = f.readlines()
168+
harness, session_id = parse_traces_info(traces)
169+
file_path = original_files[shard_idx]
170+
if file_path.startswith(self.base_path):
171+
file_path = os.path.relpath(file_path, self.base_path)
172+
pa_table = pa.Table.from_pydict(
173+
{
174+
"harness": [harness],
175+
"session_id": [session_id],
176+
"traces": [traces],
177+
"file_path": [file_path],
178+
}
179+
)
180+
yield Key(shard_idx, 0), self._cast_table(pa_table)
181+
152182
# If the file has one json object per line
153183
else:
154184
with open(file, "rb") as f:
@@ -265,3 +295,89 @@ def _generate_tables(self, base_files, files_iterables, allow_full_read=True):
265295
self._cast_table(pa_table, json_field_paths=json_field_paths),
266296
)
267297
batch_idx += 1
298+
299+
300+
AGENT_TRACES_TYPES_VALUES = {
301+
"claude_code": ["user", "assistant", "system"],
302+
"pi": ["session", "message"],
303+
"codex": ["session_meta", "turn_context", "response_item", "event_msg"],
304+
}
305+
AGENT_TRACES_TYPE_TO_HARNESS = {}
306+
for _harness, _trace_types in AGENT_TRACES_TYPES_VALUES.items():
307+
for _trace_type in _trace_types:
308+
AGENT_TRACES_TYPE_TO_HARNESS[_trace_type] = _harness
309+
310+
311+
AGENT_TRACES_FEATURES_MARKERS = {
312+
"claude_code": datasets.Features(
313+
{
314+
"type": datasets.Value("string"),
315+
"message": datasets.Json(),
316+
}
317+
),
318+
"pi": datasets.Features(
319+
{
320+
"type": datasets.Value("string"),
321+
"message": datasets.Json(),
322+
}
323+
),
324+
"codex": datasets.Features(
325+
{
326+
"type": datasets.Value("string"),
327+
"payload": datasets.Json(),
328+
}
329+
),
330+
}
331+
332+
AGENT_TRACES_FEATURES = datasets.Features(
333+
{
334+
"harness": datasets.Value("string"),
335+
"session_id": datasets.Value("string"),
336+
"traces": datasets.List(datasets.Json()),
337+
"file_path": datasets.Value("string"),
338+
}
339+
)
340+
341+
342+
def has_agent_traces_markers(features: datasets.Features) -> bool:
343+
for agent_traces_features_marker in AGENT_TRACES_FEATURES_MARKERS.values():
344+
if all(features.get(key) == feature for key, feature in agent_traces_features_marker.items()):
345+
return True
346+
return False
347+
348+
349+
def parse_traces_info(traces: list[str]) -> tuple[Optional[str], Optional[str]]:
350+
harness, session_id = None, None
351+
for trace in traces:
352+
decoded_trace = ujson_loads(trace)
353+
if harness is None:
354+
if "type" in decoded_trace and isinstance(decoded_trace["type"], str):
355+
harness = AGENT_TRACES_TYPE_TO_HARNESS.get(decoded_trace["type"])
356+
if session_id is None:
357+
# claude
358+
if "sessionId" in decoded_trace and isinstance(decoded_trace["sessionId"], str):
359+
session_id = decoded_trace["sessionId"]
360+
# claude (not sure but this format does exist online)
361+
elif "session_id" in decoded_trace and isinstance(decoded_trace["session_id"], str):
362+
session_id = decoded_trace["session_id"]
363+
# codex
364+
elif (
365+
"payload" in decoded_trace
366+
and isinstance(decoded_trace["payload"], dict)
367+
and "id" in decoded_trace["payload"]
368+
and isinstance(decoded_trace["payload"]["id"], str)
369+
):
370+
session_id = decoded_trace["payload"]["id"]
371+
# pi / openclaw (openclaw embeds pi-agent; distinguish via cwd)
372+
elif (
373+
"type" in decoded_trace
374+
and decoded_trace["type"] == "session"
375+
and "id" in decoded_trace
376+
and isinstance(decoded_trace["id"], str)
377+
):
378+
session_id = decoded_trace["id"]
379+
if isinstance(decoded_trace.get("cwd"), str) and "/.openclaw/" in decoded_trace["cwd"]:
380+
harness = "openclaw"
381+
if harness and session_id:
382+
break
383+
return harness, session_id

src/datasets/utils/json.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def ujson_loads(*args, **kwargs):
2323
return pd.io.json.loads(*args, **kwargs)
2424

2525

26-
def json_encode_field(example: Any, json_field_path: str) -> Any:
26+
def json_encode_field(example: Any, json_field_path: list[str]) -> Any:
2727
if json_field_path:
2828
field, *json_field_path = json_field_path
2929
if example is None:
@@ -57,7 +57,7 @@ def json_decode_field(example: Any, json_field_path: str) -> Any:
5757
return example
5858

5959

60-
def find_mixed_struct_types_field_paths(examples: list, allow_root=False) -> list[str]:
60+
def find_mixed_struct_types_field_paths(examples: list, allow_root=False) -> list[list[str]]:
6161
mixed_struct_types_field_paths = []
6262
examples = [example for example in examples if example is not None]
6363
if not examples:
@@ -84,15 +84,15 @@ def find_mixed_struct_types_field_paths(examples: list, allow_root=False) -> lis
8484
return mixed_struct_types_field_paths
8585

8686

87-
def get_json_field_path_from_pyarrow_json_error(err_str: str) -> str:
87+
def get_json_field_path_from_pyarrow_json_error(err_str: str) -> list[str]:
8888
# e.g. json_field_path_str = "col/subfield_containing_a_list/[]/subsubfield_in_item_in_the_list"
8989
json_field_path_str = err_str.split("Column(", 1)[1].rsplit(") changed from", 1)[0].strip("/")
9090
# e.g. json_field_path = ["col", "subfield_containing_a_list", 0, "subsubfield_in_item_in_the_list"]
9191
json_field_path = [0 if seg == "[]" else seg for seg in json_field_path_str.split("/")]
9292
return json_field_path
9393

9494

95-
def insert_json_field_path(json_field_paths: list[str], json_field_path: str) -> None:
95+
def insert_json_field_path(json_field_paths: list[list[str]], json_field_path: list[str]) -> None:
9696
# Add to list of json_field_paths and check if other share a common path
9797
for i in range(len(json_field_paths)):
9898
if json_field_paths[i][: len(json_field_path)] == json_field_path:
@@ -102,15 +102,15 @@ def insert_json_field_path(json_field_paths: list[str], json_field_path: str) ->
102102
json_field_paths.append(json_field_path)
103103

104104

105-
def json_encode_fields_in_json_lines(original_batch: bytes, json_field_paths: list[str]) -> bytes:
105+
def json_encode_fields_in_json_lines(original_batch: bytes, json_field_paths: list[list[str]]) -> bytes:
106106
examples = [ujson_loads(line) for line in original_batch.splitlines()]
107107
for json_field_path in json_field_paths:
108108
examples = [json_encode_field(example, json_field_path) for example in examples]
109109
batch = "\n".join([ujson_dumps(example) for example in examples]).encode()
110110
return batch
111111

112112

113-
def get_json_field_paths_from_feature(feature: "FeatureType") -> list[str]:
113+
def get_json_field_paths_from_feature(feature: "FeatureType") -> list[list[str]]:
114114
from datasets.features.features import Json, _visit_with_path
115115

116116
json_field_paths = []
@@ -124,7 +124,7 @@ def get_json_type_path(_feature, feature_path):
124124
return json_field_paths
125125

126126

127-
def set_json_types_in_feature(feature: "FeatureType", json_field_paths: list[str]) -> None:
127+
def set_json_types_in_feature(feature: "FeatureType", json_field_paths: list[list[str]]) -> None:
128128
from datasets.features.features import Json, _visit_with_path
129129

130130
def set_json_type(feature, feature_path):

tests/packaged_modules/test_json.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,10 @@ def test_json_generate_tables(file_fixture, config_kwargs, expected, request):
326326
json = Json(**config_kwargs)
327327
base_files = [request.getfixturevalue(file_fixture)]
328328
files_iterables = [[file] for file in base_files]
329-
generator = json._generate_tables(base_files=base_files, files_iterables=files_iterables)
329+
original_files = list(base_files)
330+
generator = json._generate_tables(
331+
base_files=base_files, files_iterables=files_iterables, original_files=original_files
332+
)
330333
pa_table = pa.concat_tables([table for _, table in generator])
331334
out = Features.from_arrow_schema(pa_table.schema).decode_batch(pa_table.to_pydict())
332335
assert out == expected

0 commit comments

Comments
 (0)