Skip to content

Commit 177e72c

Browse files
committed
feat: add the new TargetOnKart entrypoint make_cache_target
feat: add the new parameter `cache_in_memory_by_default` to switch default Target style: update the variable name from `target_key` to `data_key` for code consistency test: add tests for `TaskOnKart`s with the `cache_in_memory` parameter
1 parent b9dbb4d commit 177e72c

4 files changed

Lines changed: 160 additions & 6 deletions

File tree

gokart/in_memory/target.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from datetime import datetime
22
from logging import warning
3-
from typing import Any
3+
from typing import Any, Optional
44

55
from gokart.in_memory.repository import InMemoryCacheRepository
66
from gokart.target import TargetOnKart, TaskLockParams
@@ -42,5 +42,12 @@ def _path(self) -> str:
4242
return self._data_key
4343

4444

45-
def make_inmemory_target(target_key: str, task_lock_params: TaskLockParams):
46-
return InMemoryTarget(target_key, task_lock_params)
45+
def _make_data_key(data_key: str, unique_id: Optional[str] = None):
46+
if not unique_id:
47+
return data_key
48+
return data_key + '_' + unique_id
49+
50+
51+
def make_inmemory_target(data_key: str, task_lock_params: TaskLockParams, unique_id: Optional[str] = None):
52+
_data_key = _make_data_key(data_key, unique_id)
53+
return InMemoryTarget(_data_key, task_lock_params)

gokart/task.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,10 @@
2020

2121
import gokart
2222
import gokart.target
23-
from gokart.conflict_prevention_lock.task_lock import make_task_lock_params, make_task_lock_params_for_run
23+
from gokart.conflict_prevention_lock.task_lock import TaskLockParams, make_task_lock_params, make_task_lock_params_for_run
2424
from gokart.conflict_prevention_lock.task_lock_wrappers import wrap_run_with_lock
2525
from gokart.file_processor import FileProcessor
26+
from gokart.in_memory.target import make_inmemory_target
2627
from gokart.pandas_type_config import PandasTypeConfigMap
2728
from gokart.parameter import ExplicitBoolParameter, ListTaskInstanceParameter, TaskInstanceParameter
2829
from gokart.target import TargetOnKart
@@ -105,6 +106,9 @@ class TaskOnKart(luigi.Task, Generic[T]):
105106
default=True, description='Check if output file exists at run. If exists, run() will be skipped.', significant=False
106107
)
107108
should_lock_run: bool = ExplicitBoolParameter(default=False, significant=False, description='Whether to use redis lock or not at task run.')
109+
cache_in_memory_by_default: bool = ExplicitBoolParameter(
110+
default=False, significant=False, description='If `True`, output is stored on a memory instead of files unless specified.'
111+
)
108112

109113
@property
110114
def priority(self):
@@ -134,11 +138,13 @@ def __init__(self, *args, **kwargs):
134138
task_lock_params = make_task_lock_params_for_run(task_self=self)
135139
self.run = wrap_run_with_lock(run_func=self.run, task_lock_params=task_lock_params) # type: ignore
136140

141+
self.make_default_target = self.make_target if not self.cache_in_memory_by_default else self.make_cache_target
142+
137143
def input(self) -> FlattenableItems[TargetOnKart]:
138144
return super().input()
139145

140146
def output(self) -> FlattenableItems[TargetOnKart]:
141-
return self.make_target()
147+
return self.make_default_target()
142148

143149
def requires(self) -> FlattenableItems['TaskOnKart']:
144150
tasks = self.make_task_instance_dictionary()
@@ -210,11 +216,19 @@ def clone(self, cls=None, **kwargs):
210216
return cls(**new_k)
211217

212218
def make_target(self, relative_file_path: Optional[str] = None, use_unique_id: bool = True, processor: Optional[FileProcessor] = None) -> TargetOnKart:
219+
# if self.cache_in_memory and processor:
220+
# logger.warning(f"processor {type(processor)} never used.")
213221
formatted_relative_file_path = (
214222
relative_file_path if relative_file_path is not None else os.path.join(self.__module__.replace('.', '/'), f'{type(self).__name__}.pkl')
215223
)
216224
file_path = os.path.join(self.workspace_directory, formatted_relative_file_path)
217225
unique_id = self.make_unique_id() if use_unique_id else None
226+
# if self.cache_in_memory:
227+
# from gokart.target import _make_file_path
228+
# return make_inmemory_target(
229+
# target_key=_make_file_path(file_path, unique_id),
230+
# task_lock_params=TaskLockParams(None, None, None, "hoge", False, False, 100)
231+
# )
218232

