@@ -433,6 +433,9 @@ const char* GetErrorString(int err)
433433 case WS_SFTP_NOT_FILE_E :
434434 return "not a regular file" ;
435435
436+ case WS_MSGID_NOT_ALLOWED_E :
437+ return "message not allowed before user authentication" ;
438+
436439 default :
437440 return "Unknown error code" ;
438441 }
@@ -545,6 +548,84 @@ static void HandshakeInfoFree(HandshakeInfo* hs, void* heap)
545548}
546549
547550
551+ #ifndef NO_WOLFSSH_SERVER
552+ INLINE static int IsMessageAllowedServer (WOLFSSH * ssh , byte msg )
553+ {
554+ /* Has client userauth started? */
555+ if (ssh -> acceptState < ACCEPT_KEYED ) {
556+ if (msg > MSGID_KEXDH_LIMIT ) {
557+ return 0 ;
558+ }
559+ }
560+ /* Is server userauth complete? */
561+ if (ssh -> acceptState < ACCEPT_SERVER_USERAUTH_SENT ) {
562+ /* Explicitly check for messages not allowed before user
563+ * authentication has comleted. */
564+ if (msg >= MSGID_USERAUTH_LIMIT ) {
565+ WLOG (WS_LOG_DEBUG , "Message ID %u not allowed by server "
566+ "before user authentication is complete" , msg );
567+ return 0 ;
568+ }
569+ /* Explicitly check for the user authentication messages that
570+ * only the server sends, it shouldn't receive them. */
571+ if (msg > MSGID_USERAUTH_RESTRICT ) {
572+ WLOG (WS_LOG_DEBUG , "Message ID %u not allowed by server "
573+ "during user authentication" , msg );
574+ return 0 ;
575+ }
576+ }
577+ return 1 ;
578+ }
579+ #endif /* NO_WOLFSSH_SERVER */
580+
581+
582+ #ifndef NO_WOLFSSH_CLIENT
583+ INLINE static int IsMessageAllowedClient (WOLFSSH * ssh , byte msg )
584+ {
585+ /* Has client userauth started? */
586+ if (ssh -> connectState < CONNECT_CLIENT_KEXDH_INIT_SENT ) {
587+ if (msg >= MSGID_KEXDH_LIMIT ) {
588+ return 0 ;
589+ }
590+ }
591+ /* Is client userauth complete? */
592+ if (ssh -> connectState < CONNECT_SERVER_USERAUTH_ACCEPT_DONE ) {
593+ /* Explicitly check for messages not allowed before user
594+ * authentication has comleted. */
595+ if (msg >= MSGID_USERAUTH_LIMIT ) {
596+ WLOG (WS_LOG_DEBUG , "Message ID %u not allowed by client "
597+ "before user authentication is complete" , msg );
598+ return 0 ;
599+ }
600+ /* Explicitly check for the user authentication message that
601+ * only the client sends, it shouldn't receive it. */
602+ if (msg == MSGID_USERAUTH_RESTRICT ) {
603+ WLOG (WS_LOG_DEBUG , "Message ID %u not allowed by client "
604+ "during user authentication" , msg );
605+ return 0 ;
606+ }
607+ }
608+ return 1 ;
609+ }
610+ #endif /* NO_WOLFSSH_CLIENT */
611+
612+
613+ INLINE static int IsMessageAllowed (WOLFSSH * ssh , byte msg )
614+ {
615+ #ifndef NO_WOLFSSH_SERVER
616+ if (ssh -> ctx -> side == WOLFSSH_ENDPOINT_SERVER ) {
617+ return IsMessageAllowedServer (ssh , msg );
618+ }
619+ #endif /* NO_WOLFSSH_SERVER */
620+ #ifndef NO_WOLFSSH_CLIENT
621+ if (ssh -> ctx -> side == WOLFSSH_ENDPOINT_CLIENT ) {
622+ return IsMessageAllowedClient (ssh , msg );
623+ }
624+ #endif /* NO_WOLFSSH_CLIENT */
625+ return 0 ;
626+ }
627+
628+
548629#ifdef DEBUG_WOLFSSH
549630
550631static const char cannedBanner [] =
@@ -7526,6 +7607,10 @@ static int DoPacket(WOLFSSH* ssh, byte* bufferConsumed)
75267607 return WS_OVERFLOW_E ;
75277608 }
75287609
7610+ if (!IsMessageAllowed (ssh , msg )) {
7611+ return WS_MSGID_NOT_ALLOWED_E ;
7612+ }
7613+
75297614 switch (msg ) {
75307615
75317616 case MSGID_DISCONNECT :
0 commit comments