MemCoE: Learning How and What to Memorize: Cognition-Inspired Two-Stage Optimization for Evolving Memory
This repository contains the official implementation of MemCoE, a cognition-inspired framework that optimizes both how to memorize and what to memorize through a two-stage approach: Memory Guideline Induction (MGI) and Guideline-Aligned Memory Policy Optimization (GMPO).
- Overview
- Installation
- Datasets
- Stage 1: Memory Guideline Induction (MGI)
- Stage 2: Guideline-Aligned Memory Policy Optimization (GMPO)
- Acknowledgements
- Citation
MemCoE introduces a two-stage optimization paradigm for evolving memory systems:
-
Memory Guideline Induction (MGI): Automatically induces high-level memory guidelines through textual gradient optimization, determining what information is worth memorizing.
-
Guideline-Aligned Memory Policy Optimization (GMPO): Trains a memory policy model aligned with the induced guidelines, optimizing how to effectively store and retrieve memories.
Create a new conda environment and install dependencies:
conda create -n memcoe python=3.10
conda activate memcoe
pip install -r requirements.txtDownload the following datasets and place them in the MGI_module/datasets/ directory:
| Dataset | Source |
|---|---|
| PersonaMem | HuggingFace |
| PrefEval | HuggingFace |
| PersonaBench | GitHub |
mkdir -p MGI_module/datasets
# Download and extract datasets to MGI_module/datasets/MGI supports both local model deployment via vLLM and API-based models (e.g., GPT-4o-mini).
Deploy a local model (e.g., Qwen2.5-7B-Instruct) using vLLM:
CUDA_VISIBLE_DEVICES=4,5,6,7 vllm serve Qwen2.5-7B-Instruct \
--tensor_parallel_size 4 \
--port 6025 \
--host 0.0.0.0 \
--served-model-name Qwen2.5-7B-InstructVerify the deployment:
cd MGI_module/
python async_vllm.py --model Qwen2.5-7B-InstructConfigure your API credentials in async_llm.py:
API_KEY = 'your-api-key'
BASE_URL = 'your-base-url'Run the textual gradient optimization to induce memory guidelines:
cd MGI_module/textgrad/
python train_textgrad.py \
--model Qwen2.5-7B-Instruct \
--num_chunks 8 \
--batch_size 10 \
--num_feedback 5Output: Results are saved in the textgrad/ directory after each optimization step.
Note: All MGI prompt templates are defined in
meta_prompt.py.
After training, update the TEMPLATE_EVOLVE prompt template in MGI_module/processors/memory.py with the optimized guidelines, then run inference:
cd MGI_module/
# PersonaMem (32k and 128k context)
python run_inference.py --dataset personamem --mode memory --context longcontext --num_chunks 8 --size 32k --model Qwen2.5-7B-Instruct --output_dir results
python run_inference.py --dataset personamem --mode memory --context longcontext --num_chunks 8 --size 128k --model Qwen2.5-7B-Instruct --output_dir results
# PrefEval (explicit and implicit)
python run_inference.py --dataset prefeval --mode memory --num_chunks 8 --pref_form explicit --model Qwen2.5-7B-Instruct --output_dir results
python run_inference.py --dataset prefeval --mode memory --num_chunks 8 --pref_form implicit --model Qwen2.5-7B-Instruct --output_dir results
# PersonaBench (varying noise levels)
python run_inference.py --dataset personabench --mode memory --num_chunks 8 --noise 0.0 --model Qwen2.5-7B-Instruct --output_dir results
python run_inference.py --dataset personabench --mode memory --num_chunks 8 --noise 0.3 --model Qwen2.5-7B-Instruct --output_dir results
python run_inference.py --dataset personabench --mode memory --num_chunks 8 --noise 0.5 --model Qwen2.5-7B-Instruct --output_dir results
python run_inference.py --dataset personabench --mode memory --num_chunks 8 --noise 0.7 --model Qwen2.5-7B-Instruct --output_dir resultsOutput: Results are saved in the results/ directory.
Train the memory policy model:
bash run_memory_7B.shNote: Modify
run_memory_7B.shto adjust hyperparameters and resource configurations for your hardware setup.
Output: Model checkpoints are saved in memory_agent/7B/global_step_xxx/.
Run the merger script to convert the checkpoint:
bash scripts/merger.shNote: Update
CKPT=memory_agent/7B/global_step_xxxin the script with the actual checkpoint step number.
cd MGI_module/
CUDA_VISIBLE_DEVICES=4,5,6,7 vllm serve memory_agent/7B/global_step_xxx/huggingface \
--tensor_parallel_size 4 \
--port 6025 \
--host 0.0.0.0 \
--served-model-name MemCoE-7B# PersonaMem (32k and 128k context)
python run_inference.py --dataset personamem --mode memory --context longcontext --num_chunks 8 --size 32k --model MemCoE-7B --output_dir results
python run_inference.py --dataset personamem --mode memory --context longcontext --num_chunks 8 --size 128k --model MemCoE-7B --output_dir results
# PrefEval (explicit and implicit)
python run_inference.py --dataset prefeval --mode memory --num_chunks 8 --pref_form explicit --model MemCoE-7B --output_dir results
python run_inference.py --dataset prefeval --mode memory --num_chunks 8 --pref_form implicit --model MemCoE-7B --output_dir results
# PersonaBench (varying noise levels)
python run_inference.py --dataset personabench --mode memory --num_chunks 8 --noise 0.0 --model MemCoE-7B --output_dir results
python run_inference.py --dataset personabench --mode memory --num_chunks 8 --noise 0.3 --model MemCoE-7B --output_dir results
python run_inference.py --dataset personabench --mode memory --num_chunks 8 --noise 0.5 --model MemCoE-7B --output_dir results
python run_inference.py --dataset personabench --mode memory --num_chunks 8 --noise 0.7 --model MemCoE-7B --output_dir resultsWe thank the following projects for their excellent work and open-source contributions:
