|
5 | 5 | import csv |
6 | 6 |
|
7 | 7 | from collections import Counter |
8 | | -from typing import Final, Any |
| 8 | +from typing import Final, Any, NamedTuple |
9 | 9 |
|
10 | 10 | from .utilities import SchemaType |
11 | 11 |
|
|
260 | 260 | } |
261 | 261 |
|
262 | 262 |
|
| 263 | +class ColumnValidationResult(NamedTuple): |
| 264 | + is_valid: bool |
| 265 | + unexpected_columns: list[str] |
| 266 | + missing_required_columns: list[str] |
| 267 | + |
| 268 | + |
263 | 269 | def validate_file(filename: str, allowed_types: set[SchemaType]) -> set[SchemaType]: |
264 | 270 | """Validates given a filename.""" |
265 | 271 | with open(filename) as f: |
@@ -299,26 +305,55 @@ def get_col_names(f: Any) -> Any: |
299 | 305 |
|
300 | 306 | def detect_file_type(col_names: list[str]) -> set[SchemaType]: |
301 | 307 | """Returns all schemas that match for a list of col names.""" |
302 | | - res = set() |
303 | | - for schema, schema_cols in SCHEMA_TYPE_TO_COLS.items(): |
| 308 | + matches = set() |
| 309 | + errors = {} |
| 310 | + |
| 311 | + for schema, expected_cols in SCHEMA_TYPE_TO_COLS.items(): |
304 | 312 | optional_cols = SCHEMA_TYPE_TO_OPTIONAL_COLS[schema] |
305 | | - if valid_subset_lists(schema_cols, col_names, optional_cols): |
306 | | - res.add(schema) |
307 | | - if not res: |
308 | | - # If it doesn't match any, it will match unknown. |
309 | | - res.add(SchemaType.UNKNOWN) |
310 | | - return res |
| 313 | + result = valid_subset_lists(expected_cols, col_names, optional_cols) |
| 314 | + |
| 315 | + if result.is_valid: |
| 316 | + matches.add(schema) |
| 317 | + else: |
| 318 | + errors[schema.name] = result |
| 319 | + |
| 320 | + if matches: |
| 321 | + return matches |
| 322 | + |
| 323 | + error_msgs = [] |
| 324 | + for schema_name, res in errors.items(): |
| 325 | + msg = f"\nSchema: {schema_name}" |
| 326 | + if res.unexpected_columns: |
| 327 | + msg += f"\n Unexpected columns: {res.unexpected_columns}" |
| 328 | + if res.missing_required_columns: |
| 329 | + msg += f"\n Missing required columns: {res.missing_required_columns}" |
| 330 | + error_msgs.append(msg) |
| 331 | + |
| 332 | + raise ValueError( |
| 333 | + "No valid schema matched. Details of mismatches:\n" + "\n".join(error_msgs) |
| 334 | + ) |
311 | 335 |
|
312 | 336 |
|
313 | 337 | def valid_subset_lists( |
314 | | - superset_list: list[str], subset_list: list[str], optional_list: list[str] |
315 | | -) -> bool: |
| 338 | + expected: list[str], actual: list[str], optional_list: list[str] |
| 339 | +) -> ColumnValidationResult: |
316 | 340 | """Checks if the subset_list is a subset of or equivalent to superset_list. And if so, |
317 | 341 | whether the missing values are all present in the optional list. This method disregards order |
318 | 342 | but cares about duplicates.""" |
319 | 343 | # Checks if any value in subset list is NOT present in superset list. |
320 | | - if Counter(subset_list) - Counter(superset_list): |
321 | | - # This is not a valid state, users should not be passing in unrecognized columns. |
322 | | - return False |
323 | | - missing_vals = Counter(superset_list) - Counter(subset_list) |
324 | | - return not Counter(missing_vals) - Counter(optional_list) |
| 344 | + expected_counter = Counter(expected) |
| 345 | + actual_counter = Counter(actual) |
| 346 | + |
| 347 | + unexpected = list((actual_counter - expected_counter).elements()) |
| 348 | + |
| 349 | + # Columns expected but missing (excluding optional) |
| 350 | + missing_total = list((expected_counter - actual_counter).elements()) |
| 351 | + missing_required = [col for col in missing_total if col not in optional_list] |
| 352 | + |
| 353 | + is_valid = not unexpected and not missing_required |
| 354 | + |
| 355 | + return ColumnValidationResult( |
| 356 | + is_valid=is_valid, |
| 357 | + unexpected_columns=unexpected, |
| 358 | + missing_required_columns=missing_required, |
| 359 | + ) |
0 commit comments