|
14 | 14 | from matplotlib import pyplot |
15 | 15 | from moto import mock_s3 |
16 | 16 |
|
| 17 | +from gokart.conflict_prevention_lock.task_lock import make_task_lock_params |
17 | 18 | from gokart.file_processor import _ChunkedLargeFileReader, make_file_processor |
18 | 19 | from gokart.target import SingleFileTarget, make_model_target, make_target |
19 | 20 |
|
@@ -263,22 +264,26 @@ def test_typed_target(self): |
263 | 264 | test_case = pd.DataFrame(dict(a=[1, 2])) |
264 | 265 |
|
265 | 266 | with tempfile.TemporaryDirectory() as temp_dir: |
266 | | - _task_lock_params = None |
267 | 267 | file_path = os.path.join(temp_dir, 'test.csv') |
| 268 | + unique_id = 'test_unique_id' |
| 269 | + _task_lock_params = make_task_lock_params(file_path=file_path, unique_id=unique_id) |
268 | 270 | processor = make_file_processor(file_path, store_index_in_feather=False) |
269 | 271 | file_system_target = luigi.LocalTarget(file_path, format=processor.format()) |
270 | 272 | file_target = SingleFileTarget(target=file_system_target, processor=processor, task_lock_params=_task_lock_params, expected_dataframe_type=self.DummyDataFrameSchema) |
271 | 273 |
|
272 | 274 | file_target.dump(test_case) |
273 | 275 | dumped_data = file_target.load() |
274 | | - self.assertIsInstance(dumped_data, self.DummyDataFrameSchema) |
| 276 | + self.assertIsInstance(dumped_data, pa.typing.DataFrame) |
| 277 | + self.DummyDataFrameSchema.validate(dumped_data) |
| 278 | + |
275 | 279 |
|
276 | 280 | def test_invalid_typed_target(self): |
277 | 281 | test_case = pd.DataFrame(dict(a=['1', '2'])) |
278 | 282 |
|
279 | 283 | with tempfile.TemporaryDirectory() as temp_dir: |
280 | | - _task_lock_params = None |
281 | 284 | file_path = os.path.join(temp_dir, 'test.csv') |
| 285 | + unique_id = 'test_unique_id' |
| 286 | + _task_lock_params = make_task_lock_params(file_path=file_path, unique_id=unique_id) |
282 | 287 | processor = make_file_processor(file_path, store_index_in_feather=False) |
283 | 288 | file_system_target = luigi.LocalTarget(file_path, format=processor.format()) |
284 | 289 | file_target = SingleFileTarget(target=file_system_target, processor=processor, task_lock_params=_task_lock_params, expected_dataframe_type=self.DummyDataFrameSchema) |
|
0 commit comments