Skip to content

Commit 94a3eaf

Browse files
authored
[DOP-34705] add recursive query for job_dependencies (#412)
* [DOP-34705] add recursive query for job_dependencies * [DOP-34705] replace columns to model in cte * [DOP-34705] replace columns to model in cte * [DOP-34705] add tests for edge cases * [DOP-34705] rebase fix * [DOP-34705] add changelog * [DOP-34705] combine queries for direction both * [DOP-34705] move query for both into header
1 parent a7da6c3 commit 94a3eaf

7 files changed

Lines changed: 229 additions & 18 deletions

File tree

data_rentgen/db/repositories/job_dependency.py

Lines changed: 62 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
# SPDX-License-Identifier: Apache-2.0
33
from typing import Literal
44

5-
from sqlalchemy import ARRAY, Integer, any_, bindparam, cast, func, or_, select, tuple_
5+
from sqlalchemy import ARRAY, Integer, any_, bindparam, cast, func, literal, select, tuple_
6+
from sqlalchemy.orm import aliased
67

78
from data_rentgen.db.models.job_dependency import JobDependency
89
from data_rentgen.db.repositories.base import Repository
@@ -26,6 +27,59 @@
2627
JobDependency.to_job_id == bindparam("to_job_id"),
2728
)
2829

30+
upstream_jobs_query_base_part = (
31+
select(
32+
JobDependency,
33+
literal(1).label("depth"),
34+
)
35+
.select_from(JobDependency)
36+
.where(JobDependency.to_job_id == any_(bindparam("job_ids")))
37+
)
38+
upstream_jobs_query_cte = upstream_jobs_query_base_part.cte(name="upstream_jobs_query", recursive=True)
39+
40+
upstream_jobs_query_recursive_part = (
41+
select(
42+
JobDependency,
43+
(upstream_jobs_query_cte.c.depth + 1).label("depth"),
44+
)
45+
.select_from(JobDependency)
46+
.where(
47+
upstream_jobs_query_cte.c.depth < bindparam("depth"),
48+
JobDependency.to_job_id == upstream_jobs_query_cte.c.from_job_id,
49+
)
50+
)
51+
52+
53+
upstream_jobs_query_cte = upstream_jobs_query_cte.union(upstream_jobs_query_recursive_part)
54+
upstream_entities_query = select(aliased(JobDependency, upstream_jobs_query_cte))
55+
56+
downstream_jobs_query_base_part = (
57+
select(
58+
JobDependency,
59+
literal(1).label("depth"),
60+
)
61+
.select_from(JobDependency)
62+
.where(JobDependency.from_job_id == any_(bindparam("job_ids")))
63+
)
64+
downstream_jobs_query_cte = downstream_jobs_query_base_part.cte(name="downstream_jobs_query", recursive=True)
65+
66+
downstream_jobs_query_recursive_part = (
67+
select(
68+
JobDependency,
69+
(downstream_jobs_query_cte.c.depth + 1).label("depth"),
70+
)
71+
.select_from(JobDependency)
72+
.where(
73+
downstream_jobs_query_cte.c.depth < bindparam("depth"),
74+
JobDependency.from_job_id == downstream_jobs_query_cte.c.to_job_id,
75+
)
76+
)
77+
78+
downstream_jobs_query_cte = downstream_jobs_query_cte.union(downstream_jobs_query_recursive_part)
79+
downstream_entities_query = select(aliased(JobDependency, downstream_jobs_query_cte))
80+
81+
both_entities_query = select(aliased(JobDependency, (upstream_entities_query.union(downstream_entities_query)).cte()))
82+
2983

3084
class JobDependencyRepository(Repository[JobDependency]):
3185
async def fetch_bulk(
@@ -60,25 +114,19 @@ async def get_dependencies(
60114
self,
61115
job_ids: list[int],
62116
direction: Literal["UPSTREAM", "DOWNSTREAM", "BOTH"],
117+
depth: int,
63118
) -> list[JobDependency]:
64119

65-
job_dependency_query = select(JobDependency)
66120
match direction:
67121
case "UPSTREAM":
68-
job_dependency_query = job_dependency_query.where(JobDependency.to_job_id == any_(bindparam("job_ids")))
122+
query = upstream_entities_query
69123
case "DOWNSTREAM":
70-
job_dependency_query = job_dependency_query.where(
71-
JobDependency.from_job_id == any_(bindparam("job_ids"))
72-
)
124+
query = downstream_entities_query
73125
case "BOTH":
74-
job_dependency_query = job_dependency_query.where(
75-
or_(
76-
JobDependency.from_job_id == any_(bindparam("job_ids")),
77-
JobDependency.to_job_id == any_(bindparam("job_ids")),
78-
)
79-
)
80-
scalars = await self._session.scalars(job_dependency_query, {"job_ids": job_ids})
81-
return list(scalars.all())
126+
query = both_entities_query
127+
128+
result = await self._session.scalars(query, {"job_ids": job_ids, "depth": depth})
129+
return list(result.all())
82130

83131
async def _get(self, job_dependency: JobDependencyDTO) -> JobDependency | None:
84132
return await self._session.scalar(

data_rentgen/server/api/v1/router/job.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,11 @@ async def get_job_dependencies(
9191
job_service: Annotated[JobService, Depends()],
9292
current_user: Annotated[User, Depends(get_user())],
9393
) -> JobDependenciesResponseV1:
94-
job_dependencies = await job_service.get_job_dependencies(query_args.start_node_id, query_args.direction)
94+
job_dependencies = await job_service.get_job_dependencies(
95+
start_node_id=query_args.start_node_id,
96+
direction=query_args.direction,
97+
depth=query_args.depth,
98+
)
9599
return JobDependenciesResponseV1(
96100
relations=JobDependenciesRelationsV1(
97101
parents=[

data_rentgen/server/schemas/v1/job.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,4 +129,5 @@ class JobDependenciesQueryV1(BaseModel):
129129
description="Direction of the lineage",
130130
examples=["DOWNSTREAM", "UPSTREAM", "BOTH"],
131131
)
132+
depth: int = Field(description="Depth of dependencies between jobs", default=1)
132133
model_config = ConfigDict(extra="ignore")

data_rentgen/server/services/job.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,15 @@ async def get_job_dependencies(
111111
self,
112112
start_node_id: int,
113113
direction: Literal["UPSTREAM", "DOWNSTREAM", "BOTH"],
114+
depth: int,
114115
) -> JobDependenciesResult:
115-
logger.info("Get Job dependencies with start at job with id %s and direction: %s", start_node_id, direction)
116+
logger.info(
117+
"Get Job dependencies with start at job with id %s and next params: direction: %s, depth: %s",
118+
start_node_id,
119+
direction,
120+
depth,
121+
)
122+
job_ids = {start_node_id}
116123

117124
ancestor_relations = await self._uow.job.list_ancestor_relations([start_node_id])
118125
descendant_relations = await self._uow.job.list_descendant_relations([start_node_id])
@@ -122,10 +129,13 @@ async def get_job_dependencies(
122129
| {p.child_job_id for p in descendant_relations}
123130
)
124131

125-
dependencies = await self._uow.job_dependency.get_dependencies(job_ids=list(job_ids), direction=direction)
132+
dependencies = await self._uow.job_dependency.get_dependencies(
133+
job_ids=list(job_ids),
134+
direction=direction,
135+
depth=depth,
136+
)
126137
dependency_job_ids = {d.from_job_id for d in dependencies} | {d.to_job_id for d in dependencies}
127138
job_ids |= dependency_job_ids
128-
129139
# return ancestors for all found jobs in the graph
130140
ancestor_relations += await self._uow.job.list_ancestor_relations(list(dependency_job_ids))
131141
job_ids |= {p.parent_job_id for p in ancestor_relations}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add ``depth`` query parameter to ``GET /v1/jobs/dependencies`` endpoint, allowing control over how many layers of dependency are traversed. Defaults to ``1``.

tests/test_server/fixtures/factories/job.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,51 @@ async def jobs_with_same_parent_job(
327327
await clean_db(async_session)
328328

329329

330+
@pytest_asyncio.fixture
331+
async def job_dependency_depth_chain(
332+
async_session_maker: Callable[[], AbstractAsyncContextManager[AsyncSession]],
333+
) -> AsyncGenerator[list[Job], None]:
334+
"""
335+
Linear dependency chain of 5 jobs:
336+
337+
job_1 → job_2 → job_3 → job_4 → job_5
338+
339+
Each arrow is a JobDependency edge with type "DIRECT_DEPENDENCY".
340+
Used for testing depth-limited dependency queries.
341+
"""
342+
async with async_session_maker() as async_session:
343+
location = await create_location(async_session)
344+
job_type = await create_job_type(async_session)
345+
346+
jobs = []
347+
for i in range(1, 6):
348+
job = await create_job(
349+
async_session,
350+
location_id=location.id,
351+
job_type_id=job_type.id,
352+
job_kwargs={"name": f"depth-chain-job-{i}"},
353+
)
354+
jobs.append(job)
355+
356+
async_session.add_all(
357+
[
358+
JobDependency(
359+
from_job_id=jobs[i].id,
360+
to_job_id=jobs[i + 1].id,
361+
type="DIRECT_DEPENDENCY",
362+
)
363+
for i in range(len(jobs) - 1)
364+
],
365+
)
366+
await async_session.commit()
367+
async_session.expunge_all()
368+
369+
yield jobs
370+
371+
async with async_session_maker() as async_session:
372+
await clean_db(async_session)
373+
374+
330375
@pytest_asyncio.fixture
331376
async def job_dependency_chain(
332377
async_session_maker: Callable[[], AbstractAsyncContextManager[AsyncSession]],

tests/test_server/test_jobs/test_job_dependencies.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,3 +202,105 @@ async def test_get_job_dependencies_with_direction_downstream(
202202
},
203203
"nodes": {"jobs": jobs_to_json(expected_nodes)},
204204
}
205+
206+
207+
@pytest.mark.parametrize(
208+
["depth", "direction", "expected_dep_indices", "expected_job_indices"],
209+
[
210+
(1, "DOWNSTREAM", [(2, 3)], [2, 3]),
211+
(2, "DOWNSTREAM", [(2, 3), (3, 4)], [2, 3, 4]),
212+
(1, "UPSTREAM", [(1, 2)], [1, 2]),
213+
(2, "UPSTREAM", [(0, 1), (1, 2)], [0, 1, 2]),
214+
(1, "BOTH", [(1, 2), (2, 3)], [1, 2, 3]),
215+
(2, "BOTH", [(0, 1), (1, 2), (2, 3), (3, 4)], [0, 1, 2, 3, 4]),
216+
(5, "BOTH", [(0, 1), (1, 2), (2, 3), (3, 4)], [0, 1, 2, 3, 4]),
217+
],
218+
ids=[
219+
"depth_1-downstream",
220+
"depth_2-downstream",
221+
"depth_1-upstream",
222+
"depth_2-upstream",
223+
"depth_1-both",
224+
"depth_2-both",
225+
"depth_5-both",
226+
],
227+
)
228+
async def test_get_job_dependencies_with_depth(
229+
test_client: AsyncClient,
230+
job_dependency_depth_chain: tuple[Job, ...],
231+
async_session: AsyncSession,
232+
mocked_user: MockedUser,
233+
depth: int,
234+
direction: str,
235+
expected_dep_indices: list[tuple[int, int]],
236+
expected_job_indices: list[int],
237+
):
238+
"""
239+
Fixture chain: job_0 → job_1 → job_2 → job_3 → job_4
240+
Start node is always job_2 (middle of the chain).
241+
"""
242+
jobs = job_dependency_depth_chain
243+
start_job = jobs[2]
244+
245+
expected_jobs = await enrich_jobs([jobs[i] for i in expected_job_indices], async_session)
246+
247+
response = await test_client.get(
248+
"v1/jobs/dependencies",
249+
headers={"Authorization": f"Bearer {mocked_user.access_token}"},
250+
params={"start_node_id": start_job.id, "depth": depth, "direction": direction},
251+
)
252+
assert response.status_code == HTTPStatus.OK, response.json()
253+
assert response.json() == {
254+
"relations": {
255+
"parents": [],
256+
"dependencies": [
257+
{
258+
"from": {"kind": "JOB", "id": str(jobs[i].id)},
259+
"to": {"kind": "JOB", "id": str(jobs[j].id)},
260+
"type": "DIRECT_DEPENDENCY",
261+
}
262+
for i, j in sorted(expected_dep_indices)
263+
],
264+
},
265+
"nodes": {"jobs": jobs_to_json(expected_jobs)},
266+
}
267+
268+
269+
@pytest.mark.parametrize(
270+
["direction", "start_node_index"],
271+
[
272+
("UPSTREAM", 0),
273+
("DOWNSTREAM", 4),
274+
],
275+
ids=["upstream_boundary", "downstream_boundary"],
276+
)
277+
async def test_get_job_dependencies_with_depth_on_boundary(
278+
test_client: AsyncClient,
279+
job_dependency_depth_chain: tuple[Job, ...],
280+
async_session: AsyncSession,
281+
mocked_user: MockedUser,
282+
direction: str,
283+
start_node_index: int,
284+
):
285+
"""
286+
Fixture chain: job_0 → job_1 → job_2 → job_3 → job_4
287+
Start node is job_0 or job_4.
288+
"""
289+
jobs = job_dependency_depth_chain
290+
start_job = jobs[start_node_index]
291+
292+
[expected_job] = await enrich_jobs([start_job], async_session)
293+
294+
response = await test_client.get(
295+
"v1/jobs/dependencies",
296+
headers={"Authorization": f"Bearer {mocked_user.access_token}"},
297+
params={"start_node_id": start_job.id, "depth": 2, "direction": direction},
298+
)
299+
assert response.status_code == HTTPStatus.OK, response.json()
300+
assert response.json() == {
301+
"relations": {
302+
"parents": [],
303+
"dependencies": [],
304+
},
305+
"nodes": {"jobs": jobs_to_json([expected_job])},
306+
}

0 commit comments

Comments
 (0)