Skip to content

Commit bfb7be4

Browse files
committed
Refactor settings
1 parent e815da3 commit bfb7be4

1 file changed

Lines changed: 72 additions & 122 deletions

File tree

src/onelogin/saml2/settings.py

Lines changed: 72 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -206,22 +206,12 @@ def __load_settings_from_dict(self, settings):
206206
if len(errors) == 0:
207207
self.__errors = []
208208
self.__sp = settings['sp']
209-
210-
if 'idp' in settings:
211-
self.__idp = settings['idp']
212-
213-
if 'strict' in settings:
214-
self.__strict = settings['strict']
215-
if 'debug' in settings:
216-
self.__debug = settings['debug']
217-
if 'security' in settings:
218-
self.__security = settings['security']
219-
else:
220-
self.__security = {}
221-
if 'contactPerson' in settings:
222-
self.__contacts = settings['contactPerson']
223-
if 'organization' in settings:
224-
self.__organization = settings['organization']
209+
self.__idp = settings.get('idp', {})
210+
self.__strict = settings.get('strict', False)
211+
self.__debug = settings.get('debug', False)
212+
self.__security = settings.get('security', {})
213+
self.__contacts = settings.get('contactPerson', {})
214+
self.__organization = settings.get('organization', {})
225215

226216
self.__add_default_values()
227217
return True
@@ -261,79 +251,54 @@ def __add_default_values(self):
261251
"""
262252
Add default values if the settings info is not complete
263253
"""
264-
if 'assertionConsumerService' not in self.__sp:
265-
self.__sp['assertionConsumerService'] = {}
266-
if 'binding' not in self.__sp['assertionConsumerService']:
267-
self.__sp['assertionConsumerService']['binding'] = OneLogin_Saml2_Constants.BINDING_HTTP_POST
254+
self.__sp.setdefault('assertionConsumerService', {})
255+
self.__sp['assertionConsumerService'].setdefault('binding', OneLogin_Saml2_Constants.BINDING_HTTP_POST)
268256

269-
if 'singleLogoutService' not in self.__sp:
270-
self.__sp['singleLogoutService'] = {}
271-
if 'binding' not in self.__sp['singleLogoutService']:
272-
self.__sp['singleLogoutService']['binding'] = OneLogin_Saml2_Constants.BINDING_HTTP_REDIRECT
257+
self.__sp.setdefault('attributeConsumingService', {})
273258

274-
if 'singleLogoutService' not in self.__idp:
275-
self.__idp['singleLogoutService'] = {}
259+
self.__sp.setdefault('singleLogoutService', {})
276260

277-
# Related to nameID
278-
if 'NameIDFormat' not in self.__sp:
279-
self.__sp['NameIDFormat'] = OneLogin_Saml2_Constants.NAMEID_UNSPECIFIED
280-
if 'nameIdEncrypted' not in self.__security:
281-
self.__security['nameIdEncrypted'] = False
261+
self.__sp['singleLogoutService'].setdefault('binding', OneLogin_Saml2_Constants.BINDING_HTTP_REDIRECT)
282262

283-
# Sign provided
284-
if 'authnRequestsSigned' not in self.__security:
285-
self.__security['authnRequestsSigned'] = False
286-
if 'logoutRequestSigned' not in self.__security:
287-
self.__security['logoutRequestSigned'] = False
288-
if 'logoutResponseSigned' not in self.__security:
289-
self.__security['logoutResponseSigned'] = False
290-
if 'signMetadata' not in self.__security:
291-
self.__security['signMetadata'] = False
263+
# Related to nameID
264+
self.__sp.setdefault('NameIDFormat', OneLogin_Saml2_Constants.NAMEID_UNSPECIFIED)
265+
self.__security.setdefault('nameIdEncrypted', False)
292266

