diff --git a/src/DIRAC/WorkloadManagementSystem/Client/JobMonitoringClient.py b/src/DIRAC/WorkloadManagementSystem/Client/JobMonitoringClient.py index e899c0154d1..e0eda68d8e7 100755 --- a/src/DIRAC/WorkloadManagementSystem/Client/JobMonitoringClient.py +++ b/src/DIRAC/WorkloadManagementSystem/Client/JobMonitoringClient.py @@ -79,3 +79,9 @@ def getJobsStates(self, jobIDs): if res["OK"]: res["Value"] = strToIntDict(res["Value"]) return res + + def getInputData(self, jobIDs): + res = self._getRPC().getInputData(jobIDs) + if res["OK"] and isinstance(res["Value"], dict): + res["Value"] = strToIntDict(res["Value"]) + return res diff --git a/src/DIRAC/WorkloadManagementSystem/DB/JobDB.py b/src/DIRAC/WorkloadManagementSystem/DB/JobDB.py index 90f95a7f7fc..0d82be10fd6 100755 --- a/src/DIRAC/WorkloadManagementSystem/DB/JobDB.py +++ b/src/DIRAC/WorkloadManagementSystem/DB/JobDB.py @@ -11,8 +11,11 @@ * *CompressJDLs*: Enable compression of JDLs when they are stored in the database, default *False*. """ +from __future__ import annotations + import datetime import operator +from typing import overload from DIRAC.ConfigurationSystem.Client.Helpers.Registry import getVOForGroup from DIRAC.ConfigurationSystem.Client.Helpers.Resources import getSiteTier @@ -20,7 +23,14 @@ from DIRAC.Core.Utilities.ClassAd.ClassAdLight import ClassAd from DIRAC.Core.Utilities.Decorators import deprecated from DIRAC.Core.Utilities.DErrno import EWMSJMAN, EWMSSUBM, cmpError -from DIRAC.Core.Utilities.ReturnValues import S_ERROR, S_OK, convertToReturnValue, returnValueOrRaise, SErrorException +from DIRAC.Core.Utilities.ReturnValues import ( + S_ERROR, + S_OK, + convertToReturnValue, + returnValueOrRaise, + SErrorException, + DReturnType, +) from DIRAC.FrameworkSystem.Client.Logger import contextLogger from DIRAC.ResourceStatusSystem.Client.SiteStatus import SiteStatus from DIRAC.WorkloadManagementSystem.Client import JobMinorStatus, JobStatus @@ -320,23 +330,42 @@ def getJobOptParameters(self, jobID, paramList=None): ############################################################################# - def getInputData(self, jobID): + @overload + def getInputData(self, jobID: int | str) -> DReturnType[list[str]]: + ... + + @overload + def getInputData(self, jobID: list[int | str]) -> DReturnType[dict[int, list[str]]]: + ... + + def getInputData(self, jobID: int | str | list[int | str]) -> DReturnType[list[str] | dict[int, list[str]]]: """Get input data for the given job""" - ret = self._escapeString(jobID) - if not ret["OK"]: - return ret - jobID = ret["Value"] - cmd = f"SELECT LFN FROM InputData WHERE JobID={jobID}" + if isinstance(jobID, (int, str)): + ret = self._escapeString(jobID) + if not ret["OK"]: + return ret + jobID = ret["Value"] + query = f"JobID={jobID}" + result = [] + else: + job_ids = {int(i) for i in jobID} + query = f"JobID IN ({','.join(map(str, job_ids))})" + result = {i: [] for i in job_ids} + cmd = f"SELECT JobID, LFN FROM InputData WHERE {query}" res = self._query(cmd) if not res["OK"]: return res - inputData = [i[0] for i in res["Value"] if i[0].strip()] - for index, lfn in enumerate(inputData): + for jid, lfn in res["Value"]: + lfn = lfn.strip() if lfn.lower().startswith("lfn:"): - inputData[index] = lfn[4:] + lfn = lfn[4:] + if isinstance(result, list): + result.append(lfn) + else: + result[jid].append(lfn) - return S_OK(inputData) + return S_OK(result) ############################################################################# def setInputData(self, jobID, inputData): diff --git a/src/DIRAC/WorkloadManagementSystem/DB/tests/Test_JobDB.py b/src/DIRAC/WorkloadManagementSystem/DB/tests/Test_JobDB.py index ffa4405a57c..97c9c0e4969 100644 --- a/src/DIRAC/WorkloadManagementSystem/DB/tests/Test_JobDB.py +++ b/src/DIRAC/WorkloadManagementSystem/DB/tests/Test_JobDB.py @@ -28,7 +28,7 @@ def test_getInputData(jobDB: JobDB): """Test the getInputData method from JobDB""" # Arrange jobDB._escapeString = MagicMock(return_value=S_OK()) - jobDB._query = MagicMock(return_value=S_OK((("/vo/user/lfn1",), ("LFN:/vo/user/lfn2",)))) + jobDB._query = MagicMock(return_value=S_OK([(1234, "/vo/user/lfn1"), (1234, "LFN:/vo/user/lfn2")])) # Act res = jobDB.getInputData(1234) diff --git a/src/DIRAC/WorkloadManagementSystem/Service/JobMonitoringHandler.py b/src/DIRAC/WorkloadManagementSystem/Service/JobMonitoringHandler.py index f97667fd460..5c05740ff1c 100755 --- a/src/DIRAC/WorkloadManagementSystem/Service/JobMonitoringHandler.py +++ b/src/DIRAC/WorkloadManagementSystem/Service/JobMonitoringHandler.py @@ -514,7 +514,7 @@ def export_getJobHeartBeatData(cls, jobID): return cls.jobDB.getHeartBeatData(jobID) ############################################################################## - types_getInputData = [int] + types_getInputData = [(int, list)] @classmethod def export_getInputData(cls, jobID):