Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ show_error_codes = true
warn_return_any = true
strict_optional = true
disallow_incomplete_defs = true
exclude_gitignore = true
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's set as this by default

exclude = ["tests"]

[tool.ruff]
Expand Down
1 change: 1 addition & 0 deletions src/pardner/services/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@
from pardner.services.base import (
UnsupportedVerticalException as UnsupportedVerticalException,
)
from pardner.services.strava import StravaTransferService as StravaTransferService
from pardner.services.tumblr import TumblrTransferService as TumblrTransferService
51 changes: 37 additions & 14 deletions src/pardner/services/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from requests_oauthlib import OAuth2Session

from pardner.services.utils import scope_as_set, scope_as_string
from pardner.verticals import Vertical


Expand Down Expand Up @@ -47,6 +48,7 @@ def __init__(
client_secret: str,
redirect_uri: str,
supported_verticals: set[Vertical],
state: Optional[str] = None,
verticals: set[Vertical] = set(),
) -> None:
"""
Expand All @@ -58,28 +60,42 @@ def __init__(
:param client_secret: The `client_secret` paired to the `client_id`.
:param redirect_uri: The registered callback URI.
:param supported_verticals: The `Vertical`s that can be fetched on the service.
:param state: State string used to prevent CSRF and identify flow.
:param verticals: The `Vertical`s for which the transfer service has
appropriate scope to fetch.
"""
self._oAuth2Session = OAuth2Session(
client_id=client_id, redirect_uri=redirect_uri
)
self._client_secret = client_secret
self._supported_verticals = supported_verticals
self._service_name = service_name
self._verticals = verticals
self._oAuth2Session = OAuth2Session(
client_id=client_id, redirect_uri=redirect_uri, state=state
)
self.scope = self.scope_for_verticals(verticals)

@property
def name(self) -> str:
return self._service_name

@property
def scope(self) -> set[str]:
return self._oAuth2Session.scope if self._oAuth2Session.scope else set()
return (
scope_as_set(self._oAuth2Session.scope)
if self._oAuth2Session.scope
else set()
)

@scope.setter
def scope(self, new_scope: Iterable[str]) -> None:
self._oAuth2Session.scope = set(new_scope)
"""
Sets the scope of the transfer service flow.
Some services have specific requirements for the format of the scope
string (e.g., scopes have to be comma separated, or `+` separated).

:param new_scope: The new scopes that should be set for the transfer
service.
"""
self._oAuth2Session.scope = scope_as_string(new_scope)

