Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 30 additions & 2 deletions src/lmms_engine/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,34 @@
FPS_MAX_FRAMES = 768


def _safe_concatenate_datasets(data_list):
"""Concatenate datasets with automatic schema alignment.

When loading multiple parquet files, schema mismatches can occur if columns
have different inferred types (e.g., a struct column is null in one file but
has nested fields in another). This function handles such cases by falling
back to row-wise concatenation when direct concatenation fails.

Args:
data_list: List of HuggingFace Dataset objects to concatenate.

Returns:
A single concatenated Dataset.
"""
if len(data_list) <= 1:
return concatenate_datasets(data_list)
try:
return concatenate_datasets(data_list)
except Exception as e:
logger.warning(
f"Direct concatenation failed due to schema mismatch: {e}. " f"Falling back to row-wise concatenation."
)
all_rows = []
for ds in data_list:
all_rows.extend(ds.to_list())
return Dataset.from_list(all_rows)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be cause extremely slow and large memory needed under large scale data. Recommend to preprocess or cast dataset before training to ensure the dataset is in the same features.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point — I've replaced the to_list() fallback with cast()-based schema alignment in e28d773. The new approach:

  1. Tries cast() to the first dataset's features (fast, zero-copy for compatible schemas)
  2. If that fails, builds a merged feature set preferring non-null types and casts all datasets to it

This keeps the fix zero-overhead on the happy path and avoids materializing the full dataset into memory on schema mismatch.



class DataUtilities:
@staticmethod
def load_json(path: str) -> List[Dict[str, List]]:
Expand Down Expand Up @@ -104,7 +132,7 @@ def load_yaml(path: str) -> Tuple[List[Dict[str, List]], List[str]]:
data_list.append(data)
logger.info(f"Dataset size: {len(data)}")
data_folder_list.extend([data_folder] * len(data))
data_list = concatenate_datasets(data_list)
data_list = _safe_concatenate_datasets(data_list)
return data_list, data_folder_list

@staticmethod
Expand Down Expand Up @@ -222,7 +250,7 @@ def load_inline_datasets(
data_list.append(data)
logger.info(f"Dataset size: {len(data)}")
data_folder_list.extend([data_folder] * len(data))
data_list = concatenate_datasets(data_list)
data_list = _safe_concatenate_datasets(data_list)

return data_list, data_folder_list

Expand Down
Loading