forked from jeroenrinzema/psql-wire
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathauth_test.go
More file actions
130 lines (100 loc) · 3.83 KB
/
auth_test.go
File metadata and controls
130 lines (100 loc) · 3.83 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
package wire
import (
"bytes"
"context"
"fmt"
"strconv"
"testing"
"github.com/jeroenrinzema/psql-wire/pkg/buffer"
"github.com/jeroenrinzema/psql-wire/pkg/types"
"github.com/neilotoole/slogt"
"github.com/stretchr/testify/require"
)
func TestDefaultHandleAuth(t *testing.T) {
input := bytes.NewBuffer([]byte{})
sink := bytes.NewBuffer([]byte{})
ctx := context.Background()
reader := buffer.NewReader(slogt.New(t), input, buffer.DefaultBufferSize)
writer := buffer.NewWriter(slogt.New(t), sink)
server := &Server{logger: slogt.New(t)}
_, err := server.handleAuth(ctx, reader, writer)
require.NoError(t, err)
result := buffer.NewReader(slogt.New(t), sink, buffer.DefaultBufferSize)
ty, ln, err := result.ReadTypedMsg()
require.NoError(t, err)
if ln == 0 {
t.Error("unexpected length, expected typed message length to be greater then 0")
}
if ty != 'R' {
t.Errorf("unexpected message type %s, expected 'R'", strconv.QuoteRune(rune(ty)))
}
status, err := result.GetUint32()
require.NoError(t, err)
if authType(status) != authOK {
t.Errorf("unexpected auth status %d, expected OK", status)
}
}
func TestClearTextPassword(t *testing.T) {
expected := "password"
input := bytes.NewBuffer([]byte{})
incoming := buffer.NewWriter(slogt.New(t), input)
// NOTE: we could reuse the server buffered writer to write client messages
incoming.Start(types.ServerMessage(types.ClientPassword))
incoming.AddString(expected)
incoming.AddNullTerminate()
incoming.End() //nolint:errcheck
validate := func(ctx context.Context, database, username, password string) (context.Context, bool, error) {
if password != expected {
return ctx, false, fmt.Errorf("unexpected password: %s", password)
}
return ctx, true, nil
}
sink := bytes.NewBuffer([]byte{})
ctx := context.Background()
reader := buffer.NewReader(slogt.New(t), input, buffer.DefaultBufferSize)
writer := buffer.NewWriter(slogt.New(t), sink)
server := &Server{logger: slogt.New(t), Auth: ClearTextPassword(validate)}
out, err := server.handleAuth(ctx, reader, writer)
require.NoError(t, err)
require.Equal(t, ctx, out)
}
func TestClearTextPasswordIncorrect(t *testing.T) {
correctPassword := "correct-password"
incorrectPassword := "wrong-password"
input := bytes.NewBuffer([]byte{})
incoming := buffer.NewWriter(slogt.New(t), input)
// Client sends the incorrect password
incoming.Start(types.ServerMessage(types.ClientPassword))
incoming.AddString(incorrectPassword)
incoming.AddNullTerminate()
incoming.End() //nolint:errcheck
validate := func(ctx context.Context, database, username, password string) (context.Context, bool, error) {
// Only accept the correct password
if password == correctPassword {
return ctx, true, nil
}
return ctx, false, nil
}
sink := bytes.NewBuffer([]byte{})
ctx := context.Background()
reader := buffer.NewReader(slogt.New(t), input, buffer.DefaultBufferSize)
writer := buffer.NewWriter(slogt.New(t), sink)
server := &Server{logger: slogt.New(t), Auth: ClearTextPassword(validate)}
_, err := server.handleAuth(ctx, reader, writer)
// Authentication should fail with an error
require.Error(t, err)
require.Contains(t, err.Error(), "invalid username/password")
// Verify what was written to the client
result := buffer.NewReader(slogt.New(t), sink, buffer.DefaultBufferSize)
// First message should be the auth request (asking for password)
ty, _, err := result.ReadTypedMsg()
require.NoError(t, err)
require.Equal(t, types.ServerMessage(ty), types.ServerAuth)
// The client SHOULD receive an error response message
ty, _, err = result.ReadTypedMsg()
require.NoError(t, err)
require.Equal(t, types.ServerMessage(ty), types.ServerErrorResponse)
// No ready for query message should follow (connection will be closed)
_, _, err = result.ReadTypedMsg()
require.Error(t, err, "Expected no ready for query message after auth failure")
}