Skip to content

Commit 461a65b

Browse files
authored
Merge commit from fork
[v9.1] do not use eval in RequestDB
2 parents b0eca8f + 8ec1107 commit 461a65b

2 files changed

Lines changed: 90 additions & 23 deletions

File tree

src/DIRAC/RequestManagementSystem/DB/RequestDB.py

Lines changed: 53 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import datetime
1616
import errno
1717
import random
18-
1918
from urllib.parse import quote_plus
2019

2120
from sqlalchemy import (
@@ -32,6 +31,7 @@
3231
create_engine,
3332
distinct,
3433
func,
34+
inspect,
3535
)
3636
from sqlalchemy.exc import SQLAlchemyError
3737
from sqlalchemy.orm import backref, joinedload, registry, relationship, sessionmaker
@@ -187,6 +187,38 @@ class RequestDB:
187187
db holding requests
188188
"""
189189

190+
@staticmethod
191+
def _get_column(table_name, column_name):
192+
"""Resolve supported ORM column attributes without evaluating input."""
193+
194+
models = {"Request": Request, "Operation": Operation}
195+
aliases = {"Status": "_Status"}
196+
197+
model = models.get(table_name)
198+
if model is None:
199+
raise ValueError(f"Unknown table '{table_name}'")
200+
201+
resolved_name = aliases.get(column_name, column_name)
202+
if resolved_name not in inspect(model).column_attrs:
203+
raise ValueError(f"Unknown {table_name} attribute '{column_name}'")
204+
205+
return getattr(model, resolved_name)
206+
207+
@classmethod
208+
def _apply_web_filter(cls, query, table_name, column_name, value):
209+
column = cls._get_column(table_name, column_name)
210+
if isinstance(value, list):
211+
return query.filter(column.in_(value))
212+
return query.filter(column == value)
213+
214+
@classmethod
215+
def _get_order_expression(cls, table_name, column_name, direction):
216+
column = cls._get_column(table_name, column_name)
217+
normalized_direction = direction.lower()
218+
if normalized_direction not in {"asc", "desc"}:
219+
raise ValueError(f"Unknown sort direction '{direction}'")
220+
return getattr(column, normalized_direction)()
221+
190222
def __getDBConnectionInfo(self, fullname):
191223
"""Collect from the CS all the info needed to connect to the DB.
192224
This should be in a base class eventually
@@ -704,13 +736,12 @@ def getRequestSummaryWeb(self, selectDict, sortList, startItem, maxItems):
704736
elif key == "Status":
705737
key = "_Status"
706738

707-
if isinstance(value, list):
708-
summaryQuery = summaryQuery.filter(eval(f"{tableName}.{key}.in_({value})"))
709-
else:
710-
summaryQuery = summaryQuery.filter(eval(f"{tableName}.{key}") == value)
739+
summaryQuery = self._apply_web_filter(summaryQuery, tableName, key, value)
711740

712741
if sortList:
713-
summaryQuery = summaryQuery.order_by(eval(f"Request.{sortList[0][0]}.{sortList[0][1].lower()}()"))
742+
summaryQuery = summaryQuery.order_by(
743+
self._get_order_expression("Request", sortList[0][0], sortList[0][1])
744+
)
714745

715746
try:
716747
requestLists = summaryQuery.all()
@@ -744,6 +775,8 @@ def getRequestSummaryWeb(self, selectDict, sortList, startItem, maxItems):
744775
resultDict["TotalRecords"] = nRequests
745776

746777
return S_OK(resultDict)
778+
except ValueError as e:
779+
return S_ERROR(str(e))
747780
#
748781
except Exception as e:
749782
self.log.exception("getRequestSummaryWeb: unexpected exception", lException=e)
@@ -763,17 +796,15 @@ def getRequestCountersWeb(self, groupingAttribute, selectDict):
763796

764797
session = self.DBSession()
765798

766-
if groupingAttribute == "Type":
767-
groupingAttribute = "Operation.Type"
768-
elif groupingAttribute == "Status":
769-
groupingAttribute = "Request._Status"
770-
else:
771-
groupingAttribute = f"Request.{groupingAttribute}"
772-
773799
try:
800+
if groupingAttribute == "Type":
801+
groupingColumn = self._get_column("Operation", "Type")
802+
else:
803+
groupingColumn = self._get_column("Request", groupingAttribute)
804+
774805
summaryQuery = session.query(
775-
eval(groupingAttribute), func.count(Request.RequestID) # pylint: disable=not-callable,no-member
776-
)
806+
groupingColumn, func.count(Request.RequestID)
807+
) # pylint: disable=not-callable,no-member
777808

