-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathdataset.py
More file actions
140 lines (111 loc) · 5.23 KB
/
dataset.py
File metadata and controls
140 lines (111 loc) · 5.23 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import json
import os
from typing import List, Optional, Dict, Any
class ProgramSynthesisDatapoint:
def __init__(self, data: Dict[str, Any]):
"""
Initialize a program synthesis datapoint from a dictionary.
Args:
data: Dictionary containing problem data from JSONL file
"""
self.description = data.get("description", "")
self.input_from = data.get("input_from", "")
self.output_to = data.get("output_to", "")
self.time_limit = data.get("time_limit", 5.0) # Default to 5.0 seconds
self.memory_limit = data.get("memory_limit", "")
self.input_spec = data.get("input_spec", "")
self.output_spec = data.get("output_spec", "")
self.notes = data.get("notes", "")
self.sample_inputs = data.get("sample_inputs", [])
self.sample_outputs = data.get("sample_outputs", [])
self.tags = data.get("tags", [])
self.src_uid = data.get("src_uid", "")
self.difficulty = data.get("difficulty")
def __str__(self):
return f"ProgramSynthesisDatapoint(difficulty={self.difficulty}, tags={self.tags})"
def __repr__(self):
return self.__str__()
class ProgramSynthesisDataset:
def __init__(self,
data_file: str = "data/dataset.jsonl",
difficulty_cutoff: Optional[int] = None,
max_samples: int = 20):
"""
Initialize a program synthesis dataset.
Args:
data_file: Path to the JSONL data file (relative to this file)
difficulty_cutoff: Maximum difficulty level to include (None for no filter)
max_samples: Maximum number of samples to load (default: 20)
"""
self.data_file = data_file
self.difficulty_cutoff = difficulty_cutoff
self.max_samples = max_samples
self.datapoints: List[ProgramSynthesisDatapoint] = []
self._load_data()
def _load_data(self):
"""Load data from the JSONL file with filtering."""
# Get the directory of the current file
current_dir = os.path.dirname(os.path.abspath(__file__))
file_path = os.path.join(current_dir, self.data_file)
if not os.path.exists(file_path):
raise FileNotFoundError(f"Data file not found: {file_path}")
loaded_samples = 0
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
if loaded_samples >= self.max_samples:
break
line = line.strip()
if not line:
continue
try:
data = json.loads(line)
# Apply difficulty filter if specified
if self.difficulty_cutoff is not None:
difficulty = data.get("difficulty")
if difficulty is None or difficulty > self.difficulty_cutoff:
continue
datapoint = ProgramSynthesisDatapoint(data)
self.datapoints.append(datapoint)
loaded_samples += 1
except json.JSONDecodeError as e:
print(f"Warning: Skipping invalid JSON line: {e}")
continue
def __len__(self):
return len(self.datapoints)
def __getitem__(self, idx):
return self.datapoints[idx]
def __iter__(self):
return iter(self.datapoints)
def get_stats(self) -> Dict[str, Any]:
"""Get statistics about the loaded dataset."""
if not self.datapoints:
return {"total_samples": 0}
difficulties = [dp.difficulty for dp in self.datapoints if dp.difficulty is not None]
all_tags = []
for dp in self.datapoints:
all_tags.extend(dp.tags)
stats = {
"total_samples": len(self.datapoints),
"difficulty_range": (min(difficulties), max(difficulties)) if difficulties else None,
"avg_difficulty": sum(difficulties) / len(difficulties) if difficulties else None,
"unique_tags": len(set(all_tags)),
"most_common_tags": self._get_most_common_tags(all_tags)
}
return stats
def _get_most_common_tags(self, all_tags: List[str], top_k: int = 5) -> List[tuple]:
"""Get the most common tags in the dataset."""
from collections import Counter
tag_counts = Counter(all_tags)
return tag_counts.most_common(top_k)
def filter_by_tags(self, tags: List[str]) -> 'ProgramSynthesisDataset':
"""Create a new dataset filtered by tags."""
filtered_datapoints = []
for dp in self.datapoints:
if any(tag in dp.tags for tag in tags):
filtered_datapoints.append(dp)
new_dataset = ProgramSynthesisDataset.__new__(ProgramSynthesisDataset)
new_dataset.data_file = self.data_file
new_dataset.difficulty_cutoff = self.difficulty_cutoff
new_dataset.max_samples = self.max_samples
new_dataset.datapoints = filtered_datapoints
return new_dataset