|
17 | 17 | with_pandas = False |
18 | 18 |
|
19 | 19 |
|
20 | | -def _to_key(x, use_tuple=False): |
21 | | - if x.values.size > 1 or use_tuple: |
22 | | - return tuple(x.values) |
23 | | - return x.item() |
| 20 | +def _mapping_uses_tuple_keys(mapping): |
| 21 | + return bool(mapping) and isinstance(next(iter(mapping)), tuple) |
| 22 | + |
| 23 | + |
| 24 | +def _row_to_key(row, force_tuple=False): |
| 25 | + if row.values.size > 1 or force_tuple: |
| 26 | + return tuple(row.values) |
| 27 | + return row.item() |
24 | 28 |
|
25 | 29 |
|
26 | 30 | class DataSaver(BaseLearner): |
@@ -111,10 +115,9 @@ def to_dataframe( # type: ignore[override] |
111 | 115 | **kwargs, |
112 | 116 | ) |
113 | 117 |
|
114 | | - # Detect if the learner uses tuple keys even for single inputs (e.g., LearnerND 1D) |
115 | | - use_tuple = self.extra_data and isinstance(next(iter(self.extra_data)), tuple) |
| 118 | + force_tuple = _mapping_uses_tuple_keys(self.extra_data) |
116 | 119 | df[extra_data_name] = [ |
117 | | - self.extra_data[_to_key(x, use_tuple=use_tuple)] |
| 120 | + self.extra_data[_row_to_key(x, force_tuple=force_tuple)] |
118 | 121 | for _, x in df[df.attrs["inputs"]].iterrows() |
119 | 122 | ] |
120 | 123 | return df |
@@ -151,12 +154,11 @@ def load_dataframe( # type: ignore[override] |
151 | 154 | function_prefix=function_prefix, |
152 | 155 | **kwargs, |
153 | 156 | ) |
154 | | - keys = df.attrs.get("inputs", list(input_names)) |
155 | | - # Detect if the learner uses tuple keys even for single inputs |
156 | | - use_tuple = self.data and isinstance(next(iter(self.data)), tuple) |
157 | | - for _, x in df[keys + [extra_data_name]].iterrows(): |
158 | | - key = _to_key(x.iloc[:-1], use_tuple=use_tuple) |
159 | | - self.extra_data[key] = x.iloc[-1] |
| 157 | + input_columns = df.attrs.get("inputs", list(input_names)) |
| 158 | + force_tuple = _mapping_uses_tuple_keys(self.learner.data) |
| 159 | + for _, row in df[input_columns + [extra_data_name]].iterrows(): |
| 160 | + key = _row_to_key(row.iloc[:-1], force_tuple=force_tuple) |
| 161 | + self.extra_data[key] = row.iloc[-1] |
160 | 162 |
|
161 | 163 | def _get_data(self) -> tuple[Any, OrderedDict[Any, Any]]: |
162 | 164 | return self.learner._get_data(), self.extra_data |
|
0 commit comments