Skip to content

Commit ba1db33

Browse files
committed
feat: Add include_sql to Search Pipeline Run API
1 parent 9847bfa commit ba1db33

2 files changed

Lines changed: 128 additions & 10 deletions

File tree

cloud_pipelines_backend/api_server_sql.py

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ class GetPipelineRunResponse(PipelineRunResponse):
6464
class ListPipelineJobsResponse:
6565
pipeline_runs: list[PipelineRunResponse]
6666
next_page_token: str | None = None
67+
sql: str | None = None
6768

6869

6970
class PipelineRunsApiService_Sql:
@@ -171,6 +172,36 @@ def terminate(
171172
execution_node.extra_data["desired_state"] = "TERMINATED"
172173
session.commit()
173174

175+
@staticmethod
176+
def _compile_sql_string(
177+
stmt: sql.Select,
178+
dialect: sql.engine.Dialect,
179+
) -> str:
180+
"""Compile a SQLAlchemy statement to a SQL string for debugging.
181+
182+
Uses ``literal_binds=True`` to inline bound parameters as literal
183+
values, producing a self-contained query string::
184+
185+
SELECT ... WHERE key = 'environment' AND created_at < '2024-01-15' LIMIT 10
186+
187+
If a column type lacks a ``literal_processor`` (raises CompileError or
188+
NotImplementedError), falls back to placeholder syntax with a params
189+
comment::
190+
191+
SELECT ... WHERE key = :key_1 AND created_at < :created_at_1 LIMIT :param_1
192+
-- params: {'key_1': 'environment', 'created_at_1': '2024-01-15', 'param_1': 10}
193+
"""
194+
try:
195+
compiled = stmt.compile(
196+
dialect=dialect,
197+
compile_kwargs={"literal_binds": True},
198+
)
199+
return str(compiled)
200+
except (sql.exc.CompileError, NotImplementedError):
201+
compiled = stmt.compile(dialect=dialect)
202+
params_suffix = f"\n-- params: {compiled.params}" if compiled.params else ""
203+
return str(compiled) + params_suffix
204+
174205
# Note: This method must be last to not shadow the "list" type
175206
def list(
176207
self,
@@ -182,6 +213,7 @@ def list(
182213
current_user: str | None = None,
183214
include_pipeline_names: bool = False,
184215
include_execution_stats: bool = False,
216+
include_sql: bool = False,
185217
) -> ListPipelineJobsResponse:
186218
where_clauses = filter_query_sql.build_list_filters(
187219
filter_value=filter,
@@ -190,18 +222,22 @@ def list(
190222
current_user=current_user,
191223
)
192224

193-
pipeline_runs = list(
194-
session.scalars(
195-
sql.select(bts.PipelineRun)
196-
.where(*where_clauses)
197-
.order_by(
198-
bts.PipelineRun.created_at.desc(),
199-
bts.PipelineRun.id.desc(),
200-
)
201-
.limit(self._DEFAULT_PAGE_SIZE)
202-
).all()
225+
stmt = (
226+
sql.select(bts.PipelineRun)
227+
.where(*where_clauses)
228+
.order_by(
229+
bts.PipelineRun.created_at.desc(),
230+
bts.PipelineRun.id.desc(),
231+
)
232+
.limit(self._DEFAULT_PAGE_SIZE)
203233
)
204234

235+
sql_string = None
236+
if include_sql:
237+
sql_string = self._compile_sql_string(stmt, session.bind.dialect)
238+
239+
pipeline_runs = list(session.scalars(stmt).all())
240+
205241
next_page_token = filter_query_sql.maybe_next_page_token(
206242
rows=pipeline_runs, page_size=self._DEFAULT_PAGE_SIZE
207243
)
@@ -217,6 +253,7 @@ def list(
217253
for pipeline_run in pipeline_runs
218254
],
219255
next_page_token=next_page_token,
256+
sql=sql_string,
220257
)
221258

222259
def _create_pipeline_run_response(

tests/test_api_server_sql.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,87 @@ def test_list_filter_created_by_me(self, session_factory, service):
295295
assert len(result.pipeline_runs) == 1
296296
assert result.pipeline_runs[0].created_by == "alice@example.com"
297297

298+
def test_list_include_sql_default_none(self, session_factory, service):
299+
_create_run(session_factory, service, root_task=_make_task_spec())
300+
301+
with session_factory() as session:
302+
result = service.list(session=session)
303+
assert result.sql is None
304+
305+
def test_list_include_sql_true(self, session_factory, service):
306+
_create_run(session_factory, service, root_task=_make_task_spec())
307+
308+
with session_factory() as session:
309+
result = service.list(session=session, include_sql=True)
310+
expected = (
311+
"SELECT pipeline_run.id, pipeline_run.root_execution_id,"
312+
" pipeline_run.annotations, pipeline_run.created_by,"
313+
" pipeline_run.created_at, pipeline_run.updated_at,"
314+
" pipeline_run.parent_pipeline_id, pipeline_run.extra_data \n"
315+
"FROM pipeline_run"
316+
" ORDER BY pipeline_run.created_at DESC, pipeline_run.id DESC\n"
317+
" LIMIT 10 OFFSET 0"
318+
)
319+
assert result.sql == expected
320+
321+
def test_list_include_sql_with_filter_query(self, session_factory, service):
322+
run = _create_run(session_factory, service, root_task=_make_task_spec())
323+
with session_factory() as session:
324+
service.set_annotation(session=session, id=run.id, key="team", value="ml")
325+
326+
fq = json.dumps({"and": [{"key_exists": {"key": "team"}}]})
327+
with session_factory() as session:
328+
result = service.list(session=session, filter_query=fq, include_sql=True)
329+
expected = (
330+
"SELECT pipeline_run.id, pipeline_run.root_execution_id,"
331+
" pipeline_run.annotations, pipeline_run.created_by,"
332+
" pipeline_run.created_at, pipeline_run.updated_at,"
333+
" pipeline_run.parent_pipeline_id, pipeline_run.extra_data \n"
334+
"FROM pipeline_run \n"
335+
"WHERE EXISTS (SELECT pipeline_run_annotation.pipeline_run_id \n"
336+
"FROM pipeline_run_annotation \n"
337+
"WHERE pipeline_run_annotation.pipeline_run_id = pipeline_run.id"
338+
" AND pipeline_run_annotation.\"key\" = 'team')"
339+
" ORDER BY pipeline_run.created_at DESC, pipeline_run.id DESC\n"
340+
" LIMIT 10 OFFSET 0"
341+
)
342+
assert result.sql == expected
343+
344+
def test_list_include_sql_with_cursor(self, session_factory, service):
345+
for i in range(12):
346+
_create_run(
347+
session_factory,
348+
service,
349+
root_task=_make_task_spec(f"pipeline-{i}"),
350+
)
351+
352+
with session_factory() as session:
353+
page1 = service.list(session=session)
354+
assert page1.next_page_token is not None
355+
356+
with session_factory() as session:
357+
page2 = service.list(
358+
session=session,
359+
page_token=page1.next_page_token,
360+
include_sql=True,
361+
)
362+
363+
cursor_dt_iso, cursor_id = page1.next_page_token.split("~")
364+
cursor_dt = datetime.datetime.fromisoformat(cursor_dt_iso)
365+
sql_dt = cursor_dt.strftime("%Y-%m-%d %H:%M:%S.%f")
366+
expected = (
367+
"SELECT pipeline_run.id, pipeline_run.root_execution_id,"
368+
" pipeline_run.annotations, pipeline_run.created_by,"
369+
" pipeline_run.created_at, pipeline_run.updated_at,"
370+
" pipeline_run.parent_pipeline_id, pipeline_run.extra_data \n"
371+
"FROM pipeline_run \n"
372+
f"WHERE (pipeline_run.created_at, pipeline_run.id)"
373+
f" < ('{sql_dt}', '{cursor_id}')"
374+
" ORDER BY pipeline_run.created_at DESC, pipeline_run.id DESC\n"
375+
" LIMIT 10 OFFSET 0"
376+
)
377+
assert page2.sql == expected
378+
298379

299380
class TestCreatePipelineRunResponse:
300381
def test_base_response(self, session_factory, service):

0 commit comments

Comments
 (0)