Skip to content

Commit 3d9891f

Browse files
committed
feat: enhance pipeline validation with parameter and dependency checks
1 parent 27176eb commit 3d9891f

6 files changed

Lines changed: 119 additions & 9 deletions

File tree

ai_agent/pipeline_agent.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,11 @@ def validate_pipeline(pipeline_def: dict, registry: dict) -> tuple[list[str], li
128128
for i, step in enumerate(pipeline["steps"]):
129129
step_id = step.get("id", f"step_{i}")
130130

131+
params = step.get("params") or {}
132+
if not isinstance(params, dict):
133+
errors.append(f"Step '{step_id}': 'params' must be an object/dict")
134+
params = {}
135+
131136
if step_id in step_ids:
132137
errors.append(f"Duplicate step ID: '{step_id}'")
133138
step_ids.add(step_id)
@@ -150,16 +155,35 @@ def validate_pipeline(pipeline_def: dict, registry: dict) -> tuple[list[str], li
150155
for pname, pinfo in svc_info.get("params", {}).items():
151156
if pname == "dataset_name":
152157
continue # auto-injected by the compiler
153-
if pinfo.get("required") and pname not in step.get("params", {}):
158+
if pinfo.get("required") and pname not in params:
154159
errors.append(
155160
f"Step '{step_id}': missing required param '{pname}' for service '{service}'"
156161
)
157162

163+
# Validate depends_on references and semantics
164+
depends_on = step.get("depends_on", [])
165+
if depends_on is None:
166+
depends_on = []
167+
if not isinstance(depends_on, list):
168+
errors.append(f"Step '{step_id}': 'depends_on' must be a list")
169+
depends_on = []
170+
158171
# Validate depends_on references
159-
for dep in step.get("depends_on", []):
172+
for dep in depends_on:
160173
if dep not in all_step_ids:
161174
errors.append(f"Step '{step_id}': depends_on references unknown step '{dep}'")
162175

176+
if service in valid_services:
177+
svc_type = registry["services"][service]["type"]
178+
if svc_type == "extract" and depends_on:
179+
errors.append(f"Step '{step_id}': extract steps must not have depends_on")
180+
if svc_type != "extract" and not depends_on:
181+
errors.append(f"Step '{step_id}': non-extract steps require depends_on")
182+
if service == "join_datasets" and len(depends_on) != 2:
183+
errors.append(f"Step '{step_id}': join_datasets requires exactly 2 depends_on entries")
184+
if service != "join_datasets" and len(depends_on) > 1:
185+
errors.append(f"Step '{step_id}': only join_datasets supports multiple depends_on entries")
186+
163187
if not has_extract:
164188
errors.append("Pipeline must have at least one extract step")
165189

ai_agent/pipeline_compiler.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,9 @@ def _build_dispatch_registry(prep) -> dict[str, Callable]:
189189
}
190190

191191

192+
_EXTRACT_SERVICES = {"extract_csv", "extract_excel", "extract_api", "extract_sql"}
193+
194+
192195
# ── Pipeline Compiler ──────────────────────────────────────────────
193196

194197
class PipelineCompiler:
@@ -417,4 +420,10 @@ def _dispatch_step(
417420
"(one for each input dataset)"
418421
)
419422

423+
if service not in _EXTRACT_SERVICES and input_data is None:
424+
raise ValueError(
425+
f"Service '{service}' requires input data. "
426+
"Check depends_on and upstream outputs."
427+
)
428+
420429
return handler(params, input_data, dataset_name, input_data_2)

airflow/Dockerfile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ FROM apache/airflow:2.10.4
66
ENV AIRFLOW__METRICS__STATSD_ON=True
77
ENV AIRFLOW__METRICS__STATSD_HOST=statsd-exporter
88
ENV AIRFLOW__METRICS__STATSD_PORT=9125
9+
ENV PYTHONWARNINGS="ignore:invalid escape sequence.*:SyntaxWarning:azure\\.synapse\\.artifacts\\.models\\._models_py3"
910

1011
# Crea le cartelle necessarie
1112
RUN mkdir -p /opt/airflow/dags /opt/airflow/logs /opt/airflow/plugins

services/data-quality-service/app/dq.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,9 @@ def basic_quality_checks(arrow_table, rules=None):
102102
if range_rules:
103103
range_results = {}
104104
for col, bounds in range_rules.items():
105+
if not isinstance(bounds, dict):
106+
range_results[col] = {"pass": False, "reason": "invalid bounds"}
107+
continue
105108
if col not in df.columns:
106109
range_results[col] = {"pass": False, "reason": "column not found"}
107110
continue
@@ -111,15 +114,22 @@ def basic_quality_checks(arrow_table, rules=None):
111114
col_min = float(df[col].min()) if df[col].notna().any() else None
112115
col_max = float(df[col].max()) if df[col].notna().any() else None
113116
ok = True
114-
if "min" in bounds and col_min is not None:
115-
ok = ok and col_min >= bounds["min"]
116-
if "max" in bounds and col_max is not None:
117-
ok = ok and col_max <= bounds["max"]
117+
try:
118+
expected_min = float(bounds["min"]) if "min" in bounds else None
119+
expected_max = float(bounds["max"]) if "max" in bounds else None
120+
except (TypeError, ValueError):
121+
range_results[col] = {"pass": False, "reason": "invalid bounds"}
122+
continue
123+
124+
if expected_min is not None and col_min is not None:
125+
ok = ok and col_min >= expected_min
126+
if expected_max is not None and col_max is not None:
127+
ok = ok and col_max <= expected_max
118128
range_results[col] = {
119129
"actual_min": col_min,
120130
"actual_max": col_max,
121-
"expected_min": bounds.get("min"),
122-
"expected_max": bounds.get("max"),
131+
"expected_min": expected_min,
132+
"expected_max": expected_max,
123133
"pass": bool(ok)
124134
}
125135
result["checks"]["value_range"] = range_results

services/outlier-detection-service/app/outliers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,9 @@ def detect_and_remove_outliers(arrow_table, column, z_threshold=3.0):
4040
return pa.Table.from_pandas(df), 0
4141

4242
z_score = (series_numeric - mean_val).abs() / std_val
43-
filtered_df = df[z_score <= z_threshold]
43+
# Keep rows where z_score is NaN to avoid dropping non-numeric rows silently
44+
keep_mask = (z_score <= z_threshold) | z_score.isna()
45+
filtered_df = df[keep_mask]
4446
removed_count = before_rows - filtered_df.shape[0]
4547

4648
new_table = pa.Table.from_pandas(filtered_df)

tests/unit/test_pipeline_agent.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,14 @@
4242
"dataset_name": {"type": "string", "required": True, "description": "Dataset name"},
4343
},
4444
},
45+
"join_datasets": {
46+
"type": "transform",
47+
"params": {
48+
"dataset_name": {"type": "string", "required": True, "description": "Dataset name"},
49+
"join_key": {"type": "string", "required": False},
50+
"join_type": {"type": "string", "required": False},
51+
},
52+
},
4553
}
4654
}
4755

