diff --git a/src/main/java/com/coveo/saml/SamlClient.java b/src/main/java/com/coveo/saml/SamlClient.java index 5b03fa4..4c34fa8 100644 --- a/src/main/java/com/coveo/saml/SamlClient.java +++ b/src/main/java/com/coveo/saml/SamlClient.java @@ -630,6 +630,22 @@ public void decodeAndValidateSamlLogoutRequest( ValidatorUtils.validate(logoutRequest, responseIssuer, credentials, nameID); } + /** + * Decodes and validates an SAML logout request send by an identity provider. + * + * @param encodedRequest the encoded request send by the identity provider. + * @param method The HTTP method used by the request + * @return An {@link LogoutRequest} object containing information decoded from the SAML Logout + * Request. + * @throws SamlException if the signature is invalid, or if any other error occurs. + */ + public LogoutRequest decodeSamlLogoutRequest(String encodedRequest, String method) + throws SamlException { + LogoutRequest logoutRequest = (LogoutRequest) parseResponse(encodedRequest, method); + ValidatorUtils.validate(logoutRequest, responseIssuer, credentials); + return logoutRequest; + } + /** * Set service provider keys. * @@ -648,7 +664,8 @@ public void setSPKeys(String publicKey, String privateKey) throws SamlException * @param privateKey the private key * @throws SamlException if publicKey and privateKey don't form a valid credential */ - private BasicX509Credential generateBasicX509Credential(String publicKey, String privateKey) throws SamlException { + private BasicX509Credential generateBasicX509Credential(String publicKey, String privateKey) + throws SamlException { if (publicKey == null || privateKey == null) { throw new SamlException("No credentials provided"); } @@ -689,14 +706,15 @@ public void addAdditionalSPKey(String publicKey, String privateKey) throws SamlE * @param privateKey the private key * @throws SamlException if publicKey and privateKey don't form a valid credential */ - public void addAdditionalSPKey(X509Certificate certificate, PrivateKey privateKey) throws SamlException { + public void addAdditionalSPKey(X509Certificate certificate, PrivateKey privateKey) + throws SamlException { additionalSpCredentials.add(new BasicX509Credential(certificate, privateKey)); } /** * Remove all additional service provider decryption certificate/key pairs. */ - public void clearAdditionalSPKeys() throws SamlException { + public void clearAdditionalSPKeys() { additionalSpCredentials = new ArrayList<>(); } @@ -873,6 +891,21 @@ public void processLogoutRequestPostFromIdentityProvider( String encodedResponse = request.getParameter(HTTP_REQ_SAML_PARAM); decodeAndValidateSamlLogoutRequest(encodedResponse, nameID, request.getMethod()); } + + /** + * Processes a POST containing the SAML logout request. + * + * @param request the {@link HttpServletRequest}. + * @return An {@link LogoutRequest} object containing information decoded from the SAML Logout + * Request. + * @throws SamlException thrown is an unexpected error occurs. + */ + public LogoutRequest processLogoutRequestPostFromIdentityProvider(HttpServletRequest request) + throws SamlException { + String encodedResponse = request.getParameter(HTTP_REQ_SAML_PARAM); + return decodeSamlLogoutRequest(encodedResponse, request.getMethod()); + } + /** * Processes a POST containing the SAML response. * @@ -942,11 +975,11 @@ private void decodeEncryptedAssertion(Response response) throws DecryptionExcept // Create a decrypter. List resolverChain = new ArrayList<>(); - if(spCredential != null) { + if (spCredential != null) { resolverChain.add(new StaticKeyInfoCredentialResolver(spCredential)); } - if(!additionalSpCredentials.isEmpty()) { + if (!additionalSpCredentials.isEmpty()) { resolverChain.add(new CollectionKeyInfoCredentialResolver(additionalSpCredentials)); } diff --git a/src/main/java/com/coveo/saml/ValidatorUtils.java b/src/main/java/com/coveo/saml/ValidatorUtils.java index 311adf0..af974a7 100644 --- a/src/main/java/com/coveo/saml/ValidatorUtils.java +++ b/src/main/java/com/coveo/saml/ValidatorUtils.java @@ -1,7 +1,6 @@ package com.coveo.saml; import java.util.List; - import org.joda.time.DateTime; import org.opensaml.saml.common.SignableSAMLObject; import org.opensaml.saml.saml2.core.Assertion; @@ -25,7 +24,7 @@ class ValidatorUtils { /** * Validate response. * - * @param response the response + * @param response the response * @param responseIssuer the response issuer * @throws SamlException the saml exception */ @@ -58,7 +57,7 @@ private static void validateStatus(StatusResponseType response) throws SamlExcep /** * Validate issuer. * - * @param response the response + * @param response the response * @param responseIssuer the response issuer * @throws SamlException the saml exception */ @@ -68,10 +67,11 @@ private static void validateIssuer(StatusResponseType response, String responseI throw new SamlException("The response issuer didn't match the expected value"); } } + /** * Validate issuer. * - * @param request the response + * @param request the response * @param requestIssuer the request issuer * @throws SamlException the saml exception */ @@ -81,13 +81,14 @@ private static void validateIssuer(RequestAbstractType request, String requestIs throw new SamlException("The request issuer didn't match the expected value"); } } + /** * Validate assertion. * - * @param response the response + * @param response the response * @param responseIssuer the response issuer - * @param now the current date time (for unit test only) - * @param notBeforeSkew the notBeforeSkew + * @param now the current date time (for unit test only) + * @param notBeforeSkew the notBeforeSkew * @throws SamlException the saml exception */ private static void validateAssertion( @@ -114,8 +115,8 @@ private static void validateAssertion( * Enforce conditions. * * @param conditions the conditions - * @param _now the current date time (for unit test only) - * @param notBeforeSkew the notBeforeSkew + * @param _now the current date time (for unit test only) + * @param notBeforeSkew the notBeforeSkew * @throws SamlException the saml exception */ private static void enforceConditions(Conditions conditions, DateTime _now, long notBeforeSkew) @@ -137,7 +138,7 @@ private static void enforceConditions(Conditions conditions, DateTime _now, long /** * Validate signature. * - * @param response the response + * @param response the response * @param credentials the credentials * @throws SamlException the saml exception */ @@ -151,7 +152,7 @@ private static void validateSignature(SignableSAMLObject response, List credential /** * Validate. * - * @param response the response + * @param response the response * @param responseIssuer the response issuer - * @param credentials the credentials - * @param now the current date time (for unit test only) - * @param notBeforeSkew the notBeforeSkew + * @param credentials the credentials + * @param now the current date time (for unit test only) + * @param notBeforeSkew the notBeforeSkew * @throws SamlException the saml exception */ public static void validate( @@ -216,12 +217,13 @@ public static void validate( validateSignature(response, credentials); validateAssertionSignature(response, credentials); } + /** * Validate. * - * @param logoutRequest the response + * @param logoutRequest the response * @param responseIssuer the response issuer - * @param credentials the credentials + * @param credentials the credentials * @throws SamlException the saml exception */ public static void validate( @@ -233,12 +235,28 @@ public static void validate( validateLogoutRequest(logoutRequest, responseIssuer, nameID); validateSignature(logoutRequest, credentials); } + + /** + * Validate. + * + * @param logoutRequest the response + * @param responseIssuer the response issuer + * @param credentials the credentials + * @throws SamlException the saml exception + */ + public static void validate( + LogoutRequest logoutRequest, String responseIssuer, List credentials) + throws SamlException { + validateLogoutRequest(logoutRequest, responseIssuer); + validateSignature(logoutRequest, credentials); + } + /** * Validate. * - * @param response the response + * @param response the response * @param responseIssuer the response issuer - * @param credentials the credentials + * @param credentials the credentials * @throws SamlException the saml exception */ public static void validate( @@ -251,7 +269,7 @@ public static void validate( /** * Validate response. * - * @param response the response + * @param response the response * @param responseIssuer the response issuer * @throws SamlException the saml exception */ @@ -265,10 +283,11 @@ private static void validateResponse(Response response, String responseIssuer) validateIssuer(response, responseIssuer); validateStatus(response); } + /** * Validate response. * - * @param request the request + * @param request the request * @param requestIssuer the response issuer * @throws SamlException the saml exception */ @@ -283,11 +302,28 @@ private static void validateLogoutRequest( validateNameId(request, nameID); } + /** + * Validate response. + * + * @param request the request + * @param requestIssuer the response issuer + * @throws SamlException the saml exception + */ + private static void validateLogoutRequest(LogoutRequest request, String requestIssuer) + throws SamlException { + try { + new LogoutRequestSchemaValidator().validate(request); + } catch (SamlException ex) { + throw new SamlException("The request schema validation failed", ex); + } + validateIssuer(request, requestIssuer); + } + /** * Validate the logout request name id. * * @param request the request - * @param nameID the name id + * @param nameID the name id * @throws SamlException the saml exception */ private static void validateNameId(LogoutRequest request, String nameID) throws SamlException {