Skip to content

Commit 79bbaa1

Browse files
committed
Add baseline scenario configs and pipeline/modeling updates
1 parent 8d6c5fd commit 79bbaa1

17 files changed

Lines changed: 629 additions & 41 deletions

analysis_pipeline/build_trial_table.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -375,13 +375,18 @@ def main() -> None:
375375
subject_dirs = sorted(path for path in bids_root.glob("sub-*") if path.is_dir())
376376
if not subject_dirs:
377377
raise FileNotFoundError(f"No subject directories found in {bids_root}")
378+
print(
379+
f"Stage 0 starting. task={task} subjects={len(subject_dirs)} "
380+
f"trial_duration_s={args.trial_duration}"
381+
)
378382

379383
all_rows: list[dict[str, str]] = []
380384
subject_summaries: list[dict[str, Any]] = []
381385
all_anomalies: list[str] = []
382386

383-
for subject_dir in subject_dirs:
387+
for subject_idx, subject_dir in enumerate(subject_dirs, start=1):
384388
subject = subject_dir.name
389+
print(f"[Subject {subject_idx}/{len(subject_dirs)}] {subject}")
385390
events_path = subject_dir / "eeg" / f"{subject}_task-{task}_events.tsv"
386391
if not events_path.exists():
387392
all_anomalies.append(f"{subject}: Missing events file {events_path.name}.")

