Skip to content

Commit 13f34b4

Browse files
marcin-nowicki-plspring-builds
authored andcommitted
GH-10931: Add ServerWebSocketContainer.setAllowedOriginPatterns()
Fixes: #10931 * Add `originPatterns` field to `ServerWebSocketContainer` * Add `setAllowedOriginPatterns(String...)` * Update `registerWebSocketHandlers()` to propagate `setAllowedOriginPatterns()` when configured * Add unit tests for both code paths Signed-off-by: Marcin Nowicki <marcin.nowicki.poczta@gmail.com> (cherry picked from commit d90215b)
1 parent 558f384 commit 13f34b4

2 files changed

Lines changed: 110 additions & 2 deletions

File tree

spring-integration-websocket/src/main/java/org/springframework/integration/websocket/ServerWebSocketContainer.java

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
* @author Gary Russell
5252
* @author Christian Tzolov
5353
* @author Jooyoung Pyoung
54+
* @author Marcin Nowicki
5455
*
5556
* @since 4.1
5657
*/
@@ -69,6 +70,8 @@ public class ServerWebSocketContainer extends IntegrationWebSocketContainer
6970

7071
private String[] origins = {};
7172

73+
private String[] originPatterns = {};
74+
7275
private boolean autoStartup = true;
7376

7477
private int phase = 0;
@@ -121,6 +124,18 @@ public ServerWebSocketContainer setAllowedOrigins(String... origins) {
121124
return this;
122125
}
123126

127+
/**
128+
* Specify origin patterns for which cross-origin requests are allowed from a browser.
129+
* @param originPatterns the origin patterns to allow.
130+
* @return the current ServerWebSocketContainer
131+
* @since 7.0.5
132+
* @see WebSocketHandlerRegistration#setAllowedOriginPatterns(String...)
133+
*/
134+
public ServerWebSocketContainer setAllowedOriginPatterns(String... originPatterns) {
135+
this.originPatterns = Arrays.copyOf(originPatterns, originPatterns.length);
136+
return this;
137+
}
138+
124139
public ServerWebSocketContainer withSockJs(SockJsServiceOptions... sockJsServiceOptions) {
125140
if (ObjectUtils.isEmpty(sockJsServiceOptions)) {
126141
setSockJsServiceOptions(new SockJsServiceOptions());
@@ -164,8 +179,14 @@ public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
164179

165180
WebSocketHandlerRegistration registration =
166181
registry.addHandler(webSocketHandler, this.paths)
167-
.addInterceptors(this.interceptors)
168-
.setAllowedOrigins(this.origins);
182+
.addInterceptors(this.interceptors);
183+
184+
if (!ObjectUtils.isEmpty(this.originPatterns)) {
185+
registration.setAllowedOriginPatterns(this.originPatterns);
186+
}
187+
if (!ObjectUtils.isEmpty(this.origins)) {
188+
registration.setAllowedOrigins(this.origins);
189+
}
169190

170191
if (this.handshakeHandler != null) {
171192
registration.setHandshakeHandler(this.handshakeHandler);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
/*
2+
* Copyright 2026-present the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.integration.websocket.server;
18+
19+
import org.junit.jupiter.api.Test;
20+
import org.mockito.Answers;
21+
22+
import org.springframework.integration.websocket.ServerWebSocketContainer;
23+
import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistration;
24+
import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry;
25+
26+
import static org.mockito.ArgumentMatchers.any;
27+
import static org.mockito.BDDMockito.given;
28+
import static org.mockito.Mockito.mock;
29+
import static org.mockito.Mockito.never;
30+
import static org.mockito.Mockito.verify;
31+
32+
/**
33+
* Tests for {@link ServerWebSocketContainer}.
34+
*
35+
* @author Marcin Nowicki
36+
*
37+
* @since 7.0.5
38+
*/
39+
public class ServerWebSocketContainerTests {
40+
41+
@Test
42+
public void setAllowedOriginPatternsIsUsedInRegistration() {
43+
WebSocketHandlerRegistry registry = mock();
44+
WebSocketHandlerRegistration registration = mock(Answers.RETURNS_SELF);
45+
given(registry.addHandler(any(), any(String[].class))).willReturn(registration);
46+
47+
ServerWebSocketContainer container = new ServerWebSocketContainer("/test")
48+
.setAllowedOriginPatterns("https://example.com");
49+
50+
container.registerWebSocketHandlers(registry);
51+
52+
verify(registration).setAllowedOriginPatterns("https://example.com");
53+
verify(registration, never()).setAllowedOrigins(any());
54+
}
55+
56+
@Test
57+
public void setAllowedOriginsIsUsedWhenNoPatternsConfigured() {
58+
WebSocketHandlerRegistry registry = mock();
59+
WebSocketHandlerRegistration registration = mock(Answers.RETURNS_SELF);
60+
given(registry.addHandler(any(), any(String[].class))).willReturn(registration);
61+
62+
ServerWebSocketContainer container = new ServerWebSocketContainer("/test")
63+
.setAllowedOrigins("https://example.com");
64+
65+
container.registerWebSocketHandlers(registry);
66+
67+
verify(registration).setAllowedOrigins("https://example.com");
68+
verify(registration, never()).setAllowedOriginPatterns(any());
69+
}
70+
71+
@Test
72+
public void bothOriginsAndOriginPatternsAreUsedWhenBothConfigured() {
73+
WebSocketHandlerRegistry registry = mock();
74+
WebSocketHandlerRegistration registration = mock(Answers.RETURNS_SELF);
75+
given(registry.addHandler(any(), any(String[].class))).willReturn(registration);
76+
77+
ServerWebSocketContainer container = new ServerWebSocketContainer("/test")
78+
.setAllowedOrigins("https://example.com")
79+
.setAllowedOriginPatterns("https://*.example.com");
80+
81+
container.registerWebSocketHandlers(registry);
82+
83+
verify(registration).setAllowedOriginPatterns("https://*.example.com");
84+
verify(registration).setAllowedOrigins("https://example.com");
85+
}
86+
87+
}

0 commit comments

Comments
 (0)