Skip to content

Commit af1a409

Browse files
committed
Add distributed execution support with Celery and Ray backends
1 parent 3ac4719 commit af1a409

6 files changed

Lines changed: 245 additions & 8 deletions

File tree

backend/README.md

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,33 @@ curl -X POST http://localhost:8000/api/executions \
103103
```
104104

105105

106+
### Distributed Task Execution (Celery / Ray)
107+
108+
FlexiRoaster supports selectable execution backends for asynchronous and distributed workloads:
109+
110+
- `local`: default in-process execution
111+
- `celery`: async jobs, retries, and scheduling support through Celery workers
112+
- `ray`: distributed Python execution, optimized for ML/AI-heavy pipelines
113+
114+
Use the optional `execution_backend` field when creating an execution:
115+
116+
```bash
117+
curl -X POST http://localhost:8000/api/executions -H "Content-Type: application/json" -d '{"pipeline_id": "your-pipeline-id", "execution_backend": "ray"}'
118+
```
119+
120+
Or set a default backend via environment variables in `backend/.env`:
121+
122+
```env
123+
DISTRIBUTED_EXECUTION_BACKEND=local
124+
CELERY_BROKER_URL=redis://localhost:6379/0
125+
CELERY_RESULT_BACKEND=redis://localhost:6379/1
126+
CELERY_EXECUTION_TASK=flexiroaster.execute_pipeline
127+
RAY_ADDRESS=auto
128+
RAY_NAMESPACE=flexiroaster
129+
```
130+
131+
If Celery or Ray is unavailable, FlexiRoaster automatically falls back to local execution and records the fallback reason in execution context.
132+
106133
## Authentication & Security
107134

108135
- JWT authentication endpoint: `POST /api/auth/token`

backend/api/routes/executions.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
SuccessResponse
1515
)
1616
from backend.models.pipeline import Execution, ExecutionStatus
17-
from backend.core.executor import PipelineExecutor
17+
from backend.core.distributed_executor import DistributedExecutionDispatcher
1818
from backend.config import settings
1919
from backend.events import get_event_publisher
2020

@@ -57,19 +57,27 @@ def _merge_execution_logs(result_logs: List, existing_logs: List) -> List:
5757
return merged_logs
5858

5959

60-
def initialize_execution(pipeline_id: str, context: Optional[Dict[str, Any]] = None) -> Execution:
60+
def initialize_execution(
61+
pipeline_id: str,
62+
context: Optional[Dict[str, Any]] = None,
63+
execution_backend: Optional[str] = None,
64+
) -> Execution:
6165
"""Create and store a pending execution record."""
6266
from datetime import datetime
6367

6468
execution_id = f"exec-{uuid.uuid4()}"
6569
pipeline = pipelines_db[pipeline_id]
70+
merged_context = context.copy() if context else {}
71+
if execution_backend:
72+
merged_context["requested_execution_backend"] = execution_backend
73+
6674
execution = Execution(
6775
id=execution_id,
6876
pipeline_id=pipeline_id,
6977
status=ExecutionStatus.PENDING,
7078
started_at=datetime.now(),
7179
total_stages=len(pipeline.stages),
72-
context=context or {}
80+
context=merged_context
7381
)
7482
executions_db[execution_id] = execution
7583

@@ -86,12 +94,19 @@ def initialize_execution(pipeline_id: str, context: Optional[Dict[str, Any]] = N
8694
return execution
8795

8896

89-
async def execute_pipeline_background(pipeline_id: str, execution_id: str):
97+
async def execute_pipeline_background(
98+
pipeline_id: str,
99+
execution_id: str,
100+
execution_backend: Optional[str] = None,
101+
):
90102
"""Background task to execute pipeline."""
91103
try:
92104
pipeline = pipelines_db[pipeline_id]
93-
executor = PipelineExecutor()
94-
result = executor.execute(pipeline)
105+
dispatcher = DistributedExecutionDispatcher()
106+
dispatch_result = dispatcher.run(pipeline, backend_override=execution_backend)
107+
result = dispatch_result.execution
108+
result.context.setdefault("distributed_execution", {})
109+
result.context["distributed_execution"]["backend_used"] = dispatch_result.backend_used
95110

96111
existing = executions_db.get(execution_id)
97112
if existing:
@@ -183,13 +198,17 @@ async def create_execution(
183198
)
184199

185200
# Create execution record
186-
execution = initialize_execution(execution_data.pipeline_id)
201+
execution = initialize_execution(
202+
execution_data.pipeline_id,
203+
execution_backend=execution_data.execution_backend
204+
)
187205

188206
# Start execution in background
189207
background_tasks.add_task(
190208
execute_pipeline_background,
191209
execution_data.pipeline_id,
192-
execution.id
210+
execution.id,
211+
execution_data.execution_backend
193212
)
194213

195214
return execution

backend/api/schemas.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,10 @@ class PipelineListResponse(BaseModel):
103103
class ExecutionCreate(BaseModel):
104104
"""Schema for creating an execution"""
105105
pipeline_id: str
106+
execution_backend: Optional[str] = Field(
107+
default=None,
108+
description="Optional override for distributed backend: local, celery, or ray",
109+
)
106110

107111

108112
class AirflowTriggerRequest(BaseModel):

backend/config.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,18 @@ class Settings(BaseSettings):
5050
TOPIC_EXECUTION_FAILED: str = "execution.failed"
5151
TOPIC_EXECUTION_COMPLETED: str = "execution.completed"
5252

53+
# Distributed execution backends: local|celery|ray
54+
DISTRIBUTED_EXECUTION_BACKEND: str = "local"
55+
56+
# Celery settings
57+
CELERY_BROKER_URL: str = "redis://localhost:6379/0"
58+
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
59+
CELERY_EXECUTION_TASK: str = "flexiroaster.execute_pipeline"
60+
61+
# Ray settings
62+
RAY_ADDRESS: str = "auto"
63+
RAY_NAMESPACE: str = "flexiroaster"
64+
5365
@field_validator("CORS_ORIGINS", mode="before")
5466
@classmethod
5567
def parse_cors_origins(cls, v):
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
"""Distributed execution dispatcher with optional Celery and Ray backends."""
2+
from __future__ import annotations
3+
4+
from dataclasses import dataclass
5+
from typing import Optional
6+
import logging
7+
8+
from backend.config import settings
9+
from backend.core.executor import PipelineExecutor
10+
from backend.models.pipeline import Execution, Pipeline
11+
12+
logger = logging.getLogger(__name__)
13+
14+
SUPPORTED_BACKENDS = {"local", "celery", "ray"}
15+
16+
17+
@dataclass
18+
class DispatchResult:
19+
"""Result wrapper for execution dispatch metadata."""
20+
21+
execution: Execution
22+
backend_used: str
23+
24+
25+
class DistributedExecutionDispatcher:
26+
"""Run pipeline execution on local runtime or distributed frameworks."""
27+
28+
def __init__(self):
29+
self.executor = PipelineExecutor()
30+
31+
def run(self, pipeline: Pipeline, backend_override: Optional[str] = None) -> DispatchResult:
32+
backend = (backend_override or settings.DISTRIBUTED_EXECUTION_BACKEND or "local").lower().strip()
33+
34+
if backend not in SUPPORTED_BACKENDS:
35+
logger.warning("Unsupported backend '%s'. Falling back to local.", backend)
36+
backend = "local"
37+
38+
if backend == "celery":
39+
execution, used_backend = self._execute_with_celery(pipeline)
40+
return DispatchResult(execution=execution, backend_used=used_backend)
41+
42+
if backend == "ray":
43+
execution, used_backend = self._execute_with_ray(pipeline)
44+
return DispatchResult(execution=execution, backend_used=used_backend)
45+
46+
execution = self.executor.execute(pipeline)
47+
return DispatchResult(execution=execution, backend_used="local")
48+
49+
def _execute_with_celery(self, pipeline: Pipeline) -> tuple[Execution, str]:
50+
"""Try Celery path; fallback to local execution when unavailable."""
51+
try:
52+
from celery import Celery
53+
54+
app = Celery(
55+
"flexiroaster",
56+
broker=settings.CELERY_BROKER_URL,
57+
backend=settings.CELERY_RESULT_BACKEND,
58+
)
59+
task_name = settings.CELERY_EXECUTION_TASK
60+
61+
payload = pipeline.model_dump(mode="json")
62+
async_result = app.send_task(task_name, kwargs={"pipeline": payload})
63+
remote_output = async_result.get(timeout=600)
64+
execution = Execution.model_validate(remote_output)
65+
logger.info("Pipeline %s executed via Celery task %s", pipeline.id, task_name)
66+
return execution, "celery"
67+
except Exception as exc:
68+
logger.warning("Celery backend unavailable (%s). Executing locally.", exc)
69+
execution = self.executor.execute(pipeline)
70+
execution.context.setdefault("distributed_execution", {})
71+
execution.context["distributed_execution"].update(
72+
{
73+
"requested_backend": "celery",
74+
"fallback_backend": "local",
75+
"fallback_reason": str(exc),
76+
}
77+
)
78+
return execution, "local"
79+
80+
def _execute_with_ray(self, pipeline: Pipeline) -> tuple[Execution, str]:
81+
"""Try Ray path; fallback to local execution when unavailable."""
82+
try:
83+
import ray
84+
85+
if not ray.is_initialized():
86+
ray.init(address=settings.RAY_ADDRESS, namespace=settings.RAY_NAMESPACE, ignore_reinit_error=True)
87+
88+
@ray.remote
89+
def execute_pipeline_remote(pipeline_payload: dict):
90+
from backend.core.executor import PipelineExecutor
91+
from backend.models.pipeline import Pipeline
92+
93+
model = Pipeline.model_validate(pipeline_payload)
94+
result = PipelineExecutor().execute(model)
95+
return result.model_dump(mode="json")
96+
97+
payload = pipeline.model_dump(mode="json")
98+
remote_ref = execute_pipeline_remote.remote(payload)
99+
remote_output = ray.get(remote_ref)
100+
execution = Execution.model_validate(remote_output)
101+
logger.info("Pipeline %s executed via Ray remote function", pipeline.id)
102+
return execution, "ray"
103+
except Exception as exc:
104+
logger.warning("Ray backend unavailable (%s). Executing locally.", exc)
105+
execution = self.executor.execute(pipeline)
106+
execution.context.setdefault("distributed_execution", {})
107+
execution.context["distributed_execution"].update(
108+
{
109+
"requested_backend": "ray",
110+
"fallback_backend": "local",
111+
"fallback_reason": str(exc),
112+
}
113+
)
114+
return execution, "local"
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import asyncio
2+
3+
from fastapi import BackgroundTasks
4+
5+
from backend.api.routes.executions import create_execution, execute_pipeline_background, executions_db
6+
from backend.api.routes.pipelines import create_pipeline, pipelines_db
7+
from backend.api.schemas import ExecutionCreate, PipelineCreate, StageCreate, StageTypeSchema
8+
from backend.core.distributed_executor import DistributedExecutionDispatcher
9+
from backend.models.pipeline import ExecutionStatus
10+
11+
12+
def _build_pipeline():
13+
pipelines_db.clear()
14+
executions_db.clear()
15+
return asyncio.run(
16+
create_pipeline(
17+
PipelineCreate(
18+
name="distributed-pipeline",
19+
description="pipeline for distributed testing",
20+
stages=[
21+
StageCreate(id="in", name="Input", type=StageTypeSchema.INPUT, config={"source": "x", "data": [1, 2]}),
22+
StageCreate(
23+
id="out",
24+
name="Output",
25+
type=StageTypeSchema.OUTPUT,
26+
config={"destination": "console"},
27+
dependencies=["in"],
28+
),
29+
],
30+
)
31+
)
32+
)
33+
34+
35+
def test_dispatcher_falls_back_to_local_for_unknown_backend():
36+
pipeline = _build_pipeline()
37+
38+
dispatcher = DistributedExecutionDispatcher()
39+
result = dispatcher.run(pipeline, backend_override="spark")
40+
41+
assert result.backend_used == "local"
42+
assert result.execution.status == ExecutionStatus.COMPLETED
43+
44+
45+
def test_create_execution_tracks_requested_backend_and_backend_used():
46+
pipeline = _build_pipeline()
47+
48+
execution = asyncio.run(
49+
create_execution(
50+
ExecutionCreate(pipeline_id=pipeline.id, execution_backend="celery"),
51+
BackgroundTasks(),
52+
)
53+
)
54+
55+
assert execution.context["requested_execution_backend"] == "celery"
56+
57+
asyncio.run(execute_pipeline_background(pipeline.id, execution.id, "celery"))
58+
stored = executions_db[execution.id]
59+
60+
assert stored.context["requested_execution_backend"] == "celery"
61+
assert stored.context["distributed_execution"]["backend_used"] == "local"

0 commit comments

Comments
 (0)