Skip to content

Commit 155d896

Browse files
fix: fix huggingface_reader.py
1 parent c7198d8 commit 155d896

1 file changed

Lines changed: 36 additions & 27 deletions

File tree

graphgen/models/reader/huggingface_reader.py

Lines changed: 36 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from graphgen.bases.base_reader import BaseReader
99

1010
if TYPE_CHECKING:
11+
import numpy as np
1112
import ray
1213
from ray.data import Dataset
1314

@@ -133,6 +134,7 @@ def _load_single_dataset(
133134
:param hf_datasets: Imported datasets module
134135
:return: Ray Dataset
135136
"""
137+
import numpy as np
136138
import ray
137139

138140
# Parse dataset path format: "dataset_name:subset:split"
@@ -159,36 +161,43 @@ def _load_single_dataset(
159161
dataset_name, split=final_split, **load_kwargs
160162
)
161163

162-
# Convert to pandas and then to Ray dataset
163-
# Add type column if not present
164-
dataset_dict = hf_dataset.to_dict()
165-
166-
# Ensure data is in list of dicts format
167-
if isinstance(dataset_dict, dict) and all(
168-
isinstance(v, list) for v in dataset_dict.values()
169-
):
170-
# Convert from column-based to row-based format
171-
num_rows = len(next(iter(dataset_dict.values())))
172-
data = [
173-
{key: dataset_dict[key][i] for key in dataset_dict}
174-
for i in range(num_rows)
175-
]
176-
else:
177-
data = dataset_dict
164+
# Apply limit before converting to Ray dataset for memory efficiency
165+
if limit:
166+
if streaming:
167+
hf_dataset = hf_dataset.take(limit)
168+
else:
169+
hf_dataset = hf_dataset.select(range(limit))
170+
171+
# Convert to Ray dataset using lazy evaluation
172+
ray_ds = ray.data.from_huggingface(hf_dataset)
173+
174+
# Define batch processing function for lazy evaluation
175+
def _process_batch(batch: dict[str, "np.ndarray"]) -> dict[str, "np.ndarray"]:
176+
"""
177+
Process a batch of data to add type field and rename text column.
178+
179+
:param batch: A dictionary with column names as keys and numpy arrays
180+
:return: Processed batch dictionary with numpy arrays
181+
"""
182+
if not batch:
183+
return {}
184+
185+
# Get the number of rows in the batch
186+
num_rows = len(next(iter(batch.values())))
187+
188+
# Add type field if not present
189+
if "type" not in batch:
190+
batch["type"] = np.array(["text"] * num_rows)
178191

179-
# Add type field if not present
180-
for item in data:
181-
if "type" not in item:
182-
item["type"] = "text"
183192
# Rename text_column to 'content' if different
184-
if self.text_column != "content" and self.text_column in item:
185-
item["content"] = item[self.text_column]
193+
if self.text_column != "content" and self.text_column in batch:
194+
batch["content"] = batch[self.text_column]
195+
# Optional: delete old key to avoid duplication
196+
# del batch[self.text_column]
186197

187-
# Apply limit if specified
188-
if limit:
189-
data = data[:limit]
198+
return batch
190199

191-
# Create Ray dataset
192-
ray_ds = ray.data.from_items(data)
200+
# Apply post-processing using map_batches for distributed lazy evaluation
201+
ray_ds = ray_ds.map_batches(_process_batch)
193202

194203
return ray_ds

0 commit comments

Comments
 (0)