|
| 1 | +""" |
| 2 | +VLLMPolicy - Policy for TRL's VLLMClient |
| 3 | +
|
| 4 | +Simple policy that calls TRL's vllm_client directly instead of going through LiteLLM. |
| 5 | +Works with `trl vllm-serve` endpoints. |
| 6 | +""" |
| 7 | + |
| 8 | +from typing import Any, Dict, List, Optional |
| 9 | + |
| 10 | + |
| 11 | +class VLLMPolicy: |
| 12 | + """ |
| 13 | + Policy that uses TRL's VLLMClient for generation. |
| 14 | +
|
| 15 | + This is designed to work with `trl vllm-serve` which provides |
| 16 | + custom /generate/ and /chat/ endpoints. |
| 17 | + """ |
| 18 | + |
| 19 | + def __init__( |
| 20 | + self, |
| 21 | + vllm_client, # trainer.vllm_client |
| 22 | + tokenizer=None, # Optional tokenizer for decoding |
| 23 | + temperature: float = 1.0, |
| 24 | + max_tokens: int = 100, |
| 25 | + top_p: Optional[float] = None, |
| 26 | + top_k: Optional[int] = None, |
| 27 | + **kwargs, |
| 28 | + ): |
| 29 | + """ |
| 30 | + Initialize VLLMPolicy. |
| 31 | +
|
| 32 | + Args: |
| 33 | + vllm_client: TRL's VLLMClient instance (from trainer.vllm_client) |
| 34 | + tokenizer: Optional tokenizer for decoding token IDs to text |
| 35 | + temperature: Sampling temperature |
| 36 | + max_tokens: Maximum tokens to generate |
| 37 | + top_p: Top-p sampling |
| 38 | + top_k: Top-k sampling |
| 39 | + **kwargs: Additional generation parameters |
| 40 | + """ |
| 41 | + self.vllm_client = vllm_client |
| 42 | + self.tokenizer = tokenizer |
| 43 | + self.temperature = temperature |
| 44 | + self.max_tokens = max_tokens |
| 45 | + self.top_p = top_p if top_p is not None else 1.0 |
| 46 | + self.top_k = top_k if top_k is not None else -1 |
| 47 | + self.kwargs = kwargs |
| 48 | + |
| 49 | + async def _make_llm_call( |
| 50 | + self, |
| 51 | + messages: List[Dict[str, Any]], |
| 52 | + tools: Optional[List] = None, |
| 53 | + ) -> Dict[str, Any]: |
| 54 | + """ |
| 55 | + Make LLM call using TRL's VLLMClient. |
| 56 | +
|
| 57 | + Args: |
| 58 | + messages: List of message dicts with 'role' and 'content' |
| 59 | + tools: Not used (for compatibility) |
| 60 | +
|
| 61 | + Returns: |
| 62 | + OpenAI-compatible response dict |
| 63 | + """ |
| 64 | + # Apply chat template to convert messages to a prompt string |
| 65 | + if self.tokenizer is not None: |
| 66 | + try: |
| 67 | + # Use tokenizer's chat template |
| 68 | + prompt_text = self.tokenizer.apply_chat_template( |
| 69 | + messages, |
| 70 | + add_generation_prompt=True, |
| 71 | + tokenize=False, |
| 72 | + ) |
| 73 | + print("\n[VLLMPolicy] ===== CHAT TEMPLATE APPLIED =====", flush=True) |
| 74 | + print(f"[VLLMPolicy] Input messages ({len(messages)} messages):", flush=True) |
| 75 | + for i, msg in enumerate(messages): |
| 76 | + content_preview = str(msg.get("content", ""))[:100] |
| 77 | + print(f" [{i}] {msg.get('role', '?')}: {content_preview}...", flush=True) |
| 78 | + print(f"[VLLMPolicy] Formatted prompt (length={len(prompt_text)}):", flush=True) |
| 79 | + print("[VLLMPolicy] Prompt preview (last 500 chars):", flush=True) |
| 80 | + print(f"{prompt_text[-500:]}", flush=True) |
| 81 | + print("[VLLMPolicy] ===================================", flush=True) |
| 82 | + except Exception as e: |
| 83 | + print(f"[VLLMPolicy] Warning: Failed to apply chat template: {e}", flush=True) |
| 84 | + # Fallback: simple concatenation |
| 85 | + prompt_text = "\n".join(f"{m['role']}: {m['content']}" for m in messages) |
| 86 | + else: |
| 87 | + # No tokenizer: simple concatenation |
| 88 | + prompt_text = "\n".join(f"{m['role']}: {m['content']}" for m in messages) |
| 89 | + |
| 90 | + # Check if vllm_client is VLLMClient (server mode) or LLM (colocate mode) |
| 91 | + is_llm_object = hasattr(self.vllm_client, "llm_engine") # LLM has llm_engine |
| 92 | + |
| 93 | + if is_llm_object: |
| 94 | + # Colocate mode: use SamplingParams |
| 95 | + print("[VLLMPolicy] Using vLLM LLM (colocate mode) with SamplingParams", flush=True) |
| 96 | + from vllm import SamplingParams |
| 97 | + |
| 98 | + sampling_params = SamplingParams( |
| 99 | + temperature=self.temperature, |
| 100 | + max_tokens=self.max_tokens, |
| 101 | + top_p=self.top_p, |
| 102 | + top_k=self.top_k, |
| 103 | + n=1, |
| 104 | + ) |
| 105 | + |
| 106 | + print("[VLLMPolicy] Calling LLM.generate()...", flush=True) |
| 107 | + outputs = self.vllm_client.generate([prompt_text], sampling_params=sampling_params, use_tqdm=False) |
| 108 | + |
| 109 | + # Extract from vLLM output format |
| 110 | + output = outputs[0] |
| 111 | + prompt_ids = output.prompt_token_ids |
| 112 | + completion_ids = output.outputs[0].token_ids |
| 113 | + response = { |
| 114 | + "prompt_ids": [prompt_ids], |
| 115 | + "completion_ids": [completion_ids], |
| 116 | + } |
| 117 | + else: |
| 118 | + # Server mode: use VLLMClient with kwargs |
| 119 | + print("[VLLMPolicy] Using VLLMClient (server mode)", flush=True) |
| 120 | + vllm_params = { |
| 121 | + "temperature": self.temperature, |
| 122 | + "max_tokens": self.max_tokens, |
| 123 | + "top_p": self.top_p, |
| 124 | + "top_k": self.top_k, |
| 125 | + "n": 1, |
| 126 | + } |
| 127 | + vllm_params.update(self.kwargs) |
| 128 | + |
| 129 | + print("[VLLMPolicy] Calling vllm_client.generate()...", flush=True) |
| 130 | + response = self.vllm_client.generate( |
| 131 | + prompts=[prompt_text], |
| 132 | + **vllm_params, |
| 133 | + ) |
| 134 | + |
| 135 | + # Extract first result |
| 136 | + prompt_ids = response["prompt_ids"][0] |
| 137 | + completion_ids = response["completion_ids"][0] |
| 138 | + |
| 139 | + # Decode completion text if tokenizer available |
| 140 | + if self.tokenizer is not None: |
| 141 | + try: |
| 142 | + completion_text = self.tokenizer.decode(completion_ids, skip_special_tokens=True) |
| 143 | + print("\n[VLLMPolicy] ===== GENERATION RESULT =====", flush=True) |
| 144 | + print(f"[VLLMPolicy] Prompt tokens: {len(prompt_ids)}", flush=True) |
| 145 | + print(f"[VLLMPolicy] Completion tokens: {len(completion_ids)}", flush=True) |
| 146 | + print(f"[VLLMPolicy] FULL decoded completion ({len(completion_text)} chars):", flush=True) |
| 147 | + print("───────────────────────────────────────", flush=True) |
| 148 | + print(f"{completion_text}", flush=True) |
| 149 | + print("───────────────────────────────────────", flush=True) |
| 150 | + print("[VLLMPolicy] ==============================", flush=True) |
| 151 | + except Exception as e: |
| 152 | + print(f"[VLLMPolicy] Warning: Failed to decode completion: {e}", flush=True) |
| 153 | + completion_text = f"<decoded_error:{len(completion_ids)}_tokens>" |
| 154 | + else: |
| 155 | + # Fallback: just indicate number of tokens |
| 156 | + completion_text = f"<{len(completion_ids)}_tokens>" |
| 157 | + |
| 158 | + # Convert to OpenAI-compatible format for compatibility with OpenEnvRolloutProcessor |
| 159 | + # Also include raw token IDs for TRL integration (avoids double encoding) |
| 160 | + return { |
| 161 | + "choices": [ |
| 162 | + { |
| 163 | + "message": { |
| 164 | + "content": completion_text, |
| 165 | + "role": "assistant", |
| 166 | + } |
| 167 | + } |
| 168 | + ], |
| 169 | + "usage": { |
| 170 | + "prompt_tokens": len(prompt_ids), |
| 171 | + "completion_tokens": len(completion_ids), |
| 172 | + "total_tokens": len(prompt_ids) + len(completion_ids), |
| 173 | + }, |
| 174 | + # Include raw token IDs for TRL (avoids re-encoding) |
| 175 | + "prompt_ids": prompt_ids, |
| 176 | + "completion_ids": completion_ids, |
| 177 | + } |
0 commit comments