Skip to content
This repository was archived by the owner on Mar 31, 2026. It is now read-only.

Commit fd16d8a

Browse files
committed
fix(metrics): resolve thread-safety race condition in MetricsInterceptor
1 parent 2c5eb96 commit fd16d8a

File tree

2 files changed

+54
-11
lines changed

2 files changed

+54
-11
lines changed

google/cloud/spanner_v1/metrics/metrics_interceptor.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -126,10 +126,8 @@ def intercept(self, invoked_method, request_or_iterator, call_details):
126126
The RPC response
127127
"""
128128
factory = SpannerMetricsTracerFactory()
129-
if (
130-
SpannerMetricsTracerFactory.current_metrics_tracer is None
131-
or not factory.enabled
132-
):
129+
tracer = SpannerMetricsTracerFactory.current_metrics_tracer
130+
if tracer is None or not factory.enabled:
133131
return invoked_method(request_or_iterator, call_details)
134132

135133
# Setup Metric Tracer attributes from call details
@@ -142,15 +140,13 @@ def intercept(self, invoked_method, request_or_iterator, call_details):
142140
call_details.method, SPANNER_METHOD_PREFIX
143141
).replace("/", ".")
144142

145-
SpannerMetricsTracerFactory.current_metrics_tracer.set_method(method_name)
146-
SpannerMetricsTracerFactory.current_metrics_tracer.record_attempt_start()
143+
tracer.set_method(method_name)
144+
tracer.record_attempt_start()
147145
response = invoked_method(request_or_iterator, call_details)
148-
SpannerMetricsTracerFactory.current_metrics_tracer.record_attempt_completion()
146+
tracer.record_attempt_completion()
149147

150148
# Process and send GFE metrics if enabled
151-
if SpannerMetricsTracerFactory.current_metrics_tracer.gfe_enabled:
149+
if tracer.gfe_enabled:
152150
metadata = response.initial_metadata()
153-
SpannerMetricsTracerFactory.current_metrics_trace.record_gfe_metrics(
154-
metadata
155-
)
151+
tracer.record_gfe_metrics(metadata)
156152
return response

tests/unit/test_metrics_interceptor.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
# limitations under the License.
1414

1515
import pytest
16+
import threading
17+
import time
1618
from google.cloud.spanner_v1.metrics.metrics_interceptor import MetricsInterceptor
1719
from google.cloud.spanner_v1.metrics.spanner_metrics_tracer_factory import (
1820
SpannerMetricsTracerFactory,
@@ -102,6 +104,51 @@ def test_intercept_with_tracer(interceptor):
102104
mock_invoked_method.assert_called_once_with("request", call_details)
103105

104106

107+
def test_intercept_thread_safety(interceptor):
108+
# Regression test for race condition where current_metrics_tracer changes mid-call
109+
110+
# Mock tracers
111+
tracer_a = MagicMock()
112+
tracer_a.gfe_enabled = False
113+
tracer_b = MagicMock()
114+
tracer_a.gfe_enabled = False
115+
116+
call_details = MagicMock(
117+
method="spanner.Commit",
118+
metadata=[],
119+
)
120+
121+
def mock_invoked_method(*args, **kwargs):
122+
# Simulate network delay to allow thread switch
123+
time.sleep(0.1)
124+
return MagicMock()
125+
126+
def thread_a_func():
127+
# Set Tracer A
128+
SpannerMetricsTracerFactory.current_metrics_tracer = tracer_a
129+
# Call intercept
130+
interceptor.intercept(mock_invoked_method, None, call_details)
131+
132+
def thread_b_func():
133+
time.sleep(0.05) # Wait for A to start
134+
# Overwrite with Tracer B
135+
SpannerMetricsTracerFactory.current_metrics_tracer = tracer_b
136+
137+
t1 = threading.Thread(target=thread_a_func)
138+
t2 = threading.Thread(target=thread_b_func)
139+
140+
t1.start()
141+
t2.start()
142+
143+
t1.join()
144+
t2.join()
145+
146+
# Verify that Tracer A was used for completion, NOT Tracer B
147+
# Because Thread A started with Tracer A, it should finish with Tracer A
148+
tracer_a.record_attempt_completion.assert_called_once()
149+
tracer_b.record_attempt_completion.assert_not_called()
150+
151+
105152
class MockMetricTracer:
106153
def __init__(self):
107154
self.project = None

0 commit comments

Comments
 (0)