Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions src/datasets/arrow_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,15 +652,15 @@ def write_examples_on_file(self):
row[0][col].to_pylist()[0] if isinstance(row[0][col], (pa.Array, pa.ChunkedArray)) else row[0][col]
for row in self.current_examples
]
self.write_batch(batch_examples=batch_examples)
self._write_batch(batch_examples=batch_examples)
self.current_examples = []

def write_rows_on_file(self):
"""Write stored rows from the write-pool of rows. It concatenates the single-row tables and it writes the resulting table."""
if not self.current_rows:
return
table = pa.concat_tables(self.current_rows)
self.write_table(table)
self._write_table(table)
self.current_rows = []

def write(
Expand Down Expand Up @@ -709,6 +709,15 @@ def write_batch(
batch_examples: the batch of examples to add.
try_original_type: use `try_type` when instantiating OptimizedTypedSequence if `True`, otherwise `try_type = None`.
"""
self.write_examples_on_file() # in case there are buffered examples to write first
self._write_batch(batch_examples, writer_batch_size=writer_batch_size, try_original_type=try_original_type)

def _write_batch(
self,
batch_examples: dict[str, list],
writer_batch_size: Optional[int] = None,
try_original_type: Optional[bool] = True,
):
if batch_examples and len(next(iter(batch_examples.values()))) == 0:
return
features = None if self.pa_writer is None and self.update_features else self._features
Expand Down Expand Up @@ -752,6 +761,10 @@ def write_table(self, pa_table: pa.Table, writer_batch_size: Optional[int] = Non
Args:
example: the Table to add.
"""
self.write_rows_on_file() # in case there are buffered rows to write first
self._write_table(pa_table, writer_batch_size=writer_batch_size)

def _write_table(self, pa_table: pa.Table, writer_batch_size: Optional[int] = None):
if writer_batch_size is None:
writer_batch_size = self.writer_batch_size
if self.pa_writer is None:
Expand Down
5 changes: 4 additions & 1 deletion src/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1794,7 +1794,10 @@ def _prepare_split_single(
embed_local_files=embed_local_files,
)
try:
writer.write_table(table)
if len(table) == 1:
writer.write_row(table)
else:
writer.write_table(table)
except CastError as cast_error:
raise DatasetGenerationCastError.from_cast_error(
cast_error=cast_error,
Expand Down
2 changes: 1 addition & 1 deletion src/datasets/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,7 @@ def get_module(self) -> DatasetModule:
]
default_config_name = None
builder_kwargs = {
"base_path": hf_dataset_url(self.name, "", revision=self.commit_hash).rstrip("/"),
"base_path": base_path,
"repo_id": self.name,
"dataset_name": camelcase_to_snakecase(Path(self.name).name),
}
Expand Down
124 changes: 120 additions & 4 deletions src/datasets/packaged_modules/json/json.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import io
import os
from dataclasses import dataclass
from typing import Literal, Optional

Expand Down Expand Up @@ -50,6 +51,7 @@ class JsonConfig(datasets.BuilderConfig):
chunksize: int = 10 << 20 # 10MB
newlines_in_values: Optional[bool] = None
on_mixed_types: Optional[Literal["use_json"]] = "use_json"
parse_agent_traces: bool = True

def __post_init__(self):
super().__post_init__()
Expand Down Expand Up @@ -83,13 +85,19 @@ def _split_generators(self, dl_manager):
splits.append(
datasets.SplitGenerator(
name=split_name,
gen_kwargs={"files_iterables": files_iterables, "base_files": base_data_files[split_name]},
gen_kwargs={
"files_iterables": files_iterables,
"base_files": base_data_files[split_name],
"original_files": self.config.data_files[split_name],
},
)
)
if self.info.features is None:
try:
pa_table = next(iter(self._generate_tables(**splits[0].gen_kwargs, allow_full_read=False)))[1]
self.info.features = datasets.Features.from_arrow_schema(pa_table.schema)
if self.config.parse_agent_traces and has_agent_traces_markers(self.info.features):
self.info.features = AGENT_TRACES_FEATURES
except FullReadDisallowed:
pass
return splits
Expand Down Expand Up @@ -124,14 +132,18 @@ def _cast_table(self, pa_table: pa.Table, json_field_paths=()) -> pa.Table:
pa_table = table_cast(pa_table, features.arrow_schema)
return pa_table

def _generate_shards(self, base_files, files_iterables):
def _generate_shards(self, base_files, files_iterables, original_files):
yield from base_files

def _generate_tables(self, base_files, files_iterables, allow_full_read=True):
def _generate_tables(self, base_files, files_iterables, original_files, allow_full_read=True):
json_field_paths = []
is_agent_traces = False

if self.info.features is not None:
json_field_paths = get_json_field_paths_from_feature(self.info.features)
if self.info.features == AGENT_TRACES_FEATURES:
is_agent_traces = True
else:
json_field_paths = get_json_field_paths_from_feature(self.info.features)

for shard_idx, files_iterable in enumerate(files_iterables):
for file in files_iterable:
Expand All @@ -149,6 +161,24 @@ def _generate_tables(self, base_files, files_iterables, allow_full_read=True):
pa_table = pa.Table.from_pandas(df, preserve_index=False)
yield Key(shard_idx, 0), self._cast_table(pa_table)

# If the files are agent traces (one row = one file)
elif is_agent_traces:
with open(file, "r", encoding="utf-8") as f:
traces = f.readlines()
harness, session_id = parse_traces_info(traces)
file_path = original_files[shard_idx]
if file_path.startswith(self.base_path):
file_path = os.path.relpath(file_path, self.base_path)
pa_table = pa.Table.from_pydict(
{
"harness": [harness],
"session_id": [session_id],
"traces": [traces],
"file_path": [file_path],
}
)
yield Key(shard_idx, 0), self._cast_table(pa_table)

# If the file has one json object per line
else:
with open(file, "rb") as f:
Expand Down Expand Up @@ -265,3 +295,89 @@ def _generate_tables(self, base_files, files_iterables, allow_full_read=True):
self._cast_table(pa_table, json_field_paths=json_field_paths),
)
batch_idx += 1


AGENT_TRACES_TYPES_VALUES = {
"claude_code": ["user", "assistant", "system"],
"pi": ["session", "message"],
"codex": ["session_meta", "turn_context", "response_item", "event_msg"],
}
AGENT_TRACES_TYPE_TO_HARNESS = {}
for _harness, _trace_types in AGENT_TRACES_TYPES_VALUES.items():
for _trace_type in _trace_types:
AGENT_TRACES_TYPE_TO_HARNESS[_trace_type] = _harness


AGENT_TRACES_FEATURES_MARKERS = {
"claude_code": datasets.Features(
{
"type": datasets.Value("string"),
"message": datasets.Json(),
}
),
"pi": datasets.Features(
{
"type": datasets.Value("string"),
"message": datasets.Json(),
}
),
"codex": datasets.Features(
{
"type": datasets.Value("string"),
"payload": datasets.Json(),
}
),
}

AGENT_TRACES_FEATURES = datasets.Features(
{
"harness": datasets.Value("string"),
"session_id": datasets.Value("string"),
"traces": datasets.List(datasets.Json()),
"file_path": datasets.Value("string"),
}
)


def has_agent_traces_markers(features: datasets.Features) -> bool:
for agent_traces_features_marker in AGENT_TRACES_FEATURES_MARKERS.values():
if all(features.get(key) == feature for key, feature in agent_traces_features_marker.items()):
return True
return False


def parse_traces_info(traces: list[str]) -> tuple[Optional[str], Optional[str]]:
harness, session_id = None, None
for trace in traces:
decoded_trace = ujson_loads(trace)
if harness is None:
if "type" in decoded_trace and isinstance(decoded_trace["type"], str):
harness = AGENT_TRACES_TYPE_TO_HARNESS.get(decoded_trace["type"])
if session_id is None:
# claude
if "sessionId" in decoded_trace and isinstance(decoded_trace["sessionId"], str):
session_id = decoded_trace["sessionId"]
# claude (not sure but this format does exist online)
elif "session_id" in decoded_trace and isinstance(decoded_trace["session_id"], str):
session_id = decoded_trace["session_id"]
# codex
elif (
"payload" in decoded_trace
and isinstance(decoded_trace["payload"], dict)
and "id" in decoded_trace["payload"]
and isinstance(decoded_trace["payload"]["id"], str)
):
session_id = decoded_trace["payload"]["id"]
# pi / openclaw (openclaw embeds pi-agent; distinguish via cwd)
elif (
"type" in decoded_trace
and decoded_trace["type"] == "session"
and "id" in decoded_trace
and isinstance(decoded_trace["id"], str)
):
session_id = decoded_trace["id"]
if isinstance(decoded_trace.get("cwd"), str) and "/.openclaw/" in decoded_trace["cwd"]:
harness = "openclaw"
if harness and session_id:
break
return harness, session_id
14 changes: 7 additions & 7 deletions src/datasets/utils/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def ujson_loads(*args, **kwargs):
return pd.io.json.loads(*args, **kwargs)


def json_encode_field(example: Any, json_field_path: str) -> Any:
def json_encode_field(example: Any, json_field_path: list[str]) -> Any:
if json_field_path:
field, *json_field_path = json_field_path
if example is None:
Expand Down Expand Up @@ -57,7 +57,7 @@ def json_decode_field(example: Any, json_field_path: str) -> Any:
return example


def find_mixed_struct_types_field_paths(examples: list, allow_root=False) -> list[str]:
def find_mixed_struct_types_field_paths(examples: list, allow_root=False) -> list[list[str]]:
mixed_struct_types_field_paths = []
examples = [example for example in examples if example is not None]
if not examples:
Expand All @@ -84,15 +84,15 @@ def find_mixed_struct_types_field_paths(examples: list, allow_root=False) -> lis
return mixed_struct_types_field_paths


def get_json_field_path_from_pyarrow_json_error(err_str: str) -> str:
def get_json_field_path_from_pyarrow_json_error(err_str: str) -> list[str]:
# e.g. json_field_path_str = "col/subfield_containing_a_list/[]/subsubfield_in_item_in_the_list"
json_field_path_str = err_str.split("Column(", 1)[1].rsplit(") changed from", 1)[0].strip("/")
# e.g. json_field_path = ["col", "subfield_containing_a_list", 0, "subsubfield_in_item_in_the_list"]
json_field_path = [0 if seg == "[]" else seg for seg in json_field_path_str.split("/")]
return json_field_path


def insert_json_field_path(json_field_paths: list[str], json_field_path: str) -> None:
def insert_json_field_path(json_field_paths: list[list[str]], json_field_path: list[str]) -> None:
# Add to list of json_field_paths and check if other share a common path
for i in range(len(json_field_paths)):
if json_field_paths[i][: len(json_field_path)] == json_field_path:
Expand All @@ -102,15 +102,15 @@ def insert_json_field_path(json_field_paths: list[str], json_field_path: str) ->
json_field_paths.append(json_field_path)


def json_encode_fields_in_json_lines(original_batch: bytes, json_field_paths: list[str]) -> bytes:
def json_encode_fields_in_json_lines(original_batch: bytes, json_field_paths: list[list[str]]) -> bytes:
examples = [ujson_loads(line) for line in original_batch.splitlines()]
for json_field_path in json_field_paths:
examples = [json_encode_field(example, json_field_path) for example in examples]
batch = "\n".join([ujson_dumps(example) for example in examples]).encode()
return batch


def get_json_field_paths_from_feature(feature: "FeatureType") -> list[str]:
def get_json_field_paths_from_feature(feature: "FeatureType") -> list[list[str]]:
from datasets.features.features import Json, _visit_with_path

json_field_paths = []
Expand All @@ -124,7 +124,7 @@ def get_json_type_path(_feature, feature_path):
return json_field_paths


def set_json_types_in_feature(feature: "FeatureType", json_field_paths: list[str]) -> None:
def set_json_types_in_feature(feature: "FeatureType", json_field_paths: list[list[str]]) -> None:
from datasets.features.features import Json, _visit_with_path

def set_json_type(feature, feature_path):
Expand Down
5 changes: 4 additions & 1 deletion tests/packaged_modules/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,10 @@ def test_json_generate_tables(file_fixture, config_kwargs, expected, request):
json = Json(**config_kwargs)
base_files = [request.getfixturevalue(file_fixture)]
files_iterables = [[file] for file in base_files]
generator = json._generate_tables(base_files=base_files, files_iterables=files_iterables)
original_files = list(base_files)
generator = json._generate_tables(
base_files=base_files, files_iterables=files_iterables, original_files=original_files
)
pa_table = pa.concat_tables([table for _, table in generator])
out = Features.from_arrow_schema(pa_table.schema).decode_batch(pa_table.to_pydict())
assert out == expected
Expand Down
Loading