|
45 | 45 | LOGGER = logging.getLogger(__name__) |
46 | 46 |
|
47 | 47 | ATTRS_TO_CHECK = { |
48 | | - "event_name": (list, (str, np.str_, int, np.integer)), # int for backward compatibility |
| 48 | + "event_name": (list, (str, np.str_)), # int for backward compatibility |
49 | 49 | "event_id": (np.ndarray, None), |
50 | 50 | "frequency": (np.ndarray, float), |
51 | 51 | "frequency_unit": (str, None), |
52 | | - "date": (np.ndarray, (int, np.integer)), |
| 52 | + "date": (np.ndarray, (int, np.int64)), |
53 | 53 | "orig": (np.ndarray, (bool, np.bool_)), |
54 | 54 | "unit": (str, None), # For backward compatibility. Replaced by units. |
55 | 55 | "units": (str, None), |
@@ -706,11 +706,13 @@ def _check_and_cast_elements( |
706 | 706 | # Perform type checking and casting of elements |
707 | 707 | if isinstance(attr_value, (list, np.ndarray)): |
708 | 708 | if not all(isinstance(val, expected_dtype) for val in attr_value): |
| 709 | + provided_types = set(type(val) for val in attr_value) |
709 | 710 | warnings.warn( |
710 | | - f"Not all values are of type {expected_dtype}. Casting values.", |
| 711 | + f"Not all values are type {expected_dtype}. Provided type(s): {provided_types}. Casting values.", |
711 | 712 | UserWarning, |
712 | 713 | ) |
713 | | - casted_values = [expected_dtype(val) for val in attr_value] |
| 714 | + cast_dtype = expected_dtype if not isinstance(expected_dtype, tuple) else expected_dtype[0] |
| 715 | + casted_values = [cast_dtype(val) for val in attr_value] |
714 | 716 | # Return the casted values in the same container type |
715 | 717 | if container_type is list: |
716 | 718 | return casted_values |
|
0 commit comments