Skip to content

Commit 14287e3

Browse files
committed
refactor: improved the throughput of the tokenized file writer by using more efficient data routines
1 parent 306be02 commit 14287e3

1 file changed

Lines changed: 42 additions & 17 deletions

File tree

src/modalities/dataloader/preprocessing/tokenization/tokenized_file_writer.py

Lines changed: 42 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import math
22
import os
33
import pickle
4-
from itertools import repeat
54
from pathlib import Path
65
from typing import BinaryIO
76

@@ -82,30 +81,56 @@ def _write_index_segment(file_descriptor: BinaryIO, index_list: list[tuple[int,
8281
def _write_data_segment(
8382
file_descriptor: BinaryIO, token_data: list[np.ndarray], token_size_in_bytes: int, write_batch_size: int
8483
) -> list[tuple[int, int]]:
85-
def encoded_token_to_bytes(encoded_token: int, token_size_in_bytes: int) -> bytes:
86-
# Converts an token_ids to its byte representation.
87-
try:
88-
token_bytes = encoded_token.to_bytes(token_size_in_bytes, byteorder="little", signed=False)
89-
except OverflowError as e:
90-
raise ValueError(f"Token {encoded_token} cannot be represented by {token_size_in_bytes} bytes.") from e
91-
return token_bytes
92-
93-
samples = []
94-
index_list = []
84+
# Fast path: vectorized cast + tobytes (no per-token Python work).
85+
# Preserves little-endian unsigned representation and overflow checks.
86+
87+
if token_size_in_bytes == 1:
88+
dtype = np.dtype("u1")
89+
elif token_size_in_bytes == 2:
90+
dtype = np.dtype("<u2") # force little-endian
91+
elif token_size_in_bytes == 4:
92+
dtype = np.dtype("<u4") # force little-endian
93+
else:
94+
raise ValueError("Currently only support token byte sizes of 1, 2, and 4.")
95+
96+
max_allowed = (1 << (8 * token_size_in_bytes)) - 1
97+
98+
samples: list[bytes] = []
99+
index_list: list[tuple[int, int]] = []
95100
curr_offset = 0
101+
pending = 0
102+
96103
for sample_tokens in token_data:
97-
# convert token_ids to byte representation
98-
sample_token_byte_string = b"".join(
99-
map(encoded_token_to_bytes, sample_tokens.tolist(), repeat(token_size_in_bytes))
100-
)
104+
arr = np.asarray(sample_tokens)
105+
106+
# ---- Overflow / range check (preserves original semantics) ----
107+
if arr.size:
108+
min_val = int(arr.min())
109+
max_val = int(arr.max())
110+
if min_val < 0 or max_val > max_allowed:
111+
raise ValueError(
112+
f"Token values out of range for {token_size_in_bytes} bytes: "
113+
f"min={min_val}, max={max_val}, allowed=[0, {max_allowed}]"
114+
)
115+
# ----------------------------------------------------------------
116+
117+
# Cast to correct unsigned little-endian dtype
118+
arr = np.asarray(arr, dtype=dtype, order="C")
119+
sample_token_byte_string = arr.tobytes(order="C")
120+
101121
samples.append(sample_token_byte_string)
102122
index_list.append((curr_offset, len(sample_token_byte_string)))
103123
curr_offset += len(sample_token_byte_string)
104-
if len(samples) % write_batch_size == 0:
124+
125+
pending += 1
126+
if pending >= write_batch_size:
105127
file_descriptor.write(b"".join(samples))
106-
samples = []
128+
samples.clear()
129+
pending = 0
130+
107131
if len(samples) > 0:
108132
file_descriptor.write(b"".join(samples))
133+
109134
return index_list
110135

111136
@staticmethod

0 commit comments

Comments
 (0)