|
4 | 4 | from pydantic import BaseModel |
5 | 5 |
|
6 | 6 | from modalities.batch import DatasetBatch |
7 | | - |
8 | | - |
9 | | -class RandomDatasetBatchGeneratorConfig(BaseModel): |
10 | | - vocab_size: int |
11 | | - sequence_length: int |
12 | | - batch_size: int |
| 7 | +from modalities.config.lookup_enum import LookupEnum |
13 | 8 |
|
14 | 9 |
|
15 | 10 | class DatasetBatchGeneratorIF(ABC): |
16 | 11 | def get_dataset_batch(self) -> DatasetBatch: |
17 | 12 | raise NotImplementedError |
18 | 13 |
|
19 | 14 |
|
| 15 | +class DataTypeEnum(LookupEnum): |
| 16 | + float32 = torch.float32 |
| 17 | + bfloat16 = torch.bfloat16 |
| 18 | + int64 = torch.int64 |
| 19 | + |
| 20 | + |
| 21 | +class RandomDatasetBatchGeneratorConfig(BaseModel): |
| 22 | + dims: dict[str, int] |
| 23 | + data_type: DataTypeEnum |
| 24 | + min_val: int |
| 25 | + max_val: int |
| 26 | + |
| 27 | + |
20 | 28 | class RandomDatasetBatchGenerator(DatasetBatchGeneratorIF): |
21 | | - def __init__(self, vocab_size: int, sequence_length: int, batch_size: int): |
22 | | - self._vocab_size = vocab_size |
23 | | - self._sequence_length = sequence_length |
24 | | - self._batch_size = batch_size |
| 29 | + def __init__(self, dims: dict[str, int], data_type: DataTypeEnum, min_val: int, max_val: int): |
| 30 | + self._dims = dims |
| 31 | + self._data_type = data_type |
| 32 | + self._min_val = min_val |
| 33 | + self._max_val = max_val |
| 34 | + self._device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
25 | 35 |
|
26 | 36 | def get_dataset_batch(self) -> DatasetBatch: |
| 37 | + size = tuple(self._dims.values()) |
| 38 | + if self._data_type == DataTypeEnum.int64: |
| 39 | + inputs = torch.randint(low=self._min_val, high=self._max_val, size=size, device=self._device) |
| 40 | + targets = torch.randint(low=self._min_val, high=self._max_val, size=size, device=self._device) |
| 41 | + elif self._data_type in {DataTypeEnum.float32, DataTypeEnum.bfloat16}: |
| 42 | + inputs = ( |
| 43 | + torch.rand(size=size, device=self._device, dtype=self._data_type.value) |
| 44 | + * (self._max_val - self._min_val) |
| 45 | + + self._min_val |
| 46 | + ) |
| 47 | + targets = ( |
| 48 | + torch.rand(size=size, device=self._device, dtype=self._data_type.value) |
| 49 | + * (self._max_val - self._min_val) |
| 50 | + + self._min_val |
| 51 | + ) |
| 52 | + else: |
| 53 | + raise ValueError(f"Unsupported data type: {self._data_type}") |
| 54 | + |
27 | 55 | batch = DatasetBatch( |
28 | | - samples={"input_ids": torch.randint(0, self._vocab_size, (self._batch_size, self._sequence_length))}, |
29 | | - targets={"target_ids": torch.randint(0, self._vocab_size, (self._batch_size, self._sequence_length))}, |
| 56 | + samples={ |
| 57 | + "input_ids": inputs, |
| 58 | + }, |
| 59 | + targets={ |
| 60 | + "target_ids": targets, |
| 61 | + }, |
30 | 62 | ) |
31 | 63 | return batch |
0 commit comments