Skip to content

Commit fea73ae

Browse files
committed
feat: allow parallelization of CAI data generation
Ugly, but gets me down to a few minutes instead of almost an hour of processing.
1 parent f930d68 commit fea73ae

1 file changed

Lines changed: 14 additions & 1 deletion

File tree

toolbox/datasets/characterai.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import json
22
import logging
3+
import math
34
import os
45
import typing as t
56
from dataclasses import dataclass
@@ -97,7 +98,19 @@ def _enumerate_json_files(root_path: str) -> list[str]:
9798
absolute_file_path = os.path.abspath(os.path.join(root_path, item))
9899
files.append(absolute_file_path)
99100

100-
return files
101+
# Super nasty code to allow generation of CAI data with separate processes
102+
# so I can speed it up. Pass the "SHARD" and "TOTAL_SHARDS" environment
103+
# variables to operate on the different parts of the data.
104+
if "SHARD" not in os.environ:
105+
return files
106+
107+
TOTAL_SHARDS = int(os.environ.get("TOTAL_SHARDS", 10))
108+
items_per_shard = math.floor(len(files) / TOTAL_SHARDS)
109+
110+
shard = int(os.environ["SHARD"])
111+
file_range = (items_per_shard * shard, (items_per_shard * (shard + 1)) - 1)
112+
113+
return files[file_range[0]:file_range[1]]
101114

102115

103116
def _available_json_data() -> t.Generator[dict[str, t.Any], None, None]:

0 commit comments

Comments
 (0)