Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,9 @@ def getJobsStates(self, jobIDs):
if res["OK"]:
res["Value"] = strToIntDict(res["Value"])
return res

def getInputData(self, jobIDs):
res = self._getRPC().getInputData(jobIDs)
if res["OK"] and isinstance(res["Value"], dict):
res["Value"] = strToIntDict(res["Value"])
return res
51 changes: 40 additions & 11 deletions src/DIRAC/WorkloadManagementSystem/DB/JobDB.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,26 @@
* *CompressJDLs*: Enable compression of JDLs when they are stored in the database, default *False*.

"""
from __future__ import annotations

import datetime
import operator
from typing import overload

from DIRAC.ConfigurationSystem.Client.Helpers.Registry import getVOForGroup
from DIRAC.ConfigurationSystem.Client.Helpers.Resources import getSiteTier
from DIRAC.Core.Base.DB import DB
from DIRAC.Core.Utilities.ClassAd.ClassAdLight import ClassAd
from DIRAC.Core.Utilities.Decorators import deprecated
from DIRAC.Core.Utilities.DErrno import EWMSJMAN, EWMSSUBM, cmpError
from DIRAC.Core.Utilities.ReturnValues import S_ERROR, S_OK, convertToReturnValue, returnValueOrRaise, SErrorException
from DIRAC.Core.Utilities.ReturnValues import (
S_ERROR,
S_OK,
convertToReturnValue,
returnValueOrRaise,
SErrorException,
DReturnType,
)
from DIRAC.FrameworkSystem.Client.Logger import contextLogger
from DIRAC.ResourceStatusSystem.Client.SiteStatus import SiteStatus
from DIRAC.WorkloadManagementSystem.Client import JobMinorStatus, JobStatus
Expand Down Expand Up @@ -320,23 +330,42 @@ def getJobOptParameters(self, jobID, paramList=None):

#############################################################################

def getInputData(self, jobID):
@overload
def getInputData(self, jobID: int | str) -> DReturnType[list[str]]:
...

@overload
def getInputData(self, jobID: list[int | str]) -> DReturnType[dict[int, list[str]]]:
...

def getInputData(self, jobID: int | str | list[int | str]) -> DReturnType[list[str] | dict[int, list[str]]]:
"""Get input data for the given job"""
ret = self._escapeString(jobID)
if not ret["OK"]:
return ret
jobID = ret["Value"]
cmd = f"SELECT LFN FROM InputData WHERE JobID={jobID}"
if isinstance(jobID, (int, str)):
ret = self._escapeString(jobID)
if not ret["OK"]:
return ret
jobID = ret["Value"]
query = f"JobID={jobID}"
result = []
else:
job_ids = {int(i) for i in jobID}
query = f"JobID IN ({','.join(map(str, job_ids))})"
result = {i: [] for i in job_ids}
cmd = f"SELECT JobID, LFN FROM InputData WHERE {query}"
res = self._query(cmd)
if not res["OK"]:
return res

inputData = [i[0] for i in res["Value"] if i[0].strip()]
for index, lfn in enumerate(inputData):
for jid, lfn in res["Value"]:
lfn = lfn.strip()
if lfn.lower().startswith("lfn:"):
inputData[index] = lfn[4:]
lfn = lfn[4:]
if isinstance(result, list):
result.append(lfn)
else:
result[jid].append(lfn)

return S_OK(inputData)
return S_OK(result)

#############################################################################
def setInputData(self, jobID, inputData):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_getInputData(jobDB: JobDB):
"""Test the getInputData method from JobDB"""
# Arrange
jobDB._escapeString = MagicMock(return_value=S_OK())
jobDB._query = MagicMock(return_value=S_OK((("/vo/user/lfn1",), ("LFN:/vo/user/lfn2",))))
jobDB._query = MagicMock(return_value=S_OK([(1234, "/vo/user/lfn1"), (1234, "LFN:/vo/user/lfn2")]))

# Act
res = jobDB.getInputData(1234)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ def export_getJobHeartBeatData(cls, jobID):
return cls.jobDB.getHeartBeatData(jobID)

##############################################################################
types_getInputData = [int]
types_getInputData = [(int, list)]

@classmethod
def export_getInputData(cls, jobID):
Expand Down
Loading