Skip to content

Commit 67c5af3

Browse files
committed
allow session build with tokens
1 parent 4d4874f commit 67c5af3

1 file changed

Lines changed: 39 additions & 8 deletions

File tree

firstrade/account.py

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,19 @@ class FTSession:
5858
5959
"""
6060

61-
def __init__(self, username: str, password: str, pin: str = "", email: str = "", phone: str = "", mfa_secret: str = "", profile_path: str | None = None, *, save_session: bool = False, debug: bool = False) -> None:
61+
def __init__(
62+
self,
63+
username: str = "",
64+
password: str = "",
65+
pin: str = "",
66+
email: str = "",
67+
phone: str = "",
68+
mfa_secret: str = "",
69+
profile_path: str | None = None,
70+
*,
71+
save_session: bool = False,
72+
debug: bool = False
73+
) -> None:
6274
"""Initialize a new instance of the FTSession class.
6375
6476
Args:
@@ -92,7 +104,7 @@ def __init__(self, username: str, password: str, pin: str = "", email: str = "",
92104
logging.getLogger("requests.packages.urllib3").setLevel(logging.DEBUG)
93105
logging.getLogger("requests.packages.urllib3").propagate = True
94106
self.t_token: str | None = None
95-
self.otp_options: list[dict[str, str]] | None = None
107+
self.otp_options: str | list[dict[str, str]] | None = None
96108
self.login_json: dict[str, str] = {}
97109
self.session = requests.Session()
98110

@@ -111,7 +123,7 @@ def login(self) -> bool:
111123
ftat: str = self._load_cookies()
112124
if ftat:
113125
self.session.headers["ftat"] = ftat
114-
response: requests.Response = self._request("get", url="https://api3x.firstrade.com/", timeout=10)
126+
response: requests.Response = self._request("get", url="https://api3x.firstrade.com/", timeout=10) # type: ignore[arg-type]
115127
self.session.headers["access-token"] = urls.access_token()
116128

117129
data: dict[str, str] = {
@@ -134,7 +146,7 @@ def login(self) -> bool:
134146
return False
135147
self.t_token: str | None = self.login_json.get("t_token")
136148
if not self.login_json.get("mfa"):
137-
self.otp_options: str | None = self.login_json.get("otp")
149+
self.otp_options = self.login_json.get("otp")
138150
if response.status_code != 200:
139151
raise LoginRequestError(response.status_code)
140152
if self.login_json["error"]:
@@ -191,7 +203,25 @@ def get_tokens(self) -> dict[str, str | bytes | dict[str, str] | None]:
191203
"cookies": cookies or "",
192204
}
193205

194-
def _load_cookies(self) -> str:
206+
def build_session_from_tokens(self, tokens: dict[str, str | bytes | dict[str, str] | None]) -> None:
207+
"""Build the session headers and cookies from provided tokens."""
208+
self.session.headers.update(urls.session_headers())
209+
if tokens:
210+
access_token = tokens.get("access-token")
211+
ftat_token = tokens.get("ftat")
212+
sid_token = tokens.get("sid")
213+
214+
if isinstance(access_token, (str, bytes)):
215+
self.session.headers.update({"access-token": access_token})
216+
if isinstance(ftat_token, (str, bytes)):
217+
self.session.headers.update({"ftat": ftat_token})
218+
if isinstance(sid_token, (str, bytes)):
219+
self.session.headers.update({"sid": sid_token})
220+
cookies = tokens.get("cookies")
221+
if isinstance(cookies, dict):
222+
self.session.cookies.update(cookies) # type: ignore[arg-type]
223+
224+
def _load_cookies(self) -> str | None:
195225
"""Check if session cookies were saved.
196226
197227
Returns
@@ -311,9 +341,9 @@ def _handle_secret_mfa(self, data: dict[str, str | None]) -> requests.Response:
311341
})
312342
return self._request("post", urls.verify_pin(), data=data)
313343

314-
def _request(self, method, url, **kwargs):
344+
def _request(self, method: str, url: str, **kwargs: object) -> requests.Response:
315345
"""Send HTTP request and log the full response content if debug=True."""
316-
resp = self.session.request(method, url, **kwargs)
346+
resp = self.session.request(method, url, **kwargs) # type: ignore[no-untyped-call]
317347

318348
if self.debug:
319349
# Suppress urllib3 / http.client debug so we only see this log
@@ -359,6 +389,7 @@ def __getattr__(self, name: str) -> object:
359389
"""
360390
return getattr(self.session, name)
361391

392+
362393
class FTAccountData:
363394
"""Dataclass for storing account information."""
364395

@@ -376,7 +407,7 @@ def __init__(self, session: requests.Session) -> None:
376407
response: requests.Response = self.session._request("get", url=urls.user_info())
377408
self.user_info: dict[str, object] = response.json()
378409
response: requests.Response = self.session._request("get", urls.account_list())
379-
if response.status_code != 200 or response.json()["error"] != "":
410+
if response.status_code != 200 or response.json()["error"]:
380411
raise AccountResponseError(response.json()["error"])
381412
self.all_accounts = response.json()
382413
for item in self.all_accounts["items"]:

0 commit comments

Comments
 (0)