File tree Expand file tree Collapse file tree
examples/speculative_decoding/collect_hidden_states Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -85,6 +85,17 @@ def parse_args() -> argparse.Namespace:
8585 action = "store_true" ,
8686 help = "Set trust_remote_code for Huggingface models and tokenizers" ,
8787 )
88+ parser .add_argument (
89+ "--gpu-memory-util" ,
90+ type = float ,
91+ default = None ,
92+ help = "Override vLLM's default gpu_memory_utilization. Lower this on shared GPUs." ,
93+ )
94+ parser .add_argument (
95+ "--enforce-eager" ,
96+ action = "store_true" ,
97+ help = "Disable CUDA graph capture in vLLM. Faster startup, lower throughput." ,
98+ )
8899 add_aux_layers_args (parser )
89100 add_answer_only_loss_args (parser )
90101 return parser .parse_args ()
@@ -188,6 +199,11 @@ def keep_conversation(entry):
188199 return
189200
190201 with tempfile .TemporaryDirectory () as tmpdir :
202+ llm_kwargs = {}
203+ if args .gpu_memory_util is not None :
204+ llm_kwargs ["gpu_memory_utilization" ] = args .gpu_memory_util
205+ if args .enforce_eager :
206+ llm_kwargs ["enforce_eager" ] = True
191207 llm = LLM (
192208 model = args .model ,
193209 speculative_config = {
@@ -208,6 +224,7 @@ def keep_conversation(entry):
208224 },
209225 tensor_parallel_size = args .tp ,
210226 trust_remote_code = args .trust_remote_code ,
227+ ** llm_kwargs ,
211228 )
212229
213230 sampling_params = SamplingParams (max_tokens = 1 )
You can’t perform that action at this time.
0 commit comments