diff --git a/src/datasets/arrow_writer.py b/src/datasets/arrow_writer.py index e51fa3f8c7f..a2c555c8b27 100644 --- a/src/datasets/arrow_writer.py +++ b/src/datasets/arrow_writer.py @@ -652,7 +652,7 @@ def write_examples_on_file(self): row[0][col].to_pylist()[0] if isinstance(row[0][col], (pa.Array, pa.ChunkedArray)) else row[0][col] for row in self.current_examples ] - self.write_batch(batch_examples=batch_examples) + self._write_batch(batch_examples=batch_examples) self.current_examples = [] def write_rows_on_file(self): @@ -660,7 +660,7 @@ def write_rows_on_file(self): if not self.current_rows: return table = pa.concat_tables(self.current_rows) - self.write_table(table) + self._write_table(table) self.current_rows = [] def write( @@ -709,6 +709,15 @@ def write_batch( batch_examples: the batch of examples to add. try_original_type: use `try_type` when instantiating OptimizedTypedSequence if `True`, otherwise `try_type = None`. """ + self.write_examples_on_file() # in case there are buffered examples to write first + self._write_batch(batch_examples, writer_batch_size=writer_batch_size, try_original_type=try_original_type) + + def _write_batch( + self, + batch_examples: dict[str, list], + writer_batch_size: Optional[int] = None, + try_original_type: Optional[bool] = True, + ): if batch_examples and len(next(iter(batch_examples.values()))) == 0: return features = None if self.pa_writer is None and self.update_features else self._features @@ -752,6 +761,10 @@ def write_table(self, pa_table: pa.Table, writer_batch_size: Optional[int] = Non Args: example: the Table to add. """ + self.write_rows_on_file() # in case there are buffered rows to write first + self._write_table(pa_table, writer_batch_size=writer_batch_size) + + def _write_table(self, pa_table: pa.Table, writer_batch_size: Optional[int] = None): if writer_batch_size is None: writer_batch_size = self.writer_batch_size if self.pa_writer is None: diff --git a/src/datasets/builder.py b/src/datasets/builder.py index 644447e5b71..b995242f016 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -1794,7 +1794,10 @@ def _prepare_split_single( embed_local_files=embed_local_files, ) try: - writer.write_table(table) + if len(table) == 1: + writer.write_row(table) + else: + writer.write_table(table) except CastError as cast_error: raise DatasetGenerationCastError.from_cast_error( cast_error=cast_error, diff --git a/src/datasets/load.py b/src/datasets/load.py index 212567684f8..560bcad3a44 100644 --- a/src/datasets/load.py +++ b/src/datasets/load.py @@ -669,7 +669,7 @@ def get_module(self) -> DatasetModule: ] default_config_name = None builder_kwargs = { - "base_path": hf_dataset_url(self.name, "", revision=self.commit_hash).rstrip("/"), + "base_path": base_path, "repo_id": self.name, "dataset_name": camelcase_to_snakecase(Path(self.name).name), } diff --git a/src/datasets/packaged_modules/json/json.py b/src/datasets/packaged_modules/json/json.py index 243a1634612..db70ff03b0e 100644 --- a/src/datasets/packaged_modules/json/json.py +++ b/src/datasets/packaged_modules/json/json.py @@ -1,4 +1,5 @@ import io +import os from dataclasses import dataclass from typing import Literal, Optional @@ -50,6 +51,7 @@ class JsonConfig(datasets.BuilderConfig): chunksize: int = 10 << 20 # 10MB newlines_in_values: Optional[bool] = None on_mixed_types: Optional[Literal["use_json"]] = "use_json" + parse_agent_traces: bool = True def __post_init__(self): super().__post_init__() @@ -83,13 +85,19 @@ def _split_generators(self, dl_manager): splits.append( datasets.SplitGenerator( name=split_name, - gen_kwargs={"files_iterables": files_iterables, "base_files": base_data_files[split_name]}, + gen_kwargs={ + "files_iterables": files_iterables, + "base_files": base_data_files[split_name], + "original_files": self.config.data_files[split_name], + }, ) ) if self.info.features is None: try: pa_table = next(iter(self._generate_tables(**splits[0].gen_kwargs, allow_full_read=False)))[1] self.info.features = datasets.Features.from_arrow_schema(pa_table.schema) + if self.config.parse_agent_traces and has_agent_traces_markers(self.info.features): + self.info.features = AGENT_TRACES_FEATURES except FullReadDisallowed: pass return splits @@ -124,14 +132,18 @@ def _cast_table(self, pa_table: pa.Table, json_field_paths=()) -> pa.Table: pa_table = table_cast(pa_table, features.arrow_schema) return pa_table - def _generate_shards(self, base_files, files_iterables): + def _generate_shards(self, base_files, files_iterables, original_files): yield from base_files - def _generate_tables(self, base_files, files_iterables, allow_full_read=True): + def _generate_tables(self, base_files, files_iterables, original_files, allow_full_read=True): json_field_paths = [] + is_agent_traces = False if self.info.features is not None: - json_field_paths = get_json_field_paths_from_feature(self.info.features) + if self.info.features == AGENT_TRACES_FEATURES: + is_agent_traces = True + else: + json_field_paths = get_json_field_paths_from_feature(self.info.features) for shard_idx, files_iterable in enumerate(files_iterables): for file in files_iterable: @@ -149,6 +161,24 @@ def _generate_tables(self, base_files, files_iterables, allow_full_read=True): pa_table = pa.Table.from_pandas(df, preserve_index=False) yield Key(shard_idx, 0), self._cast_table(pa_table) + # If the files are agent traces (one row = one file) + elif is_agent_traces: + with open(file, "r", encoding="utf-8") as f: + traces = f.readlines() + harness, session_id = parse_traces_info(traces) + file_path = original_files[shard_idx] + if file_path.startswith(self.base_path): + file_path = os.path.relpath(file_path, self.base_path) + pa_table = pa.Table.from_pydict( + { + "harness": [harness], + "session_id": [session_id], + "traces": [traces], + "file_path": [file_path], + } + ) + yield Key(shard_idx, 0), self._cast_table(pa_table) + # If the file has one json object per line else: with open(file, "rb") as f: @@ -265,3 +295,89 @@ def _generate_tables(self, base_files, files_iterables, allow_full_read=True): self._cast_table(pa_table, json_field_paths=json_field_paths), ) batch_idx += 1 + + +AGENT_TRACES_TYPES_VALUES = { + "claude_code": ["user", "assistant", "system"], + "pi": ["session", "message"], + "codex": ["session_meta", "turn_context", "response_item", "event_msg"], +} +AGENT_TRACES_TYPE_TO_HARNESS = {} +for _harness, _trace_types in AGENT_TRACES_TYPES_VALUES.items(): + for _trace_type in _trace_types: + AGENT_TRACES_TYPE_TO_HARNESS[_trace_type] = _harness + + +AGENT_TRACES_FEATURES_MARKERS = { + "claude_code": datasets.Features( + { + "type": datasets.Value("string"), + "message": datasets.Json(), + } + ), + "pi": datasets.Features( + { + "type": datasets.Value("string"), + "message": datasets.Json(), + } + ), + "codex": datasets.Features( + { + "type": datasets.Value("string"), + "payload": datasets.Json(), + } + ), +} + +AGENT_TRACES_FEATURES = datasets.Features( + { + "harness": datasets.Value("string"), + "session_id": datasets.Value("string"), + "traces": datasets.List(datasets.Json()), + "file_path": datasets.Value("string"), + } +) + + +def has_agent_traces_markers(features: datasets.Features) -> bool: + for agent_traces_features_marker in AGENT_TRACES_FEATURES_MARKERS.values(): + if all(features.get(key) == feature for key, feature in agent_traces_features_marker.items()): + return True + return False + + +def parse_traces_info(traces: list[str]) -> tuple[Optional[str], Optional[str]]: + harness, session_id = None, None + for trace in traces: + decoded_trace = ujson_loads(trace) + if harness is None: + if "type" in decoded_trace and isinstance(decoded_trace["type"], str): + harness = AGENT_TRACES_TYPE_TO_HARNESS.get(decoded_trace["type"]) + if session_id is None: + # claude + if "sessionId" in decoded_trace and isinstance(decoded_trace["sessionId"], str): + session_id = decoded_trace["sessionId"] + # claude (not sure but this format does exist online) + elif "session_id" in decoded_trace and isinstance(decoded_trace["session_id"], str): + session_id = decoded_trace["session_id"] + # codex + elif ( + "payload" in decoded_trace + and isinstance(decoded_trace["payload"], dict) + and "id" in decoded_trace["payload"] + and isinstance(decoded_trace["payload"]["id"], str) + ): + session_id = decoded_trace["payload"]["id"] + # pi / openclaw (openclaw embeds pi-agent; distinguish via cwd) + elif ( + "type" in decoded_trace + and decoded_trace["type"] == "session" + and "id" in decoded_trace + and isinstance(decoded_trace["id"], str) + ): + session_id = decoded_trace["id"] + if isinstance(decoded_trace.get("cwd"), str) and "/.openclaw/" in decoded_trace["cwd"]: + harness = "openclaw" + if harness and session_id: + break + return harness, session_id diff --git a/src/datasets/utils/json.py b/src/datasets/utils/json.py index f9595b6c8b8..216338fd011 100644 --- a/src/datasets/utils/json.py +++ b/src/datasets/utils/json.py @@ -23,7 +23,7 @@ def ujson_loads(*args, **kwargs): return pd.io.json.loads(*args, **kwargs) -def json_encode_field(example: Any, json_field_path: str) -> Any: +def json_encode_field(example: Any, json_field_path: list[str]) -> Any: if json_field_path: field, *json_field_path = json_field_path if example is None: @@ -57,7 +57,7 @@ def json_decode_field(example: Any, json_field_path: str) -> Any: return example -def find_mixed_struct_types_field_paths(examples: list, allow_root=False) -> list[str]: +def find_mixed_struct_types_field_paths(examples: list, allow_root=False) -> list[list[str]]: mixed_struct_types_field_paths = [] examples = [example for example in examples if example is not None] if not examples: @@ -84,7 +84,7 @@ def find_mixed_struct_types_field_paths(examples: list, allow_root=False) -> lis return mixed_struct_types_field_paths -def get_json_field_path_from_pyarrow_json_error(err_str: str) -> str: +def get_json_field_path_from_pyarrow_json_error(err_str: str) -> list[str]: # e.g. json_field_path_str = "col/subfield_containing_a_list/[]/subsubfield_in_item_in_the_list" json_field_path_str = err_str.split("Column(", 1)[1].rsplit(") changed from", 1)[0].strip("/") # e.g. json_field_path = ["col", "subfield_containing_a_list", 0, "subsubfield_in_item_in_the_list"] @@ -92,7 +92,7 @@ def get_json_field_path_from_pyarrow_json_error(err_str: str) -> str: return json_field_path -def insert_json_field_path(json_field_paths: list[str], json_field_path: str) -> None: +def insert_json_field_path(json_field_paths: list[list[str]], json_field_path: list[str]) -> None: # Add to list of json_field_paths and check if other share a common path for i in range(len(json_field_paths)): if json_field_paths[i][: len(json_field_path)] == json_field_path: @@ -102,7 +102,7 @@ def insert_json_field_path(json_field_paths: list[str], json_field_path: str) -> json_field_paths.append(json_field_path) -def json_encode_fields_in_json_lines(original_batch: bytes, json_field_paths: list[str]) -> bytes: +def json_encode_fields_in_json_lines(original_batch: bytes, json_field_paths: list[list[str]]) -> bytes: examples = [ujson_loads(line) for line in original_batch.splitlines()] for json_field_path in json_field_paths: examples = [json_encode_field(example, json_field_path) for example in examples] @@ -110,7 +110,7 @@ def json_encode_fields_in_json_lines(original_batch: bytes, json_field_paths: li return batch -def get_json_field_paths_from_feature(feature: "FeatureType") -> list[str]: +def get_json_field_paths_from_feature(feature: "FeatureType") -> list[list[str]]: from datasets.features.features import Json, _visit_with_path json_field_paths = [] @@ -124,7 +124,7 @@ def get_json_type_path(_feature, feature_path): return json_field_paths -def set_json_types_in_feature(feature: "FeatureType", json_field_paths: list[str]) -> None: +def set_json_types_in_feature(feature: "FeatureType", json_field_paths: list[list[str]]) -> None: from datasets.features.features import Json, _visit_with_path def set_json_type(feature, feature_path): diff --git a/tests/packaged_modules/test_json.py b/tests/packaged_modules/test_json.py index 1ae094b425b..9ccb6fc0046 100644 --- a/tests/packaged_modules/test_json.py +++ b/tests/packaged_modules/test_json.py @@ -326,7 +326,10 @@ def test_json_generate_tables(file_fixture, config_kwargs, expected, request): json = Json(**config_kwargs) base_files = [request.getfixturevalue(file_fixture)] files_iterables = [[file] for file in base_files] - generator = json._generate_tables(base_files=base_files, files_iterables=files_iterables) + original_files = list(base_files) + generator = json._generate_tables( + base_files=base_files, files_iterables=files_iterables, original_files=original_files + ) pa_table = pa.concat_tables([table for _, table in generator]) out = Features.from_arrow_schema(pa_table.schema).decode_batch(pa_table.to_pydict()) assert out == expected