293267
# Metadata format
294-
if 'metadataValidUntil' not in self.__security.keys():
295-
self.__security['metadataValidUntil'] = None # None means use default
296-
if 'metadataCacheDuration' not in self.__security.keys():
297-
self.__security['metadataCacheDuration'] = None # None means use default
268+
self.__security.setdefault('metadataValidUntil', None) # None means use default
269+
self.__security.setdefault('metadataCacheDuration', None) # None means use default
270+
271+
# Sign provided
272+
self.__security.setdefault('authnRequestsSigned', False)
273+
self.__security.setdefault('logoutRequestSigned', False)
274+
self.__security.setdefault('logoutResponseSigned', False)
275+
self.__security.setdefault('signMetadata', False)
298276

299277
# Sign expected
300-
if 'wantMessagesSigned' not in self.__security:
301-
self.__security['wantMessagesSigned'] = False
302-
if 'wantAssertionsSigned' not in self.__security:
303-
self.__security['wantAssertionsSigned'] = False
278+
self.__security.setdefault('wantMessagesSigned', False)
279+
self.__security.setdefault('wantAssertionsSigned', False)
304280

305281
# NameID element expected
306-
if 'wantNameId' not in self.__security.keys():
307-
self.__security['wantNameId'] = True
282+
self.__security.setdefault('wantNameId', True)
308283

309284
# Encrypt expected
310-
if 'wantAssertionsEncrypted' not in self.__security:
311-
self.__security['wantAssertionsEncrypted'] = False
312-
if 'wantNameIdEncrypted' not in self.__security:
313-
self.__security['wantNameIdEncrypted'] = False
285+
self.__security.setdefault('wantAssertionsEncrypted', False)
286+
self.__security.setdefault('wantNameIdEncrypted', False)
314287

315288
# Signature Algorithm
316-
if 'signatureAlgorithm' not in self.__security.keys():
317-
self.__security['signatureAlgorithm'] = OneLogin_Saml2_Constants.RSA_SHA1
289+
self.__security.setdefault('signatureAlgorithm', OneLogin_Saml2_Constants.RSA_SHA1)
318290

319291
# AttributeStatement required by default
320-
if 'wantAttributeStatement' not in self.__security.keys():
321-
self.__security['wantAttributeStatement'] = True
292+
self.__security.setdefault('wantAttributeStatement', True)
322293

323-
if 'x509cert' not in self.__idp:
324-
self.__idp['x509cert'] = ''
325-
if 'certFingerprint' not in self.__idp:
326-
self.__idp['certFingerprint'] = ''
327-
if 'certFingerprintAlgorithm' not in self.__idp:
328-
self.__idp['certFingerprintAlgorithm'] = 'sha1'
294+
self.__idp.setdefault('x509cert', '')
295+
self.__idp.setdefault('certFingerprint', '')
296+
self.__idp.setdefault('certFingerprintAlgorithm', 'sha1')
329297

330-
if 'x509cert' not in self.__sp:
331-
self.__sp['x509cert'] = ''
332-
if 'privateKey' not in self.__sp:
333-
self.__sp['privateKey'] = ''
298+
self.__sp.setdefault('x509cert', '')
299+
self.__sp.setdefault('privateKey', '')
334300

335-
if 'requestedAuthnContext' not in self.__security:
336-
self.__security['requestedAuthnContext'] = True
301+
self.__security.setdefault('requestedAuthnContext', True)
337302

