Skip to content

Commit a4c8d63

Browse files
authored
Creates Strava transfer service class and adds configuration to base transfer service class (#39)
Closes #18 **State as arg**: When adding `pardner` to `pardner-site`, I realized it's necessary to instantiate a transfer service with an existing `state` variable defined, so I added that as an argument to the constructor. **Configuring scope**: `requests_oauthlib` always sends requests with scopes separated with a space, but Strava (for example) specifies the scopes need to be comma-separated. **Adding default logic to base class**: The `fetch_token` and `authorization_url` methods will almost always have the same functionality regardless of service, so I filled them out in the base class as well. **Tests**: extracted common test fixtures and moved some tests to base transfer service test file.
1 parent 236b309 commit a4c8d63

14 files changed

Lines changed: 307 additions & 99 deletions

File tree

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ show_error_codes = true
3131
warn_return_any = true
3232
strict_optional = true
3333
disallow_incomplete_defs = true
34-
exclude_gitignore = true
3534
exclude = ["tests"]
3635

3736
[tool.ruff]

src/pardner/services/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@
55
from pardner.services.base import (
66
UnsupportedVerticalException as UnsupportedVerticalException,
77
)
8+
from pardner.services.strava import StravaTransferService as StravaTransferService
89
from pardner.services.tumblr import TumblrTransferService as TumblrTransferService

src/pardner/services/base.py

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from requests_oauthlib import OAuth2Session
55

6+
from pardner.services.utils import scope_as_set, scope_as_string
67
from pardner.verticals import Vertical
78

89

@@ -47,6 +48,7 @@ def __init__(
4748
client_secret: str,
4849
redirect_uri: str,
4950
supported_verticals: set[Vertical],
51+
state: Optional[str] = None,
5052
verticals: set[Vertical] = set(),
5153
) -> None:
5254
"""
@@ -58,28 +60,42 @@ def __init__(
5860
:param client_secret: The `client_secret` paired to the `client_id`.
5961
:param redirect_uri: The registered callback URI.
6062
:param supported_verticals: The `Vertical`s that can be fetched on the service.
63+
:param state: State string used to prevent CSRF and identify flow.
6164
:param verticals: The `Vertical`s for which the transfer service has
6265
appropriate scope to fetch.
6366
"""
64-
self._oAuth2Session = OAuth2Session(
65-
client_id=client_id, redirect_uri=redirect_uri
66-
)
6767
self._client_secret = client_secret
6868
self._supported_verticals = supported_verticals
6969
self._service_name = service_name
7070
self._verticals = verticals
71+
self._oAuth2Session = OAuth2Session(
72+
client_id=client_id, redirect_uri=redirect_uri, state=state
73+
)
74+
self.scope = self.scope_for_verticals(verticals)
7175

7276
@property
7377
def name(self) -> str:
7478
return self._service_name
7579

7680
@property
7781
def scope(self) -> set[str]:
78-
return self._oAuth2Session.scope if self._oAuth2Session.scope else set()
82+
return (
83+
scope_as_set(self._oAuth2Session.scope)
84+
if self._oAuth2Session.scope
85+
else set()
86+
)
7987

8088
@scope.setter
8189
def scope(self, new_scope: Iterable[str]) -> None:
82-
self._oAuth2Session.scope = set(new_scope)
90+
"""
91+
Sets the scope of the transfer service flow.
92+
Some services have specific requirements for the format of the scope
93+
string (e.g., scopes have to be comma separated, or `+` separated).
94+
95+
:param new_scope: The new scopes that should be set for the transfer
96+
service.
97+
"""
98+
self._oAuth2Session.scope = scope_as_string(new_scope)
8399

84100
@property
85101
def verticals(self) -> set[Vertical]:
@@ -118,22 +134,23 @@ def add_verticals(
118134
"""
119135
new_verticals = set(verticals) - self.verticals
120136
new_scopes = self.scope_for_verticals(new_verticals)
121-
original_scopes: set[str] = self.scope if self.scope else set()
122137

123-
if not new_scopes.issubset(original_scopes) and not should_reauth:
138+
if not new_scopes.issubset(self.scope) and not should_reauth:
124139
raise InsufficientScopeException(verticals, self.name)
125-
elif not new_scopes.issubset(original_scopes):
140+
elif not new_scopes.issubset(self.scope):
126141
self.verticals = new_verticals | self.verticals
127142
del self._oAuth2Session.access_token
128-
self.scope = original_scopes | new_scopes
143+
self.scope = self.scope | new_scopes
129144
return False
130145

131146
self.verticals = new_verticals | self.verticals
132147
return True
133148

134-
@abstractmethod
135149
def fetch_token(
136-
self, code: Optional[str] = None, authorization_response: Optional[str] = None
150+
self,
151+
code: Optional[str] = None,
152+
authorization_response: Optional[str] = None,
153+
include_client_id: bool = False,
137154
) -> dict[str, Any]:
138155
"""
139156
Once the end-user authorizes the application to access their data, the
@@ -147,20 +164,26 @@ def fetch_token(
147164
browser redirected to.
148165
:param authorization_response: the URL (with parameters) the end-user's browser
149166
redirected to after authorization.
167+
:param include_client_id: whether or not to send the client ID with the token request
150168
151169
:returns: the authorization URL and state, respectively.
152170
"""
153-
pass
171+
return self._oAuth2Session.fetch_token(
172+
token_url=self._token_url,
173+
code=code,
174+
authorization_response=authorization_response,
175+
include_client_id=include_client_id,
176+
client_secret=self._client_secret,
177+
)
154178

155-
@abstractmethod
156179
def authorization_url(self) -> tuple[str, str]:
157180
"""
158181
Builds the authorization URL and state. Once the end-user (i.e., resource owner)
159182
navigates to the authorization URL they can begin the authorization flow.
160183
161184
:returns: the authorization URL and state, respectively.
162185
"""
163-
pass
186+
return self._oAuth2Session.authorization_url(self._authorization_url)
164187

165188
@abstractmethod
166189
def scope_for_verticals(self, verticals: Iterable[Vertical]) -> set[str]:

src/pardner/services/strava.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
from typing import Any, Iterable, Optional, override
2+
3+
from pardner.services.base import BaseTransferService, UnsupportedVerticalException
4+
from pardner.services.utils import scope_as_set, scope_as_string
5+
from pardner.verticals import Vertical
6+
7+
8+
class StravaTransferService(BaseTransferService):
9+
"""
10+
Class responsible for obtaining end-user authorization to make requests to
11+
Strava's API.
12+
See API documentation: https://developers.strava.com/docs/reference/
13+
"""
14+
15+
_authorization_url = 'https://www.strava.com/oauth/authorize'
16+
_token_url = 'https://www.strava.com/oauth/token'
17+
18+
def __init__(
19+
self,
20+
client_id: str,
21+
client_secret: str,
22+
redirect_uri: str,
23+
state: Optional[str] = None,
24+
verticals: set[Vertical] = set(),
25+
) -> None:
26+
super().__init__(
27+
service_name='Strava',
28+
client_id=client_id,
29+
client_secret=client_secret,
30+
redirect_uri=redirect_uri,
31+
state=state,
32+
supported_verticals={Vertical.FeedPost},
33+
verticals=verticals,
34+
)
35+
36+
@property
37+
def scope(self) -> set[str]:
38+
return scope_as_set(self._oAuth2Session.scope, delimiter=',')
39+
40+
@scope.setter
41+
def scope(self, new_scope: Iterable[str] | str) -> None:
42+
self._oAuth2Session.scope = scope_as_string(new_scope, delimiter=',')
43+
44+
@override
45+
def fetch_token(
46+
self,
47+
code: Optional[str] = None,
48+
authorization_response: Optional[str] = None,
49+
include_client_id: bool = True,
50+
) -> dict[str, Any]:
51+
return super().fetch_token(code, authorization_response, include_client_id)
52+
53+
@override
54+
def scope_for_verticals(self, verticals: Iterable[Vertical]) -> set[str]:
55+
sub_scopes: set[str] = set()
56+
for vertical in verticals:
57+
if vertical not in self._supported_verticals:
58+
raise UnsupportedVerticalException([vertical], self._service_name)
59+
if vertical == Vertical.FeedPost:
60+
sub_scopes.update(['activity:read', 'profile:read_all'])
61+
return sub_scopes

src/pardner/services/tumblr.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Iterable, Optional
1+
from typing import Any, Iterable, Optional, override
22

33
from pardner.services import BaseTransferService
44
from pardner.verticals import Vertical
@@ -19,32 +19,29 @@ def __init__(
1919
client_id: str,
2020
client_secret: str,
2121
redirect_uri: str,
22+
state: Optional[str] = None,
2223
verticals: set[Vertical] = set(),
2324
) -> None:
2425
super().__init__(
2526
service_name='Tumblr',
2627
client_id=client_id,
2728
client_secret=client_secret,
2829
redirect_uri=redirect_uri,
30+
state=state,
2931
supported_verticals={Vertical.FeedPost},
3032
verticals=verticals,
3133
)
3234

35+
@override
3336
def scope_for_verticals(self, verticals: Iterable[Vertical]) -> set[str]:
3437
# Tumblr only needs 'base' for read access requests
3538
return {'base'}
3639

37-
def authorization_url(self) -> tuple[str, str]:
38-
return self._oAuth2Session.authorization_url(self._authorization_url)
39-
40+
@override
4041
def fetch_token(
41-
self, code: Optional[str] = None, authorization_response: Optional[str] = None
42+
self,
43+
code: Optional[str] = None,
44+
authorization_response: Optional[str] = None,
45+
include_client_id: bool = True,
4246
) -> dict[str, Any]:
43-
# Requires client_id
44-
return self._oAuth2Session.fetch_token(
45-
token_url=self._token_url,
46-
code=code,
47-
authorization_response=authorization_response,
48-
include_client_id=True,
49-
client_secret=self._client_secret,
50-
)
47+
return super().fetch_token(code, authorization_response, include_client_id)

src/pardner/services/utils.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from typing import Any
2+
3+
4+
def scope_as_string(scopes: Any, delimiter: str = ' ') -> str | None:
5+
"""
6+
Converts a sequence of individual scopes into a single scope string.
7+
8+
:param scopes: a sequence of scopes as strings or a scope string.
9+
:param delimiter: the string used to separate individual scopes. Defaults to single space.
10+
11+
:returns: a string containing all scopes.
12+
:raises :class:ValueError: if `scopes` is neither a string nor a sequence of strings
13+
"""
14+
if isinstance(scopes, str) or scopes is None:
15+
return scopes
16+
elif isinstance(scopes, (set, tuple, list)):
17+
return delimiter.join([str(s) for s in sorted(scopes)])
18+
raise ValueError(f'Invalid scope ({scopes}), must be string, tuple, set, or list.')
19+
20+
21+
def scope_as_set(scope: Any, delimiter: str = ' ') -> set[str]:
22+
"""
23+
Splits a scope with potentially more than one scope into a set of scopes.
24+
25+
:param scope: a string with one or more scopes.
26+
:param delimiter: the string used to separate individual scopes. Defaults to single space.
27+
28+
:returns: a set of scopes.
29+
"""
30+
if isinstance(scope, (tuple, list, set)):
31+
return {str(s) for s in scope}
32+
elif scope is None:
33+
return set()
34+
return set(scope.strip().split(delimiter))

tests/__init__.py

Whitespace-only changes.

tests/test_transfer_services/__init__.py

Whitespace-only changes.
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import pytest
2+
3+
from pardner.services.strava import StravaTransferService
4+
from pardner.services.tumblr import TumblrTransferService
5+
from pardner.verticals.base import Vertical
6+
7+
8+
@pytest.fixture
9+
def mock_oAuth2Session(mocker):
10+
mock_oauth2session_request = mocker.patch('requests_oauthlib.OAuth2Session.request')
11+
mock_client_parse_request_body_response = mocker.patch(
12+
'oauthlib.oauth2.rfc6749.clients.WebApplicationClient.parse_request_body_response'
13+
)
14+
return [mock_oauth2session_request, mock_client_parse_request_body_response]
15+
16+
17+
@pytest.fixture
18+
def mock_vertical():
19+
Vertical.NEW_VERTICAL = 'new_vertical'
20+
Vertical.NEW_VERTICAL_EXTRA_SCOPE = 'new_vertical_unsupported'
21+
22+
23+
@pytest.fixture
24+
def mock_tumblr_transfer_service(verticals=[Vertical.FeedPost]):
25+
return TumblrTransferService(
26+
'fake_client_id', 'fake_client_secret', 'https://redirect_uri', None, verticals
27+
)
28+
29+
30+
@pytest.fixture
31+
def mock_strava_transfer_service(verticals=[Vertical.FeedPost]):
32+
return StravaTransferService(
33+
'fake_client_id', 'fake_client_secret', 'https://redirect_uri', None, verticals
34+
)

0 commit comments

Comments
 (0)