Skip to content

INT4 W4A16 kernel for AWQ#57

Merged
mingfeima merged 19 commits into
mingfeima:cpu_opt_ww11from
gau-nernst:w4a16_kernel
May 19, 2025
Merged

INT4 W4A16 kernel for AWQ#57
mingfeima merged 19 commits into
mingfeima:cpu_opt_ww11from
gau-nernst:w4a16_kernel

Conversation

@gau-nernst
Copy link
Copy Markdown

@gau-nernst gau-nernst commented Apr 16, 2025

Modifications

Add INT4 W4A16 kernel for AWQ to replace torch._weight_int4pack_mm_for_cpu()

Key features

  • Integer zero point for AWQ (may need to support float zero point and without zero point in the future for other INT4 weights)
  • Use lookup table for fast UINT4->BF16 conversion
  • Compute-bound (large M) with brgemm (dequant B then call brgemm)

Benchmarks. Intel 8481C (VM access, not baremetal)

python -m sglang.bench_one_batch --batch-size 1 --input 1024 --output 1024 --model TechxGenus/DeepSeek-V2-Lite-Chat-AWQ --device cpu --trust-remote-code --dtype bfloat16

TechxGenus/DeepSeek-V2-Lite-Chat-AWQ (MoE) - 5.3GB active

Impl Prefill Decode
fix_awq (d3e2a509) 354.75 tok/s 25.00 tok/s
This PR, w/o FusedMoE (5744ff6c) 873.55 tok/s 24.90 tok/s
This PR, w/ FusedMoE (f779332) 964.16 tok/s 33.02 tok/s
  • Prefill is better thanks to AMX (via brgemm), decode is only better when FusedMoE is implemented
  • TODO: shared experts INT4? - but this requre changes in deepseek_v2.py, which can get quite messy

Let me know if you need me to add correctness test. I do have some quick tests when I was developing the kernel, as well as sanity check by looking at model outputs

Checklist

@gau-nernst gau-nernst marked this pull request as ready for review May 6, 2025 02:30
@mingfeima
Copy link
Copy Markdown
Owner

@Xia-Weiwen could you help review this one?

  • I remember the PR changed zero point to float which may harm the accuracy a little (might be mistaken about this). please check if we can follow strict AWQ int4 format.
  • Please evaluate the performance on our machines, achieved memory bandwidth with int8 on both DDR5 and MRDIMM. I will help check the kernel details later on once I have the bandwidth.
  • Examine can we merge some the quantization recipe with existing ones without creating a new one called Int4CPULinearMethod this might caused some trouble when upstreaming - ideally we can follow existing FP8 impl: the same quant config and on CPU device it goes to a CPU kernel impl other than CUDA ones.
  • One thing I concern is how to expand to other qformats, e.g. gguf etc. You may propose how we can elegantly do this if refactor is needed.

@gau-nernst
Copy link
Copy Markdown
Author

I can comment a bit on point 3 about Int4CPULinearMethod

  • cpu_opt_ww11 is based on an SGLang commit that relies on vLLM for AWQ link -> we can't directly add new CPU kernel to it. Hence, I added a separate Quant config/method to override AWQ, like you can see here.
  • Latest SGLang main has AWQConfig in its codebase directly link-> it would be possible to directly integrate over there.

