Skip to content

Commit d510e9c

Browse files
committed
convert : write tensors in parallel ggml-org#12837
gguf-py : add more clarifying comments for multi-thread writes Merge branch 'master' into compilade/parallel-convert gguf-py : use ThreadPoolExecutor when writing tensors - gguf-py : handle (limited) retries for remote tensors Original author : @compilade Merge branch 'compilade/parallel-convert' into NXS_Llama.cpp
1 parent 1c3f7c4 commit d510e9c

4 files changed

Lines changed: 175 additions & 33 deletions

File tree

convert_hf_to_gguf.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,8 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path,
114114
use_temp_file: bool = False, eager: bool = False,
115115
metadata_override: Path | None = None, model_name: str | None = None,
116116
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False,
117-
small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None,
117+
small_first_shard: bool = False, hparams: dict[str, Any] | None = None,
118+
remote_hf_model_id: str | None = None, thread_count: int = 2,
118119
disable_mistral_community_chat_template: bool = False,
119120
sentence_transformers_dense_modules: bool = False):
120121
if type(self) is ModelBase or \
@@ -162,7 +163,8 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path,
162163

163164
# Configure GGUF Writer
164165
self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file,
165-
split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard)
166+
split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard,
167+
thread_count=thread_count)
166168

167169
# Mistral specific
168170
self.disable_mistral_community_chat_template = disable_mistral_community_chat_template
@@ -12236,6 +12238,11 @@ def parse_args() -> argparse.Namespace:
1223612238
"Default these modules are not included.")
1223712239
)
1223812240

12241+
parser.add_argument(
12242+
"-t", "--threads", type=int, default=2,
12243+
help="Number of threads to use when writing the tensors. Make sure you have enough RAM for at least THREADS of the biggest tensors in the model when setting this. Defaults to 2.",
12244+
)
12245+
1223912246
args = parser.parse_args()
1224012247
if not args.print_supported_models and args.model is None:
1224112248
parser.error("the following arguments are required: model")
@@ -12425,7 +12432,8 @@ def main() -> None:
1242512432
split_max_tensors=args.split_max_tensors,
1242612433
split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
1242712434
small_first_shard=args.no_tensor_first_split,
12428-
remote_hf_model_id=hf_repo_id, disable_mistral_community_chat_template=disable_mistral_community_chat_template,
12435+
remote_hf_model_id=hf_repo_id, thread_count=args.threads),
12436+
disable_mistral_community_chat_template=disable_mistral_community_chat_template,
1242912437
sentence_transformers_dense_modules=args.sentence_transformers_dense_modules
1243012438
)
1243112439

gguf-py/gguf/gguf_writer.py

Lines changed: 117 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@
66
import struct
77
import sys
88
import tempfile
9+
import threading
910
from dataclasses import dataclass
1011
from enum import Enum, auto
1112
from math import prod
1213
from pathlib import Path
1314
from io import BufferedWriter
1415
from typing import IO, Any, Sequence, Mapping
1516
from string import ascii_letters, digits
17+
from concurrent.futures import FIRST_EXCEPTION, Future, ThreadPoolExecutor, wait
1618

1719
import numpy as np
1820

@@ -62,8 +64,63 @@ class WriterState(Enum):
6264
WEIGHTS = auto()
6365

6466

67+
# To close files which were opened in thread-local context
68+
# Necessary because ThreadPoolExecutor doesn't allow setting a custom finalizer
69+
# ref: https://github.com/python/cpython/issues/89502
70+
class _ThreadedOpenFiles:
71+
files: dict[Path, BufferedWriter]
72+
73+
def __init__(self):
74+
self.files = {}
75+
76+
def __del__(self):
77+
for file in self.files.values():
78+
file.close()
79+
80+
def __getitem__(self, key: Path, /) -> BufferedWriter:
81+
if key not in self.files:
82+
self.files[key] = open(key, "r+b")
83+
return self.files[key]
84+
85+
@classmethod
86+
def init_thread_local(cls, local_data):
87+
local_data.open_files = _ThreadedOpenFiles()
88+
89+
90+
# Exit quickly instead of waiting
91+
class _InterruptibleThreadPoolExecutor(ThreadPoolExecutor):
92+
def __exit__(self, exc_type, exc_val, exc_tb) -> bool | None:
93+
del exc_type, exc_val, exc_tb
94+
self.shutdown(wait=False, cancel_futures=True)
95+
return False
96+
97+
98+
@dataclass
99+
class _ThreadedTensorWriteInfo:
100+
filename: Path
101+
offset: int
102+
post_pad: int
103+
tensor: np.ndarray
104+
bar: Any | None # optional tqdm progress bar
105+
106+
def write_chunk(self, open_files: _ThreadedOpenFiles):
107+
# This is called from a thread pool,
108+
# and each thread should have its own file handle per output file
109+
# so that they can have different seek locations.
110+
f = open_files[self.filename]
111+
112+
f.seek(self.offset)
113+
f.write(self.tensor.data)
114+
if self.post_pad > 0:
115+
f.write(bytes([0] * self.post_pad))
116+
if self.bar is not None:
117+
self.bar.update(self.tensor.nbytes)
118+
119+
65120
class GGUFWriter:
66121
fout: list[BufferedWriter] | None
122+
filenames: list[Path] | None
123+
thread_count: int
67124
path: Path | None
68125
temp_file: tempfile.SpooledTemporaryFile[bytes] | None
69126
tensors: list[dict[str, TensorInfo]]
@@ -85,7 +142,8 @@ class GGUFWriter:
85142

