Skip to content

Commit e3a11c9

Browse files
feat: add graphgen reader
1 parent 33cc281 commit e3a11c9

6 files changed

Lines changed: 424 additions & 68 deletions

File tree

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
python3 -m graphgen.run \
2+
--config_file examples/generate/generate_aggregated_qa/huggingface_config.yaml
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
global_params:
2+
working_dir: cache
3+
graph_backend: networkx # graph database backend, support: kuzu, networkx
4+
kv_backend: json_kv # key-value store backend, support: rocksdb, json_kv
5+
6+
nodes:
7+
- id: read_hf_dataset # Read from Hugging Face Hub
8+
op_name: read
9+
type: source
10+
dependencies: []
11+
params:
12+
input_path:
13+
- huggingface://wikitext:wikitext-103-v1:train # Format: huggingface://dataset_name:subset:split
14+
# Optional parameters for HuggingFaceReader:
15+
text_column: text # Column name containing text content (default: content)
16+
# cache_dir: /path/to/cache # Optional: directory to cache downloaded datasets
17+
# trust_remote_code: false # Optional: whether to trust remote code in datasets
18+
19+
- id: chunk_documents
20+
op_name: chunk
21+
type: map_batch
22+
dependencies:
23+
- read_hf_dataset
24+
execution_params:
25+
replicas: 4
26+
params:
27+
chunk_size: 1024
28+
chunk_overlap: 100
29+
30+
- id: build_kg
31+
op_name: build_kg
32+
type: map_batch
33+
dependencies:
34+
- chunk_documents
35+
execution_params:
36+
replicas: 1
37+
batch_size: 128
38+
39+
- id: quiz
40+
op_name: quiz
41+
type: map_batch
42+
dependencies:
43+
- build_kg
44+
execution_params:
45+
replicas: 1
46+
batch_size: 128
47+
params:
48+
quiz_samples: 2
49+
50+
- id: judge
51+
op_name: judge
52+
type: map_batch
53+
dependencies:
54+
- quiz
55+
execution_params:
56+
replicas: 1
57+
batch_size: 128
58+
59+
- id: partition
60+
op_name: partition
61+
type: aggregate
62+
dependencies:
63+
- judge
64+
params:
65+
method: ece
66+
method_params:
67+
max_units_per_community: 20
68+
min_units_per_community: 5
69+
max_tokens_per_community: 10240
70+
unit_sampling: max_loss
71+
72+
- id: generate
73+
op_name: generate
74+
type: map_batch
75+
dependencies:
76+
- partition
77+
execution_params:
78+
replicas: 1
79+
batch_size: 128
80+
save_output: true
81+
params:
82+
method: aggregated
83+
data_format: ChatML

