Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/models/dit_3b/modulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from ...common.cache import Cache
from ...common.distributed.ops import slice_inputs
from ...optimization.compatibility import portable_repeat_interleave

# (dim: int, emb_dim: int)
ada_layer_type = Callable[[int, int], nn.Module]
Expand Down Expand Up @@ -80,7 +81,7 @@ def forward(
emb = cache(
f"emb_repeat_{idx}_{branch_tag}",
lambda: slice_inputs(
torch.repeat_interleave(emb, hid_len, dim=0),
portable_repeat_interleave(emb, hid_len, dim=0),
dim=0,
),
)
Expand Down
3 changes: 2 additions & 1 deletion src/models/dit_3b/nablocks/attention/mmattn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .....common.cache import Cache
from .....common.distributed.ops import gather_heads_scatter_seq, gather_seq_scatter_heads_qkv
from .....common.half_precision_fixes import safe_pad_operation
from .....optimization.compatibility import portable_repeat_interleave

from ... import na
from ...attention import FlashAttentionVarlen
Expand Down Expand Up @@ -210,7 +211,7 @@ def make_window(x: torch.Tensor):
txt_len = cache("txt_len", lambda: txt_shape.prod(-1))

vid_len_win = cache_win("vid_len", lambda: window_shape.prod(-1))
txt_len_win = cache_win("txt_len", lambda: txt_len.repeat_interleave(window_count))
txt_len_win = cache_win("txt_len", lambda: portable_repeat_interleave(txt_len, window_count, dim=0))
all_len_win = cache_win("all_len", lambda: vid_len_win + txt_len_win)
concat_win, unconcat_win = cache_win(
"mm_pnp", lambda: na.repeat_concat_idx(vid_len_win, txt_len, window_count)
Expand Down
3 changes: 2 additions & 1 deletion src/models/dit_7b/modulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from ...common.cache import Cache
from ...common.distributed.ops import slice_inputs
from ...optimization.compatibility import portable_repeat_interleave

# (dim: int, emb_dim: int)
ada_layer_type = Callable[[int, int], nn.Module]
Expand Down Expand Up @@ -75,7 +76,7 @@ def forward(
emb = cache(
f"emb_repeat_{idx}_{branch_tag}",
lambda: slice_inputs(
torch.repeat_interleave(emb, hid_len, dim=0),
portable_repeat_interleave(emb, hid_len, dim=0),
dim=0,
),
)
Expand Down
3 changes: 2 additions & 1 deletion src/models/dit_7b/nablocks/mmsr_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
# from ..cache import Cache
from ....common.cache import Cache
from ....common.distributed.ops import gather_heads_scatter_seq, gather_seq_scatter_heads_qkv
from ....optimization.compatibility import portable_repeat_interleave

from .. import na
from ..attention import FlashAttentionVarlen
Expand Down Expand Up @@ -117,7 +118,7 @@ def make_window(x: torch.Tensor):
txt_len = cache("txt_len", lambda: txt_shape.prod(-1))

vid_len_win = cache_win("vid_len", lambda: window_shape.prod(-1))
txt_len_win = cache_win("txt_len", lambda: txt_len.repeat_interleave(window_count))
txt_len_win = cache_win("txt_len", lambda: portable_repeat_interleave(txt_len, window_count, dim=0))
all_len_win = cache_win("all_len", lambda: vid_len_win + txt_len_win)
concat_win, unconcat_win = cache_win(
"mm_pnp", lambda: na.repeat_concat_idx(vid_len_win, txt_len, window_count)
Expand Down
14 changes: 14 additions & 0 deletions src/optimization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""
Optimization utilities for SeedVR2 Video Upscaler.

Exports:
IS_ROCM: Boolean indicating if running on AMD ROCm/HIP backend
portable_repeat_interleave: Cross-platform repeat_interleave that works on ROCm
"""

from .compatibility import IS_ROCM, portable_repeat_interleave

__all__ = [
"IS_ROCM",
"portable_repeat_interleave",
]
123 changes: 123 additions & 0 deletions src/optimization/compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,129 @@ def _check_conv3d_memory_bug():
NVIDIA_CONV3D_MEMORY_BUG_WORKAROUND = _check_conv3d_memory_bug()


# 5. AMD ROCm/HIP Detection and Portable Operations
def _check_is_rocm() -> bool:
"""Check if running on AMD ROCm/HIP backend."""
return hasattr(torch.version, 'hip') and torch.version.hip is not None

IS_ROCM = _check_is_rocm()


def portable_repeat_interleave(
input: torch.Tensor,
repeats: "Union[int, torch.Tensor]",
dim: int = 0,
debug: "Optional[Debug]" = None,
) -> torch.Tensor:
"""
Cross-platform replacement for torch.repeat_interleave.

torch.repeat_interleave can trigger hipErrorIllegalState on AMD ROCm
due to HIP kernel bugs. This function uses torch.arange + torch.index_select
as a portable alternative on ROCm, while using the native (faster)
implementation on CUDA/CPU.

Supports both:
- Scalar repeats: repeat each element the same number of times
- Tensor repeats: different repeat count per element

Args:
input: Input tensor to repeat elements from
repeats: Int or 1D LongTensor of repeat counts
dim: Dimension along which to repeat (default: 0)
debug: Optional Debug instance for logging tensor diagnostics

Returns:
Tensor with elements repeated along dim according to repeats

Example (scalar):
>>> x = torch.tensor([1, 2, 3])
>>> portable_repeat_interleave(x, 2, dim=0)
tensor([1, 1, 2, 2, 3, 3])

Example (tensor):
>>> x = torch.tensor([[1, 2], [3, 4], [5, 6]])
>>> repeats = torch.tensor([2, 1, 3])
>>> portable_repeat_interleave(x, repeats, dim=0)
tensor([[1, 2], [1, 2], [3, 4], [5, 6], [5, 6], [5, 6]])
"""
# Handle scalar repeats (int or 0-dim tensor)
is_scalar_repeat = isinstance(repeats, int) or (
isinstance(repeats, torch.Tensor) and repeats.dim() == 0
)

# Debug logging for tensor state diagnostics
if debug is not None:
def _tensor_info(name: str, t: torch.Tensor) -> str:
return (f"{name}: device={t.device}, dtype={t.dtype}, "
f"shape={list(t.shape)}, contiguous={t.is_contiguous()}")
repeats_info = str(repeats) if is_scalar_repeat else _tensor_info('repeats', repeats)
debug.log(
f"portable_repeat_interleave called:\n"
f" {_tensor_info('input', input)}\n"
f" repeats={repeats_info} (scalar={is_scalar_repeat})\n"
f" dim={dim}, IS_ROCM={IS_ROCM}",
level="DEBUG",
category="rocm_compat"
)

if not IS_ROCM:
# Use native implementation on CUDA/CPU - faster when it works
return torch.repeat_interleave(input, repeats, dim=dim)

# ROCm-safe implementation using index expansion
# This avoids the buggy HIP kernel in torch.repeat_interleave

if is_scalar_repeat:
# Scalar repeat: each element repeated same number of times
# Build indices [0,0,1,1,2,2,...] for repeat=2
repeat_count = int(repeats)
num_elements = input.shape[dim]
# Create indices: [0,1,2,...] then repeat each
indices = torch.arange(num_elements, device=input.device, dtype=torch.long)
# Use repeat + reshape to expand: [0,1,2] -> [[0,0],[1,1],[2,2]] -> [0,0,1,1,2,2]
expanded_indices = indices.unsqueeze(1).expand(-1, repeat_count).reshape(-1)
else:
# Tensor repeat: different count per element
# Strategy: Build index tensor mapping each output position to source position
# For repeats=[2, 1, 3], build indices=[0, 0, 1, 2, 2, 2]

# Ensure repeats is on same device as input
if repeats.device != input.device:
repeats = repeats.to(input.device)

num_elements = repeats.shape[0]
indices = torch.arange(num_elements, device=input.device, dtype=torch.long)

# Try repeat_interleave on indices first (simpler 1D case may work)
try:
expanded_indices = torch.repeat_interleave(indices, repeats)
except RuntimeError as e:
# Ultimate fallback: build indices manually (slower but always works)
if debug is not None:
debug.log(
f"repeat_interleave on indices failed, using manual fallback: {e}",
level="WARNING",
category="rocm_compat"
)
indices_list = []
repeats_cpu = repeats.cpu().tolist()
for i, r in enumerate(repeats_cpu):
indices_list.extend([i] * int(r))
expanded_indices = torch.tensor(indices_list, device=input.device, dtype=torch.long)

result = torch.index_select(input, dim, expanded_indices)

if debug is not None:
debug.log(
f"portable_repeat_interleave result: {_tensor_info('output', result)}",
level="DEBUG",
category="rocm_compat"
)

return result


# Log all optimization status once globally (cross-process) using environment variable
if not os.environ.get("SEEDVR2_OPTIMIZATIONS_LOGGED"):
os.environ["SEEDVR2_OPTIMIZATIONS_LOGGED"] = "1"
Expand Down