219233
task_lock_params = make_task_lock_params(
220234
file_path=file_path,
@@ -229,6 +243,21 @@ def make_target(self, relative_file_path: Optional[str] = None, use_unique_id: b
229243
file_path=file_path, unique_id=unique_id, processor=processor, task_lock_params=task_lock_params, store_index_in_feather=self.store_index_in_feather
230244
)
231245

246+
def make_cache_target(self, data_key: Optional[str] = None, use_unique_id: bool = True):
247+
_data_key = data_key if data_key else os.path.join(self.__module__.replace('.', '/'), type(self).__name__)
248+
unique_id = self.make_unique_id() if use_unique_id else None
249+
# TODO: combine with redis
250+
task_lock_params = TaskLockParams(
251+
redis_host=None,
252+
redis_port=None,
253+
redis_timeout=None,
254+
redis_key='redis_key',
255+
should_task_lock=False,
256+
raise_task_lock_exception_on_collision=False,
257+
lock_extend_seconds=-1,
258+
)
259+
return make_inmemory_target(_data_key, task_lock_params, unique_id)
260+
232261
def make_large_data_frame_target(self, relative_file_path: Optional[str] = None, use_unique_id: bool = True, max_byte=int(2**26)) -> TargetOnKart:
233262
formatted_relative_file_path = (
234263
relative_file_path if relative_file_path is not None else os.path.join(self.__module__.replace('.', '/'), f'{type(self).__name__}.zip')

test/in_memory/test_in_memory_target.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def task_lock_params(self):
2222

2323
@pytest.fixture
2424
def target(self, task_lock_params: TaskLockParams):
25-
return make_inmemory_target(target_key='dummy_key', task_lock_params=task_lock_params)
25+
return make_inmemory_target(data_key='dummy_key', task_lock_params=task_lock_params)
2626

2727
@pytest.fixture(autouse=True)
2828
def clear_repo(self):
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
from typing import Optional, Type, Union
2+
3+
import luigi
4+
import pytest
5+
6+
import gokart
7+
from gokart.in_memory import InMemoryCacheRepository, InMemoryTarget
8+
from gokart.target import SingleFileTarget
9+
10+
11+
class DummyTask(gokart.TaskOnKart):
12+
task_namespace = __name__
13+
param: str = luigi.Parameter()
14+
15+
def run(self):
16+
self.dump(self.param)
17+
18+
19+
class DummyTaskWithDependencies(gokart.TaskOnKart):
20+
task_namespace = __name__
21+
task: list[gokart.TaskOnKart[str]] = gokart.ListTaskInstanceParameter()
22+
23+
def run(self):
24+
result = ','.join(self.load())
25+
self.dump(result)
26+
27+
28+
class DumpIntTask(gokart.TaskOnKart[int]):
29+
task_namespace = __name__
30+
value: int = luigi.IntParameter()
31+
32+
def run(self):
33+
self.dump(self.value)
34+
35+
36+
class AddTask(gokart.TaskOnKart[Union[int, float]]):
37+
a: gokart.TaskOnKart[int] = gokart.TaskInstanceParameter()
38+
b: gokart.TaskOnKart[int] = gokart.TaskInstanceParameter()
39+
40+
def requires(self):
41+
return dict(a=self.a, b=self.b)
42+
43+
def run(self):
44+
a = self.load(self.a)
45+
b = self.load(self.b)
46+
self.dump(a + b)
47+
48+
49+
class TestTaskOnKartWithCache:
50+
@pytest.fixture(autouse=True)
51+
def clear_repository(slef):
52+
InMemoryCacheRepository().clear()
53+
54+
@pytest.mark.parametrize('data_key', ['sample_key', None])
55+
@pytest.mark.parametrize('use_unique_id', [True, False])
56+
def test_key_identity(self, data_key: Optional[str], use_unique_id: bool):
57+
task = DummyTask(param='param')
58+
ext = '.pkl'
59+
relative_file_path = data_key + ext if data_key else None
60+
target = task.make_target(relative_file_path=relative_file_path, use_unique_id=use_unique_id)
61+
cached_target = task.make_cache_target(data_key=data_key, use_unique_id=use_unique_id)
62+
63+
target_path = target.path().removeprefix(task.workspace_directory).removesuffix(ext).strip('/')
64+
assert cached_target.path() == target_path
65+
66+
def test_make_cached_target(self):
67+
task = DummyTask(param='param')
68+
target = task.make_cache_target()
69+
assert isinstance(target, InMemoryTarget)
70+
71+
@pytest.mark.parametrize(['cache_in_memory_by_default', 'target_type'], [[True, InMemoryTarget], [False, SingleFileTarget]])
72+
def test_make_default_target(self, cache_in_memory_by_default: bool, target_type: Type[gokart.TaskOnKart]):
73+
task = DummyTask(param='param', cache_in_memory_by_default=cache_in_memory_by_default)
74+
target = task.output()
75+
assert isinstance(target, target_type)
76+
77+
def test_complete_with_cache_in_memory_flag(self, tmpdir):
78+
task = DummyTask(param='param', cache_in_memory_by_default=True, workspace_directory=tmpdir)
79+
assert not task.complete()
80+
file_target = task.make_target()
81+
file_target.dump('data')
82+
assert not task.complete()
83+
cache_target = task.make_cache_target()
84+
cache_target.dump('data')
85+
assert task.complete()
86+
87+
def test_complete_without_cache_in_memory_flag(self, tmpdir):
88+
task = DummyTask(param='param', workspace_directory=tmpdir)
89+
assert not task.complete()
90+
cache_target = task.make_cache_target()
91+
cache_target.dump('data')
92+
assert not task.complete()
93+
file_target = task.make_target()
94+
file_target.dump('data')
95+
assert task.complete()
96+
97+
def test_dump_with_cache_in_memory_flag(self, tmpdir):
98+
task = DummyTask(param='param', cache_in_memory_by_default=True, workspace_directory=tmpdir)
99+
file_target = task.make_target()
100+
cache_target = task.make_cache_target()
101+
task.dump('data')
102+
assert not file_target.exists()
103+
assert cache_target.exists()
104+
105+
def test_dump_without_cache_in_memory_flag(self, tmpdir):
106+
task = DummyTask(param='param', workspace_directory=tmpdir)
107+
file_target = task.make_target()
108+
cache_target = task.make_cache_target()
109+
task.dump('data')
110+
assert file_target.exists()
111+
assert not cache_target.exists()
112+
113+
def test_gokart_build(self):
114+
task = AddTask(
115+
a=DumpIntTask(value=2, cache_in_memory_by_default=True), b=DumpIntTask(value=3, cache_in_memory_by_default=True), cache_in_memory_by_default=True
116+
)
117+
output = gokart.build(task, reset_register=False)
118+
assert output == 5

0 commit comments

Comments
 (0)