Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
python3 -m graphgen.run \
--config_file examples/generate/generate_aggregated_qa/huggingface_config.yaml
83 changes: 83 additions & 0 deletions examples/generate/generate_aggregated_qa/huggingface_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
global_params:
working_dir: cache
graph_backend: networkx # graph database backend, support: kuzu, networkx
kv_backend: json_kv # key-value store backend, support: rocksdb, json_kv

nodes:
- id: read_hf_dataset # Read from Hugging Face Hub
op_name: read
type: source
dependencies: []
params:
input_path:
- huggingface://wikitext:wikitext-103-v1:train # Format: huggingface://dataset_name:subset:split
# Optional parameters for HuggingFaceReader:
text_column: text # Column name containing text content (default: content)
# cache_dir: /path/to/cache # Optional: directory to cache downloaded datasets
# trust_remote_code: false # Optional: whether to trust remote code in datasets

- id: chunk_documents
op_name: chunk
type: map_batch
dependencies:
- read_hf_dataset
execution_params:
replicas: 4
params:
chunk_size: 1024
chunk_overlap: 100

- id: build_kg
op_name: build_kg
type: map_batch
dependencies:
- chunk_documents
execution_params:
replicas: 1
batch_size: 128

- id: quiz
op_name: quiz
type: map_batch
dependencies:
- build_kg
execution_params:
replicas: 1
batch_size: 128
params:
quiz_samples: 2

- id: judge
op_name: judge
type: map_batch
dependencies:
- quiz
execution_params:
replicas: 1
batch_size: 128

- id: partition
op_name: partition
type: aggregate
dependencies:
- judge
params:
method: ece
method_params:
max_units_per_community: 20
min_units_per_community: 5
max_tokens_per_community: 10240
unit_sampling: max_loss

