Skip to content

Commit 6c82b0d

Browse files
committed
Implement require_auth connection parameter
1 parent 32ba56b commit 6c82b0d

5 files changed

Lines changed: 158 additions & 13 deletions

File tree

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
unreleased
22
----------
33

4+
- Implement `require_auth` connection parameter ([#1310]).
5+
6+
[#1310]: https://github.com/lib/pq/pull/1310
47

58
v1.12.2 (2026-04-02)
69
--------------------

conn.go

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"net"
1616
"os"
1717
"reflect"
18+
"slices"
1819
"strconv"
1920
"strings"
2021
"sync"
@@ -1247,6 +1248,7 @@ func (cn *conn) startup(cfg Config) error {
12471248
return err
12481249
}
12491250

1251+
var didauth bool
12501252
for {
12511253
t, r, err := cn.recv()
12521254
if err != nil {
@@ -1263,7 +1265,11 @@ func (cn *conn) startup(cfg Config) error {
12631265
case proto.ParameterStatus:
12641266
cn.processParameterStatus(r)
12651267
case proto.AuthenticationRequest:
1266-
err := cn.auth(r, cfg)
1268+
code := proto.AuthCode(r.int32())
1269+
if code != proto.AuthReqOk {
1270+
didauth = true
1271+
}
1272+
err := cn.auth(code, r, cfg)
12671273
if err != nil {
12681274
return err
12691275
}
@@ -1274,6 +1280,9 @@ func (cn *conn) startup(cfg Config) error {
12741280
return fmt.Errorf("pq: protocol version mismatch: min_protocol_version=%s; server supports up to 3.%d", cfg.MinProtocolVersion, newestMinor)
12751281
}
12761282
case proto.ReadyForQuery:
1283+
if len(cn.cfg.RequireAuth) > 0 && !didauth && !slices.Contains(cn.cfg.RequireAuth, RequireAuthNone) {
1284+
return fmt.Errorf("pq: authentication method requirement %q failed: server did not perform any authentication", cn.cfg.RequireAuth)
1285+
}
12771286
cn.processReadyForQuery(r)
12781287
return nil
12791288
default:
@@ -1282,8 +1291,8 @@ func (cn *conn) startup(cfg Config) error {
12821291
}
12831292
}
12841293

1285-
func (cn *conn) auth(r *readBuf, cfg Config) error {
1286-
switch code := proto.AuthCode(r.int32()); code {
1294+
func (cn *conn) auth(code proto.AuthCode, r *readBuf, cfg Config) error {
1295+
switch code {
12871296
default:
12881297
return fmt.Errorf("pq: unknown authentication response: %s", code)
12891298
case proto.AuthReqKrb4, proto.AuthReqKrb5, proto.AuthReqCrypt, proto.AuthReqSSPI:
@@ -1292,13 +1301,19 @@ func (cn *conn) auth(r *readBuf, cfg Config) error {
12921301
return nil
12931302

12941303
case proto.AuthReqPassword:
1304+
if len(cn.cfg.RequireAuth) > 0 && !slices.Contains(cn.cfg.RequireAuth, RequireAuthPassword) && !slices.Contains(cn.cfg.RequireAuth, RequireAuthAny) {
1305+
return fmt.Errorf("pq: authentication method requirement %q failed: server requested %q", cn.cfg.RequireAuth, RequireAuthPassword)
1306+
}
12951307
w := cn.writeBuf(proto.PasswordMessage)
12961308
w.string(cfg.Password)
12971309
// Don't need to check AuthOk response here; auth() is called in a loop,
12981310
// which catches the errors and AuthReqOk responses.
12991311
return cn.send(w)
13001312

13011313
case proto.AuthReqMD5:
1314+
if len(cn.cfg.RequireAuth) > 0 && !slices.Contains(cn.cfg.RequireAuth, RequireAuthMD5) && !slices.Contains(cn.cfg.RequireAuth, RequireAuthAny) {
1315+
return fmt.Errorf("pq: authentication method requirement %q failed: server requested %q", cn.cfg.RequireAuth, RequireAuthMD5)
1316+
}
13021317
s := string(r.next(4))
13031318
w := cn.writeBuf(proto.PasswordMessage)
13041319
w.string("md5" + md5s(md5s(cfg.Password+cfg.User)+s))
@@ -1361,6 +1376,9 @@ func (cn *conn) auth(r *readBuf, cfg Config) error {
13611376
return nil
13621377

13631378
case proto.AuthReqSASL:
1379+
if len(cn.cfg.RequireAuth) > 0 && !slices.Contains(cn.cfg.RequireAuth, RequireAuthScramSHA256) && !slices.Contains(cn.cfg.RequireAuth, RequireAuthAny) {
1380+
return fmt.Errorf("pq: authentication method requirement %q failed: server requested %q", cn.cfg.RequireAuth, RequireAuthScramSHA256)
1381+
}
13641382
sc := scram.NewClient(sha256.New, cfg.User, cfg.Password)
13651383
sc.Step(nil)
13661384
if sc.Err() != nil {

conn_test.go

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1611,18 +1611,18 @@ func TestCommitInFailedTransactionWithCancelContext(t *testing.T) {
16111611

16121612
func TestAuth(t *testing.T) {
16131613
tests := []struct {
1614-
buf readBuf
1614+
code proto.AuthCode
16151615
wantErr string
16161616
}{
1617-
{readBuf{0, 0, 0, 9}, `pq: unsupported authentication method: SSPI (9)`},
1618-
{readBuf{0, 0, 0, 99}, `unknown authentication response: <unknown> (99)`},
1617+
{proto.AuthCode(9), `pq: unsupported authentication method: SSPI (9)`},
1618+
{proto.AuthCode(99), `unknown authentication response: <unknown> (99)`},
16191619
}
16201620

16211621
t.Parallel()
16221622
for _, tt := range tests {
16231623
t.Run("", func(t *testing.T) {
16241624
t.Run("unsupported auth", func(t *testing.T) {
1625-
err := (&conn{}).auth(&tt.buf, Config{})
1625+
err := (&conn{}).auth(tt.code, &readBuf{}, Config{})
16261626
if !pqtest.ErrorContains(err, tt.wantErr) {
16271627
t.Errorf("wrong error:\nhave: %s\nwant: %s", err, tt.wantErr)
16281628
}
@@ -1650,10 +1650,39 @@ func TestAuth(t *testing.T) {
16501650
{"user=pqgoscram password=wordpass", ``},
16511651

16521652
{"user=pqgounknown password=wordpass", `or:role "pqgounknown" does not exist|password authentication failed for user pqgounknown`},
1653+
{"user=pqgounknown password=wordpass require_auth=md5", `or:role "pqgounknown" does not exist|password authentication failed for user pqgounknown`},
1654+
1655+
// require_auth
1656+
{"user=pqgomd5 password=wordpass require_auth=md5,password", ``},
1657+
{"user=pqgopassword password=wordpass require_auth=md5,password", ``},
1658+
{"user=pqgoscram password=wordpass require_auth=md5,password,scram-sha-256", ``},
1659+
{"user=pqgomd5 password=wordpass require_auth=!none", ``},
1660+
{"user=pqgopassword password=wordpass require_auth=!none", ``},
1661+
{"user=pqgoscram password=wordpass require_auth=!none", ``},
1662+
1663+
{"user=pqgomd5 password=wordpass require_auth=password", `"password" failed: server requested "md5"`},
1664+
{"user=pqgopassword password=wordpass require_auth=md5", `"md5" failed: server requested "password"`},
1665+
{"user=pqgoscram password=wordpass require_auth=md5,password", `authentication method requirement "md5,password" failed: server requested "scram-sha-256"`},
1666+
{"user=pqgomd5 password=wordpass require_auth=!md5,!password", `"!md5,!password" failed: server requested "md5"`},
1667+
{"user=pqgopassword password=wordpass require_auth=!md5,!password", `"!md5,!password" failed: server requested "password"`},
1668+
{"user=pqgoscram password=wordpass require_auth=!md5,!password,!scram-sha-256", `"!md5,!password,!scram-sha-256" failed: server requested "scram-sha-256"`},
1669+
{"user=pqgomd5 password=wordpass require_auth=password", `"password" failed: server requested "md5"`},
1670+
1671+
{"user=pqgo password=unused require_auth=none", ``},
1672+
{"user=pqgo password=unused require_auth=!none", `"!none" failed: server did not perform any authentication`},
1673+
{"user=pqgo password=unused require_auth=md5,password,scram-sha-256", `"md5,password,scram-sha-256" failed: server did not perform any authentication`},
1674+
1675+
// authentication method requirement "md5" failed: server did not complete authentication
1676+
// authentication method requirement "md5,password" failed: server did not complete authentication
1677+
// authentication method requirement "md5" failed: server requested SASL authentication
1678+
// authentication method requirement "!none" failed: server did not complete authentication (no auth)
16531679
}
16541680

16551681
for _, tt := range tests {
16561682
t.Run(tt.conn, func(t *testing.T) {
1683+
if strings.Contains(tt.conn, "md5") {
1684+
pqtest.SkipCockroach(t) // md5 not supported
1685+
}
16571686
_, err := pqtest.DB(t, tt.conn)
16581687
if !pqtest.ErrorContains(err, tt.wantErr) {
16591688
t.Errorf("wrong error:\nhave: %s\nwant: %s", err, tt.wantErr)

connector.go

Lines changed: 92 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,12 @@ type (
4444
// SSLProtocolVersion is a ssl_min_protocol_version or
4545
// ssl_max_protocol_version setting.
4646
SSLProtocolVersion string
47+
48+
// RequireAuth is a require_auth setting.
49+
RequireAuth string
50+
51+
// RequireAuths is a require_auth setting.
52+
RequireAuths []RequireAuth
4753
)
4854

4955
// Values for [SSLMode] that pq supports.
@@ -178,6 +184,41 @@ func (s SSLProtocolVersion) tlsconf() uint16 {
178184
}
179185
}
180186

187+
// Values for [RequireAuth] that pq supports.
188+
const (
189+
RequireAuthNone = RequireAuth("none")
190+
RequireAuthPassword = RequireAuth("password")
191+
RequireAuthMD5 = RequireAuth("md5")
192+
RequireAuthGSS = RequireAuth("gss")
193+
RequireAuthScramSHA256 = RequireAuth("scram-sha-256")
194+
RequireAuthAny = RequireAuth("!none")
195+
RequireAuthNotPassword = RequireAuth("!password")
196+
RequireAuthNotMD5 = RequireAuth("!md5")
197+
RequireAuthNotGSS = RequireAuth("!gss")
198+
RequireAuthNotScramSHA256 = RequireAuth("!scram-sha-256")
199+
200+
// Not (yet) supported by pq
201+
// RequireAuthSSPI = "sspi"
202+
// RequireAuthOAuth = "oauth"
203+
// RequireAuthNotSSPI = "!sspi"
204+
// RequireAuthNotOAuth = "!oauth"
205+
)
206+
207+
var requireAuths = []RequireAuth{RequireAuthNone, RequireAuthPassword, RequireAuthMD5,
208+
RequireAuthGSS, RequireAuthScramSHA256, RequireAuthAny, RequireAuthNotPassword,
209+
RequireAuthNotMD5, RequireAuthNotGSS, RequireAuthNotScramSHA256}
210+
211+
func (r RequireAuths) String() string {
212+
var b strings.Builder
213+
for i, rr := range r {
214+
if i > 0 {
215+
b.WriteString(",")
216+
}
217+
b.WriteString(string(rr))
218+
}
219+
return b.String()
220+
}
221+
181222
// Connector represents a fixed configuration for the pq driver with a given
182223
// dsn. Connector satisfies the [database/sql/driver.Connector] interface and
183224
// can be used to create any number of DB Conn's via [sql.OpenDB].
@@ -340,14 +381,14 @@ type Config struct {
340381
//
341382
// The default is determined by [tls.Config.MinVersion], which is TLSv1.2 at
342383
// the time of writing.
343-
SSLMinProtocolVersion SSLProtocolVersion `postgres:"ssl_min_protocol_version" env:"SSLPGMINPROTOCOLVERSION"`
384+
SSLMinProtocolVersion SSLProtocolVersion `postgres:"ssl_min_protocol_version" env:"PGSSLMINPROTOCOLVERSION"`
344385

345386
// Maximum SSL/TLS protocol version to allow for the connection. If not set,
346387
// this parameter is ignored and the connection will use the maximum bound
347388
// defined by the backend, if set. Setting the maximum protocol version is
348389
// mainly useful for testing or if some component has issues working with a
349390
// newer protocol.
350-
SSLMaxProtocolVersion SSLProtocolVersion `postgres:"ssl_max_protocol_version" env:"SSLPGMAXPROTOCOLVERSION"`
391+
SSLMaxProtocolVersion SSLProtocolVersion `postgres:"ssl_max_protocol_version" env:"PGSSLMAXPROTOCOLVERSION"`
351392

352393
// Interpert sslcert and sslkey as PEM encoded data, rather than a path to a
353394
// PEM file. This is a pq extension, not supported in libpq.
@@ -430,6 +471,25 @@ type Config struct {
430471
// Path to connection service file. Defaults to ~/.pg_service.conf.
431472
ServiceFile string `postgres:"-" env:"PGSERVICEFILE"`
432473

474+
// Require an authentication method from the server and refuse to connect if
475+
// the server does not use the requested method.
476+
//
477+
// This accepts a comma-separated list.
478+
//
479+
// Methods may be negated with a ! prefix, in which case the server must
480+
// *not* attempt the listed method, and the server is free not to
481+
// authenticate the client at all. Negated and non-negated forms may not be
482+
// combined in the same setting with a comma-separated list.
483+
//
484+
// As a special case the "none" method requires the server not to use an
485+
// authentication challenge. This does not prohibit client certificate
486+
// authentication via TLS or GSS authentication via its encrypted transport.
487+
// This can be negated to require some form of authentication.
488+
//
489+
// By default any authentication method is accepted and the server is free
490+
// to skip authentication altogether.
491+
RequireAuth RequireAuths `postgres:"require_auth" env:"PGREQUIREAUTH"`
492+
433493
// Runtime parameters: any unrecognized parameter in the DSN will be added
434494
// to this and sent to PostgreSQL during startup.
435495
Runtime map[string]string `postgres:"-" env:"-"`
@@ -674,8 +734,8 @@ func (cfg *Config) fromEnv(env []string) error {
674734
switch k {
675735
case "PGREQUIRESSL", "PGSSLCOMPRESSION", // Deprecated.
676736
"PGREALM", "PGGSSENCMODE", "PGGSSDELEGATION", "PGGSSLIB", // krb stuff
677-
"PGREQUIREAUTH", "PGCHANNELBINDING",
678-
"PGSSLCERTMODE", "PGSSLCRL", "PGSSLCRLDIR", "PGREQUIREPEER":
737+
"PGCHANNELBINDING", "PGSSLCRL", "PGSSLCRLDIR",
738+
"PGSSLCERTMODE", "PGREQUIREPEER":
679739
return fmt.Errorf("pq: environment variable $%s is not supported", k)
680740
case "PGKRBSRVNAME":
681741
if newGss == nil {
@@ -835,8 +895,9 @@ func (cfg *Config) setFromTag(o map[string]string, tag string, service bool) err
835895
loadbalancehosts = (tag == "postgres" && k == "load_balance_hosts") || (tag == "env" && k == "PGLOADBALANCEHOSTS")
836896
minprotocolversion = (tag == "postgres" && k == "min_protocol_version") || (tag == "env" && k == "PGMINPROTOCOLVERSION")
837897
maxprotocolversion = (tag == "postgres" && k == "max_protocol_version") || (tag == "env" && k == "PGMAXPROTOCOLVERSION")
838-
sslminprotocolversion = (tag == "postgres" && k == "ssl_min_protocol_version") || (tag == "env" && k == "SSLPGMINPROTOCOLVERSION")
839-
sslmaxprotocolversion = (tag == "postgres" && k == "ssl_max_protocol_version") || (tag == "env" && k == "SSLPGMAXPROTOCOLVERSION")
898+
sslminprotocolversion = (tag == "postgres" && k == "ssl_min_protocol_version") || (tag == "env" && k == "PGSSLMINPROTOCOLVERSION")
899+
sslmaxprotocolversion = (tag == "postgres" && k == "ssl_max_protocol_version") || (tag == "env" && k == "PGSSLMAXPROTOCOLVERSION")
900+
requireauth = (tag == "postgres" && k == "require_auth") || (tag == "env" && k == "PGREQUIREAUTH")
840901
)
841902
if k == "" || k == "-" {
842903
continue
@@ -910,6 +971,31 @@ func (cfg *Config) setFromTag(o map[string]string, tag string, service bool) err
910971
cfg.multiHost = append(cfg.multiHost, vv[1:]...)
911972
}
912973
rv.SetString(v)
974+
case reflect.Slice:
975+
if requireauth {
976+
if v == "" {
977+
rv.Set(reflect.ValueOf((RequireAuths)(nil)))
978+
continue
979+
}
980+
var (
981+
vv = strings.Split(v, ",")
982+
s = make(RequireAuths, len(vv))
983+
neg = len(vv) > 0 && strings.HasPrefix(vv[0], "!")
984+
)
985+
for i := range vv {
986+
if !slices.Contains(requireAuths, RequireAuth(vv[i])) {
987+
return fmt.Errorf(f+`%q is not supported; supported values are %s`, k, vv[i], pqutil.Join(requireAuths))
988+
}
989+
if neg && !strings.HasPrefix(vv[i], "!") {
990+
return fmt.Errorf(f+`require_auth method %q cannot be mixed with negative methods`, k, vv[i])
991+
}
992+
if !neg && strings.HasPrefix(vv[i], "!") {
993+
return fmt.Errorf(f+`negative require_auth method %q cannot be mixed with non-negative methods`, k, vv[i])
994+
}
995+
s[i] = RequireAuth(vv[i])
996+
}
997+
rv.Set(reflect.ValueOf(s))
998+
}
913999
case reflect.Int64:
9141000
n, err := strconv.ParseInt(v, 10, 64)
9151001
if err != nil {

connector_test.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,15 @@ func TestNewConfig(t *testing.T) {
457457
{"", []string{"PGMINPROTOCOLVERSION=bogus"}, "", `pq: wrong value for $PGMINPROTOCOLVERSION: "bogus" is not supported`},
458458
{"", []string{"PGMAXPROTOCOLVERSION=bogus"}, "", `pq: wrong value for $PGMAXPROTOCOLVERSION: "bogus" is not supported`},
459459
{"min_protocol_version=3.2 max_protocol_version=3.0", nil, "", `min_protocol_version "3.2" cannot be greater than max_protocol_version "3.0"`},
460+
461+
// requireauth
462+
{"require_auth=", nil, "require_auth=''", ``},
463+
{"require_auth=none", nil, "require_auth=none", ""},
464+
{"require_auth=md5,scram-sha-256", nil, "require_auth=md5,scram-sha-256", ""},
465+
{"require_auth=md5,scram-sha256", nil, "", `wrong value for "require_auth": "scram-sha256" is not supported`},
466+
{"require_auth=!md5,!scram-sha-256", nil, "require_auth=!md5,!scram-sha-256", ""},
467+
{"require_auth=md5,!password", nil, "", `negative require_auth method "!password" cannot be mixed with non-negative methods`},
468+
{"require_auth=!md5,password", nil, "", `require_auth method "password" cannot be mixed with negative methods`},
460469
}
461470

462471
t.Parallel()

0 commit comments

Comments
 (0)