Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 47 additions & 2 deletions src/lmms_engine/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,51 @@
FPS_MAX_FRAMES = 768


def _safe_concatenate_datasets(data_list):
"""Concatenate datasets with automatic schema alignment.

When loading multiple parquet files, schema mismatches can occur if columns
have different inferred types (e.g., a struct column is null in one file but
has nested fields in another). This function handles such cases by casting
all datasets to a unified schema before concatenation.

Args:
data_list: List of HuggingFace Dataset objects to concatenate.

Returns:
A single concatenated Dataset.
"""
if len(data_list) <= 1:
return concatenate_datasets(data_list)
try:
return concatenate_datasets(data_list)
except Exception as e:
logger.warning(
f"Direct concatenation failed due to schema mismatch: {e}. " f"Attempting schema alignment via cast."
)
# Use the first dataset's features as the target schema and cast others to match.
# This avoids the memory overhead of to_list() on large datasets.
target_features = data_list[0].features
aligned = [data_list[0]]
for ds in data_list[1:]:
try:
aligned.append(ds.cast(target_features))
except Exception:
# If cast to first dataset's schema fails, build a merged
# feature set from all datasets, preferring non-null types.
from datasets import Features

merged = {}
for d in data_list:
for col_name, col_type in d.features.items():
if col_name not in merged or str(merged[col_name]) == "Value(dtype='null')":
merged[col_name] = col_type
merged_features = Features(merged)
aligned = [d.cast(merged_features) for d in data_list]
break
return concatenate_datasets(aligned)


class DataUtilities:
@staticmethod
def load_json(path: str) -> List[Dict[str, List]]:
Expand Down Expand Up @@ -104,7 +149,7 @@ def load_yaml(path: str) -> Tuple[List[Dict[str, List]], List[str]]:
data_list.append(data)
logger.info(f"Dataset size: {len(data)}")
data_folder_list.extend([data_folder] * len(data))
data_list = concatenate_datasets(data_list)
data_list = _safe_concatenate_datasets(data_list)
return data_list, data_folder_list

@staticmethod
Expand Down Expand Up @@ -222,7 +267,7 @@ def load_inline_datasets(
data_list.append(data)
logger.info(f"Dataset size: {len(data)}")
data_folder_list.extend([data_folder] * len(data))
data_list = concatenate_datasets(data_list)
data_list = _safe_concatenate_datasets(data_list)

return data_list, data_folder_list

Expand Down
Loading