- id: generate
op_name: generate
type: map_batch
dependencies:
- partition
execution_params:
replicas: 1
batch_size: 128
save_output: true
params:
method: aggregated
data_format: ChatML
2 changes: 2 additions & 0 deletions graphgen/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
)
from .reader import (
CSVReader,
HuggingFaceReader,
JSONReader,
ParquetReader,
PDFReader,
Expand Down Expand Up @@ -92,6 +93,7 @@
"PickleReader": ".reader",
"RDFReader": ".reader",
"TXTReader": ".reader",
"HuggingFaceReader": ".reader",
# Searcher
"NCBISearch": ".searcher.db.ncbi_searcher",
"RNACentralSearch": ".searcher.db.rnacentral_searcher",
Expand Down
1 change: 1 addition & 0 deletions graphgen/models/reader/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .csv_reader import CSVReader
from .huggingface_reader import HuggingFaceReader
from .json_reader import JSONReader
from .parquet_reader import ParquetReader
from .pdf_reader import PDFReader
Expand Down
194 changes: 194 additions & 0 deletions graphgen/models/reader/huggingface_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
"""
Hugging Face Datasets Reader
This module provides a reader for accessing datasets from Hugging Face Hub.
"""

from typing import TYPE_CHECKING, List, Optional, Union

from graphgen.bases.base_reader import BaseReader

if TYPE_CHECKING:
import ray
Comment thread
github-code-quality[bot] marked this conversation as resolved.
Fixed
from ray.data import Dataset


class HuggingFaceReader(BaseReader):
"""
Reader for Hugging Face Datasets.

Supports loading datasets from the Hugging Face Hub.
Can specify a dataset by name and optional subset/split.

Columns:
- type: The type of the document (e.g., "text", "image", etc.)
- if type is "text", "content" column must be present (or specify via text_column).

Example:
reader = HuggingFaceReader(text_column="text")
ds = reader.read("wikitext")
# or with split and subset
ds = reader.read("wikitext:wikitext-103-v1:train")
"""

def __init__(
self,
text_column: str = "content",
modalities: Optional[list] = None,
cache_dir: Optional[str] = None,
trust_remote_code: bool = False,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

The HuggingFaceReader class introduces the trust_remote_code parameter which is passed directly to the Hugging Face datasets.load_dataset function. When set to True, this allows the execution of arbitrary Python code contained within the dataset repository (e.g., in the loading script). Since this parameter is exposed to the end-user via the configuration file (through reader_kwargs), it creates a significant risk of Remote Code Execution (RCE) if an attacker can provide or influence the configuration. While the default is False, exposing this dangerous functionality to the configuration without adequate warnings or restrictions is a security concern. Consider removing this parameter from the configuration or implementing a strict allow-list for trusted datasets.

):
"""
Initialize HuggingFaceReader.

:param text_column: Column name containing text content
:param modalities: List of supported modalities
:param cache_dir: Directory to cache downloaded datasets
:param trust_remote_code: Whether to trust remote code in datasets
"""
super().__init__(text_column=text_column, modalities=modalities)
self.cache_dir = cache_dir
self.trust_remote_code = trust_remote_code

def read(
self,
input_path: Union[str, List[str]],
split: Optional[str] = None,
subset: Optional[str] = None,
streaming: bool = False,
limit: Optional[int] = None,
) -> "Dataset":
"""
Read dataset from Hugging Face Hub.

:param input_path: Dataset identifier(s) from Hugging Face Hub
Format: "dataset_name" or "dataset_name:subset:split"
Example: "wikitext" or "wikitext:wikitext-103-v1:train"
:param split: Specific split to load (overrides split in path)
:param subset: Specific subset/configuration to load (overrides subset in path)
:param streaming: Whether to stream the dataset instead of downloading
:param limit: Maximum number of samples to load
:return: Ray Dataset containing the data
"""
try:
import datasets as hf_datasets
except ImportError as exc:
raise ImportError(
"The 'datasets' package is required to use HuggingFaceReader. "
"Please install it with: pip install datasets"
) from exc

if isinstance(input_path, list):
# Handle multiple datasets
all_dss = []
for path in input_path:
ds = self._load_single_dataset(
path,
split=split,
subset=subset,
streaming=streaming,
limit=limit,
hf_datasets=hf_datasets,
)
all_dss.append(ds)

if len(all_dss) == 1:
combined_ds = all_dss[0]
else:
combined_ds = all_dss[0].union(*all_dss[1:])
Comment on lines +95 to +98
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

If input_path is an empty list, all_dss will also be empty. This will cause an IndexError on line 97 when trying to access all_dss[0]. You should handle the case of an empty list of datasets to avoid this crash.

            if not all_dss:
                import ray

                return ray.data.from_items([])

            if len(all_dss) == 1:
                combined_ds = all_dss[0]
            else:
                combined_ds = all_dss[0].union(*all_dss[1:])

Comment on lines +81 to +98
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

If input_path is an empty list, all_dss will also be empty. This will cause an IndexError on line 98 (or 96) when trying to access all_dss[0]. You should handle the case of an empty input_path list to prevent a crash.

        if isinstance(input_path, list):
            if not input_path:
                import ray

                return ray.data.from_items([])

            # Handle multiple datasets
            all_dss = []
            for path in input_path:
                ds = self._load_single_dataset(
                    path,
                    split=split,
                    subset=subset,
                    streaming=streaming,
                    limit=limit,
                    hf_datasets=hf_datasets,
                )
                all_dss.append(ds)

            if len(all_dss) == 1:
                combined_ds = all_dss[0]
            else:
                combined_ds = all_dss[0].union(*all_dss[1:])

else:
combined_ds = self._load_single_dataset(
input_path,
split=split,
subset=subset,
streaming=streaming,
limit=limit,
hf_datasets=hf_datasets,
)

# Validate and filter
combined_ds = combined_ds.map_batches(
self._validate_batch, batch_format="pandas"
)
combined_ds = combined_ds.filter(self._should_keep_item)

return combined_ds

def _load_single_dataset(
self,
dataset_path: str,
split: Optional[str] = None,
subset: Optional[str] = None,
streaming: bool = False,
limit: Optional[int] = None,
hf_datasets=None,
) -> "Dataset":
"""
Load a single dataset from Hugging Face Hub.

:param dataset_path: Dataset path, can include subset and split
:param split: Override split
:param subset: Override subset
:param streaming: Whether to stream
:param limit: Max samples
:param hf_datasets: Imported datasets module
:return: Ray Dataset
"""
import ray

# Parse dataset path format: "dataset_name:subset:split"
parts = dataset_path.split(":")
dataset_name = parts[0]
parsed_subset = parts[1] if len(parts) > 1 else None
parsed_split = parts[2] if len(parts) > 2 else None

# Override with explicit parameters
final_subset = subset or parsed_subset
final_split = split or parsed_split or "train"

# Load dataset from Hugging Face
load_kwargs = {
"cache_dir": self.cache_dir,
"trust_remote_code": self.trust_remote_code,
"streaming": streaming,
}

if final_subset:
load_kwargs["name"] = final_subset

hf_dataset = hf_datasets.load_dataset(
dataset_name, split=final_split, **load_kwargs
)

# Convert to pandas and then to Ray dataset
# Add type column if not present
dataset_dict = hf_dataset.to_dict()

# Ensure data is in list of dicts format
if isinstance(dataset_dict, dict) and all(
isinstance(v, list) for v in dataset_dict.values()
):
# Convert from column-based to row-based format
num_rows = len(next(iter(dataset_dict.values())))
data = [
{key: dataset_dict[key][i] for key in dataset_dict}
for i in range(num_rows)
]
else:
data = dataset_dict

# Add type field if not present
for item in data:
if "type" not in item:
item["type"] = "text"
# Rename text_column to 'content' if different
if self.text_column != "content" and self.text_column in item:
item["content"] = item[self.text_column]

# Apply limit if specified
if limit:
data = data[:limit]

# Create Ray dataset
ray_ds = ray.data.from_items(data)

return ray_ds
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The current implementation loads the entire Hugging Face dataset into memory using hf_dataset.to_dict(). This is highly inefficient for large datasets and can easily lead to out-of-memory errors. It also defeats the purpose of using streaming=True, as it will still try to materialize the entire dataset.

A much more scalable approach is to use ray.data.from_huggingface() to convert the Hugging Face dataset object directly into a Ray dataset. Post-processing steps like adding columns or renaming them should be done using map_batches on the Ray dataset. This will allow for lazy evaluation and distributed processing, which is crucial for handling large-scale data.

        # Convert to Ray dataset
        if limit:
            if streaming:
                hf_dataset = hf_dataset.take(limit)
            else:
                hf_dataset = hf_dataset.select(range(limit))

        ray_ds = ray.data.from_huggingface(hf_dataset)

        def _process_batch(batch: dict[str, list]) -> dict[str, list]:
            if not batch:
                return {}
            num_rows = len(next(iter(batch.values())))
            if "type" not in batch:
                batch["type"] = ["text"] * num_rows

            if self.text_column != "content" and self.text_column in batch:
                batch["content"] = batch[self.text_column]

            return batch

        # Add type field and rename text_column in a scalable way
        ray_ds = ray_ds.map_batches(_process_batch)

        return ray_ds

Loading
Loading