66
77import pytest
88
9- from forge .actors .generator import Generator as Policy
9+ from forge .actors .generator import Generator
1010from vllm import SamplingParams
1111from vllm .engine .arg_utils import AsyncEngineArgs
1212from vllm .sampling_params import RequestOutputKind
3030
3131@pytest .mark .asyncio
3232async def test_same_output ():
33- """Compare outputs between vLLM and Policy service"""
33+ """Compare outputs between vLLM and Generator service"""
3434 test_prompts = [
3535 "Hello, how are you?" ,
3636 "What is 2+2?" ,
3737 "Tell me a joke." ,
3838 "Explain machine learning briefly." ,
3939 "What color is the sky?" ,
4040 ]
41- policy = None
41+ generator = None
4242 try :
4343 # Setup vLLM directly
4444 args = AsyncEngineArgs (
@@ -50,8 +50,8 @@ async def test_same_output():
5050 )
5151 vllm_model = AsyncLLM .from_engine_args (args )
5252
53- # Setup Policy service
54- policy = await Policy .options (
53+ # Setup Generator service
54+ generator = await Generator .options (
5555 procs = 1 , num_replicas = 1 , with_gpus = True
5656 ).as_service (
5757 engine_args = {
@@ -72,7 +72,7 @@ async def test_same_output():
7272
7373 print ("Models ready. Generating outputs...\n " )
7474 vllm_outputs = []
75- policy_outputs = []
75+ generator_outputs = []
7676 sampling_params = SamplingParams (
7777 max_tokens = MAX_TOKENS ,
7878 temperature = TEMPERATURE ,
@@ -89,19 +89,19 @@ async def test_same_output():
8989 vllm_outputs .append (res .outputs [0 ].text )
9090
9191 # Policy generation
92- policy_result = await policy .generate .route (prompt )
92+ policy_result = await generator .generate .route (prompt )
9393 policy_text = policy_result [0 ].text
94- policy_outputs .append (policy_text )
94+ generator_outputs .append (policy_text )
9595
9696 # Final check
97- for vllm_output , policy_output in zip (vllm_outputs , policy_outputs ):
97+ for vllm_output , generator_output in zip (vllm_outputs , generator_outputs ):
9898 assert vllm_output != ""
99- assert policy_output != ""
100- assert vllm_output == policy_output
99+ assert generator_output != ""
100+ assert vllm_output == generator_output
101101
102102 finally :
103- if policy is not None :
104- await policy .shutdown ()
103+ if generator is not None :
104+ await generator .shutdown ()
105105
106106
107107@pytest .mark .asyncio
@@ -126,7 +126,7 @@ async def test_cache_usage():
126126 via the AsyncLLM interface.
127127 - We do not test different different block sizes.
128128 """
129- policy = None
129+ generator = None
130130 try :
131131 # Setup vLLM directly
132132 args = AsyncEngineArgs (
@@ -139,8 +139,8 @@ async def test_cache_usage():
139139 )
140140 vllm_model = AsyncLLM .from_engine_args (args )
141141
142- # Setup Policy service
143- policy = await Policy .options (
142+ # Setup Generator service
143+ generator = await Generator .options (
144144 procs = 1 , num_replicas = 1 , with_gpus = True
145145 ).as_service (
146146 engine_args = {
@@ -170,7 +170,7 @@ async def test_cache_usage():
170170 output_kind = RequestOutputKind .FINAL_ONLY ,
171171 )
172172 vllm_outputs = []
173- policy_outputs = []
173+ generator_outputs = []
174174
175175 # Exactly 16 tokens to fill up 1 block
176176 first_prompt = (
@@ -182,9 +182,9 @@ async def test_cache_usage():
182182 ):
183183 vllm_outputs .append (res .outputs [0 ].text )
184184 assert res .num_cached_tokens == expected_cached_tokens
185- res = await policy .generate .route (first_prompt )
185+ res = await generator .generate .route (first_prompt )
186186 assert res [0 ].metadata ["num_cached_tokens" ] == expected_cached_tokens
187- policy_outputs .append (res [0 ].text )
187+ generator_outputs .append (res [0 ].text )
188188
189189 # Another 16 tokens to now populate 2 blocks (+ reuse the first block)
190190 second_prompt = (
@@ -197,9 +197,9 @@ async def test_cache_usage():
197197 ):
198198 vllm_outputs .append (res .outputs [0 ].text )
199199 assert res .num_cached_tokens == expected_cached_tokens
200- res = await policy .generate .route (second_prompt )
200+ res = await generator .generate .route (second_prompt )
201201 assert res [0 ].metadata ["num_cached_tokens" ] == expected_cached_tokens
202- policy_outputs .append (res [0 ].text )
202+ generator_outputs .append (res [0 ].text )
203203
204204 # The first same 32 tokens should now be populated in blocks
205205 third_prompt = second_prompt
@@ -209,13 +209,13 @@ async def test_cache_usage():
209209 ):
210210 vllm_outputs .append (res .outputs [0 ].text )
211211 assert res .num_cached_tokens == expected_cached_tokens
212- res = await policy .generate .route (third_prompt )
212+ res = await generator .generate .route (third_prompt )
213213 assert res [0 ].metadata ["num_cached_tokens" ] == expected_cached_tokens
214- policy_outputs .append (res [0 ].text )
214+ generator_outputs .append (res [0 ].text )
215215
216216 # Now, let's clear the cache
217217 await vllm_model .reset_prefix_cache ()
218- await policy ._reset_prefix_cache .route ()
218+ await generator ._reset_prefix_cache .route ()
219219
220220 # And try the third prompt again (should not use any cached tokens)
221221 expected_cached_tokens = 0
@@ -224,16 +224,16 @@ async def test_cache_usage():
224224 ):
225225 vllm_outputs .append (res .outputs [0 ].text )
226226 assert res .num_cached_tokens == expected_cached_tokens
227- res = await policy .generate .route (third_prompt )
227+ res = await generator .generate .route (third_prompt )
228228 assert res [0 ].metadata ["num_cached_tokens" ] == expected_cached_tokens
229- policy_outputs .append (res [0 ].text )
229+ generator_outputs .append (res [0 ].text )
230230
231231 # Sanity check that outputs are still the same
232- for vllm_output , policy_output in zip (vllm_outputs , policy_outputs ):
232+ for vllm_output , generator_output in zip (vllm_outputs , generator_outputs ):
233233 assert vllm_output != ""
234- assert policy_output != ""
235- assert vllm_output == policy_output
234+ assert generator_output != ""
235+ assert vllm_output == generator_output
236236
237237 finally :
238- if policy is not None :
239- await policy .shutdown ()
238+ if generator is not None :
239+ await generator .shutdown ()
0 commit comments