Skip to content

Commit c281170

Browse files
committed
Split token auth mechanisms for two different groups of endpoints
1 parent 74032e1 commit c281170

2 files changed

Lines changed: 44 additions & 44 deletions

File tree

pyproject.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,6 @@ GitHub = "https://github.com/DiamondLightSource/python-murfey"
9696
"murfey.spa_inject" = "murfey.cli.inject_spa_processing:run"
9797
"murfey.spa_ispyb_entries" = "murfey.cli.spa_ispyb_messages:run"
9898
"murfey.transfer" = "murfey.cli.transfer:run"
99-
[project.entry-points."murfey.auth.token_validation"]
100-
"password" = "murfey.server.api.auth:password_token_validation"
10199
[project.entry-points."murfey.config.extraction"]
102100
"murfey_machine" = "murfey.util.config:get_extended_machine_config"
103101
[project.entry-points."murfey.workflows"]

src/murfey/server/api/auth.py

Lines changed: 44 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -127,10 +127,6 @@ def check_user(username: str) -> bool:
127127
return username in [u.username for u in users]
128128

129129

130-
def validate_instrument_server_token(timestamp: float) -> bool:
131-
return timestamp in instrument_server_tokens.keys()
132-
133-
134130
def validate_instrument_server_session_token(session_id: int, visit: str):
135131
with Session(engine) as murfey_db:
136132
session_data = murfey_db.exec(
@@ -141,46 +137,28 @@ def validate_instrument_server_session_token(session_id: int, visit: str):
141137
return visit == session_data[0].visit
142138

143139

144-
def password_token_validation(token: str):
145-
decoded_data = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
146-
# first check if the token has expired
147-
if expiry_time := decoded_data.get("expiry_time"):
148-
if expiry_time < time.time():
149-
raise JWTError
150-
if decoded_data.get("user"):
151-
if not check_user(decoded_data["user"]):
152-
raise JWTError
153-
elif decoded_data.get("session") is not None:
154-
if not validate_instrument_server_session_token(
155-
decoded_data["session"], decoded_data["visit"]
156-
):
157-
raise JWTError
158-
else:
159-
raise JWTError
160-
161-
162140
async def validate_token(token: Annotated[str, Depends(oauth2_scheme)]):
163141
try:
164-
if auth_url:
165-
headers = (
166-
{}
167-
if security_config.auth_type == "cookie"
168-
else {"Authorization": f"Bearer {token}"}
169-
)
170-
cookies = (
171-
{security_config.cookie_key: token}
172-
if security_config.auth_type == "cookie"
173-
else {}
174-
)
175-
async with aiohttp.ClientSession(cookies=cookies) as session:
176-
async with session.get(
177-
auth_url,
178-
headers=headers,
179-
) as response:
180-
success = response.status == 200
181-
validation_outcome = await response.json()
182-
if not (success and validation_outcome.get("valid")):
142+
try:
143+
if security_config.auth_type == "password":
144+
await validate_password_token(token)
145+
except JWTError:
146+
await validate_instrument_token(token)
147+
decoded_data = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
148+
# first check if the token has expired
149+
if expiry_time := decoded_data.get("expiry_time"):
150+
if expiry_time < time.time():
151+
raise JWTError
152+
if decoded_data.get("user"):
153+
if not check_user(decoded_data["user"]):
154+
raise JWTError
155+
elif decoded_data.get("session") is not None:
156+
if not validate_instrument_server_session_token(
157+
decoded_data["session"], decoded_data["visit"]
158+
):
183159
raise JWTError
160+
else:
161+
raise JWTError
184162
except JWTError:
185163
raise HTTPException(
186164
status_code=status.HTTP_401_UNAUTHORIZED,
@@ -190,6 +168,30 @@ async def validate_token(token: Annotated[str, Depends(oauth2_scheme)]):
190168
return None
191169

192170

171+
async def validate_password_token(token: Annotated[str, Depends(oauth2_scheme)]):
172+
if auth_url:
173+
headers = (
174+
{}
175+
if security_config.auth_type == "cookie"
176+
else {"Authorization": f"Bearer {token}"}
177+
)
178+
cookies = (
179+
{security_config.cookie_key: token}
180+
if security_config.auth_type == "cookie"
181+
else {}
182+
)
183+
async with aiohttp.ClientSession(cookies=cookies) as session:
184+
async with session.get(
185+
f"{auth_url}/validate_token",
186+
headers=headers,
187+
) as response:
188+
success = response.status == 200
189+
validation_outcome = await response.json()
190+
if not (success and validation_outcome.get("valid")):
191+
raise JWTError
192+
return None
193+
194+
193195
async def validate_instrument_token(
194196
token: Annotated[str, Depends(instrument_oauth2_scheme)]
195197
):
@@ -202,7 +204,7 @@ async def validate_instrument_token(
202204
else {"Authorization": f"Bearer {token}"}
203205
)
204206
async with session.get(
205-
security_config.instrument_auth_url,
207+
f"{security_config.instrument_auth_url}/validate_token",
206208
headers=headers,
207209
) as response:
208210
success = response.status == 200

0 commit comments

Comments
 (0)