Skip to content

Commit 85d4982

Browse files
committed
fix tests
Signed-off-by: Ye Yu <yeyu@nvidia.com>
1 parent 2fc1837 commit 85d4982

1 file changed

Lines changed: 2 additions & 4 deletions

File tree

tests/gpu_megatron/torch/speculative/plugins/test_speculative_megatron_modules.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,7 @@ def _test_speculative_gpt_model(algo, num_layers, activation_func, normalization
104104
("eagle3", 2, "swiglu", "RMSNorm"), # GQA
105105
],
106106
)
107-
def test_speculative_gpt_model(
108-
algo, num_medusa_heads_or_eagle_layers, activation_func, normalization
109-
):
107+
def test_speculative_gpt_model(algo, num_layers, activation_func, normalization):
110108
if algo == "eagle3":
111109
try:
112110
import megatron.core.post_training # noqa: F401
@@ -118,7 +116,7 @@ def test_speculative_gpt_model(
118116
job=partial(
119117
_test_speculative_gpt_model,
120118
algo,
121-
num_medusa_heads_or_eagle_layers,
119+
num_layers,
122120
activation_func,
123121
normalization,
124122
),

0 commit comments

Comments
 (0)