|
| 1 | +import os |
1 | 2 | import time |
2 | 3 | from typing import List, Optional, Union |
3 | 4 |
|
@@ -108,6 +109,18 @@ def __init__(self, pipeline_config): |
108 | 109 | "Use cache_backend='none' or 'cache_dit' (not 'teacache')." |
109 | 110 | ) |
110 | 111 |
|
| 112 | + # Fixed latent for reproducible benchmarking (e.g. MLPerf). |
| 113 | + # Set TRTLLM_VIDEO_FIXED_LATENT_PATH to a .pt file containing a pre-sampled |
| 114 | + # noise tensor; it will be used in place of freshly sampled random latents for |
| 115 | + # all T2V requests. Loaded once at server startup, reused across requests. |
| 116 | + self._fixed_latent: Optional[torch.Tensor] = None |
| 117 | + _fixed_latent_path = os.environ.get("TRTLLM_VIDEO_FIXED_LATENT_PATH") |
| 118 | + if _fixed_latent_path: |
| 119 | + self._fixed_latent = torch.load(_fixed_latent_path, weights_only=True) |
| 120 | + logger.warning( |
| 121 | + f"Loaded fixed latent from {_fixed_latent_path}, shape={self._fixed_latent.shape}" |
| 122 | + ) |
| 123 | + |
111 | 124 | super().__init__(pipeline_config) |
112 | 125 |
|
113 | 126 | def _compute_wan_timestep_embedding(self, module, timestep=None, **kwargs): |
@@ -486,6 +499,8 @@ def forward( |
486 | 499 | latents, i2v_condition, i2v_first_frame_mask = self._prepare_latents_wan22_5B_i2v( |
487 | 500 | batch_size, image, height, width, num_frames, generator |
488 | 501 | ) |
| 502 | + elif self._fixed_latent is not None: |
| 503 | + latents = self._fixed_latent.to(device=self.device, dtype=self.dtype) |
489 | 504 | else: |
490 | 505 | latents = self._prepare_latents(batch_size, height, width, num_frames, generator) |
491 | 506 | logger.debug(f"Latents shape: {latents.shape}") |
|
0 commit comments