analysis_pipeline/config/pipeline.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,13 @@ stage6:
4747
clip_lower_quantile: 0.01
4848
clip_upper_quantile: 0.99
4949
random_seed: 42
50+
torch_device: "auto"
5051
class_scenarios:
5152
- name: "all_bins"
5253
- name: "omit_easiest"
5354
drop_labels: ["0.6-1.5"]
55+
- name: "omit_hardest"
56+
drop_labels: ["6.0-6.9"]
5457
- name: "three_level_merged"
5558
merge_map:
5659
"0.6-1.5": "low"
@@ -67,3 +70,5 @@ stage6_confusions:
6770
args:
6871
metric: "balanced_accuracy_mean"
6972
top_k_per_protocol: 1
73+
include_all: true
74+
out_png_dir: "analysis_pipeline/reports/confusion_pngs"
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
version: 1
2+
3+
paths:
4+
bids_root: "./data/bids_arithmetic"
5+
python_executable: "python"
6+
7+
reports:
8+
run_manifest: "analysis_pipeline/reports/run_manifest_baseline_advanced_nn.json"
9+
10+
stages:
11+
stage0: false
12+
stage1: false
13+
stage2: false
14+
stage3: false
15+
stage4: false
16+
stage5: true
17+
stage6: true
18+
stage6_confusions: true
19+
20+
stage_args:
21+
stage5:
22+
include_tutorial: true
23+
dropout_policy: "absolute"
24+
dropout_threshold: 35.0
25+
fused_out: "analysis_pipeline/features/features_fused_tutorial_baseline.tsv"
26+
split_manifest_out: "analysis_pipeline/features/split_manifest_tutorial_baseline.json"
27+
summary_json: "analysis_pipeline/reports/fusion_summary_tutorial_baseline.json"
28+
unimodal_tag: "tutorial_baseline"
29+
30+
stage6:
31+
run_tag_prefix: "baseline_adv_nn"
32+
results_json_template: "analysis_pipeline/reports/ml_results_{scenario}_baseline_advanced_nn.json"
33+
summary_md_template: "analysis_pipeline/reports/ml_summary_{scenario}_baseline_advanced_nn.md"
34+
base_args:
35+
split_manifest: "analysis_pipeline/features/split_manifest_tutorial_baseline.json"
36+
datasets: ["eeg", "ecg", "pupil", "fused"]
37+
protocols: ["loso", "group_holdout", "within_participant"]
38+
models: ["lstm1d", "gru1d", "cnn1d", "transformer", "bilstm1d", "bigru1d", "cnn1d_deep", "transformer_xl"]
39+
feature_selectors: ["none"]
40+
inner_folds: 2
41+
max_param_combos: 2
42+
max_outer_splits_per_protocol: 2
43+
clip_lower_quantile: 0.01
44+
clip_upper_quantile: 0.99
45+
random_seed: 42
46+
torch_device: "auto"
47+
baseline_from_tutorial_label: "baseline"
48+
class_scenarios:
49+
- name: "baseline_all_bins"
50+
- name: "baseline_omit_hardest"
51+
drop_labels: ["6.0-6.9"]
52+
- name: "baseline_low_high_omit_hardest"
53+
drop_labels: ["6.0-6.9"]
54+
merge_map:
55+
"0.6-1.5": "low_1_2_3"
56+
"1.5-2.4": "low_1_2_3"
57+
"2.4-3.3": "low_1_2_3"
58+
"3.3-4.2": "high_4_5_6"
59+
"4.2-5.1": "high_4_5_6"
60+
"5.1-6.0": "high_4_5_6"
61+
- name: "baseline_grouped_4class_omit_hardest"
62+
drop_labels: ["6.0-6.9"]
63+
merge_map:
64+
"0.6-1.5": "low_1_2"
65+
"1.5-2.4": "low_1_2"
66+
"2.4-3.3": "mid_3_4"
67+
"3.3-4.2": "mid_3_4"
68+
"4.2-5.1": "high_5_6"
69+
"5.1-6.0": "high_5_6"
70+
- name: "baseline_omit_easiest"
71+
drop_labels: ["0.6-1.5"]
72+
73+
stage6_confusions:
74+
out_json_template: "analysis_pipeline/reports/confusion_highlights_{scenario}_baseline_advanced_nn.json"
75+
out_md_template: "analysis_pipeline/reports/confusion_highlights_{scenario}_baseline_advanced_nn.md"
76+
args:
77+
metric: "balanced_accuracy_mean"
78+
top_k_per_protocol: 1
79+
include_all: true
80+
out_png_dir: "analysis_pipeline/reports/confusion_pngs"
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
version: 1
2+
3+
paths:
4+
bids_root: "./data/bids_arithmetic"
5+
python_executable: "python"
6+
7+
reports:
8+
run_manifest: "analysis_pipeline/reports/run_manifest_baseline_variants.json"
9+
10+
stages:
11+
stage0: false
12+
stage1: false
13+
stage2: false
14+
stage3: false
15+
stage4: false
16+
stage5: true
17+
stage6: true
18+
stage6_confusions: true
19+
20+
stage_args:
21+
stage5:
22+
include_tutorial: true
23+
dropout_policy: "absolute"
24+
dropout_threshold: 35.0
25+
fused_out: "analysis_pipeline/features/features_fused_tutorial_baseline.tsv"
26+
split_manifest_out: "analysis_pipeline/features/split_manifest_tutorial_baseline.json"
27+
summary_json: "analysis_pipeline/reports/fusion_summary_tutorial_baseline.json"
28+
unimodal_tag: "tutorial_baseline"
29+
30+
stage6:
31+
run_tag_prefix: "baseline_variant"
32+
results_json_template: "analysis_pipeline/reports/ml_results_{scenario}_baseline.json"
33+
summary_md_template: "analysis_pipeline/reports/ml_summary_{scenario}_baseline.md"
34+
base_args:
35+
split_manifest: "analysis_pipeline/features/split_manifest_tutorial_baseline.json"
36+
datasets: ["eeg", "ecg", "pupil", "fused"]
37+
protocols: ["loso", "group_holdout", "within_participant"]
38+
models: ["logreg", "knn", "svm", "gaussian_nb", "decision_tree", "mlp", "rf"]
39+
feature_selectors: ["none"]
40+
inner_folds: 2
41+
max_param_combos: 2
42+
max_outer_splits_per_protocol: 2
43+
clip_lower_quantile: 0.01
44+
clip_upper_quantile: 0.99
45+
random_seed: 42
46+
torch_device: "auto"
47+
baseline_from_tutorial_label: "baseline"
48+
class_scenarios:
49+
- name: "baseline_all_bins"
50+
- name: "baseline_omit_hardest"
51+
drop_labels: ["6.0-6.9"]
52+
- name: "baseline_low_high_omit_hardest"
53+
drop_labels: ["6.0-6.9"]
54+
merge_map:
55+
"0.6-1.5": "low_1_2_3"
56+
"1.5-2.4": "low_1_2_3"
57+
"2.4-3.3": "low_1_2_3"
58+
"3.3-4.2": "high_4_5_6"
59+
"4.2-5.1": "high_4_5_6"
60+
"5.1-6.0": "high_4_5_6"
61+
- name: "baseline_grouped_4class_omit_hardest"
62+
drop_labels: ["6.0-6.9"]
63+
merge_map:
64+
"0.6-1.5": "low_1_2"
65+
"1.5-2.4": "low_1_2"
66+
"2.4-3.3": "mid_3_4"
67+
"3.3-4.2": "mid_3_4"
68+
"4.2-5.1": "high_5_6"
69+
"5.1-6.0": "high_5_6"
70+
- name: "baseline_omit_easiest"
71+
drop_labels: ["0.6-1.5"]
72+
73+
stage6_confusions:
74+
out_json_template: "analysis_pipeline/reports/confusion_highlights_{scenario}_baseline.json"
75+
out_md_template: "analysis_pipeline/reports/confusion_highlights_{scenario}_baseline.md"
76+
args:
77+
metric: "balanced_accuracy_mean"
78+
top_k_per_protocol: 1
79+
include_all: true
80+
out_png_dir: "analysis_pipeline/reports/confusion_pngs"

