Skip to content

Commit 8883b89

Browse files
committed
fix(spanner): preserve all async cache updates
1 parent 26fb4c6 commit 8883b89

File tree

2 files changed

+282
-20
lines changed

2 files changed

+282
-20
lines changed

java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/ChannelFinder.java

Lines changed: 41 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,14 @@
3636
import java.util.List;
3737
import java.util.Objects;
3838
import java.util.Set;
39+
import java.util.concurrent.ConcurrentLinkedQueue;
3940
import java.util.concurrent.ExecutorService;
4041
import java.util.concurrent.LinkedBlockingQueue;
4142
import java.util.concurrent.ThreadLocalRandom;
4243
import java.util.concurrent.ThreadPoolExecutor;
4344
import java.util.concurrent.TimeUnit;
45+
import java.util.concurrent.atomic.AtomicBoolean;
4446
import java.util.concurrent.atomic.AtomicLong;
45-
import java.util.concurrent.atomic.AtomicReference;
4647
import java.util.function.Predicate;
4748
import javax.annotation.Nullable;
4849

@@ -62,8 +63,11 @@ public final class ChannelFinder {
6263
private final AtomicLong databaseId = new AtomicLong();
6364
private final KeyRecipeCache recipeCache = new KeyRecipeCache();
6465
private final KeyRangeCache rangeCache;
65-
private final AtomicReference<CacheUpdate> pendingUpdate = new AtomicReference<>();
66-
private volatile java.util.concurrent.CountDownLatch drainingLatch;
66+
private final ConcurrentLinkedQueue<PendingCacheUpdate> pendingUpdates =
67+
new ConcurrentLinkedQueue<>();
68+
private final AtomicBoolean drainScheduled = new AtomicBoolean();
69+
private volatile java.util.concurrent.CountDownLatch drainingLatch =
70+
new java.util.concurrent.CountDownLatch(0);
6771
@Nullable private final EndpointLifecycleManager lifecycleManager;
6872
@Nullable private final String finderKey;
6973

@@ -105,15 +109,24 @@ private static ExecutorService createCacheUpdatePool() {
105109
return executor;
106110
}
107111

112+
private static final class PendingCacheUpdate {
113+
private final CacheUpdate update;
114+
115+
private PendingCacheUpdate(CacheUpdate update) {
116+
this.update = update;
117+
}
118+
}
119+
108120
public void update(CacheUpdate update) {
109121
synchronized (updateLock) {
110122
long currentId = databaseId.get();
111-
if (currentId != update.getDatabaseId()) {
123+
long updateDatabaseId = update.getDatabaseId();
124+
if (updateDatabaseId != 0 && currentId != updateDatabaseId) {
112125
if (currentId != 0) {
113126
recipeCache.clear();
114127
rangeCache.clear();
115128
}
116-
databaseId.set(update.getDatabaseId());
129+
databaseId.set(updateDatabaseId);
117130
}
118131
if (update.hasKeyRecipes()) {
119132
recipeCache.addRecipes(update.getKeyRecipes());
@@ -141,10 +154,8 @@ public void update(CacheUpdate update) {
141154
}
142155

143156
public void updateAsync(CacheUpdate update) {
144-
// Replace any pending update atomically. Each CacheUpdate contains the full current state,
145-
// so intermediate updates can be safely dropped to prevent unbounded queue growth.
146-
if (pendingUpdate.getAndSet(update) == null) {
147-
// No previous pending update means no drain task is scheduled yet — submit one.
157+
pendingUpdates.add(new PendingCacheUpdate(update));
158+
if (drainScheduled.compareAndSet(false, true)) {
148159
java.util.concurrent.CountDownLatch latch = new java.util.concurrent.CountDownLatch(1);
149160
drainingLatch = latch;
150161
CACHE_UPDATE_POOL.execute(
@@ -159,9 +170,15 @@ public void updateAsync(CacheUpdate update) {
159170
}
160171

161172
private void drainPendingUpdate() {
162-
CacheUpdate toApply;
163-
while ((toApply = pendingUpdate.getAndSet(null)) != null) {
164-
update(toApply);
173+
while (true) {
174+
PendingCacheUpdate toApply;
175+
while ((toApply = pendingUpdates.poll()) != null) {
176+
update(toApply.update);
177+
}
178+
drainScheduled.set(false);
179+
if (pendingUpdates.isEmpty() || !drainScheduled.compareAndSet(false, true)) {
180+
return;
181+
}
165182
}
166183
}
167184

@@ -171,15 +188,19 @@ private void drainPendingUpdate() {
171188
*/
172189
@VisibleForTesting
173190
void awaitPendingUpdates() throws InterruptedException {
174-
// Spin until no pending update remains.
175191
long deadline = System.nanoTime() + java.util.concurrent.TimeUnit.SECONDS.toNanos(5);
176-
while (pendingUpdate.get() != null && System.nanoTime() < deadline) {
177-
Thread.sleep(1);
178-
}
179-
// Wait for the drain task to fully complete (including the update() call).
180-
java.util.concurrent.CountDownLatch latch = drainingLatch;
181-
if (latch != null) {
182-
latch.await(5, java.util.concurrent.TimeUnit.SECONDS);
192+
while (System.nanoTime() < deadline) {
193+
java.util.concurrent.CountDownLatch latch = drainingLatch;
194+
if (latch != null) {
195+
long remainingNanos = deadline - System.nanoTime();
196+
if (remainingNanos <= 0) {
197+
break;
198+
}
199+
latch.await(remainingNanos, java.util.concurrent.TimeUnit.NANOSECONDS);
200+
}
201+
if (pendingUpdates.isEmpty() && !drainScheduled.get()) {
202+
return;
203+
}
183204
}
184205
}
185206

Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
/*
2+
* Copyright 2026 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.google.cloud.spanner.spi.v1;
18+
19+
import static com.google.common.truth.Truth.assertThat;
20+
21+
import com.google.protobuf.ByteString;
22+
import com.google.spanner.v1.CacheUpdate;
23+
import com.google.spanner.v1.Group;
24+
import com.google.spanner.v1.Range;
25+
import com.google.spanner.v1.Tablet;
26+
import io.grpc.CallOptions;
27+
import io.grpc.ClientCall;
28+
import io.grpc.ManagedChannel;
29+
import io.grpc.MethodDescriptor;
30+
import java.lang.reflect.Field;
31+
import java.util.Map;
32+
import java.util.concurrent.ConcurrentHashMap;
33+
import java.util.concurrent.CountDownLatch;
34+
import java.util.concurrent.ExecutorService;
35+
import java.util.concurrent.TimeUnit;
36+
import java.util.concurrent.atomic.AtomicLong;
37+
import org.junit.Test;
38+
import org.junit.runner.RunWith;
39+
import org.junit.runners.JUnit4;
40+
41+
@RunWith(JUnit4.class)
42+
public class ChannelFinderTest {
43+
44+
@Test
45+
public void updateAsyncDrainsQueuedUpdatesInOrderWithoutDroppingAny() throws Exception {
46+
ExecutorService executor = cacheUpdatePool();
47+
int threadCount = maxCacheUpdateThreads();
48+
CountDownLatch workersStarted = new CountDownLatch(threadCount);
49+
CountDownLatch releaseWorkers = new CountDownLatch(1);
50+
51+
try {
52+
for (int i = 0; i < threadCount; i++) {
53+
executor.execute(
54+
() -> {
55+
workersStarted.countDown();
56+
try {
57+
releaseWorkers.await(5, TimeUnit.SECONDS);
58+
} catch (InterruptedException e) {
59+
Thread.currentThread().interrupt();
60+
}
61+
});
62+
}
63+
assertThat(workersStarted.await(5, TimeUnit.SECONDS)).isTrue();
64+
65+
ChannelFinder finder = new ChannelFinder(new FakeEndpointCache());
66+
int updateCount = 64;
67+
for (int i = 0; i < updateCount; i++) {
68+
finder.updateAsync(singleRangeUpdate(i));
69+
}
70+
71+
releaseWorkers.countDown();
72+
finder.awaitPendingUpdates();
73+
74+
assertThat(rangeCache(finder).size()).isEqualTo(updateCount);
75+
} finally {
76+
releaseWorkers.countDown();
77+
}
78+
}
79+
80+
@Test
81+
public void updateIgnoresZeroDatabaseIdAndKeepsExistingCache() throws Exception {
82+
ChannelFinder finder = new ChannelFinder(new FakeEndpointCache());
83+
finder.update(singleRangeUpdate(0));
84+
85+
finder.update(CacheUpdate.newBuilder().setDatabaseId(0L).build());
86+
87+
assertThat(databaseId(finder)).isEqualTo(7L);
88+
assertThat(rangeCache(finder).size()).isEqualTo(1);
89+
}
90+
91+
private static CacheUpdate singleRangeUpdate(int index) {
92+
String startKey = String.format("k%05d", index);
93+
String limitKey = String.format("k%05d", index + 1);
94+
long groupUid = index + 1L;
95+
return CacheUpdate.newBuilder()
96+
.setDatabaseId(7L)
97+
.addRange(
98+
Range.newBuilder()
99+
.setStartKey(bytes(startKey))
100+
.setLimitKey(bytes(limitKey))
101+
.setGroupUid(groupUid)
102+
.setSplitId(groupUid)
103+
.setGeneration(bytes("g")))
104+
.addGroup(
105+
Group.newBuilder()
106+
.setGroupUid(groupUid)
107+
.setGeneration(bytes("g"))
108+
.addTablets(
109+
Tablet.newBuilder()
110+
.setTabletUid(groupUid)
111+
.setServerAddress("server-" + index + ":1234")
112+
.setIncarnation(bytes("i"))
113+
.setDistance(0)))
114+
.build();
115+
}
116+
117+
private static ByteString bytes(String value) {
118+
return ByteString.copyFromUtf8(value);
119+
}
120+
121+
private static ExecutorService cacheUpdatePool() throws Exception {
122+
Field field = ChannelFinder.class.getDeclaredField("CACHE_UPDATE_POOL");
123+
field.setAccessible(true);
124+
return (ExecutorService) field.get(null);
125+
}
126+
127+
private static int maxCacheUpdateThreads() throws Exception {
128+
Field field = ChannelFinder.class.getDeclaredField("MAX_CACHE_UPDATE_THREADS");
129+
field.setAccessible(true);
130+
return field.getInt(null);
131+
}
132+
133+
private static long databaseId(ChannelFinder finder) throws Exception {
134+
Field field = ChannelFinder.class.getDeclaredField("databaseId");
135+
field.setAccessible(true);
136+
return ((AtomicLong) field.get(finder)).get();
137+
}
138+
139+
private static KeyRangeCache rangeCache(ChannelFinder finder) throws Exception {
140+
Field field = ChannelFinder.class.getDeclaredField("rangeCache");
141+
field.setAccessible(true);
142+
return (KeyRangeCache) field.get(finder);
143+
}
144+
145+
private static final class FakeEndpointCache implements ChannelEndpointCache {
146+
private final Map<String, FakeEndpoint> endpoints = new ConcurrentHashMap<>();
147+
private final FakeEndpoint defaultEndpoint = new FakeEndpoint("default");
148+
149+
@Override
150+
public ChannelEndpoint defaultChannel() {
151+
return defaultEndpoint;
152+
}
153+
154+
@Override
155+
public ChannelEndpoint get(String address) {
156+
return endpoints.computeIfAbsent(address, FakeEndpoint::new);
157+
}
158+
159+
@Override
160+
public ChannelEndpoint getIfPresent(String address) {
161+
return endpoints.computeIfAbsent(address, FakeEndpoint::new);
162+
}
163+
164+
@Override
165+
public void evict(String address) {
166+
endpoints.remove(address);
167+
}
168+
169+
@Override
170+
public void shutdown() {
171+
endpoints.clear();
172+
}
173+
}
174+
175+
private static final class FakeEndpoint implements ChannelEndpoint {
176+
private final String address;
177+
private final ManagedChannel channel = new FakeManagedChannel();
178+
179+
private FakeEndpoint(String address) {
180+
this.address = address;
181+
}
182+
183+
@Override
184+
public String getAddress() {
185+
return address;
186+
}
187+
188+
@Override
189+
public boolean isHealthy() {
190+
return true;
191+
}
192+
193+
@Override
194+
public boolean isTransientFailure() {
195+
return false;
196+
}
197+
198+
@Override
199+
public ManagedChannel getChannel() {
200+
return channel;
201+
}
202+
}
203+
204+
private static final class FakeManagedChannel extends ManagedChannel {
205+
@Override
206+
public ManagedChannel shutdown() {
207+
return this;
208+
}
209+
210+
@Override
211+
public ManagedChannel shutdownNow() {
212+
return this;
213+
}
214+
215+
@Override
216+
public boolean isShutdown() {
217+
return false;
218+
}
219+
220+
@Override
221+
public boolean isTerminated() {
222+
return false;
223+
}
224+
225+
@Override
226+
public boolean awaitTermination(long timeout, TimeUnit unit) {
227+
return true;
228+
}
229+
230+
@Override
231+
public <RequestT, ResponseT> ClientCall<RequestT, ResponseT> newCall(
232+
MethodDescriptor<RequestT, ResponseT> methodDescriptor, CallOptions callOptions) {
233+
throw new UnsupportedOperationException();
234+
}
235+
236+
@Override
237+
public String authority() {
238+
return "fake";
239+
}
240+
}
241+
}

0 commit comments

Comments
 (0)