This repository was archived by the owner on Jan 24, 2024. It is now read-only.
File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -22,6 +22,8 @@ def is_true(value):
2222MODEL_CACHE_DIR = os .getenv ("MODEL_CACHE_DIR" , "models" )
2323MODEL_LOAD_IN_8BIT = is_true (os .getenv ("MODEL_LOAD_IN_8BIT" , "" ))
2424MODEL_LOAD_IN_4BIT = is_true (os .getenv ("MODEL_LOAD_IN_4BIT" , "" ))
25+ MODEL_4BIT_QUANT_TYPE = os .getenv ("MODEL_4BIT_QUANT_TYPE" , "fp4" )
26+ MODEL_4BIT_DOUBLE_QUANT = is_true (os .getenv ("MODEL_4BIT_DOUBLE_QUANT" , "" ))
2527MODEL_LOCAL_FILES_ONLY = is_true (os .getenv ("MODEL_LOCAL_FILES_ONLY" , "" ))
2628MODEL_TRUST_REMOTE_CODE = is_true (os .getenv ("MODEL_TRUST_REMOTE_CODE" , "" ))
2729MODEL_HALF_PRECISION = is_true (os .getenv ("MODEL_HALF_PRECISION" , "" ))
Original file line number Diff line number Diff line change 2121from . import MODEL_CACHE_DIR
2222from . import MODEL_LOAD_IN_8BIT
2323from . import MODEL_LOAD_IN_4BIT
24+ from . import MODEL_4BIT_QUANT_TYPE
25+ from . import MODEL_4BIT_DOUBLE_QUANT
2426from . import MODEL_LOCAL_FILES_ONLY
2527from . import MODEL_TRUST_REMOTE_CODE
2628from . import MODEL_HALF_PRECISION
4446 cache_dir = MODEL_CACHE_DIR ,
4547 load_in_8bit = MODEL_LOAD_IN_8BIT ,
4648 load_in_4bit = MODEL_LOAD_IN_4BIT ,
49+ quant_type = MODEL_4BIT_QUANT_TYPE ,
50+ double_quant = MODEL_4BIT_DOUBLE_QUANT ,
4751 local_files_only = MODEL_LOCAL_FILES_ONLY ,
4852 trust_remote_code = MODEL_TRUST_REMOTE_CODE ,
4953 half_precision = MODEL_HALF_PRECISION ,
Original file line number Diff line number Diff line change 1212 MinNewTokensLengthLogitsProcessor ,
1313 TemperatureLogitsWarper ,
1414 TopPLogitsWarper ,
15+ BitsAndBytesConfig
1516)
1617
1718from .choice import map_choice
@@ -311,6 +312,8 @@ def load_model(
311312 cache_dir = None ,
312313 load_in_8bit = False ,
313314 load_in_4bit = False ,
315+ quant_type = "fp4" ,
316+ double_quant = False ,
314317 local_files_only = False ,
315318 trust_remote_code = False ,
316319 half_precision = False ,
@@ -328,10 +331,26 @@ def load_model(
328331
329332 # Set device mapping and quantization options if CUDA is available.
330333 if torch .cuda .is_available ():
334+ # Set quantization options if specified.
335+ quant_config = None
336+ if load_in_8bit and load_in_4bit :
337+ raise ValueError ("Only one of load_in_8bit and load_in_4bit can be True" )
338+ if load_in_8bit :
339+ quant_config = BitsAndBytesConfig (
340+ load_in_8bit = True ,
341+ )
342+ elif load_in_4bit :
343+ quant_config = BitsAndBytesConfig (
344+ load_in_4bit = True ,
345+ bnb_4bit_quant_type = quant_type ,
346+ bnb_4bit_use_double_quant = double_quant ,
347+ bnb_4bit_compute_dtype = torch .bfloat16 ,
348+ )
331349 kwargs = kwargs .copy ()
332350 kwargs ["device_map" ] = "auto"
333351 kwargs ["load_in_8bit" ] = load_in_8bit
334352 kwargs ["load_in_4bit" ] = load_in_4bit
353+ kwargs ["quantization_config" ] = quant_config
335354
336355 # Cast all parameters to float16 if quantization is enabled.
337356 if half_precision or load_in_8bit or load_in_4bit :
You can’t perform that action at this time.
0 commit comments