338303
def check_settings(self, settings):
339304
"""
@@ -372,37 +337,31 @@ def check_idp_settings(self, settings):
372337
if not isinstance(settings, dict) or len(settings) == 0:
373338
errors.append('invalid_syntax')
374339
else:
375-
if 'idp' not in settings or len(settings['idp']) == 0:
340+
if not settings.get('idp'):
376341
errors.append('idp_not_found')
377342
else:
378343
idp = settings['idp']
379-
if 'entityId' not in idp or len(idp['entityId']) == 0:
344+
if not idp.get('entityId'):
380345
errors.append('idp_entityId_not_found')
381346

382-
if 'singleSignOnService' not in idp or \
383-
'url' not in idp['singleSignOnService'] or \
384-
len(idp['singleSignOnService']['url']) == 0:
347+
if not idp.get('singleSignOnService', {}).get('url'):
385348
errors.append('idp_sso_not_found')
386349
elif not validate_url(idp['singleSignOnService']['url']):
387350
errors.append('idp_sso_url_invalid')
388351

389-
if 'singleLogoutService' in idp and \
390-
'url' in idp['singleLogoutService'] and \
391-
len(idp['singleLogoutService']['url']) > 0 and \
392-
not validate_url(idp['singleLogoutService']['url']):
352+
slo_url = idp.get('singleLogoutService', {}).get('url')
353+
if slo_url and not validate_url(slo_url):
393354
errors.append('idp_slo_url_invalid')
394355

395356
if 'security' in settings:
396357
security = settings['security']
397358

398-
exists_x509 = ('x509cert' in idp and
399-
len(idp['x509cert']) > 0)
400-
exists_fingerprint = ('certFingerprint' in idp and
401-
len(idp['certFingerprint']) > 0)
359+
exists_x509 = bool(idp.get('x509cert'))
360+
exists_fingerprint = bool(idp.get('certFingerprint'))
402361

403-
want_assert_sign = 'wantAssertionsSigned' in security.keys() and security['wantAssertionsSigned']
404-
want_mes_signed = 'wantMessagesSigned' in security.keys() and security['wantMessagesSigned']
405-
nameid_enc = 'nameIdEncrypted' in security.keys() and security['nameIdEncrypted']
362+
want_assert_sign = bool(security.get('wantAssertionsSigned'))
363+
want_mes_signed = bool(security.get('wantMessagesSigned'))
364+
nameid_enc = bool(security.get('nameIdEncrypted'))
406365

407366
if (want_assert_sign or want_mes_signed) and \
408367
not(exists_x509 or exists_fingerprint):
@@ -422,32 +381,28 @@ def check_sp_settings(self, settings):
422381
assert isinstance(settings, dict)
423382

424383
errors = []
425-
if not isinstance(settings, dict) or len(settings) == 0:
384+
if not isinstance(settings, dict) or not settings:
426385
errors.append('invalid_syntax')
427386
else:
428-
if 'sp' not in settings or len(settings['sp']) == 0:
387+
if not settings.get('sp'):
429388
errors.append('sp_not_found')
430389
else:
431390
# check_sp_certs uses self.__sp so I add it
432391
old_sp = self.__sp
433392
self.__sp = settings['sp']
434393

435394
sp = settings['sp']
436-
security = {}
437-
if 'security' in settings:
438-
security = settings['security']
395+
security = settings.get('security', {})
439396

440-
if 'entityId' not in sp or len(sp['entityId']) == 0:
397+
if not sp.get('entityId'):
441398
errors.append('sp_entityId_not_found')
442399

443-
if 'assertionConsumerService' not in sp or \
444-
'url' not in sp['assertionConsumerService'] or \
445-
len(sp['assertionConsumerService']['url']) == 0:
400+
if not sp.get('assertionConsumerService', {}).get('url'):
446401
errors.append('sp_acs_not_found')
447402
elif not validate_url(sp['assertionConsumerService']['url']):
448403
errors.append('sp_acs_url_invalid')
449404

450-
if 'attributeConsumingService' in sp and len(sp['attributeConsumingService']):
405+
if sp.get('attributeConsumingService'):
451406
attributeConsumingService = sp['attributeConsumingService']
452407
if 'serviceName' not in attributeConsumingService:
453408
errors.append('sp_attributeConsumingService_serviceName_not_found')
@@ -472,22 +427,20 @@ def check_sp_settings(self, settings):
472427
if "serviceDescription" in attributeConsumingService and not isinstance(attributeConsumingService['serviceDescription'], basestring):
473428
errors.append('sp_attributeConsumingService_serviceDescription_type_invalid')
474429

475-
if 'singleLogoutService' in sp and \
476-
'url' in sp['singleLogoutService'] and \
477-
len(sp['singleLogoutService']['url']) > 0 and \
478-
not validate_url(sp['singleLogoutService']['url']):
430+
slo_url = sp.get('singleLogoutService', {}).get('url')
431+
if slo_url and not validate_url(slo_url):
479432
errors.append('sp_sls_url_invalid')
480433

481434
if 'signMetadata' in security and isinstance(security['signMetadata'], dict):
482435
if 'keyFileName' not in security['signMetadata'] or \
483436
'certFileName' not in security['signMetadata']:
484437
errors.append('sp_signMetadata_invalid')
485438

486-
authn_sign = 'authnRequestsSigned' in security and security['authnRequestsSigned']
487-
logout_req_sign = 'logoutRequestSigned' in security and security['logoutRequestSigned']
488-
logout_res_sign = 'logoutResponseSigned' in security and security['logoutResponseSigned']
489-
want_assert_enc = 'wantAssertionsEncrypted' in security and security['wantAssertionsEncrypted']
490-
want_nameid_enc = 'wantNameIdEncrypted' in security and security['wantNameIdEncrypted']
439+
authn_sign = bool(security.get('authnRequestsSigned'))
440+
logout_req_sign = bool(security.get('logoutRequestSigned'))
441+
logout_res_sign = bool(security.get('logoutResponseSigned'))
442+
want_assert_enc = bool(security.get('wantAssertionsEncrypted'))
443+
want_nameid_enc = bool(security.get('wantNameIdEncrypted'))
491444

492445
if not self.check_sp_certs():
493446
if authn_sign or logout_req_sign or logout_res_sign or \
@@ -526,7 +479,6 @@ def check_sp_settings(self, settings):
526479
def check_sp_certs(self):
527480
"""
528481
Checks if the x509 certs of the SP exists and are valid.
529-
530482
:returns: If the x509 certs of the SP exists and are valid
531483
:rtype: boolean
532484
"""
@@ -537,42 +489,40 @@ def check_sp_certs(self):
537489
def get_sp_key(self):
538490
"""
539491
Returns the x509 private key of the SP.
540-
541492
:returns: SP private key
542-
:rtype: string
493+
:rtype: string or None
543494
"""
544495
key = self.__sp.get('privateKey')
545-
if not key:
546-
key_file_name = self.__paths['cert'] + 'sp.key'
496+
key_file_name = self.__paths['cert'] + 'sp.key'
497+
498+
if not key and exists(key_file_name):
499+
with open(key_file_name) as f:
500+
key = f.read()
547501

548-
if exists(key_file_name):
549-
with open(key_file_name, 'r') as f_key:
550-
key = self.__sp['privateKey'] = f_key.read()
551502
return key or None
552503

553504
def get_sp_cert(self):
554505
"""
555506
Returns the x509 public cert of the SP.
556-
557507
:returns: SP public cert
558-
:rtype: string
508+
:rtype: string or None
559509
"""
560510
cert = self.__sp.get('x509cert')
561-
if not cert:
562-
cert_file_name = self.__paths['cert'] + 'sp.crt'
563-
if exists(cert_file_name):
564-
with open(cert_file_name, 'r') as f_cert:
565-
cert = self.__sp['x509cert'] = f_cert.read()
511+
cert_file_name = self.__paths['cert'] + 'sp.crt'
512+
513+
if not cert and exists(cert_file_name):
514+
with open(cert_file_name) as f:
515+
cert = f.read()
516+
566517
return cert or None
567518

568519
def get_idp_cert(self):
569520
"""
570521
Returns the x509 public cert of the IdP.
571-
572522
:returns: IdP public cert
573523
:rtype: string
574524
"""
575-
return self.__idp['x509cert'] or None
525+
return self.__idp.get('x509cert')
576526

577527
def get_idp_data(self):
578528
"""

0 commit comments

Comments
 (0)