-
-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathtest_distributed_execution.py
More file actions
81 lines (59 loc) · 2.76 KB
/
test_distributed_execution.py
File metadata and controls
81 lines (59 loc) · 2.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import asyncio
from fastapi import BackgroundTasks
from backend.api.routes.executions import create_execution, execute_pipeline_background, executions_db
from backend.api.routes.pipelines import create_pipeline, pipelines_db
from backend.api.schemas import ExecutionCreate, PipelineCreate, StageCreate, StageTypeSchema
from backend.core.distributed_executor import DistributedExecutionDispatcher
from backend.models.pipeline import ExecutionStatus
def _build_pipeline():
pipelines_db.clear()
executions_db.clear()
return asyncio.run(
create_pipeline(
PipelineCreate(
name="distributed-pipeline",
description="pipeline for distributed testing",
stages=[
StageCreate(id="in", name="Input", type=StageTypeSchema.INPUT, config={"source": "x", "data": [1, 2]}),
StageCreate(
id="out",
name="Output",
type=StageTypeSchema.OUTPUT,
config={"destination": "console"},
dependencies=["in"],
),
],
)
)
)
def test_dispatcher_falls_back_to_local_for_unknown_backend():
pipeline = _build_pipeline()
dispatcher = DistributedExecutionDispatcher()
result = dispatcher.run(pipeline, backend_override="unknown-backend")
assert result.backend_used == "local"
assert result.execution.status == ExecutionStatus.COMPLETED
def test_create_execution_tracks_requested_backend_and_backend_used():
pipeline = _build_pipeline()
execution = asyncio.run(
create_execution(
ExecutionCreate(pipeline_id=pipeline.id, execution_backend="celery"),
BackgroundTasks(),
)
)
assert execution.context["requested_execution_backend"] == "celery"
asyncio.run(execute_pipeline_background(pipeline.id, execution.id, "celery"))
stored = executions_db[execution.id]
assert stored.context["requested_execution_backend"] == "celery"
assert stored.context["distributed_execution"]["backend_used"] == "local"
def test_dispatcher_spark_falls_back_when_dependency_missing():
pipeline = _build_pipeline()
dispatcher = DistributedExecutionDispatcher()
result = dispatcher.run(pipeline, backend_override="spark")
assert result.backend_used in {"spark", "local"}
assert result.execution.status == ExecutionStatus.COMPLETED
def test_dispatcher_dask_falls_back_when_dependency_missing():
pipeline = _build_pipeline()
dispatcher = DistributedExecutionDispatcher()
result = dispatcher.run(pipeline, backend_override="dask")
assert result.backend_used in {"dask", "local"}
assert result.execution.status == ExecutionStatus.COMPLETED