diff --git a/config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java b/config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java index dcc0f7f722a..a23bbaa7b0c 100644 --- a/config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java +++ b/config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java @@ -552,6 +552,7 @@ public void postWhenCustomRequestHandlerThenUsed() { given(this.csrfTokenRepository.loadToken(any(ServerWebExchange.class))).willReturn(Mono.just(csrfToken)); given(this.csrfTokenRepository.generateToken(any(ServerWebExchange.class))).willReturn(Mono.empty()); ServerCsrfTokenRequestHandler requestHandler = mock(ServerCsrfTokenRequestHandler.class); + given(requestHandler.handleAsync(any(ServerWebExchange.class), any())).willReturn(Mono.empty()); given(requestHandler.resolveCsrfTokenValue(any(ServerWebExchange.class), any(CsrfToken.class))) .willReturn(Mono.just(csrfToken.getToken())); // @formatter:off @@ -564,7 +565,7 @@ public void postWhenCustomRequestHandlerThenUsed() { client.post().uri("/").exchange().expectStatus().isOk(); verify(this.csrfTokenRepository, times(2)).loadToken(any(ServerWebExchange.class)); verify(this.csrfTokenRepository).generateToken(any(ServerWebExchange.class)); - verify(requestHandler).handle(any(ServerWebExchange.class), any()); + verify(requestHandler).handleAsync(any(ServerWebExchange.class), any()); verify(requestHandler).resolveCsrfTokenValue(any(ServerWebExchange.class), any()); } diff --git a/web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java b/web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java index 2c8a97f56b4..c982dbae9eb 100644 --- a/web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java +++ b/web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java @@ -60,6 +60,7 @@ * @author Rob Winch * @author Parikshit Dutta * @author Steve Riesenberg + * @author Andrey Litvitski * @since 5.0 */ public class CsrfWebFilter implements WebFilter { @@ -147,8 +148,7 @@ private Mono containsValidCsrfToken(ServerWebExchange exchange, CsrfTok private Mono continueFilterChain(ServerWebExchange exchange, WebFilterChain chain) { return Mono.defer(() -> { Mono csrfToken = csrfToken(exchange); - this.requestHandler.handle(exchange, csrfToken); - return chain.filter(exchange); + return this.requestHandler.handleAsync(exchange, csrfToken).then(chain.filter(exchange)); }); } diff --git a/web/src/main/java/org/springframework/security/web/server/csrf/ServerCsrfTokenRequestHandler.java b/web/src/main/java/org/springframework/security/web/server/csrf/ServerCsrfTokenRequestHandler.java index 86cd026764d..50b3cc8e269 100644 --- a/web/src/main/java/org/springframework/security/web/server/csrf/ServerCsrfTokenRequestHandler.java +++ b/web/src/main/java/org/springframework/security/web/server/csrf/ServerCsrfTokenRequestHandler.java @@ -29,6 +29,7 @@ * made available to the application through exchange attributes. * * @author Steve Riesenberg + * @author Andrey Litvitski * @since 5.8 * @see ServerCsrfTokenRequestAttributeHandler */ @@ -40,9 +41,23 @@ public interface ServerCsrfTokenRequestHandler extends ServerCsrfTokenRequestRes * @param exchange the {@code ServerWebExchange} with the request being handled * @param csrfToken the {@code Mono} created by the * {@link ServerCsrfTokenRepository} + * @deprecated since 7.0 in favor of {@link #handleAsync(ServerWebExchange, Mono)} */ + @Deprecated(since = "7.0", forRemoval = true) void handle(ServerWebExchange exchange, Mono csrfToken); + /** + * Handles a request using a {@link CsrfToken}. + * @param exchange the {@code ServerWebExchange} with the request being handled + * @param csrfToken the {@code Mono} created by the + * {@link ServerCsrfTokenRepository} + * @return a {@code Mono} that completes when handling is finished + */ + default Mono handleAsync(ServerWebExchange exchange, Mono csrfToken) { + handle(exchange, csrfToken); + return Mono.empty(); + } + @Override default Mono resolveCsrfTokenValue(ServerWebExchange exchange, CsrfToken csrfToken) { Assert.notNull(exchange, "exchange cannot be null"); diff --git a/web/src/test/java/org/springframework/security/web/server/csrf/CsrfWebFilterTests.java b/web/src/test/java/org/springframework/security/web/server/csrf/CsrfWebFilterTests.java index 46be524d366..80f431fd437 100644 --- a/web/src/test/java/org/springframework/security/web/server/csrf/CsrfWebFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/server/csrf/CsrfWebFilterTests.java @@ -164,6 +164,7 @@ public void filterWhenPostAndEstablishedCsrfTokenAndHeaderValidTokenThenContinue @Test public void filterWhenRequestHandlerSetThenUsed() { ServerCsrfTokenRequestHandler requestHandler = mock(ServerCsrfTokenRequestHandler.class); + given(requestHandler.handleAsync(any(ServerWebExchange.class), any())).willReturn(Mono.empty()); given(requestHandler.resolveCsrfTokenValue(any(ServerWebExchange.class), any(CsrfToken.class))) .willReturn(Mono.just(this.token.getToken())); this.csrfFilter.setRequestHandler(requestHandler); @@ -179,7 +180,7 @@ public void filterWhenRequestHandlerSetThenUsed() { StepVerifier.create(result).verifyComplete(); chainResult.assertWasSubscribed(); - verify(requestHandler).handle(eq(this.post), any()); + verify(requestHandler).handleAsync(eq(this.post), any()); verify(requestHandler).resolveCsrfTokenValue(this.post, this.token); }