Skip to content

Commit 70f3d0e

Browse files
author
Shrey Modi
committed
final
1 parent b57ad2c commit 70f3d0e

5 files changed

Lines changed: 670 additions & 104 deletions

File tree

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
"""
2+
VLLMPolicy - Policy for TRL's VLLMClient
3+
4+
Simple policy that calls TRL's vllm_client directly instead of going through LiteLLM.
5+
Works with `trl vllm-serve` endpoints.
6+
"""
7+
8+
from typing import Any, Dict, List, Optional
9+
10+
11+
class VLLMPolicy:
12+
"""
13+
Policy that uses TRL's VLLMClient for generation.
14+
15+
This is designed to work with `trl vllm-serve` which provides
16+
custom /generate/ and /chat/ endpoints.
17+
"""
18+
19+
def __init__(
20+
self,
21+
vllm_client, # trainer.vllm_client
22+
tokenizer=None, # Optional tokenizer for decoding
23+
temperature: float = 1.0,
24+
max_tokens: int = 100,
25+
top_p: Optional[float] = None,
26+
top_k: Optional[int] = None,
27+
**kwargs,
28+
):
29+
"""
30+
Initialize VLLMPolicy.
31+
32+
Args:
33+
vllm_client: TRL's VLLMClient instance (from trainer.vllm_client)
34+
tokenizer: Optional tokenizer for decoding token IDs to text
35+
temperature: Sampling temperature
36+
max_tokens: Maximum tokens to generate
37+
top_p: Top-p sampling
38+
top_k: Top-k sampling
39+
**kwargs: Additional generation parameters
40+
"""
41+
self.vllm_client = vllm_client
42+
self.tokenizer = tokenizer
43+
self.temperature = temperature
44+
self.max_tokens = max_tokens
45+
self.top_p = top_p if top_p is not None else 1.0
46+
self.top_k = top_k if top_k is not None else -1
47+
self.kwargs = kwargs
48+
49+
async def _make_llm_call(
50+
self,
51+
messages: List[Dict[str, Any]],
52+
tools: Optional[List] = None,
53+
) -> Dict[str, Any]:
54+
"""
55+
Make LLM call using TRL's VLLMClient.
56+
57+
Args:
58+
messages: List of message dicts with 'role' and 'content'
59+
tools: Not used (for compatibility)
60+
61+
Returns:
62+
OpenAI-compatible response dict
63+
"""
64+
# Apply chat template to convert messages to a prompt string
65+
if self.tokenizer is not None:
66+
try:
67+
# Use tokenizer's chat template
68+
prompt_text = self.tokenizer.apply_chat_template(
69+
messages,
70+
add_generation_prompt=True,
71+
tokenize=False,
72+
)
73+
print("\n[VLLMPolicy] ===== CHAT TEMPLATE APPLIED =====", flush=True)
74+
print(f"[VLLMPolicy] Input messages ({len(messages)} messages):", flush=True)
75+
for i, msg in enumerate(messages):
76+
content_preview = str(msg.get("content", ""))[:100]
77+
print(f" [{i}] {msg.get('role', '?')}: {content_preview}...", flush=True)
78+
print(f"[VLLMPolicy] Formatted prompt (length={len(prompt_text)}):", flush=True)
79+
print("[VLLMPolicy] Prompt preview (last 500 chars):", flush=True)
80+
print(f"{prompt_text[-500:]}", flush=True)
81+
print("[VLLMPolicy] ===================================", flush=True)
82+
except Exception as e:
83+
print(f"[VLLMPolicy] Warning: Failed to apply chat template: {e}", flush=True)
84+
# Fallback: simple concatenation
85+
prompt_text = "\n".join(f"{m['role']}: {m['content']}" for m in messages)
86+
else:
87+
# No tokenizer: simple concatenation
88+
prompt_text = "\n".join(f"{m['role']}: {m['content']}" for m in messages)
89+
90+
# Check if vllm_client is VLLMClient (server mode) or LLM (colocate mode)
91+
is_llm_object = hasattr(self.vllm_client, "llm_engine") # LLM has llm_engine
92+
93+
if is_llm_object:
94+
# Colocate mode: use SamplingParams
95+
print("[VLLMPolicy] Using vLLM LLM (colocate mode) with SamplingParams", flush=True)
96+
from vllm import SamplingParams
97+
98+
sampling_params = SamplingParams(
99+
temperature=self.temperature,
100+
max_tokens=self.max_tokens,
101+
top_p=self.top_p,
102+
top_k=self.top_k,
103+
n=1,
104+
)
105+
106+
print("[VLLMPolicy] Calling LLM.generate()...", flush=True)
107+
outputs = self.vllm_client.generate([prompt_text], sampling_params=sampling_params, use_tqdm=False)
108+
109+
# Extract from vLLM output format
110+
output = outputs[0]
111+
prompt_ids = output.prompt_token_ids
112+
completion_ids = output.outputs[0].token_ids
113+
response = {
114+
"prompt_ids": [prompt_ids],
115+
"completion_ids": [completion_ids],
116+
}
117+
else:
118+
# Server mode: use VLLMClient with kwargs
119+
print("[VLLMPolicy] Using VLLMClient (server mode)", flush=True)
120+
vllm_params = {
121+
"temperature": self.temperature,
122+
"max_tokens": self.max_tokens,
123+
"top_p": self.top_p,
124+
"top_k": self.top_k,
125+
"n": 1,
126+
}
127+
vllm_params.update(self.kwargs)
128+
129+
print("[VLLMPolicy] Calling vllm_client.generate()...", flush=True)
130+
response = self.vllm_client.generate(
131+
prompts=[prompt_text],
132+
**vllm_params,
133+
)
134+
135+
# Extract first result
136+
prompt_ids = response["prompt_ids"][0]
137+
completion_ids = response["completion_ids"][0]
138+
139+
# Decode completion text if tokenizer available
140+
if self.tokenizer is not None:
141+
try:
142+
completion_text = self.tokenizer.decode(completion_ids, skip_special_tokens=True)
143+
print("\n[VLLMPolicy] ===== GENERATION RESULT =====", flush=True)
144+
print(f"[VLLMPolicy] Prompt tokens: {len(prompt_ids)}", flush=True)
145+
print(f"[VLLMPolicy] Completion tokens: {len(completion_ids)}", flush=True)
146+
print(f"[VLLMPolicy] FULL decoded completion ({len(completion_text)} chars):", flush=True)
147+
print("───────────────────────────────────────", flush=True)
148+
print(f"{completion_text}", flush=True)
149+
print("───────────────────────────────────────", flush=True)
150+
print("[VLLMPolicy] ==============================", flush=True)
151+
except Exception as e:
152+
print(f"[VLLMPolicy] Warning: Failed to decode completion: {e}", flush=True)
153+
completion_text = f"<decoded_error:{len(completion_ids)}_tokens>"
154+
else:
155+
# Fallback: just indicate number of tokens
156+
completion_text = f"<{len(completion_ids)}_tokens>"
157+
158+
# Convert to OpenAI-compatible format for compatibility with OpenEnvRolloutProcessor
159+
# Also include raw token IDs for TRL integration (avoids double encoding)
160+
return {
161+
"choices": [
162+
{
163+
"message": {
164+
"content": completion_text,
165+
"role": "assistant",
166+
}
167+
}
168+
],
169+
"usage": {
170+
"prompt_tokens": len(prompt_ids),
171+
"completion_tokens": len(completion_ids),
172+
"total_tokens": len(prompt_ids) + len(completion_ids),
173+
},
174+
# Include raw token IDs for TRL (avoids re-encoding)
175+
"prompt_ids": prompt_ids,
176+
"completion_ids": completion_ids,
177+
}

0 commit comments

Comments
 (0)