Skip to content

Commit c04e459

Browse files
authored
Merge pull request #88 from fuzziecoder/codex/implement-ray-for-ml-orchestration
Add distributed execution support (Celery & Ray)
2 parents 4a420ec + 0141730 commit c04e459

6 files changed

Lines changed: 251 additions & 3 deletions

File tree

backend/README.md

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,33 @@ curl -X POST http://localhost:8000/api/executions \
103103
```
104104

105105

106+
### Distributed Task Execution (Celery / Ray)
107+
108+
FlexiRoaster supports selectable execution backends for asynchronous and distributed workloads:
109+
110+
- `local`: default in-process execution
111+
- `celery`: async jobs, retries, and scheduling support through Celery workers
112+
- `ray`: distributed Python execution, optimized for ML/AI-heavy pipelines
113+
114+
Use the optional `execution_backend` field when creating an execution:
115+
116+
```bash
117+
curl -X POST http://localhost:8000/api/executions -H "Content-Type: application/json" -d '{"pipeline_id": "your-pipeline-id", "execution_backend": "ray"}'
118+
```
119+
120+
Or set a default backend via environment variables in `backend/.env`:
121+
122+
```env
123+
DISTRIBUTED_EXECUTION_BACKEND=local
124+
CELERY_BROKER_URL=redis://localhost:6379/0
125+
CELERY_RESULT_BACKEND=redis://localhost:6379/1
126+
CELERY_EXECUTION_TASK=flexiroaster.execute_pipeline
127+
RAY_ADDRESS=auto
128+
RAY_NAMESPACE=flexiroaster
129+
```
130+
131+
If Celery or Ray is unavailable, FlexiRoaster automatically falls back to local execution and records the fallback reason in execution context.
132+
106133
## Authentication & Security
107134

108135
- JWT authentication endpoint: `POST /api/auth/token`

backend/api/routes/executions.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
SuccessResponse
1616
)
1717
from backend.models.pipeline import Execution, ExecutionStatus
18+
from backend.core.distributed_executor import DistributedExecutionDispatcher
1819
from backend.core.executor import PipelineExecutor
1920
from backend.core.orchestration import OrchestrationEngine, OrchestrationRequest, OrchestrationRegistry
2021
ExecutionResponse,
@@ -88,6 +89,10 @@ def _merge_execution_logs(result_logs: List, existing_logs: List) -> List:
8889

8990
def initialize_execution(
9091
pipeline_id: str,
92+
context: Optional[Dict[str, Any]] = None,
93+
execution_backend: Optional[str] = None,
94+
) -> Execution:
95+
"""Create and store a pending execution record."""
9196
user_id: str,
9297
context: Optional[Dict[str, Any]] = None,
9398
) -> Execution:
@@ -96,13 +101,18 @@ def initialize_execution(
96101

97102
execution_id = f"exec-{uuid.uuid4()}"
98103
pipeline = pipelines_db[pipeline_id]
104+
merged_context = context.copy() if context else {}
105+
if execution_backend:
106+
merged_context["requested_execution_backend"] = execution_backend
107+
99108
execution = Execution(
100109
id=execution_id,
101110
pipeline_id=pipeline_id,
102111
user_id=user_id,
103112
status=ExecutionStatus.PENDING,
104113
started_at=datetime.now(),
105114
total_stages=len(pipeline.stages),
115+
context=merged_context
106116
context=context or {},
107117
)
108118
executions_db[execution_id] = execution
@@ -121,12 +131,19 @@ def initialize_execution(
121131
return execution
122132

123133

124-
async def execute_pipeline_background(pipeline_id: str, execution_id: str):
134+
async def execute_pipeline_background(
135+
pipeline_id: str,
136+
execution_id: str,
137+
execution_backend: Optional[str] = None,
138+
):
125139
"""Background task to execute pipeline."""
126140
try:
127141
pipeline = pipelines_db[pipeline_id]
128-
executor = PipelineExecutor()
129-
result = executor.execute(pipeline)
142+
dispatcher = DistributedExecutionDispatcher()
143+
dispatch_result = dispatcher.run(pipeline, backend_override=execution_backend)
144+
result = dispatch_result.execution
145+
result.context.setdefault("distributed_execution", {})
146+
result.context["distributed_execution"]["backend_used"] = dispatch_result.backend_used
130147

131148
existing = executions_db.get(execution_id)
132149
if existing:
@@ -216,6 +233,19 @@ async def create_execution(
216233
orchestration_context = _build_orchestration_context(execution_data)
217234

218235
# Create execution record
236+
execution = initialize_execution(
237+
execution_data.pipeline_id,
238+
execution_backend=execution_data.execution_backend
239+
)
240+
241+
# Start execution in background
242+
background_tasks.add_task(
243+
execute_pipeline_background,
244+
execution_data.pipeline_id,
245+
execution.id,
246+
execution_data.execution_backend
247+
)
248+
219249
execution = initialize_execution(execution_data.pipeline_id, context=orchestration_context)
220250

221251
orchestration_engine = ORCHESTRATION_SCHEMA_TO_CORE[execution_data.orchestration.engine]

backend/api/schemas.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,10 @@ class OrchestrationConfig(BaseModel):
123123
options: Dict[str, Any] = Field(default_factory=dict)
124124

125125
pipeline_id: str
126+
execution_backend: Optional[str] = Field(
127+
default=None,
128+
description="Optional override for distributed backend: local, celery, or ray",
129+
)
126130
orchestration: OrchestrationConfig = Field(default_factory=OrchestrationConfig)
127131

128132

backend/config.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,18 @@ class Settings(BaseSettings):
5151
TOPIC_EXECUTION_FAILED: str = "execution.failed"
5252
TOPIC_EXECUTION_COMPLETED: str = "execution.completed"
5353

54+
# Distributed execution backends: local|celery|ray
55+
DISTRIBUTED_EXECUTION_BACKEND: str = "local"
56+
57+
# Celery settings
58+
CELERY_BROKER_URL: str = "redis://localhost:6379/0"
59+
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
60+
CELERY_EXECUTION_TASK: str = "flexiroaster.execute_pipeline"
61+
62+
# Ray settings
63+
RAY_ADDRESS: str = "auto"
64+
RAY_NAMESPACE: str = "flexiroaster"
65+
5466
@field_validator("CORS_ORIGINS", mode="before")
5567
@classmethod
5668
def parse_cors_origins(cls, v):
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
"""Distributed execution dispatcher with optional Celery and Ray backends."""
2+
from __future__ import annotations
3+
4+
from dataclasses import dataclass
5+
from typing import Optional
6+
import logging
7+
8+
from backend.config import settings
9+
from backend.core.executor import PipelineExecutor
10+
from backend.models.pipeline import Execution, Pipeline
11+
12+
logger = logging.getLogger(__name__)
13+
14+
SUPPORTED_BACKENDS = {"local", "celery", "ray"}
15+
16+
17+
@dataclass
18+
class DispatchResult:
19+
"""Result wrapper for execution dispatch metadata."""
20+
21+
execution: Execution
22+
backend_used: str
23+
24+
25+
class DistributedExecutionDispatcher:
26+
"""Run pipeline execution on local runtime or distributed frameworks."""
27+
28+
def __init__(self):
29+
self.executor = PipelineExecutor()
30+
31+
def run(self, pipeline: Pipeline, backend_override: Optional[str] = None) -> DispatchResult:
32+
backend = (backend_override or settings.DISTRIBUTED_EXECUTION_BACKEND or "local").lower().strip()
33+
34+
if backend not in SUPPORTED_BACKENDS:
35+
logger.warning("Unsupported backend '%s'. Falling back to local.", backend)
36+
backend = "local"
37+
38+
if backend == "celery":
39+
execution, used_backend = self._execute_with_celery(pipeline)
40+
return DispatchResult(execution=execution, backend_used=used_backend)
41+
42+
if backend == "ray":
43+
execution, used_backend = self._execute_with_ray(pipeline)
44+
return DispatchResult(execution=execution, backend_used=used_backend)
45+
46+
execution = self.executor.execute(pipeline)
47+
return DispatchResult(execution=execution, backend_used="local")
48+
49+
def _execute_with_celery(self, pipeline: Pipeline) -> tuple[Execution, str]:
50+
"""Try Celery path; fallback to local execution when unavailable."""
51+
try:
52+
from celery import Celery
53+
54+
app = Celery(
55+
"flexiroaster",
56+
broker=settings.CELERY_BROKER_URL,
57+
backend=settings.CELERY_RESULT_BACKEND,
58+
)
59+
task_name = settings.CELERY_EXECUTION_TASK
60+
61+
payload = pipeline.model_dump(mode="json")
62+
async_result = app.send_task(task_name, kwargs={"pipeline": payload})
63+
remote_output = async_result.get(timeout=600)
64+
execution = Execution.model_validate(remote_output)
65+
logger.info("Pipeline %s executed via Celery task %s", pipeline.id, task_name)
66+
return execution, "celery"
67+
except Exception as exc:
68+
logger.warning("Celery backend unavailable (%s). Executing locally.", exc)
69+
execution = self.executor.execute(pipeline)
70+
execution.context.setdefault("distributed_execution", {})
71+
execution.context["distributed_execution"].update(
72+
{
73+
"requested_backend": "celery",
74+
"fallback_backend": "local",
75+
"fallback_reason": str(exc),
76+
}
77+
)
78+
return execution, "local"
79+
80+
def _execute_with_ray(self, pipeline: Pipeline) -> tuple[Execution, str]:
81+
"""Try Ray path; fallback to local execution when unavailable."""
82+
try:
83+
import ray
84+
85+
if not ray.is_initialized():
86+
ray.init(address=settings.RAY_ADDRESS, namespace=settings.RAY_NAMESPACE, ignore_reinit_error=True)
87+
88+
@ray.remote
89+
def execute_pipeline_remote(pipeline_payload: dict):
90+
from backend.core.executor import PipelineExecutor
91+
from backend.models.pipeline import Pipeline
92+
93+
model = Pipeline.model_validate(pipeline_payload)
94+
result = PipelineExecutor().execute(model)
95+
return result.model_dump(mode="json")
96+
97+
payload = pipeline.model_dump(mode="json")
98+
remote_ref = execute_pipeline_remote.remote(payload)
99+
remote_output = ray.get(remote_ref)
100+
execution = Execution.model_validate(remote_output)
101+
logger.info("Pipeline %s executed via Ray remote function", pipeline.id)
102+
return execution, "ray"
103+
except Exception as exc:
104+
logger.warning("Ray backend unavailable (%s). Executing locally.", exc)
105+
execution = self.executor.execute(pipeline)
106+
execution.context.setdefault("distributed_execution", {})
107+
execution.context["distributed_execution"].update(
108+
{
109+
"requested_backend": "ray",
110+
"fallback_backend": "local",
111+
"fallback_reason": str(exc),
112+
}
113+
)
114+
return execution, "local"
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import asyncio
2+
3+
from fastapi import BackgroundTasks
4+
5+
from backend.api.routes.executions import create_execution, execute_pipeline_background, executions_db
6+
from backend.api.routes.pipelines import create_pipeline, pipelines_db
7+
from backend.api.schemas import ExecutionCreate, PipelineCreate, StageCreate, StageTypeSchema
8+
from backend.core.distributed_executor import DistributedExecutionDispatcher
9+
from backend.models.pipeline import ExecutionStatus
10+
11+
12+
def _build_pipeline():
13+
pipelines_db.clear()
14+
executions_db.clear()
15+
return asyncio.run(
16+
create_pipeline(
17+
PipelineCreate(
18+
name="distributed-pipeline",
19+
description="pipeline for distributed testing",
20+
stages=[
21+
StageCreate(id="in", name="Input", type=StageTypeSchema.INPUT, config={"source": "x", "data": [1, 2]}),
22+
StageCreate(
23+
id="out",
24+
name="Output",
25+
type=StageTypeSchema.OUTPUT,
26+
config={"destination": "console"},
27+
dependencies=["in"],
28+
),
29+
],
30+
)
31+
)
32+
)
33+
34+
35+
def test_dispatcher_falls_back_to_local_for_unknown_backend():
36+
pipeline = _build_pipeline()
37+
38+
dispatcher = DistributedExecutionDispatcher()
39+
result = dispatcher.run(pipeline, backend_override="spark")
40+
41+
assert result.backend_used == "local"
42+
assert result.execution.status == ExecutionStatus.COMPLETED
43+
44+
45+
def test_create_execution_tracks_requested_backend_and_backend_used():
46+
pipeline = _build_pipeline()
47+
48+
execution = asyncio.run(
49+
create_execution(
50+
ExecutionCreate(pipeline_id=pipeline.id, execution_backend="celery"),
51+
BackgroundTasks(),
52+
)
53+
)
54+
55+
assert execution.context["requested_execution_backend"] == "celery"
56+
57+
asyncio.run(execute_pipeline_background(pipeline.id, execution.id, "celery"))
58+
stored = executions_db[execution.id]
59+
60+
assert stored.context["requested_execution_backend"] == "celery"
61+
assert stored.context["distributed_execution"]["backend_used"] == "local"

0 commit comments

Comments
 (0)