778809
for key, value in selectDict.items():
779810
if key == "ToDate":
@@ -788,12 +819,9 @@ def getRequestCountersWeb(self, groupingAttribute, selectDict):
788819
elif key == "Status":
789820
key = "_Status"
790821

791-
if isinstance(value, list):
792-
summaryQuery = summaryQuery.filter(eval(f"{objectType}.{key}.in_({value})"))
793-
else:
794-
summaryQuery = summaryQuery.filter(eval(f"{objectType}.{key}") == value)
822+
summaryQuery = self._apply_web_filter(summaryQuery, objectType, key, value)
795823

796-
summaryQuery = summaryQuery.group_by(eval(groupingAttribute))
824+
summaryQuery = summaryQuery.group_by(groupingColumn)
797825

798826
try:
799827
requestLists = summaryQuery.all()
@@ -805,6 +833,8 @@ def getRequestCountersWeb(self, groupingAttribute, selectDict):
805833

806834
return S_OK(resultDict)
807835

836+
except ValueError as e:
837+
return S_ERROR(str(e))
808838
except Exception as e:
809839
self.log.exception("getRequestSummaryWeb: unexpected exception", lException=e)
810840
return S_ERROR(f"getRequestSummaryWeb: unexpected exception : {e}")
@@ -817,11 +847,11 @@ def getDistinctValues(self, tableName, columnName):
817847

818848
session = self.DBSession()
819849
distinctValues = []
820-
if columnName == "Status":
821-
columnName = "_Status"
822850
try:
823-
result = session.query(distinct(eval(f"{tableName}.{columnName}"))).all()
851+
result = session.query(distinct(self._get_column(tableName, columnName))).all()
824852
distinctValues = [dist[0] for dist in result]
853+
except ValueError as e:
854+
return S_ERROR(str(e))
825855
except NoResultFound:
826856
pass
827857
except Exception as e:

src/DIRAC/RequestManagementSystem/DB/test/Test_RequestDB.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
from DIRAC import gLogger, S_OK
1313

1414
from DIRAC.RequestManagementSystem.DB import RequestDB
15+
from DIRAC.RequestManagementSystem.Client.File import File
16+
from DIRAC.RequestManagementSystem.Client.Operation import Operation
17+
from DIRAC.RequestManagementSystem.Client.Request import Request
1518

1619
from DIRAC.RequestManagementSystem.DB.test.RMSTestScenari import ( # pylint: disable=unused-import
1720
test_dirty,
@@ -39,3 +42,37 @@ def mock_requestDB__init__(self):
3942
db.createTables()
4043

4144
yield db
45+
46+
47+
def test_web_queries_reject_unknown_attributes(reqDB):
48+
request = Request({"RequestName": "web-summary"})
49+
operation = Operation({"Type": "RemoveReplica", "TargetSE": "CERN-USER"})
50+
operation += File({"LFN": "/lhcb/user/c/cibak/web-summary"})
51+
request += operation
52+
53+
put = reqDB.putRequest(request)
54+
assert put["OK"], put
55+
56+
summary = reqDB.getRequestSummaryWeb({"Type": "RemoveReplica"}, [("RequestID", "ASC")], 0, 10)
57+
assert summary["OK"], summary
58+
assert summary["Value"]["TotalRecords"] == 1, summary
59+
60+
counters = reqDB.getRequestCountersWeb("Type", {"Status": "Waiting"})
61+
assert counters["OK"], counters
62+
assert counters["Value"] == {"RemoveReplica": 1}, counters
63+
64+
distinct = reqDB.getDistinctValues("Operation", "Type")
65+
assert distinct["OK"], distinct
66+
assert distinct["Value"] == ["RemoveReplica"], distinct
67+
68+
invalid_summary = reqDB.getRequestSummaryWeb({"__class__": "Request"}, [], 0, 10)
69+
assert not invalid_summary["OK"], invalid_summary
70+
assert invalid_summary["Message"] == "Unknown Request attribute '__class__'"
71+
72+
invalid_counters = reqDB.getRequestCountersWeb("__class__", {})
73+
assert not invalid_counters["OK"], invalid_counters
74+
assert invalid_counters["Message"] == "Unknown Request attribute '__class__'"
75+
76+
invalid_distinct = reqDB.getDistinctValues("Request", "__class__")
77+
assert not invalid_distinct["OK"], invalid_distinct
78+
assert invalid_distinct["Message"] == "Unknown Request attribute '__class__'"

0 commit comments

Comments
 (0)