Static RAG systems waste tokens by retrieving the same number of passages for every query — regardless of difficulty.
This project implements a reinforcement learning controller that dynamically selects retrieval depth per query, achieving:
- 28.8% token reduction vs fixed k=4 retrieval
- Near-parity F1 with best static baselines (0.430 vs 0.432)
- Stable Pareto-optimal tradeoff under a cost-aware reward function
| Strategy | F1 | EM | Avg Tokens | Reward | Cost/200q |
|---|---|---|---|---|---|
| No Retrieval | 0.2206 | 0.0600 | 71 | 0.2135 | ₹0.31 |
| Fixed k=2 | 0.4209 | 0.2200 | 334 | 0.3876 | ₹0.92 |
| Fixed k=4 | 0.4319 | 0.2250 | 608 | 0.3712 | ₹1.60 |
| RL Policy | 0.4298 | 0.2300 | 433 | 0.3865 | ₹1.16 |
The RL policy achieves k=4-level quality at k=2-level cost by learning when each query benefits from more retrieval context.
Before RL, we evaluate threshold-based heuristic gating using retriever confidence scores (P20–P80 percentiles). The RL policy dominates all heuristic variants on the Pareto frontier:
Key insight: heuristic gating can reduce tokens (P50 → 420 tokens) but sacrifices F1 because the threshold is global rather than per-query. The RL policy learns a query-specific decision boundary that the heuristic cannot capture.
In production RAG systems:
- API costs scale linearly with context tokens — every unnecessary passage burns budget
- Static k wastes context — easy factoid queries don't need 4 passages; hard multi-fact queries do
- Dynamic gating reduces token burn — a lightweight policy (83K params) decides retrieval depth in <1ms
- Tunable λ controls aggressiveness — adjust the cost-quality tradeoff without retraining from scratch
This is infrastructure-level optimization for any system serving LLM calls at scale.
Text version
Query
│
├─── Embedder (all-MiniLM-L6-v2, 384d, local)
│ │
│ ┌────┴─────────────────────────────┐
│ │ State Features │
│ │ 384d embedding │
│ │ + 4d retriever confidence │
│ │ + 3d lexical signals │
│ │ = 391d → normalized │
│ └────┬─────────────────────────────┘
│ │
│ ┌────┴─────────────────────────────┐
│ │ Policy Network (MLP) │
│ │ Linear(384→128) projection │
│ │ Linear(135→128) → ReLU │
│ │ Linear(128→128) → ReLU │
│ │ Linear(128→3) → softmax │
│ │ → action ∈ {skip, k=2, k=4} │
│ └────┬─────────────────────────────┘
│ │
├─── Retriever (FAISS IndexFlatIP, 500K passages)
│ │
└─── LLM (GPT-4o-mini, temp=0)
│
Answer + Reward
R = F1 - λ·(tokens/1000)
Training: REINFORCE with running-mean baseline (β=0.99), gradient clipping at 1.0, 10 epochs over 1,000 NQ-Open queries.
pip install -r requirements.txtecho "OPENAI_API_KEY=sk-your-key" > .envpython main.py download # Download NQ questions + Wikipedia corpus
python main.py index # Build FAISS index (~2h on CPU)
python main.py recall-check # Verify retrieval quality (need ≥60%)
python main.py baselines # Evaluate fixed-k baselines
python main.py fit-normalizer # Fit state normalizer
python main.py train # Train RL policy (10 epochs)
python main.py evaluate # Evaluate on test set
python main.py plot # Generate all plotsOr run everything at once:
python main.py full-pipelinefrom controller.rl_controller import RLRAGController
controller = RLRAGController(
model_path="models/policy_final.pt",
normalizer_path="models/state_normalizer.npz",
)
response = controller.answer("Who discovered penicillin?")
print(response["answer"]) # "Alexander Fleming"
print(response["action"]) # "retrieve_k2"
print(response["tokens"]) # 287adaptive-cost-aware-rag/
│
├── README.md
├── requirements.txt
├── LICENSE
├── CITATION.cff
├── .gitignore
│
├── configs/
│ └── default.py # All hyperparameters in one place
│
├── controller/
│ ├── network.py # Policy MLP (83K params)
│ ├── features.py # State extraction + normalization
│ └── rl_controller.py # Production-ready wrapper
│
├── training/
│ └── trainer.py # REINFORCE training loop
│
├── retriever/
│ ├── embedder.py # Local sentence-transformer embeddings
│ ├── indexer.py # FAISS index builder
│ └── retriever.py # Top-k retrieval interface
│
├── generator/
│ └── llm.py # GPT-4o-mini wrapper with cache
│
├── baseline/
│ ├── fixed_topk.py # Fixed-k baselines
│ └── heuristic_gate.py # Threshold-based adaptive gating
│
├── evaluation/
│ ├── metrics.py # F1 + EM with answer normalization
│ ├── cost_tracker.py # Token/cost logging
│ └── evaluate.py # Full evaluation pipeline
│
├── data/
│ └── download_nq.py # Dataset download + preprocessing
│
├── plotting/
│ ├── pareto.py # Publication-quality visualizations
│ └── generate_visuals.py # Heuristic + λ sweep plot generator
│
├── results/ # Pre-generated plots
│ ├── pareto_cost_accuracy.png
│ ├── baseline_comparison.png
│ ├── training_curves.png
│ ├── action_distribution.png
│ ├── heuristic_comparison.png
│ ├── heuristic_pareto.png
│ ├── lambda_sweep.png
│ └── lambda_training_dynamics.png
│
└── main.py # CLI entry point
The policy starts by defaulting to heavy retrieval (k=4, safe), then gradually learns to shift toward lighter retrieval (k=2) as it discovers which queries don't need extra context:
| Epoch | Train F1 | Train Reward | Loss | k=2 % | k=4 % |
|---|---|---|---|---|---|
| 1 | 0.363 | 0.325 | 0.305 | 23.5% | 43.8% |
| 5 | 0.435 | 0.381 | 0.080 | 25.4% | 72.9% |
| 7 | 0.437 | 0.387 | 0.052 | 40.1% | 57.7% |
| 10 | 0.435 | 0.390 | 0.034 | 57.8% | 41.1% |
The λ parameter controls the quality-cost tradeoff. Higher λ pushes the policy toward cheaper actions:
| λ | F1 | Tokens | skip % | k=2 % | k=4 % | Behavior |
|---|---|---|---|---|---|---|
| 0.05 | 0.434 | 560 | 2% | 18% | 80% | Quality-first — mostly k=4 |
| 0.10 | 0.430 | 433 | 0% | 63.5% | 36.5% | Balanced (default) |
| 0.20 | 0.405 | 280 | 12% | 72% | 16% | Cost-aggressive — heavy skipping |
The training dynamics reveal how λ shapes the policy over epochs — low λ keeps k=4 dominant, high λ rapidly shifts toward k=2 and skip:
This turns the controller into a tunable framework — adjust λ to match your cost-quality requirements without retraining from scratch.
- Dataset: NaturalQuestions-Open (1,000 train / 200 test, seed=42)
- Corpus: 500K passages — 69K NQ answer passages + 431K Wikipedia (hybrid for recall)
- Embeddings: all-MiniLM-L6-v2 (384d, L2-normalized, local)
- Index: FAISS IndexFlatIP (exact search, 732 MB)
- LLM: GPT-4o-mini (temperature=0, max_tokens=100)
- RL: REINFORCE with running-mean baseline (β=0.99)
- Policy: 3-layer MLP, 83,587 parameters, ~1ms inference on CPU
- Recall@10: 72% on test set (pass threshold: 60%)
- Total training cost:
₹70 ($0.85 USD)
This controller is designed to integrate with Agent Lightning for scalable RL fine-tuning:
| Phase 1 (This Repo) | Phase 2 (Planned) |
|---|---|
| REINFORCE + baseline | GRPO via Agent Lightning VERL |
| 3 actions | 13 actions (gate × k × compression) |
| GPT-4o-mini API | Qwen2.5-7B via vLLM (local) |
| 1,000 NQ queries | Full NQ + HotpotQA + TriviaQA |
| Simple F1 - λ·cost | Full 4-component reward |
@software{suvarna2026rlrag,
title = {Adaptive Cost-Aware RAG Controller},
author = {Suvarna, Shreyas},
year = {2026},
url = {https://github.com/ORION2809/RLRAG-Cost-Aware-Retrieval-Depth-Control-for-Production-RAG}
}MIT — see LICENSE for details.







