diff --git a/tensorrt_llm/_torch/visual_gen/pipeline.py b/tensorrt_llm/_torch/visual_gen/pipeline.py index ec35e53bd95..9f1d0452e26 100644 --- a/tensorrt_llm/_torch/visual_gen/pipeline.py +++ b/tensorrt_llm/_torch/visual_gen/pipeline.py @@ -170,16 +170,19 @@ def _cuda_profiler_stop(self): logger.info("CUDA profiler stopped") def _setup_cuda_graphs(self): - """Wrap all transformer components with CUDA graph capture/replay.""" + """Wrap all transformer components with CUDA graph capture/replay. + + Composes with torch.compile: the runner wraps the (outer) transformer + ``forward`` while torch.compile compiles the inner transformer blocks + (see ``torch_compile``). Graph capture happens during warmup, by which + point the runner's own ``WARMUP_STEPS`` eager iterations have already + triggered torch.compile's lazy compilation, so the captured graph + contains the optimized compiled kernels. (The ``LTX2Pipeline`` override + relies on the same ordering.) + """ if not self.pipeline_config.cuda_graph.enable: return - if self.pipeline_config.torch_compile.enable: - logger.warning( - "CUDA graphs with torch.compile not yet supported. Using torch.compile only." - ) - return - if len(self.transformer_components) > 1: logger.info( "CUDA graph runner: multiple transformer components, using shared graph pool" @@ -188,6 +191,7 @@ def _setup_cuda_graphs(self): else: shared_pool = None + compile_note = " (with torch.compile)" if self.pipeline_config.torch_compile.enable else "" for name in self.transformer_components: model = getattr(self, name, None) if model is None: @@ -198,7 +202,7 @@ def _setup_cuda_graphs(self): shared_pool, ) model.register_cuda_graph_extra_key_fns(runner) - logger.info(f"CUDA graph runner: wrapping {name}.forward") + logger.info(f"CUDA graph runner: wrapping {name}.forward{compile_note}") model.forward = runner.wrap(model.forward) self._cuda_graph_runners[name] = runner diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index 799077d8fe9..d5436def024 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -458,6 +458,7 @@ unittest/_torch/multi_gpu/test_user_buffers.py::test_user_buffers_pass[2-fp16-_t unittest/_torch/ray_orchestrator/multi_gpu/test_llm_update_weights_multi_gpu.py::test_llm_partial_update_weights_nvfp4[auto-Qwen3/Qwen3-8B] SKIP (https://nvbugs/6372690) unittest/_torch/ray_orchestrator/multi_gpu/test_llm_update_weights_multi_gpu.py::test_llm_partial_update_weights_nvfp4[fp8-Qwen3/Qwen3-30B-A3B] SKIP (https://nvbugs/6372690) unittest/_torch/ray_orchestrator/multi_gpu/test_llm_update_weights_multi_gpu.py::test_llm_partial_update_weights_nvfp4[fp8-Qwen3/Qwen3-8B] SKIP (https://nvbugs/6372690) +unittest/_torch/sampler/test_beam_search_speculative_d2h.py::test_speculative_d2h_predictor_hit_is_sync_free SKIP (https://nvbugs/6378901) unittest/_torch/thop/parallel/test_fp8_rowwise_linear.py::test_fp8_rowwise_linear[dtype1] SKIP (https://nvbugs/6301807) unittest/_torch/thop/serial/test_moe.py::TestMoeFp4::test_no_autotune[use_score_as_input-RoutingDSv3-swiglu-1024-1024-1] SKIP (https://nvbugs/5908070) unittest/_torch/thop/serial/test_moe.py::TestMoeFp4::test_no_autotune[use_score_as_input-RoutingRenormalize_qwen_next-swiglu-1024-1024-150] SKIP (https://nvbugs/5908070)