|
15 | 15 | import os |
16 | 16 | import subprocess |
17 | 17 | import tempfile |
| 18 | +import urllib.request |
18 | 19 | from typing import Callable, Dict, Optional |
19 | 20 |
|
20 | 21 | from opencut.helpers import ensure_package, get_ffmpeg_path, get_video_info, run_ffmpeg |
21 | 22 |
|
22 | 23 | logger = logging.getLogger("opencut") |
23 | 24 |
|
| 25 | +REALESRGAN_MODELS_DIR = os.path.join(os.path.expanduser("~"), ".opencut", "models") |
| 26 | +REALESRGAN_MODEL_SPECS = { |
| 27 | + "RealESRGAN_x4plus": { |
| 28 | + "filename": "RealESRGAN_x4plus.pth", |
| 29 | + "url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth", |
| 30 | + "scale": 4, |
| 31 | + "num_block": 23, |
| 32 | + }, |
| 33 | + "RealESRGAN_x2plus": { |
| 34 | + "filename": "RealESRGAN_x2plus.pth", |
| 35 | + "url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth", |
| 36 | + "scale": 2, |
| 37 | + "num_block": 23, |
| 38 | + }, |
| 39 | + "RealESRGAN_x4plus_anime_6B": { |
| 40 | + "filename": "RealESRGAN_x4plus_anime_6B.pth", |
| 41 | + "url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth", |
| 42 | + "scale": 4, |
| 43 | + "num_block": 6, |
| 44 | + }, |
| 45 | +} |
| 46 | +_REALESRGAN_ALIASES = { |
| 47 | + "realesrgan_x4plus": "RealESRGAN_x4plus", |
| 48 | + "realesrgan-x4plus": "RealESRGAN_x4plus", |
| 49 | + "x4plus": "RealESRGAN_x4plus", |
| 50 | + "realesrgan_x2plus": "RealESRGAN_x2plus", |
| 51 | + "realesrgan-x2plus": "RealESRGAN_x2plus", |
| 52 | + "x2plus": "RealESRGAN_x2plus", |
| 53 | + "realesrgan_x4plus_anime_6b": "RealESRGAN_x4plus_anime_6B", |
| 54 | + "realesrgan-x4plus-anime-6b": "RealESRGAN_x4plus_anime_6B", |
| 55 | + "x4plus_anime_6b": "RealESRGAN_x4plus_anime_6B", |
| 56 | + "anime": "RealESRGAN_x4plus_anime_6B", |
| 57 | +} |
| 58 | + |
| 59 | + |
| 60 | +def _canonical_realesrgan_model_name(model_name: str) -> str: |
| 61 | + if model_name in REALESRGAN_MODEL_SPECS: |
| 62 | + return model_name |
| 63 | + normalized = str(model_name or "").strip().lower().replace(" ", "_") |
| 64 | + return _REALESRGAN_ALIASES.get(normalized, "RealESRGAN_x4plus") |
| 65 | + |
| 66 | + |
| 67 | +def _resolve_realesrgan_model_path( |
| 68 | + model_name: str, |
| 69 | + on_progress: Optional[Callable] = None, |
| 70 | +) -> str: |
| 71 | + canonical_name = _canonical_realesrgan_model_name(model_name) |
| 72 | + spec = REALESRGAN_MODEL_SPECS[canonical_name] |
| 73 | + os.makedirs(REALESRGAN_MODELS_DIR, exist_ok=True) |
| 74 | + model_path = os.path.join(REALESRGAN_MODELS_DIR, spec["filename"]) |
| 75 | + if os.path.isfile(model_path) and os.path.getsize(model_path) >= 1024: |
| 76 | + return model_path |
| 77 | + |
| 78 | + if on_progress: |
| 79 | + on_progress(4, f"Downloading {canonical_name} weights...") |
| 80 | + |
| 81 | + tmp_path = f"{model_path}.download" |
| 82 | + try: |
| 83 | + urllib.request.urlretrieve(spec["url"], tmp_path) |
| 84 | + if not os.path.isfile(tmp_path) or os.path.getsize(tmp_path) < 1024: |
| 85 | + raise RuntimeError(f"Downloaded Real-ESRGAN weights are empty: {spec['url']}") |
| 86 | + os.replace(tmp_path, model_path) |
| 87 | + except Exception: |
| 88 | + try: |
| 89 | + if os.path.exists(tmp_path): |
| 90 | + os.unlink(tmp_path) |
| 91 | + except OSError: |
| 92 | + pass |
| 93 | + raise |
| 94 | + return model_path |
| 95 | + |
24 | 96 | # --------------------------------------------------------------------------- |
25 | 97 | # Availability |
26 | 98 | # --------------------------------------------------------------------------- |
@@ -119,9 +191,20 @@ def upscale_realesrgan( |
119 | 191 | if on_progress: |
120 | 192 | on_progress(5, "Loading Real-ESRGAN model...") |
121 | 193 |
|
122 | | - model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) |
| 194 | + canonical_model_name = _canonical_realesrgan_model_name(model_name) |
| 195 | + model_spec = REALESRGAN_MODEL_SPECS[canonical_model_name] |
| 196 | + model_path = _resolve_realesrgan_model_path(canonical_model_name, on_progress) |
| 197 | + model_scale = int(model_spec["scale"]) |
| 198 | + model = RRDBNet( |
| 199 | + num_in_ch=3, |
| 200 | + num_out_ch=3, |
| 201 | + num_feat=64, |
| 202 | + num_block=int(model_spec["num_block"]), |
| 203 | + num_grow_ch=32, |
| 204 | + scale=model_scale, |
| 205 | + ) |
123 | 206 | upsampler = RealESRGANer( |
124 | | - scale=4, model_path=None, model=model, tile=tile, |
| 207 | + scale=model_scale, model_path=model_path, model=model, tile=tile, |
125 | 208 | half=torch.cuda.is_available(), |
126 | 209 | ) |
127 | 210 |
|
|
0 commit comments