diff --git a/angelslim/compressor/speculative/inference/models/eagle3/draft/llama3_eagle3.py b/angelslim/compressor/speculative/inference/models/eagle3/draft/llama3_eagle3.py index e77c79ec..2d513619 100644 --- a/angelslim/compressor/speculative/inference/models/eagle3/draft/llama3_eagle3.py +++ b/angelslim/compressor/speculative/inference/models/eagle3/draft/llama3_eagle3.py @@ -240,17 +240,13 @@ def __init__(self, config): self.config = config self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads + self.head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.max_position_embeddings = config.max_position_embeddings - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads " - f"(got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) self.q_proj = nn.Linear( self.hidden_size * 2, self.num_heads * self.head_dim, bias=False ) @@ -400,15 +396,8 @@ def forward( ).to(query_states.dtype) attn_output = torch.matmul(attn_weights, value_states) - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size " - f"{(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = attn_output.reshape(bsz, q_len, -1) if self.config.pretraining_tp > 1: attn_output = attn_output.split( diff --git a/docs/source/performance/speculative_decoding/benchmarks.md b/docs/source/performance/speculative_decoding/benchmarks.md index 63cd82c1..f73c2a4f 100644 --- a/docs/source/performance/speculative_decoding/benchmarks.md +++ b/docs/source/performance/speculative_decoding/benchmarks.md @@ -23,28 +23,13 @@ ### Hunyuan Series Models -
|   |   | -MT-bench | -HumanEval | -GSM8K | -Alpaca | -Mean | |||||
|---|---|---|---|---|---|---|---|---|---|---|---|
| Temperature | Model | Speedup | τ | Speedup | τ | Speedup | τ | Speedup | τ | Speedup | τ |
| Temperature=0 | -Hunyuan-1.8B-Instruct | 1.97x | 2.90 | 2.58x | 3.73 | 2.61x | 3.71 | 1.71x | 2.43 | 2.22x | 3.19 |
| Hunyuan-4B-Instruct | 1.77x | 2.60 | 2.64x | 3.35 | 2.14x | 3.17 | 1.72x | 2.57 | 2.07x | 2.92 | |
| Hunyuan-7B-Instruct | 2.22x | 3.58 | 3.59x | 5.47 | 2.96x | 4.68 | 1.64x | 2.56 | 2.60x | 4.07 | |
| Temperature=1 | -Hunyuan-1.8B-Instruct | 1.58x | 2.36 | 2.35x | 3.56 | 2.23x | 3.38 | 1.26x | 1.87 | 1.86x | 2.79 |
| Hunyuan-4B-Instruct | 1.36x | 2.05 | 1.97x | 2.86 | 1.72x | 2.68 | 1.14x | 1.76 | 1.55x | 2.34 | |
| Hunyuan-7B-Instruct | 1.90x | 3.11 | 3.12x | 5.09 | 2.74x | 4.34 | 1.47x | 2.39 | 2.31x | 3.73 | |