forked from googleapis/python-spanner
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_atomic_counter.py
More file actions
78 lines (63 loc) · 2.11 KB
/
test_atomic_counter.py
File metadata and controls
78 lines (63 loc) · 2.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# Copyright 2024 Google LLC All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import threading
import unittest
from google.cloud.spanner_v1._helpers import AtomicCounter
class TestAtomicCounter(unittest.TestCase):
def test_initialization(self):
ac_default = AtomicCounter()
assert ac_default.value == 0
ac_1 = AtomicCounter(1)
assert ac_1.value == 1
ac_negative_1 = AtomicCounter(-1)
assert ac_negative_1.value == -1
def test_increment(self):
ac = AtomicCounter()
result_default = ac.increment()
assert result_default == 1
assert ac.value == 1
result_with_value = ac.increment(2)
assert result_with_value == 3
assert ac.value == 3
result_plus_100 = ac.increment(100)
assert result_plus_100 == 103
def test_plus_call(self):
ac = AtomicCounter()
ac += 1
assert ac.value == 1
n = ac + 2
assert n == 3
assert ac.value == 1
n = 200 + ac
assert n == 201
assert ac.value == 1
def test_multiple_threads_incrementing(self):
ac = AtomicCounter()
n = 200
m = 10
def do_work():
for i in range(m):
ac.increment()
threads = []
for i in range(n):
th = threading.Thread(target=do_work)
threads.append(th)
th.start()
random.shuffle(threads)
for th in threads:
th.join()
assert not th.is_alive()
# Finally the result should be n*m
assert ac.value == n * m