Skip to content

Commit 9f7a57f

Browse files
committed
feat: 웹소켓 메시지 전송 시 JWT 만료 여부 실시간 검증 추가 (#223)
1 parent 891b0aa commit 9f7a57f

2 files changed

Lines changed: 156 additions & 51 deletions

File tree

src/main/java/com/sofa/linkiving/security/config/StompHandler.java

Lines changed: 38 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
import com.sofa.linkiving.security.jwt.error.JwtErrorCode;
2121

2222
import lombok.RequiredArgsConstructor;
23+
import lombok.extern.slf4j.Slf4j;
2324

25+
@Slf4j
2426
@Component
2527
@RequiredArgsConstructor
2628
@Order(Ordered.HIGHEST_PRECEDENCE + 99)
@@ -32,37 +34,50 @@ public class StompHandler implements ChannelInterceptor {
3234
public Message<?> preSend(Message<?> message, MessageChannel channel) {
3335
StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
3436

35-
if (accessor != null && StompCommand.CONNECT.equals(accessor.getCommand())) {
36-
String token = null;
37+
if (accessor != null) {
38+
StompCommand command = accessor.getCommand();
3739

38-
Map<String, Object> sessionAttributes = accessor.getSessionAttributes();
39-
if (sessionAttributes != null && sessionAttributes.containsKey("accessToken")) {
40-
token = (String)sessionAttributes.get("accessToken");
41-
}
42-
43-
if (token == null) {
44-
String authorizationHeader = accessor.getFirstNativeHeader(JwtKeys.Headers.AUTHORIZATION);
40+
if (StompCommand.CONNECT.equals(command) || StompCommand.SEND.equals(command)) {
41+
String token = null;
42+
Map<String, Object> sessionAttributes = accessor.getSessionAttributes();
4543

46-
if (authorizationHeader != null && authorizationHeader.startsWith(JwtKeys.Headers.BEARER_PREFIX)) {
47-
token = authorizationHeader.substring(JwtKeys.Headers.BEARER_PREFIX.length());
44+
if (sessionAttributes != null && sessionAttributes.containsKey("accessToken")) {
45+
token = (String)sessionAttributes.get("accessToken");
4846
}
49-
}
5047

51-
try {
52-
if (token == null) {
53-
throw new BusinessException(JwtErrorCode.EMPTY_TOKEN);
54-
}
48+
if (token == null && StompCommand.CONNECT.equals(command)) {
49+
String authorizationHeader = accessor.getFirstNativeHeader(JwtKeys.Headers.AUTHORIZATION);
50+
51+
if (authorizationHeader != null && authorizationHeader.startsWith(JwtKeys.Headers.BEARER_PREFIX)) {
52+
token = authorizationHeader.substring(JwtKeys.Headers.BEARER_PREFIX.length());
5553

56-
if (jwtTokenProvider.validateAccessToken(token)) {
57-
Authentication authentication = jwtTokenProvider.getAuthentication(token);
58-
accessor.setUser(authentication);
54+
if (sessionAttributes != null) {
55+
sessionAttributes.put("accessToken", token);
56+
}
57+
}
5958
}
6059

61-
} catch (BusinessException e) {
62-
throw new MessagingException(e.getMessage());
60+
try {
61+
if (token == null) {
62+
throw new BusinessException(JwtErrorCode.EMPTY_TOKEN);
63+
}
6364

64-
} catch (Exception e) {
65-
throw new MessagingException("서버 내부 오류로 연결에 실패했습니다.");
65+
if (jwtTokenProvider.validateAccessToken(token)) {
66+
67+
if (StompCommand.CONNECT.equals(command)) {
68+
Authentication authentication = jwtTokenProvider.getAuthentication(token);
69+
accessor.setUser(authentication);
70+
}
71+
}
72+
73+
} catch (BusinessException e) {
74+
log.warn("웹소켓 인증/만료 에러 차단: {}", e.getMessage());
75+
throw new MessagingException(e.getMessage());
76+
77+
} catch (Exception e) {
78+
log.error("웹소켓 서버 내부 오류", e);
79+
throw new MessagingException("서버 내부 오류로 연결에 실패했습니다.");
80+
}
6681
}
6782
}
6883

src/test/java/com/sofa/linkiving/domain/chat/integration/WebSocketChatIntegrationTest.java

Lines changed: 118 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
import java.lang.reflect.Type;
88
import java.util.ArrayList;
99
import java.util.Collections;
10+
import java.util.HashMap;
1011
import java.util.List;
12+
import java.util.Map;
1113
import java.util.concurrent.BlockingQueue;
1214
import java.util.concurrent.ExecutionException;
1315
import java.util.concurrent.LinkedBlockingQueue;
@@ -20,6 +22,7 @@
2022
import org.springframework.beans.factory.annotation.Autowired;
2123
import org.springframework.boot.test.context.SpringBootTest;
2224
import org.springframework.boot.test.web.server.LocalServerPort;
25+
import org.springframework.http.HttpHeaders;
2326
import org.springframework.messaging.converter.MappingJackson2MessageConverter;
2427
import org.springframework.messaging.simp.stomp.StompFrameHandler;
2528
import org.springframework.messaging.simp.stomp.StompHeaders;
@@ -28,6 +31,7 @@
2831
import org.springframework.test.annotation.DirtiesContext;
2932
import org.springframework.test.context.ActiveProfiles;
3033
import org.springframework.test.context.bean.override.mockito.MockitoBean;
34+
import org.springframework.test.context.bean.override.mockito.MockitoSpyBean;
3135
import org.springframework.web.socket.WebSocketHttpHeaders;
3236
import org.springframework.web.socket.client.standard.StandardWebSocketClient;
3337
import org.springframework.web.socket.messaging.WebSocketStompClient;
@@ -38,9 +42,6 @@
3842
import com.fasterxml.jackson.databind.ObjectMapper;
3943
import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule;
4044
import com.sofa.linkiving.domain.chat.ai.AnswerClient;
41-
import com.sofa.linkiving.domain.chat.dto.request.AnswerCancelReq;
42-
import com.sofa.linkiving.domain.chat.dto.request.AnswerReq;
43-
import com.sofa.linkiving.domain.chat.dto.response.AnswerRes;
4445
import com.sofa.linkiving.domain.chat.dto.response.RagAnswerRes;
4546
import com.sofa.linkiving.domain.chat.entity.Chat;
4647
import com.sofa.linkiving.domain.chat.repository.ChatRepository;
@@ -51,6 +52,10 @@
5152
import com.sofa.linkiving.domain.member.repository.MemberRepository;
5253
import com.sofa.linkiving.infra.redis.RedisService;
5354
import com.sofa.linkiving.security.jwt.JwtTokenProvider;
55+
import com.sofa.linkiving.security.jwt.error.CustomJwtException;
56+
import com.sofa.linkiving.security.jwt.error.JwtErrorCode;
57+
58+
import jakarta.annotation.Nonnull;
5459

5560
@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT)
5661
@DirtiesContext(classMode = DirtiesContext.ClassMode.AFTER_EACH_TEST_METHOD)
@@ -78,15 +83,16 @@ public class WebSocketChatIntegrationTest {
7883
@MockitoBean
7984
private RedisService redisService;
8085

81-
@Autowired
86+
@MockitoSpyBean
8287
private JwtTokenProvider jwtTokenProvider;
8388

8489
@MockitoBean
8590
private AnswerClient answerClient;
8691

8792
private StompSession stompSession;
88-
private BlockingQueue<AnswerRes> blockingQueue;
93+
private BlockingQueue<Map<String, Object>> blockingQueue;
8994
private Chat testChat;
95+
private Member testMember;
9096

9197
@BeforeEach
9298
void setUp() throws ExecutionException, InterruptedException, TimeoutException {
@@ -97,7 +103,7 @@ void setUp() throws ExecutionException, InterruptedException, TimeoutException {
97103
memberRepository.deleteAllInBatch();
98104

99105
// 1. 데이터 초기화
100-
Member testMember = memberRepository.save(Member.builder()
106+
testMember = memberRepository.save(Member.builder()
101107
.email("test@test.com")
102108
.password("password")
103109
.build());
@@ -127,14 +133,19 @@ void setUp() throws ExecutionException, InterruptedException, TimeoutException {
127133

128134
// 3. WebSocket 연결
129135
String wsUrl = "ws://localhost:" + port + "/ws/chat";
130-
StompHeaders headers = new StompHeaders();
131136

132137
String validToken = jwtTokenProvider.createAccessToken(testMember.getEmail());
133-
headers.add("Authorization", "Bearer " + validToken);
138+
String authHeaderValue = "Bearer " + validToken;
139+
140+
StompHeaders headers = new StompHeaders();
141+
headers.add("Authorization", authHeaderValue);
142+
143+
WebSocketHttpHeaders webSocketHttpHeaders = new WebSocketHttpHeaders();
144+
webSocketHttpHeaders.add(HttpHeaders.COOKIE, "accessToken=" + validToken);
134145

135146
this.stompSession = stompClient.connectAsync(
136147
wsUrl,
137-
new WebSocketHttpHeaders(),
148+
webSocketHttpHeaders,
138149
headers,
139150
new StompSessionHandlerAdapter() {
140151
}
@@ -158,38 +169,112 @@ private List<Transport> createTransportClient() {
158169

159170
private void subscribeToChatQueue() {
160171
stompSession.subscribe("/user/queue/chat", new StompFrameHandler() {
172+
@Nonnull
161173
@Override
162-
public Type getPayloadType(StompHeaders headers) {
163-
return AnswerRes.class;
174+
public Type getPayloadType(@Nonnull StompHeaders headers) {
175+
return Map.class;
164176
}
165177

166178
@Override
167-
public void handleFrame(StompHeaders headers, Object payload) {
168-
blockingQueue.offer((AnswerRes)payload);
179+
public void handleFrame(@Nonnull StompHeaders headers, Object payload) {
180+
if (payload != null) {
181+
blockingQueue.add((Map<String, Object>)payload);
182+
}
169183
}
170184
});
171185
}
172186

173187
@Test
174-
@DisplayName("유저가 메시지를 보내면 AI 응답이 Queue로 수신된다")
188+
@DisplayName("쿠키에 담긴 accessToken으로 정상적으로 웹소켓 연결이 가능하다")
189+
void shouldConnectSuccessfullyWithCookie() throws Exception {
190+
// given
191+
WebSocketStompClient customClient = new WebSocketStompClient(new SockJsClient(createTransportClient()));
192+
String validToken = jwtTokenProvider.createAccessToken(testMember.getEmail());
193+
194+
WebSocketHttpHeaders httpHeaders = new WebSocketHttpHeaders();
195+
httpHeaders.add(HttpHeaders.COOKIE, "accessToken=" + validToken);
196+
197+
// when
198+
StompSession session = customClient.connectAsync(
199+
"ws://localhost:" + port + "/ws/chat",
200+
httpHeaders,
201+
new StompHeaders(),
202+
new StompSessionHandlerAdapter() {
203+
}
204+
).get(5, SECONDS);
205+
206+
// then
207+
assertThat(session.isConnected()).isTrue();
208+
session.disconnect();
209+
}
210+
211+
@Test
212+
@DisplayName("쿠키(토큰)가 없거나 유효하지 않으면 웹소켓 연결에 실패한다 (401)")
213+
void shouldFailToConnectWithInvalidToken() {
214+
// given
215+
WebSocketStompClient customClient = new WebSocketStompClient(new SockJsClient(createTransportClient()));
216+
217+
// 잘못된 쿠키 세팅
218+
WebSocketHttpHeaders httpHeaders = new WebSocketHttpHeaders();
219+
httpHeaders.add(HttpHeaders.COOKIE, "accessToken=invalid_token_value");
220+
221+
// when & then
222+
assertThatThrownBy(() -> customClient.connectAsync(
223+
"ws://localhost:" + port + "/ws/chat",
224+
httpHeaders,
225+
new StompHeaders(),
226+
new StompSessionHandlerAdapter() {
227+
}
228+
).get(5, SECONDS))
229+
.isInstanceOf(ExecutionException.class);
230+
}
231+
232+
@Test
233+
@DisplayName("연결 이후 SEND 요청 시 인증 토큰이 만료/조작된 경우 메시지를 차단하고 ERROR 프레임을 반환한다")
234+
void shouldFailToSendWhenTokenIsInvalid() throws Exception {
235+
// given
236+
subscribeToChatQueue();
237+
Thread.sleep(1000);
238+
239+
// when
240+
doThrow(new CustomJwtException(JwtErrorCode.EXPIRED_JWT_TOKEN))
241+
.when(jwtTokenProvider).validateAccessToken(anyString());
242+
243+
Long chatId = testChat.getId();
244+
Map<String, Object> req = new HashMap<>();
245+
req.put("chatId", chatId);
246+
req.put("message", "만료된 토큰으로 전송 시도");
247+
248+
stompSession.send("/ws/chat/send", req);
249+
250+
// then
251+
Thread.sleep(1000);
252+
verify(answerClient, never()).generateAnswer(any());
253+
}
254+
255+
@Test
256+
@DisplayName("정상적인 유저가 메시지를 보내면 AI 응답이 Queue로 수신된다")
175257
void shouldReceiveAnswerWhenMessageSent() throws InterruptedException {
176258
// given
177259
Long chatId = testChat.getId();
178260
String userMessage = "Gemini에 대해 알려줘";
179-
AnswerReq req = new AnswerReq(chatId, userMessage);
261+
262+
Map<String, Object> req = new HashMap<>();
263+
req.put("chatId", chatId);
264+
req.put("message", userMessage);
180265

181266
subscribeToChatQueue();
267+
Thread.sleep(1000);
182268

183269
// when
184270
stompSession.send("/ws/chat/send", req);
185271

186272
// then
187-
AnswerRes received = blockingQueue.poll(10, SECONDS);
273+
Map<String, Object> received = blockingQueue.poll(10, SECONDS);
188274

189275
assertThat(received).isNotNull();
190-
assertThat(received.chatId()).isEqualTo(chatId);
191-
assertThat(received.success()).isTrue();
192-
assertThat(received.content()).contains("Gemini와 관련된 내용");
276+
assertThat(received.get("success")).isEqualTo(true);
277+
assertThat(String.valueOf(received.get("content"))).contains("Gemini와 관련된 내용");
193278
}
194279

195280
@Test
@@ -198,29 +283,34 @@ void shouldReceiveErrorMessageWhenCancelled() throws InterruptedException {
198283
// given
199284
Long chatId = testChat.getId();
200285
String userMessage = "취소될 질문";
201-
AnswerReq sendReq = new AnswerReq(chatId, userMessage);
202-
AnswerCancelReq cancelReq = new AnswerCancelReq(chatId);
286+
287+
Map<String, Object> sendReq = new HashMap<>();
288+
sendReq.put("chatId", chatId);
289+
sendReq.put("message", userMessage);
290+
291+
Map<String, Object> cancelReq = new HashMap<>();
292+
cancelReq.put("chatId", chatId);
203293

204294
given(answerClient.generateAnswer(any())).willAnswer(invocation -> {
205-
Thread.sleep(500);
295+
Thread.sleep(1500);
206296
return new RagAnswerRes("지연된 답변", Collections.emptyList(), Collections.emptyList(), Collections.emptyList(),
207297
true);
208298
});
209299

210300
subscribeToChatQueue();
301+
Thread.sleep(1000);
211302

212303
// when
213304
stompSession.send("/ws/chat/send", sendReq);
214-
Thread.sleep(50);
305+
Thread.sleep(300);
215306
stompSession.send("/ws/chat/cancel", cancelReq);
216307

217308
// then
218-
AnswerRes received = blockingQueue.poll(5, SECONDS);
309+
Map<String, Object> received = blockingQueue.poll(10, SECONDS);
219310

220311
assertThat(received).isNotNull();
221-
assertThat(received.chatId()).isEqualTo(chatId);
222-
assertThat(received.success()).isFalse();
223-
assertThat(received.content()).isEqualTo(userMessage);
312+
assertThat(received.get("success")).isEqualTo(false);
313+
assertThat(String.valueOf(received.get("content"))).isEqualTo(userMessage);
224314
}
225-
226315
}
316+

0 commit comments

Comments
 (0)