Skip to content

Commit 82a875f

Browse files
committed
distribution worker fix remote vae v2
1 parent 2c15fe0 commit 82a875f

1 file changed

Lines changed: 29 additions & 10 deletions

File tree

distribution_worker.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -308,20 +308,25 @@ def _process_job(self, job):
308308
print(f"[Worker {self.worker_id}] 📦 Loaded model: {config['model']}")
309309

310310
# --- Per-config VAE loading ---
311-
# Workers only support local VAE — remote VAE endpoints (remote:https://...)
312-
# are handled by the master's RemoteVAEDecodeWorker, so fall back to
313-
# the model's bundled VAE when a remote VAE is specified.
311+
# Supports three modes:
312+
# 1. "Default" → use model's bundled VAE
313+
# 2. "remote:https://..." → remote HuggingFace VAE endpoint (decoded at decode step)
314+
# 3. "ae.safetensors" → load local VAE file
314315
config_vae = config.get("vae", "Default")
315-
if config_vae != "Default" and not config_vae.startswith("remote:"):
316-
vae = load_vae_by_name(config_vae)
317-
elif config_vae.startswith("remote:"):
318-
print(f"[Worker {self.worker_id}] ⚠️ Skipping remote VAE '{config_vae}' — using model's bundled VAE")
319-
vae = self._loaded_vae
316+
_remote_vae_url = None
317+
if config_vae != "Default":
318+
if config_vae.startswith("remote:"):
319+
_remote_vae_url = config_vae[len("remote:"):]
320+
vae = None # No local VAE needed — will use remote endpoint
321+
print(f"[Worker {self.worker_id}] 🌐 Using remote VAE: {_remote_vae_url}")
322+
else:
323+
vae = load_vae_by_name(config_vae)
320324
else:
321325
vae = self._loaded_vae
322326

323327
# Validate VAE is available (non-checkpoint models like GGUF may not bundle a VAE)
324-
if vae is None:
328+
# Skip this check when using remote VAE — no local VAE object needed.
329+
if vae is None and _remote_vae_url is None:
325330
raise RuntimeError(
326331
f"No VAE available for model '{config.get('model', 'unknown')}'. "
327332
f"Non-checkpoint models (GGUF/diffusion) require a VAE to be specified "
@@ -387,7 +392,21 @@ def _process_job(self, job):
387392
)
388393

389394
# --- VAE Decode ---
390-
image = decode_latent_with_vae(vae, result_latent["samples"])
395+
if _remote_vae_url:
396+
# Remote VAE: send latents to HuggingFace endpoint for decoding
397+
from .remote_vae import remote_decode_hf
398+
from PIL import Image as PILImage
399+
import numpy as np
400+
latent_samples = result_latent["samples"]
401+
if latent_samples.ndim == 3:
402+
latent_samples = latent_samples.unsqueeze(0)
403+
decoded = remote_decode_hf(_remote_vae_url, latent_samples, h, w)
404+
# Denormalize from [-1, 1] → [0, 1] (same as VaeImageProcessor.postprocess)
405+
decoded = (decoded / 2 + 0.5).clamp(0, 1)
406+
img_data = decoded[0].permute(1, 2, 0).cpu().numpy()
407+
image = PILImage.fromarray((img_data * 255).round().astype(np.uint8))
408+
else:
409+
image = decode_latent_with_vae(vae, result_latent["samples"])
391410

392411
# Free latent tensors ASAP (before image conversion, which is CPU-only)
393412
del result_latent, latent_in, pos_cond, neg_cond

0 commit comments

Comments
 (0)