Skip to content

Commit 7f909bd

Browse files
committed
Add NumPy hash coverage and benchmark against xxhash
1 parent 7523386 commit 7f909bd

File tree

3 files changed

+237
-6
lines changed

3 files changed

+237
-6
lines changed

scripts/benchmark_numpy_hash.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
"""Benchmark default Cachier hashing against xxhash for large NumPy arrays."""
2+
3+
from __future__ import annotations
4+
5+
import argparse
6+
import pickle
7+
import statistics
8+
import time
9+
from typing import Any, Callable, Dict, List
10+
11+
import numpy as np
12+
13+
from cachier.config import _default_hash_func
14+
15+
16+
def _xxhash_numpy_hash(args: tuple[Any, ...], kwds: dict[str, Any]) -> str:
17+
"""Hash call arguments with xxhash, optimized for NumPy arrays.
18+
19+
Parameters
20+
----------
21+
args : tuple[Any, ...]
22+
Positional arguments.
23+
kwds : dict[str, Any]
24+
Keyword arguments.
25+
26+
Returns
27+
-------
28+
str
29+
xxhash hex digest.
30+
31+
"""
32+
import xxhash
33+
34+
hasher = xxhash.xxh64()
35+
hasher.update(b"args")
36+
for value in args:
37+
if isinstance(value, np.ndarray):
38+
hasher.update(value.dtype.str.encode("utf-8"))
39+
hasher.update(str(value.shape).encode("utf-8"))
40+
hasher.update(value.tobytes(order="C"))
41+
else:
42+
hasher.update(pickle.dumps(value, protocol=pickle.HIGHEST_PROTOCOL))
43+
44+
hasher.update(b"kwds")
45+
for key, value in sorted(kwds.items()):
46+
hasher.update(pickle.dumps(key, protocol=pickle.HIGHEST_PROTOCOL))
47+
if isinstance(value, np.ndarray):
48+
hasher.update(value.dtype.str.encode("utf-8"))
49+
hasher.update(str(value.shape).encode("utf-8"))
50+
hasher.update(value.tobytes(order="C"))
51+
else:
52+
hasher.update(pickle.dumps(value, protocol=pickle.HIGHEST_PROTOCOL))
53+
54+
return hasher.hexdigest()
55+
56+
57+
def _benchmark(hash_func: Callable[[tuple[Any, ...], dict[str, Any]], str], args: tuple[Any, ...], runs: int) -> float:
58+
durations: List[float] = []
59+
for _ in range(runs):
60+
start = time.perf_counter()
61+
hash_func(args, {})
62+
durations.append(time.perf_counter() - start)
63+
return statistics.median(durations)
64+
65+
66+
def main() -> None:
67+
"""Run benchmark comparing cachier default hashing with xxhash."""
68+
parser = argparse.ArgumentParser(description=__doc__)
69+
parser.add_argument(
70+
"--elements",
71+
type=int,
72+
default=10_000_000,
73+
help="Number of float64 elements in the benchmark array",
74+
)
75+
parser.add_argument("--runs", type=int, default=7, help="Number of benchmark runs")
76+
parsed = parser.parse_args()
77+
78+
try:
79+
import xxhash # noqa: F401
80+
except ImportError as error:
81+
raise SystemExit("Missing dependency: xxhash. Install with `pip install xxhash`.") from error
82+
83+
array = np.arange(parsed.elements, dtype=np.float64)
84+
args = (array,)
85+
86+
results: Dict[str, float] = {
87+
"cachier_default": _benchmark(_default_hash_func, args, parsed.runs),
88+
"xxhash_reference": _benchmark(_xxhash_numpy_hash, args, parsed.runs),
89+
}
90+
91+
ratio = results["cachier_default"] / results["xxhash_reference"]
92+
93+
print(f"Array elements: {parsed.elements:,}")
94+
print(f"Array bytes: {array.nbytes:,}")
95+
print(f"Runs: {parsed.runs}")
96+
print(f"cachier_default median: {results['cachier_default']:.6f}s")
97+
print(f"xxhash_reference median: {results['xxhash_reference']:.6f}s")
98+
print(f"ratio (cachier_default / xxhash_reference): {ratio:.2f}x")
99+
100+
101+
if __name__ == "__main__":
102+
main()

src/cachier/config.py

Lines changed: 92 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,99 @@
99
from ._types import Backend, HashFunc, Mongetter
1010

1111

