Skip to content

Commit 7b86548

Browse files
committed
add saml slo
Signed-off-by: Ivan Zvyagintsev <ivan.zvyagintsev@flant.com>
1 parent cc62a03 commit 7b86548

3 files changed

Lines changed: 227 additions & 13 deletions

File tree

connector/saml/saml.go

Lines changed: 112 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ import (
55
"bytes"
66
"compress/flate"
77
"context"
8+
"crypto"
89
"crypto/rand"
10+
"crypto/rsa"
911
"crypto/x509"
1012
"encoding/base64"
1113
"encoding/json"
@@ -246,6 +248,7 @@ func (c *Config) openConnector(logger *slog.Logger) (*provider, error) {
246248
return nil, errors.New("no certificates found in ca data")
247249
}
248250
p.validator = dsig.NewDefaultValidationContext(certStore{certs})
251+
p.certs = certs
249252
}
250253
return p, nil
251254
}
@@ -265,6 +268,9 @@ type provider struct {
265268

266269
// If nil, don't do signature validation.
267270
validator *dsig.ValidationContext
271+
// Stored separately for HTTP-Redirect binding signature verification,
272+
// which uses raw RSA/ECDSA over query string rather than XML digital signatures.
273+
certs []*x509.Certificate
268274

269275
// Attribute mappings
270276
usernameAttr string
@@ -865,9 +871,17 @@ func (p *provider) HandleLogoutCallback(_ context.Context, r *http.Request) erro
865871
return fmt.Errorf("saml slo: %w", xrvErr)
866872
}
867873

