@@ -156,6 +156,7 @@ word32 kbResponseCount;
156156byte kbMultiRound = 0 ;
157157byte currentRound = 0 ;
158158byte unbalanced = 0 ;
159+ byte useUserAuthCb = 0 ; /* Flag to test userAuthCb for keyboard-interactive */
159160
160161WS_UserAuthData_Keyboard promptData ;
161162
@@ -223,32 +224,48 @@ static int load_key(byte isEcc, byte* buf, word32 bufSz)
223224static int serverUserAuth (byte authType , WS_UserAuthData * authData , void * ctx )
224225{
225226 (void ) ctx ;
226- if (authType != WOLFSSH_USERAUTH_KEYBOARD ) {
227- return WOLFSSH_USERAUTH_FAILURE ;
228- }
229-
230- if (authData -> sf .keyboard .responseCount != kbResponseCount ) {
231- return WOLFSSH_USERAUTH_FAILURE ;
232- }
233-
234- for (word32 resp = 0 ; resp < kbResponseCount ; resp ++ ) {
235- if (authData -> sf .keyboard .responseLengths [resp ] !=
236- kbResponseLengths [resp ]) {
227+
228+ /* Handle keyboard-interactive auth */
229+ if (authType == WOLFSSH_USERAUTH_KEYBOARD ) {
230+ /* If responseCount is 0, this is a prompt setup call */
231+ if (authData -> sf .keyboard .responseCount == 0 ) {
232+ /* Set up prompts */
233+ WMEMCPY (& authData -> sf .keyboard , & promptData , sizeof (WS_UserAuthData_Keyboard ));
234+
235+ /* Return SUCCESS_ANOTHER to proceed with sending prompts */
236+ if (useUserAuthCb ) {
237+ return WOLFSSH_USERAUTH_SUCCESS_ANOTHER ;
238+ }
239+ /* When not testing userAuthCb, return FAILURE to fall back to keyboardAuthCb */
237240 return WOLFSSH_USERAUTH_FAILURE ;
238-
239241 }
240- if (WSTRCMP ((const char * )authData -> sf .keyboard .responses [resp ],
241- (const char * )kbResponses [resp ]) != 0 ) {
242+
243+ /* Validate responses */
244+ if (authData -> sf .keyboard .responseCount != kbResponseCount ) {
242245 return WOLFSSH_USERAUTH_FAILURE ;
243246 }
247+
248+ for (word32 resp = 0 ; resp < kbResponseCount ; resp ++ ) {
249+ if (authData -> sf .keyboard .responseLengths [resp ] !=
250+ kbResponseLengths [resp ]) {
251+ return WOLFSSH_USERAUTH_FAILURE ;
252+
253+ }
254+ if (WSTRCMP ((const char * )authData -> sf .keyboard .responses [resp ],
255+ (const char * )kbResponses [resp ]) != 0 ) {
256+ return WOLFSSH_USERAUTH_FAILURE ;
257+ }
258+ }
259+ if (kbMultiRound && currentRound == 0 ) {
260+ currentRound ++ ;
261+ kbResponses [0 ] = (byte * )testText2 ;
262+ kbResponseLengths [0 ] = 8 ;
263+ return WOLFSSH_USERAUTH_SUCCESS_ANOTHER ;
264+ }
265+ return WOLFSSH_USERAUTH_SUCCESS ;
244266 }
245- if (kbMultiRound && currentRound == 0 ) {
246- currentRound ++ ;
247- kbResponses [0 ] = (byte * )testText2 ;
248- kbResponseLengths [0 ] = 8 ;
249- return WOLFSSH_USERAUTH_SUCCESS_ANOTHER ;
250- }
251- return WOLFSSH_USERAUTH_SUCCESS ;
267+
268+ return WOLFSSH_USERAUTH_FAILURE ;
252269}
253270
254271static int serverKeyboardCallback (WS_UserAuthData_Keyboard * kbAuth , void * ctx )
@@ -332,7 +349,12 @@ static THREAD_RETURN WOLFSSH_THREAD server_thread(void* args)
332349 }
333350
334351 wolfSSH_SetUserAuth (ctx , serverUserAuth );
335- wolfSSH_SetKeyboardAuthPrompts (ctx , serverKeyboardCallback );
352+
353+ /* Only set keyboard auth callback when not testing userAuthCb */
354+ if (!useUserAuthCb ) {
355+ wolfSSH_SetKeyboardAuthPrompts (ctx , serverKeyboardCallback );
356+ }
357+
336358 ssh = wolfSSH_new (ctx );
337359 if (ssh == NULL ) {
338360 ES_ERROR ("Couldn't allocate SSH data.\n" );
@@ -574,6 +596,34 @@ static void test_unbalanced_client_KeyboardInteractive(void)
574596 test_client ();
575597 unbalanced = 0 ;
576598}
599+
600+ static void test_userAuthCb_KeyboardInteractive (void )
601+ {
602+ printf ("Testing keyboard-interactive auth via userAuthCb\n" );
603+ kbResponses [0 ] = (byte * )testText1 ;
604+ kbResponseLengths [0 ] = 4 ;
605+ kbResponseCount = 1 ;
606+ useUserAuthCb = 1 ;
607+
608+ test_client ();
609+ useUserAuthCb = 0 ;
610+ }
611+
612+ static void test_userAuthCb_multi_round_KeyboardInteractive (void )
613+ {
614+ printf ("Testing multiple prompt rounds via userAuthCb\n" );
615+ kbResponses [0 ] = (byte * )testText1 ;
616+ kbResponseLengths [0 ] = 4 ;
617+ kbResponseCount = 1 ;
618+ kbMultiRound = 1 ;
619+ useUserAuthCb = 1 ;
620+
621+ test_client ();
622+ AssertIntEQ (currentRound , 1 );
623+ currentRound = 0 ;
624+ kbMultiRound = 0 ;
625+ useUserAuthCb = 0 ;
626+ }
577627#endif /* WOLFSSH_TEST_BLOCK */
578628
579629int wolfSSH_AuthTest (int argc , char * * argv )
@@ -603,6 +653,8 @@ int wolfSSH_AuthTest(int argc, char** argv)
603653 test_multi_prompt_KeyboardInteractive ();
604654 test_multi_round_KeyboardInteractive ();
605655 test_unbalanced_client_KeyboardInteractive ();
656+ test_userAuthCb_KeyboardInteractive ();
657+ test_userAuthCb_multi_round_KeyboardInteractive ();
606658
607659 AssertIntEQ (wolfSSH_Cleanup (), WS_SUCCESS );
608660
@@ -616,5 +668,3 @@ int main(int argc, char** argv)
616668 return wolfSSH_AuthTest (argc , argv );
617669}
618670#endif
619-
620-
0 commit comments