66import struct
77import sys
88import tempfile
9+ import threading
910from dataclasses import dataclass
1011from enum import Enum , auto
1112from math import prod
1213from pathlib import Path
1314from io import BufferedWriter
1415from typing import IO , Any , Sequence , Mapping
1516from string import ascii_letters , digits
17+ from concurrent .futures import FIRST_EXCEPTION , Future , ThreadPoolExecutor , wait
1618
1719import numpy as np
1820
@@ -62,8 +64,63 @@ class WriterState(Enum):
6264 WEIGHTS = auto ()
6365
6466
67+ # To close files which were opened in thread-local context
68+ # Necessary because ThreadPoolExecutor doesn't allow setting a custom finalizer
69+ # ref: https://github.com/python/cpython/issues/89502
70+ class _ThreadedOpenFiles :
71+ files : dict [Path , BufferedWriter ]
72+
73+ def __init__ (self ):
74+ self .files = {}
75+
76+ def __del__ (self ):
77+ for file in self .files .values ():
78+ file .close ()
79+
80+ def __getitem__ (self , key : Path , / ) -> BufferedWriter :
81+ if key not in self .files :
82+ self .files [key ] = open (key , "r+b" )
83+ return self .files [key ]
84+
85+ @classmethod
86+ def init_thread_local (cls , local_data ):
87+ local_data .open_files = _ThreadedOpenFiles ()
88+
89+
90+ # Exit quickly instead of waiting
91+ class _InterruptibleThreadPoolExecutor (ThreadPoolExecutor ):
92+ def __exit__ (self , exc_type , exc_val , exc_tb ) -> bool | None :
93+ del exc_type , exc_val , exc_tb
94+ self .shutdown (wait = False , cancel_futures = True )
95+ return False
96+
97+
98+ @dataclass
99+ class _ThreadedTensorWriteInfo :
100+ filename : Path
101+ offset : int
102+ post_pad : int
103+ tensor : np .ndarray
104+ bar : Any | None # optional tqdm progress bar
105+
106+ def write_chunk (self , open_files : _ThreadedOpenFiles ):
107+ # This is called from a thread pool,
108+ # and each thread should have its own file handle per output file
109+ # so that they can have different seek locations.
110+ f = open_files [self .filename ]
111+
112+ f .seek (self .offset )
113+ f .write (self .tensor .data )
114+ if self .post_pad > 0 :
115+ f .write (bytes ([0 ] * self .post_pad ))
116+ if self .bar is not None :
117+ self .bar .update (self .tensor .nbytes )
118+
119+
65120class GGUFWriter :
66121 fout : list [BufferedWriter ] | None
122+ filenames : list [Path ] | None
123+ thread_count : int
67124 path : Path | None
68125 temp_file : tempfile .SpooledTemporaryFile [bytes ] | None
69126 tensors : list [dict [str , TensorInfo ]]
@@ -85,7 +142,8 @@ class GGUFWriter:
85142
86143 def __init__ (
87144 self , path : os .PathLike [str ] | str | None , arch : str , use_temp_file : bool = False , endianess : GGUFEndian = GGUFEndian .LITTLE ,
88- split_max_tensors : int = 0 , split_max_size : int = 0 , dry_run : bool = False , small_first_shard : bool = False
145+ split_max_tensors : int = 0 , split_max_size : int = 0 , dry_run : bool = False , small_first_shard : bool = False ,
146+ thread_count : int = 2 ,
89147 ):
90148 self .fout = None
91149 self .path = Path (path ) if path else None
@@ -100,6 +158,7 @@ def __init__(
100158 self .split_max_size = split_max_size
101159 self .dry_run = dry_run
102160 self .small_first_shard = small_first_shard
161+ self .thread_count = thread_count
103162 logger .info ("gguf: This GGUF file is for {0} Endian only" .format (
104163 "Big" if self .endianess == GGUFEndian .BIG else "Little" ,
105164 ))
@@ -176,6 +235,7 @@ def open_output_file(self, path: Path | None = None) -> None:
176235
177236 if self .path is not None :
178237 filenames = self .print_plan ()
238+ self .filenames = filenames
179239 self .fout = [open (filename , "wb" ) for filename in filenames ]
180240 self .state = WriterState .EMPTY
181241
@@ -437,40 +497,76 @@ def write_tensors_to_file(self, *, progress: bool = False) -> None:
437497 self .write_ti_data_to_file ()
438498
439499 assert self .fout is not None
500+ assert self .filenames is not None
440501
441502 for fout in self .fout :
442503 self .write_padding (fout , fout .tell ())
443504
444505 if self .temp_file is None :
445- shard_bar = None
446506 bar = None
507+ # Initial file offsets before writing the tensor data
508+ offsets : list [int ] = [fout .tell () for fout in self .fout ]
447509
448510 if progress :
511+ # TODO: add back the shard bar to show which shard is being written when single-threaded
449512 from tqdm import tqdm
450513
451514 total_bytes = sum (ti .nbytes for t in self .tensors for ti in t .values ())
452515
453- if len (self .fout ) > 1 :
454- shard_bar = tqdm (desc = f"Shard (0/{ len (self .fout )} )" , total = None , unit = "byte" , unit_scale = True )
455516 bar = tqdm (desc = "Writing" , total = total_bytes , unit = "byte" , unit_scale = True )
456517
457- for i , (fout , tensors ) in enumerate (zip (self .fout , self .tensors )):
458- if shard_bar is not None :
459- shard_bar .set_description (f"Shard ({ i + 1 } /{ len (self .fout )} )" )
460- total = sum (ti .nbytes for ti in tensors .values ())
461- shard_bar .reset (total = (total if total > 0 else None ))
462-
463- # relying on the fact that Python dicts preserve insertion order (since 3.7)
464- for ti in tensors .values ():
465- assert ti .tensor is not None # can only iterate once over the tensors
466- assert ti .tensor .nbytes == ti .nbytes
467- ti .tensor .tofile (fout )
468- if shard_bar is not None :
469- shard_bar .update (ti .nbytes )
470- if bar is not None :
471- bar .update (ti .nbytes )
472- self .write_padding (fout , ti .nbytes )
473- ti .tensor = None
518+ # Allow opening the files only once per worker
519+ local_data = threading .local ()
520+
521+ # Unit of work
522+ def thread_write_tensor (tensor : _ThreadedTensorWriteInfo ):
523+ tensor .write_chunk (local_data .open_files )
524+
525+ with _InterruptibleThreadPoolExecutor (
526+ max_workers = self .thread_count ,
527+ initializer = _ThreadedOpenFiles .init_thread_local ,
528+ initargs = (local_data ,),
529+ ) as executor :
530+
531+ futures : list [Future ] = []
532+
533+ # Fill the tensor queue with all the pending tensor writes
534+ for i , (filename , tensors ) in enumerate (zip (self .filenames , self .tensors )):
535+ offset = offsets [i ]
536+
537+ # relying on the fact that Python dicts preserve insertion order (since 3.7)
538+ for ti in tensors .values ():
539+ assert ti .tensor is not None # can only iterate once over the tensors
540+ assert ti .tensor .nbytes == ti .nbytes
541+ start_offset = offset
542+ nbytes = ti .tensor .nbytes
543+ offset = self .ggml_pad (start_offset + nbytes , self .data_alignment )
544+ padding = offset - (start_offset + nbytes )
545+ futures .append (
546+ executor .submit (
547+ thread_write_tensor ,
548+ _ThreadedTensorWriteInfo (
549+ filename = filename ,
550+ offset = start_offset ,
551+ post_pad = padding ,
552+ tensor = ti .tensor ,
553+ bar = bar ,
554+ ),
555+ )
556+ )
557+ ti .tensor = None # avoid keeping a reference to written tensors
558+
559+ # FIXME: there's still some weird behavior with KeyboardInterrupt
560+ # not being able to interrupt a future mid-execution
561+ done , not_done = wait (futures , return_when = FIRST_EXCEPTION )
562+ exc = None
563+ if any (f for f in done
564+ if not f .cancelled () and (exc := f .exception ()) is not None ):
565+ raise RuntimeError ("Error writing tensors" ) from exc
566+ elif len (not_done ) != 0 :
567+ raise RuntimeError ("Not all tensors were written" )
568+
569+ del local_data
474570 else :
475571 self .temp_file .seek (0 )
476572
0 commit comments