analysis_pipeline/config/pipeline_class_variants.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,5 @@ stage6_confusions:
6363
args:
6464
metric: "balanced_accuracy_mean"
6565
top_k_per_protocol: 1
66+
include_all: true
67+
out_png_dir: "analysis_pipeline/reports/confusion_pngs"

analysis_pipeline/config/pipeline_model_feature_sweep.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,5 @@ stage6_confusions:
4141
args:
4242
metric: "balanced_accuracy_mean"
4343
top_k_per_protocol: 1
44+
include_all: true
45+
out_png_dir: "analysis_pipeline/reports/confusion_pngs"

analysis_pipeline/config/pipeline_with_deep_models.yaml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,14 @@ stage6:
3030
base_args:
3131
datasets: ["eeg", "ecg", "pupil", "fused"]
3232
protocols: ["loso", "group_holdout", "within_participant"]
33-
models: ["logreg", "knn", "svm", "gaussian_nb", "decision_tree", "mlp", "rf", "lstm1d", "gru1d", "cnn1d", "transformer"]
33+
models: ["lstm1d", "gru1d", "cnn1d", "transformer", "bilstm1d", "bigru1d", "cnn1d_deep", "transformer_xl"]
3434
inner_folds: 2
3535
max_param_combos: 2
3636
max_outer_splits_per_protocol: 2
3737
clip_lower_quantile: 0.01
3838
clip_upper_quantile: 0.99
3939
random_seed: 42
40+
torch_device: "auto"
4041
class_scenarios:
4142
- name: "all_bins"
4243
- name: "omit_easiest"
@@ -57,3 +58,5 @@ stage6_confusions:
5758
args:
5859
metric: "balanced_accuracy_mean"
5960
top_k_per_protocol: 1
61+
include_all: true
62+
out_png_dir: "analysis_pipeline/reports/confusion_pngs"

analysis_pipeline/run_pipeline.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import shlex
88
import subprocess
99
import sys
10+
import time
1011
from dataclasses import dataclass
1112
from datetime import datetime, timezone
1213
from pathlib import Path
@@ -414,6 +415,14 @@ def main() -> None:
414415
run_stamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
415416
logs_dir = _reports_dir() / "run_logs" / run_stamp
416417
workdir = _analysis_root().parent
418+
pipeline_start = time.time()
419+
total_steps = len(steps)
420+
421+
print("Pipeline run starting.")
422+
print(f" Config: {config_path}")
423+
print(f" Working directory: {workdir}")
424+
print(f" Dry run: {bool(args.dry_run)}")
425+
print(f" Planned steps: {total_steps}")
417426

418427
manifest: dict[str, Any] = {
419428
"pipeline_started_utc": _utc_now(),
@@ -433,12 +442,19 @@ def main() -> None:
433442
}
434443

