diff --git a/setup.py b/setup.py index 4ce18d37ce3..56f6e7b32b5 100644 --- a/setup.py +++ b/setup.py @@ -194,7 +194,7 @@ "torchcodec>=0.7.0; python_version < '3.14'", # minium version to get windows support, torchcodec doesn't have wheels for 3.14 yet "nibabel>=5.3.1", "trimesh>=4.10.0", - "teich==0.1.5", + "teich==0.2.8", ] NUMPY2_INCOMPATIBLE_LIBRARIES = [ diff --git a/src/datasets/packaged_modules/json/json.py b/src/datasets/packaged_modules/json/json.py index 9845dd676d4..e88507ee85a 100644 --- a/src/datasets/packaged_modules/json/json.py +++ b/src/datasets/packaged_modules/json/json.py @@ -177,9 +177,11 @@ def _generate_tables(self, base_files, files_iterables, original_files, allow_fu if self.base_path is not None and file_path.startswith(self.base_path): file_path = os.path.relpath(file_path, self.base_path) training_examples = convert_traces_to_training_data(trace_file) + is_native_cursor_trace = has_cursor_native_trace_markers(lines) examples = [] for i, training_example in enumerate(training_examples): - if training_example["metadata"]["trace_type"] == "hermes": + trace_type = training_example["metadata"]["trace_type"] + if trace_type == "hermes": timestamp = ujson_loads(lines[i])["started_at"] milliseconds = timestamp if timestamp <= 10_000_000_000 else timestamp * 1_000 sent_at = ( @@ -195,6 +197,21 @@ def _generate_tables(self, base_files, files_iterables, original_files, allow_fu "num_tool_calls": training_example["metadata"]["tool_call_count"], "trace": lines[i], } + elif trace_type == "cursor": + source_line = training_example["metadata"].get("source_line") + trace_line = lines[i] + if isinstance(source_line, int) and 1 <= source_line <= len(lines): + trace_line = lines[source_line - 1] + cursor_trace = trace if is_native_cursor_trace else trace_line + bonus_fields = { + "harness": trace_type, + "session_id": get_cursor_session_id(training_example, trace_line, trace_file), + "prompt": training_example.get("prompt"), + "sent_at": get_cursor_trace_timestamp(trace_line), + "num_user_messages": get_training_example_user_message_count(training_example), + "num_tool_calls": get_training_example_tool_call_count(training_example), + "trace": cursor_trace, + } else: harness, session_id, prompt, sent_at, num_user_messages, num_tool_calls = ( parse_traces_info(lines) @@ -359,6 +376,14 @@ def _generate_tables(self, base_files, files_iterables, original_files, allow_fu AGENT_TRACES_FEATURES_MARKERS = { + "cursor_native": datasets.Features( + { + "role": lambda f: f == Value("string"), + "message": lambda f: isinstance(f, (dict, datasets.Json)), + "type": lambda f: f == Value("string"), + "status": lambda f: f == Value("string"), + } + ), "claude_code_or_pi_or_openclaw": datasets.Features( { "type": lambda f: f == Value("string"), @@ -388,6 +413,18 @@ def _generate_tables(self, base_files, files_iterables, original_files, allow_fu "cwd": lambda f: f == Value("string"), } ), + "cursor": datasets.Features( + { + "messages": lambda f: isinstance(f, datasets.List), + "metadata": lambda f: isinstance(f, dict) + and f.get("trace_type") == Value("string") + and f.get("source") == Value("string") + and f.get("cursor_composer_id") == Value("string"), + "raw_cursor": lambda f: isinstance(f, dict) + and isinstance(f.get("composer_data"), (dict, datasets.Json)) + and isinstance(f.get("bubble_ids"), datasets.List), + } + ), } AGENT_TRACES_FEATURES = datasets.Features( @@ -417,6 +454,18 @@ def has_agent_traces_markers(features: datasets.Features) -> bool: return False +def has_cursor_native_trace_markers(trace_events: list[str]) -> bool: + has_message = False + has_turn_end = False + for event in trace_events: + decoded_event = ujson_loads(event) + if decoded_event.get("role") in ("user", "assistant") and isinstance(decoded_event.get("message"), dict): + has_message = True + elif decoded_event.get("type") == "turn_ended" and isinstance(decoded_event.get("status"), str): + has_turn_end = True + return has_message and has_turn_end + + def parse_traces_info( trace_events: list[str], ) -> tuple[Optional[str], Optional[str], Optional[str], Optional[str], int, int]: @@ -519,6 +568,60 @@ def get_trace_event_timestamp(trace_event: dict) -> Optional[str]: return None +def get_cursor_session_id(training_example: dict, trace_line: str, trace_file: Path) -> Optional[str]: + metadata = training_example.get("metadata") + if isinstance(metadata, dict) and isinstance(metadata.get("cursor_composer_id"), str): + return metadata["cursor_composer_id"] + try: + trace_event = ujson_loads(trace_line) + except ValueError: + return None + raw_cursor = trace_event.get("raw_cursor") + if not isinstance(raw_cursor, dict): + return None + composer_data = raw_cursor.get("composer_data") + if isinstance(composer_data, dict) and isinstance(composer_data.get("composerId"), str): + return composer_data["composerId"] + return trace_file.stem + + +def get_cursor_trace_timestamp(trace_line: str) -> Optional[str]: + try: + trace_event = ujson_loads(trace_line) + except ValueError: + return None + raw_cursor = trace_event.get("raw_cursor") + if not isinstance(raw_cursor, dict): + return None + composer_data = raw_cursor.get("composer_data") + if not isinstance(composer_data, dict): + return None + created_at = composer_data.get("createdAt") + if not isinstance(created_at, (int, float)): + return None + seconds = created_at / 1_000 if created_at > 10_000_000_000 else created_at + return datetime.fromtimestamp(seconds, tz=timezone.utc).isoformat(timespec="milliseconds").replace("+00:00", "Z") + + +def get_training_example_user_message_count(training_example: dict) -> int: + messages = training_example.get("messages") + if not isinstance(messages, list): + return 0 + return sum(1 for message in messages if isinstance(message, dict) and message.get("role") == "user") + + +def get_training_example_tool_call_count(training_example: dict) -> int: + messages = training_example.get("messages") + if not isinstance(messages, list): + return 0 + return sum( + len(tool_calls) + for message in messages + if isinstance(message, dict) + and isinstance(tool_calls := message.get("tool_calls"), list) + ) + + def get_content_text(content) -> Optional[str]: if isinstance(content, str): return content diff --git a/tests/packaged_modules/test_json.py b/tests/packaged_modules/test_json.py index 2d6fd5b3849..3a4a55c4920 100644 --- a/tests/packaged_modules/test_json.py +++ b/tests/packaged_modules/test_json.py @@ -515,6 +515,71 @@ def assert_agent_traces_output(tmp_path, filename, rows, expected, num_sessions= }, ] +CURSOR_SESSION = { + "messages": [ + {"role": "user", "content": "cursor prompt"}, + { + "role": "assistant", + "content": "cursor response", + "tool_calls": [ + { + "id": "call_1", + "type": "function", + "function": {"name": "read_file", "arguments": '{"path": "README.md"}'}, + } + ], + }, + {"role": "tool", "tool_call_id": "call_1", "name": "read_file", "content": "done"}, + {"role": "user", "content": "second cursor prompt"}, + {"role": "assistant", "content": "second response"}, + ], + "prompt": "cursor prompt", + "response": "second response", + "model": "claude-4-sonnet", + "tools": [ + { + "type": "function", + "function": { + "name": "read_file", + "description": "Read file", + "parameters": {"type": "object", "properties": {"path": {"type": "string"}}}, + }, + } + ], + "metadata": { + "trace_type": "cursor", + "source": "cursor", + "cursor_composer_id": "cursor-session", + "model": "claude-4-sonnet", + }, + "raw_cursor": { + "composer_data": { + "composerId": "cursor-session", + "createdAt": 1751526368368, + "lastUpdatedAt": 1751526730702, + "status": "completed", + }, + "bubble_ids": ["bubble-1"], + }, +} + +CURSOR_NATIVE_SESSION = [ + { + "role": "user", + "message": {"content": [{"type": "text", "text": "\nInspect index.html\n"}]}, + }, + { + "role": "assistant", + "message": { + "content": [ + {"type": "text", "text": "I'll inspect it."}, + {"type": "tool_use", "id": "read_file_0", "name": "read_file", "input": {"path": "index.html"}}, + ] + }, + }, + {"type": "turn_ended", "status": "success"}, +] + def test_config_raises_when_invalid_name() -> None: with pytest.raises(InvalidConfigName, match="Bad characters"): @@ -727,6 +792,18 @@ def test_json_generate_tables_with_sorted_columns(file_fixture, config_kwargs, r ("droid", "droid-session", "Inspect the project", "2026-06-02T18:55:30.274Z", 1, 1), id="droid", ), + pytest.param( + "cursor.jsonl", + [CURSOR_SESSION], + ("cursor", "cursor-session", "cursor prompt", "2025-07-03T07:06:08.368Z", 2, 1), + id="cursor", + ), + pytest.param( + "cursor-native-session.jsonl", + CURSOR_NATIVE_SESSION, + ("cursor", "cursor-native-session", "Inspect index.html", None, 1, 1), + id="cursor-native", + ), pytest.param( "missing_prompt.jsonl", [ @@ -745,8 +822,9 @@ def test_json_generate_tables_with_sorted_columns(file_fixture, config_kwargs, r def test_json_generate_tables_with_agent_trace_metadata(tmp_path, filename, rows, expected): num_sessions = 2 if filename == "hermes_two_sessions.jsonl" else 1 _, out = assert_agent_traces_output(tmp_path, filename, rows, expected, num_sessions=num_sessions) - if filename == "droid.jsonl": - assert out["metadata"][0]["trace_type"] == "droid" + if filename in ("cursor.jsonl", "cursor-native-session.jsonl", "droid.jsonl"): + expected_trace_type = "cursor" if filename.startswith("cursor") else "droid" + assert out["metadata"][0]["trace_type"] == expected_trace_type assert "models" not in out @@ -761,6 +839,18 @@ def test_json_generate_tables_with_agent_trace_metadata(tmp_path, filename, rows ("droid", "droid-session", "Inspect the project", "2026-06-02T18:55:30.274Z", 1, 1), id="droid", ), + pytest.param( + "cursor.jsonl", + [CURSOR_SESSION], + ("cursor", "cursor-session", "cursor prompt", "2025-07-03T07:06:08.368Z", 2, 1), + id="cursor", + ), + pytest.param( + "cursor-native-session.jsonl", + CURSOR_NATIVE_SESSION, + ("cursor", "cursor-native-session", "Inspect index.html", None, 1, 1), + id="cursor-native", + ), ], ) def test_json_load_dataset_with_agent_trace_metadata(tmp_path, filename, rows, expected): @@ -787,6 +877,12 @@ def test_json_load_dataset_with_agent_trace_metadata(tmp_path, filename, rows, e if filename == "droid.jsonl": assert row["metadata"]["trace_type"] == "droid" assert json.loads(row["trace"].splitlines()[0])["type"] == "session_start" + elif filename == "cursor.jsonl": + assert row["metadata"]["trace_type"] == "cursor" + assert row["trace"]["raw_cursor"]["composer_data"]["composerId"] == "cursor-session" + elif filename == "cursor-native-session.jsonl": + assert row["metadata"]["trace_type"] == "cursor" + assert json.loads(row["trace"].splitlines()[-1]) == {"type": "turn_ended", "status": "success"} def test_json_load_dataset_without_droid_marker_stays_ordinary_json(tmp_path): @@ -807,3 +903,18 @@ def test_json_load_dataset_without_droid_marker_stays_ordinary_json(tmp_path): assert dataset.column_names == ["type", "id", "version", "timestamp", "message"] assert dataset[0]["type"] == "session_start" + + +def test_json_load_dataset_without_cursor_native_marker_stays_ordinary_json(tmp_path): + trace_file = write_jsonl( + tmp_path / "cursor_missing_marker.jsonl", + [ + {"role": "user", "message": {"content": [{"type": "text", "text": "Inspect index.html"}]}}, + {"role": "assistant", "message": {"content": [{"type": "text", "text": "I'll inspect it."}]}}, + ], + ) + + dataset = load_dataset("json", data_files=trace_file, split="train", cache_dir=str(tmp_path / "cache")) + + assert dataset.column_names == ["role", "message"] + assert dataset[0]["role"] == "user"