Skip to content

Commit 73ac00a

Browse files
authored
[fix] handle parquet schema mismatch in dataset concatenation (#146)
* [fix] handle parquet schema mismatch in dataset concatenation When loading multiple parquet files via YAML config, concatenate_datasets() fails if columns have different inferred Arrow types. Add a safe wrapper that falls back to row-wise concatenation on schema mismatch. * refactor: use cast() instead of to_list() for schema alignment Replace the to_list() + from_list() fallback with cast()-based schema alignment. This avoids materializing the entire dataset into memory, making it safe for large-scale data. Strategy: 1. Try cast() to the first dataset's features (fast, zero-copy) 2. If that fails, build a merged feature set preferring non-null types and cast all datasets to the merged schema --------- Co-authored-by: mwxely <mwxely@users.noreply.github.com>
1 parent 87a1f86 commit 73ac00a

1 file changed

Lines changed: 47 additions & 2 deletions

File tree

src/lmms_engine/utils/data_utils.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,51 @@
2121
FPS_MAX_FRAMES = 768
2222

2323

24+
def _safe_concatenate_datasets(data_list):
25+
"""Concatenate datasets with automatic schema alignment.
26+
27+
When loading multiple parquet files, schema mismatches can occur if columns
28+
have different inferred types (e.g., a struct column is null in one file but
29+
has nested fields in another). This function handles such cases by casting
30+
all datasets to a unified schema before concatenation.
31+
32+
Args:
33+
data_list: List of HuggingFace Dataset objects to concatenate.
34+
35+
Returns:
36+
A single concatenated Dataset.
37+
"""
38+
if len(data_list) <= 1:
39+
return concatenate_datasets(data_list)
40+
try:
41+
return concatenate_datasets(data_list)
42+
except Exception as e:
43+
logger.warning(
44+
f"Direct concatenation failed due to schema mismatch: {e}. " f"Attempting schema alignment via cast."
45+
)
46+
# Use the first dataset's features as the target schema and cast others to match.
47+
# This avoids the memory overhead of to_list() on large datasets.
48+
target_features = data_list[0].features
49+
aligned = [data_list[0]]
50+
for ds in data_list[1:]:
51+
try:
52+
aligned.append(ds.cast(target_features))
53+
except Exception:
54+
# If cast to first dataset's schema fails, build a merged
55+
# feature set from all datasets, preferring non-null types.
56+
from datasets import Features
57+
58+
merged = {}
59+
for d in data_list:
60+
for col_name, col_type in d.features.items():
61+
if col_name not in merged or str(merged[col_name]) == "Value(dtype='null')":
62+
merged[col_name] = col_type
63+
merged_features = Features(merged)
64+
aligned = [d.cast(merged_features) for d in data_list]
65+
break
66+
return concatenate_datasets(aligned)
67+
68+
2469
class DataUtilities:
2570
@staticmethod
2671
def load_json(path: str) -> List[Dict[str, List]]:
@@ -104,7 +149,7 @@ def load_yaml(path: str) -> Tuple[List[Dict[str, List]], List[str]]:
104149
data_list.append(data)
105150
logger.info(f"Dataset size: {len(data)}")
106151
data_folder_list.extend([data_folder] * len(data))
107-
data_list = concatenate_datasets(data_list)
152+
data_list = _safe_concatenate_datasets(data_list)
108153
return data_list, data_folder_list
109154

110155
@staticmethod
@@ -222,7 +267,7 @@ def load_inline_datasets(
222267
data_list.append(data)
223268
logger.info(f"Dataset size: {len(data)}")
224269
data_folder_list.extend([data_folder] * len(data))
225-
data_list = concatenate_datasets(data_list)
270+
data_list = _safe_concatenate_datasets(data_list)
226271

227272
return data_list, data_folder_list
228273

0 commit comments

Comments
 (0)