Skip to content
This repository was archived by the owner on Mar 26, 2026. It is now read-only.

Commit fdce328

Browse files
committed
chore: cache proto context to optimize generator performance
1 parent 72db9d0 commit fdce328

File tree

6 files changed

+160
-8
lines changed

6 files changed

+160
-8
lines changed

gapic/cli/generate.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from gapic import generator
2424
from gapic.schema import api
2525
from gapic.utils import Options
26+
from gapic.utils.cache import generation_cache_context
2627

2728

2829
@click.command()
@@ -56,15 +57,23 @@ def generate(request: typing.BinaryIO, output: typing.BinaryIO) -> None:
5657
[p.package for p in req.proto_file if p.name in req.file_to_generate]
5758
).rstrip(".")
5859

59-
# Build the API model object.
60-
# This object is a frozen representation of the whole API, and is sent
61-
# to each template in the rendering step.
62-
api_schema = api.API.build(req.proto_file, opts=opts, package=package)
60+
# Create the generation cache context.
61+
# This provides the shared storage for the @cached_proto_context decorator.
62+
# 1. Performance: Memoizes `with_context` calls, speeding up generation significantly.
63+
# 2. Safety: The decorator uses this storage to "pin" Proto objects in memory.
64+
# This prevents Python's Garbage Collector from deleting objects created during
65+
# `API.build` while `Generator.get_response` is still using their IDs.
66+
# (See `gapic.utils.cache.cached_proto_context` for the specific pinning logic).
67+
with generation_cache_context():
68+
# Build the API model object.
69+
# This object is a frozen representation of the whole API, and is sent
70+
# to each template in the rendering step.
71+
api_schema = api.API.build(req.proto_file, opts=opts, package=package)
6372

64-
# Translate into a protobuf CodeGeneratorResponse; this reads the
65-
# individual templates and renders them.
66-
# If there are issues, error out appropriately.
67-
res = generator.Generator(opts).get_response(api_schema, opts)
73+
# Translate into a protobuf CodeGeneratorResponse; this reads the
74+
# individual templates and renders them.
75+
# If there are issues, error out appropriately.
76+
res = generator.Generator(opts).get_response(api_schema, opts)
6877

6978
# Output the serialized response.
7079
output.write(res.SerializeToString())

gapic/schema/metadata.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from gapic.schema import imp
3636
from gapic.schema import naming
3737
from gapic.utils import cached_property
38+
from gapic.utils import cached_proto_context
3839
from gapic.utils import RESERVED_NAMES
3940

4041
# This class is a minor hack to optimize Address's __eq__ method.
@@ -359,6 +360,7 @@ def resolve(self, selector: str) -> str:
359360
return f'{".".join(self.package)}.{selector}'
360361
return selector
361362

363+
@cached_proto_context
362364
def with_context(self, *, collisions: Set[str]) -> "Address":
363365
"""Return a derivative of this address with the provided context.
364366
@@ -398,6 +400,7 @@ def doc(self):
398400
return "\n\n".join(self.documentation.leading_detached_comments)
399401
return ""
400402

403+
@cached_proto_context
401404
def with_context(self, *, collisions: Set[str]) -> "Metadata":
402405
"""Return a derivative of this metadata with the provided context.
403406

gapic/schema/wrappers.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767

6868
from gapic import utils
6969
from gapic.schema import metadata
70+
from gapic.utils import cached_proto_context
7071
from gapic.utils import uri_sample
7172
from gapic.utils import make_private
7273

@@ -410,6 +411,7 @@ def type(self) -> Union["MessageType", "EnumType", "PrimitiveType"]:
410411
"This code should not be reachable; please file a bug."
411412
)
412413

