Skip to content

Commit 05f4308

Browse files
Proxy MongoCluster for passing in ClientSession if needed
1 parent 51041f4 commit 05f4308

5 files changed

Lines changed: 36 additions & 11 deletions

File tree

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/SessionAwareMethodInterceptor.java

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ public class SessionAwareMethodInterceptor<D, C> implements MethodInterceptor {
5454
private final ClientSessionOperator databaseDecorator;
5555
private final Object target;
5656
private final Class<?> targetType;
57+
private final Class<?> clientType;
5758
private final Class<?> collectionType;
5859
private final Class<?> databaseType;
5960
private final Class<? extends ClientSession> sessionType;
@@ -71,7 +72,7 @@ public class SessionAwareMethodInterceptor<D, C> implements MethodInterceptor {
7172
* {@code MongoCollection}.
7273
* @param <T> target object type.
7374
*/
74-
public <T> SessionAwareMethodInterceptor(ClientSession session, T target, Class<? extends ClientSession> sessionType,
75+
public <T> SessionAwareMethodInterceptor(ClientSession session, T target, Class<?> clientType, Class<? extends ClientSession> sessionType,
7576
Class<D> databaseType, ClientSessionOperator<D> databaseDecorator, Class<C> collectionType,
7677
ClientSessionOperator<C> collectionDecorator) {
7778

@@ -85,15 +86,24 @@ public <T> SessionAwareMethodInterceptor(ClientSession session, T target, Class<
8586

8687
this.session = session;
8788
this.target = target;
89+
this.clientType = ClassUtils.getUserClass(clientType);
8890
this.databaseType = ClassUtils.getUserClass(databaseType);
8991
this.collectionType = ClassUtils.getUserClass(collectionType);
9092
this.collectionDecorator = collectionDecorator;
9193
this.databaseDecorator = databaseDecorator;
9294

93-
this.targetType = ClassUtils.isAssignable(databaseType, target.getClass()) ? databaseType : collectionType;
95+
this.targetType = targetType(target.getClass());
9496
this.sessionType = sessionType;
9597
}
9698

99+
Class<?> targetType(@Nullable Class<?> targetType) {
100+
101+
if(ClassUtils.isAssignable(clientType, targetType)) {
102+
return clientType;
103+
}
104+
return ClassUtils.isAssignable(databaseType, target.getClass()) ? databaseType : collectionType;
105+
}
106+
97107
@Override
98108
public @Nullable Object invoke(MethodInvocation methodInvocation) throws Throwable {
99109

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoDatabaseFactorySupport.java

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,7 @@
1515
*/
1616
package org.springframework.data.mongodb.core;
1717

18-
import com.mongodb.client.MongoCluster;
1918
import org.jspecify.annotations.Nullable;
20-
2119
import org.springframework.aop.framework.ProxyFactory;
2220
import org.springframework.dao.DataAccessException;
2321
import org.springframework.dao.support.PersistenceExceptionTranslator;
@@ -29,6 +27,7 @@
2927
import com.mongodb.ClientSessionOptions;
3028
import com.mongodb.WriteConcern;
3129
import com.mongodb.client.ClientSession;
30+
import com.mongodb.client.MongoCluster;
3231
import com.mongodb.client.MongoCollection;
3332
import com.mongodb.client.MongoDatabase;
3433

@@ -126,7 +125,6 @@ public MongoDatabase getMongoDatabase(String dbName) throws DataAccessException
126125
*/
127126
protected abstract MongoDatabase doGetMongoDatabase(String dbName);
128127

129-
130128
public void destroy() throws Exception {
131129
if (mongoInstanceCreated) {
132130
closeClient();
@@ -170,9 +168,7 @@ record ClientSessionBoundMongoDbFactory(ClientSession session,
170168

171169
@Override
172170
public MongoCluster getCluster() {
173-
174-
// TODO: we need to proxy the cluster and methods that accept a client session
175-
return delegate.getCluster();
171+
return proxyMongoCluster(delegate.getCluster());
176172
}
177173

178174
@Override
@@ -209,6 +205,10 @@ private MongoDatabase proxyMongoDatabase(MongoDatabase database) {
209205
return createProxyInstance(session, database, MongoDatabase.class);
210206
}
211207

208+
private MongoCluster proxyMongoCluster(MongoCluster cluster) {
209+
return createProxyInstance(session, cluster, MongoCluster.class);
210+
}
211+
212212
private MongoDatabase proxyDatabase(com.mongodb.session.ClientSession session, MongoDatabase database) {
213213
return createProxyInstance(session, database, MongoDatabase.class);
214214
}
@@ -225,7 +225,7 @@ private <T> T createProxyInstance(com.mongodb.session.ClientSession session, T t
225225
factory.setInterfaces(targetType);
226226
factory.setOpaque(true);
227227

228-
factory.addAdvice(new SessionAwareMethodInterceptor<>(session, target, ClientSession.class, MongoDatabase.class,
228+
factory.addAdvice(new SessionAwareMethodInterceptor<>(session, target, MongoCluster.class, ClientSession.class, MongoDatabase.class,
229229
this::proxyDatabase, MongoCollection.class, this::proxyCollection));
230230

231231
return targetType.cast(factory.getProxy(target.getClass().getClassLoader()));

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/SimpleReactiveMongoDatabaseFactory.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
*/
1616
package org.springframework.data.mongodb.core;
1717

18+
import com.mongodb.reactivestreams.client.MongoCluster;
1819
import reactor.core.publisher.Mono;
1920

2021
import org.bson.codecs.configuration.CodecRegistry;
@@ -226,7 +227,7 @@ private <T> T createProxyInstance(com.mongodb.session.ClientSession session, T t
226227
factory.setInterfaces(targetType);
227228
factory.setOpaque(true);
228229

229-
factory.addAdvice(new SessionAwareMethodInterceptor<>(session, target, ClientSession.class, MongoDatabase.class,
230+
factory.addAdvice(new SessionAwareMethodInterceptor<>(session, target, MongoCluster.class, ClientSession.class, MongoDatabase.class,
230231
this::proxyDatabase, MongoCollection.class, this::proxyCollection));
231232

232233
return targetType.cast(factory.getProxy(target.getClass().getClassLoader()));

spring-data-mongodb/src/test/java/org/springframework/data/mongodb/SessionAwareMethodInterceptorUnitTests.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import java.lang.reflect.Method;
2323
import java.lang.reflect.Proxy;
2424

25+
import com.mongodb.client.MongoCluster;
2526
import org.bson.Document;
2627
import org.junit.jupiter.api.BeforeEach;
2728
import org.junit.jupiter.api.Test;
@@ -175,7 +176,7 @@ private <T> T createProxyInstance(com.mongodb.session.ClientSession session, T t
175176
factory.setInterfaces(targetType);
176177
factory.setOpaque(true);
177178

178-
factory.addAdvice(new SessionAwareMethodInterceptor<>(session, target, ClientSession.class, MongoDatabase.class,
179+
factory.addAdvice(new SessionAwareMethodInterceptor<>(session, target, MongoCluster.class, ClientSession.class, MongoDatabase.class,
179180
this::proxyDatabase, MongoCollection.class, this::proxyCollection));
180181

181182
return targetType.cast(factory.getProxy());

spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/SimpleMongoClientDatabaseFactoryUnitTests.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import java.lang.reflect.InvocationHandler;
2222
import java.lang.reflect.Proxy;
23+
import java.util.List;
2324

2425
import org.junit.jupiter.api.Test;
2526
import org.junit.jupiter.api.extension.ExtendWith;
@@ -103,6 +104,18 @@ void cascadedWithSessionUsesRootFactory() {
103104
assertThat(singletonTarget).isSameAs(database);
104105
}
105106

107+
@Test // GH-5087
108+
void passesOnClientSessionWhenInvokingMethodsOnMongoCluster() {
109+
110+
MongoDatabaseFactory factory = MongoDatabaseFactory.create(mongo, "foo");
111+
MongoDatabaseFactory wrapped = factory.withSession(clientSession);
112+
113+
assertThat(wrapped.getCluster()).isInstanceOf(Proxy.class);
114+
wrapped.getCluster().bulkWrite(List.of());
115+
116+
verify(mongo).bulkWrite(eq(clientSession), any());
117+
}
118+
106119
private void rejectsDatabaseName(String databaseName) {
107120
assertThatThrownBy(() -> new SimpleMongoClientDatabaseFactory(mongo, databaseName))
108121
.isInstanceOf(IllegalArgumentException.class);

0 commit comments

Comments
 (0)