diff --git a/lmdeploy/cli/utils.py b/lmdeploy/cli/utils.py index abf9419d35..5d448f5ed1 100644 --- a/lmdeploy/cli/utils.py +++ b/lmdeploy/cli/utils.py @@ -269,12 +269,23 @@ def quant_policy(parser, default: int = 0): from lmdeploy.messages import QuantPolicy + _aliases = {p.name.lower(): p.value for p in QuantPolicy} + _aliases['fp8_e4m3'] = QuantPolicy.FP8.value + + def _parse(x): + key = x.lower() + if key in _aliases: + return _aliases[key] + v = int(x) + if v not in list(QuantPolicy): + raise ValueError(f'invalid quant_policy: {x!r}') + return v + return parser.add_argument('--quant-policy', - type=int, + type=_parse, default=0, - choices=list(QuantPolicy), - help='KV cache quantization policy. ' - '0: no quantization; 4: 4-bit; 8: 8-bit; 42: TurboQuant (K4V2)') + help='KV cache quant policy: none/int4/int8/fp8/fp8_e5m2/' + 'turbo_quant (or 0/4/8/16/17/42). fp8 defaults to fp8_e4m3.') @staticmethod def rope_scaling_factor(parser): diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index 13c95d2be6..ec3240e3ca 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -22,6 +22,8 @@ class QuantPolicy(enum.IntEnum): NONE = 0 INT4 = 4 # 4-bit KV cache INT8 = 8 # 8-bit KV cache + FP8 = 16 # FP8 KV cache (float8_e4m3fn, per-tensor scale) + FP8_E5M2 = 17 # FP8 KV cache (float8_e5m2, per-tensor scale) TURBO_QUANT = 42 # TurboQuant: K=4bit QJL4 + V=2bit MSE LogitsProcessor = Callable[[torch.Tensor, torch.Tensor], torch.Tensor] @@ -242,8 +244,9 @@ class TurbomindEngineConfig: a k/v block, default to 64 enable_prefix_caching: enable cache prompts for block reuse, default to False - quant_policy: default to 0. When k/v is quantized into 4 or 8 - bit, set it to 4 or 8, respectively + quant_policy: default to 0. For TurboMind, when k/v is quantized + into int4, int8, or fp8, set it to 4, 8, or 16, + respectively rope_scaling_factor: scaling factor used for dynamic ntk, default to 0. TurboMind follows the implementation of transformer LlamaAttention @@ -311,8 +314,8 @@ def __post_init__(self): assert self.dtype in ['auto', 'float16', 'bfloat16'] assert self.tp >= 1, 'tp must be a positive integer' assert self.cache_max_entry_count > 0, 'invalid cache_max_entry_count' - assert self.quant_policy in (QuantPolicy.NONE, QuantPolicy.INT4, QuantPolicy.INT8, QuantPolicy.TURBO_QUANT), \ - 'invalid quant_policy' + assert self.quant_policy in (QuantPolicy.NONE, QuantPolicy.INT4, QuantPolicy.INT8, QuantPolicy.FP8, + QuantPolicy.TURBO_QUANT), 'invalid quant_policy' assert self.rope_scaling_factor >= 0, 'invalid rope_scaling_factor' assert self.max_prefill_token_num >= 0, \ 'invalid max_prefill_token_num' @@ -364,8 +367,9 @@ class PytorchEngineConfig: revision: The specific model version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version. - quant_policy: default to 0. When k/v is quantized into 4 or 8 - bit, set it to 4 or 8, respectively + quant_policy: default to 0. When k/v is quantized into int4, + int8, fp8, or fp8_e5m2, set it to 4, 8, 16, or 17, + respectively distributed_executor_backend: backend of distributed backend, options: ['uni', 'mp', 'ray'] empty_init: Whether to load the model weights, you should set @@ -457,8 +461,14 @@ def __post_init__(self): assert self.max_prefill_token_num >= 0, \ 'invalid max_prefill_token_num' assert self.num_gpu_blocks >= 0, 'invalid num_gpu_blocks' - assert self.quant_policy in (QuantPolicy.NONE, QuantPolicy.INT4, QuantPolicy.INT8, QuantPolicy.TURBO_QUANT), \ - 'invalid quant_policy' + assert self.quant_policy in ( + QuantPolicy.NONE, + QuantPolicy.INT4, + QuantPolicy.INT8, + QuantPolicy.FP8, + QuantPolicy.FP8_E5M2, + QuantPolicy.TURBO_QUANT, + ), 'invalid quant_policy' assert self.device_type in ['cuda', 'ascend', 'maca', 'camb'], (f'invalid device_type: {self.device_type}') assert self.kernel_block_size >= 16 and \ (self.kernel_block_size & (self.kernel_block_size - 1)) == 0, \ diff --git a/lmdeploy/pytorch/backends/attention.py b/lmdeploy/pytorch/backends/attention.py index 29b9f27290..ad90488311 100644 --- a/lmdeploy/pytorch/backends/attention.py +++ b/lmdeploy/pytorch/backends/attention.py @@ -98,6 +98,8 @@ def forward( attn_metadata: T, k_scales_zeros: torch.Tensor = None, v_scales_zeros: torch.Tensor = None, + k_scale: torch.Tensor = None, + v_scale: torch.Tensor = None, learnable_sink: torch.Tensor = None, nsa_indices: torch.Tensor = None, inplace: bool = False, diff --git a/lmdeploy/pytorch/backends/cuda/attention/default.py b/lmdeploy/pytorch/backends/cuda/attention/default.py index 5f9f97f0ec..7d50db869a 100644 --- a/lmdeploy/pytorch/backends/cuda/attention/default.py +++ b/lmdeploy/pytorch/backends/cuda/attention/default.py @@ -24,7 +24,7 @@ class TritonAttentionMetadata(AttentionMetadata): q_seqlens: Length of each query sequence [batch_size]. kv_start_loc: Start location of each KV sequence [batch_size]. kv_seqlens: Length of each KV sequence [batch_size]. - quant_policy: Quantization policy (0=none, 4=int4, 8=int8/fp8). + quant_policy: Quantization policy (0=none, 4=int4, 8=int8, 16/17=per-tensor fp8). kv_flatten_size: Total size of flattened KV cache. tile_scheduler_metadata: Scheduler metadata for Flash MLA. num_splits: Number of splits for Flash MLA. @@ -149,6 +149,8 @@ def _fill_kv_cache_impl( max_q_seqlen: int, k_scales_zeros: torch.Tensor = None, v_scales_zeros: torch.Tensor = None, + k_scale: torch.Tensor = None, + v_scale: torch.Tensor = None, ): """Fill kv cache.""" kv_seqlens = attn_metadata.kv_seqlens @@ -175,6 +177,8 @@ def _fill_kv_cache_impl( block_offsets=block_offsets, k_scales_zeros=k_scales_zeros, v_scales_zeros=v_scales_zeros, + k_scale=k_scale, + v_scale=v_scale, quant_policy=quant_policy, ) @@ -187,6 +191,8 @@ def _forward_decoding( max_q_seqlen: int, k_scales_zeros: torch.Tensor = None, v_scales_zeros: torch.Tensor = None, + k_scale: torch.Tensor = None, + v_scale: torch.Tensor = None, learnable_sink: torch.Tensor = None, ) -> torch.Tensor: """Forward pass for decoding stage. @@ -199,6 +205,8 @@ def _forward_decoding( max_q_seqlen: Maximum query sequence length. k_scales_zeros: Key quantization scales/zeros. v_scales_zeros: Value quantization scales/zeros. + k_scale: Per-tensor key scale for normal FP8 KV cache. + v_scale: Per-tensor value scale for normal FP8 KV cache. learnable_sink: Learnable sink tokens. Returns: @@ -224,6 +232,8 @@ def _forward_decoding( quant_policy=quant_policy, k_scales_zeros=k_scales_zeros, v_scales_zeros=v_scales_zeros, + k_scale=k_scale, + v_scale=v_scale, ) return attn_output @@ -236,6 +246,8 @@ def _forward_prefill( max_q_seqlen: int, k_scales_zeros: torch.Tensor = None, v_scales_zeros: torch.Tensor = None, + k_scale: torch.Tensor = None, + v_scale: torch.Tensor = None, learnable_sink: torch.Tensor = None, ) -> torch.Tensor: """Forward pass for prefill stage. @@ -248,6 +260,8 @@ def _forward_prefill( max_q_seqlen: Maximum query sequence length. k_scales_zeros: Key quantization scales/zeros. v_scales_zeros: Value quantization scales/zeros. + k_scale: Per-tensor key scale for normal FP8 KV cache. + v_scale: Per-tensor value scale for normal FP8 KV cache. learnable_sink: Learnable sink tokens. Returns: @@ -275,6 +289,8 @@ def _forward_prefill( out_dtype=query.dtype, k_scales_zeros=k_scales_zeros, v_scales_zeros=v_scales_zeros, + k_scale=k_scale, + v_scale=v_scale, quant_policy=quant_policy, flatten_kv_layout=kv_layout, ) @@ -323,6 +339,8 @@ def forward( attn_metadata: TritonAttentionMetadata, k_scales_zeros: torch.Tensor = None, v_scales_zeros: torch.Tensor = None, + k_scale: torch.Tensor = None, + v_scale: torch.Tensor = None, learnable_sink: torch.Tensor = None, inplace: bool = True, **kwargs, @@ -343,6 +361,8 @@ def forward( attn_metadata: Attention metadata containing stage info and indices. k_scales_zeros: Key quantization scales/zeros. v_scales_zeros: Value quantization scales/zeros. + k_scale: Per-tensor key scale for normal FP8 KV cache. + v_scale: Per-tensor value scale for normal FP8 KV cache. learnable_sink: Learnable sink tokens. inplace: Whether to modify query inplace (unused, kept for compatibility). @@ -363,6 +383,8 @@ def forward( max_q_seqlen=max_q_seqlen, k_scales_zeros=k_scales_zeros, v_scales_zeros=v_scales_zeros, + k_scale=k_scale, + v_scale=v_scale, ) # Validate alibi configuration @@ -379,6 +401,8 @@ def forward( max_q_seqlen, k_scales_zeros=k_scales_zeros, v_scales_zeros=v_scales_zeros, + k_scale=k_scale, + v_scale=v_scale, learnable_sink=learnable_sink, ) else: @@ -390,5 +414,7 @@ def forward( max_q_seqlen, k_scales_zeros=k_scales_zeros, v_scales_zeros=v_scales_zeros, + k_scale=k_scale, + v_scale=v_scale, learnable_sink=learnable_sink, ) diff --git a/lmdeploy/pytorch/backends/cuda/attention/fa3.py b/lmdeploy/pytorch/backends/cuda/attention/fa3.py index 2a37866da9..3fba8694ac 100644 --- a/lmdeploy/pytorch/backends/cuda/attention/fa3.py +++ b/lmdeploy/pytorch/backends/cuda/attention/fa3.py @@ -146,6 +146,8 @@ def _decoding_standard( max_q_seqlen: int, k_scales_zeros: torch.Tensor = None, v_scales_zeros: torch.Tensor = None, + k_scale: torch.Tensor = None, + v_scale: torch.Tensor = None, ) -> torch.Tensor: """Standard single-token decoding. @@ -160,6 +162,8 @@ def _decoding_standard( max_q_seqlen: Maximum query sequence length (= 1). k_scales_zeros: Key quantization scales/zeros. v_scales_zeros: Value quantization scales/zeros. + k_scale: Scalar key scale for normal FP8 KV cache. + v_scale: Scalar value scale for normal FP8 KV cache. Returns: Attention output tensor. @@ -183,6 +187,8 @@ def _decoding_standard( # custom args k_scales_zeros=k_scales_zeros, v_scales_zeros=v_scales_zeros, + k_scale=k_scale, + v_scale=v_scale, quant_policy=quant_policy, ) return attn_output @@ -196,6 +202,8 @@ def _forward_decoding( max_q_seqlen: int, k_scales_zeros: torch.Tensor = None, v_scales_zeros: torch.Tensor = None, + k_scale: torch.Tensor = None, + v_scale: torch.Tensor = None, ) -> torch.Tensor: """Forward pass for decoding stage. @@ -211,6 +219,8 @@ def _forward_decoding( max_q_seqlen: Maximum query sequence length. k_scales_zeros: Key quantization scales/zeros. v_scales_zeros: Value quantization scales/zeros. + k_scale: Scalar key scale for normal FP8 KV cache. + v_scale: Scalar value scale for normal FP8 KV cache. Returns: Attention output tensor. @@ -219,7 +229,7 @@ def _forward_decoding( return self._decoding_speculative(query, k_cache, v_cache, attn_metadata, max_q_seqlen) else: return self._decoding_standard(query, k_cache, v_cache, attn_metadata, max_q_seqlen, k_scales_zeros, - v_scales_zeros) + v_scales_zeros, k_scale, v_scale) def _forward_prefill( self, @@ -230,6 +240,8 @@ def _forward_prefill( max_q_seqlen: int, k_scales_zeros: torch.Tensor = None, v_scales_zeros: torch.Tensor = None, + k_scale: torch.Tensor = None, + v_scale: torch.Tensor = None, ) -> torch.Tensor: """Forward pass for prefill stage. @@ -244,6 +256,8 @@ def _forward_prefill( max_q_seqlen: Maximum query sequence length. k_scales_zeros: Key quantization scales/zeros. v_scales_zeros: Value quantization scales/zeros. + k_scale: Scalar key scale for normal FP8 KV cache. + v_scale: Scalar value scale for normal FP8 KV cache. Returns: Attention output tensor. @@ -265,6 +279,8 @@ def _forward_prefill( out_dtype=query.dtype, k_scales_zeros=k_scales_zeros, v_scales_zeros=v_scales_zeros, + k_scale=k_scale, + v_scale=v_scale, quant_policy=quant_policy, flatten_kv_layout='shd', ) @@ -310,6 +326,8 @@ def forward( attn_metadata: TritonAttentionMetadata, k_scales_zeros: torch.Tensor = None, v_scales_zeros: torch.Tensor = None, + k_scale: torch.Tensor = None, + v_scale: torch.Tensor = None, learnable_sink: torch.Tensor = None, inplace: bool = True, ) -> torch.Tensor: @@ -333,6 +351,8 @@ def forward( attn_metadata: Attention metadata containing stage info and indices. k_scales_zeros: Key quantization scales/zeros. v_scales_zeros: Value quantization scales/zeros. + k_scale: Scalar key scale for normal FP8 KV cache. + v_scale: Scalar value scale for normal FP8 KV cache. learnable_sink: Learnable sink tokens (unused in FA3). inplace: Whether to modify query inplace (unused, kept for compatibility). @@ -353,6 +373,8 @@ def forward( max_q_seqlen=max_q_seqlen, k_scales_zeros=k_scales_zeros, v_scales_zeros=v_scales_zeros, + k_scale=k_scale, + v_scale=v_scale, ) # Dispatch to stage-specific forward method @@ -365,6 +387,8 @@ def forward( max_q_seqlen, k_scales_zeros, v_scales_zeros, + k_scale, + v_scale, ) else: return self._forward_prefill( @@ -375,4 +399,6 @@ def forward( max_q_seqlen, k_scales_zeros, v_scales_zeros, + k_scale, + v_scale, ) diff --git a/lmdeploy/pytorch/backends/dlinfer/attention.py b/lmdeploy/pytorch/backends/dlinfer/attention.py index a8eea27545..4bafb6d43f 100644 --- a/lmdeploy/pytorch/backends/dlinfer/attention.py +++ b/lmdeploy/pytorch/backends/dlinfer/attention.py @@ -66,6 +66,8 @@ def forward( attn_metadata: DlinferAttentionMetadata, k_scales_zeros: Tensor = None, v_scales_zeros: Tensor = None, + k_scale: Tensor = None, + v_scale: Tensor = None, learnable_sink: Tensor = None, nsa_indices: Tensor = None, inplace: bool = True, diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py index 4ebbc157fc..2f1ed12fdc 100644 --- a/lmdeploy/pytorch/config.py +++ b/lmdeploy/pytorch/config.py @@ -390,7 +390,7 @@ def from_pretrained( activations. Refer to `PyTorchEngineConfig` for details hf_overrides (dict[str, Any]): overrides for the HF config. """ - from transformers import AutoConfig + from transformers import AutoConfig # noqa: I001 from lmdeploy.pytorch.transformers import config_from_pretrained hf_config = config_from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust_remote_code) diff --git a/lmdeploy/pytorch/engine/cache_engine.py b/lmdeploy/pytorch/engine/cache_engine.py index c7e904bd6f..c8844d8a2f 100644 --- a/lmdeploy/pytorch/engine/cache_engine.py +++ b/lmdeploy/pytorch/engine/cache_engine.py @@ -52,6 +52,35 @@ def _get_kv_cache_dtype(model_config: ModelConfig): return kv_cache_dtype +def _is_fp8_quant_policy(quant_policy: QuantPolicy): + """Return whether quant policy stores KV payload as torch FP8.""" + return quant_policy in (QuantPolicy.FP8, QuantPolicy.FP8_E5M2) + + +def _get_fp8_cache_dtype(quant_policy: QuantPolicy): + """Get the cache tensor dtype for an FP8 KV-cache quant policy.""" + if quant_policy == QuantPolicy.FP8: + return torch.float8_e4m3fn + if quant_policy == QuantPolicy.FP8_E5M2: + return torch.float8_e5m2 + raise ValueError(f'Not an FP8 quant policy: {quant_policy}') + + +def _describe_kv_cache_quant_policy(quant_policy: QuantPolicy): + """Describe the active KV-cache quantization policy for logs.""" + if quant_policy == QuantPolicy.FP8: + return 'fp8_e4m3 per-tensor KV cache (torch.float8_e4m3fn)' + if quant_policy == QuantPolicy.FP8_E5M2: + return 'fp8_e5m2 per-tensor KV cache (torch.float8_e5m2)' + if quant_policy == QuantPolicy.INT4: + return 'int4 KV cache' + if quant_policy == QuantPolicy.INT8: + return 'int8 KV cache' + if quant_policy == QuantPolicy.TURBO_QUANT: + return 'TurboQuant KV cache' + return None + + # 512*1 + 4*4 + 64*2 = 656 MLA_FP8_HEAD_DIM = 656 @@ -91,7 +120,11 @@ def __init__( if self.model_config.use_mla_fp8_cache: cache_config.quant_policy = 0 - if cache_config.quant_policy > 0: + if _is_fp8_quant_policy(cache_config.quant_policy): + self.kv_cache_dtype = _get_fp8_cache_dtype(cache_config.quant_policy) + assert self.cache_config.device_type in ['cuda'], \ + f'FP8 quantization is only supported on CUDA device, but got {self.cache_config.device_type}.' + elif cache_config.quant_policy > 0: if self.cache_config.device_type in ['cuda']: self.kv_cache_dtype = torch.uint8 elif self.cache_config.device_type in ['ascend', 'npu']: @@ -99,6 +132,10 @@ def __init__( else: raise ValueError(f'unsupported device_type {self.cache_config.device_type}') + quant_desc = _describe_kv_cache_quant_policy(cache_config.quant_policy) + if quant_desc is not None: + logger.info('Using %s.', quant_desc) + # Initialize the cache. self.local_gpu_cache = self.allocate_gpu_cache() self.local_cpu_cache = self.allocate_cpu_cache() @@ -210,7 +247,9 @@ def get_k_cache_desc(cls, model_config: ModelConfig, cache_config: CacheConfig, ) shape = list(shape) dtype = _get_kv_cache_dtype(model_config) - if cache_config.quant_policy in (QuantPolicy.INT4, QuantPolicy.INT8, QuantPolicy.TURBO_QUANT): + if _is_fp8_quant_policy(cache_config.quant_policy): + dtype = _get_fp8_cache_dtype(cache_config.quant_policy) + elif cache_config.quant_policy in (QuantPolicy.INT4, QuantPolicy.INT8, QuantPolicy.TURBO_QUANT): dtype = torch.uint8 return CacheDesc(shape=shape, dtype=dtype) @@ -229,7 +268,9 @@ def get_v_cache_desc(cls, model_config: ModelConfig, cache_config: CacheConfig, ) shape = list(shape) dtype = _get_kv_cache_dtype(model_config) - if cache_config.quant_policy in (QuantPolicy.INT4, QuantPolicy.INT8, QuantPolicy.TURBO_QUANT): + if _is_fp8_quant_policy(cache_config.quant_policy): + dtype = _get_fp8_cache_dtype(cache_config.quant_policy) + elif cache_config.quant_policy in (QuantPolicy.INT4, QuantPolicy.INT8, QuantPolicy.TURBO_QUANT): dtype = torch.uint8 return CacheDesc(shape=shape, dtype=dtype) @@ -239,6 +280,8 @@ def get_quant_cache_descs(cls, k_cache_desc: CacheDesc, v_cache_desc: CacheDesc, """Get quant cache descs.""" if cache_config.quant_policy == QuantPolicy.NONE: return [] + if _is_fp8_quant_policy(cache_config.quant_policy): + return [] dtype = model_config.dtype # For quant_policy==QuantPolicy.TURBO_QUANT, K uses 4-bit quantization (has MSE norm and QJL norm), diff --git a/lmdeploy/pytorch/engine/model_agent/agent.py b/lmdeploy/pytorch/engine/model_agent/agent.py index f4c710d85d..0f8af26e68 100644 --- a/lmdeploy/pytorch/engine/model_agent/agent.py +++ b/lmdeploy/pytorch/engine/model_agent/agent.py @@ -981,6 +981,9 @@ def _build_model(self): if adapters is not None: logger.debug(msg_with_rank(rank, 'loading adapters.')) add_adapters(patched_model, adapters, dtype=self.model_config.dtype, device=device) + for module in patched_model.modules(): + if hasattr(module, 'finalize_kv_scales'): + module.finalize_kv_scales(self.cache_config.quant_policy) self.patched_model = patched_model self.build_model_ctx = build_model_ctx diff --git a/lmdeploy/pytorch/kernels/cuda/fill_kv_cache.py b/lmdeploy/pytorch/kernels/cuda/fill_kv_cache.py index 9e5622e507..0b3668d607 100644 --- a/lmdeploy/pytorch/kernels/cuda/fill_kv_cache.py +++ b/lmdeploy/pytorch/kernels/cuda/fill_kv_cache.py @@ -135,6 +135,103 @@ def _fill_kv_cache_kernel( tl.store(vc_ptrs, v, mask=mask_vc) +@triton.jit +def _fill_kv_cache_fp8_scalar_kernel( + KStates, + VStates, + KCaches, + VCaches, + KScale, + VScale, + QStartLoc, + QSeqLens, + KVSeqLens, + BlockOffsets, + fp8_min: tl.constexpr, + fp8_max: tl.constexpr, + is_decoding: tl.constexpr, + head_dim: tl.constexpr, + head_dim_v: tl.constexpr, + stride_kss, + stride_ksh, + stride_ksd, + stride_vss, + stride_vsh, + stride_vsd, + stride_kcn: tl.constexpr, + stride_kcb: tl.constexpr, + stride_kch: tl.constexpr, + stride_kcd: tl.constexpr, + stride_vcn: tl.constexpr, + stride_vcb: tl.constexpr, + stride_vch: tl.constexpr, + stride_vcd: tl.constexpr, + stride_boff, + BLOCK: tl.constexpr, + BLOCK_D: tl.constexpr, + BLOCK_DV: tl.constexpr, +): + """Fill FP8 KV cache with per-tensor K/V scales.""" + batch_id = tl.program_id(2) + head_id = tl.program_id(0) + block_id = tl.program_id(1) + + q_startloc = tl.load(QStartLoc + batch_id) + q_seqlen = tl.load(QSeqLens + batch_id) + kv_seqlen = tl.load(KVSeqLens + batch_id) + history_seqlen = kv_seqlen - q_seqlen + kv_block_id = history_seqlen // BLOCK + block_id + + if kv_seqlen <= 0: + return + if kv_block_id * BLOCK >= kv_seqlen: + return + + if is_decoding: + page_offs = tl.full((1, ), history_seqlen % BLOCK, dtype=tl.int32) + kv_mask = tl.full((1, ), 1, dtype=tl.int1) + q_offs = tl.full((1, ), q_startloc, dtype=tl.int32) + else: + page_offs = tl.arange(0, BLOCK) + kv_offs = kv_block_id * BLOCK + page_offs + kv_mask = (kv_offs >= history_seqlen) & (kv_offs < kv_seqlen) + token_off = q_startloc + kv_block_id * BLOCK - history_seqlen + q_offs = token_off + page_offs + + block_off = tl.load(BlockOffsets + batch_id * stride_boff + kv_block_id) + block_off = block_off.to(tl.int64) + + d_off = tl.arange(0, BLOCK_D) + mask_ks = kv_mask[:, None] + mask_kc = mask_ks & (d_off[None, :] < head_dim) + d_off = d_off % head_dim + + ks_ptr = KStates + head_id * stride_ksh + ks_ptrs = ks_ptr + q_offs[:, None] * stride_kss + d_off[None, :] * stride_ksd + kc_ptr = KCaches + block_off * stride_kcn + head_id * stride_kch + kc_ptrs = kc_ptr + page_offs[:, None] * stride_kcb + d_off[None, :] * stride_kcd + + k_scale = tl.load(KScale).to(tl.float32) + k = tl.load(ks_ptrs, mask=mask_ks).to(tl.float32) / k_scale + k = tl.clamp(k, fp8_min, fp8_max).to(KCaches.dtype.element_ty) + tl.store(kc_ptrs, k, mask=mask_kc) + + if BLOCK_DV > 0: + dv_off = tl.arange(0, BLOCK_DV) + mask_vs = kv_mask[:, None] + mask_vc = mask_vs & (dv_off[None, :] < head_dim_v) + dv_off = dv_off % head_dim_v + vs_ptr = VStates + head_id * stride_vsh + vs_ptrs = vs_ptr + q_offs[:, None] * stride_vss + dv_off[None, :] * stride_vsd + vc_ptr = VCaches + block_off * stride_vcn + head_id * stride_vch + vc_ptrs = vc_ptr + page_offs[:, None] * stride_vcb + dv_off[None, :] * stride_vcd + + v_scale = tl.load(VScale).to(tl.float32) + v = tl.load(vs_ptrs, mask=mask_vs).to(tl.float32) / v_scale + v = tl.clamp(v, fp8_min, fp8_max).to(VCaches.dtype.element_ty) + tl.store(vc_ptrs, v, mask=mask_vc) + + @triton.jit def _fill_page_quant_int8( state_ptr, @@ -698,8 +795,15 @@ def fill_kv_cache(k_states: Tensor, k_scales_zeros: Tensor = None, v_scales_zeros: Tensor = None, quant_policy: QuantPolicy = QuantPolicy.NONE, + k_scale: Tensor = None, + v_scale: Tensor = None, kv_layout: str = 'bshd'): - """Fill key/value state to cache for paged attention.""" + """Fill key/value state to cache for paged attention. + + Args: + k_scale: Per-tensor key scale for normal FP8 KV cache. + v_scale: Per-tensor value scale for normal FP8 KV cache. + """ if kv_layout == 'bshd': b_dim, s_dim, h_dim, d_dim = (0, 1, 2, 3) elif kv_layout == 'bhsd': @@ -805,6 +909,49 @@ def fill_kv_cache(k_states: Tensor, num_warps=4, num_stages=3, ) + elif quant_policy in (QuantPolicy.FP8, QuantPolicy.FP8_E5M2): + if k_scale is None: + k_scale = torch.ones((), device=k_states.device, dtype=torch.float32) + if v_scale is None: + v_scale = k_scale + finfo = torch.finfo(k_caches.dtype) + _fill_kv_cache_fp8_scalar_kernel[grid]( + k_states, + v_states, + k_caches, + v_caches, + k_scale, + v_scale, + q_start_loc, + q_seq_length, + kv_seq_length, + block_offsets, + fp8_min=finfo.min, + fp8_max=finfo.max, + is_decoding=is_decoding, + head_dim=head_dim, + head_dim_v=head_dim_v, + stride_kss=k_states.stride(-3), + stride_ksh=k_states.stride(-2), + stride_ksd=k_states.stride(-1), + stride_vss=v_states.stride(-3), + stride_vsh=v_states.stride(-2), + stride_vsd=v_states.stride(-1), + stride_kcn=k_caches.stride(b_dim), + stride_kcb=k_caches.stride(s_dim), + stride_kch=k_caches.stride(h_dim), + stride_kcd=k_caches.stride(d_dim), + stride_vcn=v_caches.stride(b_dim), + stride_vcb=v_caches.stride(s_dim), + stride_vch=v_caches.stride(h_dim), + stride_vcd=v_caches.stride(d_dim), + stride_boff=block_offsets.stride(0), + BLOCK=BLOCK, + BLOCK_D=BLOCK_D, + BLOCK_DV=BLOCK_DV, + num_warps=4, + num_stages=3, + ) else: _fill_kv_cache_quant_kernel[grid]( k_states, diff --git a/lmdeploy/pytorch/kernels/cuda/flatten_kv_cache.py b/lmdeploy/pytorch/kernels/cuda/flatten_kv_cache.py index 057786e400..c918e1e2a3 100644 --- a/lmdeploy/pytorch/kernels/cuda/flatten_kv_cache.py +++ b/lmdeploy/pytorch/kernels/cuda/flatten_kv_cache.py @@ -91,6 +91,81 @@ def _flatten_kv_cache( tl.store(vo_ptrs, vc, mask=mask_bs[:, None] & mask_dv[None, :]) +@triton.jit +def _flatten_kv_cache_fp8_scalar( + kc_ptr, + vc_ptr, + ko_ptr, + vo_ptr, + k_scale_ptr, + v_scale_ptr, + start_loc_ptr, + seqlens_ptr, + block_offsets_ptr, + stride_kcb: tl.constexpr, + stride_kcs: tl.constexpr, + stride_kch: tl.constexpr, + stride_kcd: tl.constexpr, + stride_vcb: tl.constexpr, + stride_vcs: tl.constexpr, + stride_vch: tl.constexpr, + stride_vcd: tl.constexpr, + stride_koh, + stride_kos: tl.constexpr, + stride_kod: tl.constexpr, + stride_voh, + stride_vos: tl.constexpr, + stride_vod: tl.constexpr, + stride_boff, + OUT_SIZE, + HEAD_DIM_K: tl.constexpr, + HEAD_DIM_V: tl.constexpr, + BLOCK_BS: tl.constexpr, + BLOCK_DK: tl.constexpr, + BLOCK_DV: tl.constexpr, +): + """Flatten per-tensor FP8 KV cache.""" + page_id = tl.program_id(0) + batch_id = tl.program_id(1) + head_id = tl.program_id(2) + + num_batches = tl.num_programs(1) + seqlen = tl.load(seqlens_ptr + batch_id) + start_loc = tl.load(start_loc_ptr + batch_id) + if batch_id == num_batches - 1: + seqlen = (OUT_SIZE - start_loc).to(seqlen.dtype) + if page_id * BLOCK_BS >= seqlen: + return + + b_off = tl.load(block_offsets_ptr + batch_id * stride_boff + page_id) + b_off = b_off.to(tl.int64) + offs_bs = tl.arange(0, BLOCK_BS) + offs_dk = tl.arange(0, BLOCK_DK) % HEAD_DIM_K + offs_dv = tl.arange(0, BLOCK_DV) % HEAD_DIM_V + offs_obs = page_id * BLOCK_BS + tl.arange(0, BLOCK_BS) + mask_bs = offs_obs < seqlen + mask_dk = tl.arange(0, BLOCK_DK) < HEAD_DIM_K + mask_dv = tl.arange(0, BLOCK_DV) < HEAD_DIM_V + + kc_ptrs = (kc_ptr + b_off * stride_kcb + offs_bs[:, None] * stride_kcs + head_id * stride_kch + + offs_dk[None, :] * stride_kcd) + vc_ptrs = (vc_ptr + b_off * stride_vcb + offs_bs[:, None] * stride_vcs + head_id * stride_vch + + offs_dv[None, :] * stride_vcd) + ko_ptrs = (ko_ptr + head_id * stride_koh + (start_loc + offs_obs[:, None]) * stride_kos + + offs_dk[None, :] * stride_kod) + vo_ptrs = (vo_ptr + head_id * stride_voh + (start_loc + offs_obs[:, None]) * stride_vos + + offs_dv[None, :] * stride_vod) + + k_scale = tl.load(k_scale_ptr).to(tl.float32) + kc = tl.load(kc_ptrs).to(tl.float32) * k_scale + tl.store(ko_ptrs, kc.to(ko_ptr.dtype.element_ty), mask=mask_bs[:, None] & mask_dk[None, :]) + + if HEAD_DIM_V > 0: + v_scale = tl.load(v_scale_ptr).to(tl.float32) + vc = tl.load(vc_ptrs).to(tl.float32) * v_scale + tl.store(vo_ptrs, vc.to(vo_ptr.dtype.element_ty), mask=mask_bs[:, None] & mask_dv[None, :]) + + @triton.jit def _dequant_int4(val, HEAD_DIM: tl.constexpr, BLOCK: tl.constexpr): """Dequant int4.""" @@ -261,10 +336,17 @@ def flatten_kv_cache(k_caches: Tensor, out_dtype: torch.dtype = None, k_scales_zeros: Tensor = None, v_scales_zeros: Tensor = None, + k_scale: Tensor = None, + v_scale: Tensor = None, quant_policy: QuantPolicy = QuantPolicy.NONE, kv_layout: str = 'bshd', flatten_kv_layout: str = 'hsd'): - """Recovery paged kv cache to normal kv cache.""" + """Recovery paged kv cache to normal kv cache. + + Args: + k_scale: Per-tensor key scale for normal FP8 KV cache. + v_scale: Per-tensor value scale for normal FP8 KV cache. + """ if kv_layout == 'bshd': b_dim, s_dim, h_dim, d_dim = (0, 1, 2, 3) elif kv_layout == 'bhsd': @@ -273,7 +355,7 @@ def flatten_kv_cache(k_caches: Tensor, raise RuntimeError('Unsupported layout.') if out_dtype is None: - if quant_policy == QuantPolicy.TURBO_QUANT: + if quant_policy in (QuantPolicy.FP8, QuantPolicy.FP8_E5M2, QuantPolicy.TURBO_QUANT): out_dtype = torch.float16 else: out_dtype = k_caches.dtype @@ -355,6 +437,43 @@ def flatten_kv_cache(k_caches: Tensor, BLOCK_DK=BLOCK_DK, BLOCK_DV=BLOCK_DV, ) + elif quant_policy in (QuantPolicy.FP8, QuantPolicy.FP8_E5M2): + if k_scale is None: + k_scale = torch.ones((), device=k_caches.device, dtype=torch.float32) + if v_scale is None: + v_scale = k_scale + _flatten_kv_cache_fp8_scalar[grid]( + k_caches, + v_caches, + k_states, + v_states, + k_scale, + v_scale, + start_loc, + seqlens, + block_offsets, + stride_kcb=k_caches.stride(b_dim), + stride_kcs=k_caches.stride(s_dim), + stride_kch=k_caches.stride(h_dim), + stride_kcd=k_caches.stride(d_dim), + stride_vcb=v_caches.stride(b_dim), + stride_vcs=v_caches.stride(s_dim), + stride_vch=v_caches.stride(h_dim), + stride_vcd=v_caches.stride(d_dim), + stride_koh=stride_koh, + stride_kos=stride_kos, + stride_kod=k_states.stride(2), + stride_voh=stride_voh, + stride_vos=stride_vos, + stride_vod=v_states.stride(2), + stride_boff=block_offsets.stride(0), + OUT_SIZE=out_size, + HEAD_DIM_K=k_head_dim, + HEAD_DIM_V=v_head_dim, + BLOCK_BS=BLOCK_BS, + BLOCK_DK=BLOCK_DK, + BLOCK_DV=BLOCK_DV, + ) else: if quant_policy == QuantPolicy.TURBO_QUANT: # K = QJL4 => 3bit centroid codebook diff --git a/lmdeploy/pytorch/kernels/cuda/pagedattention.py b/lmdeploy/pytorch/kernels/cuda/pagedattention.py index a4f0bcef99..6df8181d88 100644 --- a/lmdeploy/pytorch/kernels/cuda/pagedattention.py +++ b/lmdeploy/pytorch/kernels/cuda/pagedattention.py @@ -23,6 +23,7 @@ Q_POLICY_NONE = tl.constexpr(0) Q_POLICY_INT4 = tl.constexpr(4) Q_POLICY_INT8 = tl.constexpr(8) +Q_POLICY_FP8 = tl.constexpr(16) Q_POLICY_TURBO = tl.constexpr(42) TRITON_VERSION = version.parse(triton.__version__) @@ -459,6 +460,9 @@ def _fwd_grouped_split_quant_kernel( k_cent = _k4v2_k_centroid((k & 0x7), head_size) k_sign = ((k >> 3) & 0x1).to(tl.float32) * 2.0 - 1.0 k = (kmse_norm * (k_cent + kqjl_norm * k_sign)).to(q.dtype) + elif quant_policy == Q_POLICY_FP8: + ks = tl.load(KScalesZeros).to(tl.float32) + k = k.to(q.dtype) else: ks = tl.load(ksz_ptrs + b_offset * stride_kszp) kz = tl.load(ksz_ptrs + b_offset * stride_kszp + 1) @@ -476,6 +480,8 @@ def _fwd_grouped_split_quant_kernel( k1_cent = _k4v2_k_centroid((k1 & 0x7), head_size) k1_sign = ((k1 >> 3) & 0x1).to(tl.float32) * 2.0 - 1.0 k1 = (kmse_norm * (k1_cent + kqjl_norm * k1_sign)).to(q.dtype) + elif quant_policy == Q_POLICY_FP8: + k1 = k1.to(q.dtype) else: k1 = ((k1 - kz) * ks).to(q.dtype) @@ -493,6 +499,9 @@ def _fwd_grouped_split_quant_kernel( vs = tl.load(vsz_ptrs + b_offset * stride_vszp) v = _k4v2_v_centroid(v, head_size_v) v = (v * vs).to(q.dtype) + elif quant_policy == Q_POLICY_FP8: + vs = tl.load(VScalesZeros).to(tl.float32) + v = v.to(q.dtype) else: vs = tl.load(vsz_ptrs + b_offset * stride_vszp) vz = tl.load(vsz_ptrs + b_offset * stride_vszp + 1) @@ -503,6 +512,8 @@ def _fwd_grouped_split_quant_kernel( qk += tl.dot(q, k) if BLOCK_DMODEL1 != 0: qk += tl.dot(q1, k1) + if quant_policy == Q_POLICY_FP8: + qk *= ks qk *= sm_scale if logit_softcapping > 0.0: qk = qk / logit_softcapping @@ -536,6 +547,8 @@ def _fwd_grouped_split_quant_kernel( acc = acc * alpha[:, None] # update acc + if quant_policy == Q_POLICY_FP8: + p = p * vs p, v = _convert_pv(p, v) acc += tl.dot(p, v) # update m_i and l_i @@ -756,6 +769,8 @@ def flash_attn_with_kvcache( alibi_slopes: Tensor = None, k_scales_zeros: Tensor = None, v_scales_zeros: Tensor = None, + k_scale: Tensor = None, + v_scale: Tensor = None, quant_policy: QuantPolicy = QuantPolicy.NONE, sinks: Tensor = None, kv_layout: str = 'bshd', @@ -763,6 +778,10 @@ def flash_attn_with_kvcache( """Paged Attention forward. Note that this kernel is decoding-only + + Args: + k_scale: Per-tensor key scale for normal FP8 KV cache. + v_scale: Per-tensor value scale for normal FP8 KV cache. """ global _nv_cap @@ -871,6 +890,7 @@ def _get_block_d(Lk): else: num_warps, num_stages = _kernel_meta_sm9x(BLOCK_DMODEL, BLOCK_H) + is_fp8_scalar = quant_policy in (QuantPolicy.FP8, QuantPolicy.FP8_E5M2) SPLIT_K = _get_split_k(q.device.index, grid_1, batch, num_warps) if quant_policy == QuantPolicy.INT4 or quant_policy == QuantPolicy.TURBO_QUANT: @@ -885,11 +905,33 @@ def _get_block_d(Lk): ) if quant_policy != QuantPolicy.NONE: + if is_fp8_scalar: + if k_scale is None: + k_scale = torch.ones((), device=q.device, dtype=torch.float32) + if v_scale is None: + v_scale = k_scale + k_scales_arg = k_scale + v_scales_arg = v_scale + stride_kszp = stride_kszbs = stride_kszh = stride_kszd = 0 + stride_vszp = stride_vszbs = stride_vszh = stride_vszd = 0 + triton_quant_policy = QuantPolicy.FP8 + else: + k_scales_arg = k_scales_zeros + v_scales_arg = v_scales_zeros + stride_kszp = k_scales_zeros.stride(b_dim) + stride_kszbs = k_scales_zeros.stride(s_dim) + stride_kszh = k_scales_zeros.stride(h_dim) + stride_kszd = k_scales_zeros.stride(d_dim) + stride_vszp = v_scales_zeros.stride(b_dim) + stride_vszbs = v_scales_zeros.stride(s_dim) + stride_vszh = v_scales_zeros.stride(h_dim) + stride_vszd = v_scales_zeros.stride(d_dim) + triton_quant_policy = quant_policy _fwd_grouped_split_quant_kernel[grid](q, k_cache, v_cache, - k_scales_zeros, - v_scales_zeros, + k_scales_arg, + v_scales_arg, softmax_scale, cache_seqlens, page_table, @@ -906,15 +948,15 @@ def _get_block_d(Lk): stride_vbs=v_cache.stride(s_dim), stride_vh=v_cache.stride(h_dim), stride_vd=v_cache.stride(d_dim), - stride_kszp=k_scales_zeros.stride(b_dim), - stride_kszbs=k_scales_zeros.stride(s_dim), - stride_kszh=k_scales_zeros.stride(h_dim), - stride_kszd=k_scales_zeros.stride(d_dim), - stride_vszp=v_scales_zeros.stride(b_dim), - stride_vszbs=v_scales_zeros.stride(s_dim), - stride_vszh=v_scales_zeros.stride(h_dim), - stride_vszd=v_scales_zeros.stride(d_dim), - quant_policy=quant_policy, + stride_kszp=stride_kszp, + stride_kszbs=stride_kszbs, + stride_kszh=stride_kszh, + stride_kszd=stride_kszd, + stride_vszp=stride_vszp, + stride_vszbs=stride_vszbs, + stride_vszh=stride_vszh, + stride_vszd=stride_vszd, + quant_policy=triton_quant_policy, stride_ok=acc.stride(-2), stride_obs=acc.stride(-4), stride_oh=acc.stride(-3), diff --git a/lmdeploy/pytorch/nn/attention.py b/lmdeploy/pytorch/nn/attention.py index a8747aa3d4..fdac8cbd23 100644 --- a/lmdeploy/pytorch/nn/attention.py +++ b/lmdeploy/pytorch/nn/attention.py @@ -2,12 +2,22 @@ import torch from torch import nn +from lmdeploy.messages import QuantPolicy from lmdeploy.pytorch.distributed import get_tp_world_rank +from lmdeploy.utils import get_logger from ..backends import OpType, get_backend from ..backends.attention import AttentionMetadata from .utils import get_distribute_size +logger = get_logger('lmdeploy') +_DEFAULT_FP8_SCALE_WARNED = False + + +def _is_normal_fp8_quant_policy(quant_policy: QuantPolicy): + """Return whether quant_policy uses per-tensor FP8 KV cache.""" + return quant_policy in (QuantPolicy.FP8, QuantPolicy.FP8_E5M2) + def _update_num_heads(num_heads: int, num_kv_heads: int): """Update heads.""" @@ -68,6 +78,9 @@ def __init__( self.alibi_ready = False else: self.alibi_ready = True + scale_device = kwargs.get('device', None) + self.register_buffer('k_scale', torch.ones((), dtype=torch.float32, device=scale_device)) + self.register_buffer('v_scale', torch.ones((), dtype=torch.float32, device=scale_device)) def _lazy_init(self, device): """Lazy init.""" @@ -84,6 +97,24 @@ def _lazy_init(self, device): self.impl.set_alibi_slopes(alibi_slopes) self.alibi_ready = True + @torch.no_grad() + def finalize_kv_scales(self, quant_policy: QuantPolicy): + """Finalize loaded per-tensor FP8 KV scales before inference.""" + global _DEFAULT_FP8_SCALE_WARNED + if not _is_normal_fp8_quant_policy(quant_policy): + return + + if _DEFAULT_FP8_SCALE_WARNED or quant_policy != QuantPolicy.FP8: + return + if self.k_scale.item() == 1.0 and self.v_scale.item() == 1.0: + logger.warning('Using normal FP8 E4M3 KV cache with default k_scale=v_scale=1.0. ' + 'This matches vLLM no-calibration behavior but may affect accuracy.') + _DEFAULT_FP8_SCALE_WARNED = True + + def _effective_kv_scales(self): + """Return per-tensor K/V scales.""" + return self.k_scale, self.v_scale + def forward( self, query: torch.Tensor, @@ -101,6 +132,13 @@ def forward( """forward.""" self._lazy_init(query.device) + quant_policy = attn_metadata.quant_policy + if _is_normal_fp8_quant_policy(quant_policy): + k_scale, v_scale = self._effective_kv_scales() + else: + k_scale = None + v_scale = None + kwargs = dict() if nsa_indices is not None: kwargs['nsa_indices'] = nsa_indices @@ -115,6 +153,8 @@ def forward( attn_metadata=attn_metadata, k_scales_zeros=k_scales_zeros, v_scales_zeros=v_scales_zeros, + k_scale=k_scale, + v_scale=v_scale, inplace=inplace, **kwargs, ) diff --git a/tests/pytorch/kernel/test_fill_kv_cache.py b/tests/pytorch/kernel/test_fill_kv_cache.py index d9ee15a599..3f2826beb0 100644 --- a/tests/pytorch/kernel/test_fill_kv_cache.py +++ b/tests/pytorch/kernel/test_fill_kv_cache.py @@ -675,6 +675,7 @@ def test_fill_kv_cache(self, k_states, v_states, k_caches, v_caches, ks_caches, # uncache out_k, out_ks, out_v, out_vs = self.uncache(k_caches, ks_caches, v_caches, vs_caches, cu_seqlen_q, kv_seq_length, block_offsets) + out_k = out_k.float() out_k = out_k / out_k.max() gt_k = gt_k.float() gt_k = gt_k / gt_k.max() @@ -686,3 +687,133 @@ def test_fill_kv_cache(self, k_states, v_states, k_caches, v_caches, ks_caches, torch.testing.assert_close(out_ks, gt_ks) torch.testing.assert_close(out_v, gt_v) torch.testing.assert_close(out_vs, gt_vs) + + +def _quant_fp8_scalar(x: torch.Tensor, fp8_dtype: torch.dtype, scale: float): + """Per-tensor FP8 quantization.""" + fp8_max = torch.finfo(fp8_dtype).max + scale_t = x.new_tensor(scale, dtype=torch.float32) + q = (x.to(torch.float32) / scale_t).clamp(-fp8_max, fp8_max).to(fp8_dtype) + return q, scale_t + + +def _skip_unsupported_triton_fp8_dtype(fp8_dtype: torch.dtype): + if fp8_dtype is torch.float8_e4m3fn and torch.cuda.get_device_capability()[0] < 9: + pytest.skip('Triton float8_e4m3fn conversion requires device with cc>=9.0') + + +def _assert_fp8_cache_close(actual: torch.Tensor, + expected: torch.Tensor, + fp8_dtype: torch.dtype, + scale: torch.Tensor = None): + actual = actual.to(torch.float32) + expected = expected.to(torch.float32) + if fp8_dtype is torch.float8_e5m2: + if scale is not None: + scale = scale.to(torch.float32) + actual = actual * scale + expected = expected * scale + torch.testing.assert_close(actual, expected, atol=1e-6, rtol=0.25) + else: + torch.testing.assert_close(actual, expected) + + +class TestFillKVCacheFP8Scalar(TestFillKVCache): + """Tests for fill_kv_cache with normal per-tensor QuantPolicy.FP8.""" + + @pytest.fixture(autouse=True) + def skip_unsupported_fp8_dtype(self, fp8_dtype): + _skip_unsupported_triton_fp8_dtype(fp8_dtype) + + @pytest.fixture + def fp8_dtype(self): + yield torch.float8_e4m3fn + + @pytest.fixture + def quant_policy(self): + yield QuantPolicy.FP8 + + @pytest.fixture + def head_dim(self, request): + yield request.param + + @pytest.fixture + def k_caches(self, batch_size, max_num_blocks, block_size, num_heads, head_dim, fp8_dtype): + shape = (batch_size * max_num_blocks, block_size, num_heads, head_dim) + yield torch.zeros(shape, dtype=fp8_dtype).cuda() + + @pytest.fixture + def v_caches(self, batch_size, max_num_blocks, block_size, num_heads, head_dim, fp8_dtype): + shape = (batch_size * max_num_blocks, block_size, num_heads, head_dim) + yield torch.zeros(shape, dtype=fp8_dtype).cuda() + + @pytest.fixture + def gt(self, k_states, v_states, k_caches, v_caches, seq_lens, history_lens, block_offsets, block_size, fp8_dtype): + k_states_q, k_scale = _quant_fp8_scalar(k_states, fp8_dtype, scale=0.25) + v_states_q, v_scale = _quant_fp8_scalar(v_states, fp8_dtype, scale=0.5) + batch_size = len(seq_lens) + k_caches = k_caches.clone() + v_caches = v_caches.clone() + + splited_k_states = k_states_q.split(seq_lens) + splited_v_states = v_states_q.split(seq_lens) + + for bidx in range(batch_size): + k_state = splited_k_states[bidx] + v_state = splited_v_states[bidx] + h_len = history_lens[bidx] + b_offs = block_offsets[bidx] + block_id = _div_up(h_len + 1, block_size) - 1 + fill_start = h_len % block_size + fill_size = min(block_size - fill_start, k_state.size(0)) + while True: + boff = b_offs[block_id] + fill_end = fill_start + fill_size + k_caches[boff, fill_start:fill_end] = k_state[:fill_size] + v_caches[boff, fill_start:fill_end] = v_state[:fill_size] + k_state = k_state[fill_size:] + v_state = v_state[fill_size:] + block_id += 1 + fill_start = 0 + fill_size = min(block_size, k_state.size(0)) + if fill_size == 0: + break + + yield k_caches, v_caches, k_scale, v_scale + + @pytest.mark.parametrize('head_dim', [128], indirect=True) + @pytest.mark.parametrize(['seq_lens', 'history_lens'], [ + ((1, 1, 1, 1), (1, 16, 31, 24)), + ((1, 8, 16, 24), (1, 16, 31, 24)), + ], + indirect=True) + def test_fill_kv_cache(self, k_states, v_states, k_caches, v_caches, block_offsets, q_start_loc, q_seq_length, + kv_seq_length, max_q_seq_length, gt, quant_policy, fp8_dtype): + from lmdeploy.pytorch.kernels.cuda.fill_kv_cache import fill_kv_cache + gt_k, gt_v, k_scale, v_scale = gt + fill_kv_cache(k_states, + v_states, + k_caches, + v_caches, + q_start_loc, + q_seq_length, + kv_seq_length, + max_q_seq_length, + block_offsets, + quant_policy=quant_policy, + k_scale=k_scale, + v_scale=v_scale) + _assert_fp8_cache_close(k_caches, gt_k, fp8_dtype, k_scale) + _assert_fp8_cache_close(v_caches, gt_v, fp8_dtype, v_scale) + + +class TestFillKVCacheFP8E5M2Scalar(TestFillKVCacheFP8Scalar): + """Tests for fill_kv_cache with normal per-tensor QuantPolicy.FP8_E5M2.""" + + @pytest.fixture + def fp8_dtype(self): + yield torch.float8_e5m2 + + @pytest.fixture + def quant_policy(self): + yield QuantPolicy.FP8_E5M2 diff --git a/tests/pytorch/kernel/test_flatten_kv_cache.py b/tests/pytorch/kernel/test_flatten_kv_cache.py index 76d88f70eb..5108fdc251 100644 --- a/tests/pytorch/kernel/test_flatten_kv_cache.py +++ b/tests/pytorch/kernel/test_flatten_kv_cache.py @@ -175,6 +175,89 @@ def rtol(self): yield 1e-3 +def quant_fp8_scalar(kv: torch.Tensor, fp8_dtype: torch.dtype, scale: float): + """Quantize KV cache with one per-tensor FP8 scale.""" + fp8_max = torch.finfo(fp8_dtype).max + scale_t = kv.new_tensor(scale, dtype=torch.float32) + q_kv = (kv.to(torch.float32) / scale_t).clamp(-fp8_max, fp8_max).to(fp8_dtype) + dq_kv = (q_kv.to(torch.float32) * scale_t).to(kv.dtype) + return q_kv, scale_t, dq_kv + + +def _skip_unsupported_triton_fp8_dtype(fp8_dtype: torch.dtype): + if fp8_dtype is torch.float8_e4m3fn and torch.cuda.get_device_capability()[0] < 9: + pytest.skip('Triton float8_e4m3fn conversion requires device with cc>=9.0') + + +def flatten_reference(k_caches, v_caches, kv_lens, block_offsets, block_size, num_heads, out_size, k_head_dim, + v_head_dim): + """Reference flatten for paged KV cache tensors.""" + k_states = k_caches.new_empty(num_heads, out_size, k_head_dim) + v_states = v_caches.new_empty(num_heads, out_size, v_head_dim) + start_loc = 0 + for kv_len, block_offs in zip(kv_lens, block_offsets): + remain_len = kv_len + for idx, _ in enumerate(range(0, kv_len, block_size)): + b_off = block_offs[idx] + block_len = min(block_size, remain_len) + end_loc = start_loc + block_len + k_block = k_caches[b_off, :block_len] + v_block = v_caches[b_off, :block_len] + k_states[:, start_loc:end_loc] = k_block.transpose(0, 1) + v_states[:, start_loc:end_loc] = v_block.transpose(0, 1) + start_loc = end_loc + remain_len -= block_len + return k_states, v_states + + +class TestFlattenKVCacheFP8Scalar(TestFlattenKVCache): + + @pytest.fixture(autouse=True) + def skip_unsupported_fp8_dtype(self, fp8_dtype): + _skip_unsupported_triton_fp8_dtype(fp8_dtype) + + @pytest.fixture + def fp8_dtype(self): + yield torch.float8_e4m3fn + + @pytest.fixture + def quant_policy(self): + yield QuantPolicy.FP8 + + def test_flatten_kv_cache(self, k_caches, v_caches, kv_lens, kv_seqlens, block_offsets, block_size, num_heads, + out_size, head_dim, out_dtype, fp8_dtype, quant_policy): + from lmdeploy.pytorch.kernels.cuda.flatten_kv_cache import flatten_kv_cache + + k_caches_fp8, k_scale, k_dequant = quant_fp8_scalar(k_caches, fp8_dtype, scale=0.25) + v_caches_fp8, v_scale, v_dequant = quant_fp8_scalar(v_caches, fp8_dtype, scale=0.5) + gt = flatten_reference(k_dequant, v_dequant, kv_lens, block_offsets, block_size, num_heads, out_size, head_dim, + head_dim) + + k_states, v_states = flatten_kv_cache(k_caches_fp8, + v_caches_fp8, + kv_seqlens, + block_offsets, + out_size=out_size, + out_dtype=out_dtype, + k_scale=k_scale, + v_scale=v_scale, + quant_policy=quant_policy) + + torch.testing.assert_close(k_states, gt[0], atol=1e-3, rtol=1e-5) + torch.testing.assert_close(v_states, gt[1], atol=1e-3, rtol=1e-5) + + +class TestFlattenKVCacheFP8E5M2Scalar(TestFlattenKVCacheFP8Scalar): + + @pytest.fixture + def fp8_dtype(self): + yield torch.float8_e5m2 + + @pytest.fixture + def quant_policy(self): + yield QuantPolicy.FP8_E5M2 + + @pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason='require device with cc>=9.0') class TestFlattenKVCacheMLAFP8(TestFlattenKVCache): diff --git a/tests/pytorch/kernel/test_paged_attention.py b/tests/pytorch/kernel/test_paged_attention.py index 65eb5c3d8b..efcc1d6305 100644 --- a/tests/pytorch/kernel/test_paged_attention.py +++ b/tests/pytorch/kernel/test_paged_attention.py @@ -435,6 +435,45 @@ def _make_blocked_cache_quant(batched_k, batched_v, seq_lens, history_lens, bloc return blocked_k, blocked_v, blocked_ksz, blocked_vsz +def quant_fp8_scalar(kv: torch.Tensor, fp8_dtype: torch.dtype, scale: float): + """Quantize KV to FP8 with a per-tensor scale.""" + finfo = torch.finfo(fp8_dtype) + scale_t = kv.new_tensor(scale, dtype=torch.float32) + q_kv = torch.clamp(kv.to(torch.float32) / scale_t, finfo.min, finfo.max).to(fp8_dtype) + dq_kv = (q_kv.to(torch.float32) * scale_t).to(kv.dtype) + return q_kv, scale_t, dq_kv + + +def _skip_unsupported_triton_fp8_dtype(fp8_dtype: torch.dtype): + if fp8_dtype is torch.float8_e4m3fn and torch.cuda.get_device_capability()[0] < 9: + pytest.skip('Triton float8_e4m3fn conversion requires device with cc>=9.0') + + +def _make_blocked_cache_fp8_scalar(batched_k, batched_v, seq_lens, history_lens, block_offsets, block_size, + num_heads_k, feat_dim, feat_dim_v, fp8_dtype=torch.float8_e4m3fn): + max_blocks_nums = block_offsets.max() + 1 + full_seq_lens = seq_lens + history_lens + batched_k, k_scale, dequant_k = quant_fp8_scalar(batched_k, fp8_dtype, scale=0.25) + batched_v, v_scale, dequant_v = quant_fp8_scalar(batched_v, fp8_dtype, scale=0.5) + + blocked_k = batched_k.new_zeros(max_blocks_nums, block_size, num_heads_k, feat_dim) + blocked_v = batched_v.new_zeros(max_blocks_nums, block_size, num_heads_k, feat_dim_v) + + for batch_id, offset in enumerate(block_offsets): + ori_k = batched_k[batch_id] + ori_v = batched_v[batch_id] + seq_len = full_seq_lens[batch_id] + for block_id, block_start in enumerate(range(0, seq_len, block_size)): + block_off = offset[block_id] + tmp_k = ori_k[block_start:block_start + block_size] + tmp_v = ori_v[block_start:block_start + block_size] + size = tmp_k.size(0) + blocked_k[block_off, :size] = tmp_k + blocked_v[block_off, :size] = tmp_v + + return blocked_k, blocked_v, k_scale, v_scale, dequant_k, dequant_v + + class TestPagedAttentionInt8(TestPagedAttention): @pytest.fixture @@ -505,6 +544,66 @@ def nbits(self): yield 4 +class TestPagedAttentionFP8Scalar(TestPagedAttentionBase): + + @pytest.fixture(autouse=True) + def skip_unsupported_fp8_dtype(self, fp8_dtype): + _skip_unsupported_triton_fp8_dtype(fp8_dtype) + + @pytest.fixture + def fp8_dtype(self): + yield torch.float8_e4m3fn + + @pytest.fixture + def quant_policy(self): + from lmdeploy.messages import QuantPolicy + yield QuantPolicy.FP8 + + @pytest.fixture + def blocked_kv(self, batched_kv, seq_lens, history_lens, block_offsets, block_size, num_heads_k, feat_dim, + feat_dim_v, fp8_dtype): + batched_k, batched_v = batched_kv + yield _make_blocked_cache_fp8_scalar(batched_k, batched_v, seq_lens, history_lens, block_offsets, block_size, + num_heads_k, feat_dim, feat_dim_v, fp8_dtype) + + @pytest.fixture + def gt(self, batched_q, blocked_kv, mask): + _, _, _, _, dequant_k, dequant_v = blocked_kv + yield _naive_attention(batched_q, (dequant_k, dequant_v), mask) + + @pytest.mark.parametrize('feat_dim', [48, 32], indirect=True) + @pytest.mark.parametrize('feat_dim_v', [32], indirect=True) + @pytest.mark.parametrize(['num_heads_q', 'num_heads_k'], [(8, 2), (2, 2)], indirect=True) + @pytest.mark.parametrize('history_lens', [(50, 40, 30, 20)], indirect=True) + @pytest.mark.parametrize('block_size', [16], indirect=True) + def test_paged_attention(self, conti_q, blocked_kv, block_offsets, kv_seqlens, conti_gt, quant_policy): + from lmdeploy.pytorch.kernels.cuda import flash_attn_with_kvcache + + blocked_k, blocked_v, k_scale, v_scale, _, _ = blocked_kv + + out = flash_attn_with_kvcache(conti_q, + blocked_k, + blocked_v, + k_scale=k_scale, + v_scale=v_scale, + quant_policy=quant_policy, + page_table=block_offsets, + cache_seqlens=kv_seqlens) + torch.testing.assert_close(out, conti_gt, atol=1e-3, rtol=1e-5) + + +class TestPagedAttentionFP8E5M2Scalar(TestPagedAttentionFP8Scalar): + + @pytest.fixture + def fp8_dtype(self): + yield torch.float8_e5m2 + + @pytest.fixture + def quant_policy(self): + from lmdeploy.messages import QuantPolicy + yield QuantPolicy.FP8_E5M2 + + class TestPagedAttentionBlockDecoding(TestPagedAttentionBase): @pytest.fixture diff --git a/tests/test_lmdeploy/test_fp8_kv_cache_policy.py b/tests/test_lmdeploy/test_fp8_kv_cache_policy.py new file mode 100644 index 0000000000..5abbef7a69 --- /dev/null +++ b/tests/test_lmdeploy/test_fp8_kv_cache_policy.py @@ -0,0 +1,67 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +from types import SimpleNamespace + +import pytest +import torch +from pydantic_core import ValidationError + +from lmdeploy.cli.utils import ArgumentHelper +from lmdeploy.messages import PytorchEngineConfig, QuantPolicy, TurbomindEngineConfig +from lmdeploy.pytorch.config import CacheConfig +from lmdeploy.pytorch.engine.cache_engine import ( + CacheDesc, + CacheEngine, + _describe_kv_cache_quant_policy, + _get_fp8_cache_dtype, +) + + +def test_quant_policy_fp8_aliases(): + parser = argparse.ArgumentParser() + ArgumentHelper.quant_policy(parser) + + assert parser.parse_args(['--quant-policy', 'fp8']).quant_policy == QuantPolicy.FP8 + assert parser.parse_args(['--quant-policy', 'fp8_e4m3']).quant_policy == QuantPolicy.FP8 + assert parser.parse_args(['--quant-policy', 'fp8_e5m2']).quant_policy == QuantPolicy.FP8_E5M2 + assert parser.parse_args(['--quant-policy', '17']).quant_policy == QuantPolicy.FP8_E5M2 + + +def test_pytorch_config_accepts_fp8_quant_policies(): + config = PytorchEngineConfig(quant_policy=QuantPolicy.FP8_E5M2) + + assert config.quant_policy == QuantPolicy.FP8_E5M2 + + +def test_turbomind_config_rejects_fp8_e5m2_quant_policy(): + with pytest.raises(ValidationError, match='invalid quant_policy'): + TurbomindEngineConfig(quant_policy=QuantPolicy.FP8_E5M2) + + +def test_fp8_kv_cache_dtype_mapping(): + assert _get_fp8_cache_dtype(QuantPolicy.FP8) is torch.float8_e4m3fn + assert _get_fp8_cache_dtype(QuantPolicy.FP8_E5M2) is torch.float8_e5m2 + + +def test_fp8_kv_cache_log_description(): + assert 'fp8_e4m3' in _describe_kv_cache_quant_policy(QuantPolicy.FP8) + assert 'per-tensor' in _describe_kv_cache_quant_policy(QuantPolicy.FP8) + assert 'torch.float8_e4m3fn' in _describe_kv_cache_quant_policy(QuantPolicy.FP8) + assert 'fp8_e5m2' in _describe_kv_cache_quant_policy(QuantPolicy.FP8_E5M2) + assert 'per-tensor' in _describe_kv_cache_quant_policy(QuantPolicy.FP8_E5M2) + assert 'torch.float8_e5m2' in _describe_kv_cache_quant_policy(QuantPolicy.FP8_E5M2) + assert _describe_kv_cache_quant_policy(QuantPolicy.NONE) is None + + +def test_fp8_quant_cache_descs_are_empty(): + model_config = SimpleNamespace(dtype=torch.float16) + k_desc = CacheDesc(shape=[4, 16, 2, 128], dtype=torch.float8_e4m3fn) + v_desc = CacheDesc(shape=[4, 16, 2, 128], dtype=torch.float8_e4m3fn) + + normal_cache_config = CacheConfig(max_batches=1, + block_size=16, + num_cpu_blocks=0, + num_gpu_blocks=1, + quant_policy=QuantPolicy.FP8) + + assert CacheEngine.get_quant_cache_descs(k_desc, v_desc, model_config, normal_cache_config) == [] diff --git a/tests/test_lmdeploy/test_quant_policy.py b/tests/test_lmdeploy/test_quant_policy.py index 400a18b06c..15f646f237 100644 --- a/tests/test_lmdeploy/test_quant_policy.py +++ b/tests/test_lmdeploy/test_quant_policy.py @@ -18,6 +18,10 @@ MODEL_ID = 'Qwen/Qwen3-8B' +def _e4m3_fp8_requires_sm90(): + return torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 9 + + # ============================================================================= # Shared Fixtures # ============================================================================= @@ -28,11 +32,12 @@ def model_id(): return MODEL_ID -@pytest.fixture(scope='session') +@pytest.fixture(scope='class') def pipe_no_quant(model_id): """Create pipeline without quantization (baseline). - This fixture has session scope to avoid reloading the model for each test. Caller is responsible for cleanup. + This fixture has class scope so large model instances are released before later FP8 accuracy tests allocate their + own pipelines. """ engine_config = PytorchEngineConfig( tp=1, @@ -49,11 +54,12 @@ def pipe_no_quant(model_id): torch.cuda.empty_cache() -@pytest.fixture(scope='session') +@pytest.fixture(scope='class') def pipe_quant_42(model_id): """Create pipeline with quant_policy=QuantPolicy.TURBO_QUANT. - This fixture has session scope to avoid reloading the model for each test. Caller is responsible for cleanup. + This fixture has class scope so large model instances are released before later FP8 accuracy tests allocate their + own pipelines. """ engine_config = PytorchEngineConfig( tp=1, @@ -287,3 +293,229 @@ def test_logprobs_sanity(self, pipe_no_quant, pipe_quant_42): assert isinstance(response_quant.logprobs, list) else: print('\nLogprobs not available (this is expected for some configurations)') + + +# ============================================================================= +# FP8 Tests (QuantPolicy.FP8) +# ============================================================================= + + +@pytest.mark.skipif(_e4m3_fp8_requires_sm90(), reason='Triton float8_e4m3fn conversion requires device with cc>=9.0') +class TestQuantPolicyFP8Basic: + """Basic functional tests for quant_policy=QuantPolicy.FP8.""" + + @pytest.fixture(scope='class') + def pipe(self): + """Create pipeline with quant_policy=QuantPolicy.FP8.""" + engine_config = PytorchEngineConfig( + tp=1, + cache_max_entry_count=0.1, + quant_policy=QuantPolicy.FP8, + ) + pipe = pipeline(MODEL_ID, backend_config=engine_config, log_level='INFO') + yield pipe + pipe.close() + del pipe + gc.collect() + if torch.cuda.is_available() and torch.cuda.device_count() > 0: + torch.cuda.empty_cache() + + def test_infer_single_prompt(self, pipe): + """Test single prompt inference with quant_policy=QuantPolicy.FP8.""" + prompt = 'Hello, how are you?' + response = pipe.infer(prompt, max_new_tokens=30) + + assert isinstance(response, Response) + assert hasattr(response, 'text') + assert len(response.text) > 0 + assert len(response.text.strip()) > 0 + + def test_infer_batch_prompts(self, pipe): + """Test batch inference with quant_policy=QuantPolicy.FP8.""" + prompts = ['What is AI?', 'Hello!'] + responses = pipe.infer(prompts, max_new_tokens=20) + + assert isinstance(responses, list) + assert len(responses) == len(prompts) + for resp in responses: + assert isinstance(resp, Response) + assert len(resp.text) > 0 + + def test_infer_with_generation_config(self, pipe): + """Test inference with GenerationConfig.""" + gen_config = GenerationConfig(max_new_tokens=20, temperature=0.7) + prompt = 'Tell me a short joke' + response = pipe.infer(prompt, gen_config=gen_config) + + assert isinstance(response, Response) + assert len(response.text) > 0 + + +@pytest.mark.skipif(_e4m3_fp8_requires_sm90(), reason='Triton float8_e4m3fn conversion requires device with cc>=9.0') +class TestQuantPolicyFP8Accuracy: + """Accuracy tests comparing quant_policy=QuantPolicy.FP8 against non- + quantized baseline. + + FP8 (float8_e4m3fn, per-tensor scale) is more precise than 4-bit TurboQuant, so + thresholds are tighter: MAE < 0.05, Max AE < 0.3. + + Uses class-scoped fixtures to avoid holding three models in GPU memory simultaneously + when running the full test suite alongside TestQuantPolicy42Accuracy. + """ + + @pytest.fixture(scope='class') + def pipe_no_quant(self): + """Class-scoped baseline pipeline (no quantization).""" + engine_config = PytorchEngineConfig( + tp=1, + cache_max_entry_count=0.05, + quant_policy=QuantPolicy.NONE, + ) + pipe = pipeline(MODEL_ID, backend_config=engine_config, log_level='INFO') + yield pipe + pipe.close() + del pipe + gc.collect() + if torch.cuda.is_available() and torch.cuda.device_count() > 0: + torch.cuda.empty_cache() + + @pytest.fixture(scope='class') + def pipe_quant_fp8(self): + """Class-scoped FP8 pipeline.""" + engine_config = PytorchEngineConfig( + tp=1, + cache_max_entry_count=0.05, + quant_policy=QuantPolicy.FP8, + ) + pipe = pipeline(MODEL_ID, backend_config=engine_config, log_level='INFO') + yield pipe + pipe.close() + del pipe + gc.collect() + if torch.cuda.is_available() and torch.cuda.device_count() > 0: + torch.cuda.empty_cache() + + def test_logits_accuracy(self, pipe_no_quant, pipe_quant_fp8): + """Test logits accuracy by comparing FP8 and non-quantized output + logits.""" + gen_config = GenerationConfig( + max_new_tokens=0, + temperature=0.0, + top_p=1.0, + top_k=1, + output_logits='all', + ) + + prompt = 'The capital of France is' + + response_no_quant = pipe_no_quant.infer(prompt, gen_config=gen_config) + response_quant = pipe_quant_fp8.infer(prompt, gen_config=gen_config) + + assert isinstance(response_no_quant, Response) + assert isinstance(response_quant, Response) + + if response_no_quant.logits is not None and response_quant.logits is not None: + logits_no_quant = response_no_quant.logits + logits_quant = response_quant.logits + + assert logits_no_quant.shape == logits_quant.shape, \ + f'Logits shape mismatch: {logits_no_quant.shape} vs {logits_quant.shape}' + + abs_error = (logits_no_quant - logits_quant).abs() + mean_abs_error = abs_error.mean().item() + max_abs_error = abs_error.max().item() + + print('\nFP8 logits accuracy metrics:') + print(f' Mean absolute error: {mean_abs_error:.6f}') + print(f' Max absolute error: {max_abs_error:.6f}') + + assert mean_abs_error < 0.05, \ + f'Mean absolute error {mean_abs_error:.6f} exceeds threshold 0.05' + assert max_abs_error < 0.3, \ + f'Max absolute error {max_abs_error:.6f} exceeds threshold 0.3' + else: + pytest.skip('Logits not available for comparison') + + def test_token_accuracy(self, pipe_no_quant, pipe_quant_fp8): + """Test token-level accuracy comparing FP8 and non-quantized output.""" + gen_config = GenerationConfig( + max_new_tokens=20, + temperature=0.0, + top_p=1.0, + top_k=1, + ) + + prompt = 'Hello, how are you?' + + response_no_quant = pipe_no_quant.infer(prompt, gen_config=gen_config) + response_quant = pipe_quant_fp8.infer(prompt, gen_config=gen_config) + + assert isinstance(response_no_quant, Response) + assert isinstance(response_quant, Response) + + tokens_no_quant = response_no_quant.token_ids + tokens_quant = response_quant.token_ids + + min_len = min(len(tokens_no_quant), len(tokens_quant)) + if min_len > 0: + matching_tokens = sum(1 for i in range(min_len) if tokens_no_quant[i] == tokens_quant[i]) + match_rate = matching_tokens / min_len + + print('\nFP8 token accuracy metrics:') + print(f' Baseline tokens: {len(tokens_no_quant)}') + print(f' FP8 tokens: {len(tokens_quant)}') + print(f' Matching tokens: {matching_tokens}/{min_len}') + print(f' Match rate: {match_rate:.2%}') + + assert len(tokens_no_quant) > 0, 'Baseline produced no tokens' + assert len(tokens_quant) > 0, 'FP8 model produced no tokens' + else: + pytest.skip('No tokens generated for comparison') + + def test_text_quality(self, pipe_no_quant, pipe_quant_fp8): + """Test that FP8 output is still meaningful text.""" + gen_config = GenerationConfig( + max_new_tokens=30, + temperature=0.7, + top_p=0.9, + ) + + prompt = 'Write a short story about a robot.' + + response_no_quant = pipe_no_quant.infer(prompt, gen_config=gen_config) + response_quant = pipe_quant_fp8.infer(prompt, gen_config=gen_config) + + assert isinstance(response_no_quant, Response) + assert isinstance(response_quant, Response) + + assert len(response_no_quant.text.strip()) > 0, 'Baseline output is empty' + assert len(response_quant.text.strip()) > 0, 'FP8 output is empty' + + print('\nFP8 text quality metrics:') + print(f' Baseline text length: {len(response_no_quant.text)}') + print(f' FP8 text length: {len(response_quant.text)}') + + def test_logprobs_sanity(self, pipe_no_quant, pipe_quant_fp8): + """Test that logprobs are reasonable when available.""" + gen_config = GenerationConfig( + max_new_tokens=10, + temperature=0.0, + top_p=1.0, + top_k=1, + logprobs=1, + ) + + prompt = 'What is 2+2?' + + response_no_quant = pipe_no_quant.infer(prompt, gen_config=gen_config) + response_quant = pipe_quant_fp8.infer(prompt, gen_config=gen_config) + + assert isinstance(response_no_quant, Response) + assert isinstance(response_quant, Response) + + if response_no_quant.logprobs is not None and response_quant.logprobs is not None: + print('\nFP8 logprobs available for both models') + assert isinstance(response_no_quant.logprobs, list) + assert isinstance(response_quant.logprobs, list) + else: + print('\nLogprobs not available (this is expected for some configurations)')