86143
def __init__(
87144
self, path: os.PathLike[str] | str | None, arch: str, use_temp_file: bool = False, endianess: GGUFEndian = GGUFEndian.LITTLE,
88-
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False
145+
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False,
146+
thread_count: int = 2,
89147
):
90148
self.fout = None
91149
self.path = Path(path) if path else None
@@ -100,6 +158,7 @@ def __init__(
100158
self.split_max_size = split_max_size
101159
self.dry_run = dry_run
102160
self.small_first_shard = small_first_shard
161+
self.thread_count = thread_count
103162
logger.info("gguf: This GGUF file is for {0} Endian only".format(
104163
"Big" if self.endianess == GGUFEndian.BIG else "Little",
105164
))
@@ -176,6 +235,7 @@ def open_output_file(self, path: Path | None = None) -> None:
176235

177236
if self.path is not None:
178237
filenames = self.print_plan()
238+
self.filenames = filenames
179239
self.fout = [open(filename, "wb") for filename in filenames]
180240
self.state = WriterState.EMPTY
181241

@@ -437,40 +497,76 @@ def write_tensors_to_file(self, *, progress: bool = False) -> None:
437497
self.write_ti_data_to_file()
438498

439499
assert self.fout is not None
500+
assert self.filenames is not None
440501

441502
for fout in self.fout:
442503
self.write_padding(fout, fout.tell())
443504

444505
if self.temp_file is None:
445-
shard_bar = None
446506
bar = None
507+
# Initial file offsets before writing the tensor data
508+
offsets: list[int] = [fout.tell() for fout in self.fout]
447509

448510
if progress:
511+
# TODO: add back the shard bar to show which shard is being written when single-threaded
449512
from tqdm import tqdm
450513

451514
total_bytes = sum(ti.nbytes for t in self.tensors for ti in t.values())
452515

453-
if len(self.fout) > 1:
454-
shard_bar = tqdm(desc=f"Shard (0/{len(self.fout)})", total=None, unit="byte", unit_scale=True)
455516
bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True)
456517

457-
for i, (fout, tensors) in enumerate(zip(self.fout, self.tensors)):
458-
if shard_bar is not None:
459-
shard_bar.set_description(f"Shard ({i + 1}/{len(self.fout)})")
460-
total = sum(ti.nbytes for ti in tensors.values())
461-
shard_bar.reset(total=(total if total > 0 else None))
462-
463-
# relying on the fact that Python dicts preserve insertion order (since 3.7)
464-
for ti in tensors.values():
465-
assert ti.tensor is not None # can only iterate once over the tensors
466-
assert ti.tensor.nbytes == ti.nbytes
467-
ti.tensor.tofile(fout)
468-
if shard_bar is not None:
469-
shard_bar.update(ti.nbytes)
470-
if bar is not None:
471-
bar.update(ti.nbytes)
472-
self.write_padding(fout, ti.nbytes)
473-
ti.tensor = None
518+
# Allow opening the files only once per worker
519+
local_data = threading.local()
520+
521+
# Unit of work
522+
def thread_write_tensor(tensor: _ThreadedTensorWriteInfo):
523+
tensor.write_chunk(local_data.open_files)
524+
525+
with _InterruptibleThreadPoolExecutor(
526+
max_workers=self.thread_count,
527+
initializer=_ThreadedOpenFiles.init_thread_local,
528+
initargs=(local_data,),
529+
) as executor:
530+
531+
futures: list[Future] = []
532+
533+
# Fill the tensor queue with all the pending tensor writes
534+
for i, (filename, tensors) in enumerate(zip(self.filenames, self.tensors)):
535+
offset = offsets[i]
536+
537+
# relying on the fact that Python dicts preserve insertion order (since 3.7)
538+
for ti in tensors.values():
539+
assert ti.tensor is not None # can only iterate once over the tensors
540+
assert ti.tensor.nbytes == ti.nbytes
541+
start_offset = offset
542+
nbytes = ti.tensor.nbytes
543+
offset = self.ggml_pad(start_offset + nbytes, self.data_alignment)
544+
padding = offset - (start_offset + nbytes)
545+
futures.append(
546+
executor.submit(
547+
thread_write_tensor,
548+
_ThreadedTensorWriteInfo(
549+
filename=filename,
550+
offset=start_offset,
551+
post_pad=padding,
552+
tensor=ti.tensor,
553+
bar=bar,
554+
),
555+
)
556+
)
557+
ti.tensor = None # avoid keeping a reference to written tensors
558+
559+
# FIXME: there's still some weird behavior with KeyboardInterrupt
560+
# not being able to interrupt a future mid-execution
561+
done, not_done = wait(futures, return_when=FIRST_EXCEPTION)
562+
exc = None
563+
if any(f for f in done
564+
if not f.cancelled() and (exc := f.exception()) is not None):
565+
raise RuntimeError("Error writing tensors") from exc
566+
elif len(not_done) != 0:
567+
raise RuntimeError("Not all tensors were written")
568+
569+
del local_data
474570
else:
475571
self.temp_file.seek(0)
476572

