Skip to content

Commit b3bf4d2

Browse files
romitjainkmehantseshapad
committed
v1 with cuml
Signed-off-by: romit <romit@ibm.com> Co-authored-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> Co-authored-by: Padmanabha Venkatagiri Seshadri <seshapad@in.ibm.com>
1 parent cb382a6 commit b3bf4d2

5 files changed

Lines changed: 75 additions & 59 deletions

File tree

plugins/online-data-mixing/artifacts/custom_loop_usage.py

Lines changed: 37 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,21 @@
88

99
# Third Party
1010
from accelerate import Accelerator, DataLoaderConfiguration
11-
from datasets import load_dataset
11+
from datasets import load_dataset, DatasetDict
1212
from torch.utils.data import DataLoader
1313
from tqdm import tqdm
1414
from transformers import (
1515
AutoModelForCausalLM,
16-
AutoTokenizer,
17-
DataCollatorForLanguageModeling,
16+
AutoTokenizer
1817
)
1918
import torch
19+
from functools import partial
2020

2121
# First Party
2222
from fms_acceleration_odm import OnlineMixingDataset
23+
from fms_acceleration_odm.odm.reward import Reward
2324

24-
model_name = "ibm-granite/granite-3.1-2b-instruct"
25+
model_name = "ibm-granite/granite-4.0-350m"
2526
output_dir = "./odm_custom_use"
2627
max_steps = 125
2728
batch_size = 12
@@ -40,15 +41,19 @@
4041

4142

4243
# dataset related
43-
def tokenize_fn(examples):
44-
return tokenizer(
45-
examples["text"], truncation=True, padding="max_length", max_length=128
46-
)
47-
44+
# If you have a single dataset, you can declare it with a single key, pair.
45+
# ODM will auto categorize the dataset into psuedo categories
46+
# If you have multiple categories of dataset, you can declare it with multiple key, pair, eg:
47+
# dataset_dict = {
48+
# "alpaca": load_dataset("tatsu-lab/alpaca", split="train[:1%]"),
49+
# "oasst": load_dataset("hakurei/open-instruct-v1", split="train[:1%]"),
50+
# }
4851

4952
dataset_dict = {
50-
"alpaca": load_dataset("tatsu-lab/alpaca", split="train[:1%]"),
51-
"oasst": load_dataset("hakurei/open-instruct-v1", split="train[:1%]"),
53+
"alpaca_train": load_dataset("tatsu-lab/alpaca", split="train[90%:]")
54+
}
55+
eval_dict = {
56+
"alpaca_val": load_dataset("tatsu-lab/alpaca", split="train[:1%]")
5257
}
5358

5459

@@ -63,43 +68,49 @@ def format_example(example):
6368
for name in dataset_dict:
6469
dataset_dict[name] = dataset_dict[name].map(format_example)
6570

71+
for name in eval_dict:
72+
eval_dict[name] = eval_dict[name].map(format_example)
73+
74+
dataset_dict = DatasetDict(dataset_dict) #type: ignore
75+
eval_dict = DatasetDict(eval_dict) #type: ignore
76+
77+
def collate_fn(batch, tokenizer):
78+
msgs = [b.pop("text") for b in batch]
6679

67-
def tokenize_fn(examples):
6880
return tokenizer(
69-
examples["text"],
81+
msgs,
7082
truncation=True,
7183
padding="max_length",
7284
max_length=1024,
73-
)
74-
75-
76-
for name in dataset_dict:
77-
dataset_dict[name] = dataset_dict[name].map(
78-
tokenize_fn,
79-
batched=True,
80-
remove_columns=dataset_dict[name].column_names,
85+
return_tensors="pt"
8186
)
8287

8388
collator_dict = {
84-
name: DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
89+
name: partial(collate_fn, tokenizer=tokenizer)
8590
for name in dataset_dict
8691
}
8792

93+
eval_collator_dict = {
94+
name: partial(collate_fn, tokenizer=tokenizer)
95+
for name in eval_dict
96+
}
97+
8898
# dataset preparation
8999
dataset = OnlineMixingDataset(
90100
dataset_dict=dataset_dict,
91101
collators_dict=collator_dict,
92-
eval_dataset_dict={},
93-
eval_collators_dict={},
102+
eval_dataset_dict=eval_dict,
103+
eval_collators_dict=eval_collator_dict,
94104
output_dir=output_dir,
95-
reward_type="train_loss",
105+
reward_type=Reward.TRAIN_LOSS,
96106
sampling_interval=batch_size,
107+
auto_categorize_config={"text_field": "text"}
97108
)
98109
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=None)
99110

100111
# distributed setup
101112
dataloader_config = DataLoaderConfiguration(split_batches=True, dispatch_batches=True)
102-
accelerator = Accelerator(split_batches=True, dataloader_config=dataloader_config)
113+
accelerator = Accelerator(dataloader_config=dataloader_config)
103114
model, dataloader = accelerator.prepare(model, dataloader)
104115

105116
# training setup

plugins/online-data-mixing/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ dependencies = [
2626
"datasets",
2727
"torchdata",
2828
"sentence-transformers",
29-
"scikit-learn",
29+
"cuml-cu12==25.10.*",
3030
]
3131

3232
[tool.hatch.build.targets.wheel]

plugins/online-data-mixing/src/fms_acceleration_odm/odm/auto_categorizer.py

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@
2323
import math
2424

2525
# Third Party
26-
from datasets import Dataset, DatasetDict
2726
import numpy as np
27+
from datasets import Dataset, DatasetDict
2828
from sentence_transformers import SentenceTransformer
29-
from sklearn.cluster import KMeans
29+
from cuml import KMeans
3030

