Skip to content

Commit efbafa9

Browse files
committed
refactor: simplify DataSaver key reconstruction
1 parent df5bbaf commit efbafa9

1 file changed

Lines changed: 15 additions & 13 deletions

File tree

adaptive/learner/data_saver.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,14 @@
1717
with_pandas = False
1818

1919

20-
def _to_key(x, use_tuple=False):
21-
if x.values.size > 1 or use_tuple:
22-
return tuple(x.values)
23-
return x.item()
20+
def _mapping_uses_tuple_keys(mapping):
21+
return bool(mapping) and isinstance(next(iter(mapping)), tuple)
22+
23+
24+
def _row_to_key(row, force_tuple=False):
25+
if row.values.size > 1 or force_tuple:
26+
return tuple(row.values)
27+
return row.item()
2428

2529

2630
class DataSaver(BaseLearner):
@@ -111,10 +115,9 @@ def to_dataframe( # type: ignore[override]
111115
**kwargs,
112116
)
113117

114-
# Detect if the learner uses tuple keys even for single inputs (e.g., LearnerND 1D)
115-
use_tuple = self.extra_data and isinstance(next(iter(self.extra_data)), tuple)
118+
force_tuple = _mapping_uses_tuple_keys(self.extra_data)
116119
df[extra_data_name] = [
117-
self.extra_data[_to_key(x, use_tuple=use_tuple)]
120+
self.extra_data[_row_to_key(x, force_tuple=force_tuple)]
118121
for _, x in df[df.attrs["inputs"]].iterrows()
119122
]
120123
return df
@@ -151,12 +154,11 @@ def load_dataframe( # type: ignore[override]
151154
function_prefix=function_prefix,
152155
**kwargs,
153156
)
154-
keys = df.attrs.get("inputs", list(input_names))
155-
# Detect if the learner uses tuple keys even for single inputs
156-
use_tuple = self.data and isinstance(next(iter(self.data)), tuple)
157-
for _, x in df[keys + [extra_data_name]].iterrows():
158-
key = _to_key(x.iloc[:-1], use_tuple=use_tuple)
159-
self.extra_data[key] = x.iloc[-1]
157+
input_columns = df.attrs.get("inputs", list(input_names))
158+
force_tuple = _mapping_uses_tuple_keys(self.learner.data)
159+
for _, row in df[input_columns + [extra_data_name]].iterrows():
160+
key = _row_to_key(row.iloc[:-1], force_tuple=force_tuple)
161+
self.extra_data[key] = row.iloc[-1]
160162

161163
def _get_data(self) -> tuple[Any, OrderedDict[Any, Any]]:
162164
return self.learner._get_data(), self.extra_data

0 commit comments

Comments
 (0)