Skip to content

Commit dceb63b

Browse files
Merge branch 'foundation-model-stack:main' into main
2 parents 9d79ba7 + 7777b49 commit dceb63b

6 files changed

Lines changed: 152 additions & 54 deletions

File tree

examples/GPTQ/README.md

Lines changed: 53 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ For generative LLMs, very often the bottleneck of inference is no longer the com
77

88
- [FMS Model Optimizer requirements](../../README.md#requirements)
99
- `gptqmodel` is needed for this example. Use `pip install gptqmodel` or [install from source](https://github.com/ModelCloud/GPTQModel/tree/main?tab=readme-ov-file)
10+
- It is advised to install from source if you plan to use `GPTQv2`
1011
- Optionally for the evaluation section below, install [lm-eval](https://github.com/EleutherAI/lm-evaluation-harness)
1112
```
1213
pip install lm-eval
@@ -32,7 +33,7 @@ This end-to-end example utilizes the common set of interfaces provided by `fms_m
3233
> - Tokenized data will be saved in `<path_to_save>_train` and `<path_to_save>_test`
3334
> - If you have trouble downloading Llama family of models from Hugging Face ([LLama models require access](https://www.llama.com/docs/getting-the-models/hugging-face/)), you can use `ibm-granite/granite-8b-code` instead
3435
35-
2. **Quantize the model** using the data generated above, the following command will kick off the quantization job (by invoking `gptqmodel` under the hood.) Additional acceptable arguments can be found here in [GPTQArguments](../../fms_mo/training_args.py#L127).
36+
2. **Quantize the model** using the data generated above, the following command will kick off the `GPTQv1' quantization job (by invoking `gptqmodel` under the hood.) Additional acceptable arguments can be found here in [GPTQArguments](../../fms_mo/training_args.py#L127).
3637
3738
```bash
3839
python -m fms_mo.run_quant \
@@ -41,9 +42,10 @@ This end-to-end example utilizes the common set of interfaces provided by `fms_m
4142
--quant_method gptq \
4243
--output_dir Meta-Llama-3-8B-GPTQ \
4344
--bits 4 \
44-
--group_size 128
45+
--group_size 128 \
46+
4547
```
46-
The model that can be found in the specified output directory (`Meta-Llama-3-8B-GPTQ` in our case) can be deployed and inferenced via `vLLM`.
48+
The model that can be found in the specified output directory (`Meta-Llama-3-8B-GPTQ` in our case) can be deployed and inferenced via `vLLM`. To enable `GPTQv2`, set the `quant_method` argument to `gptqv2`.
4749
4850
> [!NOTE]
4951
> - In GPTQ, `group_size` is a trade-off between accuracy and speed, but there is an additional constraint that `in_features` of the Linear layer to be quantized needs to be an **integer multiple** of `group_size`, i.e. some models may have to use smaller `group_size` than default.
@@ -82,44 +84,67 @@ This end-to-end example utilizes the common set of interfaces provided by `fms_m
8284
## Example Test Results
8385
8486
- Unquantized Model
85-
-
86-
|Model | Tasks |Version|Filter|n-shot| Metric | |Value | |Stderr|
87-
|------------|--------------|------:|------|-----:|----------|---|-----:|---|-----:|
88-
| LLAMA3-8B |lambada_openai| 1|none | 5|acc |↑ |0.7103|± |0.0063|
89-
| | | |none | 5|perplexity|↓ |3.7915|± |0.0727|
87+
88+
|Model | Tasks |Version|Filter|n-shot| Metric | |Value | |Stderr|
89+
|------------|--------------|------:|------|-----:|----------|---|-----:|---|-----:|
90+
| LLAMA3-8B |lambada_openai| 1|none | 5|acc |↑ |0.7103|± |0.0063|
91+
| | | |none | 5|perplexity|↓ |3.7915|± |0.0727|
9092
9193
- Quantized model with the settings showed above (`desc_act` default to False.)
92-
-
93-
|Model | Tasks |Version|Filter|n-shot| Metric | |Value | |Stderr|
94-
|------------|--------------|------:|------|-----:|----------|---|------:|---|-----:|
95-
| LLAMA3-8B |lambada_openai| 1|none | 5|acc |↑ |0.6365 |± |0.0067|
96-
| | | |none | 5|perplexity|↓ |5.9307 |± |0.1830|
94+
- `GPTQv1`
95+
96+
|Model | Tasks |Version|Filter|n-shot| Metric | |Value | |Stderr|
97+
|------------|--------------|------:|------|-----:|----------|---|------:|---|-----:|
98+
| LLAMA3-8B |lambada_openai| 1|none | 5|acc |↑ |0.6365 |± |0.0067|
99+
| | | |none | 5|perplexity|↓ |5.9307 |± |0.1830|
100+
101+
- `GPTQv2`
102+
103+
|Model | Tasks |Version|Filter|n-shot| Metric | |Value | |Stderr|
104+
|------------|--------------|------:|------|-----:|----------|---|------:|---|-----:|
105+
| LLAMA3-8B |lambada_openai| 1|none | 5|acc |↑ |0.6817 |± |0.0065|
106+
| | | |none | 5|perplexity|↓ |4.3994 |± |0.0995|
97107
98108
- Quantized model with `desc_act` set to `True` (could improve the model quality, but at the cost of inference speed.)
99-
-
100-
|Model | Tasks |Version|Filter|n-shot| Metric | |Value | |Stderr|
101-
|------------|--------------|------:|------|-----:|----------|---|------:|---|-----:|
102-
| LLAMA3-8B |lambada_openai| 1|none | 5|acc |↑ |0.6193 |± |0.0068|
103-
| | | |none | 5|perplexity|↓ |5.8879 |± |0.1546|
109+
- `GPTQv1`
110+
|Model | Tasks |Version|Filter|n-shot| Metric | |Value | |Stderr|
111+
|------------|--------------|------:|------|-----:|----------|---|------:|---|-----:|
112+
| LLAMA3-8B |lambada_openai| 1|none | 5|acc |↑ |0.6193 |± |0.0068|
113+
| | | |none | 5|perplexity|↓ |5.8879 |± |0.1546|
104114
105115
> [!NOTE]
106116
> There is some randomness in generating the model and data, the resulting accuracy may vary ~$\pm$ 0.05.
107117
108118
109119
## Code Walk-through
110120
111-
1. Command line arguments will be used to create a GPTQ quantization config. Information about the required arguments and their default values can be found [here](../../fms_mo/training_args.py)
121+
1. Command line arguments will be used to create a GPTQ quantization config. Information about the required arguments and their default values can be found [here](../../fms_mo/training_args.py). `GPTQv1` and `GPTQv2` is supported.
112122
113-
```python
114-
from gptqmodel import GPTQModel, QuantizeConfig
123+
- To use `GPTQv1`, set the parameter `quant_method` to `gptq` in the command line.
115124
116-
quantize_config = QuantizeConfig(
117-
bits=gptq_args.bits,
118-
group_size=gptq_args.group_size,
119-
desc_act=gptq_args.desc_act,
120-
damp_percent=gptq_args.damp_percent,
121-
)
125+
```python
126+
from gptqmodel import GPTQModel, QuantizeConfig
127+
128+
quantize_config = QuantizeConfig(
129+
bits=gptq_args.bits,
130+
group_size=gptq_args.group_size,
131+
desc_act=gptq_args.desc_act,
132+
damp_percent=gptq_args.damp_percent,
133+
)
134+
```
135+
- To use `GPTQv2`, simply set `quant_method` to `gptqv2`in the command line. Under the hood, two additional arguments will be added to QuantizeConfig, i.e. `v2` = `True` and `v2_memory_device` = `cpu`.
122136
137+
```python
138+
from gptqmodel import GPTQModel, QuantizeConfig
139+
140+
quantize_config = QuantizeConfig(
141+
bits=gptq_args.bits,
142+
group_size=gptq_args.group_size,
143+
desc_act=gptq_args.desc_act,
144+
damp_percent=gptq_args.damp_percent,
145+
v2=True,
146+
v2_memory_device='cpu',
147+
)
123148
```
124149
125150
2. Load the pre_trained model with `gptqmodel` class/wrapper. Tokenizer is optional because we already tokenized the data in a previous step.
@@ -158,4 +183,4 @@ This end-to-end example utilizes the common set of interfaces provided by `fms_m
158183
tokenizer.save_pretrained(output_dir) # optional
159184
```
160185
> [!NOTE]
161-
> 1. GPTQ of a 70B model usually takes ~4-10 hours on A100.
186+
> 1. GPTQ of a 70B model usually takes ~4-10 hours on A100 with `GPTQv1`.

fms_mo/aiu_addons/fp8/fp8_attn.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,12 +318,31 @@ def _spyre_scaled_paged_compute_op(
318318
attn_kwargs["block_table"],
319319
)
320320

321+
def __spyre_scaled_paged_validate_attn_kwargs_op(
322+
input_ids: torch.Tensor,
323+
position_ids: torch.Tensor,
324+
past_key_value_states: Optional[list[tuple[torch.Tensor, torch.Tensor]]] = None,
325+
**attn_kwargs,
326+
):
327+
__spyre_paged_validate_attn_kwargs_op(
328+
input_ids, position_ids, past_key_value_states, **attn_kwargs
329+
)
330+
331+
if past_key_value_states is not None:
332+
for k, v in past_key_value_states:
333+
assert isinstance(k, ScaledTensor)
334+
assert isinstance(v, ScaledTensor)
335+
336+
# assert that for each layer, the scales are per-sequence
337+
assert k._scale.shape[0] == input_ids.shape[0]
338+
assert v._scale.shape[0] == input_ids.shape[0]
339+
321340
register_attention_op(
322341
"spyre_paged_attn_fp8",
323342
_spyre_scaled_paged_store_op,
324343
compute_op=_math_fp8_compute_op,
325344
is_prefill_op=lambda **attn_kwargs: attn_kwargs.get("block_table", None)
326345
is None,
327346
compute_decode_op=_spyre_scaled_paged_compute_op,
328-
validate_attn_kwargs_op=__spyre_paged_validate_attn_kwargs_op,
347+
validate_attn_kwargs_op=__spyre_scaled_paged_validate_attn_kwargs_op,
329348
)

fms_mo/quant/quantizers.py

Lines changed: 54 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -123,23 +123,28 @@ def get_activation_quantizer(
123123
)
124124
elif qa_mode == "dorefa":
125125
act_quantizer = dorefa_quantize_activation
126-
elif (
127-
qa_mode == "max"
128-
): # NOTE Need to be careful using this for activation, particular to 1 sided.
129-
act_quantizer = Qmax(nbits, align_zero=align_zero, minmax=False)
130-
elif qa_mode == "minmax":
131-
act_quantizer = Qmax(nbits, align_zero=align_zero, minmax=True)
126+
127+
elif "max" in qa_mode:
128+
# NOTE Need to be careful using this for activation, particular to 1 sided.
129+
if "min" in qa_mode:
130+
act_quantizer = Qmax(nbits, align_zero=align_zero, minmax=True)
131+
elif "pertoken" in qa_mode or "perToken" in qa_mode:
132+
act_quantizer = QMaxDynamic(nbits, dim=-1)
133+
elif "per_channel" in qa_mode or "perCh" in qa_mode:
134+
act_quantizer = QMaxDynamic(nbits, dim=-2)
135+
elif "sym" in qa_mode:
136+
act_quantizer = Qmax(
137+
nbits,
138+
align_zero=True,
139+
minmax=False,
140+
extend_act_range=extend_act_range,
141+
)
142+
else:
143+
act_quantizer = Qmax(nbits, align_zero=align_zero, minmax=False)
132144
elif qa_mode == "fix":
133145
act_quantizer = QFixSymmetric(
134146
nbits, init_clip_val=clip_val, align_zero=align_zero
135147
)
136-
elif qa_mode == "maxsym":
137-
act_quantizer = Qmax(
138-
nbits,
139-
align_zero=True,
140-
minmax=False,
141-
extend_act_range=extend_act_range,
142-
)
143148
elif qa_mode == "pactsym":
144149
act_quantizer = PACT2Sym(
145150
nbits,
@@ -179,8 +184,6 @@ def get_activation_quantizer(
179184
perToken=perToken,
180185
emulate=True,
181186
)
182-
elif qa_mode == "pertokenmax":
183-
act_quantizer = PerTokenMax(nbits)
184187
else:
185188
raise ValueError(f"unrecognized activation quantization mode {qa_mode}")
186189
else: # swcap-compatible activation quantizers
@@ -3491,6 +3494,42 @@ def __repr__(self):
34913494
return f"{self.__class__.__name__}(num_bits={self.num_bits}, quantizer=)"
34923495

34933496

3497+
class QMaxDynamic(nn.Module):
3498+
def __init__(self, num_bits, dim=-1):
3499+
"""
3500+
For per-token or per-channel quantization using abs().max() as scale, usually for activation
3501+
and could be used for Qbmm M2 as well.
3502+
(reduce) dim = -1 -> abs() will output a column vector (if input is 2D) => per token
3503+
dim = -2 -> per-channel
3504+
Zero is aligned so that the levels are symmetric around zero (lossing one level)
3505+
Since the token length is un-known before running, the quantizater can only calculate the
3506+
scales at the run times dynamically, meaning no trainable quantization scales is allowed.
3507+
(unless input seq length is always the same, not just padded to a fixed length.)
3508+
"""
3509+
super().__init__()
3510+
self.num_bits = num_bits
3511+
self.levels = 2 ** (self.num_bits - 1) - 1
3512+
if isinstance(dim, str):
3513+
if "perCh" in dim or "per_channel" in dim:
3514+
dim = -2
3515+
elif "perToken" in dim or "per_token" in dim or "per_Token" in dim:
3516+
dim = -1
3517+
elif dim in [-1, -2]:
3518+
self.reduce_dim = dim
3519+
else:
3520+
raise ValueError(
3521+
f"Reduce dim can only be [-1, -2] or ['perCh', 'perToken'] but found {dim}"
3522+
)
3523+
3524+
def forward(self, input_tensor):
3525+
amax_dim = input_tensor.abs().max(dim=self.reduce_dim, keepdim=True)[0]
3526+
scales = amax_dim.clamp(min=1e-5).div(self.levels)
3527+
return input_tensor.div(scales).round().mul(scales)
3528+
3529+
def __repr__(self):
3530+
return f"{self.__class__.__name__}(num_bits={self.num_bits}, quantizer=)"
3531+
3532+
34943533
class Qdynamic(nn.Module):
34953534
def __init__(
34963535
self,

fms_mo/run_quant.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def quantize(
8888

8989
logger.info(f"{fms_mo_args}\n{opt_args.quant_method}\n")
9090

91-
if opt_args.quant_method == "gptq":
91+
if opt_args.quant_method in ["gptq", "gptqv2"]:
9292
if not available_packages["gptqmodel"]:
9393
raise ImportError(
9494
"Quantization method has been selected as gptq but unable to use external library, "
@@ -138,12 +138,23 @@ def run_gptq(model_args, data_args, opt_args, gptq_args):
138138

139139
logger = set_log_level(opt_args.log_level, "fms_mo.run_gptq")
140140

141-
quantize_config = QuantizeConfig(
142-
bits=gptq_args.bits,
143-
group_size=gptq_args.group_size,
144-
desc_act=gptq_args.desc_act,
145-
damp_percent=gptq_args.damp_percent,
146-
)
141+
if opt_args.quant_method == "gptq":
142+
quantize_config = QuantizeConfig(
143+
bits=gptq_args.bits,
144+
group_size=gptq_args.group_size,
145+
desc_act=gptq_args.desc_act,
146+
damp_percent=gptq_args.damp_percent,
147+
)
148+
else:
149+
quantize_config = QuantizeConfig(
150+
bits=gptq_args.bits,
151+
group_size=gptq_args.group_size,
152+
desc_act=gptq_args.desc_act,
153+
damp_percent=gptq_args.damp_percent,
154+
v2=True,
155+
v2_memory_device="cpu",
156+
)
157+
147158

148159
# Add custom model_type mapping to gptqmodel LUT so GPTQModel can recognize them.
149160
for mtype, cls in custom_gptq_classes.items():

fms_mo/training_args.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,10 @@ class OptArguments(TypeChecker):
138138
"""Dataclass for optimization related arguments."""
139139

140140
quant_method: str = field(
141-
metadata={"choices": ["gptq", "fp8", "dq"], "help": "Quantization technique"}
141+
metadata={
142+
"choices": ["gptq", "gptqv2", "fp8", "dq"],
143+
"help": "Quantization technique"
144+
}
142145
)
143146
output_dir: str = field(
144147
metadata={
@@ -226,6 +229,7 @@ class GPTQArguments(TypeChecker):
226229
cache_examples_on_gpu: bool = True
227230

228231

232+
229233
@dataclass
230234
class FP8Arguments(TypeChecker):
231235
"""Dataclass for FP8 related arguments that will be used by llm-compressor."""

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ classifiers=[
2323
dynamic = ["version"]
2424
dependencies = [
2525
"numpy>=1.26.4,<2.3.0",
26-
"accelerate>=0.20.3,!=0.34,<1.9",
26+
"accelerate>=0.20.3,!=0.34,<1.10",
2727
"transformers>=4.45,<4.54",
28-
"torch>=2.2.0,<2.6",
28+
"torch>=2.2.0,<2.8",
2929
"tqdm>=4.66.2,<5.0",
3030
"datasets>=3.0.0,<5.0",
3131
"pandas",

0 commit comments

Comments
 (0)