Skip to content

Commit 4a8f0ea

Browse files
refactor: 升级至 1.1.0,统一无锁原子核心并修复 Windows 序列化问题
- 重构 C++ 后端:移除自旋锁,统一采用 std::atomic 无锁实现,提升并发性能。 - 修复 Windows 下的 PicklingError:引入 CachetoolsDecorator 支持 spawn 模式下的缓存序列化。 - 简化 API:提供智能工厂方法 get_opt_cache,自动处理单/多进程后端路由。 - 同步 lab 脚本:更新所有实验脚本以适配 1.1.0 API,修复语法分析器警告。 - 文档更新:同步 README 示例并清理过时的实验文件。
1 parent 7f031ce commit 4a8f0ea

13 files changed

Lines changed: 370 additions & 271 deletions

File tree

README.md

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,35 +39,39 @@ pip install OPT4TorchDataset
3939
4040
## Quick Start
4141

42-
### Method 1: API
42+
### Method 1: Using get_opt_cache (Recommended)
4343

4444
```python
45-
from OPT4TorchDataSet.cachelib import generate_precomputed_file, OPTCacheDecorator
45+
from OPT4TorchDataSet import generate_precomputed_file, get_opt_cache
4646
from torch.utils.data import DataLoader
4747

4848
# Step 1: Offline generation of precomputed file (one-time)
4949
generate_precomputed_file(
5050
dataset_size=10000,
5151
total_iterations=100000,
52-
persist_path="precomputed/my_experiment.safetensors",
52+
persist_path="precomputed/my_plan.safetensors",
5353
random_seed=0,
54-
replacement=True,
5554
maxsize=3000
5655
)
5756

58-
# Step 2: Create cache decorator at runtime
59-
decorator = OPTCacheDecorator(
60-
precomputed_path="precomputed/my_experiment.safetensors",
61-
maxsize=3000, # Must be consistent with maxsize during precomputation
62-
total_iter=100000
57+
# Step 2: Create cache decorator (Intelligent mode: auto-handles single/multi-process)
58+
# num_workers=0 automatically uses high-performance Python version
59+
# num_workers>0 automatically uses Shared Memory C++ version
60+
dataset = MyDataset()
61+
dataset.cache = get_opt_cache(
62+
num_workers=0,
63+
precomputed_path="precomputed/my_plan.safetensors",
64+
maxsize=3000,
65+
total_iter=100000, # Required for Python mode
66+
dataset_size=10000, # Required for Shared Memory (C++) mode
67+
item_shape=(3, 224, 224) # Required for Shared Memory (C++) mode
6368
)
6469

6570
# Step 3: Apply to dataset
66-
dataset = MyDataset()
67-
dataset.__getitem__ = decorator(dataset.__getitem__)
71+
dataset.__getitem__ = dataset.cache(dataset.__getitem__)
6872

6973
# Use DataLoader
70-
dataloader = DataLoader(dataset, batch_size=32)
74+
dataloader = DataLoader(dataset, batch_size=32, num_workers=0)
7175
for batch in dataloader:
7276
pass
7377
```

README_zh.md

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,35 +39,39 @@ pip install OPT4TorchDataset
3939
4040
## Quick Start
4141

42-
### Method 1: API
42+
### Method 1: 使用 get_opt_cache (推荐)
4343