3131
logger = getLogger(__name__)
3232

@@ -39,12 +39,18 @@ class AutoCategorizeConfig:
3939
num_categories: Optional[int] = None
4040
min_categories: int = 2
4141
max_categories: int = 15
42-
model_name: str = "sentence-transformers/all-MiniLM-L6-v2"
42+
model_name: str = "Qwen/Qwen3-Embedding-0.6B"
4343
batch_size: int = 64
4444
cluster_algo: str = "kmeans"
4545
random_state: int = 0
4646
category_prefix: str = "auto_category"
47-
device: Optional[str] = None
47+
# Args for loading model
48+
model_kwargs: Dict[str, any] = field(
49+
default_factory=lambda: {
50+
"device_map": "auto",
51+
# "attn_implementation": "flash_attention_2",
52+
}
53+
)
4854
cluster_kwargs: Dict[str, Any] = field(default_factory=dict)
4955

5056

@@ -92,21 +98,21 @@ def _determine_category_count(self, dataset_size: int) -> int:
9298
def _compute_embeddings(self, dataset: Dataset) -> np.ndarray:
9399
model = SentenceTransformer(
94100
self.config.model_name,
95-
device=self.config.device,
101+
model_kwargs=self.config.model_kwargs,
102+
prompts={
103+
"clustering": "Identify the topic or theme based on the text: ",
104+
},
105+
default_prompt_name="clustering",
96106
)
97-
vectors: List[np.ndarray] = []
98-
batched_dataset = dataset.batch(self.config.batch_size, num_proc=8)
99-
for batch in batched_dataset:
100-
texts = batch[self.config.text_field] # type: ignore
101-
vec = model.encode(
102-
texts,
103-
convert_to_numpy=True,
104-
show_progress_bar=False,
105-
batch_size=min(len(texts), self.config.batch_size),
106-
normalize_embeddings=True,
107-
)
108-
vectors.append(vec)
109-
return np.vstack(vectors)
107+
108+
vectors = model.encode(
109+
dataset[self.config.text_field],
110+
convert_to_numpy=True,
111+
show_progress_bar=True,
112+
batch_size=self.config.batch_size,
113+
normalize_embeddings=True
114+
)
115+
return vectors
110116

111117
def _cluster_embeddings(self, embeddings: np.ndarray, num_categories: int) -> np.ndarray:
112118
if self.config.cluster_algo.lower() != "kmeans":
@@ -117,6 +123,9 @@ def _cluster_embeddings(self, embeddings: np.ndarray, num_categories: int) -> np
117123
kwargs = {"n_init": 10, "random_state": self.config.random_state}
118124
kwargs.update(self.config.cluster_kwargs)
119125
model = KMeans(n_clusters=num_categories, **kwargs)
126+
127+
logger.info(f"Starting {self.config.cluster_algo} clustering")
128+
120129
return model.fit_predict(embeddings)
121130

122131
def _build_dataset_dict(self, dataset: Dataset, labels: np.ndarray) -> DatasetDict:
@@ -129,10 +138,3 @@ def _build_dataset_dict(self, dataset: Dataset, labels: np.ndarray) -> DatasetDi
129138
categorized[name] = dataset.select(indices)
130139
return DatasetDict(categorized)
131140

132-
133-
def auto_categorize_dataset(
134-
dataset: Dataset,
135-
config: Optional[AutoCategorizeConfig] = None,
136-
) -> DatasetDict:
137-
"""Convenience wrapper to auto-categorize a dataset."""
138-
return DatasetAutoCategorizer(config)(dataset)

plugins/online-data-mixing/src/fms_acceleration_odm/odm/dataset.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import torch
1515

1616
# Local
17-
from .auto_categorizer import AutoCategorizeConfig, auto_categorize_dataset
17+
from .auto_categorizer import AutoCategorizeConfig, DatasetAutoCategorizer
1818
from .reward import Reward, compute_reward
1919

2020
logger = getLogger(__name__)
@@ -339,10 +339,14 @@ def _maybe_auto_categorize_dataset(
339339
logger.info("Starting auto categorization process")
340340

341341
dataset_candidate: Dataset = next(iter(dataset_container.values()))
342-
categorized = auto_categorize_dataset(
343-
dataset=dataset_candidate,
344-
config=self._auto_categorize_config
345-
)
342+
auto_categorizer = DatasetAutoCategorizer(config=self._auto_categorize_config)
343+
categorized = auto_categorizer(dataset=dataset_candidate)
344+
345+
# We can delete the auto categorizer object since
346+
# it loads a sentence embedding model
347+
del(auto_categorizer)
348+
torch.cuda.empty_cache()
349+
346350
collators_dict = self._broadcast_collators_to_auto_categories(
347351
collators_dict, list(categorized.keys()) # type: ignore
348352
)

plugins/online-data-mixing/tests/test_auto_categorization.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,7 @@ def test_auto_categorize_single_dataset(monkeypatch):
5353
assert set(odm_dataset.category_list) == {"train_cluster_0", "train_cluster_1"}
5454
# Ensure collators were broadcast to the generated categories
5555
assert set(odm_dataset.collators_dict.keys()) == set(odm_dataset.dataset_dict.keys())
56-
for name in odm_dataset.collators_dict:
57-
assert odm_dataset.collators_dict[name] is collator
56+
5857
# Combined rows should match original dataset size
5958
total_rows = sum(len(ds) for ds in odm_dataset.dataset_dict.values())
6059
assert total_rows == len(dataset)

0 commit comments

Comments
 (0)