12+
def _is_numpy_array(value: Any) -> bool:
13+
"""Check whether a value is a NumPy ndarray without importing NumPy eagerly.
14+
15+
Parameters
16+
----------
17+
value : Any
18+
The value to inspect.
19+
20+
Returns
21+
-------
22+
bool
23+
True when ``value`` is a NumPy ndarray instance.
24+
25+
"""
26+
return type(value).__module__ == "numpy" and type(value).__name__ == "ndarray"
27+
28+
29+
def _hash_numpy_array(hasher: "hashlib._Hash", value: Any) -> None:
30+
"""Update hasher with NumPy array metadata and buffer content.
31+
32+
Parameters
33+
----------
34+
hasher : hashlib._Hash
35+
The hasher to update.
36+
value : Any
37+
A NumPy ndarray instance.
38+
39+
"""
40+
hasher.update(b"numpy.ndarray")
41+
hasher.update(value.dtype.str.encode("utf-8"))
42+
hasher.update(str(value.shape).encode("utf-8"))
43+
hasher.update(value.tobytes(order="C"))
44+
45+
46+
def _update_hash_for_value(hasher: "hashlib._Hash", value: Any) -> None:
47+
"""Update hasher with a stable representation of a Python value.
48+
49+
Parameters
50+
----------
51+
hasher : hashlib._Hash
52+
The hasher to update.
53+
value : Any
54+
Value to encode.
55+
56+
"""
57+
if _is_numpy_array(value):
58+
_hash_numpy_array(hasher, value)
59+
return
60+
61+
if isinstance(value, tuple):
62+
hasher.update(b"tuple")
63+
for item in value:
64+
_update_hash_for_value(hasher, item)
65+
return
66+
67+
if isinstance(value, list):
68+
hasher.update(b"list")
69+
for item in value:
70+
_update_hash_for_value(hasher, item)
71+
return
72+
73+
if isinstance(value, dict):
74+
hasher.update(b"dict")
75+
for dict_key in sorted(value):
76+
_update_hash_for_value(hasher, dict_key)
77+
_update_hash_for_value(hasher, value[dict_key])
78+
return
79+
80+
hasher.update(pickle.dumps(value, protocol=pickle.HIGHEST_PROTOCOL))
81+
82+
1283
def _default_hash_func(args, kwds):
13-
# Sort the kwargs to ensure consistent ordering
14-
sorted_kwargs = sorted(kwds.items())
15-
# Serialize args and sorted_kwargs using pickle or similar
16-
serialized = pickle.dumps((args, sorted_kwargs))
17-
# Create a hash of the serialized data
18-
return hashlib.sha256(serialized).hexdigest()
84+
"""Compute a stable hash key for function arguments.
85+
86+
Parameters
87+
----------
88+
args : tuple
89+
Positional arguments.
90+
kwds : dict
91+
Keyword arguments.
92+
93+
Returns
94+
-------
95+
str
96+
A hex digest representing the call arguments.
97+
98+
"""
99+
hasher = hashlib.blake2b(digest_size=32)
100+
hasher.update(b"args")
101+
_update_hash_for_value(hasher, args)
102+
hasher.update(b"kwds")
103+
_update_hash_for_value(hasher, dict(sorted(kwds.items())))
104+
return hasher.hexdigest()
19105

20106

21107
def _default_cache_dir():

tests/test_numpy_hash.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
"""Tests for NumPy-aware default hash behavior."""
2+
3+
from datetime import timedelta
4+
5+
import pytest
6+
7+
from cachier import cachier
8+
9+
np = pytest.importorskip("numpy")
10+
11+
12+
@pytest.mark.parametrize("backend", ["memory", "pickle"])
13+
def test_default_hash_func_uses_array_content_for_cache_keys(backend, tmp_path):
14+
"""Verify equal arrays map to a cache hit and different arrays miss."""
15+
call_count = 0
16+
17+
decorator_kwargs = {"backend": backend, "stale_after": timedelta(seconds=120)}
18+
if backend == "pickle":
19+
decorator_kwargs["cache_dir"] = tmp_path
20+
21+
@cachier(**decorator_kwargs)
22+
def array_sum(values):
23+
nonlocal call_count
24+
call_count += 1
25+
return int(values.sum())
26+
27+
arr = np.arange(100_000, dtype=np.int64)
28+
arr_copy = arr.copy()
29+
changed = arr.copy()
30+
changed[-1] = -1
31+
32+
first = array_sum(arr)
33+
assert call_count == 1
34+
35+
second = array_sum(arr_copy)
36+
assert second == first
37+
assert call_count == 1
38+
39+
third = array_sum(changed)
40+
assert third != first
41+
assert call_count == 2
42+
43+
array_sum.clear_cache()

0 commit comments

Comments
 (0)