Skip to content

Commit ed44d5d

Browse files
Copilotxadupresdpython
authored
Fix four bugs in plot_template_data.py causing compute_oracle and pipeline to fail (#90)
* Initial plan * Fix multiple bugs in plot_template_data.py Co-authored-by: xadupre <22452781+xadupre@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: xadupre <22452781+xadupre@users.noreply.github.com> Co-authored-by: xavier dupré <sdpython@users.noreply.github.com>
1 parent 355cf1f commit ed44d5d

File tree

1 file changed

+6
-9
lines changed

1 file changed

+6
-9
lines changed

_doc/examples/ml/plot_template_data.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,10 @@ def select_variables_and_clean(df):
5252
columns = set(df.columns)
5353
assert set(keys) & set(columns) == set(
5454
keys
55-
), f"Missing columns {set(keys) - set(columns)} in {sorted(df.columns)}"
56-
groups = df[[*keys, cible]].groupby(keys).count()
57-
filtered = groups[groups[cible] > 1].reset_index(drop=False)
58-
59-
mask = filtered.duplicated(subset=keys, keep=False)
60-
return filtered[~mask][[*keys, cible]], cible
55+
), f"Missing columns {set(keys) - set(keys) & set(columns)} in {sorted(df.columns)}"
56+
subset = df[[*keys, cible]]
57+
mask = subset.duplicated(subset=keys, keep=False)
58+
return subset[~mask].reset_index(drop=True), cible
6159

6260

6361
def compute_oracle(table, cible):
@@ -72,6 +70,7 @@ def compute_oracle(table, cible):
7270
columns="Session",
7371
values=cible,
7472
)
73+
.dropna(axis=0)
7574
.sort_index()
7675
)
7776
# Keep only rows where both 2024 and 2025 have non-missing values
@@ -99,9 +98,7 @@ def split_train_test(table, cible):
9998

10099
def make_pipeline(table, cible):
101100
vars = [c for c in table.columns if c != cible]
102-
# Candidate numeric feature; include it only if it exists in the table to avoid KeyError.
103-
numeric_feature = "Capacité de l’établissement par formation"
104-
num_cols = [numeric_feature] if numeric_feature in table.columns else []
101+
num_cols = ["Capacité de l’établissement par formation"]
105102
cat_cols = [c for c in vars if c not in num_cols]
106103

107104
transformers = []

0 commit comments

Comments
 (0)