This document describes the Policy Network approach for mathematical reasoning using In-Context Reinforcement Learning (ICRL). The core idea is to learn an intelligent example selection policy that chooses the most helpful few-shot examples for solving mathematical problems.
Traditional few-shot prompting relies on:
- Random example selection - No guarantee examples are relevant
- Fixed example sets - Same examples for all problems
- Manual curation - Expensive and doesn't scale
Instead of random selection, we train a Policy Network that:
- Analyzes the target problem using semantic embeddings
- Evaluates candidate examples from a large pool
- Selects optimal examples that maximize solving success
- Learns from feedback to improve selection over time
This transforms static few-shot prompting into adaptive, learned example selection.
Input: [Problem Embedding (1536-D), Candidate Embeddings (N×1536)]
↓
Input Projection (1536 → 768)
↓
Multi-Head Attention (8 heads, 768-D)
↓
Feed-Forward Network + Residual Connection
↓
Scoring & Selection Head
↓
Output: [Probability Distribution over N candidates]
- Purpose: Captures complex relationships between problem and candidates
- Architecture: 8 attention heads with 768-dimensional representations
- Benefit: Allows model to focus on different semantic aspects simultaneously
- Implementation: PyTorch MultiheadAttention with batch_first=True
- Correctness Reward (60%): Whether selected examples lead to correct solutions
- Semantic Similarity (30%): Cosine similarity between problem and examples
- Diversity Reward (10%): Ensuring varied example types
- Total: Weighted combination for balanced learning
- Problem: Fixed temperature may be too sharp or too soft for all problems
- Solution: Learnable temperature parameter that adapts during training
- Range: Constrained between 0.1 and 2.0 for stability
- Result: Better probability distributions for example selection
- Dataset Processing: Extract mathematical problems from datasets (TAT-QA, GSM8K, etc.)
- Solution Generation: Use Function Prototype Prompting (FPP) to generate code solutions
- Embedding Creation: Convert problem+context to 1536-D embeddings using OpenAI text-embedding-3-small
- Quality Filtering: Keep only candidates with valid, executable code
- Validation: Execute code and verify correctness against ground truth
- Problem Sampling: Random selection of target problems from training set
- Candidate Pool: Create diverse pool of potential examples for each problem
- Policy Selection: Use current policy to select k examples (typically k=2-3)
- GPT Evaluation: Generate solution using selected examples, check correctness
- Reward Calculation: Multi-objective reward combining correctness, similarity, diversity
- Policy Update: PPO loss with KL divergence regularization and gradient clipping
- Primary: Maximize success rate of selected examples
- Secondary: Maintain selection consistency across similar problems
- Regularization: Contrastive loss for better representation learning
# Problem representation
problem_text = f"{context} {question}"
problem_embedding = openai.embeddings(problem_text)
# Candidate representation
candidate_text = f"{candidate_context} {candidate_question}"
candidate_embedding = openai.embeddings(candidate_text)def select_examples(policy_net, problem_emb, candidate_embs, k=3):
# Forward pass through policy network
probs = policy_net(problem_emb, candidate_embs)
# Sample k examples based on learned probabilities
selected_indices = torch.multinomial(probs, k, replacement=False)
return [candidates[i] for i in selected_indices]# Calculate multi-objective reward
accuracy_reward = 1.0 if is_correct else 0.0
similarity_reward = F.cosine_similarity(problem_emb, example_embs.mean(dim=0)).item()
diversity_reward = 1.0 - F.cosine_similarity(example_embs[0], example_embs[1]).item()
total_reward = 0.6 * accuracy_reward + 0.3 * similarity_reward + 0.1 * diversity_reward
# PPO loss with advantages
ratio = new_probs / (old_probs.detach() + 1e-8)
clipped_ratio = torch.clamp(ratio, 1 - epsilon, 1 + epsilon)
policy_loss = -torch.min(ratio * advantages, clipped_ratio * advantages).mean()# Pull positive examples closer, push negative ones away
pos_sim = F.cosine_similarity(problem_emb, positive_embs, dim=-1).mean()
neg_sim = F.cosine_similarity(problem_emb, negative_embs, dim=-1).mean()
contrastive_loss = -torch.log(torch.exp(pos_sim) / (torch.exp(pos_sim) + torch.exp(neg_sim)))The policy learns to distinguish between different types of similarity:
- Surface similarity (similar keywords, domain) - Traditional methods rely on this
- Structural similarity (similar reasoning patterns) - Policy captures this
- Solution utility (examples that actually help solve the problem) - Policy optimizes for this
Traditional methods use fixed heuristics (cosine similarity, complexity ranking), while our policy learns optimal selection from data:
- Adaptive criteria: Selection strategy adapts to problem characteristics
- Context awareness: Considers both problem and available candidates
- Feedback learning: Improves from success/failure experiences
Unlike single-metric approaches, our reward function balances multiple objectives:
- Correctness: Primary goal of solving problems correctly
- Diversity: Prevents mode collapse to similar examples
- Similarity: Ensures relevance to target problem
| Approach | Selection Strategy | Adaptability | Learning | Training Required |
|---|---|---|---|---|
| Random | Uniform sampling | None | No | No |
| KATE | Semantic similarity | Static | No | No |
| CDS | Curriculum-based | Semi-static | No | No |
| Policy (Ours) | Learned optimization | Dynamic | Yes | Yes |
pip install torch transformers openai pandas numpy tqdmfrom mint.icrl.policy_network import PolicyNetwork
from mint.icrl.evaluator import PolicyNetworkEvaluator
import torch
# Initialize components
policy_net = PolicyNetwork(emb_dim=1536, hidden_dim=768)
evaluator = PolicyNetworkEvaluator()
# Load trained model
checkpoint = torch.load('models/dataset_policy_best.pt', map_location='cpu')
policy_net.load_state_dict(checkpoint['model_state_dict'])
# Select examples for a new problem
selected_examples = evaluator.select_with_policy(
policy_net, problem_dict, candidate_pool, k=3
)from mint.icrl.trainer import PolicyNetworkTrainer
# Initialize trainer
trainer = PolicyNetworkTrainer(
dataset_name='TAT-QA',
candidates_dir='candidates',
models_dir='models'
)
# Train for specified epochs
training_history = trainer.train(
num_epochs=3,
save_best=True
)
print(f"Training completed successfully")from mint.icrl.evaluator import PolicyNetworkEvaluator
# Load trained policy
evaluator = PolicyNetworkEvaluator()
# Compare with baselines
results = evaluator.evaluate_policy_vs_random(
policy_net=policy_net,
dataset_candidates=candidates,
dataset_name='TAT-QA',
n_trials=150
)
print(f"Policy vs Random comparison completed")| Dataset | k (examples) | Pool Size | Learning Rate | Epochs | Expected Behavior |
|---|---|---|---|---|---|
| GSM8K | 2 | 20 | 3e-4 | 3 | Fast convergence on arithmetic patterns |
| SVAMP | 2 | 15 | 3e-4 | 4 | Good handling of linguistic variations |
| TabMWP | 2 | 25 | 2e-4 | 4 | Effective table structure recognition |
| TAT-QA | 3 | 25 | 2e-4 | 3 | Complex financial reasoning support |
| FinQA | 2 | 30 | 1e-4 | 5 | Multi-step financial calculations |
# Train on different datasets
python train_policy.py --dataset TAT-QA --epochs 3
python train_policy.py --dataset GSM8K --epochs 3 --lr 3e-4
python train_policy.py --dataset FinQA --epochs 5 --lr 1e-4
python train_policy.py --dataset SVAMP --epochs 4 --lr 3e-4
python train_policy.py --dataset TabMWP --epochs 4 --lr 2e-4- Cross-Dataset Transfer: Train on one dataset, evaluate on others
- Dynamic k Selection: Learn optimal number of examples per problem
- Hierarchical Policies: Different policies for different problem types
- Meta-Learning: Quick adaptation to new mathematical domains
- Multi-Modal: Incorporate visual reasoning for TabMWP
- Efficient Inference: Reduce selection time for real-time applications
- Distributed Training: Scale to larger candidate pools and datasets
- Online Learning: Continuously improve from user feedback
- Robustness: Handle out-of-distribution problems gracefully
- Ablation Studies: Impact of different reward components
- Architecture Search: Optimal network design for mathematical reasoning
- Data Efficiency: Minimum training data for effective policies
- Failure Analysis: When and why policy selection fails
This approach builds upon recent advances in:
- In-Context Learning: Understanding how LLMs use few-shot examples
- Reinforcement Learning: Policy gradient methods for discrete selection
- Representation Learning: Semantic embeddings for mathematical reasoning
- Meta-Learning: Learning to learn from examples
- Multi-Objective Reward: Balancing correctness, similarity, and diversity
- Attention-Based Selection: Using transformer architecture for example selection
- End-to-End Pipeline: Complete system from candidate generation to evaluation
- Domain Adaptability: Configurable for different mathematical reasoning domains
🎯 Policy Network Vision: This approach demonstrates how reinforcement learning can be applied to optimize in-context learning for mathematical reasoning, providing an adaptive alternative to traditional heuristic-based example selection methods.