77import java .lang .reflect .Type ;
88import java .util .ArrayList ;
99import java .util .Collections ;
10+ import java .util .HashMap ;
1011import java .util .List ;
12+ import java .util .Map ;
1113import java .util .concurrent .BlockingQueue ;
1214import java .util .concurrent .ExecutionException ;
1315import java .util .concurrent .LinkedBlockingQueue ;
2022import org .springframework .beans .factory .annotation .Autowired ;
2123import org .springframework .boot .test .context .SpringBootTest ;
2224import org .springframework .boot .test .web .server .LocalServerPort ;
25+ import org .springframework .http .HttpHeaders ;
2326import org .springframework .messaging .converter .MappingJackson2MessageConverter ;
2427import org .springframework .messaging .simp .stomp .StompFrameHandler ;
2528import org .springframework .messaging .simp .stomp .StompHeaders ;
2831import org .springframework .test .annotation .DirtiesContext ;
2932import org .springframework .test .context .ActiveProfiles ;
3033import org .springframework .test .context .bean .override .mockito .MockitoBean ;
34+ import org .springframework .test .context .bean .override .mockito .MockitoSpyBean ;
3135import org .springframework .web .socket .WebSocketHttpHeaders ;
3236import org .springframework .web .socket .client .standard .StandardWebSocketClient ;
3337import org .springframework .web .socket .messaging .WebSocketStompClient ;
3842import com .fasterxml .jackson .databind .ObjectMapper ;
3943import com .fasterxml .jackson .datatype .jsr310 .JavaTimeModule ;
4044import com .sofa .linkiving .domain .chat .ai .AnswerClient ;
41- import com .sofa .linkiving .domain .chat .dto .request .AnswerCancelReq ;
42- import com .sofa .linkiving .domain .chat .dto .request .AnswerReq ;
43- import com .sofa .linkiving .domain .chat .dto .response .AnswerRes ;
4445import com .sofa .linkiving .domain .chat .dto .response .RagAnswerRes ;
4546import com .sofa .linkiving .domain .chat .entity .Chat ;
4647import com .sofa .linkiving .domain .chat .repository .ChatRepository ;
5152import com .sofa .linkiving .domain .member .repository .MemberRepository ;
5253import com .sofa .linkiving .infra .redis .RedisService ;
5354import com .sofa .linkiving .security .jwt .JwtTokenProvider ;
55+ import com .sofa .linkiving .security .jwt .error .CustomJwtException ;
56+ import com .sofa .linkiving .security .jwt .error .JwtErrorCode ;
57+
58+ import jakarta .annotation .Nonnull ;
5459
5560@ SpringBootTest (webEnvironment = SpringBootTest .WebEnvironment .RANDOM_PORT )
5661@ DirtiesContext (classMode = DirtiesContext .ClassMode .AFTER_EACH_TEST_METHOD )
@@ -78,15 +83,16 @@ public class WebSocketChatIntegrationTest {
7883 @ MockitoBean
7984 private RedisService redisService ;
8085
81- @ Autowired
86+ @ MockitoSpyBean
8287 private JwtTokenProvider jwtTokenProvider ;
8388
8489 @ MockitoBean
8590 private AnswerClient answerClient ;
8691
8792 private StompSession stompSession ;
88- private BlockingQueue <AnswerRes > blockingQueue ;
93+ private BlockingQueue <Map < String , Object > > blockingQueue ;
8994 private Chat testChat ;
95+ private Member testMember ;
9096
9197 @ BeforeEach
9298 void setUp () throws ExecutionException , InterruptedException , TimeoutException {
@@ -97,7 +103,7 @@ void setUp() throws ExecutionException, InterruptedException, TimeoutException {
97103 memberRepository .deleteAllInBatch ();
98104
99105 // 1. 데이터 초기화
100- Member testMember = memberRepository .save (Member .builder ()
106+ testMember = memberRepository .save (Member .builder ()
101107 .email ("test@test.com" )
102108 .password ("password" )
103109 .build ());
@@ -127,14 +133,19 @@ void setUp() throws ExecutionException, InterruptedException, TimeoutException {
127133
128134 // 3. WebSocket 연결
129135 String wsUrl = "ws://localhost:" + port + "/ws/chat" ;
130- StompHeaders headers = new StompHeaders ();
131136
132137 String validToken = jwtTokenProvider .createAccessToken (testMember .getEmail ());
133- headers .add ("Authorization" , "Bearer " + validToken );
138+ String authHeaderValue = "Bearer " + validToken ;
139+
140+ StompHeaders headers = new StompHeaders ();
141+ headers .add ("Authorization" , authHeaderValue );
142+
143+ WebSocketHttpHeaders webSocketHttpHeaders = new WebSocketHttpHeaders ();
144+ webSocketHttpHeaders .add (HttpHeaders .COOKIE , "accessToken=" + validToken );
134145
135146 this .stompSession = stompClient .connectAsync (
136147 wsUrl ,
137- new WebSocketHttpHeaders () ,
148+ webSocketHttpHeaders ,
138149 headers ,
139150 new StompSessionHandlerAdapter () {
140151 }
@@ -158,38 +169,112 @@ private List<Transport> createTransportClient() {
158169
159170 private void subscribeToChatQueue () {
160171 stompSession .subscribe ("/user/queue/chat" , new StompFrameHandler () {
172+ @ Nonnull
161173 @ Override
162- public Type getPayloadType (StompHeaders headers ) {
163- return AnswerRes .class ;
174+ public Type getPayloadType (@ Nonnull StompHeaders headers ) {
175+ return Map .class ;
164176 }
165177
166178 @ Override
167- public void handleFrame (StompHeaders headers , Object payload ) {
168- blockingQueue .offer ((AnswerRes )payload );
179+ public void handleFrame (@ Nonnull StompHeaders headers , Object payload ) {
180+ if (payload != null ) {
181+ blockingQueue .add ((Map <String , Object >)payload );
182+ }
169183 }
170184 });
171185 }
172186
173187 @ Test
174- @ DisplayName ("유저가 메시지를 보내면 AI 응답이 Queue로 수신된다" )
188+ @ DisplayName ("쿠키에 담긴 accessToken으로 정상적으로 웹소켓 연결이 가능하다" )
189+ void shouldConnectSuccessfullyWithCookie () throws Exception {
190+ // given
191+ WebSocketStompClient customClient = new WebSocketStompClient (new SockJsClient (createTransportClient ()));
192+ String validToken = jwtTokenProvider .createAccessToken (testMember .getEmail ());
193+
194+ WebSocketHttpHeaders httpHeaders = new WebSocketHttpHeaders ();
195+ httpHeaders .add (HttpHeaders .COOKIE , "accessToken=" + validToken );
196+
197+ // when
198+ StompSession session = customClient .connectAsync (
199+ "ws://localhost:" + port + "/ws/chat" ,
200+ httpHeaders ,
201+ new StompHeaders (),
202+ new StompSessionHandlerAdapter () {
203+ }
204+ ).get (5 , SECONDS );
205+
206+ // then
207+ assertThat (session .isConnected ()).isTrue ();
208+ session .disconnect ();
209+ }
210+
211+ @ Test
212+ @ DisplayName ("쿠키(토큰)가 없거나 유효하지 않으면 웹소켓 연결에 실패한다 (401)" )
213+ void shouldFailToConnectWithInvalidToken () {
214+ // given
215+ WebSocketStompClient customClient = new WebSocketStompClient (new SockJsClient (createTransportClient ()));
216+
217+ // 잘못된 쿠키 세팅
218+ WebSocketHttpHeaders httpHeaders = new WebSocketHttpHeaders ();
219+ httpHeaders .add (HttpHeaders .COOKIE , "accessToken=invalid_token_value" );
220+
221+ // when & then
222+ assertThatThrownBy (() -> customClient .connectAsync (
223+ "ws://localhost:" + port + "/ws/chat" ,
224+ httpHeaders ,
225+ new StompHeaders (),
226+ new StompSessionHandlerAdapter () {
227+ }
228+ ).get (5 , SECONDS ))
229+ .isInstanceOf (ExecutionException .class );
230+ }
231+
232+ @ Test
233+ @ DisplayName ("연결 이후 SEND 요청 시 인증 토큰이 만료/조작된 경우 메시지를 차단하고 ERROR 프레임을 반환한다" )
234+ void shouldFailToSendWhenTokenIsInvalid () throws Exception {
235+ // given
236+ subscribeToChatQueue ();
237+ Thread .sleep (1000 );
238+
239+ // when
240+ doThrow (new CustomJwtException (JwtErrorCode .EXPIRED_JWT_TOKEN ))
241+ .when (jwtTokenProvider ).validateAccessToken (anyString ());
242+
243+ Long chatId = testChat .getId ();
244+ Map <String , Object > req = new HashMap <>();
245+ req .put ("chatId" , chatId );
246+ req .put ("message" , "만료된 토큰으로 전송 시도" );
247+
248+ stompSession .send ("/ws/chat/send" , req );
249+
250+ // then
251+ Thread .sleep (1000 );
252+ verify (answerClient , never ()).generateAnswer (any ());
253+ }
254+
255+ @ Test
256+ @ DisplayName ("정상적인 유저가 메시지를 보내면 AI 응답이 Queue로 수신된다" )
175257 void shouldReceiveAnswerWhenMessageSent () throws InterruptedException {
176258 // given
177259 Long chatId = testChat .getId ();
178260 String userMessage = "Gemini에 대해 알려줘" ;
179- AnswerReq req = new AnswerReq (chatId , userMessage );
261+
262+ Map <String , Object > req = new HashMap <>();
263+ req .put ("chatId" , chatId );
264+ req .put ("message" , userMessage );
180265
181266 subscribeToChatQueue ();
267+ Thread .sleep (1000 );
182268
183269 // when
184270 stompSession .send ("/ws/chat/send" , req );
185271
186272 // then
187- AnswerRes received = blockingQueue .poll (10 , SECONDS );
273+ Map < String , Object > received = blockingQueue .poll (10 , SECONDS );
188274
189275 assertThat (received ).isNotNull ();
190- assertThat (received .chatId ()).isEqualTo (chatId );
191- assertThat (received .success ()).isTrue ();
192- assertThat (received .content ()).contains ("Gemini와 관련된 내용" );
276+ assertThat (received .get ("success" )).isEqualTo (true );
277+ assertThat (String .valueOf (received .get ("content" ))).contains ("Gemini와 관련된 내용" );
193278 }
194279
195280 @ Test
@@ -198,29 +283,34 @@ void shouldReceiveErrorMessageWhenCancelled() throws InterruptedException {
198283 // given
199284 Long chatId = testChat .getId ();
200285 String userMessage = "취소될 질문" ;
201- AnswerReq sendReq = new AnswerReq (chatId , userMessage );
202- AnswerCancelReq cancelReq = new AnswerCancelReq (chatId );
286+
287+ Map <String , Object > sendReq = new HashMap <>();
288+ sendReq .put ("chatId" , chatId );
289+ sendReq .put ("message" , userMessage );
290+
291+ Map <String , Object > cancelReq = new HashMap <>();
292+ cancelReq .put ("chatId" , chatId );
203293
204294 given (answerClient .generateAnswer (any ())).willAnswer (invocation -> {
205- Thread .sleep (500 );
295+ Thread .sleep (1500 );
206296 return new RagAnswerRes ("지연된 답변" , Collections .emptyList (), Collections .emptyList (), Collections .emptyList (),
207297 true );
208298 });
209299
210300 subscribeToChatQueue ();
301+ Thread .sleep (1000 );
211302
212303 // when
213304 stompSession .send ("/ws/chat/send" , sendReq );
214- Thread .sleep (50 );
305+ Thread .sleep (300 );
215306 stompSession .send ("/ws/chat/cancel" , cancelReq );
216307
217308 // then
218- AnswerRes received = blockingQueue .poll (5 , SECONDS );
309+ Map < String , Object > received = blockingQueue .poll (10 , SECONDS );
219310
220311 assertThat (received ).isNotNull ();
221- assertThat (received .chatId ()).isEqualTo (chatId );
222- assertThat (received .success ()).isFalse ();
223- assertThat (received .content ()).isEqualTo (userMessage );
312+ assertThat (received .get ("success" )).isEqualTo (false );
313+ assertThat (String .valueOf (received .get ("content" ))).isEqualTo (userMessage );
224314 }
225-
226315}
316+
0 commit comments