Skip to content

Commit 62ab1d8

Browse files
committed
test: fix test
1 parent b0b707e commit 62ab1d8

2 files changed

Lines changed: 14 additions & 8 deletions

File tree

gokart/target.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,12 @@ def __init__(
8080
target: luigi.target.FileSystemTarget,
8181
processor: FileProcessor,
8282
task_lock_params: TaskLockParams,
83-
expected_dataframe_type: Optional[pa.DataFrameModel] = None,
83+
expected_dataframe_type: Optional[type[pa.DataFrameModel]] = None,
8484
) -> None:
8585
self._target = target
8686
self._processor = processor
8787
self._task_lock_params = task_lock_params
88+
self._expected_dataframe_type = expected_dataframe_type
8889

8990
def _exists(self) -> bool:
9091
return self._target.exists()
@@ -95,14 +96,14 @@ def _get_task_lock_params(self) -> TaskLockParams:
9596
def _load(self) -> Any:
9697
with self._target.open('r') as f:
9798
obj = self._processor.load(f)
98-
if self.expected_dataframe_type is not None:
99-
return self.expected_dataframe_type(obj)
99+
if self._expected_dataframe_type is not None:
100+
return pa.typing.DataFrame[self._expected_dataframe_type](obj)
100101

101102
return obj
102103

103104
def _dump(self, obj) -> None:
104-
if self.expected_dataframe_type is not None:
105-
self.expected_dataframe_type.validate(obj)
105+
if self._expected_dataframe_type is not None:
106+
self._expected_dataframe_type.validate(obj)
106107

107108
with self._target.open('w') as f:
108109
self._processor.dump(obj, f)

test/test_target.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from matplotlib import pyplot
1515
from moto import mock_s3
1616

17+
from gokart.conflict_prevention_lock.task_lock import make_task_lock_params
1718
from gokart.file_processor import _ChunkedLargeFileReader, make_file_processor
1819
from gokart.target import SingleFileTarget, make_model_target, make_target
1920

@@ -263,22 +264,26 @@ def test_typed_target(self):
263264
test_case = pd.DataFrame(dict(a=[1, 2]))
264265

265266
with tempfile.TemporaryDirectory() as temp_dir:
266-
_task_lock_params = None
267267
file_path = os.path.join(temp_dir, 'test.csv')
268+
unique_id = 'test_unique_id'
269+
_task_lock_params = make_task_lock_params(file_path=file_path, unique_id=unique_id)
268270
processor = make_file_processor(file_path, store_index_in_feather=False)
269271
file_system_target = luigi.LocalTarget(file_path, format=processor.format())
270272
file_target = SingleFileTarget(target=file_system_target, processor=processor, task_lock_params=_task_lock_params, expected_dataframe_type=self.DummyDataFrameSchema)
271273

272274
file_target.dump(test_case)
273275
dumped_data = file_target.load()
274-
self.assertIsInstance(dumped_data, self.DummyDataFrameSchema)
276+
self.assertIsInstance(dumped_data, pa.typing.DataFrame)
277+
self.DummyDataFrameSchema.validate(dumped_data)
278+
275279

276280
def test_invalid_typed_target(self):
277281
test_case = pd.DataFrame(dict(a=['1', '2']))
278282

279283
with tempfile.TemporaryDirectory() as temp_dir:
280-
_task_lock_params = None
281284
file_path = os.path.join(temp_dir, 'test.csv')
285+
unique_id = 'test_unique_id'
286+
_task_lock_params = make_task_lock_params(file_path=file_path, unique_id=unique_id)
282287
processor = make_file_processor(file_path, store_index_in_feather=False)
283288
file_system_target = luigi.LocalTarget(file_path, format=processor.format())
284289
file_target = SingleFileTarget(target=file_system_target, processor=processor, task_lock_params=_task_lock_params, expected_dataframe_type=self.DummyDataFrameSchema)

0 commit comments

Comments
 (0)