On point 4 how to expand to other qformats

  • Apart from supporting GGUFs, I'm also interested in supporting compressed-tensors e.g. ISTA-DASLab/DeepSeek-R1-GPTQ-4b-128g-experts
  • One approach is to add CPU kernels (and weight packing logic) to existing GGUFConfig and CompressedTensorsConfig (currently these are in vLLM, SGLang don't have this. We would need SGLang to make a copy first)
  • Another approach is creating a separate "quantization config" that can override existing quantization. vLLM IPEX uses this approach link


qzeros = (qzeros.unsqueeze(-1) >> bitshifts) & 0xF
qzeros = qzeros.flatten(-2).to(torch.uint8)
return qweight, qzeros, scales
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By convention, scales is put before qzeros. How about reordering them?

@Xia-Weiwen
Copy link
Copy Markdown

I remember the PR changed zero point to float which may harm the accuracy a little (might be mistaken about this). please check if we can follow strict AWQ int4 format.

This PR uses integer zero points which aligns with AWQ format.

@Xia-Weiwen
Copy link
Copy Markdown

Xia-Weiwen commented May 9, 2025

Please evaluate the performance on our machines, achieved memory bandwidth with int8 on both DDR5 and MRDIMM.

@MingxuZh will help collect data. We will review the data internally first. Thanks.

@Xia-Weiwen
Copy link
Copy Markdown

Examine can we merge some the quantization recipe with existing ones without creating a new one called Int4CPULinearMethod this might caused some trouble when upstreaming - ideally we can follow existing FP8 impl: the same quant config and on CPU device it goes to a CPU kernel impl other than CUDA ones.
One thing I concern is how to expand to other qformats, e.g. gguf etc. You may propose how we can elegantly do this if refactor is needed.

For point 3 and 4, I think the comments from @gau-nernst are very helpful. And I believe @gau-nernst has thought about them seriously. Since I am not very familiar with the code of vLLM or SGLang, I am afraid I cannot give further comments on these issues. Thanks.

@mingfeima
Copy link
Copy Markdown
Owner

we tested the performance numbers and find out that when concurrent requests is small, performance improvement of int4 over int8 is not big, but the improvement grows bigger when concurrent requests go larger and larger. A little bit out of my expectation, anyway will debug deeper to find our why.

Copy link
Copy Markdown
Owner

@mingfeima mingfeima left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

merge this one now. I will help refine the kernels later on. We tested that the functionality is OK, performance wise still need some work.

Thanks the contribution! @gau-nernst

@mingfeima mingfeima merged commit 89737d3 into mingfeima:cpu_opt_ww11 May 19, 2025
1 check passed
@gau-nernst gau-nernst deleted the w4a16_kernel branch May 19, 2025 02:38
jianan-gu referenced this pull request in jianan-gu/sglang May 28, 2025
* fix

* fix AWQ for DSv3

* don't use absorb MLA for AWQ

* lint

* more fixes

* add w4a16 kernel

* remove unnecessary name

* add note

* use prefetch. simplify impl

* clean up. add brgemm (WIP)

* fix brgemm

* fix mismatch BLOCK_N

* use at::quint4x2 to signify type better

* change type of zero point back to uint8

* add FusedMoE interface

* use FusedMoE kernel

* fix types

* fix MoE

* update deepseek.cpp
jianan-gu referenced this pull request in jianan-gu/sglang Jun 12, 2025
* fix

* fix AWQ for DSv3

* don't use absorb MLA for AWQ

* lint

* more fixes

* add w4a16 kernel

* remove unnecessary name

* add note

* use prefetch. simplify impl

* clean up. add brgemm (WIP)

* fix brgemm

* fix mismatch BLOCK_N

* use at::quint4x2 to signify type better

* change type of zero point back to uint8

* add FusedMoE interface

* use FusedMoE kernel

* fix types

* fix MoE

* update deepseek.cpp
CaoE pushed a commit to CaoE/sglang that referenced this pull request Aug 14, 2025
@jiqing-feng
Copy link
Copy Markdown

Hi @gau-nernst . Thanks for your amazing work! I am a little curious that why the bf16_lut is

(0x0000, 0x4170, 0x4160, 0x4150, 0x4140, 0x4130, 0x4120, 0x4110,
0x4100, 0x40E0, 0x40C0, 0x40A0, 0x4080, 0x4040, 0x4000, 0x3F80,
0x0000,-0x4080,-0x4000,-0x3FC0,-0x3F80,-0x3F60,-0x3F40,-0x3F20,
-0x3F00,-0x3EF0,-0x3EE0,-0x3ED0,-0x3EC0,-0x3EB0,-0x3EA0,-0x3E90)

If I didn't misunderstand it, the value visible shoule be

(0.0f, 15.0f, 14.0f, 13.0f, 12.0f, 11.0f, 10.0f, 9.0f,
  8.0f, 7.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f,
  0.0f, -4.0f, -3.0f, -3.5f, -1.0f, -0.875f, -0.75f, -0.625f,
  -0.5f, -0.46875f, -0.4375f, -0.40625f, -0.375f, -0.34375f, -0.3125f, -0.28125f)

The (w-z) should be ranged from [-15, 15] and the (w-z+15) should be ranged from [0, 30]. I understand the positive value which map 15-30 to 0-15, but I don't understand why the negative value range from [-4, 0] instead of [-15, 0]? Appreciate it if you can give me some instructions. Thanks!

@gau-nernst
Copy link
Copy Markdown
Author

@jiqing-feng

import torch

x = torch.tensor([
    0x0000, 0x4170, 0x4160, 0x4150, 0x4140, 0x4130, 0x4120, 0x4110,
    0x4100, 0x40E0, 0x40C0, 0x40A0, 0x4080, 0x4040, 0x4000, 0x3F80,
    0x0000,-0x4080,-0x4000,-0x3FC0,-0x3F80,-0x3F60,-0x3F40,-0x3F20,
    -0x3F00,-0x3EF0,-0x3EE0,-0x3ED0,-0x3EC0,-0x3EB0,-0x3EA0,-0x3E90
], dtype=torch.int16)
x.view(torch.bfloat16).view(2, -1)
tensor([[  0.,  15.,  14.,  13.,  12.,  11.,  10.,   9.,   8.,   7.,   6.,   5.,
           4.,   3.,   2.,   1.],
        [  0.,  -1.,  -2.,  -3.,  -4.,  -5.,  -6.,  -7.,  -8.,  -9., -10., -11.,
         -12., -13., -14., -15.]], dtype=torch.bfloat16)

I wrote this a long time ago so I don't remember the exact details. From the snippet above, looks like it's working correctly?

I think negative hexadecimal numbers can be confusing. I can't remember how I obtained the LUT, but I think because _mm512_set_epi16() requires signed integer, I had to use negative int16 to obtain my desired bit pattern. A possibly more readable way is to have uint16_t[32] global array, then do a 512-bit load.

@jiqing-feng
Copy link
Copy Markdown

Hi @gau-nernst . I made a mistake to convert the BF16 to FP32. You were right. Thanks for your help!!!

@jiqing-feng
Copy link
Copy Markdown

Hi @gau-nernst . We are going to apply your amazing work in bitsandbytes. Could you please provide any guidance on the licensing/attribution we need for the code here? Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants