Skip to content
This repository was archived by the owner on Jan 24, 2024. It is now read-only.

Commit 81968c3

Browse files
committed
Allow loading model with 4bit quantization.
For detail on 4bit options, see: https://huggingface.co/blog/4bit-transformers-bitsandbytes
1 parent 1677491 commit 81968c3

3 files changed

Lines changed: 25 additions & 0 deletions

File tree

basaran/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ def is_true(value):
2222
MODEL_CACHE_DIR = os.getenv("MODEL_CACHE_DIR", "models")
2323
MODEL_LOAD_IN_8BIT = is_true(os.getenv("MODEL_LOAD_IN_8BIT", ""))
2424
MODEL_LOAD_IN_4BIT = is_true(os.getenv("MODEL_LOAD_IN_4BIT", ""))
25+
MODEL_4BIT_QUANT_TYPE = os.getenv("MODEL_4BIT_QUANT_TYPE", "fp4")
26+
MODEL_4BIT_DOUBLE_QUANT = is_true(os.getenv("MODEL_4BIT_DOUBLE_QUANT", ""))
2527
MODEL_LOCAL_FILES_ONLY = is_true(os.getenv("MODEL_LOCAL_FILES_ONLY", ""))
2628
MODEL_TRUST_REMOTE_CODE = is_true(os.getenv("MODEL_TRUST_REMOTE_CODE", ""))
2729
MODEL_HALF_PRECISION = is_true(os.getenv("MODEL_HALF_PRECISION", ""))

basaran/__main__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
from . import MODEL_CACHE_DIR
2222
from . import MODEL_LOAD_IN_8BIT
2323
from . import MODEL_LOAD_IN_4BIT
24+
from . import MODEL_4BIT_QUANT_TYPE
25+
from . import MODEL_4BIT_DOUBLE_QUANT
2426
from . import MODEL_LOCAL_FILES_ONLY
2527
from . import MODEL_TRUST_REMOTE_CODE
2628
from . import MODEL_HALF_PRECISION
@@ -44,6 +46,8 @@
4446
cache_dir=MODEL_CACHE_DIR,
4547
load_in_8bit=MODEL_LOAD_IN_8BIT,
4648
load_in_4bit=MODEL_LOAD_IN_4BIT,
49+
quant_type=MODEL_4BIT_QUANT_TYPE,
50+
double_quant=MODEL_4BIT_DOUBLE_QUANT,
4751
local_files_only=MODEL_LOCAL_FILES_ONLY,
4852
trust_remote_code=MODEL_TRUST_REMOTE_CODE,
4953
half_precision=MODEL_HALF_PRECISION,

basaran/model.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
MinNewTokensLengthLogitsProcessor,
1313
TemperatureLogitsWarper,
1414
TopPLogitsWarper,
15+
BitsAndBytesConfig
1516
)
1617

1718
from .choice import map_choice
@@ -311,6 +312,8 @@ def load_model(
311312
cache_dir=None,
312313
load_in_8bit=False,
313314
load_in_4bit=False,
315+
quant_type="fp4",
316+
double_quant=False,
314317
local_files_only=False,
315318
trust_remote_code=False,
316319
half_precision=False,
@@ -328,10 +331,26 @@ def load_model(
328331

329332
# Set device mapping and quantization options if CUDA is available.
330333
if torch.cuda.is_available():
334+
# Set quantization options if specified.
335+
quant_config = None
336+
if load_in_8bit and load_in_4bit:
337+
raise ValueError("Only one of load_in_8bit and load_in_4bit can be True")
338+
if load_in_8bit:
339+
quant_config = BitsAndBytesConfig(
340+
load_in_8bit=True,
341+
)
342+
elif load_in_4bit:
343+
quant_config = BitsAndBytesConfig(
344+
load_in_4bit=True,
345+
bnb_4bit_quant_type=quant_type,
346+
bnb_4bit_use_double_quant=double_quant,
347+
bnb_4bit_compute_dtype=torch.bfloat16,
348+
)
331349
kwargs = kwargs.copy()
332350
kwargs["device_map"] = "auto"
333351
kwargs["load_in_8bit"] = load_in_8bit
334352
kwargs["load_in_4bit"] = load_in_4bit
353+
kwargs["quantization_config"] = quant_config
335354

336355
# Cast all parameters to float16 if quantization is enabled.
337356
if half_precision or load_in_8bit or load_in_4bit:

0 commit comments

Comments
 (0)