@@ -171,6 +179,62 @@ def test_step_missing_service_field(self):
171179
errors, warnings = validate_pipeline(pipeline, MINIMAL_REGISTRY)
172180
assert any("service" in e.lower() for e in errors)
173181

182+
def test_non_extract_requires_depends_on(self):
183+
pipeline = _pipeline([
184+
{"id": "extract", "service": "extract_csv", "params": {"file_path": "/data/f.csv"}},
185+
{"id": "clean", "service": "clean_nan"},
186+
])
187+
errors, warnings = validate_pipeline(pipeline, MINIMAL_REGISTRY)
188+
assert any("requires depends_on" in e.lower() for e in errors)
189+
190+
def test_extract_cannot_have_depends_on(self):
191+
pipeline = _pipeline([
192+
{"id": "extract", "service": "extract_csv", "params": {"file_path": "/data/f.csv"}, "depends_on": ["x"]},
193+
])
194+
errors, warnings = validate_pipeline(pipeline, MINIMAL_REGISTRY)
195+
assert any("must not have depends_on" in e.lower() for e in errors)
196+
197+
def test_params_must_be_dict(self):
198+
pipeline = _pipeline([
199+
{"id": "extract", "service": "extract_csv", "params": []},
200+
])
201+
errors, warnings = validate_pipeline(pipeline, MINIMAL_REGISTRY)
202+
assert any("params" in e.lower() and "dict" in e.lower() for e in errors)
203+
204+
def test_depends_on_must_be_list(self):
205+
pipeline = _pipeline([
206+
{"id": "extract", "service": "extract_csv", "params": {"file_path": "/data/f.csv"}},
207+
{"id": "clean", "service": "clean_nan", "depends_on": "extract"},
208+
])
209+
errors, warnings = validate_pipeline(pipeline, MINIMAL_REGISTRY)
210+
assert any("depends_on" in e.lower() and "list" in e.lower() for e in errors)
211+
212+
def test_non_join_multiple_depends_on_invalid(self):
213+
pipeline = _pipeline([
214+
{"id": "extract1", "service": "extract_csv", "params": {"file_path": "/data/a.csv"}},
215+
{"id": "extract2", "service": "extract_csv", "params": {"file_path": "/data/b.csv"}},
216+
{"id": "clean", "service": "clean_nan", "depends_on": ["extract1", "extract2"]},
217+
])
218+
errors, warnings = validate_pipeline(pipeline, MINIMAL_REGISTRY)
219+
assert any("multiple depends_on" in e.lower() for e in errors)
220+
221+
def test_join_requires_two_depends_on(self):
222+
pipeline = _pipeline([
223+
{"id": "extract1", "service": "extract_csv", "params": {"file_path": "/data/a.csv"}},
224+
{"id": "join", "service": "join_datasets", "depends_on": ["extract1"]},
225+
])
226+
errors, warnings = validate_pipeline(pipeline, MINIMAL_REGISTRY)
227+
assert any("join_datasets" in e.lower() and "exactly 2" in e.lower() for e in errors)
228+
229+
def test_join_accepts_two_depends_on(self):
230+
pipeline = _pipeline([
231+
{"id": "extract1", "service": "extract_csv", "params": {"file_path": "/data/a.csv"}},
232+
{"id": "extract2", "service": "extract_csv", "params": {"file_path": "/data/b.csv"}},
233+
{"id": "join", "service": "join_datasets", "depends_on": ["extract1", "extract2"]},
234+
])
235+
errors, warnings = validate_pipeline(pipeline, MINIMAL_REGISTRY)
236+
assert errors == []
237+
174238

175239
# ── Return type contract ──────────────────────────────────────────────────────
176240

0 commit comments

Comments
 (0)