Skip to content

ORION2809/RLRAG-Cost-Aware-Retrieval-Depth-Control-for-Production-RAG

Repository files navigation

Adaptive Cost-Aware RAG Controller

Static RAG systems waste tokens by retrieving the same number of passages for every query — regardless of difficulty.

This project implements a reinforcement learning controller that dynamically selects retrieval depth per query, achieving:

  • 28.8% token reduction vs fixed k=4 retrieval
  • Near-parity F1 with best static baselines (0.430 vs 0.432)
  • Stable Pareto-optimal tradeoff under a cost-aware reward function

Pareto curve: F1 vs Tokens


Results

Strategy F1 EM Avg Tokens Reward Cost/200q
No Retrieval 0.2206 0.0600 71 0.2135 ₹0.31
Fixed k=2 0.4209 0.2200 334 0.3876 ₹0.92
Fixed k=4 0.4319 0.2250 608 0.3712 ₹1.60
RL Policy 0.4298 0.2300 433 0.3865 ₹1.16

The RL policy achieves k=4-level quality at k=2-level cost by learning when each query benefits from more retrieval context.

Action distribution shift during training


Heuristic Baselines

Before RL, we evaluate threshold-based heuristic gating using retriever confidence scores (P20–P80 percentiles). The RL policy dominates all heuristic variants on the Pareto frontier:

Pareto frontier with heuristic baselines

Strategy comparison: fixed, heuristic, and RL

Key insight: heuristic gating can reduce tokens (P50 → 420 tokens) but sacrifices F1 because the threshold is global rather than per-query. The RL policy learns a query-specific decision boundary that the heuristic cannot capture.


Why This Matters

In production RAG systems:

  • API costs scale linearly with context tokens — every unnecessary passage burns budget
  • Static k wastes context — easy factoid queries don't need 4 passages; hard multi-fact queries do
  • Dynamic gating reduces token burn — a lightweight policy (83K params) decides retrieval depth in <1ms
  • Tunable λ controls aggressiveness — adjust the cost-quality tradeoff without retraining from scratch

This is infrastructure-level optimization for any system serving LLM calls at scale.


Architecture

Architecture diagram

Text version
Query
  │
  ├─── Embedder (all-MiniLM-L6-v2, 384d, local)
  │         │
  │    ┌────┴─────────────────────────────┐
  │    │         State Features           │
  │    │  384d embedding                  │
  │    │  + 4d retriever confidence       │
  │    │  + 3d lexical signals            │
  │    │  = 391d → normalized             │
  │    └────┬─────────────────────────────┘
  │         │
  │    ┌────┴─────────────────────────────┐
  │    │       Policy Network (MLP)       │
  │    │  Linear(384→128) projection      │
  │    │  Linear(135→128) → ReLU          │
  │    │  Linear(128→128) → ReLU          │
  │    │  Linear(128→3)   → softmax       │
  │    │  → action ∈ {skip, k=2, k=4}    │
  │    └────┬─────────────────────────────┘
  │         │
  ├─── Retriever (FAISS IndexFlatIP, 500K passages)
  │         │
  └─── LLM (GPT-4o-mini, temp=0)
            │
        Answer + Reward
        R = F1 - λ·(tokens/1000)

Training: REINFORCE with running-mean baseline (β=0.99), gradient clipping at 1.0, 10 epochs over 1,000 NQ-Open queries.


Quick Start

1. Install

pip install -r requirements.txt

2. Set API Key

echo "OPENAI_API_KEY=sk-your-key" > .env

3. Run Full Pipeline

python main.py download        # Download NQ questions + Wikipedia corpus
python main.py index           # Build FAISS index (~2h on CPU)
python main.py recall-check    # Verify retrieval quality (need ≥60%)
python main.py baselines       # Evaluate fixed-k baselines
python main.py fit-normalizer  # Fit state normalizer
python main.py train           # Train RL policy (10 epochs)
python main.py evaluate        # Evaluate on test set
python main.py plot            # Generate all plots

Or run everything at once:

python main.py full-pipeline

4. Use the Trained Controller

from controller.rl_controller import RLRAGController

controller = RLRAGController(
    model_path="models/policy_final.pt",
    normalizer_path="models/state_normalizer.npz",
)

response = controller.answer("Who discovered penicillin?")
print(response["answer"])     # "Alexander Fleming"
print(response["action"])     # "retrieve_k2"
print(response["tokens"])     # 287

Project Structure

