Skip to content

Commit 6411b3a

Browse files
authored
Update speculative (Tencent#71)
1 parent 7c5cc3c commit 6411b3a

3 files changed

Lines changed: 17 additions & 41 deletions

File tree

angelslim/compressor/speculative/inference/models/eagle3/draft/llama3_eagle3.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -240,17 +240,13 @@ def __init__(self, config):
240240
self.config = config
241241
self.hidden_size = config.hidden_size
242242
self.num_heads = config.num_attention_heads
243-
self.head_dim = self.hidden_size // self.num_heads
243+
self.head_dim = getattr(
244+
config, "head_dim", config.hidden_size // config.num_attention_heads
245+
)
244246
self.num_key_value_heads = config.num_key_value_heads
245247
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
246248
self.max_position_embeddings = config.max_position_embeddings
247249

248-
if (self.head_dim * self.num_heads) != self.hidden_size:
249-
raise ValueError(
250-
f"hidden_size must be divisible by num_heads "
251-
f"(got `hidden_size`: {self.hidden_size}"
252-
f" and `num_heads`: {self.num_heads})."
253-
)
254250
self.q_proj = nn.Linear(
255251
self.hidden_size * 2, self.num_heads * self.head_dim, bias=False
256252
)
@@ -400,15 +396,8 @@ def forward(
400396
).to(query_states.dtype)
401397
attn_output = torch.matmul(attn_weights, value_states)
402398

403-
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
404-
raise ValueError(
405-
f"`attn_output` should be of size "
406-
f"{(bsz, self.num_heads, q_len, self.head_dim)}, but is"
407-
f" {attn_output.size()}"
408-
)
409-
410399
attn_output = attn_output.transpose(1, 2).contiguous()
411-
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
400+
attn_output = attn_output.reshape(bsz, q_len, -1)
412401

413402
if self.config.pretraining_tp > 1:
414403
attn_output = attn_output.split(

docs/source/performance/speculative_decoding/benchmarks.md

Lines changed: 10 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -23,28 +23,13 @@
2323

2424
### Hunyuan Series Models
2525

26-
<table>
27-
<thead>
28-
<tr>
29-
<th>&nbsp</th><th>&nbsp</th>
30-
<th colspan="2" style="text-align: center; vertical-align: middle;">MT-bench</th>
31-
<th colspan="2" style="text-align: center; vertical-align: middle;">HumanEval</th>
32-
<th colspan="2" style="text-align: center; vertical-align: middle;">GSM8K</th>
33-
<th colspan="2" style="text-align: center; vertical-align: middle;">Alpaca</th>
34-
<th colspan="2" style="text-align: center; vertical-align: middle;">Mean</th></tr>
35-
<tr><th>Temperature</th><th>Model</th><th>Speedup</th><th>τ</th><th>Speedup</th><th>τ</th><th>Speedup</th><th>τ</th><th>Speedup</th><th>τ</th><th>Speedup</th><th>τ</th></tr>
36-
</thead>
37-
<tbody>
38-
<!-- <tr><td colspan="12" style="text-align: center; vertical-align: middle;"><strong>Temperature=0</strong></td></tr> -->
39-
<tr><td rowspan="3"><strong>Temperature=0</strong></td>
40-
<td>Hunyuan-1.8B-Instruct</td><td>1.97x</td><td>2.90</td><td>2.58x</td><td>3.73</td><td>2.61x</td><td>3.71</td><td>1.71x</td><td>2.43</td><td>2.22x</td><td>3.19</td></tr>
41-
<tr> <td>Hunyuan-4B-Instruct</td><td>1.77x</td><td>2.60</td><td>2.64x</td><td>3.35</td><td>2.14x</td><td>3.17</td><td>1.72x</td><td>2.57</td><td>2.07x</td><td>2.92</td></tr>
42-
<tr><td>Hunyuan-7B-Instruct</td><td>2.22x</td><td>3.58</td><td>3.59x</td><td>5.47</td><td>2.96x</td><td>4.68</td><td>1.64x</td><td>2.56</td><td>2.60x</td><td>4.07</td></tr>
43-
<!-- <tr><td colspan="12" style="text-align: center; vertical-align: middle;"><strong>Temperature=1</strong></td></tr> -->
44-
<tr><td rowspan="3"><strong>Temperature=1</strong></td>
45-
<td>Hunyuan-1.8B-Instruct</td><td>1.58x</td><td>2.36</td><td>2.35x</td><td>3.56</td><td>2.23x</td><td>3.38</td><td>1.26x</td><td>1.87</td><td>1.86x</td><td>2.79</td></tr>
46-
<tr><td>Hunyuan-4B-Instruct</td><td>1.36x</td><td>2.05</td><td>1.97x</td><td>2.86</td><td>1.72x</td><td>2.68</td><td>1.14x</td><td>1.76</td><td>1.55x</td><td>2.34</td></tr>
47-
<tr><td>Hunyuan-7B-Instruct</td><td>1.90x</td><td>3.11</td><td>3.12x</td><td>5.09</td><td>2.74x</td><td>4.34</td><td>1.47x</td><td>2.39</td><td>2.31x</td><td>3.73</td></tr>
48-
</tbody>
49-
</table>
50-
</table>
26+
| | | MT-bench | | HumanEval | | GSM8K | | Alpaca | | Mean | |
27+
|------------------|--------------|------------------|------------|-------------------|-------------|----------------|---------|----------------|----------|---------------|--------|
28+
| | Model | Speedup | τ | Speedup | τ | Speedup | τ | Speedup | τ | Speedup | τ |
29+
| | Hunyuan-1.8B | 1.97x | 2.90 | 2.58x | 3.73 | 2.61x | 3.71 | 1.71x | 2.43 | 2.22x | 3.19 |
30+
| **Temperature=0**| Hunyuan-4B | 1.77x | 2.60 | 2.64x | 3.35 | 2.14x | 3.17 | 1.72x | 2.57 | 2.07x | 2.92 |
31+
| | Hunyuan-7B | 2.22x | 3.58 | 3.59x | 5.47 | 2.96x | 4.68 | 1.64x | 2.56 | 2.60x | 4.07 |
32+
| | | | | | | | | | | | |
33+
| | Hunyuan-1.8B | 1.58x | 2.36 | 2.35x | 3.56 | 2.23x | 3.38 | 1.26x | 1.87 | 1.86x | 2.79 |
34+
| **Temperature=1**| Hunyuan-1.8B | 1.36x | 2.05 | 1.97x | 2.86 | 1.72x | 2.68 | 1.14x | 1.76 | 1.55x | 2.34 |
35+
| | Hunyuan-1.8B | 1.90x | 3.11 | 3.12x | 5.09 | 2.74x | 4.34 | 1.47x | 2.39 | 2.31x | 3.73 |

requirements.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,6 @@ datasets
1414
fschat
1515
openai
1616
anthropic
17-
ray
17+
ray
18+
referencing
19+
jsonschema_specifications

0 commit comments

Comments
 (0)