Skip to content

Commit b4b0e6a

Browse files
Fix some remaining references of 'policy' to 'generator'
Differential Revision: D89990890 Pull Request resolved: #692
1 parent d60c5eb commit b4b0e6a

6 files changed

Lines changed: 48 additions & 46 deletions

File tree

apps/grpo/qwen3_1_7b.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ off_by_n: 1 # Off by one by default
1111
compile: true # Enable torch.compile for trainer/ref_model, and CUDA graphs for vLLM
1212

1313
# Main loop configuration
14-
rollout_threads: 1 # Recommended to set equal to policy.num_replicas
14+
rollout_threads: 1 # Recommended to set equal to generator.num_replicas
1515

1616

1717
# Observability configuration

benchmarks/generator/throughput.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,10 +105,12 @@ async def run_throughput_benchmark(
105105
)
106106

107107
print("Spawning Generator service...")
108-
generator = await Generator.options(**cfg.services.policy).as_service(**cfg.policy)
108+
generator = await Generator.options(**cfg.services.generator).as_service(
109+
**cfg.generator
110+
)
109111

110112
print(f"Generating {num_requests} benchmark requests...")
111-
model_name = cfg.policy.engine_args.get("model", "unknown")
113+
model_name = cfg.generator.engine_args.get("model", "unknown")
112114
tokenizer = get_tokenizer(
113115
model_name,
114116
tokenizer_mode="auto",

docs/source/index.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,12 +100,12 @@ Before diving in, check out {doc}`getting_started` and ensure your system meets
100100
With TorchForge, your RL logic looks like pseudocode:
101101

102102
```python
103-
async def generate_episode(dataloader, policy, reward, replay_buffer):
103+
async def generate_episode(dataloader, generator, reward, replay_buffer):
104104
# Sample a prompt
105105
prompt, target = await dataloader.sample.route()
106106

107107
# Generate response
108-
response = await policy.generate.route(prompt)
108+
response = await generator.generate.route(prompt)
109109

110110
# Score the response
111111
reward_value = await reward.evaluate_response.route(

docs/source/tutorial_sources/zero-to-forge/1_RL_and_Forge_Fundamentals.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ graph TD
3737
### RL Components Defined (TorchForge Names)
3838

3939
1. **Dataset**: Provides questions/prompts (like "What is 2+2?")
40-
2. **Policy**: The AI being trained (generates answers like "The answer is 4")
40+
2. **Generator**: The policy being trained (generates answers like "The answer is 4")
4141
3. **Reward Model**: Evaluates answer quality (gives scores like 0.95)
4242
4. **Reference Model**: Original policy copy (prevents drift from baseline)
4343
5. **Replay Buffer**: Stores experiences (question + answer + score)
@@ -53,7 +53,7 @@ def conceptual_rl_step():
5353
question = dataset.sample() # "What is 2+2?"
5454

5555
# 2. Student generates answer
56-
answer = policy.generate(question) # "The answer is 4"
56+
answer = generator.generate(question) # "The answer is 4"
5757

5858
# 3. Teacher grades it
5959
score = reward_model.evaluate(question, answer) # 0.95
@@ -289,7 +289,7 @@ async def real_rl_training_step(services, step):
289289

290290
### Automatic Resource Management
291291
```python
292-
responses = await policy.generate.route(prompt=question)
292+
responses = await generator.generate.route(prompt=question)
293293
answer = responses[0].text # responses is list[Completion]
294294
```
295295

@@ -333,7 +333,7 @@ group_size = 1
333333
model=model,
334334
),
335335
# Policy service with GPU
336-
Policy.options(procs=1, with_gpus=True, num_replicas=1).as_service(
336+
Generator.options(procs=1, with_gpus=True, num_replicas=1).as_service(
337337
engine_config={
338338
"model": model,
339339
"tensor_parallel_size": 1,
@@ -381,15 +381,15 @@ TorchForge has two types of distributed components:
381381
- **Actors**: Single instances that handle their own internal distribution (like TitanTrainer, ReplayBuffer)
382382

383383
We cover this distinction in detail in Part 2, but for now this explains the scaling patterns:
384-
- Policy service: num_replicas=8 for high inference demand
384+
- Generator service: num_replicas=8 for high inference demand
385385
- RewardActor service: num_replicas=16 for parallel evaluation
386386
- TitanTrainer actor: Single instance with internal distributed training
387387

388388

389389
### Fault Tolerance
390390
```python
391-
# If a policy replica fails:
392-
responses = await policy.generate.route(prompt=question)
391+
# If a generator replica fails:
392+
responses = await generator.generate.route(prompt=question)
393393
answer = responses[0].text
394394
# -> TorchForge automatically routes to healthy replica
395395
# -> Failed replica respawns in background

tests/integration_tests/test_vllm_policy_correctness.py

Lines changed: 31 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import pytest
88

9-
from forge.actors.generator import Generator as Policy
9+
from forge.actors.generator import Generator
1010
from vllm import SamplingParams
1111
from vllm.engine.arg_utils import AsyncEngineArgs
1212
from vllm.sampling_params import RequestOutputKind
@@ -30,15 +30,15 @@
3030

3131
@pytest.mark.asyncio
3232
async def test_same_output():
33-
"""Compare outputs between vLLM and Policy service"""
33+
"""Compare outputs between vLLM and Generator service"""
3434
test_prompts = [
3535
"Hello, how are you?",
3636
"What is 2+2?",
3737
"Tell me a joke.",
3838
"Explain machine learning briefly.",
3939
"What color is the sky?",
4040
]
41-
policy = None
41+
generator = None
4242
try:
4343
# Setup vLLM directly
4444
args = AsyncEngineArgs(
@@ -50,8 +50,8 @@ async def test_same_output():
5050
)
5151
vllm_model = AsyncLLM.from_engine_args(args)
5252

53-
# Setup Policy service
54-
policy = await Policy.options(
53+
# Setup Generator service
54+
generator = await Generator.options(
5555
procs=1, num_replicas=1, with_gpus=True
5656
).as_service(
5757
engine_args={
@@ -72,7 +72,7 @@ async def test_same_output():
7272

7373
print("Models ready. Generating outputs...\n")
7474
vllm_outputs = []
75-
policy_outputs = []
75+
generator_outputs = []
7676
sampling_params = SamplingParams(
7777
max_tokens=MAX_TOKENS,
7878
temperature=TEMPERATURE,
@@ -89,19 +89,19 @@ async def test_same_output():
8989
vllm_outputs.append(res.outputs[0].text)
9090

9191
# Policy generation
92-
policy_result = await policy.generate.route(prompt)
92+
policy_result = await generator.generate.route(prompt)
9393
policy_text = policy_result[0].text
94-
policy_outputs.append(policy_text)
94+
generator_outputs.append(policy_text)
9595

9696
# Final check
97-
for vllm_output, policy_output in zip(vllm_outputs, policy_outputs):
97+
for vllm_output, generator_output in zip(vllm_outputs, generator_outputs):
9898
assert vllm_output != ""
99-
assert policy_output != ""
100-
assert vllm_output == policy_output
99+
assert generator_output != ""
100+
assert vllm_output == generator_output
101101

102102
finally:
103-
if policy is not None:
104-
await policy.shutdown()
103+
if generator is not None:
104+
await generator.shutdown()
105105

106106

107107
@pytest.mark.asyncio
@@ -126,7 +126,7 @@ async def test_cache_usage():
126126
via the AsyncLLM interface.
127127
- We do not test different different block sizes.
128128
"""
129-
policy = None
129+
generator = None
130130
try:
131131
# Setup vLLM directly
132132
args = AsyncEngineArgs(
@@ -139,8 +139,8 @@ async def test_cache_usage():
139139
)
140140
vllm_model = AsyncLLM.from_engine_args(args)
141141

142-
# Setup Policy service
143-
policy = await Policy.options(
142+
# Setup Generator service
143+
generator = await Generator.options(
144144
procs=1, num_replicas=1, with_gpus=True
145145
).as_service(
146146
engine_args={
@@ -170,7 +170,7 @@ async def test_cache_usage():
170170
output_kind=RequestOutputKind.FINAL_ONLY,
171171
)
172172
vllm_outputs = []
173-
policy_outputs = []
173+
generator_outputs = []
174174

175175
# Exactly 16 tokens to fill up 1 block
176176
first_prompt = (
@@ -182,9 +182,9 @@ async def test_cache_usage():
182182
):
183183
vllm_outputs.append(res.outputs[0].text)
184184
assert res.num_cached_tokens == expected_cached_tokens
185-
res = await policy.generate.route(first_prompt)
185+
res = await generator.generate.route(first_prompt)
186186
assert res[0].metadata["num_cached_tokens"] == expected_cached_tokens
187-
policy_outputs.append(res[0].text)
187+
generator_outputs.append(res[0].text)
188188

189189
# Another 16 tokens to now populate 2 blocks (+ reuse the first block)
190190
second_prompt = (
@@ -197,9 +197,9 @@ async def test_cache_usage():
197197
):
198198
vllm_outputs.append(res.outputs[0].text)
199199
assert res.num_cached_tokens == expected_cached_tokens
200-
res = await policy.generate.route(second_prompt)
200+
res = await generator.generate.route(second_prompt)
201201
assert res[0].metadata["num_cached_tokens"] == expected_cached_tokens
202-
policy_outputs.append(res[0].text)
202+
generator_outputs.append(res[0].text)
203203

204204
# The first same 32 tokens should now be populated in blocks
205205
third_prompt = second_prompt
@@ -209,13 +209,13 @@ async def test_cache_usage():
209209
):
210210
vllm_outputs.append(res.outputs[0].text)
211211
assert res.num_cached_tokens == expected_cached_tokens
212-
res = await policy.generate.route(third_prompt)
212+
res = await generator.generate.route(third_prompt)
213213
assert res[0].metadata["num_cached_tokens"] == expected_cached_tokens
214-
policy_outputs.append(res[0].text)
214+
generator_outputs.append(res[0].text)
215215

216216
# Now, let's clear the cache
217217
await vllm_model.reset_prefix_cache()
218-
await policy._reset_prefix_cache.route()
218+
await generator._reset_prefix_cache.route()
219219

220220
# And try the third prompt again (should not use any cached tokens)
221221
expected_cached_tokens = 0
@@ -224,16 +224,16 @@ async def test_cache_usage():
224224
):
225225
vllm_outputs.append(res.outputs[0].text)
226226
assert res.num_cached_tokens == expected_cached_tokens
227-
res = await policy.generate.route(third_prompt)
227+
res = await generator.generate.route(third_prompt)
228228
assert res[0].metadata["num_cached_tokens"] == expected_cached_tokens
229-
policy_outputs.append(res[0].text)
229+
generator_outputs.append(res[0].text)
230230

231231
# Sanity check that outputs are still the same
232-
for vllm_output, policy_output in zip(vllm_outputs, policy_outputs):
232+
for vllm_output, generator_output in zip(vllm_outputs, generator_outputs):
233233
assert vllm_output != ""
234-
assert policy_output != ""
235-
assert vllm_output == policy_output
234+
assert generator_output != ""
235+
assert vllm_output == generator_output
236236

237237
finally:
238-
if policy is not None:
239-
await policy.shutdown()
238+
if generator is not None:
239+
await generator.shutdown()

tests/sandbox/weight_sync/main.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,14 +130,14 @@ async def main(cfg: DictConfig):
130130
print("Initializing trainer and generator...")
131131
init_start = time.time()
132132

133-
trainer, policy = await asyncio.gather(
133+
trainer, generator = await asyncio.gather(
134134
RLTrainer.options(**cfg.actors.trainer).as_actor(
135135
**cfg.trainer,
136136
loss=lambda *args, **kwargs: torch.tensor(
137137
1.0, requires_grad=True, device="cuda"
138138
),
139139
),
140-
Generator.options(**cfg.actors.policy).as_actor(**cfg.policy),
140+
Generator.options(**cfg.actors.generator).as_actor(**cfg.generator),
141141
)
142142

143143
init_time = time.time() - init_start
@@ -172,7 +172,7 @@ async def main(cfg: DictConfig):
172172
print("Updating generator weights from store...")
173173
update_start = time.time()
174174

175-
await policy.update_weights.call(version=1)
175+
await generator.update_weights.call(version=1)
176176

177177
update_time = time.time() - update_start
178178
print(f"Updated generator weights ({update_time:.2f}s)\n")

0 commit comments

Comments
 (0)