Skip to content

Commit e9ad537

Browse files
committed
fix: refactoring pilot logging code
1 parent f02efba commit e9ad537

6 files changed

Lines changed: 394 additions & 2 deletions

File tree

diracx-routers/pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ types = [
4848
]
4949

5050
[project.entry-points."diracx.services"]
51-
pilotlogs = "diracx.routers.pilot_logging.remote_logger:router"
51+
pilots = "diracx.routers.pilots:router"
5252
jobs = "diracx.routers.jobs:router"
5353
config = "diracx.routers.configuration:router"
5454
auth = "diracx.routers.auth:router"
@@ -57,7 +57,7 @@ auth = "diracx.routers.auth:router"
5757
[project.entry-points."diracx.access_policies"]
5858
WMSAccessPolicy = "diracx.routers.jobs.access_policies:WMSAccessPolicy"
5959
SandboxAccessPolicy = "diracx.routers.jobs.access_policies:SandboxAccessPolicy"
60-
PilotLogsAccessPolicy = "diracx.routers.pilot_logging.access_policies:PilotLogsAccessPolicy"
60+
PilotLogsAccessPolicy = "diracx.routers.pilots.access_policies:PilotLogsAccessPolicy"
6161

6262
# Minimum version of the client supported
6363
[project.entry-points."diracx.min_client_version"]
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from __future__ import annotations
2+
3+
from logging import getLogger
4+
5+
from ..fastapi_classes import DiracxRouter
6+
from .logging import router as logging_router
7+
8+
logger = getLogger(__name__)
9+
10+
router = DiracxRouter()
11+
router.include_router(logging_router)
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
from __future__ import annotations
2+
3+
from enum import StrEnum, auto
4+
from typing import Annotated, Callable
5+
6+
from fastapi import Depends, HTTPException, status
7+
8+
from diracx.core.properties import (
9+
GENERIC_PILOT,
10+
NORMAL_USER,
11+
OPERATOR,
12+
PILOT,
13+
SERVICE_ADMINISTRATOR,
14+
)
15+
from diracx.routers.access_policies import BaseAccessPolicy
16+
17+
from ..utils.users import AuthorizedUserInfo
18+
19+
20+
class ActionType(StrEnum):
21+
#: Create/update pilot log records
22+
CREATE = auto()
23+
#: delete pilot logs
24+
DELETE = auto()
25+
#: Search
26+
QUERY = auto()
27+
28+
29+
class PilotLogsAccessPolicy(BaseAccessPolicy):
30+
"""Rules:
31+
Only PILOT, GENERIC_PILOT, SERVICE_ADMINISTRATOR and OPERATOR can process log records.
32+
Policies for other actions to be determined.
33+
"""
34+
35+
@staticmethod
36+
async def policy(
37+
policy_name: str,
38+
user_info: AuthorizedUserInfo,
39+
/,
40+
*,
41+
action: ActionType | None = None,
42+
):
43+
44+
if action is None:
45+
raise HTTPException(
46+
status.HTTP_400_BAD_REQUEST, detail="Action is a mandatory argument"
47+
)
48+
49+
if GENERIC_PILOT in user_info.properties and action == ActionType.CREATE:
50+
return user_info
51+
if PILOT in user_info.properties and action == ActionType.CREATE:
52+
return user_info
53+
if NORMAL_USER in user_info.properties and action == ActionType.QUERY:
54+
return user_info
55+
if SERVICE_ADMINISTRATOR in user_info.properties:
56+
return user_info
57+
if OPERATOR in user_info.properties:
58+
return user_info
59+
60+
raise HTTPException(status.HTTP_403_FORBIDDEN, detail=user_info.properties)
61+
62+
63+
CheckPilotLogsPolicyCallable = Annotated[Callable, Depends(PilotLogsAccessPolicy.check)]
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
from __future__ import annotations
2+
3+
import datetime
4+
import logging
5+
6+
from fastapi import HTTPException, status
7+
from pydantic import BaseModel
8+
from sqlalchemy import select
9+
from sqlalchemy.exc import NoResultFound
10+
11+
from diracx.core.exceptions import InvalidQueryError
12+
from diracx.core.properties import OPERATOR, SERVICE_ADMINISTRATOR
13+
from diracx.db.sql.pilot_agents.schema import PilotAgents
14+
from diracx.db.sql.utils import BaseSQLDB
15+
16+
from ..dependencies import PilotLogsDB
17+
from ..fastapi_classes import DiracxRouter
18+
from ..utils.users import AuthorizedUserInfo
19+
from .access_policies import ActionType, CheckPilotLogsPolicyCallable
20+
21+
logger = logging.getLogger(__name__)
22+
router = DiracxRouter()
23+
24+
25+
class LogLine(BaseModel):
26+
line_no: int
27+
line: str
28+
29+
30+
class LogMessage(BaseModel):
31+
pilot_stamp: str
32+
lines: list[LogLine]
33+
vo: str
34+
35+
36+
class DateRange(BaseModel):
37+
min: str | None = None # expects a string in ISO 8601 ("%Y-%m-%dT%H:%M:%S.%f%z")
38+
max: str | None = None # expects a string in ISO 8601 ("%Y-%m-%dT%H:%M:%S.%f%z")
39+
40+
41+
@router.post("/")
42+
async def send_message(
43+
data: LogMessage,
44+
pilot_logs_db: PilotLogsDB,
45+
check_permissions: CheckPilotLogsPolicyCallable,
46+
) -> int:
47+
48+
logger.warning(f"Message received '{data}'")
49+
user_info = await check_permissions(action=ActionType.CREATE)
50+
pilot_id = 0 # need to get pilot id from pilot_stamp (via PilotAgentsDB)
51+
# also add a timestamp to be able to select and delete logs based on pilot creation dates, even if corresponding
52+
# pilots have been already deleted from PilotAgentsDB (so the logs can live longer than pilots).
53+
submission_time = datetime.datetime.fromtimestamp(0, datetime.timezone.utc)
54+
pilot_agents_db = BaseSQLDB.available_implementations("PilotAgentsDB")[0]
55+
url = BaseSQLDB.available_urls()["PilotAgentsDB"]
56+
db = pilot_agents_db(url)
57+
58+
try:
59+
async with db.engine_context():
60+
async with db:
61+
stmt = select(PilotAgents.pilot_id, PilotAgents.submission_time).where(
62+
PilotAgents.pilot_stamp == data.pilot_stamp
63+
)
64+
pilot_id, submission_time = (await db.conn.execute(stmt)).one()
65+
except NoResultFound as exc:
66+
logger.error(
67+
f"Cannot determine PilotID for requested PilotStamp: {data.pilot_stamp}, Error: {exc}."
68+
)
69+
raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc
70+
71+
docs = []
72+
for line in data.lines:
73+
docs.append(
74+
{
75+
"PilotStamp": data.pilot_stamp,
76+
"PilotID": pilot_id,
77+
"SubmissionTime": submission_time,
78+
"VO": user_info.vo,
79+
"LineNumber": line.line_no,
80+
"Message": line.line,
81+
}
82+
)
83+
await pilot_logs_db.bulk_insert(pilot_logs_db.index_name(pilot_id), docs)
84+
return pilot_id
85+
86+
87+
@router.get("/logs")
88+
async def get_logs(
89+
pilot_id: int,
90+
db: PilotLogsDB,
91+
check_permissions: CheckPilotLogsPolicyCallable,
92+
) -> list[dict]:
93+
94+
logger.warning(f"Retrieving logs for pilot ID '{pilot_id}'")
95+
user_info = await check_permissions(action=ActionType.QUERY)
96+
97+
# here, users with privileged properties will see logs from all VOs. Is it what we want ?
98+
search_params = [{"parameter": "PilotID", "operator": "eq", "value": pilot_id}]
99+
if _non_privileged(user_info):
100+
search_params.append(
101+
{"parameter": "VO", "operator": "eq", "value": user_info.vo}
102+
)
103+
result = await db.search(
104+
["Message"],
105+
search_params,
106+
[{"parameter": "LineNumber", "direction": "asc"}],
107+
)
108+
if not result:
109+
return [{"Message": f"No logs for pilot ID = {pilot_id}"}]
110+
return result
111+
112+
113+
@router.delete("/logs")
114+
async def delete(
115+
pilot_id: int,
116+
data: DateRange,
117+
db: PilotLogsDB,
118+
check_permissions: CheckPilotLogsPolicyCallable,
119+
) -> str:
120+
"""Delete either logs for a specific PilotID or a creation date range.
121+
Non-privileged users can only delete log files within their own VO.
122+
"""
123+
message = "no-op"
124+
user_info = await check_permissions(action=ActionType.DELETE)
125+
non_privil_params = {"parameter": "VO", "operator": "eq", "value": user_info.vo}
126+
127+
# id pilot_id is provided we ignore data.min and data.max
128+
if data.min and data.max and not pilot_id:
129+
raise InvalidQueryError(
130+
"This query requires a range operator definition in DiracX"
131+
)
132+
133+
if pilot_id:
134+
search_params = [{"parameter": "PilotID", "operator": "eq", "value": pilot_id}]
135+
if _non_privileged(user_info):
136+
search_params.append(non_privil_params)
137+
await db.delete(search_params)
138+
message = f"Logs for pilot ID '{pilot_id}' successfully deleted"
139+
140+
elif data.min:
141+
logger.warning(f"Deleting logs for pilots with submission data >='{data.min}'")
142+
search_params = [
143+
{"parameter": "SubmissionTime", "operator": "gt", "value": data.min}
144+
]
145+
if _non_privileged(user_info):
146+
search_params.append(non_privil_params)
147+
await db.delete(search_params)
148+
message = f"Logs for for pilots with submission data >='{data.min}' successfully deleted"
149+
150+
return message
151+
152+
153+
def _non_privileged(user_info: AuthorizedUserInfo):
154+
return (
155+
SERVICE_ADMINISTRATOR not in user_info.properties
156+
and OPERATOR not in user_info.properties
157+
)
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
from __future__ import annotations
2+
3+
from contextlib import nullcontext
4+
from unittest.mock import MagicMock
5+
6+
import pytest
7+
from fastapi import HTTPException
8+
9+
from diracx.core.properties import (
10+
GENERIC_PILOT,
11+
NORMAL_USER,
12+
OPERATOR,
13+
PILOT,
14+
SERVICE_ADMINISTRATOR,
15+
)
16+
from diracx.routers.pilots.access_policies import (
17+
ActionType,
18+
PilotLogsAccessPolicy,
19+
)
20+
21+
22+
@pytest.mark.parametrize(
23+
"user, action, expectation",
24+
[
25+
(PILOT, ActionType.CREATE, nullcontext()),
26+
(PILOT, ActionType.QUERY, pytest.raises(HTTPException, match="403")),
27+
(PILOT, ActionType.DELETE, pytest.raises(HTTPException, match="403")),
28+
(GENERIC_PILOT, ActionType.CREATE, nullcontext()),
29+
(GENERIC_PILOT, ActionType.QUERY, pytest.raises(HTTPException, match="403")),
30+
(GENERIC_PILOT, ActionType.DELETE, pytest.raises(HTTPException, match="403")),
31+
(SERVICE_ADMINISTRATOR, ActionType.CREATE, nullcontext()),
32+
(SERVICE_ADMINISTRATOR, ActionType.QUERY, nullcontext()),
33+
(SERVICE_ADMINISTRATOR, ActionType.DELETE, nullcontext()),
34+
(OPERATOR, ActionType.CREATE, nullcontext()),
35+
(OPERATOR, ActionType.QUERY, nullcontext()),
36+
(OPERATOR, ActionType.DELETE, nullcontext()),
37+
(NORMAL_USER, ActionType.CREATE, pytest.raises(HTTPException, match="403")),
38+
(NORMAL_USER, ActionType.QUERY, nullcontext()),
39+
(NORMAL_USER, ActionType.DELETE, pytest.raises(HTTPException, match="403")),
40+
(
41+
"malicious_user",
42+
ActionType.CREATE,
43+
pytest.raises(HTTPException, match="403"),
44+
),
45+
("malicious_user", ActionType.QUERY, pytest.raises(HTTPException, match="403")),
46+
(
47+
"malicious_user",
48+
ActionType.DELETE,
49+
pytest.raises(HTTPException, match="403"),
50+
),
51+
("any_user", None, pytest.raises(HTTPException, match="400")),
52+
],
53+
)
54+
async def test_access_policies(user, action, expectation):
55+
user_info = MagicMock()
56+
user_info.properties = [user]
57+
with expectation:
58+
ret = await PilotLogsAccessPolicy.policy(
59+
"PilotLogsAccessPolicy", user_info, action=action
60+
)
61+
assert user in ret.properties

0 commit comments

Comments
 (0)