Skip to content

Commit 796d521

Browse files
committed
feat: add mmd operator
1 parent c336376 commit 796d521

2 files changed

Lines changed: 232 additions & 0 deletions

File tree

Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
from typing import Any, Literal
2+
3+
import pandas as pd
4+
from distflow.data.types import DatasetProcessOutputItem, MessageData
5+
from distflow.mmd import MMDDistance
6+
7+
from dataflow import get_logger
8+
from dataflow.core import OperatorABC
9+
from dataflow.utils.registry import OPERATOR_REGISTRY
10+
from dataflow.utils.storage import DataFlowStorage
11+
12+
13+
@OPERATOR_REGISTRY.register()
14+
class MMDDatasetEvaluator(OperatorABC):
15+
def __init__(
16+
self,
17+
ref_frame: DataFlowStorage,
18+
*,
19+
# dataset config
20+
ref_max_sample_num: int = 5000,
21+
ref_shuffle_seed: int = 42,
22+
ref_instruction_key: str = "input",
23+
ref_output_key: str = "output",
24+
# kernel
25+
kernel_type: Literal["RBF"] = "RBF",
26+
bias: bool = True,
27+
rbf_sigma: float = 1.0,
28+
# embedding common
29+
embedding_type: Literal[
30+
"vllm", "sentence_transformers"
31+
] = "sentence_transformers",
32+
embedding_model_name: str | None = None,
33+
# sentence_transformers specific
34+
st_device: str = "cuda",
35+
st_batch_size: int = 32,
36+
st_normalize_embeddings: bool = True,
37+
# vllm specific
38+
vllm_max_num_seqs: int = 128,
39+
vllm_gpu_memory_utilization: float = 0.9,
40+
vllm_tensor_parallel_size: int = 1,
41+
vllm_pipeline_parallel_size: int = 1,
42+
vllm_truncate_max_length: int = 40960,
43+
# cache config
44+
cache_type: Literal["redis", "none"] = "none",
45+
redis_url: str = "redis://127.0.0.1:6379",
46+
max_concurrent_requests: int = 50,
47+
redis_db: int = 0,
48+
cache_model_id: str | None = None,
49+
):
50+
self.logger = get_logger()
51+
self.logger.info(f"Initializing {self.__class__.__name__}...")
52+
53+
self.ref_max_sample_num = ref_max_sample_num
54+
self.ref_shuffle_seed = ref_shuffle_seed
55+
self.ref_data = self._sample_data_helper(
56+
data_frame=ref_frame,
57+
max_sample_num=ref_max_sample_num,
58+
shuffle_seed=ref_shuffle_seed,
59+
instruction_key=ref_instruction_key,
60+
output_key=ref_output_key,
61+
)
62+
63+
assert (
64+
embedding_model_name is not None
65+
), "embedding_model_name must be specified"
66+
if embedding_type == "sentence_transformers":
67+
from distflow.embed.sentence_transformers import SentenceTransformersEmbed
68+
69+
embedder = SentenceTransformersEmbed(
70+
model_name=embedding_model_name,
71+
device=st_device,
72+
batch_size=st_batch_size,
73+
normalize_embeddings=st_normalize_embeddings,
74+
trust_remote_code=True,
75+
)
76+
elif embedding_type == "vllm":
77+
from distflow.embed.vllm import VllmEmbed
78+
79+
embedder = VllmEmbed(
80+
model_name=embedding_model_name,
81+
max_num_seqs=vllm_max_num_seqs,
82+
gpu_memory_utilization=vllm_gpu_memory_utilization,
83+
tensor_parallel_size=vllm_tensor_parallel_size,
84+
pipeline_parallel_size=vllm_pipeline_parallel_size,
85+
truncate_max_length=vllm_truncate_max_length,
86+
)
87+
else:
88+
raise ValueError(f"Unsupported embedding_type: {embedding_type}")
89+
90+
if cache_type == "redis":
91+
from distflow.cache.redis_cache import RedisCache
92+
from distflow.embed.cache_wrapper import CachedEmbed
93+
94+
cache = RedisCache(
95+
redis_url=redis_url,
96+
max_concurrent_requests=max_concurrent_requests,
97+
redis_db=redis_db,
98+
)
99+
embedder = CachedEmbed(embedder, cache, cache_model_id=cache_model_id)
100+
elif cache_type != "none":
101+
raise ValueError(f"Unsupported cache_type: {cache_type}")
102+
103+
self.mmd_distance = MMDDistance(
104+
embedder=embedder,
105+
kernel_type=kernel_type,
106+
bias=bias,
107+
rbf_sigma=rbf_sigma,
108+
)
109+
110+
def _sample_data_helper(
111+
self,
112+
data_frame: DataFlowStorage,
113+
max_sample_num: int,
114+
shuffle_seed: int,
115+
instruction_key: str,
116+
output_key: str,
117+
) -> list[DatasetProcessOutputItem]:
118+
samples: pd.DataFrame = data_frame.read("dataframe")
119+
120+
if max_sample_num > 0 and max_sample_num < len(samples):
121+
self.logger.info(f"随机采样 {max_sample_num} 条数据")
122+
sampled_df = samples.sample(n=max_sample_num, random_state=shuffle_seed)
123+
else:
124+
self.logger.info("使用全部数据并打乱顺序")
125+
sampled_df = samples.sample(frac=1, random_state=shuffle_seed)
126+
127+
sampled_df = sampled_df.reset_index(drop=True)
128+
129+
instructions = sampled_df[instruction_key].to_list()
130+
outputs = sampled_df[output_key].to_list()
131+
data: list[DatasetProcessOutputItem] = []
132+
for instruction, output in zip(instructions, outputs):
133+
assert isinstance(instruction, str) and isinstance(
134+
output, str
135+
), "Instruction and output must be strings"
136+
data.append(
137+
DatasetProcessOutputItem(
138+
messages=[
139+
MessageData(role="user", content=instruction),
140+
MessageData(role="assistant", content=output),
141+
],
142+
meta={
143+
"frame": f"{data_frame!s}",
144+
"instruction_key": instruction_key,
145+
"output_key": output_key,
146+
"max_samples": max_sample_num,
147+
"shuffle_seed": shuffle_seed,
148+
},
149+
)
150+
)
151+
return data
152+
153+
@staticmethod
154+
def get_desc(lang: str = "zh"):
155+
if lang == "zh":
156+
return (
157+
"使用最大均值差异 (MMD) 方法评估两个数据集之间的分布差异。\n"
158+
"通过将文本嵌入到高维空间并计算核函数差异,量化评估数据集与参考数据集的分布偏移程度。\n"
159+
"输入参数:\n"
160+
"- ref_frame: 参考数据集 (DataFlowStorage),作为分布比较的基准\n"
161+
"- ref_max_sample_num: 参考集最大采样数,默认 5000\n"
162+
"- ref_shuffle_seed: 参考集随机种子,默认 42\n"
163+
"- ref_instruction_key: 参考集中指令字段名,默认 'input'\n"
164+
"- ref_output_key: 参考集中输出字段名,默认 'output'\n"
165+
"- kernel_type: 核函数类型,当前仅支持 'RBF'\n"
166+
"- rbf_sigma: RBF 核带宽参数,默认 1.0\n"
167+
"- embedding_type: 嵌入模型类型,可选 'sentence_transformers' 或 'vllm'\n"
168+
"- embedding_model_name: 嵌入模型名称(必填)\n"
169+
"- st_device/st_batch_size: SentenceTransformers 设备与批次大小\n"
170+
"- vllm_*: vLLM 相关配置参数\n"
171+
"- cache_type: 缓存类型,可选 'redis' 或 'none'\n"
172+
"输出参数:\n"
173+
"- MMDScore: MMD 距离值(越小表示分布越接近)\n"
174+
"- MMDMeta: 包含计算细节的元数据字典"
175+
)
176+
elif lang == "en":
177+
return (
178+
"Evaluate distribution discrepancy between two datasets using Maximum Mean Discrepancy (MMD).\n"
179+
"Quantifies distribution shift by computing kernel-based distance between embeddings of evaluation data and reference data.\n"
180+
"Input Parameters:\n"
181+
"- ref_frame: Reference dataset (DataFlowStorage) as distribution baseline\n"
182+
"- ref_max_sample_num: Max samples from reference, default 5000\n"
183+
"- ref_shuffle_seed: Random seed for reference sampling, default 42\n"
184+
"- ref_instruction_key: Instruction field name in reference, default 'input'\n"
185+
"- ref_output_key: Output field name in reference, default 'output'\n"
186+
"- kernel_type: Kernel function type, currently only 'RBF' supported\n"
187+
"- rbf_sigma: RBF kernel bandwidth, default 1.0\n"
188+
"- embedding_type: Embedding backend, 'sentence_transformers' or 'vllm'\n"
189+
"- embedding_model_name: Embedding model name (required)\n"
190+
"- st_device/st_batch_size: SentenceTransformers device and batch size\n"
191+
"- vllm_*: vLLM configuration parameters\n"
192+
"- cache_type: Cache type, 'redis' or 'none'\n"
193+
"Output Parameters:\n"
194+
"- MMDScore: MMD distance value (smaller indicates closer distributions)\n"
195+
"- MMDMeta: Metadata dictionary with computation details"
196+
)
197+
else:
198+
return "Evaluate dataset distribution discrepancy using Maximum Mean Discrepancy (MMD)."
199+
200+
def run(
201+
self,
202+
storage: DataFlowStorage,
203+
input_instruction_key: str,
204+
input_output_key: str,
205+
max_sample_num: int | None = None,
206+
shuffle_seed: int | None = None,
207+
) -> tuple[float, dict[str, Any]]:
208+
max_sample_num = (
209+
max_sample_num if max_sample_num is not None else self.ref_max_sample_num
210+
)
211+
shuffle_seed = (
212+
shuffle_seed if shuffle_seed is not None else self.ref_shuffle_seed
213+
)
214+
eval_data = self._sample_data_helper(
215+
data_frame=storage,
216+
max_sample_num=max_sample_num,
217+
shuffle_seed=shuffle_seed,
218+
instruction_key=input_instruction_key,
219+
output_key=input_output_key,
220+
)
221+
mmd_result = self.mmd_distance.compute(
222+
src=eval_data,
223+
tgt=self.ref_data,
224+
)
225+
mmd_value = mmd_result[0].value
226+
mmd_meta = mmd_result[0].meta
227+
self.logger.info(
228+
f"MMDDatasetEvaluator result: MMD={mmd_value}, meta={mmd_meta}"
229+
)
230+
return mmd_value, mmd_meta

requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,3 +86,5 @@ gcsfs
8686

8787
db-dtypes
8888
google-cloud-bigquery-storage
89+
90+
distflow

0 commit comments

Comments
 (0)