Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions lmdeploy/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,12 +269,23 @@ def quant_policy(parser, default: int = 0):

from lmdeploy.messages import QuantPolicy

_aliases = {p.name.lower(): p.value for p in QuantPolicy}
_aliases['fp8_e4m3'] = QuantPolicy.FP8.value

def _parse(x):
key = x.lower()
if key in _aliases:
return _aliases[key]
v = int(x)
if v not in list(QuantPolicy):
raise ValueError(f'invalid quant_policy: {x!r}')
return v

return parser.add_argument('--quant-policy',
type=int,
type=_parse,
default=0,
choices=list(QuantPolicy),
help='KV cache quantization policy. '
'0: no quantization; 4: 4-bit; 8: 8-bit; 42: TurboQuant (K4V2)')
help='KV cache quant policy: none/int4/int8/fp8/fp8_e5m2/'
'turbo_quant (or 0/4/8/16/17/42). fp8 defaults to fp8_e4m3.')

@staticmethod
def rope_scaling_factor(parser):
Expand Down
26 changes: 18 additions & 8 deletions lmdeploy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ class QuantPolicy(enum.IntEnum):
NONE = 0
INT4 = 4 # 4-bit KV cache
INT8 = 8 # 8-bit KV cache
FP8 = 16 # FP8 KV cache (float8_e4m3fn, per-tensor scale)
FP8_E5M2 = 17 # FP8 KV cache (float8_e5m2, per-tensor scale)
TURBO_QUANT = 42 # TurboQuant: K=4bit QJL4 + V=2bit MSE

