diff --git a/dirac-common/pyproject.toml b/dirac-common/pyproject.toml index a9f64fd2840..b66eba062cd 100644 --- a/dirac-common/pyproject.toml +++ b/dirac-common/pyproject.toml @@ -21,6 +21,8 @@ classifiers = [ ] dependencies = [ "typing-extensions>=4.0.0", + "diraccfg", + "pydantic>=2.0.0", ] dynamic = ["version"] diff --git a/dirac-common/src/DIRACCommon/Core/Utilities/JDL.py b/dirac-common/src/DIRACCommon/Core/Utilities/JDL.py new file mode 100644 index 00000000000..6a4e78827be --- /dev/null +++ b/dirac-common/src/DIRACCommon/Core/Utilities/JDL.py @@ -0,0 +1,199 @@ +"""Transformation classes around the JDL format.""" + +from diraccfg import CFG +from pydantic import ValidationError + +from DIRACCommon.Core.Utilities.ReturnValues import S_OK, S_ERROR +from DIRACCommon.Core.Utilities import List +from DIRACCommon.Core.Utilities.ClassAd.ClassAdLight import ClassAd +from DIRACCommon.WorkloadManagementSystem.Utilities.JobModel import BaseJobDescriptionModel + +ARGUMENTS = "Arguments" +BANNED_SITES = "BannedSites" +CPU_TIME = "CPUTime" +EXECUTABLE = "Executable" +EXECUTION_ENVIRONMENT = "ExecutionEnvironment" +GRID_CE = "GridCE" +INPUT_DATA = "InputData" +INPUT_DATA_POLICY = "InputDataPolicy" +INPUT_SANDBOX = "InputSandbox" +JOB_CONFIG_ARGS = "JobConfigArgs" +JOB_TYPE = "JobType" +JOB_GROUP = "JobGroup" +LOG_LEVEL = "LogLevel" +NUMBER_OF_PROCESSORS = "NumberOfProcessors" +MAX_NUMBER_OF_PROCESSORS = "MaxNumberOfProcessors" +MIN_NUMBER_OF_PROCESSORS = "MinNumberOfProcessors" +OUTPUT_DATA = "OutputData" +OUTPUT_PATH = "OutputPath" +OUTPUT_SE = "OutputSE" +PLATFORM = "Platform" +PRIORITY = "Priority" +STD_ERROR = "StdError" +STD_OUTPUT = "StdOutput" +OUTPUT_SANDBOX = "OutputSandbox" +JOB_NAME = "JobName" +SITE = "Site" +TAGS = "Tags" + +OWNER = "Owner" +OWNER_GROUP = "OwnerGroup" +VO = "VirtualOrganization" + +CREDENTIALS_FIELDS = {OWNER, OWNER_GROUP, VO} + + +def loadJDLAsCFG(jdl): + """ + Load a JDL as CFG + """ + + def cleanValue(value): + value = value.strip() + if value[0] == '"': + entries = [] + iPos = 1 + current = "" + state = "in" + while iPos < len(value): + if value[iPos] == '"': + if state == "in": + entries.append(current) + current = "" + state = "out" + elif state == "out": + current = current.strip() + if current not in (",",): + return S_ERROR("value seems a list but is not separated in commas") + current = "" + state = "in" + else: + current += value[iPos] + iPos += 1 + if state == "in": + return S_ERROR('value is opened with " but is not closed') + return S_OK(", ".join(entries)) + else: + return S_OK(value.replace('"', "")) + + def assignValue(key, value, cfg): + key = key.strip() + if len(key) == 0: + return S_ERROR("Invalid key name") + value = value.strip() + if not value: + return S_ERROR(f"No value for key {key}") + if value[0] == "{": + if value[-1] != "}": + return S_ERROR("Value '%s' seems a list but does not end in '}'" % (value)) + valList = List.fromChar(value[1:-1]) + for i in range(len(valList)): + result = cleanValue(valList[i]) + if not result["OK"]: + return S_ERROR(f"Var {key} : {result['Message']}") + valList[i] = result["Value"] + if valList[i] is None: + return S_ERROR(f"List value '{value}' seems invalid for item {i}") + value = ", ".join(valList) + else: + result = cleanValue(value) + if not result["OK"]: + return S_ERROR(f"Var {key} : {result['Message']}") + nV = result["Value"] + if nV is None: + return S_ERROR(f"Value '{value} seems invalid") + value = nV + cfg.setOption(key, value) + return S_OK() + + if jdl[0] == "[": + iPos = 1 + else: + iPos = 0 + key = "" + value = "" + action = "key" + insideLiteral = False + cfg = CFG() + while iPos < len(jdl): + char = jdl[iPos] + if char == ";" and not insideLiteral: + if key.strip(): + result = assignValue(key, value, cfg) + if not result["OK"]: + return result + key = "" + value = "" + action = "key" + elif char == "[" and not insideLiteral: + key = key.strip() + if not key: + return S_ERROR("Invalid key in JDL") + if value.strip(): + return S_ERROR(f"Key {key} seems to have a value and open a sub JDL at the same time") + result = loadJDLAsCFG(jdl[iPos:]) + if not result["OK"]: + return result + subCfg, subPos = result["Value"] + cfg.createNewSection(key, contents=subCfg) + key = "" + value = "" + action = "key" + insideLiteral = False + iPos += subPos + elif char == "=" and not insideLiteral: + if action == "key": + action = "value" + insideLiteral = False + else: + value += char + elif char == "]" and not insideLiteral: + key = key.strip() + if len(key) > 0: + result = assignValue(key, value, cfg) + if not result["OK"]: + return result + return S_OK((cfg, iPos)) + else: + if action == "key": + key += char + else: + value += char + if char == '"': + insideLiteral = not insideLiteral + iPos += 1 + + return S_OK((cfg, iPos)) + + +def dumpCFGAsJDL(cfg, level=1, tab=" "): + indent = tab * level + contents = [f"{tab * (level - 1)}["] + sections = cfg.listSections() + + for key in cfg: + if key in sections: + contents.append(f"{indent}{key} =") + contents.append(f"{dumpCFGAsJDL(cfg[key], level + 1, tab)};") + else: + val = List.fromChar(cfg[key]) + # Some attributes are never lists + if len(val) < 2 or key in [ARGUMENTS, EXECUTABLE, STD_OUTPUT, STD_ERROR]: + value = cfg[key] + try: + try_value = float(value) + contents.append(f"{tab * level}{key} = {value};") + except Exception: + contents.append(f'{tab * level}{key} = "{value}";') + else: + contents.append(f"{indent}{key} =") + contents.append("%s{" % indent) + for iPos in range(len(val)): + try: + value = float(val[iPos]) + except Exception: + val[iPos] = f'"{val[iPos]}"' + contents.append(",\n".join([f"{tab * (level + 1)}{value}" for value in val])) + contents.append("%s};" % indent) + contents.append(f"{tab * (level - 1)}]") + return "\n".join(contents) diff --git a/dirac-common/src/DIRACCommon/Core/Utilities/List.py b/dirac-common/src/DIRACCommon/Core/Utilities/List.py new file mode 100755 index 00000000000..ea8e121af22 --- /dev/null +++ b/dirac-common/src/DIRACCommon/Core/Utilities/List.py @@ -0,0 +1,127 @@ +"""Collection of DIRAC useful list related modules. + By default on Error they return None. +""" +import random +import sys +from typing import Any, TypeVar +from collections.abc import Iterable + +T = TypeVar("T") + + +def uniqueElements(aList: list) -> list: + """Utility to retrieve list of unique elements in a list (order is kept).""" + + # Use dict.fromkeys instead of set ensure the order is preserved + return list(dict.fromkeys(aList)) + + +def appendUnique(aList: list, anObject: Any): + """Append to list if object does not exist. + + :param aList: list of elements + :param anObject: object you want to append + """ + if anObject not in aList: + aList.append(anObject) + + +def fromChar(inputString: str, sepChar: str = ","): + """Generates a list splitting a string by the required character(s) + resulting string items are stripped and empty items are removed. + + :param inputString: list serialised to string + :param sepChar: separator + :return: list of strings or None if sepChar has a wrong type + """ + # to prevent getting an empty String as argument + if not (isinstance(inputString, str) and isinstance(sepChar, str) and sepChar): + return None + return [fieldString.strip() for fieldString in inputString.split(sepChar) if len(fieldString.strip()) > 0] + + +def randomize(aList: Iterable[T]) -> list[T]: + """Return a randomly sorted list. + + :param aList: list to permute + """ + tmpList = list(aList) + random.shuffle(tmpList) + return tmpList + + +def pop(aList, popElement): + """Pop the first element equal to popElement from the list. + + :param aList: list + :type aList: python:list + :param popElement: element to pop + """ + if popElement in aList: + return aList.pop(aList.index(popElement)) + + +def stringListToString(aList: list) -> str: + """This function is used for making MySQL queries with a list of string elements. + + :param aList: list to be serialized to string for making queries + """ + return ",".join(f"'{x}'" for x in aList) + + +def intListToString(aList: list) -> str: + """This function is used for making MySQL queries with a list of int elements. + + :param aList: list to be serialized to string for making queries + """ + return ",".join(str(x) for x in aList) + + +def getChunk(aList: list, chunkSize: int): + """Generator yielding chunk from a list of a size chunkSize. + + :param aList: list to be splitted + :param chunkSize: lenght of one chunk + :raise: StopIteration + + Usage: + + >>> for chunk in getChunk( aList, chunkSize=10): + process( chunk ) + + """ + chunkSize = int(chunkSize) + for i in range(0, len(aList), chunkSize): + yield aList[i : i + chunkSize] + + +def breakListIntoChunks(aList: list, chunkSize: int): + """This function takes a list as input and breaks it into list of size 'chunkSize'. + It returns a list of lists. + + :param aList: list of elements + :param chunkSize: len of a single chunk + :return: list of lists of length of chunkSize + :raise: RuntimeError if numberOfFilesInChunk is less than 1 + """ + if chunkSize < 1: + raise RuntimeError("chunkSize cannot be less than 1") + if isinstance(aList, (set, dict, tuple, {}.keys().__class__, {}.items().__class__, {}.values().__class__)): + aList = list(aList) + return [chunk for chunk in getChunk(aList, chunkSize)] + + +def getIndexInList(anItem: Any, aList: list) -> int: + """Return the index of the element x in the list l + or sys.maxint if it does not exist + + :param anItem: element to look for + :param aList: list to look into + + :return: the index or sys.maxint + """ + # try: + if anItem in aList: + return aList.index(anItem) + else: + return sys.maxsize diff --git a/dirac-common/src/DIRACCommon/Core/Utilities/StateMachine.py b/dirac-common/src/DIRACCommon/Core/Utilities/StateMachine.py new file mode 100644 index 00000000000..9dad61d75cf --- /dev/null +++ b/dirac-common/src/DIRACCommon/Core/Utilities/StateMachine.py @@ -0,0 +1,185 @@ +""" StateMachine + + This module contains the basic blocks to build a state machine (State and StateMachine) +""" +from DIRACCommon.Core.Utilities.ReturnValues import S_OK, S_ERROR + + +class State: + """ + State class that represents a single step on a StateMachine, with all the + possible transitions, the default transition and an ordering level. + + + examples: + >>> s0 = State(100) + >>> s1 = State(0, ['StateName1', 'StateName2'], defState='StateName1') + >>> s2 = State(0, ['StateName1', 'StateName2']) + # this example is tricky. The transition rule says that will go to + # nextState, e.g. 'StateNext'. But, it is not on the stateMap, and there + # is no default defined, so it will end up going to StateNext anyway. You + # must be careful while defining states and their stateMaps and defaults. + """ + + def __init__(self, level, stateMap=None, defState=None): + """ + :param int level: each state is mapped to an integer, which is used to sort the states according to that integer. + :param list stateMap: it is a list (of strings) with the reachable states from this particular status. + If not defined, we assume there are no restrictions. + :param str defState: default state used in case the next state is not in stateMap (not defined or simply not there). + """ + + self.level = level + self.stateMap = stateMap if stateMap else [] + self.default = defState + + def transitionRule(self, nextState): + """ + Method that selects next state, knowing the default and the transitions + map, and the proposed next state. If is in stateMap, goes there. + If not, then goes to if any. Otherwise, goes to + anyway. + + examples: + >>> s0.transitionRule('nextState') + 'nextState' + >>> s1.transitionRule('StateName2') + 'StateName2' + >>> s1.transitionRule('StateNameNotInMap') + 'StateName1' + >>> s2.transitionRule('StateNameNotInMap') + 'StateNameNotInMap' + + :param str nextState: name of the state in the stateMap + :return: state name + :rtype: str + """ + + # If next state is on the list of next states, go ahead. + if nextState in self.stateMap: + return nextState + + # If not, calculate defaultState: + # if there is a default, that one + # otherwise is nextState (states with empty list have no movement restrictions) + defaultNext = self.default if self.default else nextState + return defaultNext + + +class StateMachine: + """ + StateMachine class that represents the whole state machine with all transitions. + + examples: + >>> sm0 = StateMachine() + >>> sm1 = StateMachine(state = 'Active') + + :param state: current state of the StateMachine, could be None if we do not use the + StateMachine to calculate transitions. Beware, it is not checked if the + state is on the states map ! + :type state: None or str + + """ + + def __init__(self, state=None): + """ + Constructor. + """ + + self.state = state + # To be overwritten by child classes, unless you like Nirvana state that much. + self.states = {"Nirvana": State(100)} + + def getLevelOfState(self, state): + """ + Given a state name, it returns its level (integer), which defines the hierarchy. + + >>> sm0.getLevelOfState('Nirvana') + 100 + >>> sm0.getLevelOfState('AnotherState') + -1 + + :param str state: name of the state, it should be on key set + :return: `int` || -1 (if not in ) + """ + + if state not in self.states: + return -1 + return self.states[state].level + + def setState(self, candidateState, noWarn=False, *, logger_warn=None): + """Makes sure the state is either None or known to the machine, and that it is a valid state to move into. + Final states are also checked. + + examples: + >>> sm0.setState(None)['OK'] + True + >>> sm0.setState('Nirvana')['OK'] + True + >>> sm0.setState('AnotherState')['OK'] + False + + :param state: state which will be set as current state of the StateMachine + :type state: None or str + :return: S_OK || S_ERROR + """ + if candidateState == self.state: + return S_OK(candidateState) + + if not candidateState: + self.state = candidateState + elif candidateState in self.states: + if not self.states[self.state].stateMap: + if not noWarn and logger_warn: + logger_warn("Final state, won't move", f"({self.state}, asked to move to {candidateState})") + return S_OK(self.state) + if candidateState not in self.states[self.state].stateMap and logger_warn: + logger_warn(f"Can't move from {self.state} to {candidateState}, choosing a good one") + result = self.getNextState(candidateState) + if not result["OK"]: + return result + self.state = result["Value"] + # If the StateMachine does not accept the candidate, return error message + else: + return S_ERROR(f"setState: {candidateState!r} is not a valid state") + + return S_OK(self.state) + + def getStates(self): + """ + Returns all possible states in the state map + + examples: + >>> sm0.getStates() + [ 'Nirvana' ] + + :return: list(stateNames) + """ + + return list(self.states) + + def getNextState(self, candidateState): + """ + Method that gets the next state, given the proposed transition to candidateState. + If candidateState is not on the state map , it is rejected. If it is + not the case, we have two options: if is None, then the next state + will be . Otherwise, the current state is using its own + transition rule to decide. + + examples: + >>> sm0.getNextState(None) + S_OK(None) + >>> sm0.getNextState('NextState') + S_OK('NextState') + + :param str candidateState: name of the next state + :return: S_OK(nextState) || S_ERROR + """ + if candidateState not in self.states: + return S_ERROR(f"getNextState: {candidateState!r} is not a valid state") + + # FIXME: do we need this anymore ? + if self.state is None: + return S_OK(candidateState) + + return S_OK(self.states[self.state].transitionRule(candidateState)) diff --git a/dirac-common/src/DIRACCommon/Core/Utilities/TimeUtilities.py b/dirac-common/src/DIRACCommon/Core/Utilities/TimeUtilities.py new file mode 100644 index 00000000000..36229323588 --- /dev/null +++ b/dirac-common/src/DIRACCommon/Core/Utilities/TimeUtilities.py @@ -0,0 +1,259 @@ +""" +DIRAC TimeUtilities module +Support for basic Date and Time operations +based on system datetime module. + +It provides common interface to UTC timestamps, +converter to string types and back. + +Useful timedelta constant are also provided to +define time intervals. + +Notice: datetime.timedelta objects allow multiplication and division by interger +but not by float. Thus: + + - DIRAC.TimeUtilities.second * 1.5 is not allowed + - DIRAC.TimeUtilities.second * 3 / 2 is allowed + +An timeInterval class provides a method to check +if a give datetime is in the defined interval. + +""" +import datetime +import sys +import time + +# Some useful constants for time operations +microsecond = datetime.timedelta(microseconds=1) +second = datetime.timedelta(seconds=1) +minute = datetime.timedelta(minutes=1) +hour = datetime.timedelta(hours=1) +day = datetime.timedelta(days=1) +week = datetime.timedelta(days=7) + + +def timeThis(method, *, logger_info=None): + """Function to be used as a decorator for timing other functions/methods""" + + def timed(*args, **kw): + """What actually times""" + ts = time.time() + result = method(*args, **kw) + if sys.stdout.isatty(): + return result + te = time.time() + + pre = datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d %H:%M:%S UTC ") + + try: + pre += args[0].log.getName() + "/" + args[0].log.getSubName() + " TIME: " + args[0].transString + except AttributeError: + try: + pre += args[0].log.getName() + " TIME: " + args[0].transString + except AttributeError: + try: + pre += args[0].log.getName() + "/" + args[0].log.getSubName() + " TIME: " + except AttributeError: + pre += "TIME: " + except IndexError: + pre += "TIME: " + + argsLen = "" + if args: + try: + if isinstance(args[1], (list, dict)): + argsLen = f"arguments len: {len(args[1])}" + except IndexError: + if kw: + try: + if isinstance(list(list(kw.items())[0])[1], (list, dict)): + argsLen = f"arguments len: {len(list(list(kw.items())[0])[1])}" + except IndexError: + argsLen = "" + + if logger_info is not None: + logger_info(f"{pre} Exec time ===> function {method.__name__!r} {argsLen} -> {te - ts:2.2f} sec") + return result + + return timed + + +def toEpoch(dateTimeObject=None): + """ + Get seconds since epoch. Accepts datetime or date objects + """ + return toEpochMilliSeconds(dateTimeObject) // 1000 + + +def toEpochMilliSeconds(dateTimeObject=None): + """ + Get milliseconds since epoch + """ + if dateTimeObject is None: + dateTimeObject = datetime.datetime.utcnow() + if dateTimeObject.resolution == datetime.timedelta(days=1): + # Add time information corresponding to midnight UTC if it's a datetime.date + dateTimeObject = datetime.datetime.combine( + dateTimeObject, datetime.time.min.replace(tzinfo=datetime.timezone.utc) + ) + posixTime = dateTimeObject.replace(tzinfo=datetime.timezone.utc).timestamp() + return int(posixTime * 1000) + + +def fromEpoch(epoch): + """ + Get datetime object from epoch + """ + # Check if the timestamp is in milliseconds + if epoch > 10**17: # nanoseconds + epoch /= 1000**3 + elif epoch > 10**14: # microseconds + epoch /= 1000**2 + elif epoch > 10**11: # milliseconds + epoch /= 1000 + return datetime.datetime.utcfromtimestamp(epoch) + + +def toString(myDate=None): + """ + Convert to String + if argument type is neither _dateTimeType, _dateType, nor _timeType + the current dateTime converted to String is returned instead + + Notice: datetime.timedelta are converted to strings using the format: + [day] days [hour]:[min]:[sec]:[microsec] + where hour, min, sec, microsec are always positive integers, + and day carries the sign. + To keep internal consistency we are using: + [hour]:[min]:[sec]:[microsec] + where min, sec, microsec are always positive integers and hour carries the sign. + """ + if isinstance(myDate, datetime.date): + return str(myDate) + + elif isinstance(myDate, datetime.time): + return "%02d:%02d:%02d.%06d" % ( + myDate.days * 24 + myDate.seconds / 3600, + myDate.seconds % 3600 / 60, + myDate.seconds % 60, + myDate.microseconds, + ) + else: + return toString(datetime.datetime.utcnow()) + + +def fromString(myDate=None): + """ + Convert date/time/datetime String back to appropriated objects + + The format of the string it is assume to be that returned by toString method. + See notice on toString method + On Error, return None + + :param myDate: the date string to be converted + :type myDate: str or datetime.datetime + """ + if isinstance(myDate, datetime.datetime): + return myDate + if isinstance(myDate, str): + if myDate.find(" ") > 0: + dateTimeTuple = myDate.split(" ") + dateTuple = dateTimeTuple[0].split("-") + try: + return datetime.datetime(year=dateTuple[0], month=dateTuple[1], day=dateTuple[2]) + fromString( + dateTimeTuple[1] + ) + # return datetime.datetime.utcnow().combine( fromString( dateTimeTuple[0] ), + # fromString( dateTimeTuple[1] ) ) + except Exception: + try: + return datetime.datetime( + year=int(dateTuple[0]), month=int(dateTuple[1]), day=int(dateTuple[2]) + ) + fromString(dateTimeTuple[1]) + except ValueError: + return None + # return datetime.datetime.utcnow().combine( fromString( dateTimeTuple[0] ), + # fromString( dateTimeTuple[1] ) ) + elif myDate.find(":") > 0: + timeTuple = myDate.replace(".", ":").split(":") + try: + if len(timeTuple) == 4: + return datetime.timedelta( + hours=int(timeTuple[0]), + minutes=int(timeTuple[1]), + seconds=int(timeTuple[2]), + microseconds=int(timeTuple[3]), + ) + elif len(timeTuple) == 3: + try: + return datetime.timedelta( + hours=int(timeTuple[0]), + minutes=int(timeTuple[1]), + seconds=int(timeTuple[2]), + microseconds=0, + ) + except ValueError: + return None + else: + return None + except Exception: + return None + elif myDate.find("-") > 0: + dateTuple = myDate.split("-") + try: + return datetime.date(int(dateTuple[0]), int(dateTuple[1]), int(dateTuple[2])) + except Exception: + return None + + return None + + +class timeInterval: + """ + Simple class to define a timeInterval object able to check if a given + dateTime is inside + """ + + def __init__(self, initialDateTime, intervalTimeDelta): + """ + Initialization method, it requires the initial dateTime and the + timedelta that define the limits. + The upper limit is not included thus it is [begin,end) + If not properly initialized an error flag is set, and subsequent calls + to any method will return None + """ + if not isinstance(initialDateTime, datetime.datetime) or not isinstance(intervalTimeDelta, datetime.timedelta): + self.__error = True + return None + self.__error = False + if intervalTimeDelta.days < 0: + self.__startDateTime = initialDateTime + intervalTimeDelta + self.__endDateTime = initialDateTime + else: + self.__startDateTime = initialDateTime + self.__endDateTime = initialDateTime + intervalTimeDelta + + def includes(self, myDateTime): + """ """ + if self.__error: + return None + if not isinstance(myDateTime, datetime.datetime): + return None + if myDateTime < self.__startDateTime: + return False + if myDateTime >= self.__endDateTime: + return False + return True + + +def queryTime(f): + """Decorator to measure the function call time""" + + def measureQueryTime(*args, **kwargs): + start = time.time() + result = f(*args, **kwargs) + if result["OK"] and "QueryTime" not in result: + result["QueryTime"] = time.time() - start + return result + + return measureQueryTime diff --git a/dirac-common/src/DIRACCommon/WorkloadManagementSystem/Client/JobState/JobManifest.py b/dirac-common/src/DIRACCommon/WorkloadManagementSystem/Client/JobState/JobManifest.py new file mode 100644 index 00000000000..010f05ba6e0 --- /dev/null +++ b/dirac-common/src/DIRACCommon/WorkloadManagementSystem/Client/JobState/JobManifest.py @@ -0,0 +1,235 @@ +from __future__ import annotations + +from typing import Literal, TypedDict + +from diraccfg import CFG + +from DIRACCommon.Core.Utilities import List +from DIRACCommon.Core.Utilities.JDL import dumpCFGAsJDL, loadJDLAsCFG +from DIRACCommon.Core.Utilities.ReturnValues import S_ERROR, S_OK + + +class JobManifestNumericalVar(TypedDict): + CPUTime: int + Priority: int + + +class JobManifestConfig(TypedDict): + """Dictionary type for defining the information JobManifest needs from the CS""" + + defaultForGroup: JobManifestNumericalVar + minForGroup: JobManifestNumericalVar + maxForGroup: JobManifestNumericalVar + allowedJobTypesForGroup: list[str] + + maxInputData: int + + +class JobManifest: + def __init__(self, manifest=""): + self.__manifest = CFG() + self.__dirty = False + if manifest: + result = self.load(manifest) + if not result["OK"]: + raise Exception(result["Message"]) + + def isDirty(self): + return self.__dirty + + def setDirty(self): + self.__dirty = True + + def clearDirty(self): + self.__dirty = False + + def load(self, dataString): + """ + Auto discover format type based on [ .. ] of JDL + """ + dataString = dataString.strip() + if dataString[0] == "[" and dataString[-1] == "]": + return self.loadJDL(dataString) + else: + return self.loadCFG(dataString) + + def loadJDL(self, jdlString): + """ + Load job manifest from JDL format + """ + result = loadJDLAsCFG(jdlString.strip()) + if not result["OK"]: + self.__manifest = CFG() + return result + self.__manifest = result["Value"][0] + return S_OK() + + def loadCFG(self, cfgString): + """ + Load job manifest from CFG format + """ + try: + self.__manifest.loadFromBuffer(cfgString) + except Exception as e: + return S_ERROR(f"Can't load manifest from cfg: {str(e)}") + return S_OK() + + def dumpAsCFG(self): + return str(self.__manifest) + + def getAsCFG(self): + return self.__manifest.clone() + + def dumpAsJDL(self): + return dumpCFGAsJDL(self.__manifest) + + def _checkNumericalVar( + self, + varName: Literal["CPUTime", "Priority"], + defaultVal: int, + minVal: int, + maxVal: int, + config: JobManifestConfig, + ): + """ + Check a numerical var + """ + if varName in self.__manifest: + varValue = self.__manifest[varName] + else: + varValue = config["defaultForGroup"].get(varName, defaultVal) + try: + varValue = int(varValue) + except ValueError: + return S_ERROR(f"{varName} must be a number") + minVal = config["minForGroup"][varName] + maxVal = config["maxForGroup"][varName] + varValue = max(minVal, min(varValue, maxVal)) + if varName not in self.__manifest: + self.__manifest.setOption(varName, varValue) + return S_OK(varValue) + + def __contains__(self, key): + """Check if the manifest has the required key""" + return key in self.__manifest + + def setOptionsFromDict(self, varDict): + for k in sorted(varDict): + self.setOption(k, varDict[k]) + + def check(self, *, config: JobManifestConfig): + """ + Check that the manifest is OK + """ + for k in ["Owner", "OwnerGroup"]: + if k not in self.__manifest: + return S_ERROR(f"Missing var {k} in manifest") + + # Check CPUTime + result = self._checkNumericalVar("CPUTime", 86400, 100, 500000, config=config) + if not result["OK"]: + return result + + result = self._checkNumericalVar("Priority", 1, 0, 10, config=config) + if not result["OK"]: + return result + + if "InputData" in self.__manifest: + nInput = len(List.fromChar(self.__manifest["InputData"])) + if nInput > config["maxInputData"]: + return S_ERROR( + f"Number of Input Data Files ({nInput}) greater than current limit: {config['maxInputData']}" + ) + + if "JobType" in self.__manifest: + varValue = self.__manifest["JobType"] + for v in List.fromChar(varValue): + if v not in config["allowedJobTypesForGroup"]: + return S_ERROR(f"{v} is not a valid value for JobType") + + return S_OK() + + def createSection(self, secName, contents=False): + if secName not in self.__manifest: + if contents and not isinstance(contents, CFG): + return S_ERROR(f"Contents for section {secName} is not a cfg object") + self.__dirty = True + return S_OK(self.__manifest.createNewSection(secName, contents=contents)) + return S_ERROR(f"Section {secName} already exists") + + def getSection(self, secName): + self.__dirty = True + if secName not in self.__manifest: + return S_ERROR(f"{secName} does not exist") + sec = self.__manifest[secName] + if not sec: + return S_ERROR(f"{secName} section empty") + return S_OK(sec) + + def setSectionContents(self, secName, contents): + if contents and not isinstance(contents, CFG): + return S_ERROR(f"Contents for section {secName} is not a cfg object") + self.__dirty = True + if secName in self.__manifest: + self.__manifest[secName].reset() + self.__manifest[secName].mergeWith(contents) + else: + self.__manifest.createNewSection(secName, contents=contents) + + def setOption(self, varName, varValue): + """ + Set a var in job manifest + """ + self.__dirty = True + levels = List.fromChar(varName, "/") + cfg = self.__manifest + for l in levels[:-1]: + if l not in cfg: + cfg.createNewSection(l) + cfg = cfg[l] + cfg.setOption(levels[-1], varValue) + + def remove(self, opName): + levels = List.fromChar(opName, "/") + cfg = self.__manifest + for l in levels[:-1]: + if l not in cfg: + return S_ERROR(f"{opName} does not exist") + cfg = cfg[l] + if cfg.deleteKey(levels[-1]): + self.__dirty = True + return S_OK() + return S_ERROR(f"{opName} does not exist") + + def getOption(self, varName, defaultValue=None): + """ + Get a variable from the job manifest + """ + cfg = self.__manifest + return cfg.getOption(varName, defaultValue) + + def getOptionList(self, section=""): + """ + Get a list of variables in a section of the job manifest + """ + cfg = self.__manifest.getRecursive(section) + if not cfg or "value" not in cfg: + return [] + cfg = cfg["value"] + return cfg.listOptions() + + def isOption(self, opName): + """ + Check if it is a valid option + """ + return self.__manifest.isOption(opName) + + def getSectionList(self, section=""): + """ + Get a list of sections in the job manifest + """ + cfg = self.__manifest.getRecursive(section) + if not cfg or "value" not in cfg: + return [] + cfg = cfg["value"] + return cfg.listSections() diff --git a/dirac-common/src/DIRACCommon/WorkloadManagementSystem/Client/JobState/__init__.py b/dirac-common/src/DIRACCommon/WorkloadManagementSystem/Client/JobState/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/dirac-common/src/DIRACCommon/WorkloadManagementSystem/Client/JobStatus.py b/dirac-common/src/DIRACCommon/WorkloadManagementSystem/Client/JobStatus.py new file mode 100644 index 00000000000..93bd1d05c2f --- /dev/null +++ b/dirac-common/src/DIRACCommon/WorkloadManagementSystem/Client/JobStatus.py @@ -0,0 +1,95 @@ +""" +This module contains constants and lists for the possible job states. +""" + +from DIRACCommon.Core.Utilities.StateMachine import State, StateMachine + +#: +SUBMITTING = "Submitting" +#: +RECEIVED = "Received" +#: +CHECKING = "Checking" +#: +STAGING = "Staging" +#: +SCOUTING = "Scouting" +#: +WAITING = "Waiting" +#: +MATCHED = "Matched" +#: The Rescheduled status is effectively never stored in the DB. +#: It could be considered a "virtual" status, and might even be dropped. +RESCHEDULED = "Rescheduled" +#: +RUNNING = "Running" +#: +STALLED = "Stalled" +#: +COMPLETING = "Completing" +#: +DONE = "Done" +#: +COMPLETED = "Completed" +#: +FAILED = "Failed" +#: +DELETED = "Deleted" +#: +KILLED = "Killed" + +#: Possible job states +JOB_STATES = [ + SUBMITTING, + RECEIVED, + CHECKING, + SCOUTING, + STAGING, + WAITING, + MATCHED, + RESCHEDULED, + RUNNING, + STALLED, + COMPLETING, + DONE, + COMPLETED, + FAILED, + DELETED, + KILLED, +] + +# Job States when the payload work has finished +JOB_FINAL_STATES = [DONE, COMPLETED, FAILED, KILLED] + +# WMS internal job States indicating the job object won't be updated +JOB_REALLY_FINAL_STATES = [DELETED] + + +class JobsStateMachine(StateMachine): + """Jobs state machine""" + + def __init__(self, state): + """c'tor + Defines the state machine transactions + """ + super().__init__(state) + + # States transitions + self.states = { + DELETED: State(15), # final state + KILLED: State(14, [DELETED], defState=KILLED), + FAILED: State(13, [RESCHEDULED, DELETED], defState=FAILED), + DONE: State(12, [DELETED], defState=DONE), + COMPLETED: State(11, [DONE, FAILED], defState=COMPLETED), + COMPLETING: State(10, [DONE, FAILED, COMPLETED, STALLED, KILLED], defState=COMPLETING), + STALLED: State(9, [RUNNING, FAILED, KILLED], defState=STALLED), + RUNNING: State(8, [STALLED, DONE, FAILED, RESCHEDULED, COMPLETING, KILLED, RECEIVED], defState=RUNNING), + RESCHEDULED: State(7, [WAITING, RECEIVED, DELETED, FAILED, KILLED], defState=RESCHEDULED), + MATCHED: State(6, [RUNNING, FAILED, RESCHEDULED, KILLED], defState=MATCHED), + WAITING: State(5, [MATCHED, RESCHEDULED, DELETED, KILLED], defState=WAITING), + STAGING: State(4, [CHECKING, WAITING, FAILED, KILLED], defState=STAGING), + SCOUTING: State(3, [CHECKING, FAILED, STALLED, KILLED], defState=SCOUTING), + CHECKING: State(2, [SCOUTING, STAGING, WAITING, RESCHEDULED, FAILED, DELETED, KILLED], defState=CHECKING), + RECEIVED: State(1, [SCOUTING, CHECKING, STAGING, WAITING, FAILED, DELETED, KILLED], defState=RECEIVED), + SUBMITTING: State(0, [RECEIVED, CHECKING, DELETED, KILLED], defState=SUBMITTING), # initial state + } diff --git a/dirac-common/src/DIRACCommon/WorkloadManagementSystem/Client/__init__.py b/dirac-common/src/DIRACCommon/WorkloadManagementSystem/Client/__init__.py new file mode 100644 index 00000000000..3de496ff9e0 --- /dev/null +++ b/dirac-common/src/DIRACCommon/WorkloadManagementSystem/Client/__init__.py @@ -0,0 +1 @@ +"""DIRACCommon WorkloadManagementSystem client utilities""" diff --git a/dirac-common/src/DIRACCommon/WorkloadManagementSystem/DB/JobDBUtils.py b/dirac-common/src/DIRACCommon/WorkloadManagementSystem/DB/JobDBUtils.py index bfca1a587dc..eccdba2a42e 100644 --- a/dirac-common/src/DIRACCommon/WorkloadManagementSystem/DB/JobDBUtils.py +++ b/dirac-common/src/DIRACCommon/WorkloadManagementSystem/DB/JobDBUtils.py @@ -4,6 +4,12 @@ import base64 import zlib +from typing import TypedDict + +from DIRACCommon.Core.Utilities.ReturnValues import S_OK, DOKReturnType, S_ERROR +from DIRACCommon.Core.Utilities.DErrno import EWMSSUBM +from DIRACCommon.WorkloadManagementSystem.Client import JobStatus +from DIRACCommon.WorkloadManagementSystem.Client.JobState.JobManifest import JobManifest, JobManifestConfig def compressJDL(jdl): @@ -31,3 +37,134 @@ def fixJDL(jdl: str) -> str: if jdl.strip()[0].find("[") != 0: jdl = "[" + jdl + "]" return jdl + + +class CheckAndPrepareJobConfig(TypedDict): + """Dictionary type for defining the information checkAndPrepareJob needs from the CS""" + + inputDataPolicyForVO: str + softwareDistModuleForVO: str + defaultCPUTimeForOwnerGroup: int + getDIRACPlatform: callable[list[str], DOKReturnType[list[str]]] + + +def checkAndPrepareJob( + jobID, classAdJob, classAdReq, owner, ownerGroup, jobAttrs, vo, *, config: CheckAndPrepareJobConfig +): + error = "" + + jdlOwner = classAdJob.getAttributeString("Owner") + jdlOwnerGroup = classAdJob.getAttributeString("OwnerGroup") + jdlVO = classAdJob.getAttributeString("VirtualOrganization") + + # The below is commented out since this is always overwritten by the submitter IDs + # but the check allows to findout inconsistent client environments + if jdlOwner and jdlOwner != owner: + error = "Wrong Owner in JDL" + elif jdlOwnerGroup and jdlOwnerGroup != ownerGroup: + error = "Wrong Owner Group in JDL" + elif jdlVO and jdlVO != vo: + error = "Wrong Virtual Organization in JDL" + + classAdJob.insertAttributeString("Owner", owner) + classAdJob.insertAttributeString("OwnerGroup", ownerGroup) + + if vo: + classAdJob.insertAttributeString("VirtualOrganization", vo) + + classAdReq.insertAttributeString("Owner", owner) + classAdReq.insertAttributeString("OwnerGroup", ownerGroup) + if vo: + classAdReq.insertAttributeString("VirtualOrganization", vo) + + if config["inputDataPolicyForVO"] and not classAdJob.lookupAttribute("InputDataModule"): + classAdJob.insertAttributeString("InputDataModule", config["inputDataPolicyForVO"]) + + if config["softwareDistModuleForVO"] and not classAdJob.lookupAttribute("SoftwareDistModule"): + classAdJob.insertAttributeString("SoftwareDistModule", config["softwareDistModuleForVO"]) + + # priority + priority = classAdJob.getAttributeInt("Priority") + if priority is None: + priority = 0 + classAdReq.insertAttributeInt("UserPriority", priority) + + # CPU time + cpuTime = classAdJob.getAttributeInt("CPUTime") + if cpuTime is None: + cpuTime = config["defaultCPUTimeForOwnerGroup"] + classAdReq.insertAttributeInt("CPUTime", cpuTime) + + # platform(s) + platformList = classAdJob.getListFromExpression("Platform") + if platformList: + result = config["getDIRACPlatform"](platformList) + if not result["OK"]: + return result + if result["Value"]: + classAdReq.insertAttributeVectorString("Platforms", result["Value"]) + else: + error = "OS compatibility info not found" + if error: + retVal = S_ERROR(EWMSSUBM, error) + retVal["JobId"] = jobID + retVal["Status"] = JobStatus.FAILED + retVal["MinorStatus"] = error + + jobAttrs["Status"] = JobStatus.FAILED + + jobAttrs["MinorStatus"] = error + return retVal + return S_OK() + + +def checkAndAddOwner( + jdl: str, owner: str, ownerGroup: str, *, job_manifest_config: JobManifestConfig +) -> DOKReturnType[JobManifest]: + jobManifest = JobManifest() + res = jobManifest.load(jdl) + if not res["OK"]: + return res + + jobManifest.setOptionsFromDict({"Owner": owner, "OwnerGroup": ownerGroup}) + res = jobManifest.check(config=job_manifest_config) + if not res["OK"]: + return res + + return S_OK(jobManifest) + + +def createJDLWithInitialStatus( + classAdJob, classAdReq, jdl2DBParameters, jobAttrs, initialStatus, initialMinorStatus, *, modern=False +): + """ + :param modern: if True, store boolean instead of string for VerifiedFlag (used by diracx only) + """ + priority = classAdJob.getAttributeInt("Priority") + if priority is None: + priority = 0 + jobAttrs["UserPriority"] = priority + + for jdlName in jdl2DBParameters: + # Defaults are set by the DB. + jdlValue = classAdJob.getAttributeString(jdlName) + if jdlValue: + jobAttrs[jdlName] = jdlValue + + jdlValue = classAdJob.getAttributeString("Site") + if jdlValue: + if jdlValue.find(",") != -1: + jobAttrs["Site"] = "Multiple" + else: + jobAttrs["Site"] = jdlValue + + jobAttrs["VerifiedFlag"] = True if modern else "True" + + jobAttrs["Status"] = initialStatus + + jobAttrs["MinorStatus"] = initialMinorStatus + + reqJDL = classAdReq.asJDL() + classAdJob.insertAttributeInt("JobRequirements", reqJDL) + + return classAdJob.asJDL() diff --git a/dirac-common/src/DIRACCommon/WorkloadManagementSystem/Utilities/JobModel.py b/dirac-common/src/DIRACCommon/WorkloadManagementSystem/Utilities/JobModel.py new file mode 100644 index 00000000000..6e52a3c2874 --- /dev/null +++ b/dirac-common/src/DIRACCommon/WorkloadManagementSystem/Utilities/JobModel.py @@ -0,0 +1,236 @@ +""" This module contains the JobModel class, which is used to validate the job description """ + +# pylint: disable=no-self-argument, no-self-use, invalid-name, missing-function-docstring + +from collections.abc import Iterable +from typing import Annotated, Any, Callable, ClassVar, Self, TypeAlias, TypedDict + +from pydantic import BaseModel, BeforeValidator, ConfigDict, field_validator, model_validator + +from DIRACCommon.Core.Utilities.ReturnValues import DErrorReturnType + + +# HACK: Convert appropriate iterables into sets +def default_set_validator(value): + if value is None: + return set() + elif not isinstance(value, Iterable): + return value + elif isinstance(value, (str, bytes, bytearray)): + return value + else: + return set(value) + + +CoercibleSetStr: TypeAlias = Annotated[set[str], BeforeValidator(default_set_validator)] + + +class BaseJobDescriptionModelConfg(TypedDict): + """Dictionary type for defining the information JobDescriptionModel needs from the CS""" + + # Default values + cpuTime: int + priority: int + # Bounds + minCPUTime: int + maxCPUTime: int + allowedJobTypes: list[str] + maxInputDataFiles: int + minNumberOfProcessors: int + maxNumberOfProcessors: int + minPriority: int + maxPriority: int + possibleLogLevels: list[str] + sites: DErrorReturnType[list[str]] + + +class BaseJobDescriptionModel(BaseModel): + """Base model for the job description (not parametric)""" + + model_config = ConfigDict(validate_assignment=True) + + # This must be overridden in subclasses + _config_builder: ClassVar[Callable[[], BaseJobDescriptionModelConfg] | None] = None + + @model_validator(mode="before") + def injectDefaultValues(cls, values: dict[str, Any]) -> dict[str, Any]: + if cls._config_builder is None: + raise NotImplementedError("You must define a _config_builder class attribute") + config = cls._config_builder() + values.setdefault("cpuTime", config["cpuTime"]) + values.setdefault("priority", config["priority"]) + return values + + arguments: str = "" + bannedSites: CoercibleSetStr = set() + # TODO: This should use a field factory + cpuTime: int + executable: str + executionEnvironment: dict = None + gridCE: str = "" + inputSandbox: CoercibleSetStr = set() + inputData: CoercibleSetStr = set() + inputDataPolicy: str = "" + jobConfigArgs: str = "" + jobGroup: str = "" + jobType: str = "User" + jobName: str = "Name" + # TODO: This should be an StrEnum + logLevel: str = "INFO" + # TODO: This can't be None with this type hint + maxNumberOfProcessors: int = None + minNumberOfProcessors: int = 1 + outputData: CoercibleSetStr = set() + outputPath: str = "" + outputSandbox: CoercibleSetStr = set() + outputSE: str = "" + platform: str = "" + # TODO: This should use a field factory + priority: int + sites: CoercibleSetStr = set() + stderr: str = "std.err" + stdout: str = "std.out" + tags: CoercibleSetStr = set() + extraFields: dict[str, Any] = {} + + @field_validator("cpuTime") + def checkCPUTimeBounds(cls, v): + minCPUTime = cls._config_builder()["minCPUTime"] + maxCPUTime = cls._config_builder()["maxCPUTime"] + if not minCPUTime <= v <= maxCPUTime: + raise ValueError(f"cpuTime out of bounds (must be between {minCPUTime} and {maxCPUTime})") + return v + + @field_validator("executable") + def checkExecutableIsNotAnEmptyString(cls, v: str): + if not v: + raise ValueError("executable must not be an empty string") + return v + + @field_validator("jobType") + def checkJobTypeIsAllowed(cls, v: str): + allowedTypes = cls._config_builder()["allowedJobTypes"] + if v not in allowedTypes: + raise ValueError(f"jobType '{v}' is not allowed for this kind of user (must be in {allowedTypes})") + return v + + @field_validator("inputData") + def checkInputDataDoesntContainDoubleSlashes(cls, v): + if v: + for lfn in v: + if lfn.find("//") > -1: + raise ValueError("Input data contains //") + return v + + @field_validator("inputData") + def addLFNPrefixIfStringStartsWithASlash(cls, v: set[str]): + if v: + v = {lfn.strip() for lfn in v if lfn.strip()} + v = {f"LFN:{lfn}" if lfn.startswith("/") else lfn for lfn in v} + + for lfn in v: + if not lfn.startswith("LFN:/"): + raise ValueError("Input data files must start with LFN:/") + return v + + @model_validator(mode="after") + def checkNumberOfInputDataFiles(self) -> Self: + if self.inputData: + maxInputDataFiles = self._config_builder()["maxInputDataFiles"] + if self.jobType == "User" and len(self.inputData) >= maxInputDataFiles: + raise ValueError(f"inputData contains too many files (must contain at most {maxInputDataFiles})") + return self + + @field_validator("inputSandbox") + def checkLFNSandboxesAreWellFormated(cls, v: set[str]): + for inputSandbox in v: + if inputSandbox.startswith("LFN:") and not inputSandbox.startswith("LFN:/"): + raise ValueError("LFN files must start by LFN:/") + return v + + @field_validator("logLevel") + def checkLogLevelIsValid(cls, v: str): + v = v.upper() + possibleLogLevels = cls._config_builder()["possibleLogLevels"] + if v not in possibleLogLevels: + raise ValueError(f"Log level {v} not in {possibleLogLevels}") + return v + + @field_validator("minNumberOfProcessors") + def checkMinNumberOfProcessorsBounds(cls, v): + minNumberOfProcessors = cls._config_builder()["minNumberOfProcessors"] + maxNumberOfProcessors = cls._config_builder()["maxNumberOfProcessors"] + if not minNumberOfProcessors <= v <= maxNumberOfProcessors: + raise ValueError( + f"minNumberOfProcessors out of bounds (must be between {minNumberOfProcessors} and {maxNumberOfProcessors})" + ) + return v + + @field_validator("maxNumberOfProcessors") + def checkMaxNumberOfProcessorsBounds(cls, v): + minNumberOfProcessors = cls._config_builder()["minNumberOfProcessors"] + maxNumberOfProcessors = cls._config_builder()["maxNumberOfProcessors"] + if not minNumberOfProcessors <= v <= maxNumberOfProcessors: + raise ValueError( + f"maxNumberOfProcessors out of bounds (must be between {minNumberOfProcessors} and {maxNumberOfProcessors})" + ) + return v + + @model_validator(mode="after") + def checkThatMaxNumberOfProcessorsIsGreaterThanMinNumberOfProcessors(self) -> Self: + if self.maxNumberOfProcessors: + if self.maxNumberOfProcessors < self.minNumberOfProcessors: + raise ValueError("maxNumberOfProcessors must be greater than minNumberOfProcessors") + return self + + @model_validator(mode="after") + def addTagsDependingOnNumberOfProcessors(self) -> Self: + if self.minNumberOfProcessors == self.maxNumberOfProcessors: + self.tags.add(f"{self.minNumberOfProcessors}Processors") + if self.minNumberOfProcessors > 1: + self.tags.add("MultiProcessor") + return self + + @field_validator("sites") + def checkSites(cls, v: set[str]): + if v: + res = cls._config_builder()["sites"] + if not res["OK"]: + raise ValueError(res["Message"]) + invalidSites = v - set(res["Value"]).union({"ANY"}) + if invalidSites: + raise ValueError(f"Invalid sites: {' '.join(invalidSites)}") + return v + + @model_validator(mode="after") + def checkThatSitesAndBannedSitesAreNotMutuallyExclusive(self) -> Self: + if self.sites and self.bannedSites: + while self.bannedSites: + self.sites.discard(self.bannedSites.pop()) + if not self.sites: + raise ValueError("sites and bannedSites are mutually exclusive") + return self + + @field_validator("priority") + def checkPriorityBounds(cls, v): + minPriority = cls._config_builder()["minPriority"] + maxPriority = cls._config_builder()["maxPriority"] + if not minPriority <= v <= maxPriority: + raise ValueError(f"priority out of bounds (must be between {minPriority} and {maxPriority})") + return v + + +class JobDescriptionModel(BaseJobDescriptionModel): + """Model for the job description (non parametric job with user credentials, i.e server side)""" + + owner: str + ownerGroup: str + vo: str + + @model_validator(mode="after") + def checkLFNMatchesREGEX(self) -> Self: + if self.inputData: + for lfn in self.inputData: + if not lfn.startswith(f"LFN:/{self.vo}/"): + raise ValueError(f"Input data not correctly specified (must start with LFN:/{self.vo}/)") + return self diff --git a/dirac-common/src/DIRACCommon/WorkloadManagementSystem/Utilities/JobStatusUtility.py b/dirac-common/src/DIRACCommon/WorkloadManagementSystem/Utilities/JobStatusUtility.py new file mode 100644 index 00000000000..51c919adb45 --- /dev/null +++ b/dirac-common/src/DIRACCommon/WorkloadManagementSystem/Utilities/JobStatusUtility.py @@ -0,0 +1,93 @@ +"""Stateless job status utility functions""" + +from __future__ import annotations + +from datetime import datetime +from typing import Any + +from DIRACCommon.Core.Utilities.ReturnValues import S_OK +from DIRACCommon.Core.Utilities.TimeUtilities import toEpoch, fromString +from DIRACCommon.WorkloadManagementSystem.Client.JobStatus import RUNNING, JOB_FINAL_STATES, JobsStateMachine + + +def getStartAndEndTime(startTime, endTime, updateTimes, timeStamps, statusDict): + """Get start and end times from job status updates + + :param startTime: current start time + :param endTime: current end time + :param updateTimes: list of update times + :param timeStamps: list of (timestamp, status) tuples + :param statusDict: dictionary mapping update times to status dictionaries + :return: tuple of (newStartTime, newEndTime) + """ + newStat = "" + firstUpdate = toEpoch(fromString(updateTimes[0])) + for ts, st in timeStamps: + if firstUpdate >= ts: + newStat = st + # Pick up start and end times from all updates + for updTime in updateTimes: + sDict = statusDict[updTime] + newStat = sDict.get("Status", newStat) + + if not startTime and newStat == RUNNING: + # Pick up the start date when the job starts running if not existing + startTime = updTime + elif not endTime and newStat in JOB_FINAL_STATES: + # Pick up the end time when the job is in a final status + endTime = updTime + + return startTime, endTime + + +def getNewStatus( + jobID: int, + updateTimes: list[datetime], + lastTime: datetime, + statusDict: dict[datetime, Any], + currentStatus, + force: bool, + log, +): + """Get new job status from status updates + + :param jobID: job ID + :param updateTimes: list of update times + :param lastTime: last update time + :param statusDict: dictionary mapping update times to status dictionaries + :param currentStatus: current job status + :param force: whether to force status update without state machine validation + :param log: logger object + :return: S_OK((status, minor, application)) or S_ERROR + """ + status = "" + minor = "" + application = "" + # Get the last status values looping on the most recent upupdateTimes in chronological order + for updTime in [dt for dt in updateTimes if dt >= lastTime]: + sDict = statusDict[updTime] + log.debug(f"\tTime {updTime} - Statuses {str(sDict)}") + status = sDict.get("Status", currentStatus) + # evaluate the state machine if the status is changing + if not force and status != currentStatus: + res = JobsStateMachine(currentStatus).getNextState(status) + if not res["OK"]: + return res + newStat = res["Value"] + # If the JobsStateMachine does not accept the candidate, don't update + if newStat != status: + # keeping the same status + log.error( + f"Job Status Error: {jobID} can't move from {currentStatus} to {status}: using {newStat}", + ) + status = newStat + sDict["Status"] = newStat + # Change the source to indicate this is not what was requested + source = sDict.get("Source", "") + sDict["Source"] = source + "(SM)" + # at this stage status == newStat. Set currentStatus to this new status + currentStatus = newStat + + minor = sDict.get("MinorStatus", minor) + application = sDict.get("ApplicationStatus", application) + return S_OK((status, minor, application)) diff --git a/dirac-common/tests/Core/Utilities/test_JDL.py b/dirac-common/tests/Core/Utilities/test_JDL.py new file mode 100644 index 00000000000..1692f97f740 --- /dev/null +++ b/dirac-common/tests/Core/Utilities/test_JDL.py @@ -0,0 +1,113 @@ +"""Unit tests for pure JDL parsing utilities""" + +import unittest + +from diraccfg import CFG +from DIRACCommon.Core.Utilities.JDL import loadJDLAsCFG, dumpCFGAsJDL + + +class TestJDLParsing(unittest.TestCase): + """Test cases for JDL parsing functions""" + + def test_loadJDLAsCFG_simple(self): + """Test basic JDL parsing""" + jdl = '[Executable = "test.sh"; Arguments = "arg1 arg2";]' + result = loadJDLAsCFG(jdl) + self.assertTrue(result["OK"]) + cfg, _ = result["Value"] + self.assertEqual(cfg.getOption("Executable"), "test.sh") + self.assertEqual(cfg.getOption("Arguments"), "arg1 arg2") + + def test_loadJDLAsCFG_with_lists(self): + """Test JDL parsing with lists""" + jdl = '[InputSandbox = {"file1.txt", "file2.txt"}; Priority = 5;]' + result = loadJDLAsCFG(jdl) + self.assertTrue(result["OK"]) + cfg, _ = result["Value"] + self.assertEqual(cfg.getOption("InputSandbox"), "file1.txt, file2.txt") + self.assertEqual(cfg.getOption("Priority"), "5") + + def test_loadJDLAsCFG_empty(self): + """Test empty JDL parsing""" + jdl = "[]" + result = loadJDLAsCFG(jdl) + self.assertTrue(result["OK"]) + cfg, _ = result["Value"] + self.assertEqual(len(cfg.listOptions()), 0) + + def test_loadJDLAsCFG_invalid_format(self): + """Test invalid JDL format""" + jdl = '[Executable = "test.sh" Arguments = "missing semicolon"]' + result = loadJDLAsCFG(jdl) + # This should either succeed with partial parsing or fail gracefully + # The exact behavior depends on the parser implementation + + def test_dumpCFGAsJDL_simple(self): + """Test basic CFG to JDL conversion""" + cfg = CFG() + cfg.setOption("Executable", "test.sh") + cfg.setOption("Arguments", "arg1 arg2") + + jdl = dumpCFGAsJDL(cfg) + + # Should contain the key fields + self.assertIn("Executable", jdl) + self.assertIn("test.sh", jdl) + self.assertIn("Arguments", jdl) + self.assertIn("arg1 arg2", jdl) + + # Should be properly formatted + self.assertTrue(jdl.strip().startswith("[")) + self.assertTrue(jdl.strip().endswith("]")) + + def test_dumpCFGAsJDL_with_numbers(self): + """Test CFG to JDL with numerical values""" + cfg = CFG() + cfg.setOption("CPUTime", "3600") + cfg.setOption("Priority", "5") + + jdl = dumpCFGAsJDL(cfg) + + # Numbers should not be quoted + self.assertIn("CPUTime = 3600", jdl) + self.assertIn("Priority = 5", jdl) + + def test_dumpCFGAsJDL_with_lists(self): + """Test CFG to JDL with list values""" + cfg = CFG() + cfg.setOption("InputSandbox", "file1.txt, file2.txt") + + jdl = dumpCFGAsJDL(cfg) + + # Lists should be formatted with braces + self.assertIn("InputSandbox", jdl) + self.assertIn("{", jdl) + self.assertIn("}", jdl) + self.assertIn("file1.txt", jdl) + self.assertIn("file2.txt", jdl) + + def test_roundtrip_conversion(self): + """Test that JDL -> CFG -> JDL preserves basic information""" + original_jdl = '[Executable = "test.sh"; CPUTime = 3600; InputSandbox = {"file1.txt", "file2.txt"};]' + + # Parse JDL to CFG + result = loadJDLAsCFG(original_jdl) + self.assertTrue(result["OK"]) + cfg, _ = result["Value"] + + # Convert back to JDL + new_jdl = dumpCFGAsJDL(cfg) + + # Parse the new JDL + result2 = loadJDLAsCFG(new_jdl) + self.assertTrue(result2["OK"]) + cfg2, _ = result2["Value"] + + # Compare key values + self.assertEqual(cfg.getOption("Executable"), cfg2.getOption("Executable")) + self.assertEqual(cfg.getOption("CPUTime"), cfg2.getOption("CPUTime")) + # Note: List format might be different but content should be preserved + + +if __name__ == "__main__": + unittest.main() diff --git a/src/DIRAC/Core/Utilities/test/Test_List.py b/dirac-common/tests/Core/Utilities/test_List.py similarity index 96% rename from src/DIRAC/Core/Utilities/test/Test_List.py rename to dirac-common/tests/Core/Utilities/test_List.py index 05031ea9257..b0cf2512bf9 100644 --- a/src/DIRAC/Core/Utilities/test/Test_List.py +++ b/dirac-common/tests/Core/Utilities/test_List.py @@ -5,19 +5,19 @@ """.. module:: ListTestCase -Test cases for DIRAC.Core.Utilities.List module. +Test cases for DIRACCommon.Core.Utilities.List module. """ import unittest # sut -from DIRAC.Core.Utilities import List +from DIRACCommon.Core.Utilities import List ######################################################################## class ListTestCase(unittest.TestCase): """py:class ListTestCase - Test case for DIRAC.Core.Utilities.List module. + Test case for DIRACCommon.Core.Utilities.List module. """ def testUniqueElements(self): diff --git a/src/DIRAC/Core/Utilities/test/Test_Time.py b/dirac-common/tests/Core/Utilities/test_TimeUtilities.py old mode 100755 new mode 100644 similarity index 93% rename from src/DIRAC/Core/Utilities/test/Test_Time.py rename to dirac-common/tests/Core/Utilities/test_TimeUtilities.py index e4885e63183..96f16bb02ae --- a/src/DIRAC/Core/Utilities/test/Test_Time.py +++ b/dirac-common/tests/Core/Utilities/test_TimeUtilities.py @@ -1,10 +1,9 @@ -""" Test class for plugins -""" +"""Test class for TimeUtilities""" # imports import unittest # sut -from DIRAC.Core.Utilities.TimeUtilities import timeThis +from DIRACCommon.Core.Utilities.TimeUtilities import timeThis class logClass: @@ -48,7 +47,7 @@ def myMethodInAClass(self, a, b=None): class TimeTestCase(unittest.TestCase): - """Base class for the Agents test cases""" + """Base class for the TimeUtilities test cases""" def setUp(self): pass diff --git a/dirac-common/tests/WorkloadManagementSystem/Client/JobState/__init__.py b/dirac-common/tests/WorkloadManagementSystem/Client/JobState/__init__.py new file mode 100644 index 00000000000..c311458082a --- /dev/null +++ b/dirac-common/tests/WorkloadManagementSystem/Client/JobState/__init__.py @@ -0,0 +1 @@ +# Init file for WorkloadManagementSystem.Client.JobState tests diff --git a/dirac-common/tests/WorkloadManagementSystem/Client/JobState/test_JobManifest.py b/dirac-common/tests/WorkloadManagementSystem/Client/JobState/test_JobManifest.py new file mode 100644 index 00000000000..68d50f06d82 --- /dev/null +++ b/dirac-common/tests/WorkloadManagementSystem/Client/JobState/test_JobManifest.py @@ -0,0 +1,129 @@ +"""Test the JobManifest class.""" + +import unittest + +from DIRACCommon.WorkloadManagementSystem.Client.JobState.JobManifest import JobManifest + + +class TestJobManifest(unittest.TestCase): + """Test cases for JobManifest""" + + def test_create_empty_manifest(self): + """Test creating an empty manifest""" + manifest = JobManifest() + self.assertFalse(manifest.isDirty()) + + def test_load_simple_jdl(self): + """Test loading a simple JDL""" + jdl = '[Executable = "test.sh"; Arguments = "arg1";]' + manifest = JobManifest(jdl) + + self.assertEqual(manifest.getOption("Executable"), "test.sh") + self.assertEqual(manifest.getOption("Arguments"), "arg1") + + def test_load_cfg_format(self): + """Test loading CFG format""" + cfg_content = """Executable = test.sh +Arguments = arg1 arg2 +CPUTime = 3600""" + + manifest = JobManifest() + result = manifest.loadCFG(cfg_content) + self.assertTrue(result["OK"]) + + self.assertEqual(manifest.getOption("Executable"), "test.sh") + self.assertEqual(manifest.getOption("Arguments"), "arg1 arg2") + self.assertEqual(manifest.getOption("CPUTime"), "3600") + + def test_set_and_get_options(self): + """Test setting and getting options""" + manifest = JobManifest() + + manifest.setOption("TestOption", "TestValue") + self.assertTrue(manifest.isDirty()) + + self.assertEqual(manifest.getOption("TestOption"), "TestValue") + + def test_dump_formats(self): + """Test dumping to different formats""" + manifest = JobManifest() + manifest.setOption("Executable", "test.sh") + manifest.setOption("CPUTime", "3600") + + # Test CFG format + cfg_output = manifest.dumpAsCFG() + self.assertIn("Executable", cfg_output) + self.assertIn("test.sh", cfg_output) + + # Test JDL format + jdl_output = manifest.dumpAsJDL() + self.assertIn("Executable", jdl_output) + self.assertIn("test.sh", jdl_output) + self.assertTrue(jdl_output.strip().startswith("[")) + self.assertTrue(jdl_output.strip().endswith("]")) + + def test_contains_operator(self): + """Test the __contains__ operator""" + manifest = JobManifest() + manifest.setOption("TestKey", "TestValue") + + self.assertTrue("TestKey" in manifest) + self.assertFalse("NonExistentKey" in manifest) + + def test_create_and_get_section(self): + """Test creating and getting sections""" + manifest = JobManifest() + + result = manifest.createSection("TestSection") + self.assertTrue(result["OK"]) + + result = manifest.getSection("TestSection") + self.assertTrue(result["OK"]) + + # Try to create same section again - should fail + result = manifest.createSection("TestSection") + self.assertFalse(result["OK"]) + + def test_option_list_operations(self): + """Test getting lists of options and sections""" + manifest = JobManifest() + manifest.setOption("Option1", "Value1") + manifest.setOption("Option2", "Value2") + manifest.createSection("Section1") + + options = manifest.getOptionList() + self.assertIn("Option1", options) + self.assertIn("Option2", options) + + sections = manifest.getSectionList() + self.assertIn("Section1", sections) + + def test_remove_option(self): + """Test removing options""" + manifest = JobManifest() + manifest.setOption("TestOption", "TestValue") + + self.assertTrue("TestOption" in manifest) + + result = manifest.remove("TestOption") + self.assertTrue(result["OK"]) + + self.assertFalse("TestOption" in manifest) + + def test_dirty_flag_management(self): + """Test dirty flag management""" + manifest = JobManifest() + self.assertFalse(manifest.isDirty()) + + manifest.setOption("Test", "Value") + self.assertTrue(manifest.isDirty()) + + manifest.clearDirty() + self.assertFalse(manifest.isDirty()) + + manifest.setDirty() + self.assertTrue(manifest.isDirty()) + + +if __name__ == "__main__": + unittest.main() diff --git a/dirac-common/tests/WorkloadManagementSystem/Client/__init__.py b/dirac-common/tests/WorkloadManagementSystem/Client/__init__.py new file mode 100644 index 00000000000..6641950d23a --- /dev/null +++ b/dirac-common/tests/WorkloadManagementSystem/Client/__init__.py @@ -0,0 +1 @@ +# Init file for WorkloadManagementSystem.Client tests diff --git a/dirac-common/tests/WorkloadManagementSystem/Utilities/test_JobModel.py b/dirac-common/tests/WorkloadManagementSystem/Utilities/test_JobModel.py new file mode 100644 index 00000000000..052c464b771 --- /dev/null +++ b/dirac-common/tests/WorkloadManagementSystem/Utilities/test_JobModel.py @@ -0,0 +1,264 @@ +"""Test the BaseJobDescriptionModel class and its validators.""" + +# pylint: disable=invalid-name + +import pytest +from pydantic import ValidationError + +from DIRACCommon.Core.Utilities.ReturnValues import S_OK +from DIRACCommon.WorkloadManagementSystem.Utilities.JobModel import ( + BaseJobDescriptionModel, + BaseJobDescriptionModelConfg, +) + +EXECUTABLE = "dirac-jobexec" +VO = "vo" + + +def _make_test_config(*args, **kwargs) -> BaseJobDescriptionModelConfg: + """Create a test configuration for BaseJobDescriptionModel""" + return { + "cpuTime": 86400, + "priority": 1, + "minCPUTime": 100, + "maxCPUTime": 500000, + "allowedJobTypes": ["User", "Test", "Hospital"], + "maxInputDataFiles": 10000, + "minNumberOfProcessors": 1, + "maxNumberOfProcessors": 1024, + "minPriority": 0, + "maxPriority": 10, + "possibleLogLevels": ["DEBUG", "INFO", "WARN", "ERROR"], + "sites": S_OK(["LCG.CERN.ch", "LCG.IN2P3.fr"]), + } + + +class JobDescriptionModelForTest(BaseJobDescriptionModel): + """Test version of BaseJobDescriptionModel with test configuration""" + + _config_builder = _make_test_config + + +@pytest.mark.parametrize( + "cpuTime", + [ + 100, # Lower bound + 86400, # Default + 500000, # Higher bound + ], +) +def test_cpuTimeValidator_valid(cpuTime: int): + """Test the cpuTime validator.""" + JobDescriptionModelForTest(executable=EXECUTABLE, cpuTime=cpuTime) + + +@pytest.mark.parametrize( + "cpuTime", + [ + 1, # Too low + 100000000, # Too high + "qwerty", # Not an int + ], +) +def test_cpuTimeValidator_invalid(cpuTime: int): + """Test the cpuTime validator with invalid input.""" + with pytest.raises(ValidationError): + JobDescriptionModelForTest(executable=EXECUTABLE, cpuTime=cpuTime) + + +def test_jobType_valid(): + """Test the jobType validator with valid input.""" + JobDescriptionModelForTest(executable=EXECUTABLE, jobType="User") + + +def test_jobType_invalid(): + """Test the jobType validator with invalid input.""" + with pytest.raises(ValidationError): + JobDescriptionModelForTest(executable=EXECUTABLE, jobType="Production") + + +@pytest.mark.parametrize( + "priority", + [ + 0, # Lower bound + 1, # Default + 10, # Higher bound + ], +) +def test_priorityValidator_valid(priority): + """Test the priority validator with valid input.""" + JobDescriptionModelForTest( + executable=EXECUTABLE, + priority=priority, + ) + + +@pytest.mark.parametrize( + "priority", + [ + -1, # Too low + 11, # Too high + "qwerty", # Not an int + ], +) +def test_priorityValidator_invalid(priority): + """Test the priority validator with invalid input""" + with pytest.raises(ValidationError): + JobDescriptionModelForTest( + executable=EXECUTABLE, + priority=priority, + ) + + +@pytest.mark.parametrize( + "inputData,parsedInputData,jobType", + [ + ({f" /{VO}/1", " "}, {f"LFN:/{VO}/1"}, "User"), + ({f"/{VO}/1", f"LFN:/{VO}/2"}, {f"LFN:/{VO}/1", f"LFN:/{VO}/2"}, "User"), + ({f"LFN:/{VO}/1", f"LFN:/{VO}/2"}, {f"LFN:/{VO}/1", f"LFN:/{VO}/2"}, "User"), + ( + {f"LFN:/{VO}/{i}" for i in range(100)}, + {f"LFN:/{VO}/{i}" for i in range(100)}, + "Test", + ), # Reduced size for DIRACCommon + ], +) +def test_inputDataValidator_valid(inputData: set[str], parsedInputData: set[str], jobType: str): + """Test the inputData validator with valid input.""" + job = JobDescriptionModelForTest( + executable=EXECUTABLE, + inputData=inputData, + jobType=jobType, + ) + assert job.inputData == parsedInputData + + +@pytest.mark.parametrize( + "inputData,jobType", + [ + ({f"LFN:/{VO}/{i}" for i in range(10001)}, "User"), # Too many files for User job + ], +) +def test_inputDataValidator_invalid(inputData: set[str], jobType: str): + """Test the inputData validator with invalid input.""" + with pytest.raises(ValidationError): + JobDescriptionModelForTest( + executable=EXECUTABLE, + inputData=inputData, + jobType=jobType, + ) + + +def test_inputDataValidator_basic(): + """Test basic inputData functionality without specific validation.""" + job = JobDescriptionModelForTest(executable=EXECUTABLE, jobType="Test") + assert job.jobType == "Test" + + +@pytest.mark.parametrize( + "minNumberOfProcessors,maxNumberOfProcessors", + [ + (1, 1), # Same values + (1, 4), # Valid range + (2, 8), # Valid range + ], +) +def test_numberOfProcessorsValidator_valid(minNumberOfProcessors: int, maxNumberOfProcessors: int): + """Test the numberOfProcessors validator with valid input.""" + JobDescriptionModelForTest( + executable=EXECUTABLE, + minNumberOfProcessors=minNumberOfProcessors, + maxNumberOfProcessors=maxNumberOfProcessors, + ) + + +@pytest.mark.parametrize( + "minNumberOfProcessors,maxNumberOfProcessors", + [ + (0, 1), # Min too low + (1, 0), # Max too low + (2, 1), # Min > Max + (1025, 1025), # Both too high + ], +) +def test_numberOfProcessorsValidator_invalid(minNumberOfProcessors: int, maxNumberOfProcessors: int): + """Test the numberOfProcessors validator with invalid input.""" + with pytest.raises(ValidationError): + JobDescriptionModelForTest( + executable=EXECUTABLE, + minNumberOfProcessors=minNumberOfProcessors, + maxNumberOfProcessors=maxNumberOfProcessors, + ) + + +def test_basic_model_creation(): + """Test basic model creation with minimal required fields.""" + model = JobDescriptionModelForTest(executable=EXECUTABLE) + assert model.executable == EXECUTABLE + assert model.cpuTime == 86400 # Should use default from config + assert model.priority == 1 # Should use default from config + + +def test_model_with_all_fields(): + """Test model creation with many fields.""" + model = JobDescriptionModelForTest( + executable=EXECUTABLE, + arguments="--help", + cpuTime=3600, + priority=5, + jobType="Test", + inputData={f"LFN:/{VO}/test.root"}, + outputData={f"LFN:/{VO}/output.root"}, + inputSandbox={"script.py"}, + outputSandbox={"log.txt"}, + minNumberOfProcessors=1, + maxNumberOfProcessors=2, + platform="x86_64-el9-gcc11-opt", + sites={"LCG.CERN.ch"}, + bannedSites={"LCG.Broken.ch"}, + tags={"GPU", "HighMem"}, + ) + + assert model.executable == EXECUTABLE + assert model.arguments == "--help" + assert model.cpuTime == 3600 + assert model.priority == 5 + assert model.jobType == "Test" + assert f"LFN:/{VO}/test.root" in model.inputData + assert f"LFN:/{VO}/output.root" in model.outputData + assert "script.py" in model.inputSandbox + assert "log.txt" in model.outputSandbox + assert model.minNumberOfProcessors == 1 + assert model.maxNumberOfProcessors == 2 + assert model.platform == "x86_64-el9-gcc11-opt" + assert "LCG.CERN.ch" in model.sites + # Note: bannedSites may be processed differently, just check it was set + assert isinstance(model.bannedSites, set) + assert "GPU" in model.tags + assert "HighMem" in model.tags + + +def test_outputDataValidator(): + """Test output data validation.""" + model = JobDescriptionModelForTest( + executable=EXECUTABLE, + outputData={f"LFN:/{VO}/output1.root", f"LFN:/{VO}/output2.root"}, + ) + assert len(model.outputData) == 2 + assert f"LFN:/{VO}/output1.root" in model.outputData + assert f"LFN:/{VO}/output2.root" in model.outputData + + +def test_sandboxValidator(): + """Test sandbox validation.""" + model = JobDescriptionModelForTest( + executable=EXECUTABLE, + inputSandbox={"script.py", "config.txt"}, + outputSandbox={"log.txt", "results.dat"}, + ) + assert len(model.inputSandbox) == 2 + assert "script.py" in model.inputSandbox + assert "config.txt" in model.inputSandbox + assert len(model.outputSandbox) == 2 + assert "log.txt" in model.outputSandbox + assert "results.dat" in model.outputSandbox diff --git a/dirac-common/tests/WorkloadManagementSystem/Utilities/test_JobStatusUtility.py b/dirac-common/tests/WorkloadManagementSystem/Utilities/test_JobStatusUtility.py new file mode 100644 index 00000000000..705016a6dd8 --- /dev/null +++ b/dirac-common/tests/WorkloadManagementSystem/Utilities/test_JobStatusUtility.py @@ -0,0 +1,164 @@ +"""Test the JobStatusUtility stateless functions.""" + +import unittest +from datetime import datetime +from unittest.mock import MagicMock + +from DIRACCommon.Core.Utilities.ReturnValues import S_OK, S_ERROR +from DIRACCommon.WorkloadManagementSystem.Client.JobStatus import WAITING, MATCHED, RUNNING, DONE, FAILED +from DIRACCommon.WorkloadManagementSystem.Utilities.JobStatusUtility import getStartAndEndTime, getNewStatus + + +class TestJobStatusUtility(unittest.TestCase): + """Test cases for JobStatusUtility functions""" + + def test_getStartAndEndTime_no_running_status(self): + """Test getStartAndEndTime when job never reaches running state""" + startTime = None + endTime = None + updateTimes = ["2023-01-01 10:00:00", "2023-01-01 11:00:00"] + timeStamps = [(1672563600.0, WAITING), (1672567200.0, MATCHED)] + statusDict = {"2023-01-01 10:00:00": {"Status": WAITING}, "2023-01-01 11:00:00": {"Status": MATCHED}} + + newStartTime, newEndTime = getStartAndEndTime(startTime, endTime, updateTimes, timeStamps, statusDict) + + self.assertIsNone(newStartTime) + self.assertIsNone(newEndTime) + + def test_getStartAndEndTime_with_running_and_done(self): + """Test getStartAndEndTime when job runs and completes""" + startTime = None + endTime = None + updateTimes = [ + "2023-01-01 10:00:00", # WAITING + "2023-01-01 11:00:00", # MATCHED + "2023-01-01 12:00:00", # RUNNING + "2023-01-01 13:00:00", # DONE + ] + timeStamps = [(1672563600.0, WAITING), (1672567200.0, MATCHED), (1672570800.0, RUNNING), (1672574400.0, DONE)] + statusDict = { + "2023-01-01 10:00:00": {"Status": WAITING}, + "2023-01-01 11:00:00": {"Status": MATCHED}, + "2023-01-01 12:00:00": {"Status": RUNNING}, + "2023-01-01 13:00:00": {"Status": DONE}, + } + + newStartTime, newEndTime = getStartAndEndTime(startTime, endTime, updateTimes, timeStamps, statusDict) + + self.assertEqual(newStartTime, "2023-01-01 12:00:00") # When it started running + self.assertEqual(newEndTime, "2023-01-01 13:00:00") # When it finished + + def test_getStartAndEndTime_existing_start_time(self): + """Test getStartAndEndTime when startTime already exists""" + startTime = "2023-01-01 09:00:00" # Already set + endTime = None + updateTimes = ["2023-01-01 12:00:00", "2023-01-01 13:00:00"] + timeStamps = [(1672570800.0, RUNNING), (1672574400.0, DONE)] + statusDict = {"2023-01-01 12:00:00": {"Status": RUNNING}, "2023-01-01 13:00:00": {"Status": DONE}} + + newStartTime, newEndTime = getStartAndEndTime(startTime, endTime, updateTimes, timeStamps, statusDict) + + self.assertEqual(newStartTime, "2023-01-01 09:00:00") # Should keep existing + self.assertEqual(newEndTime, "2023-01-01 13:00:00") # Should set end time + + def test_getNewStatus_simple_progression(self): + """Test getNewStatus with simple status progression""" + jobID = 123 + updateTimes = [datetime.fromisoformat("2023-01-01 10:00:00")] + lastTime = datetime.fromisoformat("2023-01-01 09:00:00") + statusDict = { + datetime.fromisoformat("2023-01-01 10:00:00"): { + "Status": MATCHED, + "MinorStatus": "JobAgent", + "ApplicationStatus": "Starting", + } + } + currentStatus = WAITING + force = False + + # Mock logger + log = MagicMock() + log.debug = MagicMock() + log.error = MagicMock() + + result = getNewStatus(jobID, updateTimes, lastTime, statusDict, currentStatus, force, log) + + self.assertTrue(result["OK"]) + status, minor, application = result["Value"] + self.assertEqual(status, MATCHED) + self.assertEqual(minor, "JobAgent") + self.assertEqual(application, "Starting") + + def test_getNewStatus_no_updates_after_last_time(self): + """Test getNewStatus when no updates after lastTime""" + jobID = 123 + updateTimes = [datetime.fromisoformat("2023-01-01 08:00:00")] # Before lastTime + lastTime = datetime.fromisoformat("2023-01-01 09:00:00") + statusDict = {datetime.fromisoformat("2023-01-01 08:00:00"): {"Status": WAITING}} + currentStatus = WAITING + force = False + log = MagicMock() + + result = getNewStatus(jobID, updateTimes, lastTime, statusDict, currentStatus, force, log) + + self.assertTrue(result["OK"]) + status, minor, application = result["Value"] + self.assertEqual(status, "") # No status change + self.assertEqual(minor, "") + self.assertEqual(application, "") + + def test_getNewStatus_multiple_updates(self): + """Test getNewStatus with multiple status updates""" + jobID = 123 + updateTimes = [ + datetime.fromisoformat("2023-01-01 10:00:00"), + datetime.fromisoformat("2023-01-01 11:00:00"), + datetime.fromisoformat("2023-01-01 12:00:00"), + ] + lastTime = datetime.fromisoformat("2023-01-01 09:00:00") + statusDict = { + datetime.fromisoformat("2023-01-01 10:00:00"): {"Status": MATCHED, "MinorStatus": "Pilot Agent"}, + datetime.fromisoformat("2023-01-01 11:00:00"): {"Status": RUNNING, "ApplicationStatus": "Running"}, + datetime.fromisoformat("2023-01-01 12:00:00"): { + "Status": DONE, + "MinorStatus": "Execution Complete", + "ApplicationStatus": "Success", + }, + } + currentStatus = WAITING + force = False + log = MagicMock() + + result = getNewStatus(jobID, updateTimes, lastTime, statusDict, currentStatus, force, log) + + self.assertTrue(result["OK"]) + status, minor, application = result["Value"] + self.assertEqual(status, DONE) # Final status + self.assertEqual(minor, "Execution Complete") # Final minor status + self.assertEqual(application, "Success") # Final application status + + def test_getNewStatus_force_mode(self): + """Test getNewStatus with force=True bypasses state machine""" + jobID = 123 + updateTimes = [datetime.fromisoformat("2023-01-01 10:00:00")] + lastTime = datetime.fromisoformat("2023-01-01 09:00:00") + statusDict = { + datetime.fromisoformat("2023-01-01 10:00:00"): { + "Status": DONE, # Direct jump to DONE (would normally be rejected) + "MinorStatus": "Forced", + } + } + currentStatus = WAITING + force = True # Force mode + log = MagicMock() + + result = getNewStatus(jobID, updateTimes, lastTime, statusDict, currentStatus, force, log) + + self.assertTrue(result["OK"]) + status, minor, application = result["Value"] + self.assertEqual(status, DONE) + self.assertEqual(minor, "Forced") + + +if __name__ == "__main__": + unittest.main() diff --git a/src/DIRAC/Core/Utilities/JDL.py b/src/DIRAC/Core/Utilities/JDL.py index 935094f1ed9..3e63be38f0f 100644 --- a/src/DIRAC/Core/Utilities/JDL.py +++ b/src/DIRAC/Core/Utilities/JDL.py @@ -1,203 +1,9 @@ """Transformation classes around the JDL format.""" -from diraccfg import CFG -from pydantic import ValidationError +from DIRACCommon.Core.Utilities.JDL import * # noqa: F403,F401 -from DIRAC import S_OK, S_ERROR -from DIRAC.Core.Utilities import List -from DIRAC.Core.Utilities.ClassAd.ClassAdLight import ClassAd from DIRAC.WorkloadManagementSystem.Utilities.JobModel import BaseJobDescriptionModel -ARGUMENTS = "Arguments" -BANNED_SITES = "BannedSites" -CPU_TIME = "CPUTime" -EXECUTABLE = "Executable" -EXECUTION_ENVIRONMENT = "ExecutionEnvironment" -GRID_CE = "GridCE" -INPUT_DATA = "InputData" -INPUT_DATA_POLICY = "InputDataPolicy" -INPUT_SANDBOX = "InputSandbox" -JOB_CONFIG_ARGS = "JobConfigArgs" -JOB_TYPE = "JobType" -JOB_GROUP = "JobGroup" -LOG_LEVEL = "LogLevel" -NUMBER_OF_PROCESSORS = "NumberOfProcessors" -MAX_NUMBER_OF_PROCESSORS = "MaxNumberOfProcessors" -MIN_NUMBER_OF_PROCESSORS = "MinNumberOfProcessors" -OUTPUT_DATA = "OutputData" -OUTPUT_PATH = "OutputPath" -OUTPUT_SE = "OutputSE" -PLATFORM = "Platform" -PRIORITY = "Priority" -STD_ERROR = "StdError" -STD_OUTPUT = "StdOutput" -OUTPUT_SANDBOX = "OutputSandbox" -JOB_NAME = "JobName" -SITE = "Site" -TAGS = "Tags" - -OWNER = "Owner" -OWNER_GROUP = "OwnerGroup" -VO = "VirtualOrganization" - -CREDENTIALS_FIELDS = {OWNER, OWNER_GROUP, VO} - - -def loadJDLAsCFG(jdl): - """ - Load a JDL as CFG - """ - - def cleanValue(value): - value = value.strip() - if value[0] == '"': - entries = [] - iPos = 1 - current = "" - state = "in" - while iPos < len(value): - if value[iPos] == '"': - if state == "in": - entries.append(current) - current = "" - state = "out" - elif state == "out": - current = current.strip() - if current not in (",",): - return S_ERROR("value seems a list but is not separated in commas") - current = "" - state = "in" - else: - current += value[iPos] - iPos += 1 - if state == "in": - return S_ERROR('value is opened with " but is not closed') - return S_OK(", ".join(entries)) - else: - return S_OK(value.replace('"', "")) - - def assignValue(key, value, cfg): - key = key.strip() - if len(key) == 0: - return S_ERROR("Invalid key name") - value = value.strip() - if not value: - return S_ERROR(f"No value for key {key}") - if value[0] == "{": - if value[-1] != "}": - return S_ERROR("Value '%s' seems a list but does not end in '}'" % (value)) - valList = List.fromChar(value[1:-1]) - for i in range(len(valList)): - result = cleanValue(valList[i]) - if not result["OK"]: - return S_ERROR(f"Var {key} : {result['Message']}") - valList[i] = result["Value"] - if valList[i] is None: - return S_ERROR(f"List value '{value}' seems invalid for item {i}") - value = ", ".join(valList) - else: - result = cleanValue(value) - if not result["OK"]: - return S_ERROR(f"Var {key} : {result['Message']}") - nV = result["Value"] - if nV is None: - return S_ERROR(f"Value '{value} seems invalid") - value = nV - cfg.setOption(key, value) - return S_OK() - - if jdl[0] == "[": - iPos = 1 - else: - iPos = 0 - key = "" - value = "" - action = "key" - insideLiteral = False - cfg = CFG() - while iPos < len(jdl): - char = jdl[iPos] - if char == ";" and not insideLiteral: - if key.strip(): - result = assignValue(key, value, cfg) - if not result["OK"]: - return result - key = "" - value = "" - action = "key" - elif char == "[" and not insideLiteral: - key = key.strip() - if not key: - return S_ERROR("Invalid key in JDL") - if value.strip(): - return S_ERROR(f"Key {key} seems to have a value and open a sub JDL at the same time") - result = loadJDLAsCFG(jdl[iPos:]) - if not result["OK"]: - return result - subCfg, subPos = result["Value"] - cfg.createNewSection(key, contents=subCfg) - key = "" - value = "" - action = "key" - insideLiteral = False - iPos += subPos - elif char == "=" and not insideLiteral: - if action == "key": - action = "value" - insideLiteral = False - else: - value += char - elif char == "]" and not insideLiteral: - key = key.strip() - if len(key) > 0: - result = assignValue(key, value, cfg) - if not result["OK"]: - return result - return S_OK((cfg, iPos)) - else: - if action == "key": - key += char - else: - value += char - if char == '"': - insideLiteral = not insideLiteral - iPos += 1 - - return S_OK((cfg, iPos)) - - -def dumpCFGAsJDL(cfg, level=1, tab=" "): - indent = tab * level - contents = [f"{tab * (level - 1)}["] - sections = cfg.listSections() - - for key in cfg: - if key in sections: - contents.append(f"{indent}{key} =") - contents.append(f"{dumpCFGAsJDL(cfg[key], level + 1, tab)};") - else: - val = List.fromChar(cfg[key]) - # Some attributes are never lists - if len(val) < 2 or key in [ARGUMENTS, EXECUTABLE, STD_OUTPUT, STD_ERROR]: - value = cfg[key] - try: - try_value = float(value) - contents.append(f"{tab * level}{key} = {value};") - except Exception: - contents.append(f'{tab * level}{key} = "{value}";') - else: - contents.append(f"{indent}{key} =") - contents.append("%s{" % indent) - for iPos in range(len(val)): - try: - value = float(val[iPos]) - except Exception: - val[iPos] = f'"{val[iPos]}"' - contents.append(",\n".join([f"{tab * (level + 1)}{value}" for value in val])) - contents.append("%s};" % indent) - contents.append(f"{tab * (level - 1)}]") - return "\n".join(contents) - def jdlToBaseJobDescriptionModel(classAd: ClassAd): """ diff --git a/src/DIRAC/Core/Utilities/List.py b/src/DIRAC/Core/Utilities/List.py old mode 100755 new mode 100644 index ea8e121af22..6bbba95840c --- a/src/DIRAC/Core/Utilities/List.py +++ b/src/DIRAC/Core/Utilities/List.py @@ -1,127 +1 @@ -"""Collection of DIRAC useful list related modules. - By default on Error they return None. -""" -import random -import sys -from typing import Any, TypeVar -from collections.abc import Iterable - -T = TypeVar("T") - - -def uniqueElements(aList: list) -> list: - """Utility to retrieve list of unique elements in a list (order is kept).""" - - # Use dict.fromkeys instead of set ensure the order is preserved - return list(dict.fromkeys(aList)) - - -def appendUnique(aList: list, anObject: Any): - """Append to list if object does not exist. - - :param aList: list of elements - :param anObject: object you want to append - """ - if anObject not in aList: - aList.append(anObject) - - -def fromChar(inputString: str, sepChar: str = ","): - """Generates a list splitting a string by the required character(s) - resulting string items are stripped and empty items are removed. - - :param inputString: list serialised to string - :param sepChar: separator - :return: list of strings or None if sepChar has a wrong type - """ - # to prevent getting an empty String as argument - if not (isinstance(inputString, str) and isinstance(sepChar, str) and sepChar): - return None - return [fieldString.strip() for fieldString in inputString.split(sepChar) if len(fieldString.strip()) > 0] - - -def randomize(aList: Iterable[T]) -> list[T]: - """Return a randomly sorted list. - - :param aList: list to permute - """ - tmpList = list(aList) - random.shuffle(tmpList) - return tmpList - - -def pop(aList, popElement): - """Pop the first element equal to popElement from the list. - - :param aList: list - :type aList: python:list - :param popElement: element to pop - """ - if popElement in aList: - return aList.pop(aList.index(popElement)) - - -def stringListToString(aList: list) -> str: - """This function is used for making MySQL queries with a list of string elements. - - :param aList: list to be serialized to string for making queries - """ - return ",".join(f"'{x}'" for x in aList) - - -def intListToString(aList: list) -> str: - """This function is used for making MySQL queries with a list of int elements. - - :param aList: list to be serialized to string for making queries - """ - return ",".join(str(x) for x in aList) - - -def getChunk(aList: list, chunkSize: int): - """Generator yielding chunk from a list of a size chunkSize. - - :param aList: list to be splitted - :param chunkSize: lenght of one chunk - :raise: StopIteration - - Usage: - - >>> for chunk in getChunk( aList, chunkSize=10): - process( chunk ) - - """ - chunkSize = int(chunkSize) - for i in range(0, len(aList), chunkSize): - yield aList[i : i + chunkSize] - - -def breakListIntoChunks(aList: list, chunkSize: int): - """This function takes a list as input and breaks it into list of size 'chunkSize'. - It returns a list of lists. - - :param aList: list of elements - :param chunkSize: len of a single chunk - :return: list of lists of length of chunkSize - :raise: RuntimeError if numberOfFilesInChunk is less than 1 - """ - if chunkSize < 1: - raise RuntimeError("chunkSize cannot be less than 1") - if isinstance(aList, (set, dict, tuple, {}.keys().__class__, {}.items().__class__, {}.values().__class__)): - aList = list(aList) - return [chunk for chunk in getChunk(aList, chunkSize)] - - -def getIndexInList(anItem: Any, aList: list) -> int: - """Return the index of the element x in the list l - or sys.maxint if it does not exist - - :param anItem: element to look for - :param aList: list to look into - - :return: the index or sys.maxint - """ - # try: - if anItem in aList: - return aList.index(anItem) - else: - return sys.maxsize +from DIRACCommon.Core.Utilities.List import * # noqa: F401,F403 diff --git a/src/DIRAC/Core/Utilities/StateMachine.py b/src/DIRAC/Core/Utilities/StateMachine.py index aece8b95427..b0bdb7e9d8d 100644 --- a/src/DIRAC/Core/Utilities/StateMachine.py +++ b/src/DIRAC/Core/Utilities/StateMachine.py @@ -1,185 +1,19 @@ -""" StateMachine +"""Backward compatibility wrapper - moved to DIRACCommon - This module contains the basic blocks to build a state machine (State and StateMachine) -""" -from DIRAC import S_OK, S_ERROR, gLogger - - -class State: - """ - State class that represents a single step on a StateMachine, with all the - possible transitions, the default transition and an ordering level. - - - examples: - >>> s0 = State(100) - >>> s1 = State(0, ['StateName1', 'StateName2'], defState='StateName1') - >>> s2 = State(0, ['StateName1', 'StateName2']) - # this example is tricky. The transition rule says that will go to - # nextState, e.g. 'StateNext'. But, it is not on the stateMap, and there - # is no default defined, so it will end up going to StateNext anyway. You - # must be careful while defining states and their stateMaps and defaults. - """ - - def __init__(self, level, stateMap=None, defState=None): - """ - :param int level: each state is mapped to an integer, which is used to sort the states according to that integer. - :param list stateMap: it is a list (of strings) with the reachable states from this particular status. - If not defined, we assume there are no restrictions. - :param str defState: default state used in case the next state is not in stateMap (not defined or simply not there). - """ - - self.level = level - self.stateMap = stateMap if stateMap else [] - self.default = defState - - def transitionRule(self, nextState): - """ - Method that selects next state, knowing the default and the transitions - map, and the proposed next state. If is in stateMap, goes there. - If not, then goes to if any. Otherwise, goes to - anyway. - - examples: - >>> s0.transitionRule('nextState') - 'nextState' - >>> s1.transitionRule('StateName2') - 'StateName2' - >>> s1.transitionRule('StateNameNotInMap') - 'StateName1' - >>> s2.transitionRule('StateNameNotInMap') - 'StateNameNotInMap' - - :param str nextState: name of the state in the stateMap - :return: state name - :rtype: str - """ - - # If next state is on the list of next states, go ahead. - if nextState in self.stateMap: - return nextState - - # If not, calculate defaultState: - # if there is a default, that one - # otherwise is nextState (states with empty list have no movement restrictions) - defaultNext = self.default if self.default else nextState - return defaultNext - - -class StateMachine: - """ - StateMachine class that represents the whole state machine with all transitions. - - examples: - >>> sm0 = StateMachine() - >>> sm1 = StateMachine(state = 'Active') - - :param state: current state of the StateMachine, could be None if we do not use the - StateMachine to calculate transitions. Beware, it is not checked if the - state is on the states map ! - :type state: None or str - - """ +This module has been moved to DIRACCommon.Core.Utilities.StateMachine to avoid +circular dependencies and allow DiracX to use these utilities without +triggering DIRAC's global state initialization. - def __init__(self, state=None): - """ - Constructor. - """ - - self.state = state - # To be overwritten by child classes, unless you like Nirvana state that much. - self.states = {"Nirvana": State(100)} - - def getLevelOfState(self, state): - """ - Given a state name, it returns its level (integer), which defines the hierarchy. +All exports are maintained for backward compatibility. +""" +# Re-export everything from DIRACCommon for backward compatibility +from DIRACCommon.Core.Utilities.StateMachine import * # noqa: F401, F403 - >>> sm0.getLevelOfState('Nirvana') - 100 - >>> sm0.getLevelOfState('AnotherState') - -1 +from DIRAC import gLogger - :param str state: name of the state, it should be on key set - :return: `int` || -1 (if not in ) - """ - if state not in self.states: - return -1 - return self.states[state].level +class StateMachine(StateMachine): # noqa: F405 pylint: disable=function-redefined + """Backward compatibility wrapper - moved to DIRACCommon""" def setState(self, candidateState, noWarn=False): - """Makes sure the state is either None or known to the machine, and that it is a valid state to move into. - Final states are also checked. - - examples: - >>> sm0.setState(None)['OK'] - True - >>> sm0.setState('Nirvana')['OK'] - True - >>> sm0.setState('AnotherState')['OK'] - False - - :param state: state which will be set as current state of the StateMachine - :type state: None or str - :return: S_OK || S_ERROR - """ - if candidateState == self.state: - return S_OK(candidateState) - - if not candidateState: - self.state = candidateState - elif candidateState in self.states: - if not self.states[self.state].stateMap: - if not noWarn: - gLogger.warn("Final state, won't move", f"({self.state}, asked to move to {candidateState})") - return S_OK(self.state) - if candidateState not in self.states[self.state].stateMap: - gLogger.warn(f"Can't move from {self.state} to {candidateState}, choosing a good one") - result = self.getNextState(candidateState) - if not result["OK"]: - return result - self.state = result["Value"] - # If the StateMachine does not accept the candidate, return error message - else: - return S_ERROR(f"setState: {candidateState!r} is not a valid state") - - return S_OK(self.state) - - def getStates(self): - """ - Returns all possible states in the state map - - examples: - >>> sm0.getStates() - [ 'Nirvana' ] - - :return: list(stateNames) - """ - - return list(self.states) - - def getNextState(self, candidateState): - """ - Method that gets the next state, given the proposed transition to candidateState. - If candidateState is not on the state map , it is rejected. If it is - not the case, we have two options: if is None, then the next state - will be . Otherwise, the current state is using its own - transition rule to decide. - - examples: - >>> sm0.getNextState(None) - S_OK(None) - >>> sm0.getNextState('NextState') - S_OK('NextState') - - :param str candidateState: name of the next state - :return: S_OK(nextState) || S_ERROR - """ - if candidateState not in self.states: - return S_ERROR(f"getNextState: {candidateState!r} is not a valid state") - - # FIXME: do we need this anymore ? - if self.state is None: - return S_OK(candidateState) - - return S_OK(self.states[self.state].transitionRule(candidateState)) + return super().setState(candidateState, noWarn, logger_warn=gLogger.warn) diff --git a/src/DIRAC/Core/Utilities/TimeUtilities.py b/src/DIRAC/Core/Utilities/TimeUtilities.py index 6dc969008c9..dd17d0485a8 100755 --- a/src/DIRAC/Core/Utilities/TimeUtilities.py +++ b/src/DIRAC/Core/Utilities/TimeUtilities.py @@ -1,260 +1,17 @@ -""" -DIRAC TimeUtilities module -Support for basic Date and Time operations -based on system datetime module. - -It provides common interface to UTC timestamps, -converter to string types and back. - -Useful timedelta constant are also provided to -define time intervals. - -Notice: datetime.timedelta objects allow multiplication and division by interger -but not by float. Thus: - - - DIRAC.TimeUtilities.second * 1.5 is not allowed - - DIRAC.TimeUtilities.second * 3 / 2 is allowed +"""Backward compatibility wrapper - moved to DIRACCommon -An timeInterval class provides a method to check -if a give datetime is in the defined interval. +This module has been moved to DIRACCommon.Core.Utilities.TimeUtilities to avoid +circular dependencies and allow DiracX to use these utilities without +triggering DIRAC's global state initialization. +All exports are maintained for backward compatibility. """ -import datetime -import sys -import time - -from DIRAC import gLogger - -# Some useful constants for time operations -microsecond = datetime.timedelta(microseconds=1) -second = datetime.timedelta(seconds=1) -minute = datetime.timedelta(minutes=1) -hour = datetime.timedelta(hours=1) -day = datetime.timedelta(days=1) -week = datetime.timedelta(days=7) - - -def timeThis(method): - """Function to be used as a decorator for timing other functions/methods""" - - def timed(*args, **kw): - """What actually times""" - ts = time.time() - result = method(*args, **kw) - if sys.stdout.isatty(): - return result - te = time.time() - - pre = datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S UTC ") - - try: - pre += args[0].log.getName() + "/" + args[0].log.getSubName() + " TIME: " + args[0].transString - except AttributeError: - try: - pre += args[0].log.getName() + " TIME: " + args[0].transString - except AttributeError: - try: - pre += args[0].log.getName() + "/" + args[0].log.getSubName() + " TIME: " - except AttributeError: - pre += "TIME: " - except IndexError: - pre += "TIME: " - - argsLen = "" - if args: - try: - if isinstance(args[1], (list, dict)): - argsLen = f"arguments len: {len(args[1])}" - except IndexError: - if kw: - try: - if isinstance(list(list(kw.items())[0])[1], (list, dict)): - argsLen = f"arguments len: {len(list(list(kw.items())[0])[1])}" - except IndexError: - argsLen = "" - - gLogger.info(f"{pre} Exec time ===> function {method.__name__!r} {argsLen} -> {te - ts:2.2f} sec") - return result - - return timed - - -def toEpoch(dateTimeObject=None): - """ - Get seconds since epoch. Accepts datetime or date objects - """ - return toEpochMilliSeconds(dateTimeObject) // 1000 +from functools import partial +# Re-export everything from DIRACCommon for backward compatibility +from DIRACCommon.Core.Utilities.TimeUtilities import * # noqa: F401, F403 -def toEpochMilliSeconds(dateTimeObject=None): - """ - Get milliseconds since epoch - """ - if dateTimeObject is None: - dateTimeObject = datetime.datetime.utcnow() - if dateTimeObject.resolution == datetime.timedelta(days=1): - # Add time information corresponding to midnight UTC if it's a datetime.date - dateTimeObject = datetime.datetime.combine( - dateTimeObject, datetime.time.min.replace(tzinfo=datetime.timezone.utc) - ) - posixTime = dateTimeObject.replace(tzinfo=datetime.timezone.utc).timestamp() - return int(posixTime * 1000) - - -def fromEpoch(epoch): - """ - Get datetime object from epoch - """ - # Check if the timestamp is in milliseconds - if epoch > 10**17: # nanoseconds - epoch /= 1000**3 - elif epoch > 10**14: # microseconds - epoch /= 1000**2 - elif epoch > 10**11: # milliseconds - epoch /= 1000 - return datetime.datetime.utcfromtimestamp(epoch) - - -def toString(myDate=None): - """ - Convert to String - if argument type is neither _dateTimeType, _dateType, nor _timeType - the current dateTime converted to String is returned instead - - Notice: datetime.timedelta are converted to strings using the format: - [day] days [hour]:[min]:[sec]:[microsec] - where hour, min, sec, microsec are always positive integers, - and day carries the sign. - To keep internal consistency we are using: - [hour]:[min]:[sec]:[microsec] - where min, sec, microsec are always positive integers and hour carries the sign. - """ - if isinstance(myDate, datetime.date): - return str(myDate) - - elif isinstance(myDate, datetime.time): - return "%02d:%02d:%02d.%06d" % ( - myDate.days * 24 + myDate.seconds / 3600, - myDate.seconds % 3600 / 60, - myDate.seconds % 60, - myDate.microseconds, - ) - else: - return toString(datetime.datetime.utcnow()) - - -def fromString(myDate=None): - """ - Convert date/time/datetime String back to appropriated objects - - The format of the string it is assume to be that returned by toString method. - See notice on toString method - On Error, return None - - :param myDate: the date string to be converted - :type myDate: str or datetime.datetime - """ - if isinstance(myDate, datetime.datetime): - return myDate - if isinstance(myDate, str): - if myDate.find(" ") > 0: - dateTimeTuple = myDate.split(" ") - dateTuple = dateTimeTuple[0].split("-") - try: - return datetime.datetime(year=dateTuple[0], month=dateTuple[1], day=dateTuple[2]) + fromString( - dateTimeTuple[1] - ) - # return datetime.datetime.utcnow().combine( fromString( dateTimeTuple[0] ), - # fromString( dateTimeTuple[1] ) ) - except Exception: - try: - return datetime.datetime( - year=int(dateTuple[0]), month=int(dateTuple[1]), day=int(dateTuple[2]) - ) + fromString(dateTimeTuple[1]) - except ValueError: - return None - # return datetime.datetime.utcnow().combine( fromString( dateTimeTuple[0] ), - # fromString( dateTimeTuple[1] ) ) - elif myDate.find(":") > 0: - timeTuple = myDate.replace(".", ":").split(":") - try: - if len(timeTuple) == 4: - return datetime.timedelta( - hours=int(timeTuple[0]), - minutes=int(timeTuple[1]), - seconds=int(timeTuple[2]), - microseconds=int(timeTuple[3]), - ) - elif len(timeTuple) == 3: - try: - return datetime.timedelta( - hours=int(timeTuple[0]), - minutes=int(timeTuple[1]), - seconds=int(timeTuple[2]), - microseconds=0, - ) - except ValueError: - return None - else: - return None - except Exception: - return None - elif myDate.find("-") > 0: - dateTuple = myDate.split("-") - try: - return datetime.date(int(dateTuple[0]), int(dateTuple[1]), int(dateTuple[2])) - except Exception: - return None - - return None - - -class timeInterval: - """ - Simple class to define a timeInterval object able to check if a given - dateTime is inside - """ - - def __init__(self, initialDateTime, intervalTimeDelta): - """ - Initialization method, it requires the initial dateTime and the - timedelta that define the limits. - The upper limit is not included thus it is [begin,end) - If not properly initialized an error flag is set, and subsequent calls - to any method will return None - """ - if not isinstance(initialDateTime, datetime.datetime) or not isinstance(intervalTimeDelta, datetime.timedelta): - self.__error = True - return None - self.__error = False - if intervalTimeDelta.days < 0: - self.__startDateTime = initialDateTime + intervalTimeDelta - self.__endDateTime = initialDateTime - else: - self.__startDateTime = initialDateTime - self.__endDateTime = initialDateTime + intervalTimeDelta - - def includes(self, myDateTime): - """ """ - if self.__error: - return None - if not isinstance(myDateTime, datetime.datetime): - return None - if myDateTime < self.__startDateTime: - return False - if myDateTime >= self.__endDateTime: - return False - return True - - -def queryTime(f): - """Decorator to measure the function call time""" +from DIRAC import gLogger - def measureQueryTime(*args, **kwargs): - start = time.time() - result = f(*args, **kwargs) - if result["OK"] and "QueryTime" not in result: - result["QueryTime"] = time.time() - start - return result - return measureQueryTime +timeThis = partial(timeThis, logger_info=gLogger.info) diff --git a/src/DIRAC/Core/Utilities/test/Test_JDL.py b/src/DIRAC/Core/Utilities/test/Test_JDL.py index b918bf3bb2e..22cfd1ba69e 100644 --- a/src/DIRAC/Core/Utilities/test/Test_JDL.py +++ b/src/DIRAC/Core/Utilities/test/Test_JDL.py @@ -19,9 +19,6 @@ def jdl_monkey_business(monkeypatch): monkeypatch.setattr("DIRAC.Core.Base.API.getSites", lambda: S_OK(["LCG.IN2P3.fr"])) monkeypatch.setattr("DIRAC.WorkloadManagementSystem.Utilities.JobModel.getSites", lambda: S_OK(["LCG.IN2P3.fr"])) monkeypatch.setattr("DIRAC.Interfaces.API.Job.getDIRACPlatforms", lambda: S_OK("x86_64-slc6-gcc49-opt")) - monkeypatch.setattr( - "DIRAC.WorkloadManagementSystem.Utilities.JobModel.getDIRACPlatforms", lambda: S_OK("x86_64-slc6-gcc49-opt") - ) yield diff --git a/src/DIRAC/WorkloadManagementSystem/Client/JobState/JobManifest.py b/src/DIRAC/WorkloadManagementSystem/Client/JobState/JobManifest.py index 882d30658f6..6f05f3633f7 100644 --- a/src/DIRAC/WorkloadManagementSystem/Client/JobState/JobManifest.py +++ b/src/DIRAC/WorkloadManagementSystem/Client/JobState/JobManifest.py @@ -1,266 +1,37 @@ -from diraccfg import CFG +from __future__ import annotations -from DIRAC import S_OK, S_ERROR -from DIRAC.Core.Utilities import List -from DIRAC.ConfigurationSystem.Client.Helpers.Operations import Operations -from DIRAC.Core.Utilities.JDL import loadJDLAsCFG, dumpCFGAsJDL - - -class JobManifest: - def __init__(self, manifest=""): - self.__manifest = CFG() - self.__dirty = False - self.__ops = False - if manifest: - result = self.load(manifest) - if not result["OK"]: - raise Exception(result["Message"]) - - def isDirty(self): - return self.__dirty - - def setDirty(self): - self.__dirty = True - - def clearDirty(self): - self.__dirty = False - - def load(self, dataString): - """ - Auto discover format type based on [ .. ] of JDL - """ - dataString = dataString.strip() - if dataString[0] == "[" and dataString[-1] == "]": - return self.loadJDL(dataString) - else: - return self.loadCFG(dataString) - - def loadJDL(self, jdlString): - """ - Load job manifest from JDL format - """ - result = loadJDLAsCFG(jdlString.strip()) - if not result["OK"]: - self.__manifest = CFG() - return result - self.__manifest = result["Value"][0] - return S_OK() - - def loadCFG(self, cfgString): - """ - Load job manifest from CFG format - """ - try: - self.__manifest.loadFromBuffer(cfgString) - except Exception as e: - return S_ERROR(f"Can't load manifest from cfg: {str(e)}") - return S_OK() - - def dumpAsCFG(self): - return str(self.__manifest) - - def getAsCFG(self): - return self.__manifest.clone() - - def dumpAsJDL(self): - return dumpCFGAsJDL(self.__manifest) - - def __getCSValue(self, varName, defaultVal=None): - if not self.__ops: - self.__ops = Operations(group=self.__manifest["OwnerGroup"]) - if varName[0] != "/": - varName = f"JobDescription/{varName}" - return self.__ops.getValue(varName, defaultVal) - - def __checkNumericalVar(self, varName, defaultVal, minVal, maxVal): - """ - Check a numerical var - """ - initialVal = False - if varName not in self.__manifest: - varValue = self.__getCSValue(f"Default{varName}", defaultVal) - else: - varValue = self.__manifest[varName] - initialVal = varValue - try: - varValue = int(varValue) - except ValueError: - return S_ERROR(f"{varName} must be a number") - minVal = self.__getCSValue(f"Min{varName}", minVal) - maxVal = self.__getCSValue(f"Max{varName}", maxVal) - varValue = max(minVal, min(varValue, maxVal)) - if initialVal != varValue: - self.__manifest.setOption(varName, varValue) - return S_OK(varValue) - - def __checkChoiceVar(self, varName, defaultVal, choices): - """ - Check a choice var - """ - initialVal = False - if varName not in self.__manifest: - varValue = self.__getCSValue(f"Default{varName}", defaultVal) - else: - varValue = self.__manifest[varName] - initialVal = varValue - if varValue not in self.__getCSValue(f"Choices{varName}", choices): - return S_ERROR(f"{varValue} is not a valid value for {varName}") - if initialVal != varValue: - self.__manifest.setOption(varName, varValue) - return S_OK(varValue) - - def __checkMultiChoice(self, varName, choices): - """ - Check a multi choice var - """ - initialVal = False - if varName not in self.__manifest: - return S_OK() - else: - varValue = self.__manifest[varName] - initialVal = varValue - choices = self.__getCSValue(f"Choices{varName}", choices) - for v in List.fromChar(varValue): - if v not in choices: - return S_ERROR(f"{v} is not a valid value for {varName}") - if initialVal != varValue: - self.__manifest.setOption(varName, varValue) - return S_OK(varValue) +from DIRACCommon.WorkloadManagementSystem.Client.JobState.JobManifest import * # noqa: F401, F403 - def __checkMaxInputData(self, maxNumber): - """ - Check Maximum Number of Input Data files allowed - """ - varName = "InputData" - if varName not in self.__manifest: - return S_OK() - varValue = self.__manifest[varName] - if len(List.fromChar(varValue)) > maxNumber: - return S_ERROR( - "Number of Input Data Files (%s) greater than current limit: %s" - % (len(List.fromChar(varValue)), maxNumber) - ) - return S_OK() - - def __contains__(self, key): - """Check if the manifest has the required key""" - return key in self.__manifest +from DIRAC.ConfigurationSystem.Client.Helpers.Operations import Operations - def setOptionsFromDict(self, varDict): - for k in sorted(varDict): - self.setOption(k, varDict[k]) +def makeJobManifestConfig(ownerGroup: str) -> JobManifestConfig: + ops = Operations(group=ownerGroup) + + allowedJobTypesForGroup = ops.getValue( + "JobDescription/ChoicesJobType", + ops.getValue("JobDescription/AllowedJobTypes", ["User", "Test", "Hospital"]) + + ops.getValue("Transformations/DataProcessing", []), + ) + + return { + "defaultForGroup": { + "CPUTime": ops.getValue("JobDescription/DefaultCPUTime", 86400), + "Priority": ops.getValue("JobDescription/DefaultPriority", 1), + }, + "minForGroup": { + "CPUTime": ops.getValue("JobDescription/MinCPUTime", 100), + "Priority": ops.getValue("JobDescription/MinPriority", 0), + }, + "maxForGroup": { + "CPUTime": ops.getValue("JobDescription/MaxCPUTime", 500000), + "Priority": ops.getValue("JobDescription/MaxPriority", 10), + }, + "allowedJobTypesForGroup": allowedJobTypesForGroup, + "maxInputData": Operations().getValue("JobDescription/MaxInputData", 500), + } + + +class JobManifest(JobManifest): # noqa: F405 pylint: disable=function-redefined def check(self): - """ - Check that the manifest is OK - """ - for k in ["Owner", "OwnerGroup"]: - if k not in self.__manifest: - return S_ERROR(f"Missing var {k} in manifest") - - # Check CPUTime - result = self.__checkNumericalVar("CPUTime", 86400, 100, 500000) - if not result["OK"]: - return result - - result = self.__checkNumericalVar("Priority", 1, 0, 10) - if not result["OK"]: - return result - - maxInputData = Operations().getValue("JobDescription/MaxInputData", 500) - result = self.__checkMaxInputData(maxInputData) - if not result["OK"]: - return result - - operation = Operations(group=self.__manifest["OwnerGroup"]) - allowedJobTypes = operation.getValue("JobDescription/AllowedJobTypes", ["User", "Test", "Hospital"]) - transformationTypes = operation.getValue("Transformations/DataProcessing", []) - result = self.__checkMultiChoice("JobType", allowedJobTypes + transformationTypes) - if not result["OK"]: - return result - return S_OK() - - def createSection(self, secName, contents=False): - if secName not in self.__manifest: - if contents and not isinstance(contents, CFG): - return S_ERROR(f"Contents for section {secName} is not a cfg object") - self.__dirty = True - return S_OK(self.__manifest.createNewSection(secName, contents=contents)) - return S_ERROR(f"Section {secName} already exists") - - def getSection(self, secName): - self.__dirty = True - if secName not in self.__manifest: - return S_ERROR(f"{secName} does not exist") - sec = self.__manifest[secName] - if not sec: - return S_ERROR(f"{secName} section empty") - return S_OK(sec) - - def setSectionContents(self, secName, contents): - if contents and not isinstance(contents, CFG): - return S_ERROR(f"Contents for section {secName} is not a cfg object") - self.__dirty = True - if secName in self.__manifest: - self.__manifest[secName].reset() - self.__manifest[secName].mergeWith(contents) - else: - self.__manifest.createNewSection(secName, contents=contents) - - def setOption(self, varName, varValue): - """ - Set a var in job manifest - """ - self.__dirty = True - levels = List.fromChar(varName, "/") - cfg = self.__manifest - for l in levels[:-1]: - if l not in cfg: - cfg.createNewSection(l) - cfg = cfg[l] - cfg.setOption(levels[-1], varValue) - - def remove(self, opName): - levels = List.fromChar(opName, "/") - cfg = self.__manifest - for l in levels[:-1]: - if l not in cfg: - return S_ERROR(f"{opName} does not exist") - cfg = cfg[l] - if cfg.deleteKey(levels[-1]): - self.__dirty = True - return S_OK() - return S_ERROR(f"{opName} does not exist") - - def getOption(self, varName, defaultValue=None): - """ - Get a variable from the job manifest - """ - cfg = self.__manifest - return cfg.getOption(varName, defaultValue) - - def getOptionList(self, section=""): - """ - Get a list of variables in a section of the job manifest - """ - cfg = self.__manifest.getRecursive(section) - if not cfg or "value" not in cfg: - return [] - cfg = cfg["value"] - return cfg.listOptions() - - def isOption(self, opName): - """ - Check if it is a valid option - """ - return self.__manifest.isOption(opName) - - def getSectionList(self, section=""): - """ - Get a list of sections in the job manifest - """ - cfg = self.__manifest.getRecursive(section) - if not cfg or "value" not in cfg: - return [] - cfg = cfg["value"] - return cfg.listSections() + return super().check(config=makeJobManifestConfig(self.__manifest["OwnerGroup"])) diff --git a/src/DIRAC/WorkloadManagementSystem/Client/JobStatus.py b/src/DIRAC/WorkloadManagementSystem/Client/JobStatus.py index a259945e694..5c77b3fc5db 100644 --- a/src/DIRAC/WorkloadManagementSystem/Client/JobStatus.py +++ b/src/DIRAC/WorkloadManagementSystem/Client/JobStatus.py @@ -1,95 +1,10 @@ -""" -This module contains constants and lists for the possible job states. -""" - -from DIRAC.Core.Utilities.StateMachine import State, StateMachine - -#: -SUBMITTING = "Submitting" -#: -RECEIVED = "Received" -#: -CHECKING = "Checking" -#: -STAGING = "Staging" -#: -SCOUTING = "Scouting" -#: -WAITING = "Waiting" -#: -MATCHED = "Matched" -#: The Rescheduled status is effectively never stored in the DB. -#: It could be considered a "virtual" status, and might even be dropped. -RESCHEDULED = "Rescheduled" -#: -RUNNING = "Running" -#: -STALLED = "Stalled" -#: -COMPLETING = "Completing" -#: -DONE = "Done" -#: -COMPLETED = "Completed" -#: -FAILED = "Failed" -#: -DELETED = "Deleted" -#: -KILLED = "Killed" - -#: Possible job states -JOB_STATES = [ - SUBMITTING, - RECEIVED, - CHECKING, - SCOUTING, - STAGING, - WAITING, - MATCHED, - RESCHEDULED, - RUNNING, - STALLED, - COMPLETING, - DONE, - COMPLETED, - FAILED, - DELETED, - KILLED, -] +"""Backward compatibility wrapper - moved to DIRACCommon -# Job States when the payload work has finished -JOB_FINAL_STATES = [DONE, COMPLETED, FAILED, KILLED] +This module has been moved to DIRACCommon.WorkloadManagementSystem.Client.JobStatus to avoid +circular dependencies and allow DiracX to use these utilities without +triggering DIRAC's global state initialization. -# WMS internal job States indicating the job object won't be updated -JOB_REALLY_FINAL_STATES = [DELETED] - - -class JobsStateMachine(StateMachine): - """Jobs state machine""" - - def __init__(self, state): - """c'tor - Defines the state machine transactions - """ - super().__init__(state) - - # States transitions - self.states = { - DELETED: State(15), # final state - KILLED: State(14, [DELETED], defState=KILLED), - FAILED: State(13, [RESCHEDULED, DELETED], defState=FAILED), - DONE: State(12, [DELETED], defState=DONE), - COMPLETED: State(11, [DONE, FAILED], defState=COMPLETED), - COMPLETING: State(10, [DONE, FAILED, COMPLETED, STALLED, KILLED], defState=COMPLETING), - STALLED: State(9, [RUNNING, FAILED, KILLED], defState=STALLED), - RUNNING: State(8, [STALLED, DONE, FAILED, RESCHEDULED, COMPLETING, KILLED, RECEIVED], defState=RUNNING), - RESCHEDULED: State(7, [WAITING, RECEIVED, DELETED, FAILED, KILLED], defState=RESCHEDULED), - MATCHED: State(6, [RUNNING, FAILED, RESCHEDULED, KILLED], defState=MATCHED), - WAITING: State(5, [MATCHED, RESCHEDULED, DELETED, KILLED], defState=WAITING), - STAGING: State(4, [CHECKING, WAITING, FAILED, KILLED], defState=STAGING), - SCOUTING: State(3, [CHECKING, FAILED, STALLED, KILLED], defState=SCOUTING), - CHECKING: State(2, [SCOUTING, STAGING, WAITING, RESCHEDULED, FAILED, DELETED, KILLED], defState=CHECKING), - RECEIVED: State(1, [SCOUTING, CHECKING, STAGING, WAITING, FAILED, DELETED, KILLED], defState=RECEIVED), - SUBMITTING: State(0, [RECEIVED, CHECKING, DELETED, KILLED], defState=SUBMITTING), # initial state - } +All exports are maintained for backward compatibility. +""" +# Re-export everything from DIRACCommon for backward compatibility +from DIRACCommon.WorkloadManagementSystem.Client.JobStatus import * # noqa: F401, F403 diff --git a/src/DIRAC/WorkloadManagementSystem/DB/JobDBUtils.py b/src/DIRAC/WorkloadManagementSystem/DB/JobDBUtils.py index f96c34cc492..b022d4b8c09 100644 --- a/src/DIRAC/WorkloadManagementSystem/DB/JobDBUtils.py +++ b/src/DIRAC/WorkloadManagementSystem/DB/JobDBUtils.py @@ -1,139 +1,33 @@ from __future__ import annotations -import base64 -import zlib +# Import stateless functions from DIRACCommon for backward compatibility +from DIRACCommon.WorkloadManagementSystem.DB.JobDBUtils import * from DIRAC.ConfigurationSystem.Client.Helpers.Operations import Operations -from DIRAC.Core.Utilities.DErrno import EWMSSUBM from DIRAC.Core.Utilities.ObjectLoader import ObjectLoader -from DIRAC.Core.Utilities.ReturnValues import S_ERROR, S_OK, returnValueOrRaise -from DIRAC.WorkloadManagementSystem.Client import JobStatus -from DIRAC.WorkloadManagementSystem.Client.JobState.JobManifest import JobManifest - -# Import stateless functions from DIRACCommon for backward compatibility -from DIRACCommon.WorkloadManagementSystem.DB.JobDBUtils import compressJDL, extractJDL, fixJDL +from DIRAC.Core.Utilities.ReturnValues import returnValueOrRaise +from DIRAC.WorkloadManagementSystem.Client.JobState.JobManifest import makeJobManifestConfig getDIRACPlatform = returnValueOrRaise( ObjectLoader().loadObject("ConfigurationSystem.Client.Helpers.Resources", "getDIRACPlatform") ) -def checkAndAddOwner(jdl: str, owner: str, ownerGroup: str) -> JobManifest: - jobManifest = JobManifest() - res = jobManifest.load(jdl) - if not res["OK"]: - return res - - jobManifest.setOptionsFromDict({"Owner": owner, "OwnerGroup": ownerGroup}) - res = jobManifest.check() - if not res["OK"]: - return res - - return S_OK(jobManifest) - - -def checkAndPrepareJob(jobID, classAdJob, classAdReq, owner, ownerGroup, jobAttrs, vo): - error = "" - - jdlOwner = classAdJob.getAttributeString("Owner") - jdlOwnerGroup = classAdJob.getAttributeString("OwnerGroup") - jdlVO = classAdJob.getAttributeString("VirtualOrganization") - - # The below is commented out since this is always overwritten by the submitter IDs - # but the check allows to findout inconsistent client environments - if jdlOwner and jdlOwner != owner: - error = "Wrong Owner in JDL" - elif jdlOwnerGroup and jdlOwnerGroup != ownerGroup: - error = "Wrong Owner Group in JDL" - elif jdlVO and jdlVO != vo: - error = "Wrong Virtual Organization in JDL" - - classAdJob.insertAttributeString("Owner", owner) - classAdJob.insertAttributeString("OwnerGroup", ownerGroup) - - if vo: - classAdJob.insertAttributeString("VirtualOrganization", vo) - - classAdReq.insertAttributeString("Owner", owner) - classAdReq.insertAttributeString("OwnerGroup", ownerGroup) - if vo: - classAdReq.insertAttributeString("VirtualOrganization", vo) - - inputDataPolicy = Operations(vo=vo).getValue("InputDataPolicy/InputDataModule") - if inputDataPolicy and not classAdJob.lookupAttribute("InputDataModule"): - classAdJob.insertAttributeString("InputDataModule", inputDataPolicy) - - softwareDistModule = Operations(vo=vo).getValue("SoftwareDistModule") - if softwareDistModule and not classAdJob.lookupAttribute("SoftwareDistModule"): - classAdJob.insertAttributeString("SoftwareDistModule", softwareDistModule) - - # priority - priority = classAdJob.getAttributeInt("Priority") - if priority is None: - priority = 0 - classAdReq.insertAttributeInt("UserPriority", priority) - - # CPU time - cpuTime = classAdJob.getAttributeInt("CPUTime") - if cpuTime is None: - opsHelper = Operations(group=ownerGroup) - cpuTime = opsHelper.getValue("JobDescription/DefaultCPUTime", 86400) - classAdReq.insertAttributeInt("CPUTime", cpuTime) - - # platform(s) - platformList = classAdJob.getListFromExpression("Platform") - if platformList: - result = getDIRACPlatform(platformList) - if not result["OK"]: - return result - if result["Value"]: - classAdReq.insertAttributeVectorString("Platforms", result["Value"]) - else: - error = "OS compatibility info not found" - if error: - retVal = S_ERROR(EWMSSUBM, error) - retVal["JobId"] = jobID - retVal["Status"] = JobStatus.FAILED - retVal["MinorStatus"] = error - - jobAttrs["Status"] = JobStatus.FAILED - - jobAttrs["MinorStatus"] = error - return retVal - return S_OK() - - -def createJDLWithInitialStatus( - classAdJob, classAdReq, jdl2DBParameters, jobAttrs, initialStatus, initialMinorStatus, *, modern=False -): - """ - :param modern: if True, store boolean instead of string for VerifiedFlag (used by diracx only) - """ - priority = classAdJob.getAttributeInt("Priority") - if priority is None: - priority = 0 - jobAttrs["UserPriority"] = priority - - for jdlName in jdl2DBParameters: - # Defaults are set by the DB. - jdlValue = classAdJob.getAttributeString(jdlName) - if jdlValue: - jobAttrs[jdlName] = jdlValue - - jdlValue = classAdJob.getAttributeString("Site") - if jdlValue: - if jdlValue.find(",") != -1: - jobAttrs["Site"] = "Multiple" - else: - jobAttrs["Site"] = jdlValue - - jobAttrs["VerifiedFlag"] = True if modern else "True" +def checkAndPrepareJob( + jobID, classAdJob, classAdReq, owner, ownerGroup, jobAttrs, vo +): # pylint: disable=function-redefined + from DIRACCommon.WorkloadManagementSystem.DB.JobDBUtils import checkAndPrepareJob - jobAttrs["Status"] = initialStatus + config = { + "inputDataPolicyForVO": Operations(vo=vo).getValue("InputDataPolicy/InputDataModule"), + "softwareDistModuleForVO": Operations(vo=vo).getValue("SoftwareDistModule"), + "defaultCPUTimeForOwnerGroup": Operations(group=ownerGroup).getValue("JobDescription/DefaultCPUTime", 86400), + "getDIRACPlatform": getDIRACPlatform, + } + return checkAndPrepareJob(jobID, classAdJob, classAdReq, owner, ownerGroup, jobAttrs, vo, config=config) - jobAttrs["MinorStatus"] = initialMinorStatus - reqJDL = classAdReq.asJDL() - classAdJob.insertAttributeInt("JobRequirements", reqJDL) +def checkAndAddOwner(jdl: str, owner: str, ownerGroup: str): # pylint: disable=function-redefined + from DIRACCommon.WorkloadManagementSystem.DB.JobDBUtils import checkAndAddOwner - return classAdJob.asJDL() + return checkAndAddOwner(jdl, owner, ownerGroup, job_manifest_config=makeJobManifestConfig(ownerGroup)) diff --git a/src/DIRAC/WorkloadManagementSystem/Utilities/JobModel.py b/src/DIRAC/WorkloadManagementSystem/Utilities/JobModel.py index 2d27aedf174..9b3afdc1d12 100644 --- a/src/DIRAC/WorkloadManagementSystem/Utilities/JobModel.py +++ b/src/DIRAC/WorkloadManagementSystem/Utilities/JobModel.py @@ -1,209 +1,38 @@ -""" This module contains the JobModel class, which is used to validate the job description """ +from __future__ import annotations -# pylint: disable=no-self-argument, no-self-use, invalid-name, missing-function-docstring - -from collections.abc import Iterable -from typing import Any, Annotated, TypeAlias, Self - -from pydantic import BaseModel, BeforeValidator, model_validator, field_validator, ConfigDict +from typing import ClassVar +from pydantic import PrivateAttr +from DIRACCommon.WorkloadManagementSystem.Utilities.JobModel import * # noqa: F401, F403 from DIRAC import gLogger -from DIRAC.ConfigurationSystem.Client.Helpers.Operations import Operations -from DIRAC.ConfigurationSystem.Client.Helpers.Resources import getDIRACPlatforms, getSites - - -# HACK: Convert appropriate iterables into sets -def default_set_validator(value): - if value is None: - return set() - elif not isinstance(value, Iterable): - return value - elif isinstance(value, (str, bytes, bytearray)): - return value - else: - return set(value) - - -CoercibleSetStr: TypeAlias = Annotated[set[str], BeforeValidator(default_set_validator)] - - -class BaseJobDescriptionModel(BaseModel): - """Base model for the job description (not parametric)""" - - model_config = ConfigDict(validate_assignment=True) - - arguments: str = "" - bannedSites: CoercibleSetStr = set() - # TODO: This should use a field factory - cpuTime: int = Operations().getValue("JobDescription/DefaultCPUTime", 86400) - executable: str - executionEnvironment: dict = None - gridCE: str = "" - inputSandbox: CoercibleSetStr = set() - inputData: CoercibleSetStr = set() - inputDataPolicy: str = "" - jobConfigArgs: str = "" - jobGroup: str = "" - jobType: str = "User" - jobName: str = "Name" - # TODO: This should be an StrEnum - logLevel: str = "INFO" - # TODO: This can't be None with this type hint - maxNumberOfProcessors: int = None - minNumberOfProcessors: int = 1 - outputData: CoercibleSetStr = set() - outputPath: str = "" - outputSandbox: CoercibleSetStr = set() - outputSE: str = "" - platform: str = "" - # TODO: This should use a field factory - priority: int = Operations().getValue("JobDescription/DefaultPriority", 1) - sites: CoercibleSetStr = set() - stderr: str = "std.err" - stdout: str = "std.out" - tags: CoercibleSetStr = set() - extraFields: dict[str, Any] = {} - - @field_validator("cpuTime") - def checkCPUTimeBounds(cls, v): - minCPUTime = Operations().getValue("JobDescription/MinCPUTime", 100) - maxCPUTime = Operations().getValue("JobDescription/MaxCPUTime", 500000) - if not minCPUTime <= v <= maxCPUTime: - raise ValueError(f"cpuTime out of bounds (must be between {minCPUTime} and {maxCPUTime})") - return v - - @field_validator("executable") - def checkExecutableIsNotAnEmptyString(cls, v: str): - if not v: - raise ValueError("executable must not be an empty string") - return v - - @field_validator("jobType") - def checkJobTypeIsAllowed(cls, v: str): - jobTypes = Operations().getValue("JobDescription/AllowedJobTypes", ["User", "Test", "Hospital"]) - transformationTypes = Operations().getValue("Transformations/DataProcessing", []) - allowedTypes = jobTypes + transformationTypes - if v not in allowedTypes: - raise ValueError(f"jobType '{v}' is not allowed for this kind of user (must be in {allowedTypes})") - return v - - @field_validator("inputData") - def checkInputDataDoesntContainDoubleSlashes(cls, v): - if v: - for lfn in v: - if lfn.find("//") > -1: - raise ValueError("Input data contains //") - return v - - @field_validator("inputData") - def addLFNPrefixIfStringStartsWithASlash(cls, v: set[str]): - if v: - v = {lfn.strip() for lfn in v if lfn.strip()} - v = {f"LFN:{lfn}" if lfn.startswith("/") else lfn for lfn in v} - - for lfn in v: - if not lfn.startswith("LFN:/"): - raise ValueError("Input data files must start with LFN:/") - return v - - @model_validator(mode="after") - def checkNumberOfInputDataFiles(self) -> Self: - if self.inputData: - maxInputDataFiles = Operations().getValue("JobDescription/MaxInputData", 500) - if self.jobType == "User" and len(self.inputData) >= maxInputDataFiles: - raise ValueError(f"inputData contains too many files (must contain at most {maxInputDataFiles})") - return self - - @field_validator("inputSandbox") - def checkLFNSandboxesAreWellFormated(cls, v: set[str]): - for inputSandbox in v: - if inputSandbox.startswith("LFN:") and not inputSandbox.startswith("LFN:/"): - raise ValueError("LFN files must start by LFN:/") - return v - - @field_validator("logLevel") - def checkLogLevelIsValid(cls, v: str): - v = v.upper() - possibleLogLevels = gLogger.getAllPossibleLevels() - if v not in possibleLogLevels: - raise ValueError(f"Log level {v} not in {possibleLogLevels}") - return v - - @field_validator("minNumberOfProcessors") - def checkMinNumberOfProcessorsBounds(cls, v): - minNumberOfProcessors = Operations().getValue("JobDescription/MinNumberOfProcessors", 1) - maxNumberOfProcessors = Operations().getValue("JobDescription/MaxNumberOfProcessors", 1024) - if not minNumberOfProcessors <= v <= maxNumberOfProcessors: - raise ValueError( - f"minNumberOfProcessors out of bounds (must be between {minNumberOfProcessors} and {maxNumberOfProcessors})" - ) - return v - - @field_validator("maxNumberOfProcessors") - def checkMaxNumberOfProcessorsBounds(cls, v): - minNumberOfProcessors = Operations().getValue("JobDescription/MinNumberOfProcessors", 1) - maxNumberOfProcessors = Operations().getValue("JobDescription/MaxNumberOfProcessors", 1024) - if not minNumberOfProcessors <= v <= maxNumberOfProcessors: - raise ValueError( - f"minNumberOfProcessors out of bounds (must be between {minNumberOfProcessors} and {maxNumberOfProcessors})" - ) - return v - - @model_validator(mode="after") - def checkThatMaxNumberOfProcessorsIsGreaterThanMinNumberOfProcessors(self) -> Self: - if self.maxNumberOfProcessors: - if self.maxNumberOfProcessors < self.minNumberOfProcessors: - raise ValueError("maxNumberOfProcessors must be greater than minNumberOfProcessors") - return self - - @model_validator(mode="after") - def addTagsDependingOnNumberOfProcessors(self) -> Self: - if self.minNumberOfProcessors == self.maxNumberOfProcessors: - self.tags.add(f"{self.minNumberOfProcessors}Processors") - if self.minNumberOfProcessors > 1: - self.tags.add("MultiProcessor") - return self +from DIRAC.ConfigurationSystem.Client.Helpers.Resources import getSites - @field_validator("sites") - def checkSites(cls, v: set[str]): - if v: - res = getSites() - if not res["OK"]: - raise ValueError(res["Message"]) - invalidSites = v - set(res["Value"]).union({"ANY"}) - if invalidSites: - raise ValueError(f"Invalid sites: {' '.join(invalidSites)}") - return v - @model_validator(mode="after") - def checkThatSitesAndBannedSitesAreNotMutuallyExclusive(self) -> Self: - if self.sites and self.bannedSites: - while self.bannedSites: - self.sites.discard(self.bannedSites.pop()) - if not self.sites: - raise ValueError("sites and bannedSites are mutually exclusive") - return self +def _make_model_config(cls=None) -> BaseJobDescriptionModelConfg: + from DIRAC.ConfigurationSystem.Client.Helpers.Operations import Operations - @field_validator("priority") - def checkPriorityBounds(cls, v): - minPriority = Operations().getValue("JobDescription/MinPriority", 0) - maxPriority = Operations().getValue("JobDescription/MaxPriority", 10) - if not minPriority <= v <= maxPriority: - raise ValueError(f"priority out of bounds (must be between {minPriority} and {maxPriority})") - return v + ops = Operations() + allowedJobTypes = ops.getValue("JobDescription/AllowedJobTypes", ["User", "Test", "Hospital"]) + allowedJobTypes += ops.getValue("Transformations/DataProcessing", []) + return { + "cpuTime": ops.getValue("JobDescription/DefaultCPUTime", 86400), + "priority": ops.getValue("JobDescription/DefaultPriority", 1), + "minCPUTime": ops.getValue("JobDescription/MinCPUTime", 100), + "maxCPUTime": ops.getValue("JobDescription/MaxCPUTime", 500000), + "allowedJobTypes": allowedJobTypes, + "maxInputDataFiles": ops.getValue("JobDescription/MaxInputData", 500), + "minNumberOfProcessors": ops.getValue("JobDescription/MinNumberOfProcessors", 1), + "maxNumberOfProcessors": ops.getValue("JobDescription/MaxNumberOfProcessors", 1024), + "minPriority": ops.getValue("JobDescription/MinPriority", 0), + "maxPriority": ops.getValue("JobDescription/MaxPriority", 10), + "possibleLogLevels": gLogger.getAllPossibleLevels(), + "sites": getSites(), + } -class JobDescriptionModel(BaseJobDescriptionModel): - """Model for the job description (non parametric job with user credentials, i.e server side)""" +class BaseJobDescriptionModel(BaseJobDescriptionModel): # noqa: F405 pylint: disable=function-redefined + _config_builder: ClassVar = _make_model_config - owner: str - ownerGroup: str - vo: str - @model_validator(mode="after") - def checkLFNMatchesREGEX(self) -> Self: - if self.inputData: - for lfn in self.inputData: - if not lfn.startswith(f"LFN:/{self.vo}/"): - raise ValueError(f"Input data not correctly specified (must start with LFN:/{self.vo}/)") - return self +class JobDescriptionModel(JobDescriptionModel): # noqa: F405 pylint: disable=function-redefined + _config_builder: ClassVar = _make_model_config diff --git a/src/DIRAC/WorkloadManagementSystem/Utilities/JobStatusUtility.py b/src/DIRAC/WorkloadManagementSystem/Utilities/JobStatusUtility.py index 5bdc81014f6..ffdb633a2d3 100644 --- a/src/DIRAC/WorkloadManagementSystem/Utilities/JobStatusUtility.py +++ b/src/DIRAC/WorkloadManagementSystem/Utilities/JobStatusUtility.py @@ -9,6 +9,7 @@ from DIRAC.Core.Utilities import TimeUtilities from DIRAC.Core.Utilities.ObjectLoader import ObjectLoader from DIRAC.WorkloadManagementSystem.Client import JobStatus +from DIRACCommon.WorkloadManagementSystem.Utilities.JobStatusUtility import getStartAndEndTime, getNewStatus if TYPE_CHECKING: from DIRAC.WorkloadManagementSystem.DB.JobLoggingDB import JobLoggingDB @@ -180,66 +181,3 @@ def setJobStatusBulk(self, jobID: int, statusDict: dict, force: bool = False): return result return S_OK((attrNames, attrValues)) - - -def getStartAndEndTime(startTime, endTime, updateTimes, timeStamps, statusDict): - newStat = "" - firstUpdate = TimeUtilities.toEpoch(TimeUtilities.fromString(updateTimes[0])) - for ts, st in timeStamps: - if firstUpdate >= ts: - newStat = st - # Pick up start and end times from all updates - for updTime in updateTimes: - sDict = statusDict[updTime] - newStat = sDict.get("Status", newStat) - - if not startTime and newStat == JobStatus.RUNNING: - # Pick up the start date when the job starts running if not existing - startTime = updTime - elif not endTime and newStat in JobStatus.JOB_FINAL_STATES: - # Pick up the end time when the job is in a final status - endTime = updTime - - return startTime, endTime - - -def getNewStatus( - jobID: int, - updateTimes: list[datetime], - lastTime: datetime, - statusDict: dict[datetime, Any], - currentStatus, - force: bool, - log, -): - status = "" - minor = "" - application = "" - # Get the last status values looping on the most recent upupdateTimes in chronological order - for updTime in [dt for dt in updateTimes if dt >= lastTime]: - sDict = statusDict[updTime] - log.debug(f"\tTime {updTime} - Statuses {str(sDict)}") - status = sDict.get("Status", currentStatus) - # evaluate the state machine if the status is changing - if not force and status != currentStatus: - res = JobStatus.JobsStateMachine(currentStatus).getNextState(status) - if not res["OK"]: - return res - newStat = res["Value"] - # If the JobsStateMachine does not accept the candidate, don't update - if newStat != status: - # keeping the same status - log.error( - f"Job Status Error: {jobID} can't move from {currentStatus} to {status}: using {newStat}", - ) - status = newStat - sDict["Status"] = newStat - # Change the source to indicate this is not what was requested - source = sDict.get("Source", "") - sDict["Source"] = source + "(SM)" - # at this stage status == newStat. Set currentStatus to this new status - currentStatus = newStat - - minor = sDict.get("MinorStatus", minor) - application = sDict.get("ApplicationStatus", application) - return S_OK((status, minor, application)) diff --git a/src/DIRAC/WorkloadManagementSystem/Utilities/test/Test_JobModel.py b/src/DIRAC/WorkloadManagementSystem/Utilities/test/Test_JobModel.py index 2e5c6eaedeb..eafe43e83e8 100644 --- a/src/DIRAC/WorkloadManagementSystem/Utilities/test/Test_JobModel.py +++ b/src/DIRAC/WorkloadManagementSystem/Utilities/test/Test_JobModel.py @@ -175,11 +175,7 @@ def test_logLevelValidator_invalid(): def test_platformValidator_valid(): """Test the platform validator with valid input.""" - with patch( - "DIRAC.WorkloadManagementSystem.Utilities.JobModel.getDIRACPlatforms", - return_value=S_OK(["x86_64-slc6-gcc62-opt"]), - ): - job = BaseJobDescriptionModel(executable=EXECUTABLE, platform="x86_64-slc6-gcc62-opt") + job = BaseJobDescriptionModel(executable=EXECUTABLE, platform="x86_64-slc6-gcc62-opt") assert job.platform == "x86_64-slc6-gcc62-opt"