You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
`OnlineMixingDataset` can be imported easily and integrated into existing training loops with minimal changes. A sample custom training loop implementation can be found [here](./artifacts/custom_loop_usage.py). Given code sample uses two instruction tuning datasets and trains `ibm-granite/granite-3.1-2b-instruct` model for next token prediction task.
19
+
20
+
## Metrics
21
+
22
+
All metrics related to the online data mixing will be logged to `odm.jsonl` file in the checkpoint output directory.
23
+
24
+
Metric | Description
25
+
--|--
26
+
`samples_produced_so_far` | Total samples produced by the dataset so far at the time of logging.
27
+
`sampling_interval` | Takes sample count "n" as input. At every "n" steps category/dataset chosen by weighted random sampling where weights are provided by the Multi-Armed Bandit algorithm.
28
+
`total_categories` | Total categories or datasets involved in mixing.
29
+
`current_sampling_weights` | Current state of the sampling weights at the time of logging.
30
+
`current_sampling_ratio` | Current state of the sampling ratios at the time of logging.
31
+
`arm_idx` | Last sampled category index. Categories/datasets are sorted in ascending order based on their names and index starts from 0 and each index corresponds to respective category/dataset.
32
+
`category_level_counts_so_far` | Split of sample count across datasets so far at the time of logging.
33
+
`rewards` | State of the rewards at the time of logging. Essentially are the last provided rewards across datasets.
34
+
`action` | Type of action took place at the time logging. It is either "update" or "sample" which correspond to weight update of the MAB algorithm or category sampling.
35
+
36
+
## Rewards
37
+
38
+
Below are the currently available rewards and we are constantly looking to improve the existing rewards and also add new ones. Further, we encourage users to identify rewards that can help their usecases.
39
+
40
+
Rewards | Description
41
+
--|--
42
+
`ENTROPY` | Calculation of shannon entropy of the logits averaged across all the tokens. Higher entropy would mean model requires more samples from that datasets/category.
43
+
`ENTROPY3_VARENT1` | 3 parts of shannon entropy and 1 part of variance of the entropy. Higher values mean requirement of more samples.
44
+
`ENTROPY_LAST_TOKEN` | Shannon entropy of the last token in the sample. Higher values mean requirement of more samples.
45
+
`TRAIN_LOSS` | Training loss where loss is maintained across categories and is updated based on the latest loss and sampled dataset/category. Higher values mean requirement of more samples.
46
+
`VALIDATION_LOSS` | Validation loss across categories calculated using evaluation datasets from each of the categories. Higher values mean requirement of more samples.
47
+
`GRADNORM` | Gradient norm where norms are maintained across categories and are updated based on the latest values and sampled dataset/category. Higher values mean reducing samples from that particular dataset/category.
48
+
49
+
### Adding a Custom Reward
50
+
Custom rewards can be added to the `compute_reward` function and adding it to the `Reward` enum. If the custom reward requires specific set of information from the training loop then `_extract_information_from_state_for_reward` function has to be extended for extracting such information from trainer state. This is member function of `OnlineMixingDataset`.
51
+
16
52
17
53
### Planned TODOs
18
54
Please see issue [#153](https://github.com/foundation-model-stack/fms-acceleration/issues/153).
0 commit comments