diff --git a/src/models/dit_3b/modulation.py b/src/models/dit_3b/modulation.py index 854ae09a..96f89f97 100644 --- a/src/models/dit_3b/modulation.py +++ b/src/models/dit_3b/modulation.py @@ -19,6 +19,7 @@ from ...common.cache import Cache from ...common.distributed.ops import slice_inputs +from ...optimization.compatibility import portable_repeat_interleave # (dim: int, emb_dim: int) ada_layer_type = Callable[[int, int], nn.Module] @@ -80,7 +81,7 @@ def forward( emb = cache( f"emb_repeat_{idx}_{branch_tag}", lambda: slice_inputs( - torch.repeat_interleave(emb, hid_len, dim=0), + portable_repeat_interleave(emb, hid_len, dim=0), dim=0, ), ) diff --git a/src/models/dit_3b/nablocks/attention/mmattn.py b/src/models/dit_3b/nablocks/attention/mmattn.py index a311449f..c99bfa5e 100644 --- a/src/models/dit_3b/nablocks/attention/mmattn.py +++ b/src/models/dit_3b/nablocks/attention/mmattn.py @@ -22,6 +22,7 @@ from .....common.cache import Cache from .....common.distributed.ops import gather_heads_scatter_seq, gather_seq_scatter_heads_qkv from .....common.half_precision_fixes import safe_pad_operation +from .....optimization.compatibility import portable_repeat_interleave from ... import na from ...attention import FlashAttentionVarlen @@ -210,7 +211,7 @@ def make_window(x: torch.Tensor): txt_len = cache("txt_len", lambda: txt_shape.prod(-1)) vid_len_win = cache_win("vid_len", lambda: window_shape.prod(-1)) - txt_len_win = cache_win("txt_len", lambda: txt_len.repeat_interleave(window_count)) + txt_len_win = cache_win("txt_len", lambda: portable_repeat_interleave(txt_len, window_count, dim=0)) all_len_win = cache_win("all_len", lambda: vid_len_win + txt_len_win) concat_win, unconcat_win = cache_win( "mm_pnp", lambda: na.repeat_concat_idx(vid_len_win, txt_len, window_count) diff --git a/src/models/dit_7b/modulation.py b/src/models/dit_7b/modulation.py index 38fff07e..af71727f 100644 --- a/src/models/dit_7b/modulation.py +++ b/src/models/dit_7b/modulation.py @@ -19,6 +19,7 @@ from ...common.cache import Cache from ...common.distributed.ops import slice_inputs +from ...optimization.compatibility import portable_repeat_interleave # (dim: int, emb_dim: int) ada_layer_type = Callable[[int, int], nn.Module] @@ -75,7 +76,7 @@ def forward( emb = cache( f"emb_repeat_{idx}_{branch_tag}", lambda: slice_inputs( - torch.repeat_interleave(emb, hid_len, dim=0), + portable_repeat_interleave(emb, hid_len, dim=0), dim=0, ), ) diff --git a/src/models/dit_7b/nablocks/mmsr_block.py b/src/models/dit_7b/nablocks/mmsr_block.py index fe9010ac..4209723e 100644 --- a/src/models/dit_7b/nablocks/mmsr_block.py +++ b/src/models/dit_7b/nablocks/mmsr_block.py @@ -20,6 +20,7 @@ # from ..cache import Cache from ....common.cache import Cache from ....common.distributed.ops import gather_heads_scatter_seq, gather_seq_scatter_heads_qkv +from ....optimization.compatibility import portable_repeat_interleave from .. import na from ..attention import FlashAttentionVarlen @@ -117,7 +118,7 @@ def make_window(x: torch.Tensor): txt_len = cache("txt_len", lambda: txt_shape.prod(-1)) vid_len_win = cache_win("vid_len", lambda: window_shape.prod(-1)) - txt_len_win = cache_win("txt_len", lambda: txt_len.repeat_interleave(window_count)) + txt_len_win = cache_win("txt_len", lambda: portable_repeat_interleave(txt_len, window_count, dim=0)) all_len_win = cache_win("all_len", lambda: vid_len_win + txt_len_win) concat_win, unconcat_win = cache_win( "mm_pnp", lambda: na.repeat_concat_idx(vid_len_win, txt_len, window_count) diff --git a/src/optimization/__init__.py b/src/optimization/__init__.py index e69de29b..ee763d13 100644 --- a/src/optimization/__init__.py +++ b/src/optimization/__init__.py @@ -0,0 +1,14 @@ +""" +Optimization utilities for SeedVR2 Video Upscaler. + +Exports: + IS_ROCM: Boolean indicating if running on AMD ROCm/HIP backend + portable_repeat_interleave: Cross-platform repeat_interleave that works on ROCm +""" + +from .compatibility import IS_ROCM, portable_repeat_interleave + +__all__ = [ + "IS_ROCM", + "portable_repeat_interleave", +] diff --git a/src/optimization/compatibility.py b/src/optimization/compatibility.py index c462022b..52c5e9cd 100644 --- a/src/optimization/compatibility.py +++ b/src/optimization/compatibility.py @@ -640,6 +640,129 @@ def _check_conv3d_memory_bug(): NVIDIA_CONV3D_MEMORY_BUG_WORKAROUND = _check_conv3d_memory_bug() +# 5. AMD ROCm/HIP Detection and Portable Operations +def _check_is_rocm() -> bool: + """Check if running on AMD ROCm/HIP backend.""" + return hasattr(torch.version, 'hip') and torch.version.hip is not None + +IS_ROCM = _check_is_rocm() + + +def portable_repeat_interleave( + input: torch.Tensor, + repeats: "Union[int, torch.Tensor]", + dim: int = 0, + debug: "Optional[Debug]" = None, +) -> torch.Tensor: + """ + Cross-platform replacement for torch.repeat_interleave. + + torch.repeat_interleave can trigger hipErrorIllegalState on AMD ROCm + due to HIP kernel bugs. This function uses torch.arange + torch.index_select + as a portable alternative on ROCm, while using the native (faster) + implementation on CUDA/CPU. + + Supports both: + - Scalar repeats: repeat each element the same number of times + - Tensor repeats: different repeat count per element + + Args: + input: Input tensor to repeat elements from + repeats: Int or 1D LongTensor of repeat counts + dim: Dimension along which to repeat (default: 0) + debug: Optional Debug instance for logging tensor diagnostics + + Returns: + Tensor with elements repeated along dim according to repeats + + Example (scalar): + >>> x = torch.tensor([1, 2, 3]) + >>> portable_repeat_interleave(x, 2, dim=0) + tensor([1, 1, 2, 2, 3, 3]) + + Example (tensor): + >>> x = torch.tensor([[1, 2], [3, 4], [5, 6]]) + >>> repeats = torch.tensor([2, 1, 3]) + >>> portable_repeat_interleave(x, repeats, dim=0) + tensor([[1, 2], [1, 2], [3, 4], [5, 6], [5, 6], [5, 6]]) + """ + # Handle scalar repeats (int or 0-dim tensor) + is_scalar_repeat = isinstance(repeats, int) or ( + isinstance(repeats, torch.Tensor) and repeats.dim() == 0 + ) + + # Debug logging for tensor state diagnostics + if debug is not None: + def _tensor_info(name: str, t: torch.Tensor) -> str: + return (f"{name}: device={t.device}, dtype={t.dtype}, " + f"shape={list(t.shape)}, contiguous={t.is_contiguous()}") + repeats_info = str(repeats) if is_scalar_repeat else _tensor_info('repeats', repeats) + debug.log( + f"portable_repeat_interleave called:\n" + f" {_tensor_info('input', input)}\n" + f" repeats={repeats_info} (scalar={is_scalar_repeat})\n" + f" dim={dim}, IS_ROCM={IS_ROCM}", + level="DEBUG", + category="rocm_compat" + ) + + if not IS_ROCM: + # Use native implementation on CUDA/CPU - faster when it works + return torch.repeat_interleave(input, repeats, dim=dim) + + # ROCm-safe implementation using index expansion + # This avoids the buggy HIP kernel in torch.repeat_interleave + + if is_scalar_repeat: + # Scalar repeat: each element repeated same number of times + # Build indices [0,0,1,1,2,2,...] for repeat=2 + repeat_count = int(repeats) + num_elements = input.shape[dim] + # Create indices: [0,1,2,...] then repeat each + indices = torch.arange(num_elements, device=input.device, dtype=torch.long) + # Use repeat + reshape to expand: [0,1,2] -> [[0,0],[1,1],[2,2]] -> [0,0,1,1,2,2] + expanded_indices = indices.unsqueeze(1).expand(-1, repeat_count).reshape(-1) + else: + # Tensor repeat: different count per element + # Strategy: Build index tensor mapping each output position to source position + # For repeats=[2, 1, 3], build indices=[0, 0, 1, 2, 2, 2] + + # Ensure repeats is on same device as input + if repeats.device != input.device: + repeats = repeats.to(input.device) + + num_elements = repeats.shape[0] + indices = torch.arange(num_elements, device=input.device, dtype=torch.long) + + # Try repeat_interleave on indices first (simpler 1D case may work) + try: + expanded_indices = torch.repeat_interleave(indices, repeats) + except RuntimeError as e: + # Ultimate fallback: build indices manually (slower but always works) + if debug is not None: + debug.log( + f"repeat_interleave on indices failed, using manual fallback: {e}", + level="WARNING", + category="rocm_compat" + ) + indices_list = [] + repeats_cpu = repeats.cpu().tolist() + for i, r in enumerate(repeats_cpu): + indices_list.extend([i] * int(r)) + expanded_indices = torch.tensor(indices_list, device=input.device, dtype=torch.long) + + result = torch.index_select(input, dim, expanded_indices) + + if debug is not None: + debug.log( + f"portable_repeat_interleave result: {_tensor_info('output', result)}", + level="DEBUG", + category="rocm_compat" + ) + + return result + + # Log all optimization status once globally (cross-process) using environment variable if not os.environ.get("SEEDVR2_OPTIMIZATIONS_LOGGED"): os.environ["SEEDVR2_OPTIMIZATIONS_LOGGED"] = "1"