414+
@cached_proto_context
413415
def with_context(
414416
self,
415417
*,
@@ -805,6 +807,7 @@ def get_field(
805807
# message.
806808
return cursor.message.get_field(*field_path[1:], collisions=collisions)
807809

810+
@cached_proto_context
808811
def with_context(
809812
self,
810813
*,
@@ -937,6 +940,7 @@ def ident(self) -> metadata.Address:
937940
"""Return the identifier data to be used in templates."""
938941
return self.meta.address
939942

943+
@cached_proto_context
940944
def with_context(self, *, collisions: Set[str]) -> "EnumType":
941945
"""Return a derivative of this enum with the provided context.
942946
@@ -1058,6 +1062,7 @@ class ExtendedOperationInfo:
10581062
request_type: MessageType
10591063
operation_type: MessageType
10601064

1065+
@cached_proto_context
10611066
def with_context(
10621067
self,
10631068
*,
@@ -1127,6 +1132,7 @@ class OperationInfo:
11271132
response_type: MessageType
11281133
metadata_type: MessageType
11291134

1135+
@cached_proto_context
11301136
def with_context(
11311137
self,
11321138
*,
@@ -1937,6 +1943,7 @@ def void(self) -> bool:
19371943
"""Return True if this method has no return value, False otherwise."""
19381944
return self.output.ident.proto == "google.protobuf.Empty"
19391945

1946+
@cached_proto_context
19401947
def with_context(
19411948
self,
19421949
*,
@@ -2357,6 +2364,7 @@ def operation_polling_method(self) -> Optional[Method]:
23572364
def is_internal(self) -> bool:
23582365
return any(m.is_internal for m in self.methods.values())
23592366

2367+
@cached_proto_context
23602368
def with_context(
23612369
self,
23622370
*,

gapic/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
from gapic.utils.cache import cached_property
16+
from gapic.utils.cache import cached_proto_context
1617
from gapic.utils.case import to_snake_case
1718
from gapic.utils.case import to_camel_case
1819
from gapic.utils.checks import is_msg_field_pb
@@ -34,6 +35,7 @@
3435

3536
__all__ = (
3637
"cached_property",
38+
"cached_proto_context",
3739
"convert_uri_fieldnames",
3840
"doc",
3941
"empty",

gapic/utils/cache.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
# limitations under the License.
1414

1515
import functools
16+
import contextlib
17+
import threading
1618

1719

1820
def cached_property(fx):
@@ -43,3 +45,91 @@ def inner(self):
4345
return self._cached_values[fx.__name__]
4446

4547
return property(inner)
48+
49+
50+
# Thread-local storage for the simple cache dictionary.
51+
# This ensures that parallel generation tasks (if any) do not corrupt each other's cache.
52+
_thread_local = threading.local()
53+
54+
55+
@contextlib.contextmanager
56+
def generation_cache_context():
57+
"""Context manager to explicitly manage the lifecycle of the generation cache.
58+
59+
This manager initializes a fresh dictionary in thread-local storage when entering
60+
the context and strictly deletes it when exiting.
61+
62+
**Memory Management:**
63+
The cache stores strong references to Proto objects to "pin" them in memory
64+
(see `cached_proto_context`). It is critical that this context manager deletes
65+
the dictionary in the `finally` block. Deleting the dictionary breaks the
66+
reference chain, allowing Python's Garbage Collector to finally free all the
67+
large Proto objects that were pinned during generation.
68+
"""
69+
# Initialize the cache as a standard dictionary.
70+
_thread_local.cache = {}
71+
try:
72+
yield
73+
finally:
74+
# Delete the dictionary to free all memory and pinned objects.
75+
# This is essential to prevent memory leaks in long-running processes.
76+
del _thread_local.cache
77+
78+
79+
def cached_proto_context(func):
80+
"""Decorator to memoize `with_context` calls based on object identity and collisions.
81+
82+
This mechanism provides a significant performance boost by preventing
83+
redundant recalculations of naming collisions during template rendering.
84+
85+
Since the Proto wrapper objects are unhashable (mutable), we use `id(self)` as
86+
the primary cache key. Normally, this is dangerous: if the object is garbage
87+
collected, Python might reuse its memory address for a *new* object, leading to
88+
a cache collision (the "Zombie ID" bug).
89+
90+
To prevent this, this decorator stores the value as a tuple: `(result, self)`.
91+
By keeping a reference to `self` in the cache value, we "pin" the object in
92+
memory. This forces the Garbage Collector to keep the object alive, guaranteeing
93+
that `id(self)` remains unique for the entire lifespan of the `generation_cache_context`.
94+
95+
Args:
96+
func (Callable): The function to decorate (usually `with_context`).
97+
98+
Returns:
99+
Callable: The wrapped function with caching and pinning logic.
100+
"""
101+
102+
@functools.wraps(func)
103+
def wrapper(self, *, collisions, **kwargs):
104+
105+
# 1. Check for active cache (returns None if context is not active)
106+
context_cache = getattr(_thread_local, "cache", None)
107+
108+
# If we are not inside a generation_cache_context (e.g. unit tests),
109+
# bypass the cache entirely.
110+
if context_cache is None:
111+
return func(self, collisions=collisions, **kwargs)
112+
113+
# 2. Create the cache key
114+
# We use frozenset for collisions to make it hashable.
115+
# We use id(self) because 'self' is not hashable.
116+
collisions_key = frozenset(collisions) if collisions else None
117+
key = (id(self), collisions_key)
118+
119+
# 3. Check Cache
120+
if key in context_cache:
121+
# The cache stores (result, pinned_object). We return just the result.
122+
return context_cache[key][0]
123+
124+
# 4. Execute the actual function
125+
# We ensure context_cache is passed down to the recursive calls
126+
result = func(self, collisions=collisions, **kwargs)
127+
128+
# 5. Update Cache & Pin Object
129+
# We store (result, self). The reference to 'self' prevents garbage collection,
130+
# ensuring that 'id(self)' cannot be reused for a new object while this
131+
# cache entry exists.
132+
context_cache[key] = (result, self)
133+
return result
134+
135+
return wrapper

tests/unit/utils/test_cache.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,43 @@ def bar(self):
3131
assert foo.call_count == 1
3232
assert foo.bar == 42
3333
assert foo.call_count == 1
34+
35+
36+
def test_cached_proto_context():
37+
class Foo:
38+
def __init__(self):
39+
self.call_count = 0
40+
41+
# We define a signature that matches the real Proto.with_context
42+
# to ensure arguments are propagated correctly.
43+
@cache.cached_proto_context
44+
def with_context(self, collisions, *, skip_fields=False, visited_messages=None):
45+
self.call_count += 1
46+
return f"val-{self.call_count}"
47+
48+
foo = Foo()
49+
50+
# 1. Test Bypass (No Context)
51+
# The cache is not active, so every call increments the counter.
52+
assert foo.with_context(collisions={"a"}) == "val-1"
53+
assert foo.with_context(collisions={"a"}) == "val-2"
54+
55+
# 2. Test Context Activation
56+
with cache.generation_cache_context():
57+
# Reset counter to make tracking easier
58+
foo.call_count = 0
59+
60+
# A. Basic Cache Hit
61+
assert foo.with_context(collisions={"a"}) == "val-1", "a"
62+
assert foo.with_context(collisions={"a"}) == "val-1" # Hit
63+
assert foo.call_count == 1
64+
65+
# B. Collision Difference
66+
# Changing collisions creates a new key
67+
assert foo.with_context(collisions={"b"}) == "val-2"
68+
assert foo.call_count == 2
69+
70+
# 3. Context Cleared
71+
# Everything should be forgotten now.
72+
assert getattr(cache._thread_local, "cache", None) is None
73+
assert foo.with_context(collisions={"a"}) == "val-3"

0 commit comments

Comments
 (0)