LogitsProcessor = Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
Expand Down Expand Up @@ -242,8 +244,9 @@ class TurbomindEngineConfig:
a k/v block, default to 64
enable_prefix_caching: enable cache prompts for block reuse,
default to False
quant_policy: default to 0. When k/v is quantized into 4 or 8
bit, set it to 4 or 8, respectively
quant_policy: default to 0. For TurboMind, when k/v is quantized
into int4, int8, or fp8, set it to 4, 8, or 16,
respectively
rope_scaling_factor: scaling factor used for dynamic ntk,
default to 0. TurboMind follows the implementation of transformer
LlamaAttention
Expand Down Expand Up @@ -311,8 +314,8 @@ def __post_init__(self):
assert self.dtype in ['auto', 'float16', 'bfloat16']
assert self.tp >= 1, 'tp must be a positive integer'
assert self.cache_max_entry_count > 0, 'invalid cache_max_entry_count'
assert self.quant_policy in (QuantPolicy.NONE, QuantPolicy.INT4, QuantPolicy.INT8, QuantPolicy.TURBO_QUANT), \
'invalid quant_policy'
assert self.quant_policy in (QuantPolicy.NONE, QuantPolicy.INT4, QuantPolicy.INT8, QuantPolicy.FP8,
QuantPolicy.TURBO_QUANT), 'invalid quant_policy'
assert self.rope_scaling_factor >= 0, 'invalid rope_scaling_factor'
assert self.max_prefill_token_num >= 0, \
'invalid max_prefill_token_num'
Expand Down Expand Up @@ -364,8 +367,9 @@ class PytorchEngineConfig:
revision: The specific model version to use.
It can be a branch name, a tag name, or a commit id.
If unspecified, will use the default version.
quant_policy: default to 0. When k/v is quantized into 4 or 8
bit, set it to 4 or 8, respectively
quant_policy: default to 0. When k/v is quantized into int4,
int8, fp8, or fp8_e5m2, set it to 4, 8, 16, or 17,
respectively
distributed_executor_backend: backend of distributed backend,
options: ['uni', 'mp', 'ray']
empty_init: Whether to load the model weights, you should set
Expand Down Expand Up @@ -457,8 +461,14 @@ def __post_init__(self):
assert self.max_prefill_token_num >= 0, \
'invalid max_prefill_token_num'
assert self.num_gpu_blocks >= 0, 'invalid num_gpu_blocks'
assert self.quant_policy in (QuantPolicy.NONE, QuantPolicy.INT4, QuantPolicy.INT8, QuantPolicy.TURBO_QUANT), \
'invalid quant_policy'
assert self.quant_policy in (
QuantPolicy.NONE,
QuantPolicy.INT4,
QuantPolicy.INT8,
QuantPolicy.FP8,
QuantPolicy.FP8_E5M2,
QuantPolicy.TURBO_QUANT,
), 'invalid quant_policy'
assert self.device_type in ['cuda', 'ascend', 'maca', 'camb'], (f'invalid device_type: {self.device_type}')
assert self.kernel_block_size >= 16 and \
(self.kernel_block_size & (self.kernel_block_size - 1)) == 0, \
Expand Down
2 changes: 2 additions & 0 deletions lmdeploy/pytorch/backends/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ def forward(
attn_metadata: T,
k_scales_zeros: torch.Tensor = None,
v_scales_zeros: torch.Tensor = None,
k_scale: torch.Tensor = None,
v_scale: torch.Tensor = None,
learnable_sink: torch.Tensor = None,
nsa_indices: torch.Tensor = None,
inplace: bool = False,
Expand Down
28 changes: 27 additions & 1 deletion lmdeploy/pytorch/backends/cuda/attention/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class TritonAttentionMetadata(AttentionMetadata):
q_seqlens: Length of each query sequence [batch_size].
kv_start_loc: Start location of each KV sequence [batch_size].
kv_seqlens: Length of each KV sequence [batch_size].
quant_policy: Quantization policy (0=none, 4=int4, 8=int8/fp8).
quant_policy: Quantization policy (0=none, 4=int4, 8=int8, 16/17=per-tensor fp8).
kv_flatten_size: Total size of flattened KV cache.
tile_scheduler_metadata: Scheduler metadata for Flash MLA.
num_splits: Number of splits for Flash MLA.
Expand Down Expand Up @@ -149,6 +149,8 @@ def _fill_kv_cache_impl(
max_q_seqlen: int,
k_scales_zeros: torch.Tensor = None,
v_scales_zeros: torch.Tensor = None,
k_scale: torch.Tensor = None,
v_scale: torch.Tensor = None,
):
"""Fill kv cache."""
kv_seqlens = attn_metadata.kv_seqlens
Expand All @@ -175,6 +177,8 @@ def _fill_kv_cache_impl(
block_offsets=block_offsets,
k_scales_zeros=k_scales_zeros,
v_scales_zeros=v_scales_zeros,
k_scale=k_scale,
v_scale=v_scale,
quant_policy=quant_policy,
)

Expand All @@ -187,6 +191,8 @@ def _forward_decoding(
max_q_seqlen: int,
k_scales_zeros: torch.Tensor = None,
v_scales_zeros: torch.Tensor = None,
k_scale: torch.Tensor = None,
v_scale: torch.Tensor = None,
learnable_sink: torch.Tensor = None,
) -> torch.Tensor:
"""Forward pass for decoding stage.
Expand All @@ -199,6 +205,8 @@ def _forward_decoding(
max_q_seqlen: Maximum query sequence length.
k_scales_zeros: Key quantization scales/zeros.
v_scales_zeros: Value quantization scales/zeros.
k_scale: Per-tensor key scale for normal FP8 KV cache.
v_scale: Per-tensor value scale for normal FP8 KV cache.
learnable_sink: Learnable sink tokens.

Returns:
Expand All @@ -224,6 +232,8 @@ def _forward_decoding(
quant_policy=quant_policy,
k_scales_zeros=k_scales_zeros,
v_scales_zeros=v_scales_zeros,
k_scale=k_scale,
v_scale=v_scale,
)
return attn_output

Expand All @@ -236,6 +246,8 @@ def _forward_prefill(
max_q_seqlen: int,
k_scales_zeros: torch.Tensor = None,
v_scales_zeros: torch.Tensor = None,
k_scale: torch.Tensor = None,
v_scale: torch.Tensor = None,
learnable_sink: torch.Tensor = None,
) -> torch.Tensor:
"""Forward pass for prefill stage.
Expand All @@ -248,6 +260,8 @@ def _forward_prefill(
max_q_seqlen: Maximum query sequence length.
k_scales_zeros: Key quantization scales/zeros.
v_scales_zeros: Value quantization scales/zeros.
k_scale: Per-tensor key scale for normal FP8 KV cache.
v_scale: Per-tensor value scale for normal FP8 KV cache.
learnable_sink: Learnable sink tokens.

Returns:
Expand Down Expand Up @@ -275,6 +289,8 @@ def _forward_prefill(
out_dtype=query.dtype,
k_scales_zeros=k_scales_zeros,
v_scales_zeros=v_scales_zeros,
k_scale=k_scale,
v_scale=v_scale,
quant_policy=quant_policy,
flatten_kv_layout=kv_layout,
)
Expand Down Expand Up @@ -323,6 +339,8 @@ def forward(
attn_metadata: TritonAttentionMetadata,
k_scales_zeros: torch.Tensor = None,
v_scales_zeros: torch.Tensor = None,
k_scale: torch.Tensor = None,
v_scale: torch.Tensor = None,
learnable_sink: torch.Tensor = None,
inplace: bool = True,
**kwargs,
Expand All @@ -343,6 +361,8 @@ def forward(
attn_metadata: Attention metadata containing stage info and indices.
k_scales_zeros: Key quantization scales/zeros.
v_scales_zeros: Value quantization scales/zeros.
k_scale: Per-tensor key scale for normal FP8 KV cache.
v_scale: Per-tensor value scale for normal FP8 KV cache.
learnable_sink: Learnable sink tokens.
inplace: Whether to modify query inplace (unused, kept for compatibility).

Expand All @@ -363,6 +383,8 @@ def forward(
max_q_seqlen=max_q_seqlen,
k_scales_zeros=k_scales_zeros,
v_scales_zeros=v_scales_zeros,
k_scale=k_scale,
v_scale=v_scale,
)

# Validate alibi configuration
Expand All @@ -379,6 +401,8 @@ def forward(
max_q_seqlen,
k_scales_zeros=k_scales_zeros,
v_scales_zeros=v_scales_zeros,
k_scale=k_scale,
v_scale=v_scale,
learnable_sink=learnable_sink,
)
else:
Expand All @@ -390,5 +414,7 @@ def forward(
max_q_seqlen,
k_scales_zeros=k_scales_zeros,
v_scales_zeros=v_scales_zeros,
k_scale=k_scale,
v_scale=v_scale,
learnable_sink=learnable_sink,
)
28 changes: 27 additions & 1 deletion lmdeploy/pytorch/backends/cuda/attention/fa3.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ def _decoding_standard(
max_q_seqlen: int,
k_scales_zeros: torch.Tensor = None,
v_scales_zeros: torch.Tensor = None,
k_scale: torch.Tensor = None,
v_scale: torch.Tensor = None,
) -> torch.Tensor:
"""Standard single-token decoding.

Expand All @@ -160,6 +162,8 @@ def _decoding_standard(
max_q_seqlen: Maximum query sequence length (= 1).
k_scales_zeros: Key quantization scales/zeros.
v_scales_zeros: Value quantization scales/zeros.
k_scale: Scalar key scale for normal FP8 KV cache.
v_scale: Scalar value scale for normal FP8 KV cache.

Returns:
Attention output tensor.
Expand All @@ -183,6 +187,8 @@ def _decoding_standard(
# custom args
k_scales_zeros=k_scales_zeros,
v_scales_zeros=v_scales_zeros,
k_scale=k_scale,
v_scale=v_scale,
quant_policy=quant_policy,
)
return attn_output
Expand All @@ -196,6 +202,8 @@ def _forward_decoding(
max_q_seqlen: int,
k_scales_zeros: torch.Tensor = None,
v_scales_zeros: torch.Tensor = None,
k_scale: torch.Tensor = None,
v_scale: torch.Tensor = None,
) -> torch.Tensor:
"""Forward pass for decoding stage.

Expand All @@ -211,6 +219,8 @@ def _forward_decoding(
max_q_seqlen: Maximum query sequence length.
k_scales_zeros: Key quantization scales/zeros.
v_scales_zeros: Value quantization scales/zeros.
k_scale: Scalar key scale for normal FP8 KV cache.
v_scale: Scalar value scale for normal FP8 KV cache.

Returns:
Attention output tensor.
Expand All @@ -219,7 +229,7 @@ def _forward_decoding(
return self._decoding_speculative(query, k_cache, v_cache, attn_metadata, max_q_seqlen)
else:
return self._decoding_standard(query, k_cache, v_cache, attn_metadata, max_q_seqlen, k_scales_zeros,
v_scales_zeros)
v_scales_zeros, k_scale, v_scale)

def _forward_prefill(
self,
Expand All @@ -230,6 +240,8 @@ def _forward_prefill(
max_q_seqlen: int,
k_scales_zeros: torch.Tensor = None,
v_scales_zeros: torch.Tensor = None,
k_scale: torch.Tensor = None,
v_scale: torch.Tensor = None,
) -> torch.Tensor:
"""Forward pass for prefill stage.

Expand All @@ -244,6 +256,8 @@ def _forward_prefill(
max_q_seqlen: Maximum query sequence length.
k_scales_zeros: Key quantization scales/zeros.
v_scales_zeros: Value quantization scales/zeros.
k_scale: Scalar key scale for normal FP8 KV cache.
v_scale: Scalar value scale for normal FP8 KV cache.

Returns:
Attention output tensor.
Expand All @@ -265,6 +279,8 @@ def _forward_prefill(
out_dtype=query.dtype,
k_scales_zeros=k_scales_zeros,
v_scales_zeros=v_scales_zeros,
k_scale=k_scale,
v_scale=v_scale,
quant_policy=quant_policy,
flatten_kv_layout='shd',
)
Expand Down Expand Up @@ -310,6 +326,8 @@ def forward(
attn_metadata: TritonAttentionMetadata,
k_scales_zeros: torch.Tensor = None,
v_scales_zeros: torch.Tensor = None,
k_scale: torch.Tensor = None,
v_scale: torch.Tensor = None,
learnable_sink: torch.Tensor = None,
inplace: bool = True,
) -> torch.Tensor:
Expand All @@ -333,6 +351,8 @@ def forward(
attn_metadata: Attention metadata containing stage info and indices.
k_scales_zeros: Key quantization scales/zeros.
v_scales_zeros: Value quantization scales/zeros.
k_scale: Scalar key scale for normal FP8 KV cache.
v_scale: Scalar value scale for normal FP8 KV cache.
learnable_sink: Learnable sink tokens (unused in FA3).
inplace: Whether to modify query inplace (unused, kept for compatibility).

Expand All @@ -353,6 +373,8 @@ def forward(
max_q_seqlen=max_q_seqlen,
k_scales_zeros=k_scales_zeros,
v_scales_zeros=v_scales_zeros,
k_scale=k_scale,
v_scale=v_scale,
)

# Dispatch to stage-specific forward method
Expand All @@ -365,6 +387,8 @@ def forward(
max_q_seqlen,
k_scales_zeros,
v_scales_zeros,
k_scale,
v_scale,
)
else:
return self._forward_prefill(
Expand All @@ -375,4 +399,6 @@ def forward(
max_q_seqlen,
k_scales_zeros,
v_scales_zeros,
k_scale,
v_scale,
)
2 changes: 2 additions & 0 deletions lmdeploy/pytorch/backends/dlinfer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ def forward(
attn_metadata: DlinferAttentionMetadata,
k_scales_zeros: Tensor = None,
v_scales_zeros: Tensor = None,
k_scale: Tensor = None,
v_scale: Tensor = None,
learnable_sink: Tensor = None,
nsa_indices: Tensor = None,
inplace: bool = True,
Expand Down
2 changes: 1 addition & 1 deletion lmdeploy/pytorch/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def from_pretrained(
activations. Refer to `PyTorchEngineConfig` for details
hf_overrides (dict[str, Any]): overrides for the HF config.
"""
from transformers import AutoConfig
from transformers import AutoConfig # noqa: I001

from lmdeploy.pytorch.transformers import config_from_pretrained
hf_config = config_from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust_remote_code)
Expand Down
Loading
Loading