@@ -445,6 +445,9 @@ const char* GetErrorString(int err)
445445 case WS_SFTP_NOT_FILE_E :
446446 return "not a regular file" ;
447447
448+ case WS_MSGID_NOT_ALLOWED_E :
449+ return "message not allowed before user authentication" ;
450+
448451 default :
449452 return "Unknown error code" ;
450453 }
@@ -557,6 +560,84 @@ static void HandshakeInfoFree(HandshakeInfo* hs, void* heap)
557560}
558561
559562
563+ #ifndef NO_WOLFSSH_SERVER
564+ INLINE static int IsMessageAllowedServer (WOLFSSH * ssh , byte msg )
565+ {
566+ /* Has client userauth started? */
567+ if (ssh -> acceptState < ACCEPT_KEYED ) {
568+ if (msg > MSGID_KEXDH_LIMIT ) {
569+ return 0 ;
570+ }
571+ }
572+ /* Is server userauth complete? */
573+ if (ssh -> acceptState < ACCEPT_SERVER_USERAUTH_SENT ) {
574+ /* Explicitly check for messages not allowed before user
575+ * authentication has comleted. */
576+ if (msg >= MSGID_USERAUTH_LIMIT ) {
577+ WLOG (WS_LOG_DEBUG , "Message ID %u not allowed by server "
578+ "before user authentication is complete" , msg );
579+ return 0 ;
580+ }
581+ /* Explicitly check for the user authentication messages that
582+ * only the server sends, it shouldn't receive them. */
583+ if (msg > MSGID_USERAUTH_RESTRICT ) {
584+ WLOG (WS_LOG_DEBUG , "Message ID %u not allowed by server "
585+ "during user authentication" , msg );
586+ return 0 ;
587+ }
588+ }
589+ return 1 ;
590+ }
591+ #endif /* NO_WOLFSSH_SERVER */
592+
593+
594+ #ifndef NO_WOLFSSH_CLIENT
595+ INLINE static int IsMessageAllowedClient (WOLFSSH * ssh , byte msg )
596+ {
597+ /* Has client userauth started? */
598+ if (ssh -> connectState < CONNECT_CLIENT_KEXDH_INIT_SENT ) {
599+ if (msg >= MSGID_KEXDH_LIMIT ) {
600+ return 0 ;
601+ }
602+ }
603+ /* Is client userauth complete? */
604+ if (ssh -> connectState < CONNECT_SERVER_USERAUTH_ACCEPT_DONE ) {
605+ /* Explicitly check for messages not allowed before user
606+ * authentication has comleted. */
607+ if (msg >= MSGID_USERAUTH_LIMIT ) {
608+ WLOG (WS_LOG_DEBUG , "Message ID %u not allowed by client "
609+ "before user authentication is complete" , msg );
610+ return 0 ;
611+ }
612+ /* Explicitly check for the user authentication message that
613+ * only the client sends, it shouldn't receive it. */
614+ if (msg == MSGID_USERAUTH_RESTRICT ) {
615+ WLOG (WS_LOG_DEBUG , "Message ID %u not allowed by client "
616+ "during user authentication" , msg );
617+ return 0 ;
618+ }
619+ }
620+ return 1 ;
621+ }
622+ #endif /* NO_WOLFSSH_CLIENT */
623+
624+
625+ INLINE static int IsMessageAllowed (WOLFSSH * ssh , byte msg )
626+ {
627+ #ifndef NO_WOLFSSH_SERVER
628+ if (ssh -> ctx -> side == WOLFSSH_ENDPOINT_SERVER ) {
629+ return IsMessageAllowedServer (ssh , msg );
630+ }
631+ #endif /* NO_WOLFSSH_SERVER */
632+ #ifndef NO_WOLFSSH_CLIENT
633+ if (ssh -> ctx -> side == WOLFSSH_ENDPOINT_CLIENT ) {
634+ return IsMessageAllowedClient (ssh , msg );
635+ }
636+ #endif /* NO_WOLFSSH_CLIENT */
637+ return 0 ;
638+ }
639+
640+
560641#ifdef DEBUG_WOLFSSH
561642
562643static const char cannedBanner [] =
@@ -8019,6 +8100,10 @@ static int DoPacket(WOLFSSH* ssh, byte* bufferConsumed)
80198100 return WS_OVERFLOW_E ;
80208101 }
80218102
8103+ if (!IsMessageAllowed (ssh , msg )) {
8104+ return WS_MSGID_NOT_ALLOWED_E ;
8105+ }
8106+
80228107 switch (msg ) {
80238108
80248109 case MSGID_DISCONNECT :
0 commit comments