Skip to content

Commit 764e169

Browse files
Carl Hvarfnerfacebook-github-bot
authored andcommitted
Use get_heterogeneous_feature_mapping in LFI argparse dispatcher
Summary: Refactors the learned imputation argument dispatcher to delegate feature index computation to the dataset's built-in mapping utility. This eliminates duplicated feature-ordering logic and ensures consistency with the canonical ordering convention. Differential Revision: D102197138
1 parent 970c571 commit 764e169

2 files changed

Lines changed: 5 additions & 35 deletions

File tree

ax/generators/torch/botorch_modular/input_constructors/input_transforms.py

Lines changed: 4 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -357,40 +357,12 @@ def _input_transform_argparse_learned_feature_imputation(
357357
)
358358
input_transform_options = input_transform_options or {}
359359

360-
# Order datasets: target first, then remaining (same as HeterogeneousMTGP).
361-
child_datasets = dataset.datasets.copy()
362-
target_dataset = child_datasets.pop(dataset.target_outcome_name)
363-
all_datasets = [target_dataset] + list(child_datasets.values())
364-
365-
# The feature_names[:task_feature_index] slice only works when the task
366-
# column is the last column (index == -1). Guard against other positions
367-
# the same way ImputedMultiTaskGP.construct_inputs does.
360+
# Delegate feature ordering and index computation to MultiTaskDataset.
361+
all_datasets, feature_indices_list, d = dataset.get_heterogeneous_feature_mapping()
362+
feature_indices = dict(enumerate(feature_indices_list))
368363
task_feature_index = (
369-
dataset.task_feature_index if (dataset.task_feature_index is not None) else -1
364+
dataset.task_feature_index if dataset.task_feature_index is not None else -1
370365
)
371-
if task_feature_index != -1:
372-
raise NotImplementedError(
373-
"LearnedFeatureImputation argparse only supports "
374-
"task_feature_index == -1. Got "
375-
f"task_feature_index={task_feature_index}."
376-
)
377-
378-
# Use target's feature order as canonical (NO alphabetical sort).
379-
# Source-only features are appended at the end.
380-
all_features: list[str] = list(target_dataset.feature_names[:task_feature_index])
381-
for ds in all_datasets[1:]:
382-
for fn in ds.feature_names[:task_feature_index]:
383-
if fn not in all_features:
384-
all_features.append(fn)
385-
d = len(all_features)
386-
387-
# Map each task's features to indices in the global feature space.
388-
feature_indices = {
389-
task_idx: [
390-
all_features.index(fn) for fn in ds.feature_names[:task_feature_index]
391-
]
392-
for task_idx, ds in enumerate(all_datasets)
393-
}
394366

395367
dtype = torch_dtype or torch.float64
396368
# Constrain imputation values to [0, 1] since the preceding Normalize

ax/generators/torch/tests/test_input_transform_argparse.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -475,9 +475,7 @@ def test_argparse_learned_feature_imputation(self) -> None:
475475
target_outcome_name="y0",
476476
task_feature_index=0,
477477
)
478-
with self.assertRaisesRegex(
479-
NotImplementedError, "task_feature_index == -1"
480-
):
478+
with self.assertRaisesRegex(NotImplementedError, "task_feature_index.*-1"):
481479
input_transform_argparse(
482480
LearnedFeatureImputation,
483481
dataset=bad_ds,

0 commit comments

Comments
 (0)