Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@
"torchcodec>=0.7.0; python_version < '3.14'", # minium version to get windows support, torchcodec doesn't have wheels for 3.14 yet
"nibabel>=5.3.1",
"trimesh>=4.10.0",
"teich==0.1.5",
"teich==0.2.8",
]

NUMPY2_INCOMPATIBLE_LIBRARIES = [
Expand Down
105 changes: 104 additions & 1 deletion src/datasets/packaged_modules/json/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,11 @@ def _generate_tables(self, base_files, files_iterables, original_files, allow_fu
if self.base_path is not None and file_path.startswith(self.base_path):
file_path = os.path.relpath(file_path, self.base_path)
training_examples = convert_traces_to_training_data(trace_file)
is_native_cursor_trace = has_cursor_native_trace_markers(lines)
examples = []
for i, training_example in enumerate(training_examples):
if training_example["metadata"]["trace_type"] == "hermes":
trace_type = training_example["metadata"]["trace_type"]
if trace_type == "hermes":
timestamp = ujson_loads(lines[i])["started_at"]
milliseconds = timestamp if timestamp <= 10_000_000_000 else timestamp * 1_000
sent_at = (
Expand All @@ -195,6 +197,21 @@ def _generate_tables(self, base_files, files_iterables, original_files, allow_fu
"num_tool_calls": training_example["metadata"]["tool_call_count"],
"trace": lines[i],
}
elif trace_type == "cursor":
source_line = training_example["metadata"].get("source_line")
trace_line = lines[i]
if isinstance(source_line, int) and 1 <= source_line <= len(lines):
trace_line = lines[source_line - 1]
cursor_trace = trace if is_native_cursor_trace else trace_line
bonus_fields = {
"harness": trace_type,
"session_id": get_cursor_session_id(training_example, trace_line, trace_file),
"prompt": training_example.get("prompt"),
"sent_at": get_cursor_trace_timestamp(trace_line),
"num_user_messages": get_training_example_user_message_count(training_example),
"num_tool_calls": get_training_example_tool_call_count(training_example),
"trace": cursor_trace,
}
else:
harness, session_id, prompt, sent_at, num_user_messages, num_tool_calls = (
parse_traces_info(lines)
Expand Down Expand Up @@ -359,6 +376,14 @@ def _generate_tables(self, base_files, files_iterables, original_files, allow_fu


AGENT_TRACES_FEATURES_MARKERS = {
"cursor_native": datasets.Features(
{
"role": lambda f: f == Value("string"),
"message": lambda f: isinstance(f, (dict, datasets.Json)),
"type": lambda f: f == Value("string"),
"status": lambda f: f == Value("string"),
}
),
"claude_code_or_pi_or_openclaw": datasets.Features(
{
"type": lambda f: f == Value("string"),
Expand Down Expand Up @@ -388,6 +413,18 @@ def _generate_tables(self, base_files, files_iterables, original_files, allow_fu
"cwd": lambda f: f == Value("string"),
}
),
"cursor": datasets.Features(
{
"messages": lambda f: isinstance(f, datasets.List),
"metadata": lambda f: isinstance(f, dict)
and f.get("trace_type") == Value("string")
and f.get("source") == Value("string")
and f.get("cursor_composer_id") == Value("string"),
"raw_cursor": lambda f: isinstance(f, dict)
and isinstance(f.get("composer_data"), (dict, datasets.Json))
and isinstance(f.get("bubble_ids"), datasets.List),
}
),
}

AGENT_TRACES_FEATURES = datasets.Features(
Expand Down Expand Up @@ -417,6 +454,18 @@ def has_agent_traces_markers(features: datasets.Features) -> bool:
return False


def has_cursor_native_trace_markers(trace_events: list[str]) -> bool:
has_message = False
has_turn_end = False
for event in trace_events:
decoded_event = ujson_loads(event)
if decoded_event.get("role") in ("user", "assistant") and isinstance(decoded_event.get("message"), dict):
has_message = True
elif decoded_event.get("type") == "turn_ended" and isinstance(decoded_event.get("status"), str):
has_turn_end = True
return has_message and has_turn_end


def parse_traces_info(
trace_events: list[str],
) -> tuple[Optional[str], Optional[str], Optional[str], Optional[str], int, int]:
Expand Down Expand Up @@ -519,6 +568,60 @@ def get_trace_event_timestamp(trace_event: dict) -> Optional[str]:
return None


def get_cursor_session_id(training_example: dict, trace_line: str, trace_file: Path) -> Optional[str]:
metadata = training_example.get("metadata")
if isinstance(metadata, dict) and isinstance(metadata.get("cursor_composer_id"), str):
return metadata["cursor_composer_id"]
try:
trace_event = ujson_loads(trace_line)
except ValueError:
return None
raw_cursor = trace_event.get("raw_cursor")
if not isinstance(raw_cursor, dict):
return None
composer_data = raw_cursor.get("composer_data")
if isinstance(composer_data, dict) and isinstance(composer_data.get("composerId"), str):
return composer_data["composerId"]
return trace_file.stem


def get_cursor_trace_timestamp(trace_line: str) -> Optional[str]:
try:
trace_event = ujson_loads(trace_line)
except ValueError:
return None
raw_cursor = trace_event.get("raw_cursor")
if not isinstance(raw_cursor, dict):
return None
composer_data = raw_cursor.get("composer_data")
if not isinstance(composer_data, dict):
return None
created_at = composer_data.get("createdAt")
if not isinstance(created_at, (int, float)):
return None
seconds = created_at / 1_000 if created_at > 10_000_000_000 else created_at
return datetime.fromtimestamp(seconds, tz=timezone.utc).isoformat(timespec="milliseconds").replace("+00:00", "Z")


def get_training_example_user_message_count(training_example: dict) -> int:
messages = training_example.get("messages")
if not isinstance(messages, list):
return 0
return sum(1 for message in messages if isinstance(message, dict) and message.get("role") == "user")


def get_training_example_tool_call_count(training_example: dict) -> int:
messages = training_example.get("messages")
if not isinstance(messages, list):
return 0
return sum(
len(tool_calls)
for message in messages
if isinstance(message, dict)
and isinstance(tool_calls := message.get("tool_calls"), list)
)


def get_content_text(content) -> Optional[str]:
if isinstance(content, str):
return content
Expand Down
115 changes: 113 additions & 2 deletions tests/packaged_modules/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,71 @@ def assert_agent_traces_output(tmp_path, filename, rows, expected, num_sessions=
},
]

CURSOR_SESSION = {
"messages": [
{"role": "user", "content": "cursor prompt"},
{
"role": "assistant",
"content": "cursor response",
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {"name": "read_file", "arguments": '{"path": "README.md"}'},
}
],
},
{"role": "tool", "tool_call_id": "call_1", "name": "read_file", "content": "done"},
{"role": "user", "content": "second cursor prompt"},
{"role": "assistant", "content": "second response"},
],
"prompt": "cursor prompt",
"response": "second response",
"model": "claude-4-sonnet",
"tools": [
{
"type": "function",
"function": {
"name": "read_file",
"description": "Read file",
"parameters": {"type": "object", "properties": {"path": {"type": "string"}}},
},
}
],
"metadata": {
"trace_type": "cursor",
"source": "cursor",
"cursor_composer_id": "cursor-session",
"model": "claude-4-sonnet",
},
"raw_cursor": {
"composer_data": {
"composerId": "cursor-session",
"createdAt": 1751526368368,
"lastUpdatedAt": 1751526730702,
"status": "completed",
},
"bubble_ids": ["bubble-1"],
},
}

CURSOR_NATIVE_SESSION = [
{
"role": "user",
"message": {"content": [{"type": "text", "text": "<user_query>\nInspect index.html\n</user_query>"}]},
},
{
"role": "assistant",
"message": {
"content": [
{"type": "text", "text": "I'll inspect it."},
{"type": "tool_use", "id": "read_file_0", "name": "read_file", "input": {"path": "index.html"}},
]
},
},
{"type": "turn_ended", "status": "success"},
]


def test_config_raises_when_invalid_name() -> None:
with pytest.raises(InvalidConfigName, match="Bad characters"):
Expand Down Expand Up @@ -727,6 +792,18 @@ def test_json_generate_tables_with_sorted_columns(file_fixture, config_kwargs, r
("droid", "droid-session", "Inspect the project", "2026-06-02T18:55:30.274Z", 1, 1),
id="droid",
),
pytest.param(
"cursor.jsonl",
[CURSOR_SESSION],
("cursor", "cursor-session", "cursor prompt", "2025-07-03T07:06:08.368Z", 2, 1),
id="cursor",
),
pytest.param(
"cursor-native-session.jsonl",
CURSOR_NATIVE_SESSION,
("cursor", "cursor-native-session", "Inspect index.html", None, 1, 1),
id="cursor-native",
),
pytest.param(
"missing_prompt.jsonl",
[
Expand All @@ -745,8 +822,9 @@ def test_json_generate_tables_with_sorted_columns(file_fixture, config_kwargs, r
def test_json_generate_tables_with_agent_trace_metadata(tmp_path, filename, rows, expected):
num_sessions = 2 if filename == "hermes_two_sessions.jsonl" else 1
_, out = assert_agent_traces_output(tmp_path, filename, rows, expected, num_sessions=num_sessions)
if filename == "droid.jsonl":
assert out["metadata"][0]["trace_type"] == "droid"
if filename in ("cursor.jsonl", "cursor-native-session.jsonl", "droid.jsonl"):
expected_trace_type = "cursor" if filename.startswith("cursor") else "droid"
assert out["metadata"][0]["trace_type"] == expected_trace_type
assert "models" not in out


Expand All @@ -761,6 +839,18 @@ def test_json_generate_tables_with_agent_trace_metadata(tmp_path, filename, rows
("droid", "droid-session", "Inspect the project", "2026-06-02T18:55:30.274Z", 1, 1),
id="droid",
),
pytest.param(
"cursor.jsonl",
[CURSOR_SESSION],
("cursor", "cursor-session", "cursor prompt", "2025-07-03T07:06:08.368Z", 2, 1),
id="cursor",
),
pytest.param(
"cursor-native-session.jsonl",
CURSOR_NATIVE_SESSION,
("cursor", "cursor-native-session", "Inspect index.html", None, 1, 1),
id="cursor-native",
),
],
)
def test_json_load_dataset_with_agent_trace_metadata(tmp_path, filename, rows, expected):
Expand All @@ -787,6 +877,12 @@ def test_json_load_dataset_with_agent_trace_metadata(tmp_path, filename, rows, e
if filename == "droid.jsonl":
assert row["metadata"]["trace_type"] == "droid"
assert json.loads(row["trace"].splitlines()[0])["type"] == "session_start"
elif filename == "cursor.jsonl":
assert row["metadata"]["trace_type"] == "cursor"
assert row["trace"]["raw_cursor"]["composer_data"]["composerId"] == "cursor-session"
elif filename == "cursor-native-session.jsonl":
assert row["metadata"]["trace_type"] == "cursor"
assert json.loads(row["trace"].splitlines()[-1]) == {"type": "turn_ended", "status": "success"}


def test_json_load_dataset_without_droid_marker_stays_ordinary_json(tmp_path):
Expand All @@ -807,3 +903,18 @@ def test_json_load_dataset_without_droid_marker_stays_ordinary_json(tmp_path):

assert dataset.column_names == ["type", "id", "version", "timestamp", "message"]
assert dataset[0]["type"] == "session_start"


def test_json_load_dataset_without_cursor_native_marker_stays_ordinary_json(tmp_path):
trace_file = write_jsonl(
tmp_path / "cursor_missing_marker.jsonl",
[
{"role": "user", "message": {"content": [{"type": "text", "text": "Inspect index.html"}]}},
{"role": "assistant", "message": {"content": [{"type": "text", "text": "I'll inspect it."}]}},
],
)

dataset = load_dataset("json", data_files=trace_file, split="train", cache_dir=str(tmp_path / "cache"))

assert dataset.column_names == ["role", "message"]
assert dataset[0]["role"] == "user"