@@ -3,14 +3,19 @@ package saml
33
44import (
55 "bytes"
6+ "compress/flate"
67 "context"
8+ "crypto/rand"
79 "crypto/x509"
810 "encoding/base64"
911 "encoding/json"
1012 "encoding/pem"
1113 "encoding/xml"
1214 "fmt"
15+ "io"
1316 "log/slog"
17+ "net/http"
18+ "net/url"
1419 "os"
1520 "strings"
1621 "sync"
@@ -84,6 +89,15 @@ type Config struct {
8489
8590 InsecureSkipSignatureValidation bool `json:"insecureSkipSignatureValidation"`
8691
92+ // SLOURL is the IdP's Single Logout Service URL (HTTP-Redirect binding).
93+ // If empty, SLO is not available for this connector.
94+ SLOURL string `json:"sloURL"`
95+
96+ // InsecureSkipSLOSignatureValidation skips signature validation on SLO responses.
97+ // This is insecure and should only be used for testing or when the IdP
98+ // does not sign LogoutResponses.
99+ InsecureSkipSLOSignatureValidation bool `json:"insecureSkipSLOSignatureValidation"`
100+
87101 // Assertion attribute names to lookup various claims with.
88102 UsernameAttr string `json:"usernameAttr"`
89103 EmailAttr string `json:"emailAttr"`
@@ -164,6 +178,9 @@ func (c *Config) openConnector(logger *slog.Logger) (*provider, error) {
164178 logger : logger ,
165179
166180 nameIDPolicyFormat : c .NameIDPolicyFormat ,
181+
182+ sloURL : c .SLOURL ,
183+ insecureSkipSLOSignatureValidation : c .InsecureSkipSLOSignatureValidation ,
167184 }
168185
169186 if p .nameIDPolicyFormat == "" {
@@ -189,7 +206,8 @@ func (c *Config) openConnector(logger *slog.Logger) (*provider, error) {
189206 }
190207 }
191208
192- if ! c .InsecureSkipSignatureValidation {
209+ needsSLOSigValidation := c .SLOURL != "" && ! c .InsecureSkipSLOSignatureValidation
210+ if ! c .InsecureSkipSignatureValidation || needsSLOSigValidation {
193211 if (c .CA == "" ) == (c .CAData == nil ) {
194212 return nil , errors .New ("must provide either 'ca' or 'caData'" )
195213 }
@@ -233,8 +251,9 @@ func (c *Config) openConnector(logger *slog.Logger) (*provider, error) {
233251}
234252
235253var (
236- _ connector.SAMLConnector = (* provider )(nil )
237- _ connector.RefreshConnector = (* provider )(nil )
254+ _ connector.SAMLConnector = (* provider )(nil )
255+ _ connector.RefreshConnector = (* provider )(nil )
256+ _ connector.LogoutCallbackConnector = (* provider )(nil )
238257)
239258
240259type provider struct {
@@ -259,30 +278,41 @@ type provider struct {
259278
260279 nameIDPolicyFormat string
261280
281+ sloURL string
282+ insecureSkipSLOSignatureValidation bool
283+
262284 logger * slog.Logger
263285}
264286
265- // cachedIdentity stores the identity from SAML assertion for refresh token support.
266- // Since SAML has no native refresh mechanism, we cache the identity obtained during
267- // the initial authentication and return it on subsequent refresh requests .
287+ // cachedIdentity stores the identity from SAML assertion for refresh token support
288+ // and SLO (Single Logout). The NameID/NameIDFormat/SessionIndex fields are used
289+ // to build a SAML LogoutRequest when the user logs out .
268290type cachedIdentity struct {
269291 UserID string `json:"userId"`
270292 Username string `json:"username"`
271293 PreferredUsername string `json:"preferredUsername"`
272294 Email string `json:"email"`
273295 EmailVerified bool `json:"emailVerified"`
274296 Groups []string `json:"groups,omitempty"`
297+ NameID string `json:"nameId,omitempty"`
298+ NameIDFormat string `json:"nameIdFormat,omitempty"`
299+ SessionIndex string `json:"sessionIndex,omitempty"`
275300}
276301
277- // marshalCachedIdentity serializes the identity into ConnectorData for refresh token support.
278- func marshalCachedIdentity (ident connector.Identity ) (connector.Identity , error ) {
302+ // marshalCachedIdentity serializes the identity along with SAML-specific SLO
303+ // fields into ConnectorData. The nameIDFormat and sessionIdx parameters come
304+ // from the parsed SAML assertion and are needed to construct a LogoutRequest.
305+ func marshalCachedIdentity (ident connector.Identity , nameIDFormat , sessionIdx string ) (connector.Identity , error ) {
279306 ci := cachedIdentity {
280307 UserID : ident .UserID ,
281308 Username : ident .Username ,
282309 PreferredUsername : ident .PreferredUsername ,
283310 Email : ident .Email ,
284311 EmailVerified : ident .EmailVerified ,
285312 Groups : ident .Groups ,
313+ NameID : ident .UserID ,
314+ NameIDFormat : nameIDFormat ,
315+ SessionIndex : sessionIdx ,
286316 }
287317 connectorData , err := json .Marshal (ci )
288318 if err != nil {
@@ -407,15 +437,22 @@ func (p *provider) HandlePOST(s connector.Scopes, samlResponse, inResponseTo str
407437 }
408438 }
409439
440+ var nameIDFormat string
410441 switch {
411442 case subject .NameID != nil :
412443 if ident .UserID = subject .NameID .Value ; ident .UserID == "" {
413444 return ident , fmt .Errorf ("element NameID does not contain a value" )
414445 }
446+ nameIDFormat = subject .NameID .Format
415447 default :
416448 return ident , fmt .Errorf ("subject does not contain an NameID element" )
417449 }
418450
451+ var sessionIdx string
452+ if len (assertion .AuthnStatements ) > 0 {
453+ sessionIdx = assertion .AuthnStatements [0 ].SessionIndex
454+ }
455+
419456 // After verifying the assertion, map data in the attribute statements to
420457 // various user info.
421458 attributes := assertion .AttributeStatement
@@ -442,7 +479,7 @@ func (p *provider) HandlePOST(s connector.Scopes, samlResponse, inResponseTo str
442479
443480 if len (p .allowedGroups ) == 0 && (! s .Groups || p .groupsAttr == "" ) {
444481 // Groups not requested or not configured. We're done.
445- return marshalCachedIdentity (ident )
482+ return marshalCachedIdentity (ident , nameIDFormat , sessionIdx )
446483 }
447484
448485 if len (p .allowedGroups ) > 0 && (! s .Groups || p .groupsAttr == "" ) {
@@ -468,7 +505,7 @@ func (p *provider) HandlePOST(s connector.Scopes, samlResponse, inResponseTo str
468505
469506 if len (p .allowedGroups ) == 0 {
470507 // No allowed groups set, just return the ident
471- return marshalCachedIdentity (ident )
508+ return marshalCachedIdentity (ident , nameIDFormat , sessionIdx )
472509 }
473510
474511 // Look for membership in one of the allowed groups
@@ -484,7 +521,7 @@ func (p *provider) HandlePOST(s connector.Scopes, samlResponse, inResponseTo str
484521 }
485522
486523 // Otherwise, we're good
487- return marshalCachedIdentity (ident )
524+ return marshalCachedIdentity (ident , nameIDFormat , sessionIdx )
488525}
489526
490527// Refresh implements connector.RefreshConnector.
@@ -711,3 +748,162 @@ func before(now, notBefore time.Time) bool {
711748func after (now , notOnOrAfter time.Time ) bool {
712749 return now .After (notOnOrAfter .Add (allowedClockDrift ))
713750}
751+
752+ // newRequestID generates a random ID suitable for SAML request IDs.
753+ func newRequestID () string {
754+ buf := make ([]byte , 16 )
755+ if _ , err := io .ReadFull (rand .Reader , buf ); err != nil {
756+ panic ("crypto/rand failed: " + err .Error ())
757+ }
758+ return fmt .Sprintf ("_%x" , buf )
759+ }
760+
761+ // LogoutURL builds a SAML LogoutRequest and returns the IdP's SLO endpoint URL
762+ // with the request encoded using HTTP-Redirect binding (deflate + base64).
763+ //
764+ // See: https://docs.oasis-open.org/security/saml/v2.0/saml-bindings-2.0-os.pdf
765+ // "3.4 HTTP Redirect Binding"
766+ func (p * provider ) LogoutURL (_ context.Context , connectorData []byte , postLogoutRedirectURI string ) (string , error ) {
767+ if p .sloURL == "" {
768+ return "" , nil
769+ }
770+
771+ var ci cachedIdentity
772+ if len (connectorData ) > 0 {
773+ if err := json .Unmarshal (connectorData , & ci ); err != nil {
774+ return "" , fmt .Errorf ("saml: failed to unmarshal connector data for logout: %v" , err )
775+ }
776+ }
777+
778+ if ci .NameID == "" {
779+ return "" , nil
780+ }
781+
782+ req := & logoutRequest {
783+ ID : newRequestID (),
784+ IssueInstant : xmlTime (p .now ()),
785+ Destination : p .sloURL ,
786+ NameID : nameID {
787+ Format : ci .NameIDFormat ,
788+ Value : ci .NameID ,
789+ },
790+ }
791+ if p .entityIssuer != "" {
792+ req .Issuer = & issuer {Issuer : p .entityIssuer }
793+ }
794+ if ci .SessionIndex != "" {
795+ req .SessionIndex = []sessionIndex {{Value : ci .SessionIndex }}
796+ }
797+
798+ data , err := xml .Marshal (req )
799+ if err != nil {
800+ return "" , fmt .Errorf ("saml: failed to marshal LogoutRequest: %v" , err )
801+ }
802+
803+ // HTTP-Redirect binding: deflate then base64-encode.
804+ var buf bytes.Buffer
805+ fw , err := flate .NewWriter (& buf , flate .DefaultCompression )
806+ if err != nil {
807+ return "" , fmt .Errorf ("saml: failed to create deflate writer: %v" , err )
808+ }
809+ if _ , err := fw .Write (data ); err != nil {
810+ return "" , fmt .Errorf ("saml: failed to deflate LogoutRequest: %v" , err )
811+ }
812+ if err := fw .Close (); err != nil {
813+ return "" , fmt .Errorf ("saml: failed to close deflate writer: %v" , err )
814+ }
815+
816+ encoded := base64 .StdEncoding .EncodeToString (buf .Bytes ())
817+
818+ u , err := url .Parse (p .sloURL )
819+ if err != nil {
820+ return "" , fmt .Errorf ("saml: failed to parse SLO URL: %v" , err )
821+ }
822+ q := u .Query ()
823+ q .Set ("SAMLRequest" , encoded )
824+ if postLogoutRedirectURI != "" {
825+ q .Set ("RelayState" , postLogoutRedirectURI )
826+ }
827+ u .RawQuery = q .Encode ()
828+
829+ return u .String (), nil
830+ }
831+
832+ // HandleLogoutCallback validates the IdP's LogoutResponse received after
833+ // an SP-initiated logout redirect. The response arrives as a SAMLResponse
834+ // parameter via either GET query (HTTP-Redirect binding: deflated + base64)
835+ // or POST form (HTTP-POST binding: base64 only).
836+ func (p * provider ) HandleLogoutCallback (_ context.Context , r * http.Request ) error {
837+ var samlResponse string
838+ if r .Method == http .MethodGet {
839+ samlResponse = r .URL .Query ().Get ("SAMLResponse" )
840+ } else {
841+ if err := r .ParseForm (); err != nil {
842+ return fmt .Errorf ("saml slo: failed to parse form: %v" , err )
843+ }
844+ samlResponse = r .FormValue ("SAMLResponse" )
845+ }
846+
847+ if samlResponse == "" {
848+ return nil
849+ }
850+
851+ compressed , err := base64 .StdEncoding .DecodeString (samlResponse )
852+ if err != nil {
853+ return fmt .Errorf ("saml slo: failed to decode SAMLResponse: %v" , err )
854+ }
855+
856+ // HTTP-Redirect binding uses DEFLATE compression; HTTP-POST does not.
857+ // Try to inflate; if it fails, treat the data as uncompressed XML.
858+ rawResp , err := io .ReadAll (flate .NewReader (bytes .NewReader (compressed )))
859+ if err != nil {
860+ rawResp = compressed
861+ }
862+
863+ byteReader := bytes .NewReader (rawResp )
864+ if xrvErr := xrv .Validate (byteReader ); xrvErr != nil {
865+ return fmt .Errorf ("saml slo: %w" , xrvErr )
866+ }
867+
868+ if p .validator != nil && ! p .insecureSkipSLOSignatureValidation {
869+ if _ , err := p .validateSignature (rawResp ); err != nil {
870+ return fmt .Errorf ("saml slo: %v" , err )
871+ }
872+ }
873+
874+ var resp logoutResponse
875+ if err := xml .Unmarshal (rawResp , & resp ); err != nil {
876+ return fmt .Errorf ("saml slo: failed to unmarshal LogoutResponse: %v" , err )
877+ }
878+
879+ if resp .Status != nil {
880+ if err := p .validateStatus (resp .Status ); err != nil {
881+ return fmt .Errorf ("saml slo: %v" , err )
882+ }
883+ }
884+
885+ return nil
886+ }
887+
888+ // validateSignature validates the XML digital signature of the given raw XML.
889+ func (p * provider ) validateSignature (rawXML []byte ) ([]byte , error ) {
890+ if p .validator == nil {
891+ return nil , fmt .Errorf ("signature validation unavailable (no validator configured)" )
892+ }
893+
894+ doc := etree .NewDocument ()
895+ if err := doc .ReadFromBytes (rawXML ); err != nil {
896+ return nil , fmt .Errorf ("failed to parse XML: %v" , err )
897+ }
898+
899+ root := doc .Root ()
900+ if root == nil {
901+ return nil , fmt .Errorf ("empty XML document" )
902+ }
903+
904+ if _ , err := p .validator .Validate (root ); err != nil {
905+ return nil , fmt .Errorf ("signature validation failed: %v" , err )
906+ }
907+
908+ return rawXML , nil
909+ }
0 commit comments