diff --git a/DEBUGGING_NOTES.md b/DEBUGGING_NOTES.md new file mode 100644 index 000000000000..93bfdb1dfefd --- /dev/null +++ b/DEBUGGING_NOTES.md @@ -0,0 +1,95 @@ +# Debugging Notes - Inference Scaling Implementation + +## Issues Found and Fixed + +### 1. **Noise vs Latent Image Handling** ✅ FIXED +**Issue:** Was incorrectly extracting `samples` from `latent_image` dict and passing it as `noise`. + +**Fix:** +- Now properly extracts `latent_image_tensor` from dict +- Uses `comfy.sample.prepare_noise()` to generate noise from latent_image and seed (same as `common_ksampler`) +- Passes both `noise` and `latent_image_tensor` to `comfy.sample.sample()` + +### 2. **Sampler Usage** ✅ FIXED +**Issue:** Was creating `KSampler` instance but should use `comfy.sample.sample()` directly. + +**Fix:** +- Now uses `comfy.sample.sample()` which handles all the setup properly +- This matches how `common_ksampler` works in `nodes.py` + +### 3. **VAE Decoding** ✅ FIXED +**Issue:** Needed to ensure proper tensor format and device handling. + +**Fix:** +- VAE.decode expects `[B, C, H, W]` tensor and returns `[B, H, W, C]` +- Added proper error handling and tensor validation +- Moves decoded images to CPU for quality checking to avoid GPU memory issues + +### 4. **Callback Scope Issues** ✅ FIXED +**Issue:** Variables like `steps` not accessible in callback closure. + +**Fix:** +- Removed reference to `steps` in logging (not needed) +- Added `nonlocal` declarations for variables modified in callback +- Added check for `check_interval > 0` to avoid division by zero + +### 5. **Error Handling** ✅ IMPROVED +**Issue:** Errors could break the sampling process. + +**Fix:** +- Added comprehensive try/except blocks +- Quality check failures don't stop sampling +- Better logging with appropriate log levels +- Traceback only in debug mode to avoid spam + +### 6. **Result Format** ✅ FIXED +**Issue:** Need to preserve all keys from input latent dict. + +**Fix:** +- Uses `latent_image.copy()` to preserve all keys (batch_index, noise_mask, etc.) +- Only updates `samples` key with result + +## Remaining Limitations + +### CFG Adjustment +- **Status:** Cannot be dynamically adjusted during sampling +- **Reason:** CFG is set at the start of sampling and used throughout +- **Workaround:** Quality checking still works and provides valuable feedback +- **Future:** Could implement adaptive step count or early stopping + +### Performance +- **VAE Decoding:** Adds overhead during sampling (decodes at intervals) +- **Mitigation:** Only checks at specified intervals (default: every 5 steps) +- **Future:** Could optimize by using TAESD for faster preview decoding + +## Testing Checklist + +- [x] Syntax check passes (`py_compile`) +- [x] No linter errors +- [ ] Import test (requires Python 3.8+ for `get_origin`) +- [ ] Runtime test with actual ComfyUI +- [ ] Test with different samplers/schedulers +- [ ] Test with different VAE models +- [ ] Test quality verification logic +- [ ] Test error handling (invalid verifier_id, etc.) + +## Code Quality Improvements Made + +1. **Better Error Messages:** More descriptive exceptions with tracebacks +2. **Logging:** Added info/warning logs for quality checks +3. **Type Safety:** Added isinstance checks for tensors +4. **Memory Management:** Moves decoded images to CPU +5. **Code Organization:** Follows ComfyUI patterns (`common_ksampler` style) + +## Files Modified + +- `comfy_api_nodes/nodes_inference_scaling.py` - Main implementation +- `DEBUGGING_NOTES.md` - This file + +## Next Steps for Full Testing + +1. Test in actual ComfyUI environment +2. Verify callback is called correctly +3. Test VAE decoding with different models +4. Verify quality metrics are reasonable +5. Test edge cases (empty latents, different batch sizes, etc.) diff --git a/HOW_TO_WRITE_NODES.md b/HOW_TO_WRITE_NODES.md new file mode 100644 index 000000000000..59d68fff2273 --- /dev/null +++ b/HOW_TO_WRITE_NODES.md @@ -0,0 +1,434 @@ +# How to Write Nodes in ComfyUI-Inference-Scaling + +This guide will walk you through creating custom nodes for ComfyUI using the modern API system. + +## Overview + +In this ComfyUI project, nodes are created using the **ComfyAPI** system. Each node is a class that: +1. Inherits from `IO.ComfyNode` +2. Defines its schema (inputs, outputs, metadata) via `define_schema()` +3. Implements its execution logic via `execute()` +4. Is registered through a `ComfyExtension` class + +## Basic Structure + +### 1. Node Class + +Every node is a class that inherits from `IO.ComfyNode`: + +```python +from comfy_api.latest import IO, ComfyExtension +from typing_extensions import override +from inspect import cleandoc + +class MyCustomNode(IO.ComfyNode): + """ + Description of what your node does. + This docstring will appear in the UI. + """ + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="MyCustomNode", + display_name="My Custom Node", + category="mycategory/subcategory", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + # Define inputs here + ], + outputs=[ + # Define outputs here + ], + ) + + @classmethod + async def execute(cls, ...) -> IO.NodeOutput: + # Implementation here + pass +``` + +### 2. Schema Definition + +The `define_schema()` method defines: +- **node_id**: Unique identifier (usually matches class name) +- **display_name**: Name shown in the UI +- **category**: Where it appears in the node menu (use `/` for subcategories) +- **description**: Tooltip/help text +- **inputs**: List of input definitions +- **outputs**: List of output definitions +- **hidden**: Optional hidden inputs (like auth tokens) +- **is_api_node**: Set to `True` for API nodes + +### 3. Input Types + +Common input types available: + +```python +# String input +IO.String.Input( + "prompt", + default="", + multiline=True, # For longer text + tooltip="Help text shown on hover", + optional=True, # Makes it optional +) + +# Integer input +IO.Int.Input( + "seed", + default=0, + min=0, + max=100, + step=1, + display_mode=IO.NumberDisplay.slider, # or .number + control_after_generate=True, # Shows randomize button + tooltip="Random seed", +) + +# Float input +IO.Float.Input( + "strength", + default=0.5, + min=0.0, + max=1.0, + step=0.01, + tooltip="Strength value", +) + +# Combo/Dropdown input +IO.Combo.Input( + "model", + options=["option1", "option2", "option3"], + default="option1", + tooltip="Select a model", +) + +# Image input +IO.Image.Input( + "image", + tooltip="Input image", + optional=True, +) + +# Mask input +IO.Mask.Input( + "mask", + tooltip="Mask for inpainting", + optional=True, +) + +# Audio input +IO.Audio.Input( + "audio", + tooltip="Input audio", + optional=True, +) + +# Video input +IO.Video.Input( + "video", + tooltip="Input video", + optional=True, +) +``` + +### 4. Output Types + +Common output types: + +```python +# Image output +IO.Image.Output() + +# Audio output +IO.Audio.Output() + +# Video output +IO.Video.Output() + +# String output +IO.String.Output() + +# Integer output +IO.Int.Output() + +# Float output +IO.Float.Output() +``` + +### 5. Execute Method + +The `execute()` method is where your node's logic runs: + +```python +@classmethod +async def execute( + cls, + # Parameters match input names from define_schema + prompt: str, + seed: int = 0, + image: Optional[torch.Tensor] = None, + # ... other inputs +) -> IO.NodeOutput: + """ + Execute the node logic. + + Args: + prompt: Text prompt + seed: Random seed + image: Optional image tensor (shape: [B, H, W, C]) + ... + + Returns: + IO.NodeOutput with the result + """ + # Your implementation here + + # For image outputs: + result_tensor = ... # torch.Tensor with shape [B, H, W, C] + return IO.NodeOutput(result_tensor) + + # For multiple outputs: + return IO.NodeOutput(image=result_image, metadata=result_metadata) +``` + +**Important Notes:** +- The method is `async` - use `await` for async operations +- Input parameters match the names from `define_schema()` +- Image tensors have shape `[Batch, Height, Width, Channels]` (usually RGBA) +- Use `IO.NodeOutput()` to return results + +### 6. Extension Registration + +To register your nodes, create an extension class: + +```python +class MyExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + MyCustomNode, + AnotherNode, + # ... list all your nodes + ] + + +async def comfy_entrypoint() -> MyExtension: + """ + Entry point function that ComfyUI calls to load your extension. + """ + return MyExtension() +``` + +## Complete Example + +Here's a complete example of a simple image processing node: + +```python +from io import BytesIO +import torch +import numpy as np +from PIL import Image +from typing import Optional +from typing_extensions import override +from inspect import cleandoc + +from comfy_api.latest import IO, ComfyExtension +from comfy_api_nodes.util import validate_string + + +class ImageBrightnessNode(IO.ComfyNode): + """ + Adjusts the brightness of an input image. + """ + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="ImageBrightnessNode", + display_name="Image Brightness", + category="image/processing", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + IO.Image.Input( + "image", + tooltip="Input image to adjust", + ), + IO.Float.Input( + "brightness", + default=1.0, + min=0.0, + max=2.0, + step=0.1, + tooltip="Brightness multiplier (1.0 = no change)", + ), + ], + outputs=[ + IO.Image.Output(), + ], + ) + + @classmethod + async def execute( + cls, + image: torch.Tensor, + brightness: float = 1.0, + ) -> IO.NodeOutput: + """ + Adjust image brightness. + + Args: + image: Input image tensor [B, H, W, C] + brightness: Brightness multiplier + + Returns: + Brightness-adjusted image + """ + # Ensure we have a batch dimension + if len(image.shape) == 3: + image = image.unsqueeze(0) + + # Apply brightness adjustment + # Clamp values to [0, 1] range + adjusted = torch.clamp(image * brightness, 0.0, 1.0) + + return IO.NodeOutput(adjusted) + + +class MyExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + ImageBrightnessNode, + ] + + +async def comfy_entrypoint() -> MyExtension: + return MyExtension() +``` + +## API Node Example + +For API nodes (nodes that call external APIs), you'll typically: + +1. Use utility functions from `comfy_api_nodes.util`: + - `sync_op()` - for synchronous API calls + - `poll_op()` - for polling async operations + - `validate_string()` - for input validation + - `tensor_to_bytesio()` - convert image tensors to bytes + - `bytesio_to_image_tensor()` - convert bytes to image tensors + +2. Include hidden inputs for authentication: +```python +hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, +], +is_api_node=True, +``` + +3. Use `ApiEndpoint` for API calls: +```python +from comfy_api_nodes.util import ApiEndpoint, sync_op + +response = await sync_op( + cls, + ApiEndpoint(path="/api/endpoint", method="POST"), + response_model=YourResponseModel, + data=YourRequestModel(...), + files={...}, # Optional file uploads + content_type="application/json", +) +``` + +## File Organization + +1. **Create your node file**: `comfy_api_nodes/nodes_yourname.py` +2. **Follow naming conventions**: + - Node classes: `YourNodeName` (PascalCase) + - Extension class: `YourExtension` (PascalCase) + - File: `nodes_yourname.py` (snake_case) +3. **Import required modules**: + - `from comfy_api.latest import IO, ComfyExtension` + - `from typing_extensions import override` + - `from inspect import cleandoc` + +## Testing Your Node + +1. **Start ComfyUI**: + ```bash + python main.py + ``` + +2. **Check the node appears** in the node menu under your specified category + +3. **Test the node** by: + - Adding it to a workflow + - Connecting inputs + - Executing the workflow + +## Common Patterns + +### Working with Images + +```python +# Image tensor shape: [Batch, Height, Width, Channels] +# Channels are usually RGBA (4 channels) + +# Convert PIL Image to tensor +pil_img = Image.open("image.png").convert("RGBA") +arr = np.asarray(pil_img).astype(np.float32) / 255.0 +tensor = torch.from_numpy(arr).unsqueeze(0) # Add batch dimension + +# Convert tensor to PIL Image +tensor = image.squeeze(0).cpu() # Remove batch, move to CPU +image_np = (tensor.numpy() * 255).astype(np.uint8) +pil_img = Image.fromarray(image_np) +``` + +### Validation + +```python +from comfy_api_nodes.util import validate_string + +# Validate string inputs +validate_string(prompt, strip_whitespace=False) + +# Check optional inputs +if image is not None: + # Process image + pass +``` + +### Error Handling + +```python +if some_condition: + raise Exception("Error message here") +``` + +## Tips + +1. **Use type hints** - They help with IDE autocomplete and documentation +2. **Add tooltips** - Help users understand what each input does +3. **Use `cleandoc()`** - Cleans up docstrings for display +4. **Make inputs optional** when appropriate - Improves usability +5. **Follow existing patterns** - Look at `nodes_openai.py` or `nodes_stability.py` for examples +6. **Test thoroughly** - Especially edge cases (None inputs, empty strings, etc.) + +## Resources + +- Existing node examples: `comfy_api_nodes/nodes_*.py` +- ComfyAPI documentation: `comfy_api/latest/` +- Utility functions: `comfy_api_nodes/util/` + +## Next Steps + +1. Look at existing nodes for reference +2. Start with a simple node +3. Test incrementally +4. Add more features as needed + +Happy node writing! 🎨 diff --git a/INFERENCE_SCALING_README.md b/INFERENCE_SCALING_README.md new file mode 100644 index 000000000000..90d3a6559ace --- /dev/null +++ b/INFERENCE_SCALING_README.md @@ -0,0 +1,129 @@ +# Inference Scaling Implementation + +This document describes the inference scaling feature implementation for ComfyUI. + +## Overview + +The inference scaling system monitors image quality during the generation loop (not just at the end) and can adjust sampling parameters based on quality metrics. It consists of two main nodes: + +1. **VerifierSelectionNode** - Creates and configures an image quality verifier +2. **InferenceScalingNode** - Wraps the standard KSampler with quality monitoring + +## Architecture + +### VerifierSelectionNode + +This node allows users to select and configure a quality verifier: + +- **Inputs:** + - `verifier_type`: Type of verifier ("simple" or "custom") + - `quality_threshold`: Minimum quality score threshold (0.0-1.0) + +- **Outputs:** + - `verifier_id`: String identifier for the verifier (used by InferenceScalingNode) + +### InferenceScalingNode + +This node wraps the standard sampling process with quality monitoring: + +- **Inputs:** + - `verifier_id`: Verifier identifier from VerifierSelectionNode + - `model`: The diffusion model + - `positive`: Positive conditioning + - `negative`: Negative conditioning + - `latent_image`: Initial latent image + - `vae`: VAE model for decoding latents during quality checks + - `seed`: Random seed + - `steps`: Number of sampling steps + - `cfg`: CFG scale + - `sampler_name`: Sampler algorithm name + - `scheduler`: Scheduler name + - `check_interval`: Check quality every N steps (default: 5) + - `quality_threshold`: Quality threshold for adjustments + - `scale_factor`: Factor to scale parameters when quality is low + +- **Outputs:** + - `latent`: The denoised latent image + +## How It Works + +1. **During Sampling:** + - The node creates a callback function that is called at each sampling step + - At intervals specified by `check_interval`, the callback: + - Decodes the current latent representation to an image using VAE + - Uses the verifier to assess image quality + - Logs quality metrics + - Can adjust parameters based on quality (currently limited - see Limitations) + +2. **Quality Verification:** + - The `SimpleVerifier` class uses: + - Image variance (higher = more detail) + - Edge detection strength (measures structure) + - Combined into a quality score (0.0-1.0) + +3. **Scaling Logic:** + - When quality is below threshold: Can increase CFG (though this is limited - see Limitations) + - Quality history is tracked for analysis + +## Limitations + +### CFG Adjustment + +**Important:** Currently, CFG cannot be dynamically adjusted during sampling because: +- CFG is set at the beginning of sampling and used throughout +- The sampling process doesn't support mid-run parameter changes + +**Workarounds:** +- Quality checking still works and provides valuable feedback +- Quality history can be used to inform future runs +- Early stopping could be implemented (not yet done) + +### Future Improvements + +1. **Dynamic Step Adjustment:** Implement adaptive step count based on quality +2. **Early Stopping:** Stop sampling early if quality is consistently good +3. **Better Verifiers:** Add more sophisticated quality metrics (e.g., perceptual metrics, CLIP-based scoring) +4. **Parameter Tuning:** Implement more sophisticated parameter adjustment strategies + +## Usage Example + +``` +1. Create VerifierSelectionNode + - Set verifier_type: "simple" + - Set quality_threshold: 0.5 + - Connect output to InferenceScalingNode's verifier_id input + +2. Create InferenceScalingNode + - Connect model, positive, negative, latent_image, vae + - Set steps: 20 + - Set cfg: 7.0 + - Set check_interval: 5 (check every 5 steps) + - Set quality_threshold: 0.5 + - Set scale_factor: 1.2 + - Connect verifier_id from VerifierSelectionNode +``` + +## Files + +- `comfy_api_nodes/nodes_inference_scaling.py` - Main implementation +- `INFERENCE_SCALING_README.md` - This file + +## Testing + +To test the implementation: + +1. Start ComfyUI +2. Create a workflow with: + - Model loading + - Text encoding (positive/negative) + - Empty latent image + - VerifierSelectionNode + - InferenceScalingNode + - VAE decode (to see final result) +3. Run the workflow and observe quality checks in logs + +## Notes + +- VAE decoding during sampling adds computational overhead +- Quality checks are performed at intervals to balance performance and monitoring +- The verifier registry stores verifiers in memory (simple approach for now) diff --git a/LOCAL_FEATURES.md b/LOCAL_FEATURES.md new file mode 100644 index 000000000000..5786649259b4 --- /dev/null +++ b/LOCAL_FEATURES.md @@ -0,0 +1,160 @@ +# ComfyUI-Inference-Scaling 本地新功能说明 + +> 基于 ComfyUI 官方仓库,本项目在 `feature/inference-scaling` 分支上添加了以下自定义功能。 + +--- + +## 新增功能概览 + +### 1. 推理缩放节点 (Inference Scaling Nodes) + +**文件:** `comfy_api_nodes/nodes_inference_scaling.py` + +在标准 KSampler 采样流程中加入了**实时图像质量监控**机制,可以在生成过程中(而非仅在最终结果)动态评估图像质量。 + +#### 1.1 QualityVerifier(质量验证器基类) + +- 定义了图像质量评估的抽象接口 +- 输入:`torch.Tensor` 格式的图像(`[B, H, W, C]`,范围 `[0, 1]`) +- 输出:`(is_acceptable: bool, quality_score: float)` 元组 + +#### 1.2 SimpleVerifier(简单质量验证器) + +基于以下两项指标评估图像质量: + +| 指标 | 方法 | 权重 | +|------|------|------| +| 图像方差 | 高方差 = 更多细节 | 60% | +| 边缘强度 | 类 Sobel 算子检测结构 | 40% | + +质量分数归一化到 `[0.0, 1.0]`。 + +#### 1.3 VerifierSelectionNode(验证器选择节点) + +ComfyUI 节点,用于创建和配置质量验证器。 + +**输入参数:** +- `verifier_type`:验证器类型(`"simple"` / `"custom"`) +- `quality_threshold`:最低质量分数阈值(`0.0–1.0`,默认 `0.5`) + +**输出:** +- `verifier_id`:验证器的字符串 ID,供 InferenceScalingNode 使用 + +#### 1.4 InferenceScalingNode(推理缩放节点) + +核心节点,将标准采样流程包装并加入质量监控回调。 + +**输入参数:** + +| 参数 | 类型 | 默认值 | 说明 | +|------|------|--------|------| +| `verifier_id` | String | `""` | 来自 VerifierSelectionNode 的验证器 ID | +| `model` | Model | — | 扩散模型 | +| `positive` | Conditioning | — | 正向提示词条件 | +| `negative` | Conditioning | — | 负向提示词条件 | +| `latent_image` | Latent | — | 初始潜空间图像 | +| `vae` | VAE | — | 用于步骤中间解码的 VAE 模型 | +| `seed` | Int | `0` | 随机种子 | +| `steps` | Int | `20` | 采样步数 | +| `cfg` | Float | `7.0` | CFG 引导系数 | +| `sampler_name` | Combo | `"euler"` | 采样器名称 | +| `scheduler` | Combo | `"normal"` | 调度器名称 | +| `check_interval` | Int | `5` | 每 N 步检查一次质量 | +| `quality_threshold` | Float | `0.5` | 质量调整触发阈值 | +| `scale_factor` | Float | `1.2` | 质量低时的参数缩放系数 | + +**输出:** +- `latent`:去噪后的潜空间图像(格式与标准 KSampler 一致) + +**工作原理:** +1. 在每个采样步骤的回调中,按 `check_interval` 间隔解码当前潜空间 +2. 使用 VAE 将潜变量解码为图像(`[B, H, W, C]`) +3. 调用验证器评估质量分数 +4. 记录质量历史,低质量时输出警告日志 +5. 采样完成后返回标准格式的 latent dict + +--- + +### 2. 节点开发教程 + +**文件:** `HOW_TO_WRITE_NODES.md` + +详细说明了如何基于本项目的 ComfyAPI 系统编写自定义节点,包括: + +- `IO.ComfyNode` 的继承结构 +- `define_schema()` 方法的编写方式(输入/输出定义) +- `execute()` 异步方法的实现模式 +- 通过 `ComfyExtension` 注册节点的方式 +- 各种输入类型的使用示例(String、Int、Float、Image、Combo 等) + +--- + +### 3. 节点模板示例 + +**文件:** `comfy_api_nodes/nodes_template_example.py` + +可直接复制修改的节点模板,演示了: + +- 文本输入(`IO.String.Input`) +- 数值输入带滑块(`IO.Int.Input` + `IO.NumberDisplay.slider`) +- 可选图像输入(`IO.Image.Input`, `optional=True`) +- 多类型输出 +- 完整的 `execute()` 实现示例 + +--- + +### 4. 调试笔记 + +**文件:** `DEBUGGING_NOTES.md` + +记录了开发过程中发现并修复的问题: + +| 问题 | 状态 | 解决方案 | +|------|------|----------| +| 噪声/潜变量处理错误 | ✅ 已修复 | 正确使用 `comfy.sample.prepare_noise()` 生成噪声 | +| KSampler 实例化方式错误 | ✅ 已修复 | 改用 `comfy.sample.sample()` 直接调用 | +| VAE 解码张量格式问题 | ✅ 已修复 | 正确处理 `[B,C,H,W]` → `[B,H,W,C]` 转换 | +| 回调闭包变量作用域 | ✅ 已修复 | 添加 `nonlocal` 声明 | +| 错误处理不完善 | ✅ 已改进 | 质量检查失败不中断采样,仅记录警告 | +| latent dict 键丢失 | ✅ 已修复 | 使用 `latent_image.copy()` 保留所有键 | + +--- + +## 已知限制 + +- **CFG 动态调整:** 由于 CFG 在采样开始时确定,无法在采样过程中实时修改 +- **VAE 解码开销:** 每 N 步解码一次会增加额外计算时间(可通过 `check_interval` 控制频率) + +## 未来改进方向 + +- 实现早停机制(质量持续良好时提前结束采样) +- 支持基于质量的动态步数调整 +- 引入更高级的质量评估方法(如感知损失、CLIP 评分) +- 使用 TAESD 替代完整 VAE 加速预览解码 + +--- + +## 使用方式 + +``` +1. 添加 VerifierSelectionNode + - verifier_type: "simple" + - quality_threshold: 0.5 + +2. 添加 InferenceScalingNode + - 连接 model / positive / negative / latent_image / vae + - 连接 verifier_id(来自 VerifierSelectionNode) + - steps: 20, cfg: 7.0 + - check_interval: 5(每5步检查一次) + - scale_factor: 1.2 + +3. 连接后续的 VAE Decode 节点查看最终图像 +``` + +--- + +## 分支与版本信息 + +- **自定义分支:** `feature/inference-scaling` +- **基于 ComfyUI 版本:** 已 rebase 到上游最新提交(`f21f6b22`,2026-04-04 更新) +- **自定义提交:** `c254fa10 Add inference scaling nodes with quality monitoring during generation` diff --git a/comfy_api_nodes/nodes_inference_scaling.py b/comfy_api_nodes/nodes_inference_scaling.py new file mode 100644 index 000000000000..fb67fc089a0a --- /dev/null +++ b/comfy_api_nodes/nodes_inference_scaling.py @@ -0,0 +1,437 @@ +""" +Inference Scaling Nodes for ComfyUI + +This module provides nodes for adaptive inference scaling during image generation. +The system monitors image quality during the generation loop and adjusts sampling +parameters dynamically based on quality metrics. +""" + +from typing import Any +from typing_extensions import override +from inspect import cleandoc +import torch +import numpy as np + +from comfy_api.latest import IO, ComfyExtension +from comfy.samplers import KSampler +from comfy.sd import VAE +from comfy import model_management +import comfy.sample + +# Get sampler and scheduler options +SAMPLER_OPTIONS = KSampler.SAMPLERS +SCHEDULER_OPTIONS = KSampler.SCHEDULERS + + +class QualityVerifier: + """ + Base class for image quality verification. + Subclasses should implement specific quality metrics. + """ + + def __init__(self, threshold: float = 0.5): + self.threshold = threshold + + def verify(self, image: torch.Tensor) -> tuple[bool, float]: + """ + Verify image quality. + + Args: + image: Image tensor [B, H, W, C] in range [0, 1] + + Returns: + Tuple of (is_acceptable, quality_score) + """ + raise NotImplementedError("Subclasses must implement verify method") + + def get_quality_score(self, image: torch.Tensor) -> float: + """ + Get quality score without threshold check. + + Args: + image: Image tensor [B, H, W, C] in range [0, 1] + + Returns: + Quality score in range [0, 1] + """ + _, score = self.verify(image) + return score + + +class SimpleVerifier(QualityVerifier): + """ + Simple verifier using basic image statistics. + Uses variance and edge detection as quality indicators. + """ + + def verify(self, image: torch.Tensor) -> tuple[bool, float]: + """ + Simple quality check based on image variance and structure. + """ + # Convert to grayscale for analysis + if image.shape[-1] == 4: # RGBA + gray = image[..., :3].mean(dim=-1, keepdim=True) + else: + gray = image.mean(dim=-1, keepdim=True) + + # Calculate variance (higher variance = more detail) + variance = gray.var().item() + + # Simple edge detection using Sobel-like operator + # Convert to numpy for easier processing + img_np = gray.squeeze().cpu().numpy() + if len(img_np.shape) == 3: + img_np = img_np[0] # Take first batch + + # Calculate gradients + grad_x = np.gradient(img_np, axis=1) + grad_y = np.gradient(img_np, axis=0) + edge_strength = np.sqrt(grad_x**2 + grad_y**2).mean() + + # Combine metrics (normalize to [0, 1]) + variance_score = min(variance * 10, 1.0) # Scale variance + edge_score = min(edge_strength * 5, 1.0) # Scale edge strength + + quality_score = (variance_score * 0.6 + edge_score * 0.4) + + is_acceptable = quality_score >= self.threshold + + return is_acceptable, quality_score + + +class VerifierSelectionNode(IO.ComfyNode): + """ + Node for selecting or creating an image quality verifier. + The verifier is used by InferenceScalingNode to judge image quality + during the generation process. + """ + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="VerifierSelectionNode", + display_name="Verifier Selection", + category="inference_scaling", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + IO.Combo.Input( + "verifier_type", + options=["simple", "custom"], + default="simple", + tooltip="Type of quality verifier to use", + ), + IO.Float.Input( + "quality_threshold", + default=0.5, + min=0.0, + max=1.0, + step=0.01, + tooltip="Minimum quality score threshold (0.0-1.0)", + ), + ], + outputs=[ + IO.String.Output(), # Verifier identifier (for now, just a string) + ], + ) + + @classmethod + async def execute( + cls, + verifier_type: str = "simple", + quality_threshold: float = 0.5, + ) -> IO.NodeOutput: + """ + Create and return a verifier identifier. + + In a real implementation, this would create and store the verifier + instance. For now, we return a serialized identifier. + """ + # Create verifier based on type + if verifier_type == "simple": + verifier = SimpleVerifier(threshold=quality_threshold) + else: + verifier = SimpleVerifier(threshold=quality_threshold) + + # Store verifier in a way that InferenceScalingNode can access it + # For now, we'll use a simple approach: store in a module-level dict + import comfy_api_nodes.nodes_inference_scaling as mod + if not hasattr(mod, '_verifier_registry'): + mod._verifier_registry = {} + + verifier_id = f"verifier_{id(verifier)}" + mod._verifier_registry[verifier_id] = verifier + + return IO.NodeOutput(verifier_id) + + +class InferenceScalingNode(IO.ComfyNode): + """ + Main inference scaling node that wraps KSampler with quality monitoring. + + This node: + 1. Wraps the standard sampling process + 2. Periodically decodes latents to images using VAE during sampling + 3. Uses a verifier to judge image quality + 4. Adjusts sampling parameters (steps, CFG) based on quality + """ + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="InferenceScalingNode", + display_name="Inference Scaling", + category="inference_scaling", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + IO.String.Input( + "verifier_id", + default="", + tooltip="Verifier identifier from VerifierSelectionNode", + ), + IO.Model.Input( + "model", + tooltip="Model to use for sampling", + ), + IO.Conditioning.Input( + "positive", + tooltip="Positive conditioning", + ), + IO.Conditioning.Input( + "negative", + tooltip="Negative conditioning", + ), + IO.Latent.Input( + "latent_image", + tooltip="Initial latent image", + ), + IO.Vae.Input( + "vae", + tooltip="VAE model for decoding latents during quality checks", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xffffffffffffffff, + control_after_generate=True, + tooltip="Random seed", + ), + IO.Int.Input( + "steps", + default=20, + min=1, + max=1000, + tooltip="Number of sampling steps", + ), + IO.Float.Input( + "cfg", + default=7.0, + min=0.0, + max=30.0, + step=0.1, + tooltip="CFG scale", + ), + IO.Combo.Input( + "sampler_name", + options=SAMPLER_OPTIONS, + default="euler", + tooltip="Sampler name", + ), + IO.Combo.Input( + "scheduler", + options=SCHEDULER_OPTIONS, + default="normal", + tooltip="Scheduler name", + ), + IO.Int.Input( + "check_interval", + default=5, + min=1, + max=50, + tooltip="Check quality every N steps", + ), + IO.Float.Input( + "quality_threshold", + default=0.5, + min=0.0, + max=1.0, + step=0.01, + tooltip="Quality threshold for early stopping or adjustment", + ), + IO.Float.Input( + "scale_factor", + default=1.2, + min=0.5, + max=2.0, + step=0.1, + tooltip="Factor to scale steps/CFG when quality is low", + ), + ], + outputs=[ + IO.Latent.Output(), # Latent output + ], + ) + + @classmethod + async def execute( + cls, + verifier_id: str, + model: Any, + positive: Any, + negative: Any, + latent_image: dict, + vae: VAE, + seed: int = 0, + steps: int = 20, + cfg: float = 7.0, + sampler_name: str = "euler", + scheduler: str = "normal", + check_interval: int = 5, + quality_threshold: float = 0.5, + scale_factor: float = 1.2, + ) -> IO.NodeOutput: + """ + Execute inference scaling sampling. + + This wraps the standard sampling process and adds quality monitoring. + """ + # Get verifier from registry + import comfy_api_nodes.nodes_inference_scaling as mod + if not hasattr(mod, '_verifier_registry') or verifier_id not in mod._verifier_registry: + raise ValueError(f"Verifier {verifier_id} not found. Please create it with VerifierSelectionNode first.") + + verifier = mod._verifier_registry[verifier_id] + + # Prepare sampling parameters + model_management.get_torch_device() + + # Extract latent_image tensor from dict + latent_image_tensor = latent_image["samples"] + + # Fix empty latent channels if needed + latent_image_tensor = comfy.sample.fix_empty_latent_channels(model, latent_image_tensor) + + # Prepare noise from latent_image and seed (same as common_ksampler does) + batch_inds = latent_image.get("batch_index", None) + noise = comfy.sample.prepare_noise(latent_image_tensor, seed, batch_inds) + + # Get noise_mask if present + noise_mask = latent_image.get("noise_mask", None) + + # Track quality during sampling + quality_history = [] + current_cfg = cfg + + def quality_check_callback(callback_dict: dict): + """ + Callback function called during sampling steps. + Decodes latents, checks quality, and adjusts parameters. + + Args: + callback_dict: Dictionary with keys 'x', 'i', 'sigma', 'sigma_hat', 'denoised' + """ + nonlocal current_cfg, quality_history + + step = callback_dict['i'] + denoised = callback_dict['denoised'] # This is the predicted x0 + + # Only check at specified intervals + if check_interval > 0 and step % check_interval != 0: + return + + try: + # Decode latent to image using VAE + # denoised is the predicted x0 in latent space [B, C, H, W] + with torch.no_grad(): + # Ensure denoised is a proper tensor + if not isinstance(denoised, torch.Tensor): + return + + # VAE.decode expects [B, C, H, W] format and handles device management + # It returns [B, H, W, C] in range [0, 1] + decoded = vae.decode(denoised) + + # Ensure decoded is in correct format [B, H, W, C] + if decoded is None: + return + + # Move to CPU for quality checking to avoid GPU memory issues + decoded_cpu = decoded.cpu() + + # Check quality + is_acceptable, quality_score = verifier.verify(decoded_cpu) + quality_history.append((step, quality_score, is_acceptable)) + + # Log quality for debugging + import logging + logging.info(f"Inference Scaling: Step {step}, Quality: {quality_score:.3f}, Acceptable: {is_acceptable}") + + # Note: CFG adjustment here won't affect current sampling run + # as CFG is set at the start. This is for future reference/logging. + if not is_acceptable and quality_score < quality_threshold: + # Quality is low - log for analysis + logging.warning(f"Inference Scaling: Low quality detected at step {step}: {quality_score:.3f} < {quality_threshold}") + elif is_acceptable and quality_score > quality_threshold * 1.2: + # Quality is good + logging.info(f"Inference Scaling: Good quality at step {step}: {quality_score:.3f}") + + except Exception as e: + # If decoding fails, continue sampling (don't break the generation) + import logging + logging.warning(f"Inference Scaling: Quality check failed at step {step}: {e}") + # Don't log full traceback in production to avoid spam + # logging.debug(traceback.format_exc()) + + # Perform sampling with quality monitoring + try: + # Use comfy.sample.sample which handles everything properly + result_samples = comfy.sample.sample( + model=model, + noise=noise, + steps=steps, + cfg=current_cfg, + sampler_name=sampler_name, + scheduler=scheduler, + positive=positive, + negative=negative, + latent_image=latent_image_tensor, + denoise=1.0, + disable_noise=False, + start_step=None, + last_step=None, + force_full_denoise=False, + noise_mask=noise_mask, + callback=quality_check_callback, + disable_pbar=False, + seed=seed, + ) + + # Return the result as a latent dict (preserve other keys from input) + result_latent = latent_image.copy() + result_latent["samples"] = result_samples + + return IO.NodeOutput(result_latent) + + except Exception as e: + import traceback + raise Exception(f"Inference scaling sampling failed: {e}\n{traceback.format_exc()}") + + +class InferenceScalingExtension(ComfyExtension): + """ + Extension class that registers inference scaling nodes. + """ + + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + VerifierSelectionNode, + InferenceScalingNode, + ] + + +async def comfy_entrypoint() -> InferenceScalingExtension: + """ + Entry point function that ComfyUI calls to load the extension. + """ + return InferenceScalingExtension() diff --git a/comfy_api_nodes/nodes_template_example.py b/comfy_api_nodes/nodes_template_example.py new file mode 100644 index 000000000000..16b4ab8ddad1 --- /dev/null +++ b/comfy_api_nodes/nodes_template_example.py @@ -0,0 +1,180 @@ +""" +Template example node file. +Copy this file and modify it to create your own custom nodes. +""" + +from typing import Optional +from typing_extensions import override +from inspect import cleandoc +import torch +import numpy as np +from PIL import Image + +from comfy_api.latest import IO, ComfyExtension + + +class ExampleNode(IO.ComfyNode): + """ + This is an example node that demonstrates the basic structure. + Replace this docstring with a description of what your node does. + """ + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="ExampleNode", + display_name="Example Node", + category="example", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + IO.String.Input( + "text_input", + default="Hello, ComfyUI!", + multiline=False, + tooltip="A text input field", + ), + IO.Int.Input( + "number_input", + default=42, + min=0, + max=100, + step=1, + display_mode=IO.NumberDisplay.slider, + tooltip="A number input with slider", + ), + IO.Image.Input( + "image_input", + tooltip="An optional image input", + optional=True, + ), + ], + outputs=[ + IO.String.Output(), + IO.Int.Output(), + IO.Image.Output(), + ], + ) + + @classmethod + async def execute( + cls, + text_input: str, + number_input: int, + image_input: Optional[torch.Tensor] = None, + ) -> IO.NodeOutput: + """ + Execute the node logic. + + Args: + text_input: The text input value + number_input: The number input value + image_input: Optional image tensor [B, H, W, C] + + Returns: + NodeOutput with processed results + """ + # Process text + processed_text = f"Processed: {text_input}" + + # Process number + processed_number = number_input * 2 + + # Process image if provided + if image_input is not None: + # Ensure batch dimension exists + if len(image_input.shape) == 3: + image_input = image_input.unsqueeze(0) + + # Example: invert the image + processed_image = 1.0 - image_input + else: + # Create a default image if none provided + # Create a simple 256x256 RGBA image + default_image = torch.ones(1, 256, 256, 4) * 0.5 + processed_image = default_image + + # Return multiple outputs + return IO.NodeOutput( + string=processed_text, + int=processed_number, + image=processed_image, + ) + + +class SimpleImageNode(IO.ComfyNode): + """ + A simpler example with just image input/output. + """ + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="SimpleImageNode", + display_name="Simple Image Node", + category="example/image", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + IO.Image.Input( + "image", + tooltip="Input image", + ), + IO.Float.Input( + "multiplier", + default=1.0, + min=0.0, + max=2.0, + step=0.1, + tooltip="Brightness multiplier", + ), + ], + outputs=[ + IO.Image.Output(), + ], + ) + + @classmethod + async def execute( + cls, + image: torch.Tensor, + multiplier: float = 1.0, + ) -> IO.NodeOutput: + """ + Multiply image brightness. + + Args: + image: Input image tensor [B, H, W, C] + multiplier: Brightness multiplier + + Returns: + Adjusted image + """ + # Ensure batch dimension + if len(image.shape) == 3: + image = image.unsqueeze(0) + + # Apply multiplier and clamp to valid range + result = torch.clamp(image * multiplier, 0.0, 1.0) + + return IO.NodeOutput(result) + + +class ExampleExtension(ComfyExtension): + """ + Extension class that registers all your nodes. + """ + + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + ExampleNode, + SimpleImageNode, + # Add more nodes here as you create them + ] + + +async def comfy_entrypoint() -> ExampleExtension: + """ + Entry point function that ComfyUI calls to load your extension. + This function name must be exactly 'comfy_entrypoint'. + """ + return ExampleExtension()