4444
```python
45-
from OPT4TorchDataSet.cachelib import generate_precomputed_file, OPTCacheDecorator
45+
from OPT4TorchDataSet import generate_precomputed_file, get_opt_cache
4646
from torch.utils.data import DataLoader
4747

4848
# Step 1: 离线生成预计算文件(一次性)
4949
generate_precomputed_file(
5050
dataset_size=10000,
5151
total_iterations=100000,
52-
persist_path="precomputed/my_experiment.safetensors",
52+
persist_path="precomputed/my_plan.safetensors",
5353
random_seed=0,
54-
replacement=True,
5554
maxsize=3000
5655
)
5756

58-
# Step 2: 运行时创建缓存装饰器
59-
decorator = OPTCacheDecorator(
60-
precomputed_path="precomputed/my_experiment.safetensors",
61-
maxsize=3000, # 必须与预计算时的maxsize一致
62-
total_iter=100000
57+
# Step 2: 运行时创建缓存装饰器(智能模式:自动处理单进程/多进程)
58+
# num_workers=0 会自动使用高性能 Python 版
59+
# num_workers>0 会自动使用共享内存 C++ 版
60+
dataset = MyDataset()
61+
dataset.cache = get_opt_cache(
62+
num_workers=0,
63+
precomputed_path="precomputed/my_plan.safetensors",
64+
maxsize=3000,
65+
total_iter=100000, # Python 模式需要
66+
dataset_size=10000, # 共享内存 (C++) 模式需要
67+
item_shape=(3, 224, 224) # 共享内存 (C++) 模式需要
6368
)
6469

6570
# Step 3: 应用到数据集
66-
dataset = MyDataset()
67-
dataset.__getitem__ = decorator(dataset.__getitem__)
71+
dataset.__getitem__ = dataset.cache(dataset.__getitem__)
6872

6973
# 使用数据加载器
70-
dataloader = DataLoader(dataset, batch_size=32)
74+
dataloader = DataLoader(dataset, batch_size=32, num_workers=0)
7175
for batch in dataloader:
7276
pass
7377
```

lab/hit_rate/experiment.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,12 @@
99
import json
1010
import sys
1111
from pathlib import Path
12-
from typing import List, Dict, Tuple, Optional, Union, Any
12+
from typing import List, Tuple, Optional, Union, Any
1313

14-
import torch
1514
import typer
1615
from loguru import logger
1716
from torch.utils.data import DataLoader, RandomSampler
18-
from cachetools import cached, LRUCache, LFUCache, FIFOCache, RRCache
17+
from cachetools import LRUCache, LFUCache, FIFOCache, RRCache
1918

2019
PROJECT_ROOT = Path(__file__).resolve().parents[2]
2120
if str(PROJECT_ROOT) not in sys.path:
@@ -24,10 +23,11 @@
2423
sys.path.insert(0, str(PROJECT_ROOT / "src"))
2524

2625
from lib.hit_rate_dataset import HitRateDataset
26+
from OPT4TorchDataSet import get_opt_cache, generate_precomputed_file
2727
from OPT4TorchDataSet.cachelib import (
2828
OPTCacheDecorator,
29+
SharedOPTCacheDecorator,
2930
CachetoolsDecorator,
30-
generate_precomputed_file,
3131
)
3232

3333
# Setup logging
@@ -41,11 +41,11 @@
4141
class CacheExperiment:
4242
def __init__(
4343
self,
44-
caches: Optional[List[Tuple[str, float, Any]]] = None,
44+
caches: List[Tuple[str, float, Any]],
4545
output_dir: Union[str, Path] = "results",
4646
batch_size: int = 32,
4747
num_workers: int = 0,
48-
dataset: Optional[torch.utils.data.Dataset] = None,
48+
dataset: Optional[HitRateDataset] = None,
4949
epochs: int = 1,
5050
):
5151
"""
@@ -78,9 +78,13 @@ def _run_single_experiment(self, cache) -> float:
7878
Returns:
7979
float: 未命中次数
8080
"""
81+
if self.caches is None:
82+
raise ValueError("Caches list cannot be None")
83+
if self.dataset is None:
84+
raise ValueError("Dataset cannot be None")
8185

8286
# 创建新的数据集实例,确保每次实验都从干净状态开始
83-
dataset = HitRateDataset(len(self.dataset.dataset))
87+
dataset = HitRateDataset(len(self.dataset))
8488
dataset.setCache(cache)
8589

