@@ -44,6 +44,12 @@ type (
4444 // SSLProtocolVersion is a ssl_min_protocol_version or
4545 // ssl_max_protocol_version setting.
4646 SSLProtocolVersion string
47+
48+ // RequireAuth is a require_auth setting.
49+ RequireAuth string
50+
51+ // RequireAuths is a require_auth setting.
52+ RequireAuths []RequireAuth
4753)
4854
4955// Values for [SSLMode] that pq supports.
@@ -178,6 +184,41 @@ func (s SSLProtocolVersion) tlsconf() uint16 {
178184 }
179185}
180186
187+ // Values for [RequireAuth] that pq supports.
188+ const (
189+ RequireAuthNone = RequireAuth ("none" )
190+ RequireAuthPassword = RequireAuth ("password" )
191+ RequireAuthMD5 = RequireAuth ("md5" )
192+ RequireAuthGSS = RequireAuth ("gss" )
193+ RequireAuthScramSHA256 = RequireAuth ("scram-sha-256" )
194+ RequireAuthAny = RequireAuth ("!none" )
195+ RequireAuthNotPassword = RequireAuth ("!password" )
196+ RequireAuthNotMD5 = RequireAuth ("!md5" )
197+ RequireAuthNotGSS = RequireAuth ("!gss" )
198+ RequireAuthNotScramSHA256 = RequireAuth ("!scram-sha-256" )
199+
200+ // Not (yet) supported by pq
201+ // RequireAuthSSPI = "sspi"
202+ // RequireAuthOAuth = "oauth"
203+ // RequireAuthNotSSPI = "!sspi"
204+ // RequireAuthNotOAuth = "!oauth"
205+ )
206+
207+ var requireAuths = []RequireAuth {RequireAuthNone , RequireAuthPassword , RequireAuthMD5 ,
208+ RequireAuthGSS , RequireAuthScramSHA256 , RequireAuthAny , RequireAuthNotPassword ,
209+ RequireAuthNotMD5 , RequireAuthNotGSS , RequireAuthNotScramSHA256 }
210+
211+ func (r RequireAuths ) String () string {
212+ var b strings.Builder
213+ for i , rr := range r {
214+ if i > 0 {
215+ b .WriteString ("," )
216+ }
217+ b .WriteString (string (rr ))
218+ }
219+ return b .String ()
220+ }
221+
181222// Connector represents a fixed configuration for the pq driver with a given
182223// dsn. Connector satisfies the [database/sql/driver.Connector] interface and
183224// can be used to create any number of DB Conn's via [sql.OpenDB].
@@ -340,14 +381,14 @@ type Config struct {
340381 //
341382 // The default is determined by [tls.Config.MinVersion], which is TLSv1.2 at
342383 // the time of writing.
343- SSLMinProtocolVersion SSLProtocolVersion `postgres:"ssl_min_protocol_version" env:"SSLPGMINPROTOCOLVERSION "`
384+ SSLMinProtocolVersion SSLProtocolVersion `postgres:"ssl_min_protocol_version" env:"PGSSLMINPROTOCOLVERSION "`
344385
345386 // Maximum SSL/TLS protocol version to allow for the connection. If not set,
346387 // this parameter is ignored and the connection will use the maximum bound
347388 // defined by the backend, if set. Setting the maximum protocol version is
348389 // mainly useful for testing or if some component has issues working with a
349390 // newer protocol.
350- SSLMaxProtocolVersion SSLProtocolVersion `postgres:"ssl_max_protocol_version" env:"SSLPGMAXPROTOCOLVERSION "`
391+ SSLMaxProtocolVersion SSLProtocolVersion `postgres:"ssl_max_protocol_version" env:"PGSSLMAXPROTOCOLVERSION "`
351392
352393 // Interpert sslcert and sslkey as PEM encoded data, rather than a path to a
353394 // PEM file. This is a pq extension, not supported in libpq.
@@ -430,6 +471,25 @@ type Config struct {
430471 // Path to connection service file. Defaults to ~/.pg_service.conf.
431472 ServiceFile string `postgres:"-" env:"PGSERVICEFILE"`
432473
474+ // Require an authentication method from the server and refuse to connect if
475+ // the server does not use the requested method.
476+ //
477+ // This accepts a comma-separated list.
478+ //
479+ // Methods may be negated with a ! prefix, in which case the server must
480+ // *not* attempt the listed method, and the server is free not to
481+ // authenticate the client at all. Negated and non-negated forms may not be
482+ // combined in the same setting with a comma-separated list.
483+ //
484+ // As a special case the "none" method requires the server not to use an
485+ // authentication challenge. This does not prohibit client certificate
486+ // authentication via TLS or GSS authentication via its encrypted transport.
487+ // This can be negated to require some form of authentication.
488+ //
489+ // By default any authentication method is accepted and the server is free
490+ // to skip authentication altogether.
491+ RequireAuth RequireAuths `postgres:"require_auth" env:"PGREQUIREAUTH"`
492+
433493 // Runtime parameters: any unrecognized parameter in the DSN will be added
434494 // to this and sent to PostgreSQL during startup.
435495 Runtime map [string ]string `postgres:"-" env:"-"`
@@ -674,8 +734,8 @@ func (cfg *Config) fromEnv(env []string) error {
674734 switch k {
675735 case "PGREQUIRESSL" , "PGSSLCOMPRESSION" , // Deprecated.
676736 "PGREALM" , "PGGSSENCMODE" , "PGGSSDELEGATION" , "PGGSSLIB" , // krb stuff
677- "PGREQUIREAUTH " , "PGCHANNELBINDING " ,
678- "PGSSLCERTMODE" , "PGSSLCRL" , "PGSSLCRLDIR" , " PGREQUIREPEER" :
737+ "PGCHANNELBINDING " , "PGSSLCRL" , "PGSSLCRLDIR " ,
738+ "PGSSLCERTMODE" , "PGREQUIREPEER" :
679739 return fmt .Errorf ("pq: environment variable $%s is not supported" , k )
680740 case "PGKRBSRVNAME" :
681741 if newGss == nil {
@@ -835,8 +895,9 @@ func (cfg *Config) setFromTag(o map[string]string, tag string, service bool) err
835895 loadbalancehosts = (tag == "postgres" && k == "load_balance_hosts" ) || (tag == "env" && k == "PGLOADBALANCEHOSTS" )
836896 minprotocolversion = (tag == "postgres" && k == "min_protocol_version" ) || (tag == "env" && k == "PGMINPROTOCOLVERSION" )
837897 maxprotocolversion = (tag == "postgres" && k == "max_protocol_version" ) || (tag == "env" && k == "PGMAXPROTOCOLVERSION" )
838- sslminprotocolversion = (tag == "postgres" && k == "ssl_min_protocol_version" ) || (tag == "env" && k == "SSLPGMINPROTOCOLVERSION" )
839- sslmaxprotocolversion = (tag == "postgres" && k == "ssl_max_protocol_version" ) || (tag == "env" && k == "SSLPGMAXPROTOCOLVERSION" )
898+ sslminprotocolversion = (tag == "postgres" && k == "ssl_min_protocol_version" ) || (tag == "env" && k == "PGSSLMINPROTOCOLVERSION" )
899+ sslmaxprotocolversion = (tag == "postgres" && k == "ssl_max_protocol_version" ) || (tag == "env" && k == "PGSSLMAXPROTOCOLVERSION" )
900+ requireauth = (tag == "postgres" && k == "require_auth" ) || (tag == "env" && k == "PGREQUIREAUTH" )
840901 )
841902 if k == "" || k == "-" {
842903 continue
@@ -910,6 +971,31 @@ func (cfg *Config) setFromTag(o map[string]string, tag string, service bool) err
910971 cfg .multiHost = append (cfg .multiHost , vv [1 :]... )
911972 }
912973 rv .SetString (v )
974+ case reflect .Slice :
975+ if requireauth {
976+ if v == "" {
977+ rv .Set (reflect .ValueOf ((RequireAuths )(nil )))
978+ continue
979+ }
980+ var (
981+ vv = strings .Split (v , "," )
982+ s = make (RequireAuths , len (vv ))
983+ neg = len (vv ) > 0 && strings .HasPrefix (vv [0 ], "!" )
984+ )
985+ for i := range vv {
986+ if ! slices .Contains (requireAuths , RequireAuth (vv [i ])) {
987+ return fmt .Errorf (f + `%q is not supported; supported values are %s` , k , vv [i ], pqutil .Join (requireAuths ))
988+ }
989+ if neg && ! strings .HasPrefix (vv [i ], "!" ) {
990+ return fmt .Errorf (f + `require_auth method %q cannot be mixed with negative methods` , k , vv [i ])
991+ }
992+ if ! neg && strings .HasPrefix (vv [i ], "!" ) {
993+ return fmt .Errorf (f + `negative require_auth method %q cannot be mixed with non-negative methods` , k , vv [i ])
994+ }
995+ s [i ] = RequireAuth (vv [i ])
996+ }
997+ rv .Set (reflect .ValueOf (s ))
998+ }
913999 case reflect .Int64 :
9141000 n , err := strconv .ParseInt (v , 10 , 64 )
9151001 if err != nil {
0 commit comments