Skip to content

Commit 807632a

Browse files
committed
Implement require_auth connection parameter
Also correct the PGSSLMINPROTOCOLVERSION and PGSSLMAXPROTOCOLVERSION env var names >_<
1 parent 1b41ea6 commit 807632a

5 files changed

Lines changed: 158 additions & 14 deletions

File tree

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,17 @@ unreleased
99

1010
### Features
1111

12+
- Implement `require_auth` connection parameter ([#1310]).
13+
1214
### Fixes
1315

1416
- Add Redshift-specific OID mappings ([#1291], [#1317]).
1517

18+
- Use correct environment variable name for `PGSSLMINPROTOCOLVERSION` and
19+
`PGSSLMAXPROTOCOLVERSION` ([#1310]).
20+
1621
[#1291]: https://github.com/lib/pq/pull/1291
22+
[#1310]: https://github.com/lib/pq/pull/1310
1723
[#1317]: https://github.com/lib/pq/pull/1317
1824

1925

conn.go

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"net"
1616
"os"
1717
"reflect"
18+
"slices"
1819
"strconv"
1920
"strings"
2021
"sync"
@@ -1255,6 +1256,7 @@ func (cn *conn) startup(cfg Config) error {
12551256
return err
12561257
}
12571258

1259+
var didauth bool
12581260
for {
12591261
t, r, err := cn.recv()
12601262
if err != nil {
@@ -1271,7 +1273,11 @@ func (cn *conn) startup(cfg Config) error {
12711273
case proto.ParameterStatus:
12721274
cn.processParameterStatus(r)
12731275
case proto.AuthenticationRequest:
1274-
err := cn.auth(r, cfg)
1276+
code := proto.AuthCode(r.int32())
1277+
if code != proto.AuthReqOk {
1278+
didauth = true
1279+
}
1280+
err := cn.auth(code, r, cfg)
12751281
if err != nil {
12761282
return err
12771283
}
@@ -1282,6 +1288,9 @@ func (cn *conn) startup(cfg Config) error {
12821288
return fmt.Errorf("pq: protocol version mismatch: min_protocol_version=%s; server supports up to 3.%d", cfg.MinProtocolVersion, newestMinor)
12831289
}
12841290
case proto.ReadyForQuery:
1291+
if len(cn.cfg.RequireAuth) > 0 && !didauth && !slices.Contains(cn.cfg.RequireAuth, RequireAuthNone) {
1292+
return fmt.Errorf("pq: authentication method requirement %q failed: server did not perform any authentication", cn.cfg.RequireAuth)
1293+
}
12851294
cn.processReadyForQuery(r)
12861295
return nil
12871296
default:
@@ -1290,8 +1299,8 @@ func (cn *conn) startup(cfg Config) error {
12901299
}
12911300
}
12921301

1293-
func (cn *conn) auth(r *readBuf, cfg Config) error {
1294-
switch code := proto.AuthCode(r.int32()); code {
1302+
func (cn *conn) auth(code proto.AuthCode, r *readBuf, cfg Config) error {
1303+
switch code {
12951304
default:
12961305
return fmt.Errorf("pq: unknown authentication response: %s", code)
12971306
case proto.AuthReqKrb4, proto.AuthReqKrb5, proto.AuthReqCrypt, proto.AuthReqSSPI:
@@ -1300,13 +1309,19 @@ func (cn *conn) auth(r *readBuf, cfg Config) error {
13001309
return nil
13011310

13021311
case proto.AuthReqPassword:
1312+
if len(cn.cfg.RequireAuth) > 0 && !slices.Contains(cn.cfg.RequireAuth, RequireAuthPassword) && !slices.Contains(cn.cfg.RequireAuth, RequireAuthAny) {
1313+
return fmt.Errorf("pq: authentication method requirement %q failed: server requested %q", cn.cfg.RequireAuth, RequireAuthPassword)
1314+
}
13031315
w := cn.writeBuf(proto.PasswordMessage)
13041316
w.string(cfg.Password)
13051317
// Don't need to check AuthOk response here; auth() is called in a loop,
13061318
// which catches the errors and AuthReqOk responses.
13071319
return cn.send(w)
13081320

13091321
case proto.AuthReqMD5:
1322+
if len(cn.cfg.RequireAuth) > 0 && !slices.Contains(cn.cfg.RequireAuth, RequireAuthMD5) && !slices.Contains(cn.cfg.RequireAuth, RequireAuthAny) {
1323+
return fmt.Errorf("pq: authentication method requirement %q failed: server requested %q", cn.cfg.RequireAuth, RequireAuthMD5)
1324+
}
13101325
s := string(r.next(4))
13111326
w := cn.writeBuf(proto.PasswordMessage)
13121327
w.string("md5" + md5s(md5s(cfg.Password+cfg.User)+s))
@@ -1369,6 +1384,9 @@ func (cn *conn) auth(r *readBuf, cfg Config) error {
13691384
return nil
13701385

13711386
case proto.AuthReqSASL:
1387+
if len(cn.cfg.RequireAuth) > 0 && !slices.Contains(cn.cfg.RequireAuth, RequireAuthScramSHA256) && !slices.Contains(cn.cfg.RequireAuth, RequireAuthAny) {
1388+
return fmt.Errorf("pq: authentication method requirement %q failed: server requested %q", cn.cfg.RequireAuth, RequireAuthScramSHA256)
1389+
}
13721390
sc := scram.NewClient(sha256.New, cfg.User, cfg.Password)
13731391
sc.Step(nil)
13741392
if sc.Err() != nil {

conn_test.go

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1607,18 +1607,18 @@ func TestCommitInFailedTransactionWithCancelContext(t *testing.T) {
16071607

16081608
func TestAuth(t *testing.T) {
16091609
tests := []struct {
1610-
buf readBuf
1610+
code proto.AuthCode
16111611
wantErr string
16121612
}{
1613-
{readBuf{0, 0, 0, 9}, `pq: unsupported authentication method: SSPI (9)`},
1614-
{readBuf{0, 0, 0, 99}, `unknown authentication response: <unknown> (99)`},
1613+
{proto.AuthCode(9), `pq: unsupported authentication method: SSPI (9)`},
1614+
{proto.AuthCode(99), `unknown authentication response: <unknown> (99)`},
16151615
}
16161616

16171617
t.Parallel()
16181618
for _, tt := range tests {
16191619
t.Run("", func(t *testing.T) {
16201620
t.Run("unsupported auth", func(t *testing.T) {
1621-
err := (&conn{}).auth(&tt.buf, Config{})
1621+
err := (&conn{}).auth(tt.code, &readBuf{}, Config{})
16221622
if !pqtest.ErrorContains(err, tt.wantErr) {
16231623
t.Errorf("wrong error:\nhave: %s\nwant: %s", err, tt.wantErr)
16241624
}
@@ -1646,10 +1646,34 @@ func TestAuth(t *testing.T) {
16461646
{"user=pqgoscram password=wordpass", ``},
16471647

16481648
{"user=pqgounknown password=wordpass", `or:role "pqgounknown" does not exist|password authentication failed for user pqgounknown`},
1649+
{"user=pqgounknown password=wordpass require_auth=md5", `or:role "pqgounknown" does not exist|password authentication failed for user pqgounknown`},
1650+
1651+
// require_auth
1652+
{"user=pqgomd5 password=wordpass require_auth=md5,password", ``},
1653+
{"user=pqgopassword password=wordpass require_auth=md5,password", ``},
1654+
{"user=pqgoscram password=wordpass require_auth=md5,password,scram-sha-256", ``},
1655+
{"user=pqgomd5 password=wordpass require_auth=!none", ``},
1656+
{"user=pqgopassword password=wordpass require_auth=!none", ``},
1657+
{"user=pqgoscram password=wordpass require_auth=!none", ``},
1658+
1659+
{"user=pqgomd5 password=wordpass require_auth=password", `"password" failed: server requested "md5"`},
1660+
{"user=pqgopassword password=wordpass require_auth=md5", `"md5" failed: server requested "password"`},
1661+
{"user=pqgoscram password=wordpass require_auth=md5,password", `authentication method requirement "md5,password" failed: server requested "scram-sha-256"`},
1662+
{"user=pqgomd5 password=wordpass require_auth=!md5,!password", `"!md5,!password" failed: server requested "md5"`},
1663+
{"user=pqgopassword password=wordpass require_auth=!md5,!password", `"!md5,!password" failed: server requested "password"`},
1664+
{"user=pqgoscram password=wordpass require_auth=!md5,!password,!scram-sha-256", `"!md5,!password,!scram-sha-256" failed: server requested "scram-sha-256"`},
1665+
{"user=pqgomd5 password=wordpass require_auth=password", `"password" failed: server requested "md5"`},
1666+
1667+
{"user=pqgo password=unused require_auth=none", ``},
1668+
{"user=pqgo password=unused require_auth=!none", `"!none" failed: server did not perform any authentication`},
1669+
{"user=pqgo password=unused require_auth=md5,password,scram-sha-256", `"md5,password,scram-sha-256" failed: server did not perform any authentication`},
16491670
}
16501671

16511672
for _, tt := range tests {
16521673
t.Run(tt.conn, func(t *testing.T) {
1674+
if strings.Contains(tt.conn, "md5") {
1675+
pqtest.SkipCockroach(t) // md5 not supported
1676+
}
16531677
_, err := pqtest.DB(t, tt.conn)
16541678
if !pqtest.ErrorContains(err, tt.wantErr) {
16551679
t.Errorf("wrong error:\nhave: %s\nwant: %s", err, tt.wantErr)

connector.go

Lines changed: 94 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,12 @@ type (
4545
// SSLProtocolVersion is a ssl_min_protocol_version or
4646
// ssl_max_protocol_version setting.
4747
SSLProtocolVersion string
48+
49+
// RequireAuth is a require_auth setting.
50+
RequireAuth string
51+
52+
// RequireAuths is a require_auth setting.
53+
RequireAuths []RequireAuth
4854
)
4955

5056
// Values for [SSLMode] that pq supports.
@@ -179,6 +185,41 @@ func (s SSLProtocolVersion) tlsconf() uint16 {
179185
}
180186
}
181187

188+
// Values for [RequireAuth] that pq supports.
189+
const (
190+
RequireAuthNone = RequireAuth("none")
191+
RequireAuthPassword = RequireAuth("password")
192+
RequireAuthMD5 = RequireAuth("md5")
193+
RequireAuthGSS = RequireAuth("gss")
194+
RequireAuthScramSHA256 = RequireAuth("scram-sha-256")
195+
RequireAuthAny = RequireAuth("!none")
196+
RequireAuthNotPassword = RequireAuth("!password")
197+
RequireAuthNotMD5 = RequireAuth("!md5")
198+
RequireAuthNotGSS = RequireAuth("!gss")
199+
RequireAuthNotScramSHA256 = RequireAuth("!scram-sha-256")
200+
201+
// Not (yet) supported by pq
202+
// RequireAuthSSPI = "sspi"
203+
// RequireAuthOAuth = "oauth"
204+
// RequireAuthNotSSPI = "!sspi"
205+
// RequireAuthNotOAuth = "!oauth"
206+
)
207+
208+
var requireAuths = []RequireAuth{RequireAuthNone, RequireAuthPassword, RequireAuthMD5,
209+
RequireAuthGSS, RequireAuthScramSHA256, RequireAuthAny, RequireAuthNotPassword,
210+
RequireAuthNotMD5, RequireAuthNotGSS, RequireAuthNotScramSHA256}
211+
212+
func (r RequireAuths) String() string {
213+
var b strings.Builder
214+
for i, rr := range r {
215+
if i > 0 {
216+
b.WriteString(",")
217+
}
218+
b.WriteString(string(rr))
219+
}
220+
return b.String()
221+
}
222+
182223
// Connector represents a fixed configuration for the pq driver with a given
183224
// dsn. Connector satisfies the [database/sql/driver.Connector] interface and
184225
// can be used to create any number of DB Conn's via [sql.OpenDB].
@@ -341,14 +382,14 @@ type Config struct {
341382
//
342383
// The default is determined by [tls.Config.MinVersion], which is TLSv1.2 at
343384
// the time of writing.
344-
SSLMinProtocolVersion SSLProtocolVersion `postgres:"ssl_min_protocol_version" env:"SSLPGMINPROTOCOLVERSION"`
385+
SSLMinProtocolVersion SSLProtocolVersion `postgres:"ssl_min_protocol_version" env:"PGSSLMINPROTOCOLVERSION"`
345386

346387
// Maximum SSL/TLS protocol version to allow for the connection. If not set,
347388
// this parameter is ignored and the connection will use the maximum bound
348389
// defined by the backend, if set. Setting the maximum protocol version is
349390
// mainly useful for testing or if some component has issues working with a
350391
// newer protocol.
351-
SSLMaxProtocolVersion SSLProtocolVersion `postgres:"ssl_max_protocol_version" env:"SSLPGMAXPROTOCOLVERSION"`
392+
SSLMaxProtocolVersion SSLProtocolVersion `postgres:"ssl_max_protocol_version" env:"PGSSLMAXPROTOCOLVERSION"`
352393

353394
// Interpert sslcert and sslkey as PEM encoded data, rather than a path to a
354395
// PEM file. This is a pq extension, not supported in libpq.
@@ -431,6 +472,25 @@ type Config struct {
431472
// Path to connection service file. Defaults to ~/.pg_service.conf.
432473
ServiceFile string `postgres:"-" env:"PGSERVICEFILE"`
433474

475+
// Require an authentication method from the server and refuse to connect if
476+
// the server does not use the requested method.
477+
//
478+
// This accepts a comma-separated list.
479+
//
480+
// Methods may be negated with a ! prefix, in which case the server must
481+
// *not* attempt the listed method, and the server is free not to
482+
// authenticate the client at all. Negated and non-negated forms may not be
483+
// combined in the same setting with a comma-separated list.
484+
//
485+
// As a special case the "none" method requires the server not to use an
486+
// authentication challenge. This does not prohibit client certificate
487+
// authentication via TLS or GSS authentication via its encrypted transport.
488+
// This can be negated to require some form of authentication.
489+
//
490+
// By default any authentication method is accepted and the server is free
491+
// to skip authentication altogether.
492+
RequireAuth RequireAuths `postgres:"require_auth" env:"PGREQUIREAUTH"`
493+
434494
// Runtime parameters: any unrecognized parameter in the DSN will be added
435495
// to this and sent to PostgreSQL during startup.
436496
Runtime map[string]string `postgres:"-" env:"-"`
@@ -517,7 +577,8 @@ func NewConfig(dsn string) (Config, error) {
517577
// Clone returns a copy of the [Config].
518578
func (cfg Config) Clone() Config {
519579
c := cfg
520-
c.Runtime, c.Multi, c.set = maps.Clone(cfg.Runtime), slices.Clone(cfg.Multi), slices.Clone(cfg.set)
580+
c.Runtime, c.Multi, c.RequireAuth, c.set = maps.Clone(cfg.Runtime), slices.Clone(cfg.Multi),
581+
slices.Clone(cfg.RequireAuth), slices.Clone(cfg.set)
521582
return c
522583
}
523584

@@ -672,8 +733,8 @@ func (cfg *Config) fromEnv(env []string) error {
672733
switch k {
673734
case "PGREQUIRESSL", "PGSSLCOMPRESSION", // Deprecated.
674735
"PGREALM", "PGGSSENCMODE", "PGGSSDELEGATION", "PGGSSLIB", // krb stuff
675-
"PGREQUIREAUTH", "PGCHANNELBINDING",
676-
"PGSSLCERTMODE", "PGSSLCRL", "PGSSLCRLDIR", "PGREQUIREPEER":
736+
"PGCHANNELBINDING", "PGSSLCRL", "PGSSLCRLDIR",
737+
"PGSSLCERTMODE", "PGREQUIREPEER":
677738
return fmt.Errorf("pq: environment variable $%s is not supported", k)
678739
case "PGKRBSRVNAME":
679740
if newGss == nil {
@@ -833,8 +894,9 @@ func (cfg *Config) setFromTag(o map[string]string, tag string, service bool) err
833894
loadbalancehosts = (tag == "postgres" && k == "load_balance_hosts") || (tag == "env" && k == "PGLOADBALANCEHOSTS")
834895
minprotocolversion = (tag == "postgres" && k == "min_protocol_version") || (tag == "env" && k == "PGMINPROTOCOLVERSION")
835896
maxprotocolversion = (tag == "postgres" && k == "max_protocol_version") || (tag == "env" && k == "PGMAXPROTOCOLVERSION")
836-
sslminprotocolversion = (tag == "postgres" && k == "ssl_min_protocol_version") || (tag == "env" && k == "SSLPGMINPROTOCOLVERSION")
837-
sslmaxprotocolversion = (tag == "postgres" && k == "ssl_max_protocol_version") || (tag == "env" && k == "SSLPGMAXPROTOCOLVERSION")
897+
sslminprotocolversion = (tag == "postgres" && k == "ssl_min_protocol_version") || (tag == "env" && k == "PGSSLMINPROTOCOLVERSION")
898+
sslmaxprotocolversion = (tag == "postgres" && k == "ssl_max_protocol_version") || (tag == "env" && k == "PGSSLMAXPROTOCOLVERSION")
899+
requireauth = (tag == "postgres" && k == "require_auth") || (tag == "env" && k == "PGREQUIREAUTH")
838900
)
839901
if k == "" || k == "-" {
840902
continue
@@ -908,6 +970,31 @@ func (cfg *Config) setFromTag(o map[string]string, tag string, service bool) err
908970
cfg.multiHost = append(cfg.multiHost, vv[1:]...)
909971
}
910972
rv.SetString(v)
973+
case reflect.Slice:
974+
if requireauth {
975+
if v == "" {
976+
rv.Set(reflect.ValueOf((RequireAuths)(nil)))
977+
continue
978+
}
979+
var (
980+
vv = strings.Split(v, ",")
981+
s = make(RequireAuths, len(vv))
982+
neg = len(vv) > 0 && strings.HasPrefix(vv[0], "!")
983+
)
984+
for i := range vv {
985+
if !slices.Contains(requireAuths, RequireAuth(vv[i])) {
986+
return fmt.Errorf(f+`%q is not supported; supported values are %s`, k, vv[i], pqutil.Join(requireAuths))
987+
}
988+
if neg && !strings.HasPrefix(vv[i], "!") {
989+
return fmt.Errorf(f+`require_auth method %q cannot be mixed with negative methods`, k, vv[i])
990+
}
991+
if !neg && strings.HasPrefix(vv[i], "!") {
992+
return fmt.Errorf(f+`negative require_auth method %q cannot be mixed with non-negative methods`, k, vv[i])
993+
}
994+
s[i] = RequireAuth(vv[i])
995+
}
996+
rv.Set(reflect.ValueOf(s))
997+
}
911998
case reflect.Int64:
912999
n, err := strconv.ParseInt(v, 10, 64)
9131000
if err != nil {

connector_test.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,15 @@ func TestNewConfig(t *testing.T) {
435435
{"", []string{"PGMINPROTOCOLVERSION=bogus"}, "", `pq: wrong value for $PGMINPROTOCOLVERSION: "bogus" is not supported`},
436436
{"", []string{"PGMAXPROTOCOLVERSION=bogus"}, "", `pq: wrong value for $PGMAXPROTOCOLVERSION: "bogus" is not supported`},
437437
{"min_protocol_version=3.2 max_protocol_version=3.0", nil, "", `min_protocol_version "3.2" cannot be greater than max_protocol_version "3.0"`},
438+
439+
// requireauth
440+
{"require_auth=", nil, "require_auth=''", ``},
441+
{"require_auth=none", nil, "require_auth=none", ""},
442+
{"require_auth=md5,scram-sha-256", nil, "require_auth=md5,scram-sha-256", ""},
443+
{"require_auth=md5,scram-sha256", nil, "", `wrong value for "require_auth": "scram-sha256" is not supported`},
444+
{"require_auth=!md5,!scram-sha-256", nil, "require_auth=!md5,!scram-sha-256", ""},
445+
{"require_auth=md5,!password", nil, "", `negative require_auth method "!password" cannot be mixed with non-negative methods`},
446+
{"require_auth=!md5,password", nil, "", `require_auth method "password" cannot be mixed with negative methods`},
438447
}
439448

440449
t.Parallel()

0 commit comments

Comments
 (0)