Skip to content

Commit b97131a

Browse files
lukebaumanncopybara-github
authored andcommitted
Added LRU cache and tests to pathwaysutils.
PiperOrigin-RevId: 768240563
1 parent 3e77709 commit b97131a

2 files changed

Lines changed: 129 additions & 0 deletions

File tree

pathwaysutils/lru_cache.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""An LRU cache that will be cleared when JAX clears its internal cache."""
15+
16+
import functools
17+
from typing import Any, Callable
18+
19+
import jax.extend
20+
21+
22+
def lru_cache(
23+
maxsize: int = 4096,
24+
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
25+
"""An LRU cache that will be cleared when JAX clears its internal cache.
26+
27+
Args:
28+
maxsize: The maximum number of entries to keep in the cache. When this limit
29+
is reached, the least recently used entry will be evicted.
30+
31+
Returns:
32+
A function that can be used to decorate a function to cache its results.
33+
"""
34+
35+
def wrap(f):
36+
cached = functools.lru_cache(maxsize=maxsize)(f)
37+
wrapper = functools.wraps(f)(cached)
38+
39+
wrapper.cache_clear = cached.cache_clear
40+
wrapper.cache_info = cached.cache_info
41+
jax.extend.backend.add_clear_backends_callback(wrapper.cache_clear)
42+
return wrapper
43+
44+
return wrap
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import jax.extend
16+
from pathwaysutils import lru_cache
17+
from absl.testing import absltest
18+
19+
20+
class LruCacheTest(absltest.TestCase):
21+
22+
def test_cache_hits(self):
23+
x = [100]
24+
25+
@lru_cache.lru_cache(maxsize=1)
26+
def f(i):
27+
x[i] += 1
28+
return x[i]
29+
30+
self.assertEqual(f(0), 101) # Miss
31+
self.assertEqual(f(0), 101) # Hit
32+
33+
def test_cache_hits_and_misses_by_arguments(self):
34+
x = [100, 200]
35+
36+
@lru_cache.lru_cache(maxsize=2)
37+
def f(i):
38+
x[i] += 1
39+
return x[i]
40+
41+
self.assertEqual(f(0), 101) # Miss
42+
self.assertEqual(f(0), 101) # Hit
43+
44+
self.assertEqual(f(1), 201) # Miss
45+
self.assertEqual(f(1), 201) # Hit
46+
47+
self.assertEqual(f(0), 101) # Hit
48+
self.assertEqual(f(0), 101) # Hit
49+
50+
def test_cache_lru_eviction(self):
51+
x = [100, 200]
52+
53+
@lru_cache.lru_cache(maxsize=1)
54+
def f(i):
55+
x[i] += 1
56+
return x[i]
57+
58+
self.assertEqual(f(0), 101) # Miss
59+
self.assertEqual(f(0), 101) # Hit
60+
61+
self.assertEqual(f(1), 201) # Miss
62+
self.assertEqual(f(1), 201) # Hit
63+
64+
self.assertEqual(f(0), 102) # Miss
65+
self.assertEqual(f(0), 102) # Hit
66+
67+
def test_clear_cache_via_jax_clear_backend_cache(self):
68+
x = [100]
69+
70+
@lru_cache.lru_cache(maxsize=1)
71+
def f(i):
72+
x[i] += 1
73+
return x[i]
74+
75+
self.assertEqual(f(0), 101) # Miss
76+
self.assertEqual(f(0), 101) # Hit
77+
78+
jax.extend.backend.clear_backends()
79+
80+
self.assertEqual(f(0), 102) # Miss
81+
self.assertEqual(f(0), 102) # Hit
82+
83+
84+
if __name__ == "__main__":
85+
absltest.main()

0 commit comments

Comments
 (0)