Skip to content

Commit 15a1926

Browse files
authored
Merge pull request #356 from ryuwd/roneil-jobattr-update-fix
Fix job attribute update to account for mismatching columns between rows
2 parents 97b3c58 + ed019f1 commit 15a1926

3 files changed

Lines changed: 94 additions & 7 deletions

File tree

diracx-db/src/diracx/db/sql/job/db.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from datetime import datetime, timezone
44
from typing import TYPE_CHECKING, Any
55

6-
from sqlalchemy import bindparam, delete, func, insert, select, update
6+
from sqlalchemy import bindparam, case, delete, func, insert, select, update
77
from sqlalchemy.exc import IntegrityError, NoResultFound
88

99
if TYPE_CHECKING:
@@ -219,13 +219,25 @@ async def setJobAttributesBulk(self, jobData):
219219
jobData[job_id].update(
220220
{"LastUpdateTime": datetime.now(tz=timezone.utc)}
221221
)
222+
columns = set(key for attrs in jobData.values() for key in attrs.keys())
223+
case_expressions = {
224+
column: case(
225+
*[
226+
(Jobs.__table__.c.JobID == job_id, attrs[column])
227+
for job_id, attrs in jobData.items()
228+
if column in attrs
229+
],
230+
else_=getattr(Jobs.__table__.c, column), # Retain original value
231+
)
232+
for column in columns
233+
}
222234

223-
await self.conn.execute(
224-
Jobs.__table__.update().where(
225-
Jobs.__table__.c.JobID == bindparam("b_JobID")
226-
),
227-
[{"b_JobID": job_id, **attrs} for job_id, attrs in jobData.items()],
235+
stmt = (
236+
Jobs.__table__.update()
237+
.values(**case_expressions)
238+
.where(Jobs.__table__.c.JobID.in_(jobData.keys()))
228239
)
240+
await self.conn.execute(stmt)
229241

230242
async def getJobJDL(self, job_id: int, original: bool = False) -> str:
231243
from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import extractJDL

diracx-db/src/diracx/db/sql/utils/job.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,11 +325,12 @@ def parse_jdl(job_id, job_jdl):
325325
"failed": failed,
326326
"success": {
327327
job_id: {
328-
"InputData": job_jdls[job_id],
328+
"InputData": job_jdls.get(job_id, None),
329329
**attribute_changes[job_id],
330330
**set_status_result.model_dump(),
331331
}
332332
for job_id, set_status_result in set_job_status_result.success.items()
333+
if job_id not in failed
333334
},
334335
}
335336

diracx-routers/tests/test_job_manager.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -934,6 +934,80 @@ def test_insert_and_reschedule(normal_user_client: TestClient):
934934
}
935935

936936

937+
## test edge case for rescheduling
938+
939+
940+
def test_reschedule_job_attr_update(normal_user_client: TestClient):
941+
job_definitions = [TEST_JDL] * 15
942+
943+
r = normal_user_client.post("/api/jobs/jdl", json=job_definitions)
944+
assert r.status_code == 200, r.json()
945+
assert len(r.json()) == len(job_definitions)
946+
947+
submitted_job_ids = sorted([job_dict["JobID"] for job_dict in r.json()])
948+
949+
# Test /jobs/reschedule and
950+
# test max_reschedule
951+
952+
max_resched = 3
953+
954+
fail_resched_ids = submitted_job_ids[0:5]
955+
good_resched_ids = list(set(submitted_job_ids) - set(fail_resched_ids))
956+
957+
for i in range(max_resched):
958+
r = normal_user_client.post(
959+
"/api/jobs/reschedule",
960+
params={"job_ids": fail_resched_ids},
961+
)
962+
assert r.status_code == 200, r.json()
963+
result = r.json()
964+
successful_results = result["success"]
965+
for jid in fail_resched_ids:
966+
assert str(jid) in successful_results, result
967+
assert successful_results[str(jid)]["Status"] == JobStatus.RECEIVED
968+
assert successful_results[str(jid)]["MinorStatus"] == "Job Rescheduled"
969+
assert successful_results[str(jid)]["RescheduleCounter"] == i + 1
970+
971+
for i in range(max_resched):
972+
r = normal_user_client.post(
973+
"/api/jobs/reschedule",
974+
params={"job_ids": submitted_job_ids},
975+
)
976+
assert r.status_code == 200, r.json()
977+
result = r.json()
978+
successful_results = result["success"]
979+
failed_results = result["failed"]
980+
for jid in good_resched_ids:
981+
assert str(jid) in successful_results, result
982+
assert successful_results[str(jid)]["Status"] == JobStatus.RECEIVED
983+
assert successful_results[str(jid)]["MinorStatus"] == "Job Rescheduled"
984+
assert successful_results[str(jid)]["RescheduleCounter"] == i + 1
985+
for jid in fail_resched_ids:
986+
assert str(jid) in failed_results, result
987+
# assert successful_results[jid]["Status"] == JobStatus.RECEIVED
988+
# assert successful_results[jid]["MinorStatus"] == "Job Rescheduled"
989+
# assert successful_results[jid]["RescheduleCounter"] == i + 1
990+
991+
r = normal_user_client.post(
992+
"/api/jobs/reschedule",
993+
params={"job_ids": submitted_job_ids},
994+
)
995+
assert (
996+
r.status_code != 200
997+
), f"Rescheduling more than {max_resched} times should have failed by now {r.json()}"
998+
assert r.json() == {
999+
"detail": {
1000+
"success": [],
1001+
"failed": {
1002+
str(i): {
1003+
"detail": f"Maximum number of reschedules exceeded ({max_resched})"
1004+
}
1005+
for i in good_resched_ids + fail_resched_ids
1006+
},
1007+
}
1008+
}
1009+
1010+
9371011
# Test delete job
9381012

9391013

0 commit comments

Comments
 (0)