868-
if p.validator != nil && !p.insecureSkipSLOSignatureValidation {
869-
if _, err := p.validateSignature(rawResp); err != nil {
870-
return fmt.Errorf("saml slo: %v", err)
874+
if !p.insecureSkipSLOSignatureValidation {
875+
if r.Method == http.MethodGet && len(p.certs) > 0 {
876+
// HTTP-Redirect binding: signature is in query parameters.
877+
if err := p.validateRedirectSignature(r); err != nil {
878+
return fmt.Errorf("saml slo: %v", err)
879+
}
880+
} else if r.Method != http.MethodGet && p.validator != nil {
881+
// HTTP-POST binding: signature is embedded in XML.
882+
if _, err := p.validateSignature(rawResp); err != nil {
883+
return fmt.Errorf("saml slo: %v", err)
884+
}
871885
}
872886
}
873887

@@ -885,6 +899,101 @@ func (p *provider) HandleLogoutCallback(_ context.Context, r *http.Request) erro
885899
return nil
886900
}
887901

902+
// redirectSigAlgToHash maps XML Signature algorithm URIs used in SAML HTTP-Redirect
903+
// binding to Go crypto.Hash values. Only RSA algorithms are supported.
904+
// See: https://www.w3.org/TR/xmldsig-core1/#sec-AlgID
905+
var redirectSigAlgToHash = map[string]crypto.Hash{
906+
"http://www.w3.org/2000/09/xmldsig#rsa-sha1": crypto.SHA1,
907+
"http://www.w3.org/2001/04/xmldsig-more#rsa-sha256": crypto.SHA256,
908+
"http://www.w3.org/2001/04/xmldsig-more#rsa-sha384": crypto.SHA384,
909+
"http://www.w3.org/2001/04/xmldsig-more#rsa-sha512": crypto.SHA512,
910+
}
911+
912+
// rawQueryParam extracts the raw (still URL-encoded) value of a query parameter
913+
// from a raw query string. This is needed for SAML HTTP-Redirect binding signature
914+
// validation, which signs over the URL-encoded parameter values.
915+
func rawQueryParam(rawQuery, key string) (string, bool) {
916+
prefix := key + "="
917+
for rawQuery != "" {
918+
var pair string
919+
if i := strings.IndexByte(rawQuery, '&'); i >= 0 {
920+
pair, rawQuery = rawQuery[:i], rawQuery[i+1:]
921+
} else {
922+
pair, rawQuery = rawQuery, ""
923+
}
924+
if strings.HasPrefix(pair, prefix) {
925+
return pair[len(prefix):], true
926+
}
927+
}
928+
return "", false
929+
}
930+
931+
// validateRedirectSignature verifies the query-string signature used in SAML
932+
// HTTP-Redirect binding. Unlike HTTP-POST where the signature is embedded in
933+
// the XML (<ds:Signature>), HTTP-Redirect carries it as Signature and SigAlg
934+
// query parameters. The signed content is reconstructed per SAML 2.0 Bindings
935+
// Section 3.4.4.1: SAMLResponse=value&RelayState=value&SigAlg=value (using
936+
// the original URL-encoded values).
937+
func (p *provider) validateRedirectSignature(r *http.Request) error {
938+
rawQuery := r.URL.RawQuery
939+
940+
sigEncoded, ok := rawQueryParam(rawQuery, "Signature")
941+
if !ok || sigEncoded == "" {
942+
return fmt.Errorf("missing Signature query parameter")
943+
}
944+
945+
sigAlgEncoded, ok := rawQueryParam(rawQuery, "SigAlg")
946+
if !ok || sigAlgEncoded == "" {
947+
return fmt.Errorf("missing SigAlg query parameter")
948+
}
949+
950+
sigAlg, err := url.QueryUnescape(sigAlgEncoded)
951+
if err != nil {
952+
return fmt.Errorf("failed to decode SigAlg: %v", err)
953+
}
954+
955+
hashAlg, ok := redirectSigAlgToHash[sigAlg]
956+
if !ok {
957+
return fmt.Errorf("unsupported signature algorithm: %s", sigAlg)
958+
}
959+
960+
// Reconstruct the signed content in the spec-mandated order.
961+
var parts []string
962+
if v, ok := rawQueryParam(rawQuery, "SAMLResponse"); ok {
963+
parts = append(parts, "SAMLResponse="+v)
964+
}
965+
if v, ok := rawQueryParam(rawQuery, "RelayState"); ok {
966+
parts = append(parts, "RelayState="+v)
967+
}
968+
parts = append(parts, "SigAlg="+sigAlgEncoded)
969+
signedContent := strings.Join(parts, "&")
970+
971+
sigB64, err := url.QueryUnescape(sigEncoded)
972+
if err != nil {
973+
return fmt.Errorf("failed to URL-decode Signature: %v", err)
974+
}
975+
sig, err := base64.StdEncoding.DecodeString(sigB64)
976+
if err != nil {
977+
return fmt.Errorf("failed to base64-decode Signature: %v", err)
978+
}
979+
980+
h := hashAlg.New()
981+
h.Write([]byte(signedContent))
982+
hashed := h.Sum(nil)
983+
984+
for _, cert := range p.certs {
985+
rsaPub, ok := cert.PublicKey.(*rsa.PublicKey)
986+
if !ok {
987+
continue
988+
}
989+
if rsa.VerifyPKCS1v15(rsaPub, hashAlg, hashed, sig) == nil {
990+
return nil
991+
}
992+
}
993+
994+
return fmt.Errorf("redirect binding signature validation failed")
995+
}
996+
888997
// validateSignature validates the XML digital signature of the given raw XML.
889998
func (p *provider) validateSignature(rawXML []byte) ([]byte, error) {
890999
if p.validator == nil {

connector/saml/saml_test.go

Lines changed: 114 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ import (
44
"bytes"
55
"compress/flate"
66
"context"
7+
"crypto"
8+
"crypto/rand"
9+
"crypto/rsa"
710
"crypto/tls"
811
"crypto/x509"
912
"encoding/base64"
@@ -1268,7 +1271,14 @@ func signXMLDocument(t *testing.T, doc *etree.Document) []byte {
12681271
return out
12691272
}
12701273

1271-
func TestHandleLogoutCallbackSignatureValidation(t *testing.T) {
1274+
func postSAMLResponse(encoded string) *http.Request {
1275+
form := url.Values{"SAMLResponse": {encoded}}
1276+
req := httptest.NewRequest(http.MethodPost, "/logout/callback", strings.NewReader(form.Encode()))
1277+
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
1278+
return req
1279+
}
1280+
1281+
func TestHandleLogoutCallbackPOSTSignatureValidation(t *testing.T) {
12721282
conn, err := (&Config{
12731283
CA: "testdata/ca.crt",
12741284
UsernameAttr: "Name",
@@ -1287,19 +1297,16 @@ func TestHandleLogoutCallbackSignatureValidation(t *testing.T) {
12871297
t.Fatal(err)
12881298
}
12891299
signedXML := signXMLDocument(t, doc)
1290-
12911300
encoded := base64.StdEncoding.EncodeToString(signedXML)
1292-
req := httptest.NewRequest(http.MethodGet, "/logout/callback?SAMLResponse="+url.QueryEscape(encoded), nil)
1293-
if err := conn.HandleLogoutCallback(context.Background(), req); err != nil {
1301+
1302+
if err := conn.HandleLogoutCallback(context.Background(), postSAMLResponse(encoded)); err != nil {
12941303
t.Errorf("expected no error for validly signed response, got: %v", err)
12951304
}
12961305
})
12971306

12981307
t.Run("InvalidSignature", func(t *testing.T) {
1299-
// Use unsigned XML — should fail signature validation
13001308
encoded := base64.StdEncoding.EncodeToString([]byte(successLogoutResponseXML))
1301-
req := httptest.NewRequest(http.MethodGet, "/logout/callback?SAMLResponse="+url.QueryEscape(encoded), nil)
1302-
if err := conn.HandleLogoutCallback(context.Background(), req); err == nil {
1309+
if err := conn.HandleLogoutCallback(context.Background(), postSAMLResponse(encoded)); err == nil {
13031310
t.Error("expected error for unsigned response when signature validation is enabled")
13041311
}
13051312
})
@@ -1322,15 +1329,112 @@ func TestHandleLogoutCallbackSignatureValidation(t *testing.T) {
13221329
t.Fatal(err)
13231330
}
13241331
signedXML := signXMLDocument(t, doc)
1325-
13261332
encoded := base64.StdEncoding.EncodeToString(signedXML)
1327-
req := httptest.NewRequest(http.MethodGet, "/logout/callback?SAMLResponse="+url.QueryEscape(encoded), nil)
1328-
if err := connBadCA.HandleLogoutCallback(context.Background(), req); err == nil {
1333+
1334+
if err := connBadCA.HandleLogoutCallback(context.Background(), postSAMLResponse(encoded)); err == nil {
13291335
t.Error("expected error when response is signed with different CA")
13301336
}
13311337
})
13321338
}
13331339

1340+
// signRedirectBinding builds a complete URL for a GET LogoutResponse with
1341+
// SAML HTTP-Redirect binding signature. The XML is deflated, base64-encoded,
1342+
// and a query-string RSA-SHA256 signature is appended.
1343+
func signRedirectBinding(t *testing.T, xmlPayload string, keyFile, certFile string) string {
1344+
t.Helper()
1345+
1346+
var buf bytes.Buffer
1347+
fw, err := flate.NewWriter(&buf, flate.DefaultCompression)
1348+
if err != nil {
1349+
t.Fatalf("deflate writer: %v", err)
1350+
}
1351+
if _, err := fw.Write([]byte(xmlPayload)); err != nil {
1352+
t.Fatalf("deflate write: %v", err)
1353+
}
1354+
if err := fw.Close(); err != nil {
1355+
t.Fatalf("deflate close: %v", err)
1356+
}
1357+
samlResp := base64.StdEncoding.EncodeToString(buf.Bytes())
1358+
1359+
sigAlg := "http://www.w3.org/2001/04/xmldsig-more#rsa-sha256"
1360+
signedContent := "SAMLResponse=" + url.QueryEscape(samlResp) +
1361+
"&SigAlg=" + url.QueryEscape(sigAlg)
1362+
1363+
tlsCert, err := tls.LoadX509KeyPair(certFile, keyFile)
1364+
if err != nil {
1365+
t.Fatalf("load key pair: %v", err)
1366+
}
1367+
rsaKey, ok := tlsCert.PrivateKey.(*rsa.PrivateKey)
1368+
if !ok {
1369+
t.Fatal("test key is not RSA")
1370+
}
1371+
1372+
h := crypto.SHA256.New()
1373+
h.Write([]byte(signedContent))
1374+
sig, err := rsa.SignPKCS1v15(rand.Reader, rsaKey, crypto.SHA256, h.Sum(nil))
1375+
if err != nil {
1376+
t.Fatalf("sign: %v", err)
1377+
}
1378+
sigB64 := base64.StdEncoding.EncodeToString(sig)
1379+
1380+
return "/logout/callback?" + signedContent +
1381+
"&Signature=" + url.QueryEscape(sigB64)
1382+
}
1383+
1384+
func TestHandleLogoutCallbackRedirectSignatureValidation(t *testing.T) {
1385+
conn, err := (&Config{
1386+
CA: "testdata/ca.crt",
1387+
UsernameAttr: "Name",
1388+
EmailAttr: "email",
1389+
RedirectURI: "http://127.0.0.1:5556/dex/callback",
1390+
SSOURL: "http://foo.bar/",
1391+
InsecureSkipSLOSignatureValidation: false,
1392+
}).openConnector(slog.New(slog.DiscardHandler))
1393+
if err != nil {
1394+
t.Fatal(err)
1395+
}
1396+
1397+
t.Run("ValidSignature", func(t *testing.T) {
1398+
u := signRedirectBinding(t, successLogoutResponseXML, "testdata/ca.key", "testdata/ca.crt")
1399+
req := httptest.NewRequest(http.MethodGet, u, nil)
1400+
if err := conn.HandleLogoutCallback(context.Background(), req); err != nil {
1401+
t.Errorf("expected no error, got: %v", err)
1402+
}
1403+
})
1404+
1405+
t.Run("MissingSignature", func(t *testing.T) {
1406+
var buf bytes.Buffer
1407+
fw, _ := flate.NewWriter(&buf, flate.DefaultCompression)
1408+
fw.Write([]byte(successLogoutResponseXML))
1409+
fw.Close()
1410+
encoded := base64.StdEncoding.EncodeToString(buf.Bytes())
1411+
1412+
req := httptest.NewRequest(http.MethodGet,
1413+
"/logout/callback?SAMLResponse="+url.QueryEscape(encoded), nil)
1414+
if err := conn.HandleLogoutCallback(context.Background(), req); err == nil {
1415+
t.Error("expected error for missing Signature parameter")
1416+
}
1417+
})
1418+
1419+
t.Run("WrongCA", func(t *testing.T) {
1420+
u := signRedirectBinding(t, successLogoutResponseXML, "testdata/bad-ca.key", "testdata/bad-ca.crt")
1421+
req := httptest.NewRequest(http.MethodGet, u, nil)
1422+
if err := conn.HandleLogoutCallback(context.Background(), req); err == nil {
1423+
t.Error("expected error when signed with wrong CA")
1424+
}
1425+
})
1426+
1427+
t.Run("TamperedPayload", func(t *testing.T) {
1428+
u := signRedirectBinding(t, successLogoutResponseXML, "testdata/ca.key", "testdata/ca.crt")
1429+
// Replace part of the SAMLResponse value to simulate tampering.
1430+
u = strings.Replace(u, "SAMLResponse=", "SAMLResponse=AAAA", 1)
1431+
req := httptest.NewRequest(http.MethodGet, u, nil)
1432+
if err := conn.HandleLogoutCallback(context.Background(), req); err == nil {
1433+
t.Error("expected error for tampered payload")
1434+
}
1435+
})
1436+
}
1437+
13341438
func TestSLOEndToEnd(t *testing.T) {
13351439
c := Config{
13361440
CA: "testdata/ca.crt",

connector/saml/types.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,7 @@ type logoutResponse struct {
306306
ID string `xml:"ID,attr"`
307307
InResponseTo string `xml:"InResponseTo,attr,omitempty"`
308308
Version samlVersion `xml:"Version,attr"`
309+
IssueInstant xmlTime `xml:"IssueInstant,attr,omitempty"`
309310
Destination string `xml:"Destination,attr,omitempty"`
310311

311312
Issuer *issuer `xml:"Issuer,omitempty"`

0 commit comments

Comments
 (0)