2222import unittest
2323
2424import mock
25+ import pytest
2526
2627import endpoints .api_config as api_config
2728
@@ -369,8 +370,9 @@ def testEmptyAudience(self):
369370 parsed_token , users_id_token ._ISSUERS , [], self ._SAMPLE_ALLOWED_CLIENT_IDS )
370371 self .assertEqual (False , result )
371372
373+ @mock .patch .object (oauth , 'get_authorized_scopes' )
372374 @mock .patch .object (oauth , 'get_client_id' )
373- def AttemptOauth (self , client_id , mock_get_client_id , allowed_client_ids = None ):
375+ def AttemptOauth (self , client_id , mock_get_client_id , mock_get_authorized_scopes , allowed_client_ids = None ):
374376 if allowed_client_ids is None :
375377 allowed_client_ids = self ._SAMPLE_ALLOWED_CLIENT_IDS
376378 # We have four cases:
@@ -381,20 +383,20 @@ def AttemptOauth(self, client_id, mock_get_client_id, allowed_client_ids=None):
381383 # mock call for every scope.
382384 if client_id is None :
383385 mock_get_client_id .side_effect = oauth .Error
386+ mock_get_authorized_scopes .side_effect = oauth .Error
384387 else :
385388 mock_get_client_id .return_value = client_id
389+ mock_get_authorized_scopes .return_value = self ._SAMPLE_OAUTH_SCOPES
386390 users_id_token ._set_bearer_user_vars (allowed_client_ids ,
387391 self ._SAMPLE_OAUTH_SCOPES )
388392 if client_id is None :
389- for scope in self ._SAMPLE_OAUTH_SCOPES :
390- mock_get_client_id .assert_called_with (scope )
393+ mock_get_authorized_scopes .assert_called_with (self ._SAMPLE_OAUTH_SCOPES )
391394 elif (list (allowed_client_ids ) == users_id_token .SKIP_CLIENT_ID_CHECK or
392395 client_id in allowed_client_ids ):
393396 scope = self ._SAMPLE_OAUTH_SCOPES [0 ]
394- mock_get_client_id .assert_called_with (scope )
397+ mock_get_client_id .assert_called_with ([ scope ] )
395398 else :
396- for scope in self ._SAMPLE_OAUTH_SCOPES :
397- mock_get_client_id .assert_called_with (scope )
399+ mock_get_client_id .assert_called_with (self ._SAMPLE_OAUTH_SCOPES )
398400
399401
400402 def assertOauthSucceeded (self , client_id ):
@@ -487,10 +489,10 @@ def testGetCurrentUserEmailAndAuth(self):
487489 def testGetCurrentUserOauth (self , mock_get_current_user ):
488490 mock_get_current_user .return_value = users .User ('test@gmail.com' )
489491
490- os .environ ['ENDPOINTS_USE_OAUTH_SCOPE' ] = 'scope '
492+ os .environ ['ENDPOINTS_USE_OAUTH_SCOPE' ] = 'scope1 scope2 '
491493 user = users_id_token .get_current_user ()
492494 self .assertEqual (user .email (), 'test@gmail.com' )
493- mock_get_current_user .assert_called_once_with ('scope' )
495+ mock_get_current_user .assert_called_once_with ([ 'scope1' , 'scope2' ] )
494496
495497 def testGetTokenQueryParamOauthHeader (self ):
496498 os .environ ['HTTP_AUTHORIZATION' ] = 'OAuth ' + self ._SAMPLE_TOKEN
@@ -631,9 +633,10 @@ def testMethodCallParsesIdToken(self):
631633 self .VerifyIdToken (self .TestApiAnnotatedAtApi (),
632634 message_types .VoidMessage ())
633635
636+ @mock .patch .object (oauth , 'get_authorized_scopes' )
634637 @mock .patch .object (oauth , 'get_client_id' )
635638 @mock .patch .object (users_id_token , '_is_local_dev' )
636- def testMaybeSetVarsWithActualRequestAccessToken (self , mock_local , mock_get_client_id ):
639+ def testMaybeSetVarsWithActualRequestAccessToken (self , mock_local , mock_get_client_id , mock_get_authorized_scopes ):
637640 dummy_scope = 'scope'
638641 dummy_token = 'dummy_token'
639642 dummy_email = 'test@gmail.com'
@@ -656,13 +659,15 @@ def method(self, request):
656659
657660 mock_local .return_value = False
658661 mock_get_client_id .return_value = dummy_client_id
662+ mock_get_authorized_scopes .return_value = [dummy_scope ]
659663
660664 api_instance = TestApiScopes ()
661665 os .environ ['HTTP_AUTHORIZATION' ] = 'Bearer ' + dummy_token
662666 api_instance .method (message_types .VoidMessage ())
663- self . assertEqual ( os .getenv ('ENDPOINTS_USE_OAUTH_SCOPE' ), dummy_scope )
667+ assert os .getenv ('ENDPOINTS_USE_OAUTH_SCOPE' ) == dummy_scope
664668 mock_local .assert_has_calls ([mock .call (), mock .call ()])
665- mock_get_client_id .assert_called_once_with (dummy_scope )
669+ mock_get_client_id .assert_called_once_with ([dummy_scope ])
670+ mock_get_authorized_scopes .assert_called_once_with ([dummy_scope ])
666671
667672 @mock .patch .object (users_id_token , '_get_id_token_user' )
668673 @mock .patch .object (time , 'time' )
@@ -891,5 +896,28 @@ def testBadBase64(self):
891896 self ._SAMPLE_CERT_URI , self .cache )
892897 self .assertIsNone (parsed_token )
893898
899+
900+ @pytest .mark .parametrize (('scopelist' , 'all_scopes' , 'sufficient_scopes' ), [
901+ (('scope1' , 'scope2' ), {'scope1' , 'scope2' }, {frozenset (['scope1' ]), frozenset (['scope2' ])}),
902+ (('scope1' , 'scope2 scope3' ), {'scope1' , 'scope2' , 'scope3' }, {frozenset (['scope1' ]), frozenset (['scope2' , 'scope3' ])}),
903+ (('scope1 scope2' , 'scope1 scope3' ), {'scope1' , 'scope2' , 'scope3' }, {frozenset (['scope1' , 'scope2' ]), frozenset (['scope1' , 'scope3' ])}),
904+ ])
905+ def test_process_scopes (scopelist , all_scopes , sufficient_scopes ):
906+ result = users_id_token ._process_scopes (scopelist )
907+ assert result == (all_scopes , sufficient_scopes )
908+
909+ @pytest .mark .parametrize (('authorized_scopes' , 'sufficient_scopes' , 'is_valid' ), [
910+ (['scope1' ], {frozenset (['scope1' ])}, True ),
911+ (['scope1' ], {frozenset (['scope1' , 'scope2' ])}, False ),
912+ (['scope1' , 'scope2' ], {frozenset (['scope1' ])}, True ),
913+ (['scope1' , 'scope2' ], {frozenset (['scope1' ]), frozenset (['scope2' ])}, True ),
914+ (['scope1' , 'scope2' ], {frozenset (['scope1' , 'scope2' ])}, True ),
915+ (['scope1' ], {frozenset (['scope1' ]), frozenset (['scope2' , 'scope3' ])}, True ),
916+ (['scope2' ], {frozenset (['scope1' ]), frozenset (['scope2' , 'scope3' ])}, False ),
917+ (['scope2' , 'scope3' ], {frozenset (['scope1' ]), frozenset (['scope2' , 'scope3' ])}, True ),
918+ ])
919+ def test_are_scopes_sufficient (authorized_scopes , sufficient_scopes , is_valid ):
920+ assert users_id_token ._are_scopes_sufficient (authorized_scopes , sufficient_scopes ) is is_valid
921+
894922if __name__ == '__main__' :
895923 unittest .main ()
0 commit comments