adaptive-cost-aware-rag/
│
├── README.md
├── requirements.txt
├── LICENSE
├── CITATION.cff
├── .gitignore
│
├── configs/
│   └── default.py              # All hyperparameters in one place
│
├── controller/
│   ├── network.py              # Policy MLP (83K params)
│   ├── features.py             # State extraction + normalization
│   └── rl_controller.py        # Production-ready wrapper
│
├── training/
│   └── trainer.py              # REINFORCE training loop
│
├── retriever/
│   ├── embedder.py             # Local sentence-transformer embeddings
│   ├── indexer.py              # FAISS index builder
│   └── retriever.py            # Top-k retrieval interface
│
├── generator/
│   └── llm.py                  # GPT-4o-mini wrapper with cache
│
├── baseline/
│   ├── fixed_topk.py           # Fixed-k baselines
│   └── heuristic_gate.py       # Threshold-based adaptive gating
│
├── evaluation/
│   ├── metrics.py              # F1 + EM with answer normalization
│   ├── cost_tracker.py         # Token/cost logging
│   └── evaluate.py             # Full evaluation pipeline
│
├── data/
│   └── download_nq.py          # Dataset download + preprocessing
│
├── plotting/
│   ├── pareto.py               # Publication-quality visualizations
│   └── generate_visuals.py     # Heuristic + λ sweep plot generator
│
├── results/                    # Pre-generated plots
│   ├── pareto_cost_accuracy.png
│   ├── baseline_comparison.png
│   ├── training_curves.png
│   ├── action_distribution.png
│   ├── heuristic_comparison.png
│   ├── heuristic_pareto.png
│   ├── lambda_sweep.png
│   └── lambda_training_dynamics.png
│
└── main.py                     # CLI entry point

Training Dynamics

The policy starts by defaulting to heavy retrieval (k=4, safe), then gradually learns to shift toward lighter retrieval (k=2) as it discovers which queries don't need extra context:

Epoch Train F1 Train Reward Loss k=2 % k=4 %
1 0.363 0.325 0.305 23.5% 43.8%
5 0.435 0.381 0.080 25.4% 72.9%
7 0.437 0.387 0.052 40.1% 57.7%
10 0.435 0.390 0.034 57.8% 41.1%

Training curves


Cost Penalty Sensitivity (λ)

The λ parameter controls the quality-cost tradeoff. Higher λ pushes the policy toward cheaper actions:

λ F1 Tokens skip % k=2 % k=4 % Behavior
0.05 0.434 560 2% 18% 80% Quality-first — mostly k=4
0.10 0.430 433 0% 63.5% 36.5% Balanced (default)
0.20 0.405 280 12% 72% 16% Cost-aggressive — heavy skipping

Lambda sweep analysis

The training dynamics reveal how λ shapes the policy over epochs — low λ keeps k=4 dominant, high λ rapidly shifts toward k=2 and skip:

Lambda training dynamics

This turns the controller into a tunable framework — adjust λ to match your cost-quality requirements without retraining from scratch.


Technical Details

  • Dataset: NaturalQuestions-Open (1,000 train / 200 test, seed=42)
  • Corpus: 500K passages — 69K NQ answer passages + 431K Wikipedia (hybrid for recall)
  • Embeddings: all-MiniLM-L6-v2 (384d, L2-normalized, local)
  • Index: FAISS IndexFlatIP (exact search, 732 MB)
  • LLM: GPT-4o-mini (temperature=0, max_tokens=100)
  • RL: REINFORCE with running-mean baseline (β=0.99)
  • Policy: 3-layer MLP, 83,587 parameters, ~1ms inference on CPU
  • Recall@10: 72% on test set (pass threshold: 60%)
  • Total training cost: ₹70 ($0.85 USD)

Phase 2 Roadmap

This controller is designed to integrate with Agent Lightning for scalable RL fine-tuning:

Phase 1 (This Repo) Phase 2 (Planned)
REINFORCE + baseline GRPO via Agent Lightning VERL
3 actions 13 actions (gate × k × compression)
GPT-4o-mini API Qwen2.5-7B via vLLM (local)
1,000 NQ queries Full NQ + HotpotQA + TriviaQA
Simple F1 - λ·cost Full 4-component reward

Citation

@software{suvarna2026rlrag,
  title   = {Adaptive Cost-Aware RAG Controller},
  author  = {Suvarna, Shreyas},
  year    = {2026},
  url     = {https://github.com/ORION2809/RLRAG-Cost-Aware-Retrieval-Depth-Control-for-Production-RAG}
}

License

MIT — see LICENSE for details.

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages