Skip to content

Commit aba20d2

Browse files
authored
Merge pull request #93 from datakind/Validation-Errors
feat: added option for api auth
2 parents 5e87ef8 + 19aba9d commit aba20d2

4 files changed

Lines changed: 25 additions & 10 deletions

File tree

src/webapp/authn.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,10 @@ def get_api_key(
5353
)
5454

5555

56+
def check_creds(username: str, password: str) -> bool:
57+
return username == env_vars.get("USERNAME") and password == env_vars.get("PASSWORD")
58+
59+
5660
def verify_password(plain_password: str, hashed_password: str) -> bool:
5761
"""Verify a plain password against a hash. Includes a 2y/2b replacement since Laravel
5862
Generates hashes that start with 2y. The hashing scheme recognizes both."""

src/webapp/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
"INITIAL_API_KEY_ID": "",
1616
"CATALOG_NAME": "",
1717
"SQL_WAREHOUSE_ID": "",
18+
"USERNAME": "",
19+
"PASSWORD": "",
1820
}
1921

2022
# The INSTANCE_HOST is the private IP of CLoudSQL instance e.g. '127.0.0.1' ('172.17.0.1' if deployed to GAE Flex)

src/webapp/main.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import secrets
77
from fastapi import FastAPI, Depends, HTTPException, status, Security
88
from fastapi.responses import FileResponse
9+
from fastapi.security import OAuth2PasswordRequestForm
910
from pydantic import BaseModel
1011
from sqlalchemy.future import select
1112
from sqlalchemy import update
@@ -37,6 +38,7 @@
3738
create_access_token,
3839
get_api_key,
3940
get_api_key_hash,
41+
check_creds,
4042
)
4143

4244
# Set the logging
@@ -95,22 +97,27 @@ def read_root() -> Any:
9597
@app.post("/token-from-api-key")
9698
async def access_token_from_api_key(
9799
sql_session: Annotated[Session, Depends(get_session)],
100+
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
98101
api_key_enduser_tuple: str = Security(get_api_key),
99102
) -> Token:
100103
"""Generate a token from an API key."""
101104
local_session.set(sql_session)
105+
102106
user = authenticate_api_key(api_key_enduser_tuple, local_session.get())
103-
if not user:
107+
valid = check_creds(form_data.username, form_data.password)
108+
109+
if not user and not valid:
104110
raise HTTPException(
105111
status_code=status.HTTP_401_UNAUTHORIZED,
106-
detail="API key not valid",
112+
detail="Invalid API key and credentials",
107113
headers={"WWW-Authenticate": "X-API-KEY"},
108114
)
115+
email = user.email if user else form_data.username
109116
access_token_expires = timedelta(
110117
minutes=int(env_vars["ACCESS_TOKEN_EXPIRE_MINUTES"])
111118
)
112119
access_token = create_access_token(
113-
data={"sub": user.email}, expires_delta=access_token_expires
120+
data={"sub": email}, expires_delta=access_token_expires
114121
)
115122
return Token(access_token=access_token, token_type="bearer")
116123

src/webapp/main_test.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
get_session,
1414
ApiKeyTable,
1515
)
16+
from unittest.mock import patch
1617
from .authn import get_password_hash, get_api_key_hash
1718
from .test_helper import (
1819
DATAKINDER,
@@ -145,13 +146,14 @@ def test_get_root(client: TestClient):
145146

146147

147148
def test_retrieve_token_gen_from_api_key(client: TestClient):
148-
"""Test POST /token-from-api-key."""
149-
response = client.post(
150-
"/token-from-api-key",
151-
headers={"X-API-KEY": "key_1"},
152-
)
153-
assert response.status_code == 200
154-
assert response.json()["token_type"] == "bearer"
149+
with patch.dict("os.environ", {"USERNAME": "fake", "PASSWORD": "fake"}):
150+
response = client.post(
151+
"/token-from-api-key",
152+
headers={"X-API-KEY": "key_1"},
153+
data={"username": "fake", "password": "fake"},
154+
)
155+
assert response.status_code == 200
156+
assert response.json()["token_type"] == "bearer"
155157

156158

157159
def test_get_cross_isnt_users(client: TestClient):

0 commit comments

Comments
 (0)