Skip to content

Commit b300a61

Browse files
committed
Add domain-diversity selector op
1 parent 8ab3b22 commit b300a61

5 files changed

Lines changed: 362 additions & 8 deletions

File tree

configs/config_all.yaml

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -923,7 +923,7 @@ process:
923923
redis_address: 'redis://localhost:6379' # the address of redis server
924924
lowercase: false # whether to convert text to lower case
925925
ignore_non_character: false # whether to ignore non-alphabet characters, including whitespaces, digits, and punctuations
926-
- ray_bts_minhash_deduplicator: # the document deduplicator that can run on multi-nodes using minhashLSH algorithm
926+
- ray_bts_minhash_deduplicator: # the document deduplicator that can run on multi-nodes using minhashLSH algorithm
927927
tokenization: space # tokenization method for text. One of [space, punctuation, character, sentencepiece]
928928
window_size: 5 # window size of shingling
929929
num_permutations: 256 # number of permutations in minhash computing
@@ -943,6 +943,16 @@ process:
943943
tmp_file_name: './outputs/ray-dedup-tmp/' # the temporary folder name for deduplication.
944944

945945
# Selector ops
946+
- domain_diversity_selector: # selector to select samples based on the data's domain diversity
947+
api_or_hf_model: 'text-embedding-v3' # API or huggingface embedding model name
948+
is_hf_model: False # indicates if the model is from HuggingFace
949+
api_endpoint: '/embeddings' # embedding URL endpoint for the API
950+
response_path: 'data.0.embedding' # path to extract content from the API response
951+
model_params: {} # parameters for initializing the API model
952+
select_ratio: # the ratio to be sampled
953+
init_k: 3 # the value of k in k-means algorithm
954+
ebd_dim: 512 # the embedding's dimension via API
955+
strategy: 'inter' # the selection strategy based on the relation across domains
946956
- frequency_specified_field_selector: # selector to select samples based on the sorted frequency of specified field value
947957
field_key: '' # the target keys corresponding to multi-level field information need to be separated by '.'
948958
top_ratio: # ratio of selected top specified field value
Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from .domain_diversity_selector import DomainDiversitySelector
12
from .frequency_specified_field_selector import FrequencySpecifiedFieldSelector
23
from .random_selector import RandomSelector
34
from .range_specified_field_selector import RangeSpecifiedFieldSelector
@@ -6,6 +7,6 @@
67