@property
def verticals(self) -> set[Vertical]:
Expand Down Expand Up @@ -118,22 +134,23 @@ def add_verticals(
"""
new_verticals = set(verticals) - self.verticals
new_scopes = self.scope_for_verticals(new_verticals)
original_scopes: set[str] = self.scope if self.scope else set()

if not new_scopes.issubset(original_scopes) and not should_reauth:
if not new_scopes.issubset(self.scope) and not should_reauth:
raise InsufficientScopeException(verticals, self.name)
elif not new_scopes.issubset(original_scopes):
elif not new_scopes.issubset(self.scope):
self.verticals = new_verticals | self.verticals
del self._oAuth2Session.access_token
self.scope = original_scopes | new_scopes
self.scope = self.scope | new_scopes
return False

self.verticals = new_verticals | self.verticals
return True

@abstractmethod
def fetch_token(
self, code: Optional[str] = None, authorization_response: Optional[str] = None
self,
code: Optional[str] = None,
authorization_response: Optional[str] = None,
include_client_id: bool = False,
) -> dict[str, Any]:
"""
Once the end-user authorizes the application to access their data, the
Expand All @@ -147,20 +164,26 @@ def fetch_token(
browser redirected to.
:param authorization_response: the URL (with parameters) the end-user's browser
redirected to after authorization.
:param include_client_id: whether or not to send the client ID with the token request

:returns: the authorization URL and state, respectively.
"""
pass
return self._oAuth2Session.fetch_token(
token_url=self._token_url,
code=code,
authorization_response=authorization_response,
include_client_id=include_client_id,
client_secret=self._client_secret,
)

@abstractmethod
def authorization_url(self) -> tuple[str, str]:
"""
Builds the authorization URL and state. Once the end-user (i.e., resource owner)
navigates to the authorization URL they can begin the authorization flow.

:returns: the authorization URL and state, respectively.
"""
pass
return self._oAuth2Session.authorization_url(self._authorization_url)

@abstractmethod
def scope_for_verticals(self, verticals: Iterable[Vertical]) -> set[str]:
Expand Down
61 changes: 61 additions & 0 deletions src/pardner/services/strava.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from typing import Any, Iterable, Optional, override

from pardner.services.base import BaseTransferService, UnsupportedVerticalException
from pardner.services.utils import scope_as_set, scope_as_string
from pardner.verticals import Vertical


class StravaTransferService(BaseTransferService):
"""
Class responsible for obtaining end-user authorization to make requests to
Strava's API.
See API documentation: https://developers.strava.com/docs/reference/
"""

_authorization_url = 'https://www.strava.com/oauth/authorize'
_token_url = 'https://www.strava.com/oauth/token'

def __init__(
self,
client_id: str,
client_secret: str,
redirect_uri: str,
state: Optional[str] = None,
verticals: set[Vertical] = set(),
) -> None:
super().__init__(
service_name='Strava',
client_id=client_id,
client_secret=client_secret,
redirect_uri=redirect_uri,
state=state,
supported_verticals={Vertical.FeedPost},
verticals=verticals,
)

@property
def scope(self) -> set[str]:
return scope_as_set(self._oAuth2Session.scope, delimiter=',')

@scope.setter
def scope(self, new_scope: Iterable[str] | str) -> None:
self._oAuth2Session.scope = scope_as_string(new_scope, delimiter=',')

@override
def fetch_token(
self,
code: Optional[str] = None,
authorization_response: Optional[str] = None,
include_client_id: bool = True,
) -> dict[str, Any]:
return super().fetch_token(code, authorization_response, include_client_id)

@override
def scope_for_verticals(self, verticals: Iterable[Vertical]) -> set[str]:
sub_scopes: set[str] = set()
for vertical in verticals:
if vertical not in self._supported_verticals:
raise UnsupportedVerticalException([vertical], self._service_name)
if vertical == Vertical.FeedPost:
sub_scopes.update(['activity:read', 'profile:read_all'])
return sub_scopes
23 changes: 10 additions & 13 deletions src/pardner/services/tumblr.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Iterable, Optional
from typing import Any, Iterable, Optional, override

from pardner.services import BaseTransferService
from pardner.verticals import Vertical
Expand All @@ -19,32 +19,29 @@ def __init__(
client_id: str,
client_secret: str,
redirect_uri: str,
state: Optional[str] = None,
verticals: set[Vertical] = set(),
) -> None:
super().__init__(
service_name='Tumblr',
client_id=client_id,
client_secret=client_secret,
redirect_uri=redirect_uri,
state=state,
supported_verticals={Vertical.FeedPost},
verticals=verticals,
)

@override
def scope_for_verticals(self, verticals: Iterable[Vertical]) -> set[str]:
# Tumblr only needs 'base' for read access requests
return {'base'}

def authorization_url(self) -> tuple[str, str]:
return self._oAuth2Session.authorization_url(self._authorization_url)

@override
def fetch_token(
self, code: Optional[str] = None, authorization_response: Optional[str] = None
self,
code: Optional[str] = None,
authorization_response: Optional[str] = None,
include_client_id: bool = True,
) -> dict[str, Any]:
# Requires client_id
return self._oAuth2Session.fetch_token(
token_url=self._token_url,
code=code,
authorization_response=authorization_response,
include_client_id=True,
client_secret=self._client_secret,
)
return super().fetch_token(code, authorization_response, include_client_id)
34 changes: 34 additions & 0 deletions src/pardner/services/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from typing import Any


def scope_as_string(scopes: Any, delimiter: str = ' ') -> str | None:
"""
Converts a sequence of individual scopes into a single scope string.

:param scopes: a sequence of scopes as strings or a scope string.
:param delimiter: the string used to separate individual scopes. Defaults to single space.

:returns: a string containing all scopes.
:raises :class:ValueError: if `scopes` is neither a string nor a sequence of strings
"""
if isinstance(scopes, str) or scopes is None:
return scopes
elif isinstance(scopes, (set, tuple, list)):
return delimiter.join([str(s) for s in sorted(scopes)])
raise ValueError(f'Invalid scope ({scopes}), must be string, tuple, set, or list.')


def scope_as_set(scope: Any, delimiter: str = ' ') -> set[str]:
"""
Splits a scope with potentially more than one scope into a set of scopes.

:param scope: a string with one or more scopes.
:param delimiter: the string used to separate individual scopes. Defaults to single space.

:returns: a set of scopes.
"""
if isinstance(scope, (tuple, list, set)):
return {str(s) for s in scope}
elif scope is None:
return set()
return set(scope.strip().split(delimiter))
Empty file added tests/__init__.py
Empty file.
Empty file.
34 changes: 34 additions & 0 deletions tests/test_transfer_services/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import pytest

from pardner.services.strava import StravaTransferService
from pardner.services.tumblr import TumblrTransferService
from pardner.verticals.base import Vertical


@pytest.fixture
def mock_oAuth2Session(mocker):
mock_oauth2session_request = mocker.patch('requests_oauthlib.OAuth2Session.request')
mock_client_parse_request_body_response = mocker.patch(
'oauthlib.oauth2.rfc6749.clients.WebApplicationClient.parse_request_body_response'
)
return [mock_oauth2session_request, mock_client_parse_request_body_response]


@pytest.fixture
def mock_vertical():
Vertical.NEW_VERTICAL = 'new_vertical'
Vertical.NEW_VERTICAL_EXTRA_SCOPE = 'new_vertical_unsupported'


@pytest.fixture
def mock_tumblr_transfer_service(verticals=[Vertical.FeedPost]):
return TumblrTransferService(
'fake_client_id', 'fake_client_secret', 'https://redirect_uri', None, verticals
)


@pytest.fixture
def mock_strava_transfer_service(verticals=[Vertical.FeedPost]):
return StravaTransferService(
'fake_client_id', 'fake_client_secret', 'https://redirect_uri', None, verticals
)
Loading