Skip to content

Commit 376ab4b

Browse files
authored
Merge pull request #4 from yepcode/feature/YEP-2896
YEP-2896 Refresh access token on YepCode Run SDK
2 parents 9867b36 + b8fb79a commit 376ab4b

1 file changed

Lines changed: 36 additions & 10 deletions

File tree

yepcode_run/api/yepcode_api.py

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import base64
22
import json
3-
from typing import Optional, Dict, Any, List, Union, Tuple
4-
from datetime import datetime
3+
from typing import Optional, Dict, Any, List, Union
4+
from datetime import datetime, timezone
55
import requests
66
from urllib.parse import urljoin
77
import mimetypes
8+
import re
89

910
from .types import (
1011
YepCodeApiConfig,
@@ -135,14 +136,13 @@ def _client_id_from_access_token(self) -> str:
135136
payload = self.access_token.split(".")[1]
136137
payload += "=" * ((4 - len(payload) % 4) % 4)
137138
decoded_payload = json.loads(base64.b64decode(payload).decode())
138-
return decoded_payload["client_id"]
139+
return decoded_payload["clientId"]
139140
except Exception as e:
140141
raise ValueError(f"Failed to extract client_id from access token: {e}")
141142

142143
def _team_id_from_client_id(self) -> str:
143144
if not self.client_id:
144145
raise ValueError("Client ID is not set")
145-
import re
146146

147147
match = re.match(r"^sa-(.*)-[a-z0-9]{8}$", self.client_id)
148148
if not match:
@@ -155,6 +155,10 @@ def _get_base_url(self) -> str:
155155
return f"{self.api_host}/api/{self.team_id}/rest"
156156

157157
def _get_access_token(self) -> str:
158+
if not self.client_id or not self.client_secret:
159+
raise ValueError(
160+
"AccessToken has expired. Provide a new one or enable automatic refreshing by providing an apiToken or clientId and clientSecret."
161+
)
158162
try:
159163
auth_str = base64.b64encode(
160164
f"{self.client_id}:{self.client_secret}".encode()
@@ -183,13 +187,29 @@ def _get_access_token(self) -> str:
183187
except Exception as error:
184188
raise ValueError(f"Authentication failed: {str(error)}")
185189

190+
def _is_access_token_expired(self, access_token: str) -> bool:
191+
token_payload = access_token.split(".")[1]
192+
if not token_payload:
193+
return True
194+
195+
try:
196+
token_payload += "=" * ((4 - len(token_payload) % 4) % 4)
197+
decoded_token_payload = json.loads(base64.b64decode(token_payload).decode())
198+
expiration_time = decoded_token_payload["exp"]
199+
return (
200+
expiration_time is not None
201+
and expiration_time < datetime.now(timezone.utc).timestamp()
202+
)
203+
except Exception as e:
204+
return True
205+
186206
def _request(
187207
self, method: str, endpoint: str, options: Optional[Dict[str, Any]] = None
188208
) -> Any:
189209
if options is None:
190210
options = {}
191211

192-
if not self.access_token:
212+
if not self.access_token or self._is_access_token_expired(self.access_token):
193213
self._get_access_token()
194214

195215
headers = {
@@ -239,8 +259,8 @@ def _sanitize_date_param(date: Union[datetime, str, None]) -> Optional[str]:
239259
return None
240260
if isinstance(date, datetime):
241261
return date.isoformat().split(".")[0]
242-
if isinstance(date, str) and not date.match(
243-
r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}$"
262+
if isinstance(date, str) and not re.match(
263+
r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}$", date
244264
):
245265
raise ValueError(
246266
"Invalid date format. It must be a valid ISO 8601 date (ie: 2025-01-01T00:00:00)"
@@ -459,7 +479,9 @@ def get_object(self, name: str) -> requests.Response:
459479
}
460480
endpoint = f"/storage/objects/{name}"
461481
url = urljoin(f"{self._get_base_url()}/", endpoint.lstrip("/"))
462-
response = requests.get(url, headers=headers, stream=True, timeout=self.timeout / 1000)
482+
response = requests.get(
483+
url, headers=headers, stream=True, timeout=self.timeout / 1000
484+
)
463485
response.raise_for_status()
464486
return response
465487

@@ -475,8 +497,12 @@ def create_object(self, data: CreateStorageObjectInput) -> StorageObject:
475497
url = urljoin(f"{self._get_base_url()}/", endpoint.lstrip("/"))
476498
# Detect content type
477499
content_type, _ = mimetypes.guess_type(data.name)
478-
files = {"file": (data.name, data.file, content_type or "application/octet-stream")}
479-
response = requests.post(url, headers=headers, files=files, timeout=self.timeout / 1000)
500+
files = {
501+
"file": (data.name, data.file, content_type or "application/octet-stream")
502+
}
503+
response = requests.post(
504+
url, headers=headers, files=files, timeout=self.timeout / 1000
505+
)
480506
if not response.ok:
481507
try:
482508
error_response = response.json()

0 commit comments

Comments
 (0)