Skip to content
This repository was archived by the owner on Mar 31, 2026. It is now read-only.

Commit 519755b

Browse files
committed
fix: Metrics thread-safety refactor and Batch.commit idempotency fix
1 parent 67c682e commit 519755b

File tree

9 files changed

+195
-80
lines changed

9 files changed

+195
-80
lines changed

google/cloud/spanner_v1/batch.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
"""Context manager for Cloud Spanner batched writes."""
16+
1617
import functools
1718
from typing import List, Optional
1819

@@ -242,6 +243,8 @@ def commit(
242243
observability_options=getattr(database, "observability_options", None),
243244
metadata=metadata,
244245
) as span, MetricsCapture():
246+
nth_request = getattr(database, "_next_nth_request", 0)
247+
attempt = AtomicCounter(0)
245248

246249
def wrapped_method():
247250
commit_request = CommitRequest(
@@ -256,8 +259,8 @@ def wrapped_method():
256259
# should be increased. attempt can only be increased if
257260
# we encounter UNAVAILABLE or INTERNAL.
258261
call_metadata, error_augmenter = database.with_error_augmentation(
259-
getattr(database, "_next_nth_request", 0),
260-
1,
262+
nth_request,
263+
attempt.increment(),
261264
metadata,
262265
span,
263266
)

google/cloud/spanner_v1/client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
* a :class:`~google.cloud.spanner_v1.instance.Instance` owns a
2424
:class:`~google.cloud.spanner_v1.database.Database`
2525
"""
26+
2627
import grpc
2728
import os
2829
import logging

google/cloud/spanner_v1/metrics/metrics_capture.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,18 @@
2323
from .spanner_metrics_tracer_factory import SpannerMetricsTracerFactory
2424

2525

26+
from contextvars import Token
27+
28+
2629
class MetricsCapture:
2730
"""Context manager for capturing metrics in Cloud Spanner operations.
2831
2932
This class provides a context manager interface to automatically handle
3033
the start and completion of metrics tracing for a given operation.
3134
"""
3235

36+
_token: Token
37+
3338
def __enter__(self):
3439
"""Enter the runtime context related to this object.
3540
@@ -45,11 +50,13 @@ def __enter__(self):
4550
return self
4651

4752
# Define a new metrics tracer for the new operation
48-
SpannerMetricsTracerFactory.current_metrics_tracer = (
49-
factory.create_metrics_tracer()
53+
# Set the context var and keep the token for reset
54+
tracer = factory.create_metrics_tracer()
55+
self._token = SpannerMetricsTracerFactory._current_metrics_tracer_ctx.set(
56+
tracer
5057
)
51-
if SpannerMetricsTracerFactory.current_metrics_tracer:
52-
SpannerMetricsTracerFactory.current_metrics_tracer.record_operation_start()
58+
if tracer:
59+
tracer.record_operation_start()
5360
return self
5461

5562
def __exit__(self, exc_type, exc_value, traceback):
@@ -70,6 +77,11 @@ def __exit__(self, exc_type, exc_value, traceback):
7077
if not SpannerMetricsTracerFactory().enabled:
7178
return False
7279

73-
if SpannerMetricsTracerFactory.current_metrics_tracer:
74-
SpannerMetricsTracerFactory.current_metrics_tracer.record_operation_completion()
80+
tracer = SpannerMetricsTracerFactory._current_metrics_tracer_ctx.get()
81+
if tracer:
82+
tracer.record_operation_completion()
83+
84+
# Reset the context var using the token
85+
if getattr(self, "_token", None):
86+
SpannerMetricsTracerFactory._current_metrics_tracer_ctx.reset(self._token)
7587
return False # Propagate the exception if any

google/cloud/spanner_v1/metrics/metrics_interceptor.py

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -97,22 +97,17 @@ def _set_metrics_tracer_attributes(self, resources: Dict[str, str]) -> None:
9797
Args:
9898
resources (Dict[str, str]): A dictionary containing project, instance, and database information.
9999
"""
100-
if SpannerMetricsTracerFactory.current_metrics_tracer is None:
100+
tracer = SpannerMetricsTracerFactory.get_current_tracer()
101+
if tracer is None:
101102
return
102103

103104
if resources:
104105
if "project" in resources:
105-
SpannerMetricsTracerFactory.current_metrics_tracer.set_project(
106-
resources["project"]
107-
)
106+
tracer.set_project(resources["project"])
108107
if "instance" in resources:
109-
SpannerMetricsTracerFactory.current_metrics_tracer.set_instance(
110-
resources["instance"]
111-
)
108+
tracer.set_instance(resources["instance"])
112109
if "database" in resources:
113-
SpannerMetricsTracerFactory.current_metrics_tracer.set_database(
114-
resources["database"]
115-
)
110+
tracer.set_database(resources["database"])
116111

117112
def intercept(self, invoked_method, request_or_iterator, call_details):
118113
"""Intercept gRPC calls to collect metrics.
@@ -126,10 +121,8 @@ def intercept(self, invoked_method, request_or_iterator, call_details):
126121
The RPC response
127122
"""
128123
factory = SpannerMetricsTracerFactory()
129-
if (
130-
SpannerMetricsTracerFactory.current_metrics_tracer is None
131-
or not factory.enabled
132-
):
124+
tracer = SpannerMetricsTracerFactory.get_current_tracer()
125+
if tracer is None or not factory.enabled:
133126
return invoked_method(request_or_iterator, call_details)
134127

135128
# Setup Metric Tracer attributes from call details
@@ -142,15 +135,13 @@ def intercept(self, invoked_method, request_or_iterator, call_details):
142135
call_details.method, SPANNER_METHOD_PREFIX
143136
).replace("/", ".")
144137

145-
SpannerMetricsTracerFactory.current_metrics_tracer.set_method(method_name)
146-
SpannerMetricsTracerFactory.current_metrics_tracer.record_attempt_start()
138+
tracer.set_method(method_name)
139+
tracer.record_attempt_start()
147140
response = invoked_method(request_or_iterator, call_details)
148-
SpannerMetricsTracerFactory.current_metrics_tracer.record_attempt_completion()
141+
tracer.record_attempt_completion()
149142

150143
# Process and send GFE metrics if enabled
151-
if SpannerMetricsTracerFactory.current_metrics_tracer.gfe_enabled:
144+
if tracer.gfe_enabled:
152145
metadata = response.initial_metadata()
153-
SpannerMetricsTracerFactory.current_metrics_trace.record_gfe_metrics(
154-
metadata
155-
)
146+
tracer.record_gfe_metrics(metadata)
156147
return response

google/cloud/spanner_v1/metrics/spanner_metrics_tracer_factory.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import os
2020
import logging
2121
from .constants import SPANNER_SERVICE_NAME
22+
import contextvars
2223

2324
try:
2425
import mmh3
@@ -43,7 +44,9 @@ class SpannerMetricsTracerFactory(MetricsTracerFactory):
4344
"""A factory for creating SpannerMetricsTracer instances."""
4445

4546
_metrics_tracer_factory: "SpannerMetricsTracerFactory" = None
46-
current_metrics_tracer: MetricsTracer = None
47+
_current_metrics_tracer_ctx = contextvars.ContextVar(
48+
"current_metrics_tracer", default=None
49+
)
4750

4851
def __new__(
4952
cls, enabled: bool = True, gfe_enabled: bool = False
@@ -80,10 +83,18 @@ def __new__(
8083
cls._metrics_tracer_factory.gfe_enabled = gfe_enabled
8184

8285
if cls._metrics_tracer_factory.enabled != enabled:
83-
cls._metrics_tracer_factory.enabeld = enabled
86+
cls._metrics_tracer_factory.enabled = enabled
8487

8588
return cls._metrics_tracer_factory
8689

90+
@staticmethod
91+
def get_current_tracer() -> MetricsTracer:
92+
return SpannerMetricsTracerFactory._current_metrics_tracer_ctx.get()
93+
94+
@property
95+
def current_metrics_tracer(self) -> MetricsTracer:
96+
return SpannerMetricsTracerFactory._current_metrics_tracer_ctx.get()
97+
8798
@staticmethod
8899
def _generate_client_uid() -> str:
89100
"""Generate a client UID in the form of uuidv4@pid@hostname.

tests/unit/conftest.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import pytest
2+
from unittest.mock import patch
3+
4+
5+
@pytest.fixture(autouse=True)
6+
def mock_periodic_exporting_metric_reader():
7+
"""Globally mock PeriodicExportingMetricReader to prevent real network calls."""
8+
with patch(
9+
"google.cloud.spanner_v1.client.PeriodicExportingMetricReader"
10+
) as mock_client_reader, patch(
11+
"opentelemetry.sdk.metrics.export.PeriodicExportingMetricReader"
12+
):
13+
yield mock_client_reader

tests/unit/test_metrics.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,19 +65,25 @@ def patched_client(monkeypatch):
6565

6666
client_module._metrics_monitor_initialized = False
6767

68-
with patch("google.cloud.spanner_v1.client.CloudMonitoringMetricsExporter"):
68+
with patch(
69+
"google.cloud.spanner_v1.metrics.metrics_exporter.MetricServiceClient"
70+
), patch(
71+
"google.cloud.spanner_v1.metrics.metrics_exporter.CloudMonitoringMetricsExporter"
72+
), patch(
73+
"opentelemetry.sdk.metrics.export.PeriodicExportingMetricReader"
74+
):
6975
client = Client(
7076
project="test",
7177
credentials=TestCredentials(),
72-
# client_options={"api_endpoint": "none"}
7378
)
7479
yield client
7580

7681
# Resetting
7782
metrics.set_meter_provider(metrics.NoOpMeterProvider())
7883
SpannerMetricsTracerFactory._metrics_tracer_factory = None
79-
SpannerMetricsTracerFactory.current_metrics_tracer = None
80-
client_module._metrics_monitor_initialized = False
84+
# Reset context var
85+
ctx = SpannerMetricsTracerFactory._current_metrics_tracer_ctx
86+
ctx.set(None)
8187

8288

8389
def test_metrics_emission_with_failure_attempt(patched_client):
@@ -92,10 +98,14 @@ def test_metrics_emission_with_failure_attempt(patched_client):
9298
original_intercept = metrics_interceptor.intercept
9399
first_attempt = True
94100

101+
captured_tracer_list = []
102+
95103
def mocked_raise(*args, **kwargs):
96104
raise ServiceUnavailable("Service Unavailable")
97105

98106
def mocked_call(*args, **kwargs):
107+
# Capture the tracer while it is active
108+
captured_tracer_list.append(SpannerMetricsTracerFactory.get_current_tracer())
99109
return _UnaryOutcome(MagicMock(), MagicMock())
100110

101111
def intercept_wrapper(invoked_method, request_or_iterator, call_details):
@@ -113,11 +123,14 @@ def intercept_wrapper(invoked_method, request_or_iterator, call_details):
113123

114124
metrics_interceptor.intercept = intercept_wrapper
115125
patch_path = "google.cloud.spanner_v1.metrics.metrics_exporter.CloudMonitoringMetricsExporter.export"
126+
116127
with patch(patch_path):
117128
with database.snapshot():
118129
pass
119130

120131
# Verify that the attempt count increased from the failed initial attempt
121-
assert (
122-
SpannerMetricsTracerFactory.current_metrics_tracer.current_op.attempt_count
123-
) == 2
132+
# We use the captured tracer from the SUCCESSFUL attempt (the second one)
133+
assert len(captured_tracer_list) > 0
134+
tracer = captured_tracer_list[0]
135+
assert tracer is not None
136+
assert tracer.current_op.attempt_count == 2
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import threading
2+
import time
3+
import unittest
4+
from google.cloud.spanner_v1.metrics.spanner_metrics_tracer_factory import (
5+
SpannerMetricsTracerFactory,
6+
)
7+
from google.cloud.spanner_v1.metrics.metrics_capture import MetricsCapture
8+
9+
10+
class TestMetricsConcurrency(unittest.TestCase):
11+
def setUp(self):
12+
# Reset factory singleton
13+
SpannerMetricsTracerFactory._metrics_tracer_factory = None
14+
15+
def test_concurrent_tracers(self):
16+
"""Verify that concurrent threads have isolated tracers."""
17+
factory = SpannerMetricsTracerFactory(enabled=True)
18+
# Ensure enabled
19+
factory.enabled = True
20+
21+
errors = []
22+
23+
def worker(idx):
24+
try:
25+
# Simulate a request workflow
26+
with MetricsCapture():
27+
# Capture should have set a tracer
28+
tracer = SpannerMetricsTracerFactory.get_current_tracer()
29+
if tracer is None:
30+
errors.append(f"Thread {idx}: Tracer is None inside Capture")
31+
return
32+
33+
# Set a unique attribute for this thread
34+
project_name = f"project-{idx}"
35+
tracer.set_project(project_name)
36+
37+
# Simulate some work
38+
time.sleep(0.01)
39+
40+
# Verify verify we still have OUR tracer
41+
current_tracer = SpannerMetricsTracerFactory.get_current_tracer()
42+
if current_tracer.client_attributes["project_id"] != project_name:
43+
errors.append(
44+
f"Thread {idx}: Tracer project mismatch. Expected {project_name}, got {current_tracer.client_attributes.get('project_id')}"
45+
)
46+
47+
# Check interceptor logic (simulated)
48+
# Interceptor reads from factory.current_metrics_tracer
49+
interceptor_tracer = (
50+
SpannerMetricsTracerFactory.get_current_tracer()
51+
)
52+
if interceptor_tracer is not tracer:
53+
errors.append(f"Thread {idx}: Interceptor tracer mismatch")
54+
55+
except Exception as e:
56+
errors.append(f"Thread {idx}: Exception {e}")
57+
58+
threads = []
59+
for i in range(10):
60+
t = threading.Thread(target=worker, args=(i,))
61+
threads.append(t)
62+
t.start()
63+
64+
for t in threads:
65+
t.join()
66+
67+
self.assertEqual(errors, [], f"Concurrency errors found: {errors}")
68+
69+
def test_context_var_cleanup(self):
70+
"""Verify tracer is cleaned up after ContextVar reset."""
71+
SpannerMetricsTracerFactory(enabled=True)
72+
73+
with MetricsCapture():
74+
self.assertIsNotNone(SpannerMetricsTracerFactory.get_current_tracer())
75+
76+
self.assertIsNone(SpannerMetricsTracerFactory.get_current_tracer())
77+
78+
79+
if __name__ == "__main__":
80+
unittest.main()

0 commit comments

Comments
 (0)