Skip to content

Commit cc62a03

Browse files
committed
add saml slo
Signed-off-by: Ivan Zvyagintsev <ivan.zvyagintsev@flant.com>
1 parent 410a58f commit cc62a03

3 files changed

Lines changed: 735 additions & 12 deletions

File tree

connector/saml/saml.go

Lines changed: 207 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,19 @@ package saml
33

44
import (
55
"bytes"
6+
"compress/flate"
67
"context"
8+
"crypto/rand"
79
"crypto/x509"
810
"encoding/base64"
911
"encoding/json"
1012
"encoding/pem"
1113
"encoding/xml"
1214
"fmt"
15+
"io"
1316
"log/slog"
17+
"net/http"
18+
"net/url"
1419
"os"
1520
"strings"
1621
"sync"
@@ -84,6 +89,15 @@ type Config struct {
8489

8590
InsecureSkipSignatureValidation bool `json:"insecureSkipSignatureValidation"`
8691

92+
// SLOURL is the IdP's Single Logout Service URL (HTTP-Redirect binding).
93+
// If empty, SLO is not available for this connector.
94+
SLOURL string `json:"sloURL"`
95+
96+
// InsecureSkipSLOSignatureValidation skips signature validation on SLO responses.
97+
// This is insecure and should only be used for testing or when the IdP
98+
// does not sign LogoutResponses.
99+
InsecureSkipSLOSignatureValidation bool `json:"insecureSkipSLOSignatureValidation"`
100+
87101
// Assertion attribute names to lookup various claims with.
88102
UsernameAttr string `json:"usernameAttr"`
89103
EmailAttr string `json:"emailAttr"`
@@ -164,6 +178,9 @@ func (c *Config) openConnector(logger *slog.Logger) (*provider, error) {
164178
logger: logger,
165179

166180
nameIDPolicyFormat: c.NameIDPolicyFormat,
181+
182+
sloURL: c.SLOURL,
183+
insecureSkipSLOSignatureValidation: c.InsecureSkipSLOSignatureValidation,
167184
}
168185

169186
if p.nameIDPolicyFormat == "" {
@@ -189,7 +206,8 @@ func (c *Config) openConnector(logger *slog.Logger) (*provider, error) {
189206
}
190207
}
191208

192-
if !c.InsecureSkipSignatureValidation {
209+
needsSLOSigValidation := c.SLOURL != "" && !c.InsecureSkipSLOSignatureValidation
210+
if !c.InsecureSkipSignatureValidation || needsSLOSigValidation {
193211
if (c.CA == "") == (c.CAData == nil) {
194212
return nil, errors.New("must provide either 'ca' or 'caData'")
195213
}
@@ -233,8 +251,9 @@ func (c *Config) openConnector(logger *slog.Logger) (*provider, error) {
233251
}
234252

235253
var (
236-
_ connector.SAMLConnector = (*provider)(nil)
237-
_ connector.RefreshConnector = (*provider)(nil)
254+
_ connector.SAMLConnector = (*provider)(nil)
255+
_ connector.RefreshConnector = (*provider)(nil)
256+
_ connector.LogoutCallbackConnector = (*provider)(nil)
238257
)
239258

240259
type provider struct {
@@ -259,30 +278,41 @@ type provider struct {
259278

260279
nameIDPolicyFormat string
261280

281+
sloURL string
282+
insecureSkipSLOSignatureValidation bool
283+
262284
logger *slog.Logger
263285
}
264286

265-
// cachedIdentity stores the identity from SAML assertion for refresh token support.
266-
// Since SAML has no native refresh mechanism, we cache the identity obtained during
267-
// the initial authentication and return it on subsequent refresh requests.
287+
// cachedIdentity stores the identity from SAML assertion for refresh token support
288+
// and SLO (Single Logout). The NameID/NameIDFormat/SessionIndex fields are used
289+
// to build a SAML LogoutRequest when the user logs out.
268290
type cachedIdentity struct {
269291
UserID string `json:"userId"`
270292
Username string `json:"username"`
271293
PreferredUsername string `json:"preferredUsername"`
272294
Email string `json:"email"`
273295
EmailVerified bool `json:"emailVerified"`
274296
Groups []string `json:"groups,omitempty"`
297+
NameID string `json:"nameId,omitempty"`
298+
NameIDFormat string `json:"nameIdFormat,omitempty"`
299+
SessionIndex string `json:"sessionIndex,omitempty"`
275300
}
276301

277-
// marshalCachedIdentity serializes the identity into ConnectorData for refresh token support.
278-
func marshalCachedIdentity(ident connector.Identity) (connector.Identity, error) {
302+
// marshalCachedIdentity serializes the identity along with SAML-specific SLO
303+
// fields into ConnectorData. The nameIDFormat and sessionIdx parameters come
304+
// from the parsed SAML assertion and are needed to construct a LogoutRequest.
305+
func marshalCachedIdentity(ident connector.Identity, nameIDFormat, sessionIdx string) (connector.Identity, error) {
279306
ci := cachedIdentity{
280307
UserID: ident.UserID,
281308
Username: ident.Username,
282309
PreferredUsername: ident.PreferredUsername,
283310
Email: ident.Email,
284311
EmailVerified: ident.EmailVerified,
285312
Groups: ident.Groups,
313+
NameID: ident.UserID,
314+
NameIDFormat: nameIDFormat,
315+
SessionIndex: sessionIdx,
286316
}
287317
connectorData, err := json.Marshal(ci)
288318
if err != nil {
@@ -407,15 +437,22 @@ func (p *provider) HandlePOST(s connector.Scopes, samlResponse, inResponseTo str
407437
}
408438
}
409439

440+
var nameIDFormat string
410441
switch {
411442
case subject.NameID != nil:
412443
if ident.UserID = subject.NameID.Value; ident.UserID == "" {
413444
return ident, fmt.Errorf("element NameID does not contain a value")
414445
}
446+
nameIDFormat = subject.NameID.Format
415447
default:
416448
return ident, fmt.Errorf("subject does not contain an NameID element")
417449
}
418450

451+
var sessionIdx string
452+
if len(assertion.AuthnStatements) > 0 {
453+
sessionIdx = assertion.AuthnStatements[0].SessionIndex
454+
}
455+
419456
// After verifying the assertion, map data in the attribute statements to
420457
// various user info.
421458
attributes := assertion.AttributeStatement
@@ -442,7 +479,7 @@ func (p *provider) HandlePOST(s connector.Scopes, samlResponse, inResponseTo str
442479

443480
if len(p.allowedGroups) == 0 && (!s.Groups || p.groupsAttr == "") {
444481
// Groups not requested or not configured. We're done.
445-
return marshalCachedIdentity(ident)
482+
return marshalCachedIdentity(ident, nameIDFormat, sessionIdx)
446483
}
447484

448485
if len(p.allowedGroups) > 0 && (!s.Groups || p.groupsAttr == "") {
@@ -468,7 +505,7 @@ func (p *provider) HandlePOST(s connector.Scopes, samlResponse, inResponseTo str
468505

469506
if len(p.allowedGroups) == 0 {
470507
// No allowed groups set, just return the ident
471-
return marshalCachedIdentity(ident)
508+
return marshalCachedIdentity(ident, nameIDFormat, sessionIdx)
472509
}
473510

474511
// Look for membership in one of the allowed groups
@@ -484,7 +521,7 @@ func (p *provider) HandlePOST(s connector.Scopes, samlResponse, inResponseTo str
484521
}
485522

486523
// Otherwise, we're good
487-
return marshalCachedIdentity(ident)
524+
return marshalCachedIdentity(ident, nameIDFormat, sessionIdx)
488525
}
489526

490527
// Refresh implements connector.RefreshConnector.
@@ -711,3 +748,162 @@ func before(now, notBefore time.Time) bool {
711748
func after(now, notOnOrAfter time.Time) bool {
712749
return now.After(notOnOrAfter.Add(allowedClockDrift))
713750
}
751+
752+
// newRequestID generates a random ID suitable for SAML request IDs.
753+
func newRequestID() string {
754+
buf := make([]byte, 16)
755+
if _, err := io.ReadFull(rand.Reader, buf); err != nil {
756+
panic("crypto/rand failed: " + err.Error())
757+
}
758+
return fmt.Sprintf("_%x", buf)
759+
}
760+
761+
// LogoutURL builds a SAML LogoutRequest and returns the IdP's SLO endpoint URL
762+
// with the request encoded using HTTP-Redirect binding (deflate + base64).
763+
//
764+
// See: https://docs.oasis-open.org/security/saml/v2.0/saml-bindings-2.0-os.pdf
765+
// "3.4 HTTP Redirect Binding"
766+
func (p *provider) LogoutURL(_ context.Context, connectorData []byte, postLogoutRedirectURI string) (string, error) {
767+
if p.sloURL == "" {
768+
return "", nil
769+
}
770+
771+
var ci cachedIdentity
772+
if len(connectorData) > 0 {
773+
if err := json.Unmarshal(connectorData, &ci); err != nil {
774+
return "", fmt.Errorf("saml: failed to unmarshal connector data for logout: %v", err)
775+
}
776+
}
777+
778+
if ci.NameID == "" {
779+
return "", nil
780+
}
781+
782+
req := &logoutRequest{
783+
ID: newRequestID(),
784+
IssueInstant: xmlTime(p.now()),
785+
Destination: p.sloURL,
786+
NameID: nameID{
787+
Format: ci.NameIDFormat,
788+
Value: ci.NameID,
789+
},
790+
}
791+
if p.entityIssuer != "" {
792+
req.Issuer = &issuer{Issuer: p.entityIssuer}
793+
}
794+
if ci.SessionIndex != "" {
795+
req.SessionIndex = []sessionIndex{{Value: ci.SessionIndex}}
796+
}
797+
798+
data, err := xml.Marshal(req)
799+
if err != nil {
800+
return "", fmt.Errorf("saml: failed to marshal LogoutRequest: %v", err)
801+
}
802+
803+
// HTTP-Redirect binding: deflate then base64-encode.
804+
var buf bytes.Buffer
805+
fw, err := flate.NewWriter(&buf, flate.DefaultCompression)
806+
if err != nil {
807+
return "", fmt.Errorf("saml: failed to create deflate writer: %v", err)
808+
}
809+
if _, err := fw.Write(data); err != nil {
810+
return "", fmt.Errorf("saml: failed to deflate LogoutRequest: %v", err)
811+
}
812+
if err := fw.Close(); err != nil {
813+
return "", fmt.Errorf("saml: failed to close deflate writer: %v", err)
814+
}
815+
816+
encoded := base64.StdEncoding.EncodeToString(buf.Bytes())
817+
818+
u, err := url.Parse(p.sloURL)
819+
if err != nil {
820+
return "", fmt.Errorf("saml: failed to parse SLO URL: %v", err)
821+
}
822+
q := u.Query()
823+
q.Set("SAMLRequest", encoded)
824+
if postLogoutRedirectURI != "" {
825+
q.Set("RelayState", postLogoutRedirectURI)
826+
}
827+
u.RawQuery = q.Encode()
828+
829+
return u.String(), nil
830+
}
831+
832+
// HandleLogoutCallback validates the IdP's LogoutResponse received after
833+
// an SP-initiated logout redirect. The response arrives as a SAMLResponse
834+
// parameter via either GET query (HTTP-Redirect binding: deflated + base64)
835+
// or POST form (HTTP-POST binding: base64 only).
836+
func (p *provider) HandleLogoutCallback(_ context.Context, r *http.Request) error {
837+
var samlResponse string
838+
if r.Method == http.MethodGet {
839+
samlResponse = r.URL.Query().Get("SAMLResponse")
840+
} else {
841+
if err := r.ParseForm(); err != nil {
842+
return fmt.Errorf("saml slo: failed to parse form: %v", err)
843+
}
844+
samlResponse = r.FormValue("SAMLResponse")
845+
}
846+
847+
if samlResponse == "" {
848+
return nil
849+
}
850+
851+
compressed, err := base64.StdEncoding.DecodeString(samlResponse)
852+
if err != nil {
853+
return fmt.Errorf("saml slo: failed to decode SAMLResponse: %v", err)
854+
}
855+
856+
// HTTP-Redirect binding uses DEFLATE compression; HTTP-POST does not.
857+
// Try to inflate; if it fails, treat the data as uncompressed XML.
858+
rawResp, err := io.ReadAll(flate.NewReader(bytes.NewReader(compressed)))
859+
if err != nil {
860+
rawResp = compressed
861+
}
862+
863+
byteReader := bytes.NewReader(rawResp)
864+
if xrvErr := xrv.Validate(byteReader); xrvErr != nil {
865+
return fmt.Errorf("saml slo: %w", xrvErr)
866+
}
867+
868+
if p.validator != nil && !p.insecureSkipSLOSignatureValidation {
869+
if _, err := p.validateSignature(rawResp); err != nil {
870+
return fmt.Errorf("saml slo: %v", err)
871+
}
872+
}
873+
874+
var resp logoutResponse
875+
if err := xml.Unmarshal(rawResp, &resp); err != nil {
876+
return fmt.Errorf("saml slo: failed to unmarshal LogoutResponse: %v", err)
877+
}
878+
879+
if resp.Status != nil {
880+
if err := p.validateStatus(resp.Status); err != nil {
881+
return fmt.Errorf("saml slo: %v", err)
882+
}
883+
}
884+
885+
return nil
886+
}
887+
888+
// validateSignature validates the XML digital signature of the given raw XML.
889+
func (p *provider) validateSignature(rawXML []byte) ([]byte, error) {
890+
if p.validator == nil {
891+
return nil, fmt.Errorf("signature validation unavailable (no validator configured)")
892+
}
893+
894+
doc := etree.NewDocument()
895+
if err := doc.ReadFromBytes(rawXML); err != nil {
896+
return nil, fmt.Errorf("failed to parse XML: %v", err)
897+
}
898+
899+
root := doc.Root()
900+
if root == nil {
901+
return nil, fmt.Errorf("empty XML document")
902+
}
903+
904+
if _, err := p.validator.Validate(root); err != nil {
905+
return nil, fmt.Errorf("signature validation failed: %v", err)
906+
}
907+
908+
return rawXML, nil
909+
}

0 commit comments

Comments
 (0)