435444
try:
436-
for step in steps:
437-
print(f"[{step.stage}] {step.name}")
445+
for step_index, step in enumerate(steps, start=1):
446+
step_start = time.time()
447+
print(f"[{step_index}/{total_steps}] [{step.stage}] {step.name}")
438448
print(" " + " ".join(shlex.quote(part) for part in step.command))
439449
step_result = _run_step(step=step, workdir=workdir, logs_dir=logs_dir, dry_run=args.dry_run)
440450
manifest["steps"].append(step_result)
451+
elapsed_s = time.time() - step_start
452+
print(f" Status: {step_result['status']} (elapsed={elapsed_s:.1f}s)")
441453
if step_result["return_code"] != 0:
454+
if step_result.get("stdout_log"):
455+
print(f" stdout log: {step_result['stdout_log']}")
456+
if step_result.get("stderr_log"):
457+
print(f" stderr log: {step_result['stderr_log']}")
442458
raise RuntimeError(f"Step failed: {step.name} (return_code={step_result['return_code']})")
443459
manifest["status"] = "dry_run" if args.dry_run else "success"
444460
except Exception as exc: # noqa: BLE001
@@ -449,6 +465,8 @@ def main() -> None:
449465
manifest["pipeline_finished_utc"] = _utc_now()
450466
manifest_out.parent.mkdir(parents=True, exist_ok=True)
451467
manifest_out.write_text(json.dumps(manifest, indent=2) + "\n", encoding="utf-8")
468+
print(f"Pipeline status: {manifest['status']}")
469+
print(f"Pipeline elapsed seconds: {time.time() - pipeline_start:.1f}")
452470
print(f"Run manifest: {manifest_out}")
453471

454472

analysis_pipeline/stage1_qc_summary.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -988,12 +988,17 @@ def main() -> None:
988988
subject_dirs = sorted(path for path in bids_root.glob("sub-*") if path.is_dir())
989989
if not subject_dirs:
990990
raise FileNotFoundError(f"No subject directories found in {bids_root}")
991+
print(
992+
f"Stage 1 starting. task={task} subjects={len(subject_dirs)} "
993+
f"trial_rows={len(trial_rows)}"
994+
)
991995

992996
subject_rows: list[dict[str, Any]] = []
993997
all_anomalies: list[str] = []
994-
for subject_dir in subject_dirs:
998+
for subject_idx, subject_dir in enumerate(subject_dirs, start=1):
995999
paths = _resolve_subject_paths(subject_dir, task)
9961000
subject = paths.subject
1001+
print(f"[Subject {subject_idx}/{len(subject_dirs)}] {subject}")
9971002
analysis_included = (
9981003
(participants.get(subject, {}).get("analysis_included") or "n/a").strip().lower()
9991004
)

analysis_pipeline/stage2_preprocess.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -700,9 +700,13 @@ def main() -> None:
700700
subject_dirs = [path for path in subject_dirs if path.name in wanted]
701701
if not subject_dirs:
702702
raise FileNotFoundError("No matching subject directories found.")
703+
print(
704+
f"Stage 2 starting. task={task} subjects={len(subject_dirs)} "
705+
f"overwrite={bool(args.overwrite)}"
706+
)
703707

704708
logs: list[dict[str, str]] = []
705-
for subject_dir in subject_dirs:
709+
for subject_idx, subject_dir in enumerate(subject_dirs, start=1):
706710
paths = _resolve_subject_paths(subject_dir, task)
707711
analysis_included = (
708712
(participants.get(paths.subject, {}).get("analysis_included") or "n/a")
@@ -714,7 +718,8 @@ def main() -> None:
714718
log = _process_subject(paths, out_root, analysis_included, args)
715719
logs.append(log)
716720
print(
717-
f"{paths.subject}: EEG={log['eeg_status']} ECG={log['ecg_status']} "
721+
f"[Subject {subject_idx}/{len(subject_dirs)}] {paths.subject}: "
722+
f"EEG={log['eeg_status']} ECG={log['ecg_status']} "
718723
f"Pupil={log['pupil_status']}"
719724
)
720725

0 commit comments

Comments
 (0)