graphgen/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
)
3434
from .reader import (
3535
CSVReader,
36+
HuggingFaceReader,
3637
JSONReader,
3738
ParquetReader,
3839
PDFReader,
@@ -92,6 +93,7 @@
9293
"PickleReader": ".reader",
9394
"RDFReader": ".reader",
9495
"TXTReader": ".reader",
96+
"HuggingFaceReader": ".reader",
9597
# Searcher
9698
"NCBISearch": ".searcher.db.ncbi_searcher",
9799
"RNACentralSearch": ".searcher.db.rnacentral_searcher",

graphgen/models/reader/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .csv_reader import CSVReader
2+
from .huggingface_reader import HuggingFaceReader
23
from .json_reader import JSONReader
34
from .parquet_reader import ParquetReader
45
from .pdf_reader import PDFReader
Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
"""
2+
Hugging Face Datasets Reader
3+
This module provides a reader for accessing datasets from Hugging Face Hub.
4+
"""
5+
6+
from typing import TYPE_CHECKING, List, Optional, Union
7+
8+
from graphgen.bases.base_reader import BaseReader
9+
10+
if TYPE_CHECKING:
11+
import ray
12+
from ray.data import Dataset
13+
14+
15+
class HuggingFaceReader(BaseReader):
16+
"""
17+
Reader for Hugging Face Datasets.
18+
19+
Supports loading datasets from the Hugging Face Hub.
20+
Can specify a dataset by name and optional subset/split.
21+
22+
Columns:
23+
- type: The type of the document (e.g., "text", "image", etc.)
24+
- if type is "text", "content" column must be present (or specify via text_column).
25+
26+
Example:
27+
reader = HuggingFaceReader(text_column="text")
28+
ds = reader.read("wikitext")
29+
# or with split and subset
30+
ds = reader.read("wikitext:wikitext-103-v1:train")
31+
"""
32+
33+
def __init__(
34+
self,
35+
text_column: str = "content",
36+
modalities: Optional[list] = None,
37+
cache_dir: Optional[str] = None,
38+
trust_remote_code: bool = False,
39+
):
40+
"""
41+
Initialize HuggingFaceReader.
42+
43+
:param text_column: Column name containing text content
44+
:param modalities: List of supported modalities
45+
:param cache_dir: Directory to cache downloaded datasets
46+
:param trust_remote_code: Whether to trust remote code in datasets
47+
"""
48+
super().__init__(text_column=text_column, modalities=modalities)
49+
self.cache_dir = cache_dir
50+
self.trust_remote_code = trust_remote_code
51+
52+
def read(
53+
self,
54+
input_path: Union[str, List[str]],
55+
split: Optional[str] = None,
56+
subset: Optional[str] = None,
57+
streaming: bool = False,
58+
limit: Optional[int] = None,
59+
) -> "Dataset":
60+
"""
61+
Read dataset from Hugging Face Hub.
62+
63+
:param input_path: Dataset identifier(s) from Hugging Face Hub
64+
Format: "dataset_name" or "dataset_name:subset:split"
65+
Example: "wikitext" or "wikitext:wikitext-103-v1:train"
66+
:param split: Specific split to load (overrides split in path)
67+
:param subset: Specific subset/configuration to load (overrides subset in path)
68+
:param streaming: Whether to stream the dataset instead of downloading
69+
:param limit: Maximum number of samples to load
70+
:return: Ray Dataset containing the data
71+
"""
72+
try:
73+
import datasets as hf_datasets
74+
except ImportError as exc:
75+
raise ImportError(
76+
"The 'datasets' package is required to use HuggingFaceReader. "
77+
"Please install it with: pip install datasets"
78+
) from exc
79+
80+
if isinstance(input_path, list):
81+
# Handle multiple datasets
82+
all_dss = []
83+
for path in input_path:
84+
ds = self._load_single_dataset(
85+
path,
86+
split=split,
87+
subset=subset,
88+
streaming=streaming,
89+
limit=limit,
90+
hf_datasets=hf_datasets,
91+
)
92+
all_dss.append(ds)
93+
94+
if len(all_dss) == 1:
95+
combined_ds = all_dss[0]
96+
else:
97+
combined_ds = all_dss[0].union(*all_dss[1:])
98+
else:
99+
combined_ds = self._load_single_dataset(
100+
input_path,
101+
split=split,
102+
subset=subset,
103+
streaming=streaming,
104+
limit=limit,
105+
hf_datasets=hf_datasets,
106+
)
107+
108+
# Validate and filter
109+
combined_ds = combined_ds.map_batches(
110+
self._validate_batch, batch_format="pandas"
111+
)
112+
combined_ds = combined_ds.filter(self._should_keep_item)
113+
114+
return combined_ds
115+
116+
def _load_single_dataset(
117+
self,
118+
dataset_path: str,
119+
split: Optional[str] = None,
120+
subset: Optional[str] = None,
121+
streaming: bool = False,
122+
limit: Optional[int] = None,
123+
hf_datasets=None,
124+
) -> "Dataset":
125+
"""
126+
Load a single dataset from Hugging Face Hub.
127+
128+
:param dataset_path: Dataset path, can include subset and split
129+
:param split: Override split
130+
:param subset: Override subset
131+
:param streaming: Whether to stream
132+
:param limit: Max samples
133+
:param hf_datasets: Imported datasets module
134+
:return: Ray Dataset
135+
"""
136+
import ray
137+
138+
# Parse dataset path format: "dataset_name:subset:split"
139+
parts = dataset_path.split(":")
140+
dataset_name = parts[0]
141+
parsed_subset = parts[1] if len(parts) > 1 else None
142+
parsed_split = parts[2] if len(parts) > 2 else None
143+
144+
# Override with explicit parameters
145+
final_subset = subset or parsed_subset
146+
final_split = split or parsed_split or "train"
147+
148+
# Load dataset from Hugging Face
149+
load_kwargs = {
150+
"cache_dir": self.cache_dir,
151+
"trust_remote_code": self.trust_remote_code,
152+
"streaming": streaming,
153+
}
154+
155+
if final_subset:
156+
load_kwargs["name"] = final_subset
157+
158+
hf_dataset = hf_datasets.load_dataset(
159+
dataset_name, split=final_split, **load_kwargs
160+
)
161+
162+
# Convert to pandas and then to Ray dataset
163+
# Add type column if not present
164+
dataset_dict = hf_dataset.to_dict()
165+
166+
# Ensure data is in list of dicts format
167+
if isinstance(dataset_dict, dict) and all(
168+
isinstance(v, list) for v in dataset_dict.values()
169+
):
170+
# Convert from column-based to row-based format
171+
num_rows = len(next(iter(dataset_dict.values())))
172+
data = [
173+
{key: dataset_dict[key][i] for key in dataset_dict}
174+
for i in range(num_rows)
175+
]
176+
else:
177+
data = dataset_dict
178+
179+
# Add type field if not present
180+
for item in data:
181+
if "type" not in item:
182+
item["type"] = "text"
183+
# Rename text_column to 'content' if different
184+
if self.text_column != "content" and self.text_column in item:
185+
item["content"] = item[self.text_column]
186+
187+
# Apply limit if specified
188+
if limit:
189+
data = data[:limit]
190+
191+
# Create Ray dataset
192+
ray_ds = ray.data.from_items(data)
193+
194+
return ray_ds

0 commit comments

Comments
 (0)