8690
dataloader = DataLoader(
@@ -104,13 +108,17 @@ def _run_single_experiment(self, cache) -> float:
104108

105109
def run(self):
106110
"""运行所有缓存实验并保存结果"""
111+
if self.caches is None:
112+
raise ValueError("Caches list cannot be None")
113+
if self.dataset is None:
114+
raise ValueError("Dataset cannot be None")
107115

108116
logger.info("Starting Cache Performance Experiments")
109117
results = []
110118

111119
for name, cache_size, cache in self.caches:
112120
# 对于OPT缓存,需要重置状态
113-
if isinstance(cache, OPTCacheDecorator):
121+
if isinstance(cache, (OPTCacheDecorator, SharedOPTCacheDecorator)):
114122
cache.reset()
115123

116124
miss_count = self._run_single_experiment(cache)
@@ -201,10 +209,12 @@ def main(
201209

202210
logger.info(f"预计算文件生成完成: {precomputed_path}")
203211

204-
opt_decorator = OPTCacheDecorator(
212+
opt_decorator = get_opt_cache(
213+
mode="python",
205214
precomputed_path=precomputed_path,
206215
maxsize=cache_size,
207216
total_iter=total_iter,
217+
num_workers=0,
208218
)
209219
caches.append(("OPT", size, opt_decorator))
210220

lab/ram_usage/experiment.py

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from typing import List, Dict, Tuple, Optional, Union, Any
1313
from copy import deepcopy
1414

15-
import torch
1615
import psutil
1716
import typer
1817
from loguru import logger
@@ -26,7 +25,7 @@
2625
sys.path.insert(0, str(PROJECT_ROOT / "src"))
2726

2827
from lib.hit_rate_dataset import HitRateDataset
29-
from OPT4TorchDataSet.cachelib import OPTCacheDecorator, generate_precomputed_file
28+
from OPT4TorchDataSet import get_opt_cache, generate_precomputed_file
3029

3130
# Setup logging
3231
OUTPUT_DIR = Path(__file__).parent / "results"
@@ -39,11 +38,11 @@
3938
class CacheExperiment:
4039
def __init__(
4140
self,
42-
caches: Optional[List[Tuple[str, float, Any]]] = None,
41+
caches: List[Tuple[str, float, Any]],
4342
output_dir: Union[str, Path] = "results",
4443
batch_size: int = 32,
4544
num_workers: int = 0,
46-
dataset: Optional[torch.utils.data.Dataset] = None,
45+
dataset: Optional[HitRateDataset] = None,
4746
epochs: int = 1,
4847
):
4948
self.caches = caches
@@ -57,6 +56,8 @@ def __init__(
5756

5857
def _run_single_experiment(self, cache) -> Dict:
5958
"""Run single experiment focused on RAM usage"""
59+
if self.dataset is None:
60+
raise ValueError("Dataset cannot be None")
6061

6162
# 强制进行垃圾回收以获得更准确的基线内存使用量
6263
gc.collect()
@@ -98,21 +99,24 @@ def _run_single_experiment(self, cache) -> Dict:
9899

99100
# 获取缓存统计信息
100101
entry_count = 0
101-
if hasattr(dataset, "_getitem_impl"):
102-
if hasattr(dataset._getitem_impl, "__wrapped__"):
103-
cache = getattr(dataset._getitem_impl, "__wrapped__", None)
104-
if (
105-
cache
106-
and hasattr(cache, "cache")
107-
and hasattr(cache.cache, "__dict__")
108-
):
109-
entry_count = (
110-
len(cache.cache) if hasattr(cache.cache, "__len__") else 0
111-
)
112-
elif (
113-
cache and hasattr(cache, "__dict__") and "_cache" in cache.__dict__
114-
):
115-
entry_count = len(cache._cache)
102+
# Check both _wrapped_getitem (HitRateDataset) and generic __getitem__
103+
target_func = getattr(dataset, "_wrapped_getitem", None)
104+
if target_func and hasattr(target_func, "__wrapped__"):
105+
cache_obj = getattr(target_func, "__wrapped__", None)
106+
if (
107+
cache_obj
108+
and hasattr(cache_obj, "cache")
109+
and hasattr(cache_obj.cache, "__dict__")
110+
):
111+
entry_count = (
112+
len(cache_obj.cache) if hasattr(cache_obj.cache, "__len__") else 0
113+
)
114+
elif (
115+
cache_obj
116+
and hasattr(cache_obj, "__dict__")
117+
and "_cache" in cache_obj.__dict__
118+
):
119+
entry_count = len(cache_obj._cache)
116120

117121
# 清理
118122
del dataloader
@@ -226,10 +230,12 @@ def main(
226230
(
227231
"OPT",
228232
size,
229-
OPTCacheDecorator(
233+
get_opt_cache(
234+
mode="python",
230235
precomputed_path=precomputed_path,
231236
maxsize=cache_size,
232237
total_iter=total_iter,
238+
num_workers=0,
233239
),
234240
)
235241
)

0 commit comments

Comments
 (0)