Skip to content

Commit 1adfcb6

Browse files
alexfurmenkovRamilCDISCSFJohnson24
authored
691: Tests for CoW behavior in pandas (#1690)
* tests for CoW behavior in pandas * tested true CoW via shallow copy * added shallow copying for cached datasets * dask copy workaround * fix CoW tests and wrapper * added tests for cache methods. changed cache access to get() and get_dataset() methods * readme notice about CoW usage * fix filter_cache access to cache * edits * merge main * merge main --------- Co-authored-by: RamilCDISC <113539111+RamilCDISC@users.noreply.github.com> Co-authored-by: Samuel Johnson <96841389+SFJohnson24@users.noreply.github.com> Co-authored-by: Samuel Johnson <sfjohnson24@gmail.com>
1 parent ecb58e7 commit 1adfcb6

4 files changed

Lines changed: 255 additions & 14 deletions

File tree

cdisc_rules_engine/rules_engine.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import List, Union
33
from dateutil.parser._parser import ParserError
44
import traceback
5+
import pandas as pd
56

67
from business_rules import export_rule_data
78
from business_rules.engine import run
@@ -33,6 +34,7 @@
3334
DataServiceInterface,
3435
)
3536
from cdisc_rules_engine.models.actions import COREActions
37+
from cdisc_rules_engine.models.dataset import DaskDataset
3638
from cdisc_rules_engine.models.dataset.dataset_interface import DatasetInterface
3739
from cdisc_rules_engine.models.dataset_variable import DatasetVariable
3840
from cdisc_rules_engine.models.failed_validation_entity import FailedValidationEntity
@@ -59,6 +61,8 @@
5961
from cdisc_rules_engine.models.sdtm_dataset_metadata import SDTMDatasetMetadata
6062
from cdisc_rules_engine.enums.sensitivity import Sensitivity
6163

64+
pd.options.mode.copy_on_write = True
65+
6266

6367
class RulesEngine:
6468
def __init__(
@@ -375,9 +379,9 @@ def execute_rule(
375379
rule["conditions"], dataset.columns.to_list()
376380
)
377381
rule_copy["conditions"].set_conditions(updated_conditions)
378-
# Adding copy for now to avoid updating cached dataset
379-
dataset = deepcopy(dataset)
380382
# preprocess dataset
383+
if isinstance(dataset, DaskDataset):
384+
dataset = deepcopy(dataset)
381385
dataset_preprocessor = DatasetPreprocessor(
382386
dataset, dataset_metadata, self.data_service, self.cache
383387
)

cdisc_rules_engine/services/cache/in_memory_cache_service.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from cdisc_rules_engine.interfaces import (
66
CacheServiceInterface,
77
)
8-
from cdisc_rules_engine.models.dataset import DatasetInterface
8+
from cdisc_rules_engine.models.dataset import DatasetInterface, PandasDataset
99
from cachetools import LRUCache
1010
import psutil
1111
from multiprocessing import Lock
@@ -62,11 +62,16 @@ def add(self, cache_key, data):
6262
)
6363

6464
def add_dataset(self, cache_key, data):
65+
if get_data_size(data) > self.max_dataset_cache_size:
66+
return
6567
with self.dataset_cache_lock:
6668
self.dataset_cache[cache_key] = data
6769

6870
def get_dataset(self, cache_key):
69-
return self.dataset_cache.get(cache_key, None)
71+
cached = self.dataset_cache.get(cache_key)
72+
if type(cached) is PandasDataset:
73+
return PandasDataset(cached.data.copy(deep=False))
74+
return cached
7075

7176
def add_batch(
7277
self,
@@ -82,27 +87,32 @@ def add_batch(
8287
self.add(prefix + cache_key, item)
8388

8489
def get(self, cache_key):
85-
return self.cache.get(cache_key, None)
90+
cached = self.cache.get(cache_key)
91+
if type(cached) is PandasDataset:
92+
return PandasDataset(cached.data.copy(deep=False))
93+
return cached
8694

8795
def get_all(self, cache_keys: List[str]):
88-
return [self.cache.get(key) for key in cache_keys]
96+
return [self.get(key) for key in cache_keys]
8997

9098
def get_all_by_prefix(self, prefix):
91-
items = []
92-
for key in self.cache:
93-
if key.startswith(prefix):
94-
items.append(self.cache[key])
95-
return items
99+
with self.cache_lock:
100+
keys = [key for key in self.cache.keys() if key.startswith(prefix)]
101+
return [self.get(key) for key in keys]
96102

97103
def dataset_keys(self):
98104
return self.dataset_cache.keys()
99105

100106
def filter_cache(self, prefix: str) -> dict:
101-
return {k: self.cache[k] for k in self.cache.keys() if k.startswith(prefix)}
107+
with self.cache_lock:
108+
keys = [k for k in self.cache.keys() if k.startswith(prefix)]
109+
return {k: self.get(k) for k in keys}
102110

103111
def get_by_regex(self, regex: str) -> dict:
104112
regex = regex.replace("*", ".*")
105-
return {k: self.cache[k] for k in self.cache.keys() if re.search(regex, k)}
113+
with self.cache_lock:
114+
keys = [k for k in self.cache.keys() if re.search(regex, k)]
115+
return {k: self.get(k) for k in keys}
106116

107117
def exists(self, cache_key):
108118
return cache_key in self.cache
@@ -119,7 +129,7 @@ def clear_all(self, prefix: str = None):
119129
for key in keys_to_remove:
120130
self.clear(key)
121131
else:
122-
self.cache = LRUCache(maxsize=self.max_size, getsizeof=asizeof.asizeof)
132+
self.cache = LRUCache(maxsize=self.max_size, getsizeof=cust_asizeof)
123133

124134
def add_all(self, data: dict):
125135
for key, val in data.items():

docs/cli-reference.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88

99
Run conformance validation against a CDISC standard.
1010

11+
Validate has pandas Copy-on-Write (CoW) enabled globally when using the rules engine.
12+
**Note**: In Pandas 2.x this is an opt-in feature, in Pandas 3.x, CoW is enabled by default.
13+
1114
```bash
1215
python core.py validate --help
1316
```
Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
import numpy as np
2+
import pandas as pd
3+
import pytest
4+
5+
from cdisc_rules_engine.models.dataset.pandas_dataset import PandasDataset
6+
from cdisc_rules_engine.services.cache.in_memory_cache_service import (
7+
InMemoryCacheService,
8+
)
9+
10+
11+
@pytest.fixture(autouse=True)
12+
def reset_singleton():
13+
InMemoryCacheService._instance = None
14+
yield
15+
InMemoryCacheService._instance = None
16+
17+
18+
@pytest.fixture
19+
def cache():
20+
return InMemoryCacheService()
21+
22+
23+
@pytest.fixture
24+
def sample_dataset():
25+
return PandasDataset(pd.DataFrame({"A": [1, 2, 3], "B": [10, 20, 30]}))
26+
27+
28+
class TestGet:
29+
def test_returns_new_wrapper_not_cached_object(self, cache, sample_dataset):
30+
cache.add("x", sample_dataset)
31+
result = cache.get("x")
32+
assert result is not cache.cache["x"]
33+
assert result.data is not cache.cache["x"].data
34+
35+
def test_cow_does_not_modify_cache_on_write(self, cache, sample_dataset):
36+
pd.options.mode.copy_on_write = True
37+
cache.add("x", sample_dataset)
38+
retrieved = cache.get("x")
39+
retrieved.data.loc[0, "A"] = 999
40+
assert cache.cache["x"].data.loc[0, "A"] == 1
41+
42+
def test_shares_memory_before_write(self, cache, sample_dataset):
43+
pd.options.mode.copy_on_write = True
44+
cache.add("x", sample_dataset)
45+
retrieved = cache.get("x")
46+
assert np.shares_memory(retrieved.data["A"], cache.cache["x"].data["A"])
47+
48+
def test_add_rows_does_not_affect_cache(self, cache, sample_dataset):
49+
pd.options.mode.copy_on_write = True
50+
cache.add("x", sample_dataset)
51+
retrieved = cache.get("x")
52+
retrieved.data = pd.concat(
53+
[retrieved.data, pd.DataFrame({"A": [999], "B": [999]})],
54+
ignore_index=True,
55+
)
56+
assert len(cache.cache["x"].data) == 3
57+
assert len(retrieved.data) == 4
58+
59+
def test_drop_rows_does_not_affect_cache(self, cache, sample_dataset):
60+
pd.options.mode.copy_on_write = True
61+
cache.add("x", sample_dataset)
62+
retrieved = cache.get("x")
63+
retrieved.data = retrieved.data.drop(index=0).reset_index(drop=True)
64+
assert len(cache.cache["x"].data) == 3
65+
assert len(retrieved.data) == 2
66+
67+
def test_filter_rows_does_not_affect_cache(self, cache, sample_dataset):
68+
pd.options.mode.copy_on_write = True
69+
cache.add("x", sample_dataset)
70+
retrieved = cache.get("x")
71+
retrieved.data = retrieved.data[retrieved.data["A"] > 1].reset_index(drop=True)
72+
assert len(cache.cache["x"].data) == 3
73+
assert cache.cache["x"].data["A"].tolist() == [1, 2, 3]
74+
75+
def test_multiple_gets_are_independent(self, cache, sample_dataset):
76+
pd.options.mode.copy_on_write = True
77+
cache.add("x", sample_dataset)
78+
first = cache.get("x")
79+
second = cache.get("x")
80+
first.data = first.data.drop(index=0).reset_index(drop=True)
81+
assert len(second.data) == 3
82+
assert len(cache.cache["x"].data) == 3
83+
84+
def test_non_dataset_returns_as_is(self, cache):
85+
cache.add("key", {"some": "dict"})
86+
assert cache.get("key") == {"some": "dict"}
87+
88+
def test_object_dtype_nested_mutation_affects_cache(self, cache):
89+
"""CoW can't protect in nested mutations"""
90+
df = pd.DataFrame({"A": [[1], [2], [3]]})
91+
cache.add("x", PandasDataset(df))
92+
retrieved = cache.get("x")
93+
retrieved.data.loc[0, "A"].append(999)
94+
assert cache.cache["x"].data.loc[0, "A"] == [1, 999]
95+
96+
97+
class TestGetDataset:
98+
def test_returns_new_wrapper_not_cached_object(self, cache, sample_dataset):
99+
cache.add_dataset("x", sample_dataset)
100+
result = cache.get_dataset("x")
101+
assert result is not cache.dataset_cache["x"]
102+
assert result.data is not cache.dataset_cache["x"].data
103+
104+
def test_cow_does_not_modify_cache_on_write(self, cache, sample_dataset):
105+
pd.options.mode.copy_on_write = True
106+
cache.add_dataset("x", sample_dataset)
107+
retrieved = cache.get_dataset("x")
108+
retrieved.data.loc[0, "A"] = 999
109+
assert cache.dataset_cache["x"].data.loc[0, "A"] == 1
110+
111+
def test_add_rows_does_not_affect_cache(self, cache, sample_dataset):
112+
pd.options.mode.copy_on_write = True
113+
cache.add_dataset("x", sample_dataset)
114+
retrieved = cache.get_dataset("x")
115+
retrieved.data = pd.concat(
116+
[retrieved.data, pd.DataFrame({"A": [999], "B": [999]})],
117+
ignore_index=True,
118+
)
119+
assert len(cache.dataset_cache["x"].data) == 3
120+
assert len(retrieved.data) == 4
121+
122+
def test_drop_rows_does_not_affect_cache(self, cache, sample_dataset):
123+
pd.options.mode.copy_on_write = True
124+
cache.add_dataset("x", sample_dataset)
125+
retrieved = cache.get_dataset("x")
126+
retrieved.data = retrieved.data.drop(index=0).reset_index(drop=True)
127+
assert len(cache.dataset_cache["x"].data) == 3
128+
assert len(retrieved.data) == 2
129+
130+
131+
class TestGetAll:
132+
def test_returns_new_wrappers(self, cache, sample_dataset):
133+
cache.add("x", sample_dataset)
134+
cache.add("y", sample_dataset)
135+
results = cache.get_all(["x", "y"])
136+
assert all(r is not cache.cache["x"] for r in results)
137+
assert all(r.data is not cache.cache["x"].data for r in results)
138+
139+
def test_results_are_independent(self, cache, sample_dataset):
140+
pd.options.mode.copy_on_write = True
141+
cache.add("x", sample_dataset)
142+
cache.add("y", sample_dataset)
143+
first, second = cache.get_all(["x", "y"])
144+
first.data = first.data.drop(index=0).reset_index(drop=True)
145+
assert len(second.data) == 3
146+
assert len(cache.cache["x"].data) == 3
147+
148+
def test_cow_does_not_modify_cache_on_write(self, cache, sample_dataset):
149+
pd.options.mode.copy_on_write = True
150+
cache.add("x", sample_dataset)
151+
results = cache.get_all(["x"])
152+
results[0].data.loc[0, "A"] = 999
153+
assert cache.cache["x"].data.loc[0, "A"] == 1
154+
155+
def test_missing_key_returns_none(self, cache):
156+
assert cache.get_all(["missing"]) == [None]
157+
158+
159+
class TestGetAllByPrefix:
160+
def test_returns_only_matching_keys(self, cache, sample_dataset):
161+
cache.add("ds/ae", sample_dataset)
162+
cache.add("ds/lb", sample_dataset)
163+
cache.add("other/ae", sample_dataset)
164+
results = cache.get_all_by_prefix("ds/")
165+
assert len(results) == 2
166+
167+
def test_returns_new_wrappers(self, cache, sample_dataset):
168+
cache.add("ds/ae", sample_dataset)
169+
results = cache.get_all_by_prefix("ds/")
170+
assert results[0] is not cache.cache["ds/ae"]
171+
assert results[0].data is not cache.cache["ds/ae"].data
172+
173+
def test_cow_does_not_modify_cache_on_write(self, cache, sample_dataset):
174+
pd.options.mode.copy_on_write = True
175+
cache.add("ds/ae", sample_dataset)
176+
results = cache.get_all_by_prefix("ds/")
177+
results[0].data.loc[0, "A"] = 999
178+
assert cache.cache["ds/ae"].data.loc[0, "A"] == 1
179+
180+
def test_drop_rows_does_not_affect_cache(self, cache, sample_dataset):
181+
pd.options.mode.copy_on_write = True
182+
cache.add("ds/ae", sample_dataset)
183+
results = cache.get_all_by_prefix("ds/")
184+
results[0].data = results[0].data.drop(index=0).reset_index(drop=True)
185+
assert len(cache.cache["ds/ae"].data) == 3
186+
187+
def test_no_match_returns_empty(self, cache, sample_dataset):
188+
cache.add("ds/ae", sample_dataset)
189+
assert cache.get_all_by_prefix("other/") == []
190+
191+
192+
class TestGetByRegex:
193+
def test_returns_matching_keys(self, cache, sample_dataset):
194+
cache.add("ae_data", sample_dataset)
195+
cache.add("lb_data", sample_dataset)
196+
cache.add("ae_meta", sample_dataset)
197+
result = cache.get_by_regex("ae_*")
198+
assert set(result.keys()) == {"ae_data", "ae_meta"}
199+
200+
def test_returns_new_wrappers(self, cache, sample_dataset):
201+
cache.add("ae_data", sample_dataset)
202+
result = cache.get_by_regex("ae_*")
203+
assert result["ae_data"] is not cache.cache["ae_data"]
204+
assert result["ae_data"].data is not cache.cache["ae_data"].data
205+
206+
def test_cow_does_not_modify_cache_on_write(self, cache, sample_dataset):
207+
pd.options.mode.copy_on_write = True
208+
cache.add("ae_data", sample_dataset)
209+
result = cache.get_by_regex("ae_*")
210+
result["ae_data"].data.loc[0, "A"] = 999
211+
assert cache.cache["ae_data"].data.loc[0, "A"] == 1
212+
213+
def test_drop_rows_does_not_affect_cache(self, cache, sample_dataset):
214+
pd.options.mode.copy_on_write = True
215+
cache.add("ae_data", sample_dataset)
216+
result = cache.get_by_regex("ae_*")
217+
result["ae_data"].data = (
218+
result["ae_data"].data.drop(index=0).reset_index(drop=True)
219+
)
220+
assert len(cache.cache["ae_data"].data) == 3
221+
222+
def test_no_match_returns_empty_dict(self, cache, sample_dataset):
223+
cache.add("ae_data", sample_dataset)
224+
assert cache.get_by_regex("lb_*") == {}

0 commit comments

Comments
 (0)