Skip to content

Commit 6510202

Browse files
authored
Merge pull request #212 from akx/overrides
Overridability enhancements
2 parents 8d130bf + ca04d1d commit 6510202

File tree

6 files changed

+50
-35
lines changed

6 files changed

+50
-35
lines changed

src/onelogin/saml2/auth.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@ class OneLogin_Saml2_Auth(object):
3434
SAML Response, a Logout Request or a Logout Response).
3535
"""
3636

37+
authn_request_class = OneLogin_Saml2_Authn_Request
38+
logout_request_class = OneLogin_Saml2_Logout_Request
39+
logout_response_class = OneLogin_Saml2_Logout_Response
40+
response_class = OneLogin_Saml2_Response
41+
3742
def __init__(self, request_data, old_settings=None, custom_base_path=None):
3843
"""
3944
Initializes the SP SAML instance.
@@ -103,7 +108,7 @@ def process_response(self, request_id=None):
103108

104109
if 'post_data' in self.__request_data and 'SAMLResponse' in self.__request_data['post_data']:
105110
# AuthnResponse -- HTTP_POST Binding
106-
response = OneLogin_Saml2_Response(self.__settings, self.__request_data['post_data']['SAMLResponse'])
111+
response = self.response_class(self.__settings, self.__request_data['post_data']['SAMLResponse'])
107112
self.__last_response = response.get_xml_document()
108113

