diff --git a/ssh/server.go b/ssh/server.go index 3c0fcc953e..c01a8f59c6 100644 --- a/ssh/server.go +++ b/ssh/server.go @@ -804,7 +804,8 @@ userAuthLoop: if len(payload) > 0 { return nil, parseError(msgUserAuthRequest) } - _, isPartialSuccessError := candidate.result.(*PartialSuccessError) + var _partialSuccessForQuery *PartialSuccessError + isPartialSuccessError := errors.As(candidate.result, &_partialSuccessForQuery) if candidate.result == nil || isPartialSuccessError { okMsg := userAuthPubKeyOkMsg{ Algo: algo, @@ -946,7 +947,8 @@ userAuthLoop: var failureMsg userAuthFailureMsg - if partialSuccess, ok := authErr.(*PartialSuccessError); ok { + var partialSuccess *PartialSuccessError + if errors.As(authErr, &partialSuccess) { // Permissions are not preserved between authentication steps. To // avoid confusion about the final state of the connection, we // disallow returning non-nil Permissions combined with diff --git a/ssh/server_multi_auth_test.go b/ssh/server_multi_auth_test.go index 3b39802437..0722b8a3cd 100644 --- a/ssh/server_multi_auth_test.go +++ b/ssh/server_multi_auth_test.go @@ -410,3 +410,79 @@ func TestDynamicAuthCallbacks(t *testing.T) { t.Fatal("server not returned partial success") } } + +// TestBannerWrappingPartialSuccess verifies that a PartialSuccessError wrapped +// inside a BannerError is correctly processed: the banner is sent and the +// partial success is honoured (i.e. the Next auth callbacks are used for the +// second step). Prior to the fix, the direct type-assertion +// authErr.(*PartialSuccessError) missed the wrapped value while +// errors.As(authErr, &bannerErr) had already unwrapped it, causing the +// PartialSuccess to be silently discarded and authFailures to be incremented +// instead. +func TestBannerWrappingPartialSuccess(t *testing.T) { + var bannerReceived string + + serverConfig := &ServerConfig{ + PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) { + if !bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) { + return nil, errors.New("unknown key") + } + // Return BannerError wrapping PartialSuccessError. + // The server wants to send a banner AND require a second factor. + return nil, &BannerError{ + Err: &PartialSuccessError{ + Next: ServerAuthCallbacks{ + PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) { + if string(password) == clientPassword { + return &Permissions{}, nil + } + return nil, errors.New("wrong password") + }, + }, + }, + Message: "Two-factor auth required\n", + } + }, + BannerCallback: func(conn ConnMetadata) string { return "" }, + } + + clientConfig := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + PublicKeys(testSigners["rsa"]), + Password(clientPassword), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + BannerCallback: func(msg string) error { + bannerReceived = msg + return nil + }, + } + + serverAuthErrors, err := doClientServerAuth(t, serverConfig, clientConfig) + if err != nil { + t.Fatalf("expected successful login, got: %v", err) + } + + // Verify banner was sent to the client. + if bannerReceived != "Two-factor auth required\n" { + t.Errorf("banner not received; got %q", bannerReceived) + } + + // Auth log entries: + // 0: ErrNoAuth from the initial 'none' attempt (always happens) + // 1: BannerError wrapping PartialSuccessError (publickey step) + // 2: nil (password step succeeds) + if len(serverAuthErrors) != 3 { + t.Fatalf("expected 3 auth log entries, got %d: %v", len(serverAuthErrors), serverAuthErrors) + } + // The second entry wraps a *PartialSuccessError inside a *BannerError. + var partialSuccessErr *PartialSuccessError + if !errors.As(serverAuthErrors[1], &partialSuccessErr) { + t.Errorf("second auth log entry should wrap *PartialSuccessError, got %T: %v", + serverAuthErrors[1], serverAuthErrors[1]) + } + if serverAuthErrors[2] != nil { + t.Errorf("third auth log entry should be nil (success), got: %v", serverAuthErrors[2]) + } +}