Skip to content

Commit ca37f48

Browse files
authored
Merge pull request #248 from AInVFX/main
Fix: torch.mps AttributeError on Windows
2 parents b4cd3ce + c41e364 commit ca37f48

8 files changed

Lines changed: 17 additions & 17 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[project]
22
name = "seedvr2_videoupscaler"
33
description = "SeedVR2 official ComfyUI integration: ByteDance-Seed's one-step diffusion-based video/image upscaling with memory-efficient inference"
4-
version = "2.5.1"
4+
version = "2.5.2"
55
authors = [
66
{name = "numz"},
77
{name = "adrientoupet"}

src/common/distributed/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def get_device() -> torch.device:
4747
"""
4848
Get current rank device.
4949
"""
50-
if torch.mps.is_available():
50+
if hasattr(torch, 'mps') and callable(getattr(torch.mps, 'is_available', None)) and torch.mps.is_available():
5151
return torch.device("mps")
5252
return torch.device("cuda", get_local_rank())
5353

src/data/image/transforms/area_resize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def __init__(
3131
self.max_area = max_area
3232
self.downsample_only = downsample_only
3333
self.interpolation = interpolation
34-
if torch.mps.is_available():
34+
if hasattr(torch, 'mps') and callable(getattr(torch.mps, 'is_available', None)) and torch.mps.is_available():
3535
self.interpolation = InterpolationMode.BILINEAR
3636

3737
def __call__(self, image: Union[torch.Tensor, Image.Image]):

src/data/image/transforms/na_resize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def NaResize(
2626
max_resolution: int = 0,
2727
interpolation: InterpolationMode = InterpolationMode.BICUBIC,
2828
):
29-
Interpolation = InterpolationMode.BILINEAR if torch.mps.is_available() else interpolation
29+
Interpolation = InterpolationMode.BILINEAR if (hasattr(torch, 'mps') and callable(getattr(torch.mps, 'is_available', None)) and torch.mps.is_available()) else interpolation
3030
if mode == "area":
3131
return AreaResize(
3232
max_area=resolution**2,

src/data/image/transforms/side_resize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def __init__(
3030
self.max_size = max_size
3131
self.downsample_only = downsample_only
3232
self.interpolation = interpolation
33-
if torch.mps.is_available():
33+
if hasattr(torch, 'mps') and callable(getattr(torch.mps, 'is_available', None)) and torch.mps.is_available():
3434
self.interpolation = InterpolationMode.BILINEAR
3535

3636
def __call__(self, image: Union[torch.Tensor, Image.Image]):

src/optimization/compatibility.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def __init__(self, dit_model, debug: 'Debug', compute_dtype: torch.dtype = torch
207207
self._convert_rope_freqs(target_dtype=self.compute_dtype)
208208
self.debug.end_timer("_convert_rope_freqs", "RoPE freqs conversion")
209209

210-
if torch.mps.is_available():
210+
if hasattr(torch, 'mps') and callable(getattr(torch.mps, 'is_available', None)) and torch.mps.is_available():
211211
self.debug.log(f"Also converting NaDiT parameters/buffers for MPS backend", category="setup", force=True)
212212
self.debug.start_timer("_force_nadit_precision")
213213
self._force_nadit_precision(target_dtype=self.compute_dtype)
@@ -510,7 +510,7 @@ def _compute_sdpa_attention(self, module: torch.nn.Module, x: torch.Tensor, *arg
510510
k = k.view(batch_size, seq_len, num_heads, head_dim).transpose(1, 2)
511511
v = v.view(batch_size, seq_len, num_heads, head_dim).transpose(1, 2)
512512

513-
if torch.mps.is_available():
513+
if hasattr(torch, 'mps') and callable(getattr(torch.mps, 'is_available', None)) and torch.mps.is_available():
514514
attn_output = torch.nn.functional.scaled_dot_product_attention(
515515
q, k, v,
516516
dropout_p=0.0,

src/optimization/memory_manager.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def get_basic_vram_info(device: Optional[torch.device] = None) -> Dict[str, Any]
8080
elif not isinstance(device, torch.device):
8181
device = torch.device(device)
8282
free_memory, total_memory = torch.cuda.mem_get_info(device)
83-
elif torch.mps.is_available():
83+
elif hasattr(torch, 'mps') and callable(getattr(torch.mps, 'is_available', None)) and torch.mps.is_available():
8484
# MPS doesn't support per-device queries or mem_get_info
8585
# Use system memory as proxy
8686
mem = psutil.virtual_memory()
@@ -100,7 +100,7 @@ def get_basic_vram_info(device: Optional[torch.device] = None) -> Dict[str, Any]
100100
# Initial VRAM check at module load
101101
vram_info = get_basic_vram_info(device=None)
102102
if "error" not in vram_info:
103-
backend = "MPS" if torch.mps.is_available() else "CUDA"
103+
backend = "MPS" if (hasattr(torch, 'mps') and callable(getattr(torch.mps, 'is_available', None)) and torch.mps.is_available()) else "CUDA"
104104
print(f"📊 Initial {backend} memory: {vram_info['free_gb']:.2f}GB free / {vram_info['total_gb']:.2f}GB total")
105105
else:
106106
print(f"⚠️ Memory check failed: {vram_info['error']} - No available backend!")
@@ -129,7 +129,7 @@ def get_vram_usage(device: Optional[torch.device] = None, debug: Optional['Debug
129129
reserved = torch.cuda.memory_reserved(device) / (1024**3)
130130
max_allocated = torch.cuda.max_memory_allocated(device) / (1024**3)
131131
return allocated, reserved, max_allocated
132-
elif torch.mps.is_available():
132+
elif hasattr(torch, 'mps') and callable(getattr(torch.mps, 'is_available', None)) and torch.mps.is_available():
133133
# MPS doesn't support per-device queries - uses global memory tracking
134134
allocated = torch.mps.current_allocated_memory() / (1024**3)
135135
reserved = torch.mps.driver_allocated_memory() / (1024**3)
@@ -235,11 +235,11 @@ def clear_memory(debug: Optional['Debug'] = None, deep: bool = False, force: boo
235235
if free_ratio < 0.05:
236236
should_clear = True
237237
if debug:
238-
backend = "MPS" if torch.mps.is_available() else "VRAM"
238+
backend = "MPS" if (hasattr(torch, 'mps') and callable(getattr(torch.mps, 'is_available', None)) and torch.mps.is_available()) else "VRAM"
239239
debug.log(f"{backend} pressure: {mem_info['free_gb']:.2f}GB free of {mem_info['total_gb']:.2f}GB", category="memory")
240240

241241
# For non-MPS systems, also check system RAM separately
242-
if not should_clear and not torch.mps.is_available():
242+
if not should_clear and not (hasattr(torch, 'mps') and callable(getattr(torch.mps, 'is_available', None)) and torch.mps.is_available()):
243243
mem = psutil.virtual_memory()
244244
if mem.available < mem.total * 0.05:
245245
should_clear = True
@@ -265,7 +265,7 @@ def clear_memory(debug: Optional['Debug'] = None, deep: bool = False, force: boo
265265
if torch.cuda.is_available():
266266
torch.cuda.empty_cache()
267267
torch.cuda.ipc_collect()
268-
elif torch.mps.is_available():
268+
elif hasattr(torch, 'mps') and callable(getattr(torch.mps, 'is_available', None)) and torch.mps.is_available():
269269
torch.mps.empty_cache()
270270

271271
if debug:
@@ -302,7 +302,7 @@ def clear_memory(debug: Optional['Debug'] = None, deep: bool = False, force: boo
302302
handle = _os_memory_lib.GetCurrentProcess()
303303
_os_memory_lib.SetProcessWorkingSetSize(handle, -1, -1)
304304

305-
elif torch.mps.is_available():
305+
elif hasattr(torch, 'mps') and callable(getattr(torch.mps, 'is_available', None)) and torch.mps.is_available():
306306
# macOS with MPS
307307
import ctypes # Import only when needed
308308
import ctypes.util

src/utils/debug.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ def _collect_memory_metrics(self) -> Dict[str, Any]:
333333
}
334334

335335
# VRAM metrics
336-
if torch.cuda.is_available() or torch.mps.is_available():
336+
if torch.cuda.is_available() or (hasattr(torch, 'mps') and callable(getattr(torch.mps, 'is_available', None)) and torch.mps.is_available()):
337337
metrics['vram_allocated'], metrics['vram_reserved'], current_global_peak = get_vram_usage(device=None, debug=self)
338338

339339
# Calculate peak since last log_memory_state
@@ -346,7 +346,7 @@ def _collect_memory_metrics(self) -> Dict[str, Any]:
346346
metrics['vram_free'] = vram_info["free_gb"]
347347
metrics['vram_total'] = vram_info["total_gb"]
348348

349-
backend = "MPS" if torch.mps.is_available() else "VRAM"
349+
backend = "MPS" if (hasattr(torch, 'mps') and callable(getattr(torch.mps, 'is_available', None)) and torch.mps.is_available()) else "VRAM"
350350
metrics['summary_vram'] = (f" [{backend}] {metrics['vram_allocated']:.2f}GB allocated / "
351351
f"{metrics['vram_reserved']:.2f}GB reserved / "
352352
f"Peak: {metrics['vram_peak_since_last']:.2f}GB / "
@@ -369,7 +369,7 @@ def _collect_memory_metrics(self) -> Dict[str, Any]:
369369
metrics['summary_ram'] = ""
370370

371371
# Update VRAM history for tracking
372-
if torch.cuda.is_available() or torch.mps.is_available():
372+
if torch.cuda.is_available() or (hasattr(torch, 'mps') and callable(getattr(torch.mps, 'is_available', None)) and torch.mps.is_available()):
373373
self.vram_history.append(metrics['vram_allocated'])
374374

375375
return metrics

0 commit comments

Comments
 (0)