109114
if response.is_valid(self.__request_data, request_id):
@@ -149,7 +154,7 @@ def process_slo(self, keep_local_session=False, request_id=None, delete_session_
149154

150155
get_data = 'get_data' in self.__request_data and self.__request_data['get_data']
151156
if get_data and 'SAMLResponse' in get_data:
152-
logout_response = OneLogin_Saml2_Logout_Response(self.__settings, get_data['SAMLResponse'])
157+
logout_response = self.logout_response_class(self.__settings, get_data['SAMLResponse'])
153158
self.__last_response = logout_response.get_xml()
154159
if not self.validate_response_signature(get_data):
155160
self.__errors.append('invalid_logout_response_signature')
@@ -165,7 +170,7 @@ def process_slo(self, keep_local_session=False, request_id=None, delete_session_
165170
OneLogin_Saml2_Utils.delete_local_session(delete_session_cb)
166171

167172
elif get_data and 'SAMLRequest' in get_data:
168-
logout_request = OneLogin_Saml2_Logout_Request(self.__settings, get_data['SAMLRequest'])
173+
logout_request = self.logout_request_class(self.__settings, get_data['SAMLRequest'])
169174
self.__last_request = logout_request.get_xml()
170175
if not self.validate_request_signature(get_data):
171176
self.__errors.append("invalid_logout_request_signature")
@@ -179,7 +184,7 @@ def process_slo(self, keep_local_session=False, request_id=None, delete_session_
179184

180185
in_response_to = logout_request.id
181186
self.__last_message_id = logout_request.id
182-
response_builder = OneLogin_Saml2_Logout_Response(self.__settings)
187+
response_builder = self.logout_response_class(self.__settings)
183188
response_builder.build(in_response_to)
184189
self.__last_response = response_builder.get_xml()
185190
logout_response = response_builder.get_response()
@@ -395,7 +400,7 @@ def login(self, return_to=None, force_authn=False, is_passive=False, set_nameid_
395400
:returns: Redirection URL
396401
:rtype: string
397402
"""
398-
authn_request = OneLogin_Saml2_Authn_Request(self.__settings, force_authn, is_passive, set_nameid_policy, name_id_value_req)
403+
authn_request = self.authn_request_class(self.__settings, force_authn, is_passive, set_nameid_policy, name_id_value_req)
399404
self.__last_request = authn_request.get_xml()
400405
self.__last_request_id = authn_request.get_id()
401406

@@ -449,7 +454,7 @@ def logout(self, return_to=None, name_id=None, session_index=None, nq=None, name
449454
if name_id_format is None and self.__nameid_format is not None:
450455
name_id_format = self.__nameid_format
451456

452-
logout_request = OneLogin_Saml2_Logout_Request(
457+
logout_request = self.logout_request_class(
453458
self.__settings,
454459
name_id=name_id,
455460
session_index=session_index,

src/onelogin/saml2/authn_request.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,7 @@ def __init__(self, settings, force_authn=False, is_passive=False, set_nameid_pol
4747
idp_data = self.__settings.get_idp_data()
4848
security = self.__settings.get_security_data()
4949

50-
uid = OneLogin_Saml2_Utils.generate_unique_id()
51-
self.__id = uid
50+
self.__id = self._generate_request_id()
5251
issue_instant = OneLogin_Saml2_Utils.parse_time_to_SAML(OneLogin_Saml2_Utils.now())
5352

5453
destination = idp_data['singleSignOnService']['url']
@@ -113,7 +112,7 @@ def __init__(self, settings, force_authn=False, is_passive=False, set_nameid_pol
113112

114113
request = OneLogin_Saml2_Templates.AUTHN_REQUEST % \
115114
{
116-
'id': uid,
115+
'id': self.__id,
117116
'provider_name': provider_name_str,
118117
'force_authn_str': force_authn_str,
119118
'is_passive_str': is_passive_str,
@@ -129,6 +128,14 @@ def __init__(self, settings, force_authn=False, is_passive=False, set_nameid_pol
129128

130129
self.__authn_request = request
131130

131+
def _generate_request_id(self):
132+
"""
133+
Generate an unique request ID.
134+
135+
You can override this in a subclass.
136+
"""
137+
return OneLogin_Saml2_Utils.generate_unique_id()
138+
132139
def get_request(self, deflate=True):
133140
"""
134141
Returns unsigned AuthnRequest.

src/onelogin/saml2/idp_metadata_parser.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ class OneLogin_Saml2_IdPMetadataParser(object):
2525
A class that contain methods related to obtaining and parsing metadata from IdP
2626
"""
2727

28-
@staticmethod
29-
def get_metadata(url, validate_cert=True):
28+
@classmethod
29+
def get_metadata(cls, url, validate_cert=True):
3030
"""
3131
Gets the metadata XML from the provided URL
3232
:param url: Url where the XML of the Identity Provider Metadata is published.
@@ -63,8 +63,8 @@ def get_metadata(url, validate_cert=True):
6363

6464
return xml
6565

66-
@staticmethod
67-
def parse_remote(url, validate_cert=True, entity_id=None, **kwargs):
66+
@classmethod
67+
def parse_remote(cls, url, validate_cert=True, entity_id=None, **kwargs):
6868
"""
6969
Gets the metadata XML from the provided URL and parse it, returning a dict with extracted data
7070
:param url: Url where the XML of the Identity Provider Metadata is published.
@@ -80,11 +80,12 @@ def parse_remote(url, validate_cert=True, entity_id=None, **kwargs):
8080
:returns: settings dict with extracted data
8181
:rtype: dict
8282
"""
83-
idp_metadata = OneLogin_Saml2_IdPMetadataParser.get_metadata(url, validate_cert)
84-
return OneLogin_Saml2_IdPMetadataParser.parse(idp_metadata, entity_id=entity_id, **kwargs)
83+
idp_metadata = cls.get_metadata(url, validate_cert)
84+
return cls.parse(idp_metadata, entity_id=entity_id, **kwargs)
8585

86-
@staticmethod
86+
@classmethod
8787
def parse(
88+
cls,
8889
idp_metadata,
8990
required_sso_binding=OneLogin_Saml2_Constants.BINDING_HTTP_REDIRECT,
9091
required_slo_binding=OneLogin_Saml2_Constants.BINDING_HTTP_REDIRECT,

src/onelogin/saml2/logout_request.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -203,8 +203,8 @@ def get_nameid_data(request, key=None):
203203

204204
return name_id_data
205205

206-
@staticmethod
207-
def get_nameid(request, key=None):
206+
@classmethod
207+
def get_nameid(cls, request, key=None):
208208
"""
209209
Gets the NameID of the Logout Request Message
210210
:param request: Logout Request Message
@@ -214,11 +214,11 @@ def get_nameid(request, key=None):
214214
:return: Name ID Value
215215
:rtype: string
216216
"""
217-
name_id = OneLogin_Saml2_Logout_Request.get_nameid_data(request, key)
217+
name_id = cls.get_nameid_data(request, key)
218218
return name_id['Value']
219219

220-
@staticmethod
221-
def get_nameid_format(request, key=None):
220+
@classmethod
221+
def get_nameid_format(cls, request, key=None):
222222
"""
223223
Gets the NameID Format of the Logout Request Message
224224
:param request: Logout Request Message
@@ -229,7 +229,7 @@ def get_nameid_format(request, key=None):
229229
:rtype: string
230230
"""
231231
name_id_format = None
232-
name_id_data = OneLogin_Saml2_Logout_Request.get_nameid_data(request, key)
232+
name_id_data = cls.get_nameid_data(request, key)
233233
if name_id_data and 'Format' in name_id_data.keys():
234234
name_id_format = name_id_data['Format']
235235
return name_id_format
@@ -325,7 +325,7 @@ def is_valid(self, request_data, raise_exceptions=False):
325325
)
326326

327327
# Check issuer
328-
issuer = OneLogin_Saml2_Logout_Request.get_issuer(root)
328+
issuer = self.get_issuer(root)
329329
if issuer is not None and issuer != idp_entity_id:
330330
raise OneLogin_Saml2_ValidationError(
331331
'Invalid issuer in the Logout Request (expected %(idpEntityId)s, got %(issuer)s)' %

src/onelogin/saml2/metadata.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ class OneLogin_Saml2_Metadata(object):
3434
TIME_VALID = 172800 # 2 days
3535
TIME_CACHED = 604800 # 1 week
3636

37-
@staticmethod
38-
def builder(sp, authnsign=False, wsign=False, valid_until=None, cache_duration=None, contacts=None, organization=None):
37+
@classmethod
38+
def builder(cls, sp, authnsign=False, wsign=False, valid_until=None, cache_duration=None, contacts=None, organization=None):
3939
"""
4040
Builds the metadata of the SP
4141
@@ -61,7 +61,7 @@ def builder(sp, authnsign=False, wsign=False, valid_until=None, cache_duration=N
6161
:type organization: dict
6262
"""
6363
if valid_until is None:
64-
valid_until = int(time()) + OneLogin_Saml2_Metadata.TIME_VALID
64+
valid_until = int(time()) + cls.TIME_VALID
6565
if not isinstance(valid_until, basestring):
6666
if isinstance(valid_until, datetime):
6767
valid_until_time = valid_until.timetuple()
@@ -72,7 +72,7 @@ def builder(sp, authnsign=False, wsign=False, valid_until=None, cache_duration=N
7272
valid_until_str = valid_until
7373

7474
if cache_duration is None:
75-
cache_duration = OneLogin_Saml2_Metadata.TIME_CACHED
75+
cache_duration = cls.TIME_CACHED
7676
if not isinstance(cache_duration, compat.str_type):
7777
cache_duration_str = 'PT%sS' % cache_duration # Period of Time x Seconds
7878
else:
@@ -228,8 +228,8 @@ def __add_x509_key_descriptors(root, cert, signing):
228228
x509_certificate.text = OneLogin_Saml2_Utils.format_cert(cert, False)
229229
key_descriptor.set('use', ('encryption', 'signing')[signing])
230230

231-
@staticmethod
232-
def add_x509_key_descriptors(metadata, cert=None, add_encryption=True):
231+
@classmethod
232+
def add_x509_key_descriptors(cls, metadata, cert=None, add_encryption=True):
233233
"""
234234
Adds the x509 descriptors (sign/encryption) to the metadata
235235
The same cert will be used for sign/encrypt
@@ -260,6 +260,6 @@ def add_x509_key_descriptors(metadata, cert=None, add_encryption=True):
260260
raise Exception('Malformed metadata.')
261261

262262
if add_encryption:
263-
OneLogin_Saml2_Metadata.__add_x509_key_descriptors(sp_sso_descriptor, cert, False)
264-
OneLogin_Saml2_Metadata.__add_x509_key_descriptors(sp_sso_descriptor, cert, True)
263+
cls.__add_x509_key_descriptors(sp_sso_descriptor, cert, False)
264+
cls.__add_x509_key_descriptors(sp_sso_descriptor, cert, True)
265265
return OneLogin_Saml2_XML.to_string(root)

src/onelogin/saml2/settings.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@ class OneLogin_Saml2_Settings(object):
8181
8282
"""
8383

84+
metadata_class = OneLogin_Saml2_Metadata
85+
8486
def __init__(self, settings=None, custom_base_path=None, sp_validation_only=False):
8587
"""
8688
Initializes the settings:
@@ -661,7 +663,7 @@ def get_sp_metadata(self):
661663
:returns: SP metadata (xml)
662664
:rtype: string
663665
"""
664-
metadata = OneLogin_Saml2_Metadata.builder(
666+
metadata = self.metadata_class.builder(
665667
self.__sp, self.__security['authnRequestsSigned'],
666668
self.__security['wantAssertionsSigned'],
667669
self.__security['metadataValidUntil'],
@@ -672,10 +674,10 @@ def get_sp_metadata(self):
672674
add_encryption = self.__security['wantNameIdEncrypted'] or self.__security['wantAssertionsEncrypted']
673675

674676
cert_new = self.get_sp_cert_new()
675-
metadata = OneLogin_Saml2_Metadata.add_x509_key_descriptors(metadata, cert_new, add_encryption)
677+
metadata = self.metadata_class.add_x509_key_descriptors(metadata, cert_new, add_encryption)
676678

677679
cert = self.get_sp_cert()
678-
metadata = OneLogin_Saml2_Metadata.add_x509_key_descriptors(metadata, cert, add_encryption)
680+
metadata = self.metadata_class.add_x509_key_descriptors(metadata, cert, add_encryption)
679681

680682
# Sign metadata
681683
if 'signMetadata' in self.__security and self.__security['signMetadata'] is not False:
@@ -729,7 +731,7 @@ def get_sp_metadata(self):
729731
signature_algorithm = self.__security['signatureAlgorithm']
730732
digest_algorithm = self.__security['digestAlgorithm']
731733

732-
metadata = OneLogin_Saml2_Metadata.sign_metadata(metadata, key_metadata, cert_metadata, signature_algorithm, digest_algorithm)
734+
metadata = self.metadata_class.sign_metadata(metadata, key_metadata, cert_metadata, signature_algorithm, digest_algorithm)
733735

734736
return metadata
735737

0 commit comments

Comments
 (0)