Skip to content

Commit 53db80d

Browse files
committed
Make request/response/metadata classes overridable by subclassing
1 parent b192e2a commit 53db80d

File tree

2 files changed

+17
-10
lines changed

2 files changed

+17
-10
lines changed

src/onelogin/saml2/auth.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@ class OneLogin_Saml2_Auth(object):
3434
SAML Response, a Logout Request or a Logout Response).
3535
"""
3636

37+
authn_request_class = OneLogin_Saml2_Authn_Request
38+
logout_request_class = OneLogin_Saml2_Logout_Request
39+
logout_response_class = OneLogin_Saml2_Logout_Response
40+
response_class = OneLogin_Saml2_Response
41+
3742
def __init__(self, request_data, old_settings=None, custom_base_path=None):
3843
"""
3944
Initializes the SP SAML instance.
@@ -102,7 +107,7 @@ def process_response(self, request_id=None):
102107

103108
if 'post_data' in self.__request_data and 'SAMLResponse' in self.__request_data['post_data']:
104109
# AuthnResponse -- HTTP_POST Binding
105-
response = OneLogin_Saml2_Response(self.__settings, self.__request_data['post_data']['SAMLResponse'])
110+
response = self.response_class(self.__settings, self.__request_data['post_data']['SAMLResponse'])
106111
self.__last_response = response.get_xml_document()
107112

108113
if response.is_valid(self.__request_data, request_id):
@@ -147,7 +152,7 @@ def process_slo(self, keep_local_session=False, request_id=None, delete_session_
147152

148153
get_data = 'get_data' in self.__request_data and self.__request_data['get_data']
149154
if get_data and 'SAMLResponse' in get_data:
150-
logout_response = OneLogin_Saml2_Logout_Response(self.__settings, get_data['SAMLResponse'])
155+
logout_response = self.logout_response_class(self.__settings, get_data['SAMLResponse'])
151156
self.__last_response = logout_response.get_xml()
152157
if not self.validate_response_signature(get_data):
153158
self.__errors.append('invalid_logout_response_signature')
@@ -163,7 +168,7 @@ def process_slo(self, keep_local_session=False, request_id=None, delete_session_
163168
OneLogin_Saml2_Utils.delete_local_session(delete_session_cb)
164169

165170
elif get_data and 'SAMLRequest' in get_data:
166-
logout_request = OneLogin_Saml2_Logout_Request(self.__settings, get_data['SAMLRequest'])
171+
logout_request = self.logout_request_class(self.__settings, get_data['SAMLRequest'])
167172
self.__last_request = logout_request.get_xml()
168173
if not self.validate_request_signature(get_data):
169174
self.__errors.append("invalid_logout_request_signature")
@@ -177,7 +182,7 @@ def process_slo(self, keep_local_session=False, request_id=None, delete_session_
177182

178183
in_response_to = logout_request.id
179184
self.__last_message_id = logout_request.id
180-
response_builder = OneLogin_Saml2_Logout_Response(self.__settings)
185+
response_builder = self.logout_response_class(self.__settings)
181186
response_builder.build(in_response_to)
182187
self.__last_response = response_builder.get_xml()
183188
logout_response = response_builder.get_response()
@@ -371,7 +376,7 @@ def login(self, return_to=None, force_authn=False, is_passive=False, set_nameid_
371376
:returns: Redirection URL
372377
:rtype: string
373378
"""
374-
authn_request = OneLogin_Saml2_Authn_Request(self.__settings, force_authn, is_passive, set_nameid_policy, name_id_value_req)
379+
authn_request = self.authn_request_class(self.__settings, force_authn, is_passive, set_nameid_policy, name_id_value_req)
375380
self.__last_request = authn_request.get_xml()
376381
self.__last_request_id = authn_request.get_id()
377382

@@ -425,7 +430,7 @@ def logout(self, return_to=None, name_id=None, session_index=None, nq=None, name
425430
if name_id_format is None and self.__nameid_format is not None:
426431
name_id_format = self.__nameid_format
427432

428-
logout_request = OneLogin_Saml2_Logout_Request(
433+
logout_request = self.logout_request_class(
429434
self.__settings,
430435
name_id=name_id,
431436
session_index=session_index,

src/onelogin/saml2/settings.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ class OneLogin_Saml2_Settings(object):
6666
6767
"""
6868

69+
metadata_class = OneLogin_Saml2_Metadata
70+
6971
def __init__(self, settings=None, custom_base_path=None, sp_validation_only=False):
7072
"""
7173
Initializes the settings:
@@ -616,7 +618,7 @@ def get_sp_metadata(self):
616618
:returns: SP metadata (xml)
617619
:rtype: string
618620
"""
619-
metadata = OneLogin_Saml2_Metadata.builder(
621+
metadata = self.metadata_class.builder(
620622
self.__sp, self.__security['authnRequestsSigned'],
621623
self.__security['wantAssertionsSigned'],
622624
self.__security['metadataValidUntil'],
@@ -627,10 +629,10 @@ def get_sp_metadata(self):
627629
add_encryption = self.__security['wantNameIdEncrypted'] or self.__security['wantAssertionsEncrypted']
628630

629631
cert_new = self.get_sp_cert_new()
630-
metadata = OneLogin_Saml2_Metadata.add_x509_key_descriptors(metadata, cert_new, add_encryption)
632+
metadata = self.metadata_class.add_x509_key_descriptors(metadata, cert_new, add_encryption)
631633

632634
cert = self.get_sp_cert()
633-
metadata = OneLogin_Saml2_Metadata.add_x509_key_descriptors(metadata, cert, add_encryption)
635+
metadata = self.metadata_class.add_x509_key_descriptors(metadata, cert, add_encryption)
634636

635637
# Sign metadata
636638
if 'signMetadata' in self.__security and self.__security['signMetadata'] is not False:
@@ -684,7 +686,7 @@ def get_sp_metadata(self):
684686
signature_algorithm = self.__security['signatureAlgorithm']
685687
digest_algorithm = self.__security['digestAlgorithm']
686688

687-
metadata = OneLogin_Saml2_Metadata.sign_metadata(metadata, key_metadata, cert_metadata, signature_algorithm, digest_algorithm)
689+
metadata = self.metadata_class.sign_metadata(metadata, key_metadata, cert_metadata, signature_algorithm, digest_algorithm)
688690

689691
return metadata
690692

0 commit comments

Comments
 (0)