Skip to content

Commit f37867f

Browse files
authored
Merge pull request #140 from PythonPredictions/fix/mutable_train_data_in_fit_transform
Fix/mutable train data in fit transform
2 parents e04ba90 + 1ffe9ed commit f37867f

2 files changed

Lines changed: 35 additions & 0 deletions

File tree

cobra/preprocessing/preprocessor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,9 @@ def transform(self, data: pd.DataFrame, continuous_vars: list,
293293

294294
start = time.time()
295295

296+
# Ensure to operate on separate copy of data
297+
data = data.copy()
298+
296299
if not self._is_fitted:
297300
msg = ("This {} instance is not fitted yet. Call 'fit' with "
298301
"appropriate arguments before using this method.")

tests/preprocessing/test_preprocessor.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11

22
from contextlib import contextmanager
33
from typing import Any
4+
from unittest.mock import MagicMock
45
import pytest
56
import numpy as np
67
import pandas as pd
8+
from pytest_mock import MockerFixture
79

810
from cobra.preprocessing.preprocessor import PreProcessor
911

@@ -146,3 +148,33 @@ def test_get_variable_list(self, continuous_vars: list,
146148
discrete_vars)
147149

148150
assert actual == expected
151+
152+
@staticmethod
153+
def mock_transform(df: pd.DataFrame, args):
154+
"""Mock the transform method."""
155+
df["new_column"] = "Hello World"
156+
return df
157+
158+
def test_mutable_train_data_fit_transform(self, mocker: MockerFixture):
159+
"""Test if the train_data input is not changed when performing fit_transform."""
160+
train_data = pd.DataFrame([[1, "2", 3], [10, "20", 30], [100, "200", 300]], columns=["foo", "bar", "baz"])
161+
preprocessor = PreProcessor.from_params(
162+
model_type="classification",
163+
n_bins=10,
164+
weight= 0.8
165+
)
166+
preprocessor._categorical_data_processor = MagicMock()
167+
preprocessor._categorical_data_processor.transform = self.mock_transform
168+
preprocessor._discretizer = MagicMock()
169+
preprocessor._discretizer.transform = self.mock_transform
170+
preprocessor._target_encoder = MagicMock()
171+
preprocessor._target_encoder.transform = self.mock_transform
172+
173+
result = preprocessor.fit_transform(
174+
train_data,
175+
continuous_vars=["foo"],
176+
discrete_vars=["bar"],
177+
target_column_name=["baz"]
178+
)
179+
assert "new_column" not in train_data.columns
180+
assert "new_column" in result.columns

0 commit comments

Comments
 (0)