|
| 1 | +from typing import Dict, Optional |
| 2 | + |
| 3 | +import numpy as np |
| 4 | +from pydantic import Field, PositiveInt |
| 5 | +from sklearn.cluster import KMeans |
| 6 | +from tqdm import tqdm |
| 7 | +from typing_extensions import Annotated |
| 8 | + |
| 9 | +from data_juicer.ops.base_op import OPERATORS, Selector |
| 10 | +from data_juicer.utils.lazy_loader import LazyLoader |
| 11 | +from data_juicer.utils.model_utils import get_model, prepare_model |
| 12 | + |
| 13 | +torch = LazyLoader('torch') |
| 14 | + |
| 15 | + |
| 16 | +@OPERATORS.register_module('domain_diversity_selector') |
| 17 | +class DomainDiversitySelector(Selector): |
| 18 | + """Selector to select samples based on the data's domain diversity. """ |
| 19 | + |
| 20 | + _accelerator = 'cuda' |
| 21 | + |
| 22 | + def __init__(self, |
| 23 | + api_or_hf_model: str = 'text-embedding-v3', |
| 24 | + is_hf_model: bool = False, |
| 25 | + api_endpoint: str = '/embeddings', |
| 26 | + response_path: str = 'data.0.embedding', |
| 27 | + model_params: Dict = {}, |
| 28 | + select_ratio: Optional[Annotated[float, |
| 29 | + Field(ge=0, le=1)]] = None, |
| 30 | + init_k: PositiveInt = 3, |
| 31 | + ebd_dim: PositiveInt = 512, |
| 32 | + strategy: str = 'inter', |
| 33 | + *args, |
| 34 | + **kwargs): |
| 35 | + """ |
| 36 | + Initialization method. |
| 37 | +
|
| 38 | + :param api_or_hf_model: API or huggingface embedding model name. |
| 39 | + :param is_hf_model: Indicates if the model is from HuggingFace. |
| 40 | + :param api_endpoint: Embedding URL endpoint for the API. |
| 41 | + :param response_path: Path to extract content from the API response. |
| 42 | + Defaults to 'data.0.embedding' for embedding model. |
| 43 | + :param model_params: Parameters for initializing the API model. |
| 44 | + :param select_ratio: The ratio to select. |
| 45 | + :param init_k: The value of k in k-means algorithm. |
| 46 | + :param ebd_dim: The embedding's dimension via API. |
| 47 | + :param strategy: 'inter' - Domain's inter diversity, |
| 48 | + 'intra' - Domain's intra diversity, |
| 49 | + 'global' - Diversity to global centroid. |
| 50 | + :param args: extra args |
| 51 | + :param kwargs: extra args |
| 52 | + """ |
| 53 | + super().__init__(*args, **kwargs) |
| 54 | + self.api_or_hf_model = api_or_hf_model |
| 55 | + self.is_hf_model = is_hf_model |
| 56 | + self.api_endpoint = api_endpoint |
| 57 | + self.response_path = response_path |
| 58 | + self.select_ratio = select_ratio |
| 59 | + self.init_k = init_k |
| 60 | + self.ebd_dim = ebd_dim |
| 61 | + self.strategy = strategy |
| 62 | + |
| 63 | + if is_hf_model: |
| 64 | + self.model_key = prepare_model(model_type='embedding', |
| 65 | + model_path=api_or_hf_model, |
| 66 | + trust_remote_code=True, |
| 67 | + **model_params) |
| 68 | + else: |
| 69 | + self.model_key = prepare_model(model_type='api', |
| 70 | + model=api_or_hf_model, |
| 71 | + endpoint=self.api_endpoint, |
| 72 | + response_path=self.response_path, |
| 73 | + **model_params) |
| 74 | + |
| 75 | + def dataset_embedding(self, dataset, rank=None): |
| 76 | + embeddings = [] |
| 77 | + model = get_model(self.model_key, rank, self.use_cuda()) |
| 78 | + |
| 79 | + if self.is_hf_model: |
| 80 | + # Embeddings extract via local models |
| 81 | + for sample in tqdm(dataset, desc='Embedding', unit='sample'): |
| 82 | + text = sample['text'] |
| 83 | + with torch.no_grad(): |
| 84 | + embedding = model.encode(text) |
| 85 | + embeddings.append(embedding) |
| 86 | + else: |
| 87 | + # Embeddings extract via API |
| 88 | + for sample in tqdm(dataset, desc='Embedding', unit='sample'): |
| 89 | + text = sample['text'] |
| 90 | + embedding = model(text, |
| 91 | + dimensions=self.ebd_dim, |
| 92 | + encoding_format='float') |
| 93 | + embeddings.append(embedding) |
| 94 | + |
| 95 | + embeddings = np.array(embeddings) |
| 96 | + return embeddings |
| 97 | + |
| 98 | + def domain_diversity_status(self, dataset): |
| 99 | + |
| 100 | + data_status = [] |
| 101 | + |
| 102 | + embeddings_array = self.dataset_embedding(dataset) |
| 103 | + global_centroid = np.mean(embeddings_array, axis=0) |
| 104 | + |
| 105 | + # K-means cluster |
| 106 | + kmeans = KMeans(n_clusters=self.init_k, random_state=42) |
| 107 | + labels = kmeans.fit_predict(embeddings_array) |
| 108 | + |
| 109 | + centroid_embeddings = [] |
| 110 | + for label in np.unique(labels): |
| 111 | + label_embeddings = embeddings_array[labels == label] |
| 112 | + centroid = np.mean(label_embeddings, axis=0) |
| 113 | + centroid_embeddings.append(centroid) |
| 114 | + |
| 115 | + centroid_embeddings = np.array(centroid_embeddings) |
| 116 | + |
| 117 | + # Sample-level cos-similarity to other centroids |
| 118 | + for i, entry in tqdm(enumerate(dataset), |
| 119 | + total=len(dataset), |
| 120 | + desc='Calculating similarity:'): |
| 121 | + current_embedding = embeddings_array[i] |
| 122 | + current_label = int(labels[i]) |
| 123 | + |
| 124 | + similarities = [] |
| 125 | + for j, centroid in enumerate(centroid_embeddings): |
| 126 | + if j != current_label: |
| 127 | + similarity = torch.nn.functional.cosine_similarity( |
| 128 | + torch.tensor(current_embedding).unsqueeze(0), |
| 129 | + torch.tensor(centroid).unsqueeze(0)).item() |
| 130 | + similarities.append(similarity) |
| 131 | + |
| 132 | + own_centroid_similarity = torch.nn.functional.cosine_similarity( |
| 133 | + torch.tensor(current_embedding).unsqueeze(0), |
| 134 | + torch.tensor( |
| 135 | + centroid_embeddings[current_label]).unsqueeze(0)).item() |
| 136 | + |
| 137 | + global_centroid_similarity = torch.nn.functional.cosine_similarity( |
| 138 | + torch.tensor(current_embedding).unsqueeze(0), |
| 139 | + torch.tensor(global_centroid).unsqueeze(0)).item() |
| 140 | + total_similarity = sum(similarities) |
| 141 | + |
| 142 | + data_status.append({ |
| 143 | + 'text': entry['text'], |
| 144 | + 'label': current_label, |
| 145 | + 'similarity_with_other_centroids': similarities, |
| 146 | + 'total_similarity': total_similarity, |
| 147 | + 'similarity_with_own_centroid': own_centroid_similarity, |
| 148 | + 'global_centroid_similarity': global_centroid_similarity, |
| 149 | + 'original_index': i |
| 150 | + }) |
| 151 | + |
| 152 | + return data_status, labels |
| 153 | + |
| 154 | + def diversity_process(self, dataset): |
| 155 | + data_status, labels = self.domain_diversity_status(dataset) |
| 156 | + select_indices = [] |
| 157 | + |
| 158 | + for label in np.unique(labels): |
| 159 | + label_data_status = [ |
| 160 | + item for item in data_status if item['label'] == label |
| 161 | + ] |
| 162 | + |
| 163 | + # Related to the strategy |
| 164 | + if self.strategy == 'inter': |
| 165 | + label_data_status.sort(key=lambda x: x['total_similarity']) |
| 166 | + elif self.strategy == 'intra': |
| 167 | + label_data_status.sort( |
| 168 | + key=lambda x: x['similarity_with_own_centroid'], |
| 169 | + reverse=True) |
| 170 | + elif self.strategy == 'global': |
| 171 | + label_data_status.sort( |
| 172 | + key=lambda x: x['global_centroid_similarity']) |
| 173 | + else: |
| 174 | + raise ValueError( |
| 175 | + "Invalid strategy. Use 'inter', 'intra' or 'global'.") |
| 176 | + |
| 177 | + num_to_select = max( |
| 178 | + 1, int(self.select_ratio * len(label_data_status))) |
| 179 | + selected_indices = [ |
| 180 | + item['original_index'] |
| 181 | + for item in label_data_status[:num_to_select] |
| 182 | + ] |
| 183 | + select_indices.extend(selected_indices) |
| 184 | + |
| 185 | + select_dataset = dataset.select(select_indices) |
| 186 | + |
| 187 | + return select_dataset |
| 188 | + |
| 189 | + def process(self, dataset): |
| 190 | + |
| 191 | + if len(dataset) <= 1: |
| 192 | + return dataset |
| 193 | + if self.select_ratio is None: |
| 194 | + return dataset |
| 195 | + |
| 196 | + select_dataset = self.diversity_process(dataset) |
| 197 | + return select_dataset |
0 commit comments