Skip to content

Commit 1f52008

Browse files
committed
feat(x-goog-spanner-request-id): introduce AtomicCounter
This change introduces AtomicCounter, a concurrency/thread-safe counter do deal with the multi-threaded nature of variables. It permits operations: * atomic_counter += 1 * value = atomic_counter + 1 * atomic_counter.value that'll be paramount to bringing in the logic for x-goog-spanner-request-id in much reduced changelists. Updates #1261 Carved out from PR googleapis#1264
1 parent ad69c48 commit 1f52008

File tree

3 files changed

+166
-0
lines changed

3 files changed

+166
-0
lines changed

google/cloud/spanner_v1/_helpers.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import math
2020
import time
2121
import base64
22+
import threading
2223

2324
from google.protobuf.struct_pb2 import ListValue
2425
from google.protobuf.struct_pb2 import Value
@@ -525,3 +526,45 @@ def _metadata_with_leader_aware_routing(value, **kw):
525526
List[Tuple[str, str]]: RPC metadata with leader aware routing header
526527
"""
527528
return ("x-goog-spanner-route-to-leader", str(value).lower())
529+
530+
531+
class AtomicCounter:
532+
def __init__(self, start_value=0):
533+
self.__lock = threading.Lock()
534+
self.__value = start_value
535+
536+
@property
537+
def value(self):
538+
with self.__lock:
539+
return self.__value
540+
541+
def increment(self, n=1):
542+
with self.__lock:
543+
self.__value += n
544+
return self.__value
545+
546+
def __iadd__(self, n):
547+
"""
548+
Defines the inplace += operator result.
549+
"""
550+
with self.__lock:
551+
self.__value += n
552+
return self
553+
554+
def __add__(self, n):
555+
"""
556+
Defines the result of invoking: value = AtomicCounter + addable
557+
"""
558+
with self.__lock:
559+
n += self.__value
560+
return n
561+
562+
def __radd__(self, n):
563+
"""
564+
Defines the result of invoking: value = addable + AtomicCounter
565+
"""
566+
return self.__add__(n)
567+
568+
569+
def _metadata_with_request_id(*args, **kwargs):
570+
return with_request_id(*args, **kwargs)
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Copyright 2024 Google LLC All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import threading
17+
18+
REQ_ID_VERSION = 1 # The version of the x-goog-spanner-request-id spec.
19+
REQ_ID_HEADER_KEY = "x-goog-spanner-request-id"
20+
21+
22+
def generate_rand_uint64():
23+
b = os.urandom(8)
24+
return (
25+
b[7] & 0xFF
26+
| (b[6] & 0xFF) << 8
27+
| (b[5] & 0xFF) << 16
28+
| (b[4] & 0xFF) << 24
29+
| (b[3] & 0xFF) << 32
30+
| (b[2] & 0xFF) << 36
31+
| (b[1] & 0xFF) << 48
32+
| (b[0] & 0xFF) << 56
33+
)
34+
35+
36+
REQ_RAND_PROCESS_ID = generate_rand_uint64()
37+
38+
39+
def with_request_id(client_id, channel_id, nth_request, attempt, other_metadata=[]):
40+
req_id = f"{REQ_ID_VERSION}.{REQ_RAND_PROCESS_ID}.{client_id}.{channel_id}.{nth_request}.{attempt}"
41+
other_metadata.append((REQ_ID_HEADER_KEY, req_id))
42+
return other_metadata

tests/unit/test_atomic_counter.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Copyright 2024 Google LLC All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import time
16+
import random
17+
import threading
18+
import unittest
19+
from google.cloud.spanner_v1._helpers import AtomicCounter
20+
21+
22+
class TestAtomicCounter(unittest.TestCase):
23+
def test_initialization(self):
24+
ac_default = AtomicCounter()
25+
assert ac_default.value == 0
26+
27+
ac_1 = AtomicCounter(1)
28+
assert ac_1.value == 1
29+
30+
ac_negative_1 = AtomicCounter(-1)
31+
assert ac_negative_1.value == -1
32+
33+
def test_increment(self):
34+
ac = AtomicCounter()
35+
result_default = ac.increment()
36+
assert result_default == 1
37+
assert ac.value == 1
38+
39+
result_with_value = ac.increment(2)
40+
assert result_with_value == 3
41+
assert ac.value == 3
42+
result_plus_100 = ac.increment(100)
43+
assert result_plus_100 == 103
44+
45+
def test_plus_call(self):
46+
ac = AtomicCounter()
47+
ac += 1
48+
assert ac.value == 1
49+
50+
n = ac + 2
51+
assert n == 3
52+
assert ac.value == 1
53+
54+
n = 200 + ac
55+
assert n == 201
56+
assert ac.value == 1
57+
58+
def test_multiple_threads_incrementing(self):
59+
ac = AtomicCounter()
60+
n = 200
61+
m = 10
62+
63+
def do_work():
64+
for i in range(m):
65+
ac.increment()
66+
67+
threads = []
68+
for i in range(n):
69+
th = threading.Thread(target=do_work)
70+
threads.append(th)
71+
th.start()
72+
73+
time.sleep(0.3)
74+
75+
random.shuffle(threads)
76+
for th in threads:
77+
th.join()
78+
assert th.is_alive() == False
79+
80+
# Finally the result should be n*m
81+
assert ac.value == n * m

0 commit comments

Comments
 (0)