gguf-py/gguf/lazy.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,4 +225,9 @@ def tofile(self, *args, **kwargs):
225225
eager = LazyNumpyTensor.to_eager(self)
226226
return eager.tofile(*args, **kwargs)
227227

228+
@property
229+
def data(self):
230+
eager = LazyNumpyTensor.to_eager(self)
231+
return eager.data
232+
228233
# TODO: __array_function__

gguf-py/gguf/utility.py

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,15 @@
88
import json
99
import numpy as np
1010

11+
import time
12+
import logging
13+
14+
import requests
15+
from urllib.parse import urlparse
16+
17+
18+
logger = logging.getLogger(__name__)
19+
1120

1221
def fill_templated_filename(filename: str, output_type: str | None) -> str:
1322
# Given a file name fill in any type templates e.g. 'some-model-name.{ftype}.gguf'
@@ -77,16 +86,38 @@ def naming_convention(model_name: str | None, base_name: str | None, finetune_st
7786

7887
@dataclass
7988
class RemoteTensor:
89+
name: str
8090
dtype: str
8191
shape: tuple[int, ...]
8292
offset_start: int
8393
size: int
8494
url: str
8595

8696
def data(self) -> bytearray:
87-
# TODO: handle request errors (maybe with limited retries?)
88-
# NOTE: using a bytearray, otherwise PyTorch complains the buffer is not writeable
89-
data = bytearray(SafetensorRemote.get_data_by_range(url=self.url, start=self.offset_start, size=self.size))
97+
data = None
98+
MAX_RETRIES = 8
99+
for i in range(MAX_RETRIES):
100+
try:
101+
# NOTE: using a bytearray, otherwise PyTorch complains the buffer is not writeable
102+
data = bytearray(
103+
SafetensorRemote.get_data_by_range(
104+
url=self.url, start=self.offset_start, size=self.size
105+
)
106+
)
107+
except (
108+
requests.exceptions.ChunkedEncodingError,
109+
requests.exceptions.ContentDecodingError,
110+
requests.exceptions.ConnectionError,
111+
) as e:
112+
if i == MAX_RETRIES - 1:
113+
raise RuntimeError(f"Failed to download tensor {self.name}") from e
114+
logger.warning(f"Retry ({i + 1}/{MAX_RETRIES}) downloading tensor {self.name} because of {e}")
115+
time.sleep(2 * i + 1) # 1 3 5 7 9 11 13
116+
continue
117+
118+
if data is None:
119+
raise RuntimeError(f"Failed to download tensor {self.name}")
120+
90121
return data
91122

92123

@@ -174,7 +205,14 @@ def get_list_tensors(cls, url: str) -> dict[str, RemoteTensor]:
174205
offset_start_relative, offset_end_relative = meta["data_offsets"]
175206
size = offset_end_relative - offset_start_relative
176207
offset_start = data_start_offset + offset_start_relative
177-
res[name] = RemoteTensor(dtype=dtype, shape=tuple(shape), offset_start=offset_start, size=size, url=url)
208+
res[name] = RemoteTensor(
209+
name=name,
210+
dtype=dtype,
211+
shape=tuple(shape),
212+
offset_start=offset_start,
213+
size=size,
214+
url=url,
215+
)
178216
except KeyError as e:
179217
raise ValueError(f"Missing key in metadata for tensor '{name}': {e}, meta = {meta}")
180218

@@ -223,8 +261,6 @@ def get_data_by_range(cls, url: str, start: int, size: int = -1) -> bytes:
223261
Get raw byte data from a remote file by range.
224262
If size is not specified, it will read the entire file.
225263
"""
226-
import requests
227-
from urllib.parse import urlparse
228264

229265
parsed_url = urlparse(url)
230266
if not parsed_url.scheme or not parsed_url.netloc:
@@ -245,9 +281,6 @@ def check_file_exist(cls, url: str) -> bool:
245281
Check if a file exists at the given URL.
246282
Returns True if the file exists, False otherwise.
247283
"""
248-
import requests
249-
from urllib.parse import urlparse
250-
251284
parsed_url = urlparse(url)
252285
if not parsed_url.scheme or not parsed_url.netloc:
253286
raise ValueError(f"Invalid URL: {url}")

0 commit comments

Comments
 (0)