78
__all__ = [
89
'FrequencySpecifiedFieldSelector', 'RandomSelector',
9-
'RangeSpecifiedFieldSelector', 'TagsSpecifiedFieldSelector',
10-
'TopkSpecifiedFieldSelector'
10+
'DomainDiversitySelector', 'RangeSpecifiedFieldSelector',
11+
'TagsSpecifiedFieldSelector', 'TopkSpecifiedFieldSelector'
1112
]
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
from typing import Dict, Optional
2+
3+
import numpy as np
4+
from pydantic import Field, PositiveInt
5+
from sklearn.cluster import KMeans
6+
from tqdm import tqdm
7+
from typing_extensions import Annotated
8+
9+
from data_juicer.ops.base_op import OPERATORS, Selector
10+
from data_juicer.utils.lazy_loader import LazyLoader
11+
from data_juicer.utils.model_utils import get_model, prepare_model
12+
13+
torch = LazyLoader('torch')
14+
15+
16+
@OPERATORS.register_module('domain_diversity_selector')
17+
class DomainDiversitySelector(Selector):
18+
"""Selector to select samples based on the data's domain diversity. """
19+
20+
_accelerator = 'cuda'
21+
22+
def __init__(self,
23+
api_or_hf_model: str = 'text-embedding-v3',
24+
is_hf_model: bool = False,
25+
api_endpoint: str = '/embeddings',
26+
response_path: str = 'data.0.embedding',
27+
model_params: Dict = {},
28+
select_ratio: Optional[Annotated[float,
29+
Field(ge=0, le=1)]] = None,
30+
init_k: PositiveInt = 3,
31+
ebd_dim: PositiveInt = 512,
32+
strategy: str = 'inter',
33+
*args,
34+
**kwargs):
35+
"""
36+
Initialization method.
37+
38+
:param api_or_hf_model: API or huggingface embedding model name.
39+
:param is_hf_model: Indicates if the model is from HuggingFace.
40+
:param api_endpoint: Embedding URL endpoint for the API.
41+
:param response_path: Path to extract content from the API response.
42+
Defaults to 'data.0.embedding' for embedding model.
43+
:param model_params: Parameters for initializing the API model.
44+
:param select_ratio: The ratio to select.
45+
:param init_k: The value of k in k-means algorithm.
46+
:param ebd_dim: The embedding's dimension via API.
47+
:param strategy: 'inter' - Domain's inter diversity,
48+
'intra' - Domain's intra diversity,
49+
'global' - Diversity to global centroid.
50+
:param args: extra args
51+
:param kwargs: extra args
52+
"""
53+
super().__init__(*args, **kwargs)
54+
self.api_or_hf_model = api_or_hf_model
55+
self.is_hf_model = is_hf_model
56+
self.api_endpoint = api_endpoint
57+
self.response_path = response_path
58+
self.select_ratio = select_ratio
59+
self.init_k = init_k
60+
self.ebd_dim = ebd_dim
61+
self.strategy = strategy
62+
63+
if is_hf_model:
64+
self.model_key = prepare_model(model_type='embedding',
65+
model_path=api_or_hf_model,
66+
trust_remote_code=True,
67+
**model_params)
68+
else:
69+
self.model_key = prepare_model(model_type='api',
70+
model=api_or_hf_model,
71+
endpoint=self.api_endpoint,
72+
response_path=self.response_path,
73+
**model_params)
74+
75+
def dataset_embedding(self, dataset, rank=None):
76+
embeddings = []
77+
model = get_model(self.model_key, rank, self.use_cuda())
78+
79+
if self.is_hf_model:
80+
# Embeddings extract via local models
81+
for sample in tqdm(dataset, desc='Embedding', unit='sample'):
82+
text = sample['text']
83+
with torch.no_grad():
84+
embedding = model.encode(text)
85+
embeddings.append(embedding)
86+
else:
87+
# Embeddings extract via API
88+
for sample in tqdm(dataset, desc='Embedding', unit='sample'):
89+
text = sample['text']
90+
embedding = model(text,
91+
dimensions=self.ebd_dim,
92+
encoding_format='float')
93+
embeddings.append(embedding)
94+
95+
embeddings = np.array(embeddings)
96+
return embeddings
97+
98+
def domain_diversity_status(self, dataset):
99+
100+
data_status = []
101+
102+
embeddings_array = self.dataset_embedding(dataset)
103+
global_centroid = np.mean(embeddings_array, axis=0)
104+
105+
# K-means cluster
106+
kmeans = KMeans(n_clusters=self.init_k, random_state=42)
107+
labels = kmeans.fit_predict(embeddings_array)
108+
109+
centroid_embeddings = []
110+
for label in np.unique(labels):
111+
label_embeddings = embeddings_array[labels == label]
112+
centroid = np.mean(label_embeddings, axis=0)
113+
centroid_embeddings.append(centroid)
114+
115+
centroid_embeddings = np.array(centroid_embeddings)
116+
117+
# Sample-level cos-similarity to other centroids
118+
for i, entry in tqdm(enumerate(dataset),
119+
total=len(dataset),
120+
desc='Calculating similarity:'):
121+
current_embedding = embeddings_array[i]
122+
current_label = int(labels[i])
123+
124+
similarities = []
125+
for j, centroid in enumerate(centroid_embeddings):
126+
if j != current_label:
127+
similarity = torch.nn.functional.cosine_similarity(
128+
torch.tensor(current_embedding).unsqueeze(0),
129+
torch.tensor(centroid).unsqueeze(0)).item()
130+
similarities.append(similarity)
131+
132+
own_centroid_similarity = torch.nn.functional.cosine_similarity(
133+
torch.tensor(current_embedding).unsqueeze(0),
134+
torch.tensor(
135+
centroid_embeddings[current_label]).unsqueeze(0)).item()
136+
137+
global_centroid_similarity = torch.nn.functional.cosine_similarity(
138+
torch.tensor(current_embedding).unsqueeze(0),
139+
torch.tensor(global_centroid).unsqueeze(0)).item()
140+
total_similarity = sum(similarities)
141+
142+
data_status.append({
143+
'text': entry['text'],
144+
'label': current_label,
145+
'similarity_with_other_centroids': similarities,
146+
'total_similarity': total_similarity,
147+
'similarity_with_own_centroid': own_centroid_similarity,
148+
'global_centroid_similarity': global_centroid_similarity,
149+
'original_index': i
150+
})
151+
152+
return data_status, labels
153+
154+
def diversity_process(self, dataset):
155+
data_status, labels = self.domain_diversity_status(dataset)
156+
select_indices = []
157+
158+
for label in np.unique(labels):
159+
label_data_status = [
160+
item for item in data_status if item['label'] == label
161+
]
162+
163+
# Related to the strategy
164+
if self.strategy == 'inter':
165+
label_data_status.sort(key=lambda x: x['total_similarity'])
166+
elif self.strategy == 'intra':
167+
label_data_status.sort(
168+
key=lambda x: x['similarity_with_own_centroid'],
169+
reverse=True)
170+
elif self.strategy == 'global':
171+
label_data_status.sort(
172+
key=lambda x: x['global_centroid_similarity'])
173+
else:
174+
raise ValueError(
175+
"Invalid strategy. Use 'inter', 'intra' or 'global'.")
176+
177+
num_to_select = max(
178+
1, int(self.select_ratio * len(label_data_status)))
179+
selected_indices = [
180+
item['original_index']
181+
for item in label_data_status[:num_to_select]
182+
]
183+
select_indices.extend(selected_indices)
184+
185+
select_dataset = dataset.select(select_indices)
186+
187+
return select_dataset
188+
189+
def process(self, dataset):
190+
191+
if len(dataset) <= 1:
192+
return dataset
193+
if self.select_ratio is None:
194+
return dataset
195+
196+
select_dataset = self.diversity_process(dataset)
197+
return select_dataset

0 commit comments

Comments
 (0)