diff --git a/.env.example b/.env.example new file mode 100644 index 000000000000..a47e9dcc287a --- /dev/null +++ b/.env.example @@ -0,0 +1,5 @@ + +# Keys +GOOGLE_API_KEY= +FAL_API_KEY= + diff --git a/.gitignore b/.gitignore index 2700ad5c293e..3b0fdb5b23d6 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ __pycache__/ *.py[cod] +.env /output/ /input/ !/input/example.png diff --git a/comfy/cli_args.py b/comfy/cli_args.py index e9832acaf97e..5fbd42cc5a09 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -221,13 +221,6 @@ def is_valid_directory(path: str) -> str: parser.add_argument("--enable-compress-response-body", action="store_true", help="Enable compressing response body.") -parser.add_argument( - "--comfy-api-base", - type=str, - default="https://api.comfy.org", - help="Set the base URL for the ComfyUI API. (default: https://api.comfy.org)", -) - database_default_path = os.path.abspath( os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db") ) diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 050031dc0546..abec79e74095 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -1332,9 +1332,9 @@ class Hidden(str, Enum): dynprompt = "DYNPROMPT" """DYNPROMPT is an instance of comfy_execution.graph.DynamicPrompt. It differs from PROMPT in that it may mutate during the course of execution in response to Node Expansion.""" auth_token_comfy_org = "AUTH_TOKEN_COMFY_ORG" - """AUTH_TOKEN_COMFY_ORG is a token acquired from signing into a ComfyOrg account on frontend.""" + """Deprecated: BYOK mode uses env vars instead. Kept for framework compatibility.""" api_key_comfy_org = "API_KEY_COMFY_ORG" - """API_KEY_COMFY_ORG is an API Key generated by ComfyOrg that allows skipping signing into a ComfyOrg account on frontend.""" + """Deprecated: BYOK mode uses env vars instead. Kept for framework compatibility.""" @dataclass @@ -1521,12 +1521,7 @@ def finalize(self): self.outputs = [] if self.hidden is None: self.hidden = [] - # if is an api_node, will need key-related hidden - if self.is_api_node: - if Hidden.auth_token_comfy_org not in self.hidden: - self.hidden.append(Hidden.auth_token_comfy_org) - if Hidden.api_key_comfy_org not in self.hidden: - self.hidden.append(Hidden.api_key_comfy_org) + # BYOK: ComfyOrg auth auto-injection removed -- API keys come from env vars # if is an output_node, will need prompt and extra_pnginfo if self.is_output_node: if Hidden.prompt not in self.hidden: diff --git a/comfy_api_nodes/apis/fal.py b/comfy_api_nodes/apis/fal.py new file mode 100644 index 000000000000..a478ebb1d5c7 --- /dev/null +++ b/comfy_api_nodes/apis/fal.py @@ -0,0 +1,38 @@ +"""Pydantic models for fal.ai queue API envelope. + +These are shared by all fal.ai-routed nodes. Individual model results +are returned as plain dicts (model-specific output shapes). +""" + +from pydantic import BaseModel, Field + + +class FalQueueSubmitResponse(BaseModel): + """Response from POST queue.fal.run/{model_id}.""" + + request_id: str = Field(...) + response_url: str = Field(...) + status_url: str = Field(...) + cancel_url: str = Field(default="") + + +class FalQueueStatusResponse(BaseModel): + """Response from GET queue.fal.run/{model_id}/requests/{id}/status.""" + + status: str = Field(...) # IN_QUEUE, IN_PROGRESS, COMPLETED + queue_position: int | None = Field(default=None) + response_url: str = Field(default="") + + +class FalErrorDetail(BaseModel): + """Single error detail from fal.ai.""" + + loc: list[str] = Field(default_factory=list) + msg: str = Field(default="") + type: str = Field(default="") + + +class FalErrorResponse(BaseModel): + """Error response from fal.ai endpoints.""" + + detail: list[FalErrorDetail] = Field(default_factory=list) diff --git a/comfy_api_nodes/apis/gemini.py b/comfy_api_nodes/apis/gemini.py index 639035feff9d..8ca723d87d11 100644 --- a/comfy_api_nodes/apis/gemini.py +++ b/comfy_api_nodes/apis/gemini.py @@ -145,7 +145,6 @@ class GeminiImageGenerateContentRequest(BaseModel): systemInstruction: GeminiSystemInstructionContent | None = Field(None) tools: list[GeminiTool] | None = Field(None) videoMetadata: GeminiVideoMetadata | None = Field(None) - uploadImagesToStorage: bool = Field(True) class GeminiGenerateContentRequest(BaseModel): diff --git a/comfy_api_nodes/apis/grok.py b/comfy_api_nodes/apis/grok.py deleted file mode 100644 index c56c8aeccdb1..000000000000 --- a/comfy_api_nodes/apis/grok.py +++ /dev/null @@ -1,75 +0,0 @@ -from pydantic import BaseModel, Field - - -class ImageGenerationRequest(BaseModel): - model: str = Field(...) - prompt: str = Field(...) - aspect_ratio: str = Field(...) - n: int = Field(...) - seed: int = Field(...) - response_format: str = Field("url") - resolution: str = Field(...) - - -class InputUrlObject(BaseModel): - url: str = Field(...) - - -class ImageEditRequest(BaseModel): - model: str = Field(...) - images: list[InputUrlObject] = Field(...) - prompt: str = Field(...) - resolution: str = Field(...) - n: int = Field(...) - seed: int = Field(...) - response_format: str = Field("url") - aspect_ratio: str | None = Field(...) - - -class VideoGenerationRequest(BaseModel): - model: str = Field(...) - prompt: str = Field(...) - image: InputUrlObject | None = Field(...) - duration: int = Field(...) - aspect_ratio: str | None = Field(...) - resolution: str = Field(...) - seed: int = Field(...) - - -class VideoEditRequest(BaseModel): - model: str = Field(...) - prompt: str = Field(...) - video: InputUrlObject = Field(...) - seed: int = Field(...) - - -class ImageResponseObject(BaseModel): - url: str | None = Field(None) - b64_json: str | None = Field(None) - revised_prompt: str | None = Field(None) - - -class UsageObject(BaseModel): - cost_in_usd_ticks: int | None = Field(None) - - -class ImageGenerationResponse(BaseModel): - data: list[ImageResponseObject] = Field(...) - usage: UsageObject | None = Field(None) - - -class VideoGenerationResponse(BaseModel): - request_id: str = Field(...) - - -class VideoResponseObject(BaseModel): - url: str = Field(...) - upsampled_prompt: str | None = Field(None) - duration: int = Field(...) - - -class VideoStatusResponse(BaseModel): - status: str | None = Field(None) - video: VideoResponseObject | None = Field(None) - model: str | None = Field(None) - usage: UsageObject | None = Field(None) diff --git a/comfy_api_nodes/apis/hitpaw.py b/comfy_api_nodes/apis/hitpaw.py deleted file mode 100644 index b23c5d9eb722..000000000000 --- a/comfy_api_nodes/apis/hitpaw.py +++ /dev/null @@ -1,51 +0,0 @@ -from typing import TypedDict - -from pydantic import BaseModel, Field - - -class InputVideoModel(TypedDict): - model: str - resolution: str - - -class ImageEnhanceTaskCreateRequest(BaseModel): - model_name: str = Field(...) - img_url: str = Field(...) - extension: str = Field(".png") - exif: bool = Field(False) - DPI: int | None = Field(None) - - -class VideoEnhanceTaskCreateRequest(BaseModel): - video_url: str = Field(...) - extension: str = Field(".mp4") - model_name: str | None = Field(...) - resolution: list[int] = Field(..., description="Target resolution [width, height]") - original_resolution: list[int] = Field(..., description="Original video resolution [width, height]") - - -class TaskCreateDataResponse(BaseModel): - job_id: str = Field(...) - consume_coins: int | None = Field(None) - - -class TaskStatusPollRequest(BaseModel): - job_id: str = Field(...) - - -class TaskCreateResponse(BaseModel): - code: int = Field(...) - message: str = Field(...) - data: TaskCreateDataResponse | None = Field(None) - - -class TaskStatusDataResponse(BaseModel): - job_id: str = Field(...) - status: str = Field(...) - res_url: str = Field("") - - -class TaskStatusResponse(BaseModel): - code: int = Field(...) - message: str = Field(...) - data: TaskStatusDataResponse = Field(...) diff --git a/comfy_api_nodes/apis/magnific.py b/comfy_api_nodes/apis/magnific.py deleted file mode 100644 index b9f148def44d..000000000000 --- a/comfy_api_nodes/apis/magnific.py +++ /dev/null @@ -1,122 +0,0 @@ -from typing import TypedDict - -from pydantic import AliasChoices, BaseModel, Field, model_validator - - -class InputPortraitMode(TypedDict): - portrait_mode: str - portrait_style: str - portrait_beautifier: str - - -class InputAdvancedSettings(TypedDict): - advanced_settings: str - whites: int - blacks: int - brightness: int - contrast: int - saturation: int - engine: str - transfer_light_a: str - transfer_light_b: str - fixed_generation: bool - - -class InputSkinEnhancerMode(TypedDict): - mode: str - skin_detail: int - optimized_for: str - - -class ImageUpscalerCreativeRequest(BaseModel): - image: str = Field(...) - scale_factor: str = Field(...) - optimized_for: str = Field(...) - prompt: str | None = Field(None) - creativity: int = Field(...) - hdr: int = Field(...) - resemblance: int = Field(...) - fractality: int = Field(...) - engine: str = Field(...) - - -class ImageUpscalerPrecisionV2Request(BaseModel): - image: str = Field(...) - sharpen: int = Field(...) - smart_grain: int = Field(...) - ultra_detail: int = Field(...) - flavor: str = Field(...) - scale_factor: int = Field(...) - - -class ImageRelightAdvancedSettingsRequest(BaseModel): - whites: int = Field(...) - blacks: int = Field(...) - brightness: int = Field(...) - contrast: int = Field(...) - saturation: int = Field(...) - engine: str = Field(...) - transfer_light_a: str = Field(...) - transfer_light_b: str = Field(...) - fixed_generation: bool = Field(...) - - -class ImageRelightRequest(BaseModel): - image: str = Field(...) - prompt: str | None = Field(None) - transfer_light_from_reference_image: str | None = Field(None) - light_transfer_strength: int = Field(...) - interpolate_from_original: bool = Field(...) - change_background: bool = Field(...) - style: str = Field(...) - preserve_details: bool = Field(...) - advanced_settings: ImageRelightAdvancedSettingsRequest | None = Field(...) - - -class ImageStyleTransferRequest(BaseModel): - image: str = Field(...) - reference_image: str = Field(...) - prompt: str | None = Field(None) - style_strength: int = Field(...) - structure_strength: int = Field(...) - is_portrait: bool = Field(...) - portrait_style: str | None = Field(...) - portrait_beautifier: str | None = Field(...) - flavor: str = Field(...) - engine: str = Field(...) - fixed_generation: bool = Field(...) - - -class ImageSkinEnhancerCreativeRequest(BaseModel): - image: str = Field(...) - sharpen: int = Field(...) - smart_grain: int = Field(...) - - -class ImageSkinEnhancerFaithfulRequest(BaseModel): - image: str = Field(...) - sharpen: int = Field(...) - smart_grain: int = Field(...) - skin_detail: int = Field(...) - - -class ImageSkinEnhancerFlexibleRequest(BaseModel): - image: str = Field(...) - sharpen: int = Field(...) - smart_grain: int = Field(...) - optimized_for: str = Field(...) - - -class TaskResponse(BaseModel): - """Unified response model that handles both wrapped and unwrapped API responses.""" - - task_id: str = Field(...) - status: str = Field(validation_alias=AliasChoices("status", "task_status")) - generated: list[str] | None = Field(None) - - @model_validator(mode="before") - @classmethod - def unwrap_data(cls, values: dict) -> dict: - if "data" in values and isinstance(values["data"], dict): - return values["data"] - return values diff --git a/comfy_api_nodes/apis/moonvalley.py b/comfy_api_nodes/apis/moonvalley.py deleted file mode 100644 index 7ec7a4ade542..000000000000 --- a/comfy_api_nodes/apis/moonvalley.py +++ /dev/null @@ -1,152 +0,0 @@ -from enum import Enum -from typing import Optional, Dict, Any - -from pydantic import BaseModel, Field, StrictBytes - - -class MoonvalleyPromptResponse(BaseModel): - error: Optional[Dict[str, Any]] = None - frame_conditioning: Optional[Dict[str, Any]] = None - id: Optional[str] = None - inference_params: Optional[Dict[str, Any]] = None - meta: Optional[Dict[str, Any]] = None - model_params: Optional[Dict[str, Any]] = None - output_url: Optional[str] = None - prompt_text: Optional[str] = None - status: Optional[str] = None - - -class MoonvalleyTextToVideoInferenceParams(BaseModel): - add_quality_guidance: Optional[bool] = Field( - True, description='Whether to add quality guidance' - ) - caching_coefficient: Optional[float] = Field( - 0.3, description='Caching coefficient for optimization' - ) - caching_cooldown: Optional[int] = Field( - 3, description='Number of caching cooldown steps' - ) - caching_warmup: Optional[int] = Field( - 3, description='Number of caching warmup steps' - ) - clip_value: Optional[float] = Field( - 3, description='CLIP value for generation control' - ) - conditioning_frame_index: Optional[int] = Field( - 0, description='Index of the conditioning frame' - ) - cooldown_steps: Optional[int] = Field( - 75, description='Number of cooldown steps (calculated based on num_frames)' - ) - fps: Optional[int] = Field( - 24, description='Frames per second of the generated video' - ) - guidance_scale: Optional[float] = Field( - 10, description='Guidance scale for generation control' - ) - height: Optional[int] = Field( - 1080, description='Height of the generated video in pixels' - ) - negative_prompt: Optional[str] = Field(None, description='Negative prompt text') - num_frames: Optional[int] = Field(64, description='Number of frames to generate') - seed: Optional[int] = Field( - None, description='Random seed for generation (default: random)' - ) - shift_value: Optional[float] = Field( - 3, description='Shift value for generation control' - ) - steps: Optional[int] = Field(80, description='Number of denoising steps') - use_guidance_schedule: Optional[bool] = Field( - True, description='Whether to use guidance scheduling' - ) - use_negative_prompts: Optional[bool] = Field( - False, description='Whether to use negative prompts' - ) - use_timestep_transform: Optional[bool] = Field( - True, description='Whether to use timestep transformation' - ) - warmup_steps: Optional[int] = Field( - 0, description='Number of warmup steps (calculated based on num_frames)' - ) - width: Optional[int] = Field( - 1920, description='Width of the generated video in pixels' - ) - - -class MoonvalleyTextToVideoRequest(BaseModel): - image_url: Optional[str] = None - inference_params: Optional[MoonvalleyTextToVideoInferenceParams] = None - prompt_text: Optional[str] = None - webhook_url: Optional[str] = None - - -class MoonvalleyUploadFileRequest(BaseModel): - file: Optional[StrictBytes] = None - - -class MoonvalleyUploadFileResponse(BaseModel): - access_url: Optional[str] = None - - -class MoonvalleyVideoToVideoInferenceParams(BaseModel): - add_quality_guidance: Optional[bool] = Field( - True, description='Whether to add quality guidance' - ) - caching_coefficient: Optional[float] = Field( - 0.3, description='Caching coefficient for optimization' - ) - caching_cooldown: Optional[int] = Field( - 3, description='Number of caching cooldown steps' - ) - caching_warmup: Optional[int] = Field( - 3, description='Number of caching warmup steps' - ) - clip_value: Optional[float] = Field( - 3, description='CLIP value for generation control' - ) - conditioning_frame_index: Optional[int] = Field( - 0, description='Index of the conditioning frame' - ) - cooldown_steps: Optional[int] = Field( - 36, description='Number of cooldown steps (calculated based on num_frames)' - ) - guidance_scale: Optional[float] = Field( - 15, description='Guidance scale for generation control' - ) - negative_prompt: Optional[str] = Field(None, description='Negative prompt text') - seed: Optional[int] = Field( - None, description='Random seed for generation (default: random)' - ) - shift_value: Optional[float] = Field( - 3, description='Shift value for generation control' - ) - steps: Optional[int] = Field(80, description='Number of denoising steps') - use_guidance_schedule: Optional[bool] = Field( - True, description='Whether to use guidance scheduling' - ) - use_negative_prompts: Optional[bool] = Field( - False, description='Whether to use negative prompts' - ) - use_timestep_transform: Optional[bool] = Field( - True, description='Whether to use timestep transformation' - ) - warmup_steps: Optional[int] = Field( - 24, description='Number of warmup steps (calculated based on num_frames)' - ) - - -class ControlType(str, Enum): - motion_control = 'motion_control' - pose_control = 'pose_control' - - -class MoonvalleyVideoToVideoRequest(BaseModel): - control_type: ControlType = Field( - ..., description='Supported types for video control' - ) - inference_params: Optional[MoonvalleyVideoToVideoInferenceParams] = None - prompt_text: str = Field(..., description='Describes the video to generate') - video_url: str = Field(..., description='Url to control video') - webhook_url: Optional[str] = Field( - None, description='Optional webhook URL for notifications' - ) diff --git a/comfy_api_nodes/apis/rodin.py b/comfy_api_nodes/apis/rodin.py deleted file mode 100644 index fc26a6e73a0f..000000000000 --- a/comfy_api_nodes/apis/rodin.py +++ /dev/null @@ -1,54 +0,0 @@ -from __future__ import annotations - -from enum import Enum -from typing import Optional, List -from pydantic import BaseModel, Field - - -class Rodin3DGenerateRequest(BaseModel): - seed: int = Field(..., description="seed_") - tier: str = Field(..., description="Tier of generation.") - material: str = Field(..., description="The material type.") - quality_override: int = Field(..., description="The poly count of the mesh.") - mesh_mode: str = Field(..., description="It controls the type of faces of generated models.") - TAPose: Optional[bool] = Field(None, description="") - -class GenerateJobsData(BaseModel): - uuids: List[str] = Field(..., description="str LIST") - subscription_key: str = Field(..., description="subscription key") - -class Rodin3DGenerateResponse(BaseModel): - message: Optional[str] = Field(None, description="Return message.") - prompt: Optional[str] = Field(None, description="Generated Prompt from image.") - submit_time: Optional[str] = Field(None, description="Submit Time") - uuid: Optional[str] = Field(None, description="Task str") - jobs: Optional[GenerateJobsData] = Field(None, description="Details of jobs") - -class JobStatus(str, Enum): - """ - Status for jobs - """ - Done = "Done" - Failed = "Failed" - Generating = "Generating" - Waiting = "Waiting" - -class Rodin3DCheckStatusRequest(BaseModel): - subscription_key: str = Field(..., description="subscription from generate endpoint") - -class JobItem(BaseModel): - uuid: str = Field(..., description="uuid") - status: JobStatus = Field(...,description="Status Currently") - -class Rodin3DCheckStatusResponse(BaseModel): - jobs: List[JobItem] = Field(..., description="Job status List") - -class Rodin3DDownloadRequest(BaseModel): - task_uuid: str = Field(..., description="Task str") - -class RodinResourceItem(BaseModel): - url: str = Field(..., description="Download Url") - name: str = Field(..., description="File name with ext") - -class Rodin3DDownloadResponse(BaseModel): - list: List[RodinResourceItem] = Field(..., description="Source List") diff --git a/comfy_api_nodes/apis/runway.py b/comfy_api_nodes/apis/runway.py deleted file mode 100644 index df6f2b845822..000000000000 --- a/comfy_api_nodes/apis/runway.py +++ /dev/null @@ -1,127 +0,0 @@ -from enum import Enum -from typing import Optional, List, Union -from datetime import datetime - -from pydantic import BaseModel, Field, RootModel - - -class RunwayAspectRatioEnum(str, Enum): - field_1280_720 = '1280:720' - field_720_1280 = '720:1280' - field_1104_832 = '1104:832' - field_832_1104 = '832:1104' - field_960_960 = '960:960' - field_1584_672 = '1584:672' - field_1280_768 = '1280:768' - field_768_1280 = '768:1280' - - -class Position(str, Enum): - first = 'first' - last = 'last' - - -class RunwayPromptImageDetailedObject(BaseModel): - position: Position = Field( - ..., - description="The position of the image in the output video. 'last' is currently supported for gen3a_turbo only.", - ) - uri: str = Field( - ..., description='A HTTPS URL or data URI containing an encoded image.' - ) - - -class RunwayPromptImageObject( - RootModel[Union[str, List[RunwayPromptImageDetailedObject]]] -): - root: Union[str, List[RunwayPromptImageDetailedObject]] = Field( - ..., - description='Image(s) to use for the video generation. Can be a single URI or an array of image objects with positions.', - ) - - -class RunwayModelEnum(str, Enum): - gen4_turbo = 'gen4_turbo' - gen3a_turbo = 'gen3a_turbo' - - -class RunwayDurationEnum(int, Enum): - integer_5 = 5 - integer_10 = 10 - - -class RunwayImageToVideoRequest(BaseModel): - duration: RunwayDurationEnum - model: RunwayModelEnum - promptImage: RunwayPromptImageObject - promptText: Optional[str] = Field( - None, description='Text prompt for the generation', max_length=1000 - ) - ratio: RunwayAspectRatioEnum - seed: int = Field( - ..., description='Random seed for generation', ge=0, le=4294967295 - ) - - -class RunwayImageToVideoResponse(BaseModel): - id: Optional[str] = Field(None, description='Task ID') - - -class RunwayTaskStatusEnum(str, Enum): - SUCCEEDED = 'SUCCEEDED' - RUNNING = 'RUNNING' - FAILED = 'FAILED' - PENDING = 'PENDING' - CANCELLED = 'CANCELLED' - THROTTLED = 'THROTTLED' - - -class RunwayTaskStatusResponse(BaseModel): - createdAt: datetime = Field(..., description='Task creation timestamp') - id: str = Field(..., description='Task ID') - output: Optional[List[str]] = Field(None, description='Array of output video URLs') - progress: Optional[float] = Field( - None, - description='Float value between 0 and 1 representing the progress of the task. Only available if status is RUNNING.', - ge=0.0, - le=1.0, - ) - status: RunwayTaskStatusEnum - - -class Model4(str, Enum): - gen4_image = 'gen4_image' - - -class ReferenceImage(BaseModel): - uri: Optional[str] = Field( - None, description='A HTTPS URL or data URI containing an encoded image' - ) - - -class RunwayTextToImageAspectRatioEnum(str, Enum): - field_1920_1080 = '1920:1080' - field_1080_1920 = '1080:1920' - field_1024_1024 = '1024:1024' - field_1360_768 = '1360:768' - field_1080_1080 = '1080:1080' - field_1168_880 = '1168:880' - field_1440_1080 = '1440:1080' - field_1080_1440 = '1080:1440' - field_1808_768 = '1808:768' - field_2112_912 = '2112:912' - - -class RunwayTextToImageRequest(BaseModel): - model: Model4 = Field(..., description='Model to use for generation') - promptText: str = Field( - ..., description='Text prompt for the image generation', max_length=1000 - ) - ratio: RunwayTextToImageAspectRatioEnum - referenceImages: Optional[List[ReferenceImage]] = Field( - None, description='Array of reference images to guide the generation' - ) - - -class RunwayTextToImageResponse(BaseModel): - id: Optional[str] = Field(None, description='Task ID') diff --git a/comfy_api_nodes/apis/topaz.py b/comfy_api_nodes/apis/topaz.py deleted file mode 100644 index a9e6235a7737..000000000000 --- a/comfy_api_nodes/apis/topaz.py +++ /dev/null @@ -1,133 +0,0 @@ -from typing import Optional, Union - -from pydantic import BaseModel, Field - - -class ImageEnhanceRequest(BaseModel): - model: str = Field("Reimagine") - output_format: str = Field("jpeg") - subject_detection: str = Field("All") - face_enhancement: bool = Field(True) - face_enhancement_creativity: float = Field(0, description="Is ignored if face_enhancement is false") - face_enhancement_strength: float = Field(0.8, description="Is ignored if face_enhancement is false") - source_url: str = Field(...) - output_width: Optional[int] = Field(None) - output_height: Optional[int] = Field(None) - crop_to_fill: bool = Field(False) - prompt: Optional[str] = Field(None, description="Text prompt for creative upscaling guidance") - creativity: int = Field(3, description="Creativity settings range from 1 to 9") - face_preservation: str = Field("true", description="To preserve the identity of characters") - color_preservation: str = Field("true", description="To preserve the original color") - - -class ImageAsyncTaskResponse(BaseModel): - process_id: str = Field(...) - - -class ImageStatusResponse(BaseModel): - process_id: str = Field(...) - status: str = Field(...) - progress: Optional[int] = Field(None) - credits: int = Field(...) - - -class ImageDownloadResponse(BaseModel): - download_url: str = Field(...) - expiry: int = Field(...) - - -class Resolution(BaseModel): - width: int = Field(...) - height: int = Field(...) - - -class CreateVideoRequestSource(BaseModel): - container: str = Field(...) - size: int = Field(..., description="Size of the video file in bytes") - duration: int = Field(..., description="Duration of the video file in seconds") - frameCount: int = Field(..., description="Total number of frames in the video") - frameRate: int = Field(...) - resolution: Resolution = Field(...) - - -class VideoFrameInterpolationFilter(BaseModel): - model: str = Field(...) - slowmo: Optional[int] = Field(None) - fps: int = Field(...) - duplicate: bool = Field(...) - duplicate_threshold: float = Field(...) - - -class VideoEnhancementFilter(BaseModel): - model: str = Field(...) - auto: Optional[str] = Field(None, description="Auto, Manual, Relative") - focusFixLevel: Optional[str] = Field(None, description="Downscales video input for correction of blurred subjects") - compression: Optional[float] = Field(None, description="Strength of compression recovery") - details: Optional[float] = Field(None, description="Amount of detail reconstruction") - prenoise: Optional[float] = Field(None, description="Amount of noise to add to input to reduce over-smoothing") - noise: Optional[float] = Field(None, description="Amount of noise reduction") - halo: Optional[float] = Field(None, description="Amount of halo reduction") - preblur: Optional[float] = Field(None, description="Anti-aliasing and deblurring strength") - blur: Optional[float] = Field(None, description="Amount of sharpness applied") - grain: Optional[float] = Field(None, description="Grain after AI model processing") - grainSize: Optional[float] = Field(None, description="Size of generated grain") - recoverOriginalDetailValue: Optional[float] = Field(None, description="Source details into the output video") - creativity: Optional[str] = Field(None, description="Creativity level(high, low) for slc-1 only") - isOptimizedMode: Optional[bool] = Field(None, description="Set to true for Starlight Creative (slc-1) only") - - -class OutputInformationVideo(BaseModel): - resolution: Resolution = Field(...) - frameRate: int = Field(...) - audioCodec: Optional[str] = Field(..., description="Required if audioTransfer is Copy or Convert") - audioTransfer: str = Field(..., description="Copy, Convert, None") - dynamicCompressionLevel: str = Field(..., description="Low, Mid, High") - - -class Overrides(BaseModel): - isPaidDiffusion: bool = Field(True) - - -class CreateVideoRequest(BaseModel): - source: CreateVideoRequestSource = Field(...) - filters: list[Union[VideoFrameInterpolationFilter, VideoEnhancementFilter]] = Field(...) - output: OutputInformationVideo = Field(...) - overrides: Overrides = Field(Overrides(isPaidDiffusion=True)) - - -class CreateVideoResponse(BaseModel): - requestId: str = Field(...) - - -class VideoAcceptResponse(BaseModel): - uploadId: str = Field(...) - urls: list[str] = Field(...) - - -class VideoCompleteUploadRequestPart(BaseModel): - partNum: int = Field(...) - eTag: str = Field(...) - - -class VideoCompleteUploadRequest(BaseModel): - uploadResults: list[VideoCompleteUploadRequestPart] = Field(...) - - -class VideoCompleteUploadResponse(BaseModel): - message: str = Field(..., description="Confirmation message") - - -class VideoStatusResponseEstimates(BaseModel): - cost: list[int] = Field(...) - - -class VideoStatusResponseDownloadUrl(BaseModel): - url: str = Field(...) - - -class VideoStatusResponse(BaseModel): - status: str = Field(...) - estimates: Optional[VideoStatusResponseEstimates] = Field(None) - progress: Optional[float] = Field(None) - message: Optional[str] = Field("") - download: Optional[VideoStatusResponseDownloadUrl] = Field(None) diff --git a/comfy_api_nodes/apis/tripo.py b/comfy_api_nodes/apis/tripo.py deleted file mode 100644 index ffaaa7dc1914..000000000000 --- a/comfy_api_nodes/apis/tripo.py +++ /dev/null @@ -1,312 +0,0 @@ -from __future__ import annotations -from enum import Enum -from typing import Optional, List, Dict, Any, Union - -from pydantic import BaseModel, Field, RootModel - -class TripoModelVersion(str, Enum): - v3_0_20250812 = 'v3.0-20250812' - v2_5_20250123 = 'v2.5-20250123' - v2_0_20240919 = 'v2.0-20240919' - v1_4_20240625 = 'v1.4-20240625' - - -class TripoGeometryQuality(str, Enum): - standard = 'standard' - detailed = 'detailed' - - -class TripoTextureQuality(str, Enum): - standard = 'standard' - detailed = 'detailed' - - -class TripoStyle(str, Enum): - PERSON_TO_CARTOON = "person:person2cartoon" - ANIMAL_VENOM = "animal:venom" - OBJECT_CLAY = "object:clay" - OBJECT_STEAMPUNK = "object:steampunk" - OBJECT_CHRISTMAS = "object:christmas" - OBJECT_BARBIE = "object:barbie" - GOLD = "gold" - ANCIENT_BRONZE = "ancient_bronze" - NONE = "None" - -class TripoTaskType(str, Enum): - TEXT_TO_MODEL = "text_to_model" - IMAGE_TO_MODEL = "image_to_model" - MULTIVIEW_TO_MODEL = "multiview_to_model" - TEXTURE_MODEL = "texture_model" - REFINE_MODEL = "refine_model" - ANIMATE_PRERIGCHECK = "animate_prerigcheck" - ANIMATE_RIG = "animate_rig" - ANIMATE_RETARGET = "animate_retarget" - STYLIZE_MODEL = "stylize_model" - CONVERT_MODEL = "convert_model" - -class TripoTextureAlignment(str, Enum): - ORIGINAL_IMAGE = "original_image" - GEOMETRY = "geometry" - -class TripoOrientation(str, Enum): - ALIGN_IMAGE = "align_image" - DEFAULT = "default" - -class TripoOutFormat(str, Enum): - GLB = "glb" - FBX = "fbx" - -class TripoTopology(str, Enum): - BIP = "bip" - QUAD = "quad" - -class TripoSpec(str, Enum): - MIXAMO = "mixamo" - TRIPO = "tripo" - -class TripoAnimation(str, Enum): - IDLE = "preset:idle" - WALK = "preset:walk" - RUN = "preset:run" - DIVE = "preset:dive" - CLIMB = "preset:climb" - JUMP = "preset:jump" - SLASH = "preset:slash" - SHOOT = "preset:shoot" - HURT = "preset:hurt" - FALL = "preset:fall" - TURN = "preset:turn" - QUADRUPED_WALK = "preset:quadruped:walk" - HEXAPOD_WALK = "preset:hexapod:walk" - OCTOPOD_WALK = "preset:octopod:walk" - SERPENTINE_MARCH = "preset:serpentine:march" - AQUATIC_MARCH = "preset:aquatic:march" - -class TripoStylizeStyle(str, Enum): - LEGO = "lego" - VOXEL = "voxel" - VORONOI = "voronoi" - MINECRAFT = "minecraft" - -class TripoConvertFormat(str, Enum): - GLTF = "GLTF" - USDZ = "USDZ" - FBX = "FBX" - OBJ = "OBJ" - STL = "STL" - _3MF = "3MF" - -class TripoTextureFormat(str, Enum): - BMP = "BMP" - DPX = "DPX" - HDR = "HDR" - JPEG = "JPEG" - OPEN_EXR = "OPEN_EXR" - PNG = "PNG" - TARGA = "TARGA" - TIFF = "TIFF" - WEBP = "WEBP" - -class TripoTaskStatus(str, Enum): - QUEUED = "queued" - RUNNING = "running" - SUCCESS = "success" - FAILED = "failed" - CANCELLED = "cancelled" - UNKNOWN = "unknown" - BANNED = "banned" - EXPIRED = "expired" - -class TripoFbxPreset(str, Enum): - BLENDER = "blender" - MIXAMO = "mixamo" - _3DSMAX = "3dsmax" - -class TripoFileTokenReference(BaseModel): - type: Optional[str] = Field(None, description='The type of the reference') - file_token: str - -class TripoUrlReference(BaseModel): - type: Optional[str] = Field(None, description='The type of the reference') - url: str - -class TripoObjectStorage(BaseModel): - bucket: str - key: str - -class TripoObjectReference(BaseModel): - type: str - object: TripoObjectStorage - -class TripoFileEmptyReference(BaseModel): - pass - -class TripoFileReference(RootModel): - root: Union[TripoFileTokenReference, TripoUrlReference, TripoObjectReference, TripoFileEmptyReference] - -class TripoGetStsTokenRequest(BaseModel): - format: str = Field(..., description='The format of the image') - -class TripoTextToModelRequest(BaseModel): - type: TripoTaskType = Field(TripoTaskType.TEXT_TO_MODEL, description='Type of task') - prompt: str = Field(..., description='The text prompt describing the model to generate', max_length=1024) - negative_prompt: Optional[str] = Field(None, description='The negative text prompt', max_length=1024) - model_version: Optional[TripoModelVersion] = TripoModelVersion.v2_5_20250123 - face_limit: Optional[int] = Field(None, description='The number of faces to limit the generation to') - texture: Optional[bool] = Field(True, description='Whether to apply texture to the generated model') - pbr: Optional[bool] = Field(True, description='Whether to apply PBR to the generated model') - image_seed: Optional[int] = Field(None, description='The seed for the text') - model_seed: Optional[int] = Field(None, description='The seed for the model') - texture_seed: Optional[int] = Field(None, description='The seed for the texture') - texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard - geometry_quality: Optional[TripoGeometryQuality] = TripoGeometryQuality.standard - style: Optional[TripoStyle] = None - auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model') - quad: Optional[bool] = Field(False, description='Whether to apply quad to the generated model') - -class TripoImageToModelRequest(BaseModel): - type: TripoTaskType = Field(TripoTaskType.IMAGE_TO_MODEL, description='Type of task') - file: TripoFileReference = Field(..., description='The file reference to convert to a model') - model_version: Optional[TripoModelVersion] = Field(None, description='The model version to use for generation') - face_limit: Optional[int] = Field(None, description='The number of faces to limit the generation to') - texture: Optional[bool] = Field(True, description='Whether to apply texture to the generated model') - pbr: Optional[bool] = Field(True, description='Whether to apply PBR to the generated model') - model_seed: Optional[int] = Field(None, description='The seed for the model') - texture_seed: Optional[int] = Field(None, description='The seed for the texture') - texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard - geometry_quality: Optional[TripoGeometryQuality] = TripoGeometryQuality.standard - texture_alignment: Optional[TripoTextureAlignment] = Field(TripoTextureAlignment.ORIGINAL_IMAGE, description='The texture alignment method') - style: Optional[TripoStyle] = Field(None, description='The style to apply to the generated model') - auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model') - orientation: Optional[TripoOrientation] = TripoOrientation.DEFAULT - quad: Optional[bool] = Field(False, description='Whether to apply quad to the generated model') - -class TripoMultiviewToModelRequest(BaseModel): - type: TripoTaskType = TripoTaskType.MULTIVIEW_TO_MODEL - files: List[TripoFileReference] = Field(..., description='The file references to convert to a model') - model_version: Optional[TripoModelVersion] = Field(None, description='The model version to use for generation') - orthographic_projection: Optional[bool] = Field(False, description='Whether to use orthographic projection') - face_limit: Optional[int] = Field(None, description='The number of faces to limit the generation to') - texture: Optional[bool] = Field(True, description='Whether to apply texture to the generated model') - pbr: Optional[bool] = Field(True, description='Whether to apply PBR to the generated model') - model_seed: Optional[int] = Field(None, description='The seed for the model') - texture_seed: Optional[int] = Field(None, description='The seed for the texture') - texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard - geometry_quality: Optional[TripoGeometryQuality] = TripoGeometryQuality.standard - texture_alignment: Optional[TripoTextureAlignment] = TripoTextureAlignment.ORIGINAL_IMAGE - auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model') - orientation: Optional[TripoOrientation] = Field(TripoOrientation.DEFAULT, description='The orientation for the model') - quad: Optional[bool] = Field(False, description='Whether to apply quad to the generated model') - -class TripoTextureModelRequest(BaseModel): - type: TripoTaskType = Field(TripoTaskType.TEXTURE_MODEL, description='Type of task') - original_model_task_id: str = Field(..., description='The task ID of the original model') - texture: Optional[bool] = Field(True, description='Whether to apply texture to the model') - pbr: Optional[bool] = Field(True, description='Whether to apply PBR to the model') - model_seed: Optional[int] = Field(None, description='The seed for the model') - texture_seed: Optional[int] = Field(None, description='The seed for the texture') - texture_quality: Optional[TripoTextureQuality] = Field(None, description='The quality of the texture') - texture_alignment: Optional[TripoTextureAlignment] = Field(TripoTextureAlignment.ORIGINAL_IMAGE, description='The texture alignment method') - -class TripoRefineModelRequest(BaseModel): - type: TripoTaskType = Field(TripoTaskType.REFINE_MODEL, description='Type of task') - draft_model_task_id: str = Field(..., description='The task ID of the draft model') - -class TripoAnimatePrerigcheckRequest(BaseModel): - type: TripoTaskType = Field(TripoTaskType.ANIMATE_PRERIGCHECK, description='Type of task') - original_model_task_id: str = Field(..., description='The task ID of the original model') - -class TripoAnimateRigRequest(BaseModel): - type: TripoTaskType = Field(TripoTaskType.ANIMATE_RIG, description='Type of task') - original_model_task_id: str = Field(..., description='The task ID of the original model') - out_format: Optional[TripoOutFormat] = Field(TripoOutFormat.GLB, description='The output format') - spec: Optional[TripoSpec] = Field(TripoSpec.TRIPO, description='The specification for rigging') - -class TripoAnimateRetargetRequest(BaseModel): - type: TripoTaskType = Field(TripoTaskType.ANIMATE_RETARGET, description='Type of task') - original_model_task_id: str = Field(..., description='The task ID of the original model') - animation: TripoAnimation = Field(..., description='The animation to apply') - out_format: Optional[TripoOutFormat] = Field(TripoOutFormat.GLB, description='The output format') - bake_animation: Optional[bool] = Field(True, description='Whether to bake the animation') - -class TripoStylizeModelRequest(BaseModel): - type: TripoTaskType = Field(TripoTaskType.STYLIZE_MODEL, description='Type of task') - style: TripoStylizeStyle = Field(..., description='The style to apply to the model') - original_model_task_id: str = Field(..., description='The task ID of the original model') - block_size: Optional[int] = Field(80, description='The block size for stylization') - -class TripoConvertModelRequest(BaseModel): - type: TripoTaskType = Field(TripoTaskType.CONVERT_MODEL, description='Type of task') - format: TripoConvertFormat = Field(..., description='The format to convert to') - original_model_task_id: str = Field(..., description='The task ID of the original model') - quad: Optional[bool] = Field(None, description='Whether to apply quad to the model') - force_symmetry: Optional[bool] = Field(None, description='Whether to force symmetry') - face_limit: Optional[int] = Field(None, description='The number of faces to limit the conversion to') - flatten_bottom: Optional[bool] = Field(None, description='Whether to flatten the bottom of the model') - flatten_bottom_threshold: Optional[float] = Field(None, description='The threshold for flattening the bottom') - texture_size: Optional[int] = Field(None, description='The size of the texture') - texture_format: Optional[TripoTextureFormat] = Field(TripoTextureFormat.JPEG, description='The format of the texture') - pivot_to_center_bottom: Optional[bool] = Field(None, description='Whether to pivot to the center bottom') - scale_factor: Optional[float] = Field(None, description='The scale factor for the model') - with_animation: Optional[bool] = Field(None, description='Whether to include animations') - pack_uv: Optional[bool] = Field(None, description='Whether to pack the UVs') - bake: Optional[bool] = Field(None, description='Whether to bake the model') - part_names: Optional[List[str]] = Field(None, description='The names of the parts to include') - fbx_preset: Optional[TripoFbxPreset] = Field(None, description='The preset for the FBX export') - export_vertex_colors: Optional[bool] = Field(None, description='Whether to export the vertex colors') - export_orientation: Optional[TripoOrientation] = Field(None, description='The orientation for the export') - animate_in_place: Optional[bool] = Field(None, description='Whether to animate in place') - - -class TripoTaskRequest(RootModel): - root: Union[ - TripoTextToModelRequest, - TripoImageToModelRequest, - TripoMultiviewToModelRequest, - TripoTextureModelRequest, - TripoRefineModelRequest, - TripoAnimatePrerigcheckRequest, - TripoAnimateRigRequest, - TripoAnimateRetargetRequest, - TripoStylizeModelRequest, - TripoConvertModelRequest - ] - -class TripoTaskOutput(BaseModel): - model: Optional[str] = Field(None, description='URL to the model') - base_model: Optional[str] = Field(None, description='URL to the base model') - pbr_model: Optional[str] = Field(None, description='URL to the PBR model') - rendered_image: Optional[str] = Field(None, description='URL to the rendered image') - riggable: Optional[bool] = Field(None, description='Whether the model is riggable') - -class TripoTask(BaseModel): - task_id: str = Field(..., description='The task ID') - type: Optional[str] = Field(None, description='The type of task') - status: Optional[TripoTaskStatus] = Field(None, description='The status of the task') - input: Optional[Dict[str, Any]] = Field(None, description='The input parameters for the task') - output: Optional[TripoTaskOutput] = Field(None, description='The output of the task') - progress: Optional[int] = Field(None, description='The progress of the task', ge=0, le=100) - create_time: Optional[int] = Field(None, description='The creation time of the task') - running_left_time: Optional[int] = Field(None, description='The estimated time left for the task') - queue_position: Optional[int] = Field(None, description='The position in the queue') - -class TripoTaskResponse(BaseModel): - code: int = Field(0, description='The response code') - data: TripoTask = Field(..., description='The task data') - -class TripoGeneralResponse(BaseModel): - code: int = Field(0, description='The response code') - data: Dict[str, str] = Field(..., description='The task ID data') - -class TripoBalanceData(BaseModel): - balance: float = Field(..., description='The account balance') - frozen: float = Field(..., description='The frozen balance') - -class TripoBalanceResponse(BaseModel): - code: int = Field(0, description='The response code') - data: TripoBalanceData = Field(..., description='The balance data') - -class TripoErrorResponse(BaseModel): - code: int = Field(..., description='The error code') - message: str = Field(..., description='The error message') - suggestion: str = Field(..., description='The suggestion for fixing the error') diff --git a/comfy_api_nodes/apis/wavespeed.py b/comfy_api_nodes/apis/wavespeed.py deleted file mode 100644 index 07a7bfa5d1cd..000000000000 --- a/comfy_api_nodes/apis/wavespeed.py +++ /dev/null @@ -1,35 +0,0 @@ -from pydantic import BaseModel, Field - - -class SeedVR2ImageRequest(BaseModel): - image: str = Field(...) - target_resolution: str = Field(...) - output_format: str = Field("png") - enable_sync_mode: bool = Field(False) - - -class FlashVSRRequest(BaseModel): - target_resolution: str = Field(...) - video: str = Field(...) - duration: float = Field(...) - - -class TaskCreatedDataResponse(BaseModel): - id: str = Field(...) - - -class TaskCreatedResponse(BaseModel): - code: int = Field(...) - message: str = Field(...) - data: TaskCreatedDataResponse | None = Field(None) - - -class TaskResultDataResponse(BaseModel): - status: str = Field(...) - outputs: list[str] = Field([]) - - -class TaskResultResponse(BaseModel): - code: int = Field(...) - message: str = Field(...) - data: TaskResultDataResponse | None = Field(None) diff --git a/comfy_api_nodes/nodes_bfl.py b/comfy_api_nodes/nodes_bfl.py index 23590bf24c30..5b5da3997378 100644 --- a/comfy_api_nodes/nodes_bfl.py +++ b/comfy_api_nodes/nodes_bfl.py @@ -14,16 +14,22 @@ Flux2ProGenerateRequest, ) from comfy_api_nodes.util import ( - ApiEndpoint, download_url_to_image_tensor, get_number_of_images, - poll_op, resize_mask_to_image, - sync_op, tensor_to_base64_string, validate_aspect_ratio_string, validate_string, ) +from comfy_api_nodes.util.client import fal_run + +FAL_FLUX_PRO_ULTRA = "fal-ai/flux-pro/v1.1-ultra" +FAL_FLUX_KONTEXT_PRO = "fal-ai/flux-kontext/pro" +FAL_FLUX_KONTEXT_MAX = "fal-ai/flux-kontext/max" +FAL_FLUX_PRO_EXPAND = "fal-ai/flux-pro/v1/expand" +FAL_FLUX_PRO_FILL = "fal-ai/flux-pro/v1/fill" +FAL_FLUX_2_PRO = "fal-ai/flux-pro/v2" +FAL_FLUX_2_MAX = "fal-ai/flux-pro/v2/max" def convert_mask_to_image(mask: Input.Image): @@ -93,14 +99,9 @@ def define_schema(cls) -> IO.Schema: ], outputs=[IO.Image.Output()], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.06}""", - ), ) @classmethod @@ -121,36 +122,21 @@ async def execute( ) -> IO.NodeOutput: if image_prompt is None: validate_string(prompt, strip_whitespace=False) - initial_response = await sync_op( - cls, - ApiEndpoint(path="/proxy/bfl/flux-pro-1.1-ultra/generate", method="POST"), - response_model=BFLFluxProGenerateResponse, - data=BFLFluxProUltraGenerateRequest( - prompt=prompt, - prompt_upsampling=prompt_upsampling, - seed=seed, - aspect_ratio=aspect_ratio, - raw=raw, - image_prompt=(image_prompt if image_prompt is None else tensor_to_base64_string(image_prompt)), - image_prompt_strength=(None if image_prompt is None else round(image_prompt_strength, 2)), - ), - ) - response = await poll_op( - cls, - ApiEndpoint(initial_response.polling_url), - response_model=BFLFluxStatusResponse, - status_extractor=lambda r: r.status, - progress_extractor=lambda r: r.progress, - completed_statuses=[BFLStatus.ready], - failed_statuses=[ - BFLStatus.request_moderated, - BFLStatus.content_moderated, - BFLStatus.error, - BFLStatus.task_not_found, - ], - queued_statuses=[], - ) - return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"])) + data = { + "prompt": prompt, + "prompt_upsampling": prompt_upsampling, + "seed": seed, + "aspect_ratio": aspect_ratio, + "raw": raw, + } + if image_prompt is not None: + data["image_prompt"] = tensor_to_base64_string(image_prompt) + data["image_prompt_strength"] = round(image_prompt_strength, 2) + + # TODO: Verify fal.ai response schema for Flux Pro Ultra + result = await fal_run(cls, FAL_FLUX_PRO_ULTRA, data) + image_url = result.get("images", [{}])[0].get("url") or result.get("sample") + return IO.NodeOutput(await download_url_to_image_tensor(image_url)) class FluxKontextProImageNode(IO.ComfyNode): @@ -210,14 +196,12 @@ def define_schema(cls) -> IO.Schema: ], outputs=[IO.Image.Output()], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, ) - BFL_PATH = "/proxy/bfl/flux-kontext-pro/generate" + FAL_MODEL = FAL_FLUX_KONTEXT_PRO NODE_ID = "FluxKontextProImageNode" DISPLAY_NAME = "Flux.1 Kontext [pro] Image" @@ -235,42 +219,27 @@ async def execute( validate_aspect_ratio_string(aspect_ratio, (1, 4), (4, 1)) if input_image is None: validate_string(prompt, strip_whitespace=False) - initial_response = await sync_op( - cls, - ApiEndpoint(path=cls.BFL_PATH, method="POST"), - response_model=BFLFluxProGenerateResponse, - data=BFLFluxKontextProGenerateRequest( - prompt=prompt, - prompt_upsampling=prompt_upsampling, - guidance=round(guidance, 1), - steps=steps, - seed=seed, - aspect_ratio=aspect_ratio, - input_image=(input_image if input_image is None else tensor_to_base64_string(input_image)), - ), - ) - response = await poll_op( - cls, - ApiEndpoint(initial_response.polling_url), - response_model=BFLFluxStatusResponse, - status_extractor=lambda r: r.status, - progress_extractor=lambda r: r.progress, - completed_statuses=[BFLStatus.ready], - failed_statuses=[ - BFLStatus.request_moderated, - BFLStatus.content_moderated, - BFLStatus.error, - BFLStatus.task_not_found, - ], - queued_statuses=[], - ) - return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"])) + data = { + "prompt": prompt, + "prompt_upsampling": prompt_upsampling, + "guidance": round(guidance, 1), + "steps": steps, + "seed": seed, + "aspect_ratio": aspect_ratio, + } + if input_image is not None: + data["input_image"] = tensor_to_base64_string(input_image) + + # TODO: Verify fal.ai response schema for Flux Kontext + result = await fal_run(cls, cls.FAL_MODEL, data) + image_url = result.get("images", [{}])[0].get("url") or result.get("sample") + return IO.NodeOutput(await download_url_to_image_tensor(image_url)) class FluxKontextMaxImageNode(FluxKontextProImageNode): DESCRIPTION = "Edits images using Flux.1 Kontext [max] via api based on prompt and aspect ratio." - BFL_PATH = "/proxy/bfl/flux-kontext-max/generate" + FAL_MODEL = FAL_FLUX_KONTEXT_MAX NODE_ID = "FluxKontextMaxImageNode" DISPLAY_NAME = "Flux.1 Kontext [max] Image" @@ -353,14 +322,9 @@ def define_schema(cls) -> IO.Schema: ], outputs=[IO.Image.Output()], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.05}""", - ), ) @classmethod @@ -377,39 +341,22 @@ async def execute( guidance: float, seed=0, ) -> IO.NodeOutput: - initial_response = await sync_op( - cls, - ApiEndpoint(path="/proxy/bfl/flux-pro-1.0-expand/generate", method="POST"), - response_model=BFLFluxProGenerateResponse, - data=BFLFluxExpandImageRequest( - prompt=prompt, - prompt_upsampling=prompt_upsampling, - top=top, - bottom=bottom, - left=left, - right=right, - steps=steps, - guidance=guidance, - seed=seed, - image=tensor_to_base64_string(image), - ), - ) - response = await poll_op( - cls, - ApiEndpoint(initial_response.polling_url), - response_model=BFLFluxStatusResponse, - status_extractor=lambda r: r.status, - progress_extractor=lambda r: r.progress, - completed_statuses=[BFLStatus.ready], - failed_statuses=[ - BFLStatus.request_moderated, - BFLStatus.content_moderated, - BFLStatus.error, - BFLStatus.task_not_found, - ], - queued_statuses=[], - ) - return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"])) + data = { + "prompt": prompt, + "prompt_upsampling": prompt_upsampling, + "top": top, + "bottom": bottom, + "left": left, + "right": right, + "steps": steps, + "guidance": guidance, + "seed": seed, + "image": tensor_to_base64_string(image), + } + # TODO: Verify fal.ai response schema for Flux Expand + result = await fal_run(cls, FAL_FLUX_PRO_EXPAND, data) + image_url = result.get("images", [{}])[0].get("url") or result.get("sample") + return IO.NodeOutput(await download_url_to_image_tensor(image_url)) class FluxProFillNode(IO.ComfyNode): @@ -463,14 +410,9 @@ def define_schema(cls) -> IO.Schema: ], outputs=[IO.Image.Output()], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.05}""", - ), ) @classmethod @@ -487,43 +429,26 @@ async def execute( # prepare mask mask = resize_mask_to_image(mask, image) mask = tensor_to_base64_string(convert_mask_to_image(mask)) - initial_response = await sync_op( - cls, - ApiEndpoint(path="/proxy/bfl/flux-pro-1.0-fill/generate", method="POST"), - response_model=BFLFluxProGenerateResponse, - data=BFLFluxFillImageRequest( - prompt=prompt, - prompt_upsampling=prompt_upsampling, - steps=steps, - guidance=guidance, - seed=seed, - image=tensor_to_base64_string(image[:, :, :, :3]), # make sure image will have alpha channel removed - mask=mask, - ), - ) - response = await poll_op( - cls, - ApiEndpoint(initial_response.polling_url), - response_model=BFLFluxStatusResponse, - status_extractor=lambda r: r.status, - progress_extractor=lambda r: r.progress, - completed_statuses=[BFLStatus.ready], - failed_statuses=[ - BFLStatus.request_moderated, - BFLStatus.content_moderated, - BFLStatus.error, - BFLStatus.task_not_found, - ], - queued_statuses=[], - ) - return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"])) + data = { + "prompt": prompt, + "prompt_upsampling": prompt_upsampling, + "steps": steps, + "guidance": guidance, + "seed": seed, + "image": tensor_to_base64_string(image[:, :, :, :3]), + "mask": mask, + } + # TODO: Verify fal.ai response schema for Flux Fill + result = await fal_run(cls, FAL_FLUX_PRO_FILL, data) + image_url = result.get("images", [{}])[0].get("url") or result.get("sample") + return IO.NodeOutput(await download_url_to_image_tensor(image_url)) class Flux2ProImageNode(IO.ComfyNode): NODE_ID = "Flux2ProImageNode" DISPLAY_NAME = "Flux.2 [pro] Image" - API_ENDPOINT = "/proxy/bfl/flux-2-pro/generate" + FAL_MODEL_ID = FAL_FLUX_2_PRO PRICE_BADGE_EXPR = """ ( $MP := 1024 * 1024; @@ -587,15 +512,9 @@ def define_schema(cls) -> IO.Schema: ], outputs=[IO.Image.Output()], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["width", "height"], inputs=["images"]), - expr=cls.PRICE_BADGE_EXPR, - ), ) @classmethod @@ -608,54 +527,31 @@ async def execute( prompt_upsampling: bool, images: Input.Image | None = None, ) -> IO.NodeOutput: - reference_images = {} + data = { + "prompt": prompt, + "width": width, + "height": height, + "seed": seed, + "prompt_upsampling": prompt_upsampling, + } if images is not None: if get_number_of_images(images) > 9: raise ValueError("The current maximum number of supported images is 9.") for image_index in range(images.shape[0]): key_name = f"input_image_{image_index + 1}" if image_index else "input_image" - reference_images[key_name] = tensor_to_base64_string(images[image_index], total_pixels=2048 * 2048) - initial_response = await sync_op( - cls, - ApiEndpoint(path=cls.API_ENDPOINT, method="POST"), - response_model=BFLFluxProGenerateResponse, - data=Flux2ProGenerateRequest( - prompt=prompt, - width=width, - height=height, - seed=seed, - prompt_upsampling=prompt_upsampling, - **reference_images, - ), - ) + data[key_name] = tensor_to_base64_string(images[image_index], total_pixels=2048 * 2048) - def price_extractor(_r: BaseModel) -> float | None: - return None if initial_response.cost is None else initial_response.cost / 100 - - response = await poll_op( - cls, - ApiEndpoint(initial_response.polling_url), - response_model=BFLFluxStatusResponse, - status_extractor=lambda r: r.status, - progress_extractor=lambda r: r.progress, - price_extractor=price_extractor, - completed_statuses=[BFLStatus.ready], - failed_statuses=[ - BFLStatus.request_moderated, - BFLStatus.content_moderated, - BFLStatus.error, - BFLStatus.task_not_found, - ], - queued_statuses=[], - ) - return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"])) + # TODO: Verify fal.ai response schema for Flux 2 Pro + result = await fal_run(cls, cls.FAL_MODEL_ID, data) + image_url = result.get("images", [{}])[0].get("url") or result.get("sample") + return IO.NodeOutput(await download_url_to_image_tensor(image_url)) class Flux2MaxImageNode(Flux2ProImageNode): NODE_ID = "Flux2MaxImageNode" DISPLAY_NAME = "Flux.2 [max] Image" - API_ENDPOINT = "/proxy/bfl/flux-2-max/generate" + FAL_MODEL_ID = FAL_FLUX_2_MAX PRICE_BADGE_EXPR = """ ( $MP := 1024 * 1024; diff --git a/comfy_api_nodes/nodes_bria.py b/comfy_api_nodes/nodes_bria.py index 4044ee3ead0c..a3bf9db1a3d4 100644 --- a/comfy_api_nodes/nodes_bria.py +++ b/comfy_api_nodes/nodes_bria.py @@ -2,26 +2,18 @@ from comfy_api.latest import IO, ComfyExtension, Input from comfy_api_nodes.apis.bria import ( - BriaEditImageRequest, - BriaRemoveBackgroundRequest, - BriaRemoveBackgroundResponse, - BriaRemoveVideoBackgroundRequest, - BriaRemoveVideoBackgroundResponse, - BriaImageEditResponse, - BriaStatusResponse, InputModerationSettings, ) from comfy_api_nodes.util import ( - ApiEndpoint, convert_mask_to_image, download_url_to_image_tensor, download_url_to_video_output, - poll_op, - sync_op, - upload_image_to_comfyapi, - upload_video_to_comfyapi, + upload_video_to_fal, validate_video_duration, ) +from comfy_api_nodes.util.client import fal_run + +FAL_BRIA_TEXT_TO_IMAGE = "fal-ai/bria/text-to-image/hd" class BriaImageEditNode(IO.ComfyNode): @@ -102,14 +94,9 @@ def define_schema(cls): IO.String.Output(display_name="structured_prompt"), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.04}""", - ), ) @classmethod @@ -128,37 +115,28 @@ async def execute( ) -> IO.NodeOutput: if not prompt and not structured_prompt: raise ValueError("One of prompt or structured_prompt is required to be non-empty.") + from comfy_api_nodes.util.upload_helpers import upload_image_to_fal mask_url = None if mask is not None: - mask_url = await upload_image_to_comfyapi(cls, convert_mask_to_image(mask), wait_label="Uploading mask") - response = await sync_op( - cls, - ApiEndpoint(path="proxy/bria/v2/image/edit", method="POST"), - data=BriaEditImageRequest( - instruction=prompt if prompt else None, - structured_instruction=structured_prompt if structured_prompt else None, - images=[await upload_image_to_comfyapi(cls, image, wait_label="Uploading image")], - mask=mask_url, - negative_prompt=negative_prompt if negative_prompt else None, - guidance_scale=guidance_scale, - seed=seed, - model_version=model, - steps_num=steps, - prompt_content_moderation=moderation.get("prompt_content_moderation", False), - visual_input_content_moderation=moderation.get("visual_input_moderation", False), - visual_output_content_moderation=moderation.get("visual_output_moderation", False), - ), - response_model=BriaStatusResponse, - ) - response = await poll_op( - cls, - ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"), - status_extractor=lambda r: r.status, - response_model=BriaImageEditResponse, - ) + mask_url = await upload_image_to_fal(convert_mask_to_image(mask)[0] if len(convert_mask_to_image(mask).shape) > 3 else convert_mask_to_image(mask), "image/png") + image_url = await upload_image_to_fal(image[0] if len(image.shape) > 3 else image, "image/png") + result = await fal_run(cls, FAL_BRIA_TEXT_TO_IMAGE, { + "instruction": prompt if prompt else None, + "structured_instruction": structured_prompt if structured_prompt else None, + "images": [image_url], + "mask": mask_url, + "negative_prompt": negative_prompt if negative_prompt else None, + "guidance_scale": guidance_scale, + "seed": seed, + "model_version": model, + "steps_num": steps, + "prompt_content_moderation": moderation.get("prompt_content_moderation", False), + "visual_input_content_moderation": moderation.get("visual_input_moderation", False), + "visual_output_content_moderation": moderation.get("visual_output_moderation", False), + }) # TODO: verify fal.ai field names return IO.NodeOutput( - await download_url_to_image_tensor(response.result.image_url), - response.result.structured_prompt, + await download_url_to_image_tensor(result["image"]["url"]), + result.get("structured_prompt", ""), ) @@ -200,14 +178,9 @@ def define_schema(cls): ], outputs=[IO.Image.Output()], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.018}""", - ), ) @classmethod @@ -217,25 +190,16 @@ async def execute( moderation: dict, seed: int, ) -> IO.NodeOutput: - response = await sync_op( - cls, - ApiEndpoint(path="/proxy/bria/v2/image/edit/remove_background", method="POST"), - data=BriaRemoveBackgroundRequest( - image=await upload_image_to_comfyapi(cls, image, wait_label="Uploading image"), - sync=False, - visual_input_content_moderation=moderation.get("visual_input_moderation", False), - visual_output_content_moderation=moderation.get("visual_output_moderation", False), - seed=seed, - ), - response_model=BriaStatusResponse, - ) - response = await poll_op( - cls, - ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"), - status_extractor=lambda r: r.status, - response_model=BriaRemoveBackgroundResponse, - ) - return IO.NodeOutput(await download_url_to_image_tensor(response.result.image_url)) + from comfy_api_nodes.util.upload_helpers import upload_image_to_fal + image_url = await upload_image_to_fal(image[0] if len(image.shape) > 3 else image, "image/png") + result = await fal_run(cls, FAL_BRIA_TEXT_TO_IMAGE, { + "image": image_url, + "sync": False, + "visual_input_content_moderation": moderation.get("visual_input_moderation", False), + "visual_output_content_moderation": moderation.get("visual_output_moderation", False), + "seed": seed, + }) # TODO: verify fal.ai field names and model ID for remove_background + return IO.NodeOutput(await download_url_to_image_tensor(result["image"]["url"])) class BriaRemoveVideoBackground(IO.ComfyNode): @@ -278,14 +242,9 @@ def define_schema(cls): ], outputs=[IO.Video.Output()], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.14,"format":{"suffix":"/second"}}""", - ), ) @classmethod @@ -296,24 +255,14 @@ async def execute( seed: int, ) -> IO.NodeOutput: validate_video_duration(video, max_duration=60.0) - response = await sync_op( - cls, - ApiEndpoint(path="/proxy/bria/v2/video/edit/remove_background", method="POST"), - data=BriaRemoveVideoBackgroundRequest( - video=await upload_video_to_comfyapi(cls, video), - background_color=background_color, - output_container_and_codec="mp4_h264", - seed=seed, - ), - response_model=BriaStatusResponse, - ) - response = await poll_op( - cls, - ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"), - status_extractor=lambda r: r.status, - response_model=BriaRemoveVideoBackgroundResponse, - ) - return IO.NodeOutput(await download_url_to_video_output(response.result.video_url)) + video_url = await upload_video_to_fal(video) + result = await fal_run(cls, FAL_BRIA_TEXT_TO_IMAGE, { + "video": video_url, + "background_color": background_color, + "output_container_and_codec": "mp4_h264", + "seed": seed, + }) # TODO: verify fal.ai field names and model ID for video remove_background + return IO.NodeOutput(await download_url_to_video_output(result["video"]["url"])) class BriaExtension(ComfyExtension): diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py index 6dbd5984ed39..f9a808e864bc 100644 --- a/comfy_api_nodes/nodes_bytedance.py +++ b/comfy_api_nodes/nodes_bytedance.py @@ -29,13 +29,17 @@ image_tensor_pair_to_batch, poll_op, sync_op, - upload_images_to_comfyapi, + upload_images_to_fal, validate_image_aspect_ratio, validate_image_dimensions, validate_string, ) +from comfy_api_nodes.util._helpers import get_fal_auth_header +from comfy_api_nodes.util.client import fal_run -BYTEPLUS_IMAGE_ENDPOINT = "/proxy/byteplus/api/v3/images/generations" +FAL_SEEDREAM = "fal-ai/seedream-4.5" + +BYTEPLUS_IMAGE_ENDPOINT = "fal_image_endpoint" # migrated from /proxy/byteplus/api/v3/images/generations SEEDREAM_MODELS = { "seedream 5.0 lite": "seedream-5-0-260128", @@ -43,9 +47,9 @@ "seedream-4-0-250828": "seedream-4-0-250828", } -# Long-running tasks endpoints(e.g., video) -BYTEPLUS_TASK_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" -BYTEPLUS_TASK_STATUS_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" # + /{task_id} +# Long-running tasks endpoints(e.g., video) - migrated from /proxy/byteplus/... +BYTEPLUS_TASK_ENDPOINT = "fal_task_endpoint" # migrated from /proxy/byteplus/api/v3/contents/generations/tasks +BYTEPLUS_TASK_STATUS_ENDPOINT = "fal_task_status_endpoint" # migrated from /proxy/byteplus/api/v3/contents/generations/tasks + /{task_id} def get_image_url_from_response(response: ImageTaskCreationResponse) -> str: @@ -127,14 +131,9 @@ def define_schema(cls): IO.Image.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.03}""", - ), ) @classmethod @@ -171,13 +170,9 @@ async def execute( guidance_scale=guidance_scale, watermark=watermark, ) - response = await sync_op( - cls, - ApiEndpoint(path=BYTEPLUS_IMAGE_ENDPOINT, method="POST"), - data=payload, - response_model=ImageTaskCreationResponse, - ) - return IO.NodeOutput(await download_url_to_image_tensor(get_image_url_from_response(response))) + result = await fal_run(cls, FAL_SEEDREAM, payload.model_dump(exclude_none=True)) # TODO: verify fal.ai field names; use correct fal model for seedream-3 + image_url = result["images"][0]["url"] + return IO.NodeOutput(await download_url_to_image_tensor(image_url)) class ByteDanceSeedreamNode(IO.ComfyNode): @@ -279,25 +274,9 @@ def define_schema(cls): IO.Image.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["model"]), - expr=""" - ( - $price := $contains(widgets.model, "5.0 lite") ? 0.035 : - $contains(widgets.model, "4-5") ? 0.04 : 0.03; - { - "type":"usd", - "usd": $price, - "format": { "suffix":" x images/Run", "approximate": true } - } - ) - """, - ), ) @classmethod @@ -358,33 +337,28 @@ async def execute( if n_input_images: for i in image: validate_image_aspect_ratio(i, (1, 3), (3, 1)) - reference_images_urls = await upload_images_to_comfyapi( - cls, + reference_images_urls = await upload_images_to_fal( image, max_images=n_input_images, mime_type="image/png", ) - response = await sync_op( - cls, - ApiEndpoint(path=BYTEPLUS_IMAGE_ENDPOINT, method="POST"), - response_model=ImageTaskCreationResponse, - data=Seedream4TaskCreationRequest( - model=model, - prompt=prompt, - image=reference_images_urls, - size=f"{w}x{h}", - seed=seed, - sequential_image_generation=sequential_image_generation, - sequential_image_generation_options=Seedream4Options(max_images=max_images), - watermark=watermark, - output_format="png" if model == "seedream-5-0-260128" else None, - ), - ) - if len(response.data) == 1: - return IO.NodeOutput(await download_url_to_image_tensor(get_image_url_from_response(response))) - urls = [str(d["url"]) for d in response.data if isinstance(d, dict) and "url" in d] - if fail_on_partial and len(urls) < len(response.data): - raise RuntimeError(f"Only {len(urls)} of {len(response.data)} images were generated before error.") + result = await fal_run(cls, FAL_SEEDREAM, { + "model": model, + "prompt": prompt, + "image": reference_images_urls, + "size": f"{w}x{h}", + "seed": seed, + "sequential_image_generation": sequential_image_generation, + "sequential_image_generation_options": {"max_images": max_images}, + "watermark": watermark, + "output_format": "png" if model == "seedream-5-0-260128" else None, + }) # TODO: verify fal.ai field names + images = result.get("images", []) + if len(images) == 1: + return IO.NodeOutput(await download_url_to_image_tensor(images[0]["url"])) + urls = [img["url"] for img in images if "url" in img] + if fail_on_partial and len(urls) < len(images): + raise RuntimeError(f"Only {len(urls)} of {len(images)} images were generated before error.") return IO.NodeOutput(torch.cat([await download_url_to_image_tensor(i) for i in urls])) @@ -470,12 +444,9 @@ def define_schema(cls): IO.Video.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=PRICE_BADGE_VIDEO, ) @classmethod @@ -602,12 +573,9 @@ def define_schema(cls): IO.Video.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=PRICE_BADGE_VIDEO, ) @classmethod @@ -631,7 +599,7 @@ async def execute( validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000) validate_image_aspect_ratio(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5 - image_url = (await upload_images_to_comfyapi(cls, image, max_images=1))[0] + image_url = (await upload_images_to_fal(image, max_images=1))[0] prompt = ( f"{prompt} " f"--resolution {resolution} " @@ -738,12 +706,9 @@ def define_schema(cls): IO.Video.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=PRICE_BADGE_VIDEO, ) @classmethod @@ -769,8 +734,7 @@ async def execute( validate_image_dimensions(i, min_width=300, min_height=300, max_width=6000, max_height=6000) validate_image_aspect_ratio(i, (2, 5), (5, 2), strict=False) # 0.4 to 2.5 - download_urls = await upload_images_to_comfyapi( - cls, + download_urls = await upload_images_to_fal( image_tensor_pair_to_batch(first_frame, last_frame), max_images=2, mime_type="image/png", @@ -867,46 +831,9 @@ def define_schema(cls): IO.Video.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["model", "duration", "resolution"]), - expr=""" - ( - $priceByModel := { - "seedance-1-0-pro": { - "480p":[0.23,0.24], - "720p":[0.51,0.56] - }, - "seedance-1-0-lite": { - "480p":[0.17,0.18], - "720p":[0.37,0.41] - } - }; - $model := widgets.model; - $modelKey := - $contains($model, "seedance-1-0-pro") ? "seedance-1-0-pro" : - "seedance-1-0-lite"; - $resolution := widgets.resolution; - $resKey := - $contains($resolution, "720") ? "720p" : - "480p"; - $modelPrices := $lookup($priceByModel, $modelKey); - $baseRange := $lookup($modelPrices, $resKey); - $min10s := $baseRange[0]; - $max10s := $baseRange[1]; - $scale := widgets.duration / 10; - $minCost := $min10s * $scale; - $maxCost := $max10s * $scale; - ($minCost = $maxCost) - ? {"type":"usd","usd": $minCost} - : {"type":"range_usd","min_usd": $minCost, "max_usd": $maxCost} - ) - """, - ), ) @classmethod @@ -927,7 +854,7 @@ async def execute( validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000) validate_image_aspect_ratio(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5 - image_urls = await upload_images_to_comfyapi(cls, images, max_images=4, mime_type="image/png") + image_urls = await upload_images_to_fal(images, max_images=4, mime_type="image/png") prompt = ( f"{prompt} " f"--resolution {resolution} " @@ -952,20 +879,13 @@ async def process_video_task( payload: Text2VideoTaskCreationRequest | Image2VideoTaskCreationRequest, estimated_duration: int | None, ) -> IO.NodeOutput: - initial_response = await sync_op( - cls, - ApiEndpoint(path=BYTEPLUS_TASK_ENDPOINT, method="POST"), - data=payload, - response_model=TaskCreationResponse, - ) - response = await poll_op( - cls, - ApiEndpoint(path=f"{BYTEPLUS_TASK_STATUS_ENDPOINT}/{initial_response.id}"), - status_extractor=lambda r: r.status, + result = await fal_run( + cls, FAL_SEEDREAM, + payload.model_dump(exclude_none=True), estimated_duration=estimated_duration, - response_model=TaskStatusResponse, - ) - return IO.NodeOutput(await download_url_to_video_output(response.content.video_url)) + ) # TODO: verify fal.ai field names; use correct fal model for ByteDance video + video_url = result["video"]["url"] + return IO.NodeOutput(await download_url_to_video_output(video_url)) def raise_if_text_params(prompt: str, text_params: list[str]) -> None: diff --git a/comfy_api_nodes/nodes_elevenlabs.py b/comfy_api_nodes/nodes_elevenlabs.py index e452daf77320..3dbc743ba096 100644 --- a/comfy_api_nodes/nodes_elevenlabs.py +++ b/comfy_api_nodes/nodes_elevenlabs.py @@ -5,28 +5,21 @@ from comfy_api.latest import IO, ComfyExtension, Input from comfy_api_nodes.apis.elevenlabs import ( - AddVoiceRequest, - AddVoiceResponse, DialogueInput, DialogueSettings, - SpeechToSpeechRequest, - SpeechToTextRequest, - SpeechToTextResponse, - TextToDialogueRequest, - TextToSoundEffectsRequest, - TextToSpeechRequest, TextToSpeechVoiceSettings, ) from comfy_api_nodes.util import ( - ApiEndpoint, audio_bytes_to_audio_input, audio_ndarray_to_bytesio, audio_tensor_to_contiguous_ndarray, - sync_op, - sync_op_raw, - upload_audio_to_comfyapi, + download_url_as_bytesio, validate_string, ) +from comfy_api_nodes.util.client import fal_run +from comfy_api_nodes.util.upload_helpers import upload_file_to_fal + +FAL_ELEVENLABS_TTS = "fal-ai/elevenlabs/tts/turbo-v2.5" ELEVENLABS_MUSIC_SECTIONS = "ELEVENLABS_MUSIC_SECTIONS" # Custom type for music sections ELEVENLABS_COMPOSITION_PLAN = "ELEVENLABS_COMPOSITION_PLAN" # Custom type for composition plan @@ -63,6 +56,13 @@ } +async def _upload_audio_to_fal(audio: Input.Audio) -> str: + """Convert audio input to bytes and upload to fal.ai CDN. Returns the CDN URL.""" + audio_data_np = audio_tensor_to_contiguous_ndarray(audio["waveform"]) + audio_bytes_io = audio_ndarray_to_bytesio(audio_data_np, audio["sample_rate"], "mp4", "aac") + return await upload_file_to_fal(audio_bytes_io, "audio/mp4") + + class ElevenLabsSpeechToText(IO.ComfyNode): @classmethod def define_schema(cls) -> IO.Schema: @@ -152,14 +152,9 @@ def define_schema(cls) -> IO.Schema: IO.String.Output(display_name="words_json"), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.0073,"format":{"approximate":true,"suffix":"/minute"}}""", - ), ) @classmethod @@ -176,32 +171,25 @@ async def execute( "Number of speakers cannot be specified when diarization is enabled. " "Either disable diarization or set num_speakers to 0." ) - request = SpeechToTextRequest( - model_id=model["model"], - cloud_storage_url=await upload_audio_to_comfyapi( - cls, audio, container_format="mp4", codec_name="aac", mime_type="audio/mp4" - ), - language_code=language_code if language_code.strip() else None, - tag_audio_events=model["tag_audio_events"], - num_speakers=num_speakers if num_speakers > 0 else None, - timestamps_granularity=model["timestamps_granularity"], - diarize=model["diarize"], - diarization_threshold=model["diarization_threshold"] if model["diarize"] else None, - seed=seed, - temperature=model["temperature"], - ) - response = await sync_op( - cls, - ApiEndpoint(path="/proxy/elevenlabs/v1/speech-to-text", method="POST"), - response_model=SpeechToTextResponse, - data=request, - content_type="multipart/form-data", - ) + audio_url = await _upload_audio_to_fal(audio) + result = await fal_run(cls, FAL_ELEVENLABS_TTS, { + "model_id": model["model"], + "audio_url": audio_url, + "language_code": language_code if language_code.strip() else None, + "tag_audio_events": model["tag_audio_events"], + "num_speakers": num_speakers if num_speakers > 0 else None, + "timestamps_granularity": model["timestamps_granularity"], + "diarize": model["diarize"], + "diarization_threshold": model["diarization_threshold"] if model["diarize"] else None, + "seed": seed, + "temperature": model["temperature"], + }) + # TODO: verify fal.ai field names words_json = json.dumps( - [w.model_dump(exclude_none=True) for w in response.words] if response.words else [], + result.get("words", []), indent=2, ) - return IO.NodeOutput(response.text, response.language_code, words_json) + return IO.NodeOutput(result["text"], result.get("language_code", ""), words_json) class ElevenLabsVoiceSelector(IO.ComfyNode): @@ -358,14 +346,9 @@ def define_schema(cls) -> IO.Schema: IO.Audio.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.24,"format":{"approximate":true,"suffix":"/1K chars"}}""", - ), ) @classmethod @@ -381,31 +364,27 @@ async def execute( output_format: str, ) -> IO.NodeOutput: validate_string(text, min_length=1) - request = TextToSpeechRequest( - text=text, - model_id=model["model"], - language_code=language_code if language_code.strip() else None, - voice_settings=TextToSpeechVoiceSettings( - stability=stability, - similarity_boost=model["similarity_boost"], - speed=model["speed"], - use_speaker_boost=model.get("use_speaker_boost", None), - style=model.get("style", None), - ), - seed=seed, - apply_text_normalization=apply_text_normalization, - ) - response = await sync_op_raw( - cls, - ApiEndpoint( - path=f"/proxy/elevenlabs/v1/text-to-speech/{voice}", - method="POST", - query_params={"output_format": output_format}, - ), - data=request, - as_binary=True, + voice_settings = TextToSpeechVoiceSettings( + stability=stability, + similarity_boost=model["similarity_boost"], + speed=model["speed"], + use_speaker_boost=model.get("use_speaker_boost", None), + style=model.get("style", None), ) - return IO.NodeOutput(audio_bytes_to_audio_input(response)) + result = await fal_run(cls, FAL_ELEVENLABS_TTS, { + "text": text, + "voice_id": voice, + "model_id": model["model"], + "language_code": language_code if language_code.strip() else None, + "voice_settings": voice_settings.model_dump(exclude_none=True), + "seed": seed, + "apply_text_normalization": apply_text_normalization, + "output_format": output_format, + }) + # TODO: verify fal.ai field names + audio_url = result["audio"]["url"] + audio_bytes_io = await download_url_as_bytesio(audio_url) + return IO.NodeOutput(audio_bytes_to_audio_input(audio_bytes_io.read())) class ElevenLabsAudioIsolation(IO.ComfyNode): @@ -426,14 +405,9 @@ def define_schema(cls) -> IO.Schema: IO.Audio.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.24,"format":{"approximate":true,"suffix":"/minute"}}""", - ), ) @classmethod @@ -441,16 +415,14 @@ async def execute( cls, audio: Input.Audio, ) -> IO.NodeOutput: - audio_data_np = audio_tensor_to_contiguous_ndarray(audio["waveform"]) - audio_bytes_io = audio_ndarray_to_bytesio(audio_data_np, audio["sample_rate"], "mp4", "aac") - response = await sync_op_raw( - cls, - ApiEndpoint(path="/proxy/elevenlabs/v1/audio-isolation", method="POST"), - files={"audio": ("audio.mp4", audio_bytes_io, "audio/mp4")}, - content_type="multipart/form-data", - as_binary=True, - ) - return IO.NodeOutput(audio_bytes_to_audio_input(response)) + audio_url = await _upload_audio_to_fal(audio) + result = await fal_run(cls, FAL_ELEVENLABS_TTS, { + "audio_url": audio_url, + }) + # TODO: verify fal.ai field names + output_url = result["audio"]["url"] + audio_bytes_io = await download_url_as_bytesio(output_url) + return IO.NodeOutput(audio_bytes_to_audio_input(audio_bytes_io.read())) class ElevenLabsTextToSoundEffects(IO.ComfyNode): @@ -513,14 +485,9 @@ def define_schema(cls) -> IO.Schema: IO.Audio.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.14,"format":{"approximate":true,"suffix":"/minute"}}""", - ), ) @classmethod @@ -531,22 +498,17 @@ async def execute( output_format: str, ) -> IO.NodeOutput: validate_string(text, min_length=1) - response = await sync_op_raw( - cls, - ApiEndpoint( - path="/proxy/elevenlabs/v1/sound-generation", - method="POST", - query_params={"output_format": output_format}, - ), - data=TextToSoundEffectsRequest( - text=text, - duration_seconds=model["duration"], - prompt_influence=model["prompt_influence"], - loop=model.get("loop", None), - ), - as_binary=True, - ) - return IO.NodeOutput(audio_bytes_to_audio_input(response)) + result = await fal_run(cls, FAL_ELEVENLABS_TTS, { + "text": text, + "duration_seconds": model["duration"], + "prompt_influence": model["prompt_influence"], + "loop": model.get("loop", None), + "output_format": output_format, + }) + # TODO: verify fal.ai field names + audio_url = result["audio"]["url"] + audio_bytes_io = await download_url_as_bytesio(audio_url) + return IO.NodeOutput(audio_bytes_to_audio_input(audio_bytes_io.read())) class ElevenLabsInstantVoiceClone(IO.ComfyNode): @@ -579,12 +541,9 @@ def define_schema(cls) -> IO.Schema: IO.Custom(ELEVENLABS_VOICE).Output(display_name="voice"), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge(expr="""{"type":"usd","usd":0.15}"""), ) @classmethod @@ -593,27 +552,21 @@ async def execute( files: IO.Autogrow.Type, remove_background_noise: bool, ) -> IO.NodeOutput: - file_tuples: list[tuple[str, tuple[str, bytes, str]]] = [] + audio_urls = [] for key in files: audio = files[key] - sample_rate: int = audio["sample_rate"] - waveform = audio["waveform"] - audio_data_np = audio_tensor_to_contiguous_ndarray(waveform) - audio_bytes_io = audio_ndarray_to_bytesio(audio_data_np, sample_rate, "mp4", "aac") - file_tuples.append(("files", (f"{key}.mp4", audio_bytes_io.getvalue(), "audio/mp4"))) - - response = await sync_op( - cls, - ApiEndpoint(path="/proxy/elevenlabs/v1/voices/add", method="POST"), - response_model=AddVoiceResponse, - data=AddVoiceRequest( - name=str(uuid.uuid4()), - remove_background_noise=remove_background_noise, - ), - files=file_tuples, - content_type="multipart/form-data", - ) - return IO.NodeOutput(response.voice_id) + audio_data_np = audio_tensor_to_contiguous_ndarray(audio["waveform"]) + audio_bytes_io = audio_ndarray_to_bytesio(audio_data_np, audio["sample_rate"], "mp4", "aac") + url = await upload_file_to_fal(audio_bytes_io, "audio/mp4") + audio_urls.append(url) + + result = await fal_run(cls, FAL_ELEVENLABS_TTS, { + "name": str(uuid.uuid4()), + "remove_background_noise": remove_background_noise, + "audio_urls": audio_urls, + }) + # TODO: verify fal.ai field names + return IO.NodeOutput(result["voice_id"]) ELEVENLABS_STS_VOICE_SETTINGS = [ @@ -716,14 +669,9 @@ def define_schema(cls) -> IO.Schema: IO.Audio.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.24,"format":{"approximate":true,"suffix":"/minute"}}""", - ), ) @classmethod @@ -737,8 +685,7 @@ async def execute( seed: int, remove_background_noise: bool, ) -> IO.NodeOutput: - audio_data_np = audio_tensor_to_contiguous_ndarray(audio["waveform"]) - audio_bytes_io = audio_ndarray_to_bytesio(audio_data_np, audio["sample_rate"], "mp4", "aac") + audio_url = await _upload_audio_to_fal(audio) voice_settings = TextToSpeechVoiceSettings( stability=stability, similarity_boost=model["similarity_boost"], @@ -746,24 +693,19 @@ async def execute( use_speaker_boost=model["use_speaker_boost"], speed=model["speed"], ) - response = await sync_op_raw( - cls, - ApiEndpoint( - path=f"/proxy/elevenlabs/v1/speech-to-speech/{voice}", - method="POST", - query_params={"output_format": output_format}, - ), - data=SpeechToSpeechRequest( - model_id=model["model"], - voice_settings=voice_settings.model_dump_json(exclude_none=True), - seed=seed, - remove_background_noise=remove_background_noise, - ), - files={"audio": ("audio.mp4", audio_bytes_io.getvalue(), "audio/mp4")}, - content_type="multipart/form-data", - as_binary=True, - ) - return IO.NodeOutput(audio_bytes_to_audio_input(response)) + result = await fal_run(cls, FAL_ELEVENLABS_TTS, { + "voice_id": voice, + "audio_url": audio_url, + "model_id": model["model"], + "voice_settings": voice_settings.model_dump(exclude_none=True), + "seed": seed, + "remove_background_noise": remove_background_noise, + "output_format": output_format, + }) + # TODO: verify fal.ai field names + output_url = result["audio"]["url"] + audio_bytes_io = await download_url_as_bytesio(output_url) + return IO.NodeOutput(audio_bytes_to_audio_input(audio_bytes_io.read())) def _generate_dialogue_inputs(count: int) -> list: @@ -856,14 +798,9 @@ def define_schema(cls) -> IO.Schema: IO.Audio.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.24,"format":{"approximate":true,"suffix":"/1K chars"}}""", - ), ) @classmethod @@ -878,31 +815,25 @@ async def execute( output_format: str, ) -> IO.NodeOutput: num_entries = int(inputs["inputs"]) - dialogue_inputs: list[DialogueInput] = [] + dialogue_inputs = [] for i in range(1, num_entries + 1): text = inputs[f"text{i}"] voice_id = inputs[f"voice{i}"] validate_string(text, min_length=1) - dialogue_inputs.append(DialogueInput(text=text, voice_id=voice_id)) - request = TextToDialogueRequest( - inputs=dialogue_inputs, - model_id=model, - language_code=language_code if language_code.strip() else None, - settings=DialogueSettings(stability=stability), - seed=seed, - apply_text_normalization=apply_text_normalization, - ) - response = await sync_op_raw( - cls, - ApiEndpoint( - path="/proxy/elevenlabs/v1/text-to-dialogue", - method="POST", - query_params={"output_format": output_format}, - ), - data=request, - as_binary=True, - ) - return IO.NodeOutput(audio_bytes_to_audio_input(response)) + dialogue_inputs.append({"text": text, "voice_id": voice_id}) + result = await fal_run(cls, FAL_ELEVENLABS_TTS, { + "inputs": dialogue_inputs, + "model_id": model, + "language_code": language_code if language_code.strip() else None, + "settings": {"stability": stability}, + "seed": seed, + "apply_text_normalization": apply_text_normalization, + "output_format": output_format, + }) + # TODO: verify fal.ai field names + audio_url = result["audio"]["url"] + audio_bytes_io = await download_url_as_bytesio(audio_url) + return IO.NodeOutput(audio_bytes_to_audio_input(audio_bytes_io.read())) class ElevenLabsExtension(ComfyExtension): diff --git a/comfy_api_nodes/nodes_fal.py b/comfy_api_nodes/nodes_fal.py new file mode 100644 index 000000000000..0c6cd3957555 --- /dev/null +++ b/comfy_api_nodes/nodes_fal.py @@ -0,0 +1,136 @@ +"""Generic fal.ai node -- submit any fal.ai model by ID with JSON input.""" + +import base64 +import json +from io import BytesIO + +import torch +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension, Input +from comfy_api_nodes.util.client import fal_run +from comfy_api_nodes.util.conversions import bytesio_to_image_tensor +from comfy_api_nodes.util.upload_helpers import upload_image_to_fal + + +class FalGenericNode(IO.ComfyNode): + """Submit any fal.ai model by providing a model ID and JSON input. + + Use this node to access any of the 1200+ models on fal.ai with a single + FAL_API_KEY environment variable. + """ + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="FalGenericNode", + display_name="fal.ai Generic Model", + category="api node/fal.ai", + description=( + "Run any fal.ai model by providing its model ID and JSON input. " + "Browse models at https://fal.ai/models" + ), + inputs=[ + IO.String.Input( + "model_id", + default="fal-ai/flux/dev", + tooltip=( + "fal.ai model ID, e.g. 'fal-ai/flux/dev', 'fal-ai/kling-video/v2/master/text-to-video'. " + "Find model IDs at https://fal.ai/models" + ), + ), + IO.String.Input( + "input_json", + multiline=True, + default='{"prompt": "a cat in space"}', + tooltip="JSON object with the model's input parameters. Check the model's API page for the schema.", + ), + IO.Image.Input( + "image", + optional=True, + tooltip=( + "Optional image input. If provided, it will be uploaded to fal.ai CDN " + "and the URL will be added to the input JSON as 'image_url'." + ), + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="Random seed. Added to input JSON as 'seed' if > 0.", + optional=True, + ), + ], + outputs=[ + IO.String.Output("result_json"), + IO.Image.Output("images"), + ], + hidden=[ + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model_id: str, + input_json: str, + image: Input.Image | None = None, + seed: int = 0, + ) -> IO.NodeOutput: + # Parse input JSON + try: + data = json.loads(input_json) + except json.JSONDecodeError as e: + raise ValueError(f"Invalid input JSON: {e}") from e + + if not isinstance(data, dict): + raise ValueError("Input JSON must be a JSON object (dict), not an array or scalar.") + + # Upload image if provided and add URL to input + if image is not None: + image_url = await upload_image_to_fal(image) + data["image_url"] = image_url + + # Add seed if provided + if seed > 0: + data["seed"] = seed + + # Run the model via fal.ai queue + result = await fal_run(cls, model_id, data) + + # Extract images if present in result + output_images = None + if "images" in result and isinstance(result["images"], list): + image_tensors = [] + for img_data in result["images"]: + if isinstance(img_data, dict) and "url" in img_data: + # Download image from fal.ai CDN URL + from comfy_api_nodes.util import download_url_to_image_tensor + tensor = await download_url_to_image_tensor(img_data["url"]) + image_tensors.append(tensor) + if image_tensors: + output_images = torch.cat(image_tensors, dim=0) + elif "image" in result and isinstance(result["image"], dict) and "url" in result["image"]: + from comfy_api_nodes.util import download_url_to_image_tensor + output_images = await download_url_to_image_tensor(result["image"]["url"]) + + if output_images is None: + output_images = torch.zeros((1, 64, 64, 3)) + + # Return result JSON and images + result_str = json.dumps(result, indent=2, default=str) + return IO.NodeOutput(result_str, output_images) + + +class FalExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [FalGenericNode] + + +async def comfy_entrypoint() -> FalExtension: + return FalExtension() diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index 8225ea67e917..290f7d1df81d 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -40,12 +40,12 @@ get_number_of_images, sync_op, tensor_to_base64_string, - upload_images_to_comfyapi, validate_string, video_to_base64_string, ) +from comfy_api_nodes.util._helpers import get_google_auth_header -GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini" +GEMINI_BASE_URL = "https://generativelanguage.googleapis.com/v1beta/models" GEMINI_MAX_INPUT_FILE_SIZE = 20 * 1024 * 1024 # 20 MB GEMINI_IMAGE_SYS_PROMPT = ( "You are an expert image-generation engine. You must ALWAYS produce an image.\n" @@ -96,23 +96,9 @@ async def create_image_parts( # If image_limit == 0 --> use all images; otherwise clamp to image_limit. effective_max = total_images if image_limit == 0 else min(total_images, image_limit) - # Number of images we'll send as URLs (fileData) - num_url_images = min(effective_max, 10) # Vertex API max number of image links - reference_images_urls = await upload_images_to_comfyapi( - cls, - images, - max_images=num_url_images, - ) - for reference_image_url in reference_images_urls: - image_parts.append( - GeminiPart( - fileData=GeminiFileData( - mimeType=GeminiMimeType.image_png, - fileUri=reference_image_url, - ) - ) - ) - for idx in range(num_url_images, effective_max): + # BYOK: Send all images as inline base64 (Google supports up to 100MB per request). + # No need to upload to external storage first. + for idx in range(effective_max): image_parts.append( GeminiPart( inlineData=GeminiInlineData( @@ -343,40 +329,9 @@ def define_schema(cls): IO.String.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["model"]), - expr=""" - ( - $m := widgets.model; - $contains($m, "gemini-2.5-flash") ? { - "type": "list_usd", - "usd": [0.0003, 0.0025], - "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens"} - } - : $contains($m, "gemini-2.5-pro") ? { - "type": "list_usd", - "usd": [0.00125, 0.01], - "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } - } - : ($contains($m, "gemini-3-pro-preview") or $contains($m, "gemini-3-1-pro")) ? { - "type": "list_usd", - "usd": [0.002, 0.012], - "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } - } - : $contains($m, "gemini-3-1-flash-lite") ? { - "type": "list_usd", - "usd": [0.00025, 0.0015], - "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } - } - : {"type":"text", "text":"Token-based"} - ) - """, - ), ) @classmethod @@ -464,7 +419,11 @@ async def execute( response = await sync_op( cls, - endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"), + endpoint=ApiEndpoint( + path=f"{GEMINI_BASE_URL}/{model}:generateContent", + method="POST", + headers=get_google_auth_header(), + ), data=GeminiGenerateContentRequest( contents=[ GeminiContent( @@ -637,14 +596,9 @@ def define_schema(cls): IO.String.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.039,"format":{"suffix":"/Image (1K)","approximate":true}}""", - ), ) @classmethod @@ -677,7 +631,7 @@ async def execute( response = await sync_op( cls, - ApiEndpoint(path=f"/proxy/vertexai/gemini/{model}", method="POST"), + ApiEndpoint(path=f"{GEMINI_BASE_URL}/{model}:generateContent", method="POST", headers=get_google_auth_header()), data=GeminiImageGenerateContentRequest( contents=[ GeminiContent(role=GeminiRole.user, parts=parts), @@ -772,12 +726,9 @@ def define_schema(cls): IO.String.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=GEMINI_IMAGE_2_PRICE_BADGE, ) @classmethod @@ -815,7 +766,7 @@ async def execute( response = await sync_op( cls, - ApiEndpoint(path=f"/proxy/vertexai/gemini/{model}", method="POST"), + ApiEndpoint(path=f"{GEMINI_BASE_URL}/{model}:generateContent", method="POST", headers=get_google_auth_header()), data=GeminiImageGenerateContentRequest( contents=[ GeminiContent(role=GeminiRole.user, parts=parts), @@ -933,12 +884,9 @@ def define_schema(cls): IO.String.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=GEMINI_IMAGE_2_PRICE_BADGE, ) @classmethod @@ -977,7 +925,7 @@ async def execute( response = await sync_op( cls, - ApiEndpoint(path=f"/proxy/vertexai/gemini/{model}", method="POST"), + ApiEndpoint(path=f"{GEMINI_BASE_URL}/{model}:generateContent", method="POST", headers=get_google_auth_header()), data=GeminiImageGenerateContentRequest( contents=[ GeminiContent(role=GeminiRole.user, parts=parts), diff --git a/comfy_api_nodes/nodes_grok.py b/comfy_api_nodes/nodes_grok.py deleted file mode 100644 index 0716d623964b..000000000000 --- a/comfy_api_nodes/nodes_grok.py +++ /dev/null @@ -1,477 +0,0 @@ -import torch -from typing_extensions import override - -from comfy_api.latest import IO, ComfyExtension, Input -from comfy_api_nodes.apis.grok import ( - ImageEditRequest, - ImageGenerationRequest, - ImageGenerationResponse, - InputUrlObject, - VideoEditRequest, - VideoGenerationRequest, - VideoGenerationResponse, - VideoStatusResponse, -) -from comfy_api_nodes.util import ( - ApiEndpoint, - download_url_to_image_tensor, - download_url_to_video_output, - get_fs_object_size, - get_number_of_images, - poll_op, - sync_op, - tensor_to_base64_string, - upload_video_to_comfyapi, - validate_string, - validate_video_duration, -) - - -def _extract_grok_price(response) -> float | None: - if response.usage and response.usage.cost_in_usd_ticks is not None: - return response.usage.cost_in_usd_ticks / 10_000_000_000 - return None - - -class GrokImageNode(IO.ComfyNode): - - @classmethod - def define_schema(cls): - return IO.Schema( - node_id="GrokImageNode", - display_name="Grok Image", - category="api node/image/Grok", - description="Generate images using Grok based on a text prompt", - inputs=[ - IO.Combo.Input( - "model", - options=["grok-imagine-image-pro", "grok-imagine-image", "grok-imagine-image-beta"], - ), - IO.String.Input( - "prompt", - multiline=True, - tooltip="The text prompt used to generate the image", - ), - IO.Combo.Input( - "aspect_ratio", - options=[ - "1:1", - "2:3", - "3:2", - "3:4", - "4:3", - "9:16", - "16:9", - "9:19.5", - "19.5:9", - "9:20", - "20:9", - "1:2", - "2:1", - ], - ), - IO.Int.Input( - "number_of_images", - default=1, - min=1, - max=10, - step=1, - tooltip="Number of images to generate", - display_mode=IO.NumberDisplay.number, - ), - IO.Int.Input( - "seed", - default=0, - min=0, - max=2147483647, - step=1, - display_mode=IO.NumberDisplay.number, - control_after_generate=True, - tooltip="Seed to determine if node should re-run; " - "actual results are nondeterministic regardless of seed.", - ), - IO.Combo.Input("resolution", options=["1K", "2K"], optional=True), - ], - outputs=[ - IO.Image.Output(), - ], - hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, - IO.Hidden.unique_id, - ], - is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["model", "number_of_images"]), - expr=""" - ( - $rate := $contains(widgets.model, "pro") ? 0.07 : 0.02; - {"type":"usd","usd": $rate * widgets.number_of_images} - ) - """, - ), - ) - - @classmethod - async def execute( - cls, - model: str, - prompt: str, - aspect_ratio: str, - number_of_images: int, - seed: int, - resolution: str = "1K", - ) -> IO.NodeOutput: - validate_string(prompt, strip_whitespace=True, min_length=1) - response = await sync_op( - cls, - ApiEndpoint(path="/proxy/xai/v1/images/generations", method="POST"), - data=ImageGenerationRequest( - model=model, - prompt=prompt, - aspect_ratio=aspect_ratio, - n=number_of_images, - seed=seed, - resolution=resolution.lower(), - ), - response_model=ImageGenerationResponse, - price_extractor=_extract_grok_price, - ) - if len(response.data) == 1: - return IO.NodeOutput(await download_url_to_image_tensor(response.data[0].url)) - return IO.NodeOutput( - torch.cat( - [await download_url_to_image_tensor(i) for i in [str(d.url) for d in response.data if d.url]], - ) - ) - - -class GrokImageEditNode(IO.ComfyNode): - - @classmethod - def define_schema(cls): - return IO.Schema( - node_id="GrokImageEditNode", - display_name="Grok Image Edit", - category="api node/image/Grok", - description="Modify an existing image based on a text prompt", - inputs=[ - IO.Combo.Input( - "model", - options=["grok-imagine-image-pro", "grok-imagine-image", "grok-imagine-image-beta"], - ), - IO.Image.Input("image", display_name="images"), - IO.String.Input( - "prompt", - multiline=True, - tooltip="The text prompt used to generate the image", - ), - IO.Combo.Input("resolution", options=["1K", "2K"]), - IO.Int.Input( - "number_of_images", - default=1, - min=1, - max=10, - step=1, - tooltip="Number of edited images to generate", - display_mode=IO.NumberDisplay.number, - ), - IO.Int.Input( - "seed", - default=0, - min=0, - max=2147483647, - step=1, - display_mode=IO.NumberDisplay.number, - control_after_generate=True, - tooltip="Seed to determine if node should re-run; " - "actual results are nondeterministic regardless of seed.", - ), - IO.Combo.Input( - "aspect_ratio", - options=[ - "auto", - "1:1", - "2:3", - "3:2", - "3:4", - "4:3", - "9:16", - "16:9", - "9:19.5", - "19.5:9", - "9:20", - "20:9", - "1:2", - "2:1", - ], - optional=True, - tooltip="Only allowed when multiple images are connected to the image input.", - ), - ], - outputs=[ - IO.Image.Output(), - ], - hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, - IO.Hidden.unique_id, - ], - is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["model", "number_of_images"]), - expr=""" - ( - $rate := $contains(widgets.model, "pro") ? 0.07 : 0.02; - {"type":"usd","usd": 0.002 + $rate * widgets.number_of_images} - ) - """, - ), - ) - - @classmethod - async def execute( - cls, - model: str, - image: Input.Image, - prompt: str, - resolution: str, - number_of_images: int, - seed: int, - aspect_ratio: str = "auto", - ) -> IO.NodeOutput: - validate_string(prompt, strip_whitespace=True, min_length=1) - if model == "grok-imagine-image-pro": - if get_number_of_images(image) > 1: - raise ValueError("The pro model supports only 1 input image.") - elif get_number_of_images(image) > 3: - raise ValueError("A maximum of 3 input images is supported.") - if aspect_ratio != "auto" and get_number_of_images(image) == 1: - raise ValueError( - "Custom aspect ratio is only allowed when multiple images are connected to the image input." - ) - response = await sync_op( - cls, - ApiEndpoint(path="/proxy/xai/v1/images/edits", method="POST"), - data=ImageEditRequest( - model=model, - images=[InputUrlObject(url=f"data:image/png;base64,{tensor_to_base64_string(i)}") for i in image], - prompt=prompt, - resolution=resolution.lower(), - n=number_of_images, - seed=seed, - aspect_ratio=None if aspect_ratio == "auto" else aspect_ratio, - ), - response_model=ImageGenerationResponse, - price_extractor=_extract_grok_price, - ) - if len(response.data) == 1: - return IO.NodeOutput(await download_url_to_image_tensor(response.data[0].url)) - return IO.NodeOutput( - torch.cat( - [await download_url_to_image_tensor(i) for i in [str(d.url) for d in response.data if d.url]], - ) - ) - - -class GrokVideoNode(IO.ComfyNode): - - @classmethod - def define_schema(cls): - return IO.Schema( - node_id="GrokVideoNode", - display_name="Grok Video", - category="api node/video/Grok", - description="Generate video from a prompt or an image", - inputs=[ - IO.Combo.Input("model", options=["grok-imagine-video", "grok-imagine-video-beta"]), - IO.String.Input( - "prompt", - multiline=True, - tooltip="Text description of the desired video.", - ), - IO.Combo.Input( - "resolution", - options=["480p", "720p"], - tooltip="The resolution of the output video.", - ), - IO.Combo.Input( - "aspect_ratio", - options=["auto", "16:9", "4:3", "3:2", "1:1", "2:3", "3:4", "9:16"], - tooltip="The aspect ratio of the output video.", - ), - IO.Int.Input( - "duration", - default=6, - min=1, - max=15, - step=1, - tooltip="The duration of the output video in seconds.", - display_mode=IO.NumberDisplay.slider, - ), - IO.Int.Input( - "seed", - default=0, - min=0, - max=2147483647, - step=1, - display_mode=IO.NumberDisplay.number, - control_after_generate=True, - tooltip="Seed to determine if node should re-run; " - "actual results are nondeterministic regardless of seed.", - ), - IO.Image.Input("image", optional=True), - ], - outputs=[ - IO.Video.Output(), - ], - hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, - IO.Hidden.unique_id, - ], - is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"], inputs=["image"]), - expr=""" - ( - $rate := widgets.resolution = "720p" ? 0.07 : 0.05; - $base := $rate * widgets.duration; - {"type":"usd","usd": inputs.image.connected ? $base + 0.002 : $base} - ) - """, - ), - ) - - @classmethod - async def execute( - cls, - model: str, - prompt: str, - resolution: str, - aspect_ratio: str, - duration: int, - seed: int, - image: Input.Image | None = None, - ) -> IO.NodeOutput: - image_url = None - if image is not None: - if get_number_of_images(image) != 1: - raise ValueError("Only one input image is supported.") - image_url = InputUrlObject(url=f"data:image/png;base64,{tensor_to_base64_string(image)}") - validate_string(prompt, strip_whitespace=True, min_length=1) - initial_response = await sync_op( - cls, - ApiEndpoint(path="/proxy/xai/v1/videos/generations", method="POST"), - data=VideoGenerationRequest( - model=model, - image=image_url, - prompt=prompt, - resolution=resolution, - duration=duration, - aspect_ratio=None if aspect_ratio == "auto" else aspect_ratio, - seed=seed, - ), - response_model=VideoGenerationResponse, - ) - response = await poll_op( - cls, - ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"), - status_extractor=lambda r: r.status if r.status is not None else "complete", - response_model=VideoStatusResponse, - price_extractor=_extract_grok_price, - ) - return IO.NodeOutput(await download_url_to_video_output(response.video.url)) - - -class GrokVideoEditNode(IO.ComfyNode): - - @classmethod - def define_schema(cls): - return IO.Schema( - node_id="GrokVideoEditNode", - display_name="Grok Video Edit", - category="api node/video/Grok", - description="Edit an existing video based on a text prompt.", - inputs=[ - IO.Combo.Input("model", options=["grok-imagine-video", "grok-imagine-video-beta"]), - IO.String.Input( - "prompt", - multiline=True, - tooltip="Text description of the desired video.", - ), - IO.Video.Input("video", tooltip="Maximum supported duration is 8.7 seconds and 50MB file size."), - IO.Int.Input( - "seed", - default=0, - min=0, - max=2147483647, - step=1, - display_mode=IO.NumberDisplay.number, - control_after_generate=True, - tooltip="Seed to determine if node should re-run; " - "actual results are nondeterministic regardless of seed.", - ), - ], - outputs=[ - IO.Video.Output(), - ], - hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, - IO.Hidden.unique_id, - ], - is_api_node=True, - price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd": 0.06, "format": {"suffix": "/sec", "approximate": true}}""", - ), - ) - - @classmethod - async def execute( - cls, - model: str, - prompt: str, - video: Input.Video, - seed: int, - ) -> IO.NodeOutput: - validate_string(prompt, strip_whitespace=True, min_length=1) - validate_video_duration(video, min_duration=1, max_duration=8.7) - video_stream = video.get_stream_source() - video_size = get_fs_object_size(video_stream) - if video_size > 50 * 1024 * 1024: - raise ValueError(f"Video size ({video_size / 1024 / 1024:.1f}MB) exceeds 50MB limit.") - initial_response = await sync_op( - cls, - ApiEndpoint(path="/proxy/xai/v1/videos/edits", method="POST"), - data=VideoEditRequest( - model=model, - video=InputUrlObject(url=await upload_video_to_comfyapi(cls, video)), - prompt=prompt, - seed=seed, - ), - response_model=VideoGenerationResponse, - ) - response = await poll_op( - cls, - ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"), - status_extractor=lambda r: r.status if r.status is not None else "complete", - response_model=VideoStatusResponse, - price_extractor=_extract_grok_price, - ) - return IO.NodeOutput(await download_url_to_video_output(response.video.url)) - - -class GrokExtension(ComfyExtension): - @override - async def get_node_list(self) -> list[type[IO.ComfyNode]]: - return [ - GrokImageNode, - GrokImageEditNode, - GrokVideoNode, - GrokVideoEditNode, - ] - - -async def comfy_entrypoint() -> GrokExtension: - return GrokExtension() diff --git a/comfy_api_nodes/nodes_hitpaw.py b/comfy_api_nodes/nodes_hitpaw.py deleted file mode 100644 index 488080a74a40..000000000000 --- a/comfy_api_nodes/nodes_hitpaw.py +++ /dev/null @@ -1,342 +0,0 @@ -import math - -from typing_extensions import override - -from comfy_api.latest import IO, ComfyExtension, Input -from comfy_api_nodes.apis.hitpaw import ( - ImageEnhanceTaskCreateRequest, - InputVideoModel, - TaskCreateDataResponse, - TaskCreateResponse, - TaskStatusPollRequest, - TaskStatusResponse, - VideoEnhanceTaskCreateRequest, -) -from comfy_api_nodes.util import ( - ApiEndpoint, - download_url_to_image_tensor, - download_url_to_video_output, - downscale_image_tensor, - get_image_dimensions, - poll_op, - sync_op, - upload_image_to_comfyapi, - upload_video_to_comfyapi, - validate_video_duration, -) - -VIDEO_MODELS_MODELS_MAP = { - "Portrait Restore Model (1x)": "portrait_restore_1x", - "Portrait Restore Model (2x)": "portrait_restore_2x", - "General Restore Model (1x)": "general_restore_1x", - "General Restore Model (2x)": "general_restore_2x", - "General Restore Model (4x)": "general_restore_4x", - "Ultra HD Model (2x)": "ultrahd_restore_2x", - "Generative Model (1x)": "generative_1x", -} - -# Resolution name to target dimension (shorter side) in pixels -RESOLUTION_TARGET_MAP = { - "720p": 720, - "1080p": 1080, - "2K/QHD": 1440, - "4K/UHD": 2160, - "8K": 4320, -} - -# Square (1:1) resolutions use standard square dimensions -RESOLUTION_SQUARE_MAP = { - "720p": 720, - "1080p": 1080, - "2K/QHD": 1440, - "4K/UHD": 2048, # DCI 4K square - "8K": 4096, # DCI 8K square -} - -# Models with limited resolution support (no 8K) -LIMITED_RESOLUTION_MODELS = {"Generative Model (1x)"} - -# Resolution options for different model types -RESOLUTIONS_LIMITED = ["original", "720p", "1080p", "2K/QHD", "4K/UHD"] -RESOLUTIONS_FULL = ["original", "720p", "1080p", "2K/QHD", "4K/UHD", "8K"] - -# Maximum output resolution in pixels -MAX_PIXELS_GENERATIVE = 32_000_000 -MAX_MP_GENERATIVE = MAX_PIXELS_GENERATIVE // 1_000_000 - - -class HitPawGeneralImageEnhance(IO.ComfyNode): - @classmethod - def define_schema(cls): - return IO.Schema( - node_id="HitPawGeneralImageEnhance", - display_name="HitPaw General Image Enhance", - category="api node/image/HitPaw", - description="Upscale low-resolution images to super-resolution, eliminate artifacts and noise. " - f"Maximum output: {MAX_MP_GENERATIVE} megapixels.", - inputs=[ - IO.Combo.Input("model", options=["generative_portrait", "generative"]), - IO.Image.Input("image"), - IO.Combo.Input("upscale_factor", options=[1, 2, 4]), - IO.Boolean.Input( - "auto_downscale", - default=False, - tooltip="Automatically downscale input image if output would exceed the limit.", - ), - ], - outputs=[ - IO.Image.Output(), - ], - hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, - IO.Hidden.unique_id, - ], - is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["model"]), - expr=""" - ( - $prices := { - "generative_portrait": {"min": 0.02, "max": 0.06}, - "generative": {"min": 0.05, "max": 0.15} - }; - $price := $lookup($prices, widgets.model); - { - "type": "range_usd", - "min_usd": $price.min, - "max_usd": $price.max - } - ) - """, - ), - ) - - @classmethod - async def execute( - cls, - model: str, - image: Input.Image, - upscale_factor: int, - auto_downscale: bool, - ) -> IO.NodeOutput: - height, width = get_image_dimensions(image) - requested_scale = upscale_factor - output_pixels = height * width * requested_scale * requested_scale - if output_pixels > MAX_PIXELS_GENERATIVE: - if auto_downscale: - input_pixels = width * height - scale = 1 - max_input_pixels = MAX_PIXELS_GENERATIVE - - for candidate in [4, 2, 1]: - if candidate > requested_scale: - continue - scale_output_pixels = input_pixels * candidate * candidate - if scale_output_pixels <= MAX_PIXELS_GENERATIVE: - scale = candidate - max_input_pixels = None - break - # Check if we can downscale input by at most 2x to fit - downscale_ratio = math.sqrt(scale_output_pixels / MAX_PIXELS_GENERATIVE) - if downscale_ratio <= 2.0: - scale = candidate - max_input_pixels = MAX_PIXELS_GENERATIVE // (candidate * candidate) - break - - if max_input_pixels is not None: - image = downscale_image_tensor(image, total_pixels=max_input_pixels) - upscale_factor = scale - else: - output_width = width * requested_scale - output_height = height * requested_scale - raise ValueError( - f"Output size ({output_width}x{output_height} = {output_pixels:,} pixels) " - f"exceeds maximum allowed size of {MAX_PIXELS_GENERATIVE:,} pixels ({MAX_MP_GENERATIVE}MP). " - f"Enable auto_downscale or use a smaller input image or a lower upscale factor." - ) - - initial_res = await sync_op( - cls, - ApiEndpoint(path="/proxy/hitpaw/api/photo-enhancer", method="POST"), - response_model=TaskCreateResponse, - data=ImageEnhanceTaskCreateRequest( - model_name=f"{model}_{upscale_factor}x", - img_url=await upload_image_to_comfyapi(cls, image, total_pixels=None), - ), - wait_label="Creating task", - final_label_on_success="Task created", - ) - if initial_res.code != 200: - raise ValueError(f"Task creation failed with code {initial_res.code}: {initial_res.message}") - request_price = initial_res.data.consume_coins / 1000 - final_response = await poll_op( - cls, - ApiEndpoint(path="/proxy/hitpaw/api/task-status", method="POST"), - data=TaskCreateDataResponse(job_id=initial_res.data.job_id), - response_model=TaskStatusResponse, - status_extractor=lambda x: x.data.status, - price_extractor=lambda x: request_price, - poll_interval=10.0, - max_poll_attempts=480, - ) - return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.res_url)) - - -class HitPawVideoEnhance(IO.ComfyNode): - @classmethod - def define_schema(cls): - model_options = [] - for model_name in VIDEO_MODELS_MODELS_MAP: - if model_name in LIMITED_RESOLUTION_MODELS: - resolutions = RESOLUTIONS_LIMITED - else: - resolutions = RESOLUTIONS_FULL - model_options.append( - IO.DynamicCombo.Option( - model_name, - [IO.Combo.Input("resolution", options=resolutions)], - ) - ) - - return IO.Schema( - node_id="HitPawVideoEnhance", - display_name="HitPaw Video Enhance", - category="api node/video/HitPaw", - description="Upscale low-resolution videos to high resolution, eliminate artifacts and noise. " - "Prices shown are per second of video.", - inputs=[ - IO.DynamicCombo.Input("model", options=model_options), - IO.Video.Input("video"), - ], - outputs=[ - IO.Video.Output(), - ], - hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, - IO.Hidden.unique_id, - ], - is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["model", "model.resolution"]), - expr=""" - ( - $m := $lookup(widgets, "model"); - $res := $lookup(widgets, "model.resolution"); - $standard_model_prices := { - "original": {"min": 0.01, "max": 0.198}, - "720p": {"min": 0.01, "max": 0.06}, - "1080p": {"min": 0.015, "max": 0.09}, - "2k/qhd": {"min": 0.02, "max": 0.117}, - "4k/uhd": {"min": 0.025, "max": 0.152}, - "8k": {"min": 0.033, "max": 0.198} - }; - $ultra_hd_model_prices := { - "original": {"min": 0.015, "max": 0.264}, - "720p": {"min": 0.015, "max": 0.092}, - "1080p": {"min": 0.02, "max": 0.12}, - "2k/qhd": {"min": 0.026, "max": 0.156}, - "4k/uhd": {"min": 0.034, "max": 0.203}, - "8k": {"min": 0.044, "max": 0.264} - }; - $generative_model_prices := { - "original": {"min": 0.015, "max": 0.338}, - "720p": {"min": 0.008, "max": 0.090}, - "1080p": {"min": 0.05, "max": 0.15}, - "2k/qhd": {"min": 0.038, "max": 0.225}, - "4k/uhd": {"min": 0.056, "max": 0.338} - }; - $prices := $contains($m, "ultra hd") ? $ultra_hd_model_prices : - $contains($m, "generative") ? $generative_model_prices : - $standard_model_prices; - $price := $lookup($prices, $res); - { - "type": "range_usd", - "min_usd": $price.min, - "max_usd": $price.max, - "format": {"approximate": true, "suffix": "/second"} - } - ) - """, - ), - ) - - @classmethod - async def execute( - cls, - model: InputVideoModel, - video: Input.Video, - ) -> IO.NodeOutput: - validate_video_duration(video, min_duration=0.5, max_duration=60 * 60) - resolution = model["resolution"] - src_width, src_height = video.get_dimensions() - - if resolution == "original": - output_width = src_width - output_height = src_height - else: - if src_width == src_height: - target_size = RESOLUTION_SQUARE_MAP[resolution] - if target_size < src_width: - raise ValueError( - f"Selected resolution {resolution} ({target_size}x{target_size}) is smaller than " - f"the input video ({src_width}x{src_height}). Please select a higher resolution or 'original'." - ) - output_width = target_size - output_height = target_size - else: - min_dimension = min(src_width, src_height) - target_size = RESOLUTION_TARGET_MAP[resolution] - if target_size < min_dimension: - raise ValueError( - f"Selected resolution {resolution} ({target_size}p) is smaller than " - f"the input video's shorter dimension ({min_dimension}p). " - f"Please select a higher resolution or 'original'." - ) - if src_width > src_height: - output_height = target_size - output_width = int(target_size * (src_width / src_height)) - else: - output_width = target_size - output_height = int(target_size * (src_height / src_width)) - initial_res = await sync_op( - cls, - ApiEndpoint(path="/proxy/hitpaw/api/video-enhancer", method="POST"), - response_model=TaskCreateResponse, - data=VideoEnhanceTaskCreateRequest( - video_url=await upload_video_to_comfyapi(cls, video), - resolution=[output_width, output_height], - original_resolution=[src_width, src_height], - model_name=VIDEO_MODELS_MODELS_MAP[model["model"]], - ), - wait_label="Creating task", - final_label_on_success="Task created", - ) - request_price = initial_res.data.consume_coins / 1000 - if initial_res.code != 200: - raise ValueError(f"Task creation failed with code {initial_res.code}: {initial_res.message}") - final_response = await poll_op( - cls, - ApiEndpoint(path="/proxy/hitpaw/api/task-status", method="POST"), - data=TaskStatusPollRequest(job_id=initial_res.data.job_id), - response_model=TaskStatusResponse, - status_extractor=lambda x: x.data.status, - price_extractor=lambda x: request_price, - poll_interval=10.0, - max_poll_attempts=320, - ) - return IO.NodeOutput(await download_url_to_video_output(final_response.data.res_url)) - - -class HitPawExtension(ComfyExtension): - @override - async def get_node_list(self) -> list[type[IO.ComfyNode]]: - return [ - HitPawGeneralImageEnhance, - HitPawVideoEnhance, - ] - - -async def comfy_entrypoint() -> HitPawExtension: - return HitPawExtension() diff --git a/comfy_api_nodes/nodes_hunyuan3d.py b/comfy_api_nodes/nodes_hunyuan3d.py index bd8bde9973c8..bee7f937bb56 100644 --- a/comfy_api_nodes/nodes_hunyuan3d.py +++ b/comfy_api_nodes/nodes_hunyuan3d.py @@ -2,51 +2,45 @@ from comfy_api.latest import IO, ComfyExtension, Input, Types from comfy_api_nodes.apis.hunyuan3d import ( - Hunyuan3DViewImage, InputGenerateType, ResultFile3D, - SmartTopologyRequest, - TaskFile3DInput, - TextureEditTaskRequest, - To3DPartTaskRequest, - To3DProTaskCreateResponse, - To3DProTaskQueryRequest, - To3DProTaskRequest, - To3DProTaskResultResponse, - To3DUVTaskRequest, ) from comfy_api_nodes.util import ( - ApiEndpoint, download_url_to_file_3d, downscale_image_tensor_by_max_side, - poll_op, - sync_op, - upload_3d_model_to_comfyapi, - upload_image_to_comfyapi, + upload_3d_model_to_fal, validate_image_dimensions, validate_string, ) +from comfy_api_nodes.util.client import fal_run +from comfy_api_nodes.util.upload_helpers import upload_image_to_fal - -def _is_tencent_rate_limited(status: int, body: object) -> bool: - return ( - status == 400 - and isinstance(body, dict) - and "RequestLimitExceeded" in str(body.get("Response", {}).get("Error", {}).get("Code", "")) - ) +FAL_HUNYUAN3D_V2 = "fal-ai/hunyuan3d/v2" def get_file_from_response( - response_objs: list[ResultFile3D], file_type: str, raise_if_not_found: bool = True -) -> ResultFile3D | None: + response_objs: list, file_type: str, raise_if_not_found: bool = True +): + """Extract file of given type from response list (works with dicts or ResultFile3D objects).""" for i in response_objs: - if i.Type.lower() == file_type.lower(): - return i + if isinstance(i, dict): + if i.get("Type", "").lower() == file_type.lower(): + return i + else: + if i.Type.lower() == file_type.lower(): + return i if raise_if_not_found: raise ValueError(f"'{file_type}' file type is not found in the response.") return None +def _get_url(file_obj) -> str: + """Get URL from a file object (dict or ResultFile3D).""" + if isinstance(file_obj, dict): + return file_obj["Url"] + return file_obj.Url + + class TencentTextToModelNode(IO.ComfyNode): @classmethod @@ -95,23 +89,10 @@ def define_schema(cls): IO.File3DOBJ.Output(display_name="OBJ"), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, is_output_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["generate_type", "generate_type.pbr", "face_count"]), - expr=""" - ( - $base := widgets.generate_type = "normal" ? 25 : widgets.generate_type = "lowpoly" ? 30 : 15; - $pbr := $lookup(widgets, "generate_type.pbr") ? 10 : 0; - $face := widgets.face_count != 500000 ? 10 : 0; - {"type":"usd","usd": ($base + $pbr + $face) * 0.02} - ) - """, - ), ) @classmethod @@ -127,37 +108,26 @@ async def execute( validate_string(prompt, field_name="prompt", min_length=1, max_length=1024) if model == "3.1" and generate_type["generate_type"].lower() == "lowpoly": raise ValueError("The LowPoly option is currently unavailable for the 3.1 model.") - response = await sync_op( - cls, - ApiEndpoint(path="/proxy/tencent/hunyuan/3d-pro", method="POST"), - response_model=To3DProTaskCreateResponse, - data=To3DProTaskRequest( - Model=model, - Prompt=prompt, - FaceCount=face_count, - GenerateType=generate_type["generate_type"], - EnablePBR=generate_type.get("pbr", None), - PolygonType=generate_type.get("polygon_type", None), - ), - is_rate_limited=_is_tencent_rate_limited, - ) - if response.Error: - raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}") - task_id = response.JobId - result = await poll_op( - cls, - ApiEndpoint(path="/proxy/tencent/hunyuan/3d-pro/query", method="POST"), - data=To3DProTaskQueryRequest(JobId=task_id), - response_model=To3DProTaskResultResponse, - status_extractor=lambda r: r.Status, - ) + result = await fal_run(cls, FAL_HUNYUAN3D_V2, { + "Model": model, + "Prompt": prompt, + "FaceCount": face_count, + "GenerateType": generate_type["generate_type"], + "EnablePBR": generate_type.get("pbr", None), + "PolygonType": generate_type.get("polygon_type", None), + }) + # TODO: verify fal.ai field names + if result.get("Error"): + raise ValueError(f"Task creation failed with code {result['Error']['Code']}: {result['Error']['Message']}") + task_id = result.get("JobId", "hunyuan_task") + file_3ds = result["ResultFile3Ds"] return IO.NodeOutput( f"{task_id}.glb", await download_url_to_file_3d( - get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id + get_file_from_response(file_3ds, "glb")["Url"], "glb", task_id=task_id ), await download_url_to_file_3d( - get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj", task_id=task_id + get_file_from_response(file_3ds, "obj")["Url"], "obj", task_id=task_id ), ) @@ -213,29 +183,10 @@ def define_schema(cls): IO.File3DOBJ.Output(display_name="OBJ"), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, is_output_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends( - widgets=["generate_type", "generate_type.pbr", "face_count"], - inputs=["image_left", "image_right", "image_back"], - ), - expr=""" - ( - $base := widgets.generate_type = "normal" ? 25 : widgets.generate_type = "lowpoly" ? 30 : 15; - $multiview := ( - inputs.image_left.connected or inputs.image_right.connected or inputs.image_back.connected - ) ? 10 : 0; - $pbr := $lookup(widgets, "generate_type.pbr") ? 10 : 0; - $face := widgets.face_count != 500000 ? 10 : 0; - {"type":"usd","usd": ($base + $multiview + $pbr + $face) * 0.02} - ) - """, - ), ) @classmethod @@ -263,54 +214,37 @@ async def execute( if v is None: continue validate_image_dimensions(v, min_width=128, min_height=128) - multiview_images.append( - Hunyuan3DViewImage( - ViewType=k, - ViewImageUrl=await upload_image_to_comfyapi( - cls, - downscale_image_tensor_by_max_side(v, max_side=4900), - mime_type="image/webp", - total_pixels=24_010_000, - ), - ) + view_image_url = await upload_image_to_fal( + downscale_image_tensor_by_max_side(v, max_side=4900), ) - response = await sync_op( - cls, - ApiEndpoint(path="/proxy/tencent/hunyuan/3d-pro", method="POST"), - response_model=To3DProTaskCreateResponse, - data=To3DProTaskRequest( - Model=model, - FaceCount=face_count, - GenerateType=generate_type["generate_type"], - ImageUrl=await upload_image_to_comfyapi( - cls, - downscale_image_tensor_by_max_side(image, max_side=4900), - mime_type="image/webp", - total_pixels=24_010_000, - ), - MultiViewImages=multiview_images if multiview_images else None, - EnablePBR=generate_type.get("pbr", None), - PolygonType=generate_type.get("polygon_type", None), - ), - is_rate_limited=_is_tencent_rate_limited, - ) - if response.Error: - raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}") - task_id = response.JobId - result = await poll_op( - cls, - ApiEndpoint(path="/proxy/tencent/hunyuan/3d-pro/query", method="POST"), - data=To3DProTaskQueryRequest(JobId=task_id), - response_model=To3DProTaskResultResponse, - status_extractor=lambda r: r.Status, + multiview_images.append({ + "ViewType": k, + "ViewImageUrl": view_image_url, + }) + image_url = await upload_image_to_fal( + downscale_image_tensor_by_max_side(image, max_side=4900), ) + result = await fal_run(cls, FAL_HUNYUAN3D_V2, { + "Model": model, + "FaceCount": face_count, + "GenerateType": generate_type["generate_type"], + "ImageUrl": image_url, + "MultiViewImages": multiview_images if multiview_images else None, + "EnablePBR": generate_type.get("pbr", None), + "PolygonType": generate_type.get("polygon_type", None), + }) + # TODO: verify fal.ai field names + if result.get("Error"): + raise ValueError(f"Task creation failed with code {result['Error']['Code']}: {result['Error']['Message']}") + task_id = result.get("JobId", "hunyuan_task") + file_3ds = result["ResultFile3Ds"] return IO.NodeOutput( f"{task_id}.glb", await download_url_to_file_3d( - get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id + get_file_from_response(file_3ds, "glb")["Url"], "glb", task_id=task_id ), await download_url_to_file_3d( - get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj", task_id=task_id + get_file_from_response(file_3ds, "obj")["Url"], "obj", task_id=task_id ), ) @@ -347,12 +281,9 @@ def define_schema(cls): IO.File3DFBX.Output(display_name="FBX"), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge(expr='{"type":"usd","usd":0.2}'), ) SUPPORTED_FORMATS = {"glb", "obj", "fbx"} @@ -370,30 +301,20 @@ async def execute( f"Unsupported file format: '{file_format}'. " f"Supported formats: {', '.join(sorted(cls.SUPPORTED_FORMATS))}." ) - response = await sync_op( - cls, - ApiEndpoint(path="/proxy/tencent/hunyuan/3d-uv", method="POST"), - response_model=To3DProTaskCreateResponse, - data=To3DUVTaskRequest( - File=TaskFile3DInput( - Type=file_format.upper(), - Url=await upload_3d_model_to_comfyapi(cls, model_3d, file_format), - ) - ), - is_rate_limited=_is_tencent_rate_limited, - ) - if response.Error: - raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}") - result = await poll_op( - cls, - ApiEndpoint(path="/proxy/tencent/hunyuan/3d-uv/query", method="POST"), - data=To3DProTaskQueryRequest(JobId=response.JobId), - response_model=To3DProTaskResultResponse, - status_extractor=lambda r: r.Status, - ) + model_url = await upload_3d_model_to_fal(model_3d, file_format) + result = await fal_run(cls, FAL_HUNYUAN3D_V2, { + "File": { + "Type": file_format.upper(), + "Url": model_url, + }, + }) + # TODO: verify fal.ai field names + if result.get("Error"): + raise ValueError(f"Task creation failed with code {result['Error']['Code']}: {result['Error']['Message']}") + file_3ds = result["ResultFile3Ds"] return IO.NodeOutput( - await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj"), - await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "fbx").Url, "fbx"), + await download_url_to_file_3d(get_file_from_response(file_3ds, "obj")["Url"], "obj"), + await download_url_to_file_3d(get_file_from_response(file_3ds, "fbx")["Url"], "fbx"), ) @@ -434,14 +355,9 @@ def define_schema(cls): IO.File3DFBX.Output(display_name="FBX"), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd": 0.6}""", - ), ) @classmethod @@ -456,31 +372,22 @@ async def execute( if file_format != "fbx": raise ValueError(f"Unsupported file format: '{file_format}'. Only FBX format is supported.") validate_string(prompt, field_name="prompt", min_length=1, max_length=1024) - model_url = await upload_3d_model_to_comfyapi(cls, model_3d, file_format) - response = await sync_op( - cls, - ApiEndpoint(path="/proxy/tencent/hunyuan/3d-texture-edit", method="POST"), - response_model=To3DProTaskCreateResponse, - data=TextureEditTaskRequest( - File3D=TaskFile3DInput(Type=file_format.upper(), Url=model_url), - Prompt=prompt, - EnablePBR=True, - ), - is_rate_limited=_is_tencent_rate_limited, - ) - if response.Error: - raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}") - - result = await poll_op( - cls, - ApiEndpoint(path="/proxy/tencent/hunyuan/3d-texture-edit/query", method="POST"), - data=To3DProTaskQueryRequest(JobId=response.JobId), - response_model=To3DProTaskResultResponse, - status_extractor=lambda r: r.Status, - ) + model_url = await upload_3d_model_to_fal(model_3d, file_format) + result = await fal_run(cls, FAL_HUNYUAN3D_V2, { + "File3D": { + "Type": file_format.upper(), + "Url": model_url, + }, + "Prompt": prompt, + "EnablePBR": True, + }) + # TODO: verify fal.ai field names + if result.get("Error"): + raise ValueError(f"Task creation failed with code {result['Error']['Code']}: {result['Error']['Message']}") + file_3ds = result["ResultFile3Ds"] return IO.NodeOutput( - await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb"), - await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "fbx").Url, "fbx"), + await download_url_to_file_3d(get_file_from_response(file_3ds, "glb")["Url"], "glb"), + await download_url_to_file_3d(get_file_from_response(file_3ds, "fbx")["Url"], "fbx"), ) @@ -514,12 +421,9 @@ def define_schema(cls): IO.File3DFBX.Output(display_name="FBX"), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge(expr='{"type":"usd","usd":0.6}'), ) @classmethod @@ -532,27 +436,19 @@ async def execute( file_format = model_3d.format.lower() if file_format != "fbx": raise ValueError(f"Unsupported file format: '{file_format}'. Only FBX format is supported.") - model_url = await upload_3d_model_to_comfyapi(cls, model_3d, file_format) - response = await sync_op( - cls, - ApiEndpoint(path="/proxy/tencent/hunyuan/3d-part", method="POST"), - response_model=To3DProTaskCreateResponse, - data=To3DPartTaskRequest( - File=TaskFile3DInput(Type=file_format.upper(), Url=model_url), - ), - is_rate_limited=_is_tencent_rate_limited, - ) - if response.Error: - raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}") - result = await poll_op( - cls, - ApiEndpoint(path="/proxy/tencent/hunyuan/3d-part/query", method="POST"), - data=To3DProTaskQueryRequest(JobId=response.JobId), - response_model=To3DProTaskResultResponse, - status_extractor=lambda r: r.Status, - ) + model_url = await upload_3d_model_to_fal(model_3d, file_format) + result = await fal_run(cls, FAL_HUNYUAN3D_V2, { + "File": { + "Type": file_format.upper(), + "Url": model_url, + }, + }) + # TODO: verify fal.ai field names + if result.get("Error"): + raise ValueError(f"Task creation failed with code {result['Error']['Code']}: {result['Error']['Message']}") + file_3ds = result["ResultFile3Ds"] return IO.NodeOutput( - await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "fbx").Url, "fbx"), + await download_url_to_file_3d(get_file_from_response(file_3ds, "fbx")["Url"], "fbx"), ) @@ -597,12 +493,9 @@ def define_schema(cls): IO.File3DOBJ.Output(display_name="OBJ"), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge(expr='{"type":"usd","usd":1.0}'), ) SUPPORTED_FORMATS = {"glb", "obj"} @@ -621,29 +514,21 @@ async def execute( raise ValueError( f"Unsupported file format: '{file_format}'. " f"Supported: {', '.join(sorted(cls.SUPPORTED_FORMATS))}." ) - model_url = await upload_3d_model_to_comfyapi(cls, model_3d, file_format) - response = await sync_op( - cls, - ApiEndpoint(path="/proxy/tencent/hunyuan/3d-smart-topology", method="POST"), - response_model=To3DProTaskCreateResponse, - data=SmartTopologyRequest( - File3D=TaskFile3DInput(Type=file_format.upper(), Url=model_url), - PolygonType=polygon_type, - FaceLevel=face_level, - ), - is_rate_limited=_is_tencent_rate_limited, - ) - if response.Error: - raise ValueError(f"Task creation failed: [{response.Error.Code}] {response.Error.Message}") - result = await poll_op( - cls, - ApiEndpoint(path="/proxy/tencent/hunyuan/3d-smart-topology/query", method="POST"), - data=To3DProTaskQueryRequest(JobId=response.JobId), - response_model=To3DProTaskResultResponse, - status_extractor=lambda r: r.Status, - ) + model_url = await upload_3d_model_to_fal(model_3d, file_format) + result = await fal_run(cls, FAL_HUNYUAN3D_V2, { + "File3D": { + "Type": file_format.upper(), + "Url": model_url, + }, + "PolygonType": polygon_type, + "FaceLevel": face_level, + }) + # TODO: verify fal.ai field names + if result.get("Error"): + raise ValueError(f"Task creation failed: [{result['Error']['Code']}] {result['Error']['Message']}") + file_3ds = result["ResultFile3Ds"] return IO.NodeOutput( - await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj"), + await download_url_to_file_3d(get_file_from_response(file_3ds, "obj")["Url"], "obj"), ) diff --git a/comfy_api_nodes/nodes_ideogram.py b/comfy_api_nodes/nodes_ideogram.py index 97c3609bd1c9..e19f8af64ae4 100644 --- a/comfy_api_nodes/nodes_ideogram.py +++ b/comfy_api_nodes/nodes_ideogram.py @@ -12,12 +12,13 @@ IdeogramV3EditRequest, ) from comfy_api_nodes.util import ( - ApiEndpoint, bytesio_to_image_tensor, download_url_as_bytesio, resize_mask_to_image, - sync_op, ) +from comfy_api_nodes.util.client import fal_run + +FAL_IDEOGRAM_V3 = "fal-ai/ideogram/v3" V1_V1_RES_MAP = { "Auto":"AUTO", @@ -294,21 +295,9 @@ def define_schema(cls): IO.Image.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["num_images", "turbo"]), - expr=""" - ( - $n := widgets.num_images; - $base := (widgets.turbo = true) ? 0.0286 : 0.0858; - {"type":"usd","usd": $round($base * $n, 2)} - ) - """, - ), ) @classmethod @@ -326,28 +315,28 @@ async def execute( aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None) model = "V_1_TURBO" if turbo else "V_1" - response = await sync_op( - cls, - ApiEndpoint(path="/proxy/ideogram/generate", method="POST"), - response_model=IdeogramGenerateResponse, - data=IdeogramGenerateRequest( - image_request=ImageRequest( - prompt=prompt, - model=model, - num_images=num_images, - seed=seed, - aspect_ratio=aspect_ratio if aspect_ratio != "ASPECT_1_1" else None, - magic_prompt_option=(magic_prompt_option if magic_prompt_option != "AUTO" else None), - negative_prompt=negative_prompt if negative_prompt else None, - ) - ), - max_retries=1, - ) - - if not response.data or len(response.data) == 0: + # TODO: fal.ai Ideogram V3 schema may differ from V1; verify input fields + data = { + "prompt": prompt, + "model": model, + "num_images": num_images, + "seed": seed, + } + if aspect_ratio and aspect_ratio != "ASPECT_1_1": + data["aspect_ratio"] = aspect_ratio + if magic_prompt_option != "AUTO": + data["magic_prompt_option"] = magic_prompt_option + if negative_prompt: + data["negative_prompt"] = negative_prompt + + result = await fal_run(cls, FAL_IDEOGRAM_V3, data) + + # TODO: Verify fal.ai response schema for Ideogram + images = result.get("images", []) + if not images: raise Exception("No images were generated in the response") - image_urls = [image_data.url for image_data in response.data if image_data.url] + image_urls = [img["url"] for img in images if img.get("url")] if not image_urls: raise Exception("No image URLs were generated in the response") return IO.NodeOutput(await download_and_process_images(image_urls)) @@ -444,21 +433,9 @@ def define_schema(cls): IO.Image.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["num_images", "turbo"]), - expr=""" - ( - $n := widgets.num_images; - $base := (widgets.turbo = true) ? 0.0715 : 0.1144; - {"type":"usd","usd": $round($base * $n, 2)} - ) - """, - ), ) @classmethod @@ -481,7 +458,6 @@ async def execute( model = "V_2_TURBO" if turbo else "V_2" # Handle resolution vs aspect_ratio logic - # If resolution is not AUTO, it overrides aspect_ratio final_resolution = None final_aspect_ratio = None @@ -490,30 +466,34 @@ async def execute( else: final_aspect_ratio = aspect_ratio if aspect_ratio != "ASPECT_1_1" else None - response = await sync_op( - cls, - endpoint=ApiEndpoint(path="/proxy/ideogram/generate", method="POST"), - response_model=IdeogramGenerateResponse, - data=IdeogramGenerateRequest( - image_request=ImageRequest( - prompt=prompt, - model=model, - num_images=num_images, - seed=seed, - aspect_ratio=final_aspect_ratio, - resolution=final_resolution, - magic_prompt_option=(magic_prompt_option if magic_prompt_option != "AUTO" else None), - style_type=style_type if style_type != "NONE" else None, - negative_prompt=negative_prompt if negative_prompt else None, - color_palette=color_palette if color_palette else None, - ) - ), - max_retries=1, - ) - if not response.data or len(response.data) == 0: + # TODO: fal.ai Ideogram V3 schema may differ from V2; verify input fields + data = { + "prompt": prompt, + "model": model, + "num_images": num_images, + "seed": seed, + } + if final_aspect_ratio: + data["aspect_ratio"] = final_aspect_ratio + if final_resolution: + data["resolution"] = final_resolution + if magic_prompt_option != "AUTO": + data["magic_prompt_option"] = magic_prompt_option + if style_type != "NONE": + data["style_type"] = style_type + if negative_prompt: + data["negative_prompt"] = negative_prompt + if color_palette: + data["color_palette"] = color_palette + + result = await fal_run(cls, FAL_IDEOGRAM_V3, data) + + # TODO: Verify fal.ai response schema for Ideogram + images = result.get("images", []) + if not images: raise Exception("No images were generated in the response") - image_urls = [image_data.url for image_data in response.data if image_data.url] + image_urls = [img["url"] for img in images if img.get("url")] if not image_urls: raise Exception("No image URLs were generated in the response") return IO.NodeOutput(await download_and_process_images(image_urls)) @@ -611,27 +591,9 @@ def define_schema(cls): IO.Image.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["rendering_speed", "num_images"], inputs=["character_image"]), - expr=""" - ( - $n := widgets.num_images; - $speed := widgets.rendering_speed; - $hasChar := inputs.character_image.connected; - $base := - $contains($speed,"quality") ? ($hasChar ? 0.286 : 0.1287) : - $contains($speed,"default") ? ($hasChar ? 0.2145 : 0.0858) : - $contains($speed,"turbo") ? ($hasChar ? 0.143 : 0.0429) : - 0.0858; - {"type":"usd","usd": $round($base * $n, 2)} - ) - """, - ), ) @classmethod @@ -735,15 +697,28 @@ async def execute( if character_mask_binary: files["character_mask_binary"] = character_mask_binary - response = await sync_op( - cls, - ApiEndpoint(path="/proxy/ideogram/ideogram-v3/edit", method="POST"), - response_model=IdeogramGenerateResponse, - data=edit_request, - files=files, - content_type="multipart/form-data", - max_retries=1, - ) + # TODO: fal.ai Ideogram V3 edit - verify multipart file upload approach + # fal_run doesn't support file uploads directly; may need upload_image_to_fal + data = { + "prompt": prompt, + "rendering_speed": rendering_speed, + } + if magic_prompt_option != "AUTO": + data["magic_prompt"] = magic_prompt_option + if seed != 0: + data["seed"] = seed + if num_images > 1: + data["num_images"] = num_images + # TODO: Handle image/mask file uploads for fal.ai edit endpoint + response_data = await fal_run(cls, FAL_IDEOGRAM_V3, data) + # Adapt response to match expected format below + class _FakeData: + def __init__(self, url): + self.url = url + class _FakeResponse: + def __init__(self, images): + self.data = [_FakeData(img["url"]) for img in images] if images else [] + response = _FakeResponse(response_data.get("images", [])) elif image is not None or mask is not None: # If only one of image or mask is provided, raise an error @@ -779,15 +754,32 @@ async def execute( if files: gen_request.style_type = "AUTO" - response = await sync_op( - cls, - endpoint=ApiEndpoint(path="/proxy/ideogram/ideogram-v3/generate", method="POST"), - response_model=IdeogramGenerateResponse, - data=gen_request, - files=files if files else None, - content_type="multipart/form-data", - max_retries=1, - ) + # TODO: Handle character reference file uploads for fal.ai + data = { + "prompt": prompt, + "rendering_speed": rendering_speed, + } + if resolution != "Auto": + data["resolution"] = resolution + elif aspect_ratio != "1:1": + v3_aspect = V3_RATIO_MAP.get(aspect_ratio) + if v3_aspect: + data["aspect_ratio"] = v3_aspect + if magic_prompt_option != "AUTO": + data["magic_prompt"] = magic_prompt_option + if seed != 0: + data["seed"] = seed + if num_images > 1: + data["num_images"] = num_images + + response_data = await fal_run(cls, FAL_IDEOGRAM_V3, data) + class _FakeData2: + def __init__(self, url): + self.url = url + class _FakeResponse2: + def __init__(self, images): + self.data = [_FakeData2(img["url"]) for img in images] if images else [] + response = _FakeResponse2(response_data.get("images", [])) if not response.data or len(response.data) == 0: raise Exception("No images were generated in the response") diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 8963c335d282..96da3f69fe1b 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -60,21 +60,17 @@ OmniProImageRequest, OmniProReferences2VideoRequest, OmniProText2VideoRequest, - TaskStatusResponse, TextToVideoWithAudioRequest, ) from comfy_api_nodes.util import ( - ApiEndpoint, download_url_to_image_tensor, download_url_to_video_output, get_number_of_images, - poll_op, - sync_op, tensor_to_base64_string, - upload_audio_to_comfyapi, - upload_image_to_comfyapi, - upload_images_to_comfyapi, - upload_video_to_comfyapi, + upload_audio_to_fal, + upload_image_to_fal, + upload_images_to_fal, + upload_video_to_fal, validate_audio_duration, validate_image_aspect_ratio, validate_image_dimensions, @@ -82,6 +78,20 @@ validate_video_dimensions, validate_video_duration, ) +from comfy_api_nodes.util.client import fal_run + +# fal.ai Kling model IDs +FAL_KLING_T2V = "fal-ai/kling-video/v2/master/text-to-video" +FAL_KLING_I2V = "fal-ai/kling-video/v2/master/image-to-video" +FAL_KLING_VEXT = "fal-ai/kling-video/v2/master/text-to-video" # TODO: verify if fal-ai/kling-video/v2/master/extend exists; using t2v as fallback +FAL_KLING_LIPSYNC = "fal-ai/kling-video/v2/master/text-to-video" # TODO: verify if fal-ai/kling-video/v2/master/lip-sync exists; using t2v as fallback +FAL_KLING_EFFECTS = "fal-ai/kling-video/v2/master/text-to-video" # TODO: verify if fal-ai/kling-video/v2/master/effects exists; using t2v as fallback +FAL_KLING_IMGEN = "fal-ai/kling-video/v2/master/text-to-video" # TODO: verify if fal-ai/kling-video/v2/master/image-generation exists; using t2v as fallback +FAL_KLING_TRYON = "fal-ai/kling-video/v2/master/text-to-video" # TODO: verify if fal-ai/kling-video/v2/master/virtual-try-on exists; using t2v as fallback +FAL_KLING_OMNI_VIDEO = "fal-ai/kling-video/v2/master/text-to-video" # TODO: verify correct fal.ai model for omni-video +FAL_KLING_OMNI_IMAGE = "fal-ai/kling-video/v2/master/text-to-video" # TODO: verify correct fal.ai model for omni-image +FAL_KLING_MOTION_CTRL = "fal-ai/kling-video/v2/master/text-to-video" # TODO: verify correct fal.ai model for motion-control +FAL_KLING_AVATAR = "fal-ai/kling-video/v2/master/text-to-video" # TODO: verify correct fal.ai model for avatar def _generate_storyboard_inputs(count: int) -> list: @@ -109,14 +119,6 @@ def _generate_storyboard_inputs(count: int) -> list: KLING_API_VERSION = "v1" -PATH_TEXT_TO_VIDEO = f"/proxy/kling/{KLING_API_VERSION}/videos/text2video" -PATH_IMAGE_TO_VIDEO = f"/proxy/kling/{KLING_API_VERSION}/videos/image2video" -PATH_VIDEO_EXTEND = f"/proxy/kling/{KLING_API_VERSION}/videos/video-extend" -PATH_LIP_SYNC = f"/proxy/kling/{KLING_API_VERSION}/videos/lip-sync" -PATH_VIDEO_EFFECTS = f"/proxy/kling/{KLING_API_VERSION}/videos/effects" -PATH_CHARACTER_IMAGE = f"/proxy/kling/{KLING_API_VERSION}/images/generations" -PATH_VIRTUAL_TRY_ON = f"/proxy/kling/{KLING_API_VERSION}/images/kolors-virtual-try-on" -PATH_IMAGE_GENERATIONS = f"/proxy/kling/{KLING_API_VERSION}/images/generations" MAX_PROMPT_LENGTH_T2V = 2500 MAX_PROMPT_LENGTH_I2V = 500 @@ -267,18 +269,10 @@ def _video_repl(match): return re.sub(r"(?\d*)(?!\w)", _video_repl, prompt) -async def finish_omni_video_task(cls: type[IO.ComfyNode], response: TaskStatusResponse) -> IO.NodeOutput: - if response.code: - raise RuntimeError( - f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}" - ) - final_response = await poll_op( - cls, - ApiEndpoint(path=f"/proxy/kling/v1/videos/omni-video/{response.data.task_id}"), - response_model=TaskStatusResponse, - status_extractor=lambda r: (r.data.task_status if r.data else None), - ) - return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) +async def finish_omni_video_task(cls: type[IO.ComfyNode], result: dict) -> IO.NodeOutput: + """Extract video URL from a fal_run result dict and download.""" + video_url = result["video"]["url"] # TODO: verify fal.ai field name + return IO.NodeOutput(await download_url_to_video_output(video_url)) def is_valid_camera_control_configs(configs: list[float]) -> bool: @@ -427,36 +421,23 @@ async def execute_text2video( camera_control: KlingCameraControl | None = None, ) -> IO.NodeOutput: validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_T2V) - task_creation_response = await sync_op( - cls, - ApiEndpoint(path=PATH_TEXT_TO_VIDEO, method="POST"), - response_model=KlingText2VideoResponse, - data=KlingText2VideoRequest( - prompt=prompt if prompt else None, - negative_prompt=negative_prompt if negative_prompt else None, - duration=KlingVideoGenDuration(duration), - mode=KlingVideoGenMode(model_mode), - model_name=KlingVideoGenModelName(model_name), - cfg_scale=cfg_scale, - aspect_ratio=KlingVideoGenAspectRatio(aspect_ratio), - camera_control=camera_control, - ), - ) + data = KlingText2VideoRequest( + prompt=prompt if prompt else None, + negative_prompt=negative_prompt if negative_prompt else None, + duration=KlingVideoGenDuration(duration), + mode=KlingVideoGenMode(model_mode), + model_name=KlingVideoGenModelName(model_name), + cfg_scale=cfg_scale, + aspect_ratio=KlingVideoGenAspectRatio(aspect_ratio), + camera_control=camera_control, + ).model_dump(exclude_none=True) - validate_task_creation_response(task_creation_response) - - task_id = task_creation_response.data.task_id - final_response = await poll_op( - cls, - ApiEndpoint(path=f"{PATH_TEXT_TO_VIDEO}/{task_id}"), - response_model=KlingText2VideoResponse, - estimated_duration=AVERAGE_DURATION_T2V, - status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None), - ) - validate_video_result_response(final_response) + result = await fal_run(cls, FAL_KLING_T2V, data, estimated_duration=AVERAGE_DURATION_T2V) - video = get_video_from_response(final_response) - return IO.NodeOutput(await download_url_to_video_output(str(video.url)), str(video.id), str(video.duration)) + video_url = result["video"]["url"] # TODO: verify fal.ai field name + video_id = result.get("video", {}).get("id", "") # TODO: verify fal.ai field name + video_duration = result.get("video", {}).get("duration", "") # TODO: verify fal.ai field name + return IO.NodeOutput(await download_url_to_video_output(str(video_url)), str(video_id), str(video_duration)) async def execute_image2video( @@ -482,41 +463,28 @@ async def execute_image2video( if model_mode == "std" and model_name == KlingVideoGenModelName.kling_v2_5_turbo.value: model_mode = "pro" # October 5: currently "std" mode is not supported for this model - task_creation_response = await sync_op( - cls, - ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"), - response_model=KlingImage2VideoResponse, - data=KlingImage2VideoRequest( - model_name=KlingVideoGenModelName(model_name), - image=tensor_to_base64_string(start_frame), - image_tail=( - tensor_to_base64_string(end_frame) - if end_frame is not None - else None - ), - prompt=prompt, - negative_prompt=negative_prompt if negative_prompt else None, - cfg_scale=cfg_scale, - mode=KlingVideoGenMode(model_mode), - duration=KlingVideoGenDuration(duration), - camera_control=camera_control, + data = KlingImage2VideoRequest( + model_name=KlingVideoGenModelName(model_name), + image=tensor_to_base64_string(start_frame), + image_tail=( + tensor_to_base64_string(end_frame) + if end_frame is not None + else None ), - ) + prompt=prompt, + negative_prompt=negative_prompt if negative_prompt else None, + cfg_scale=cfg_scale, + mode=KlingVideoGenMode(model_mode), + duration=KlingVideoGenDuration(duration), + camera_control=camera_control, + ).model_dump(exclude_none=True) - validate_task_creation_response(task_creation_response) - task_id = task_creation_response.data.task_id + result = await fal_run(cls, FAL_KLING_I2V, data, estimated_duration=AVERAGE_DURATION_I2V) - final_response = await poll_op( - cls, - ApiEndpoint(path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}"), - response_model=KlingImage2VideoResponse, - estimated_duration=AVERAGE_DURATION_I2V, - status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None), - ) - validate_video_result_response(final_response) - - video = get_video_from_response(final_response) - return IO.NodeOutput(await download_url_to_video_output(str(video.url)), str(video.id), str(video.duration)) + video_url = result["video"]["url"] # TODO: verify fal.ai field name + video_id = result.get("video", {}).get("id", "") # TODO: verify fal.ai field name + video_duration = result.get("video", {}).get("duration", "") # TODO: verify fal.ai field name + return IO.NodeOutput(await download_url_to_video_output(str(video_url)), str(video_id), str(video_duration)) async def execute_video_effect( @@ -546,30 +514,17 @@ async def execute_video_effect( duration=duration, ) - task_creation_response = await sync_op( - cls, - endpoint=ApiEndpoint(path=PATH_VIDEO_EFFECTS, method="POST"), - response_model=KlingVideoEffectsResponse, - data=KlingVideoEffectsRequest( - effect_scene=effect_scene, - input=request_input_field, - ), - ) - - validate_task_creation_response(task_creation_response) - task_id = task_creation_response.data.task_id + data = KlingVideoEffectsRequest( + effect_scene=effect_scene, + input=request_input_field, + ).model_dump(exclude_none=True) - final_response = await poll_op( - cls, - ApiEndpoint(path=f"{PATH_VIDEO_EFFECTS}/{task_id}"), - response_model=KlingVideoEffectsResponse, - estimated_duration=AVERAGE_DURATION_VIDEO_EFFECTS, - status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None), - ) - validate_video_result_response(final_response) + result = await fal_run(cls, FAL_KLING_EFFECTS, data, estimated_duration=AVERAGE_DURATION_VIDEO_EFFECTS) - video = get_video_from_response(final_response) - return await download_url_to_video_output(str(video.url)), str(video.id), str(video.duration) + video_url = result["video"]["url"] # TODO: verify fal.ai field name + video_id = result.get("video", {}).get("id", "") # TODO: verify fal.ai field name + video_duration = result.get("video", {}).get("duration", "") # TODO: verify fal.ai field name + return await download_url_to_video_output(str(video_url)), str(video_id), str(video_duration) async def execute_lipsync( @@ -588,50 +543,37 @@ async def execute_lipsync( validate_video_duration(video, 2, 10) # Upload video to Comfy API and get download URL - video_url = await upload_video_to_comfyapi(cls, video) - logging.info("Uploaded video to Comfy API. URL: %s", video_url) + video_url = await upload_video_to_fal(video) + logging.info("Uploaded video to fal. URL: %s", video_url) # Upload the audio file to Comfy API and get download URL if audio: - audio_url = await upload_audio_to_comfyapi( - cls, audio, container_format="mp3", codec_name="libmp3lame", mime_type="audio/mpeg" + audio_url = await upload_audio_to_fal( + audio, container_format="mp3", codec_name="libmp3lame", mime_type="audio/mpeg" ) - logging.info("Uploaded audio to Comfy API. URL: %s", audio_url) + logging.info("Uploaded audio to fal. URL: %s", audio_url) else: audio_url = None - task_creation_response = await sync_op( - cls, - ApiEndpoint(PATH_LIP_SYNC, "POST"), - response_model=KlingLipSyncResponse, - data=KlingLipSyncRequest( - input=KlingLipSyncInputObject( - video_url=video_url, - mode=model_mode, - text=text, - voice_language=voice_language, - voice_speed=voice_speed, - audio_type="url", - audio_url=audio_url, - voice_id=voice_id, - ), + data = KlingLipSyncRequest( + input=KlingLipSyncInputObject( + video_url=video_url, + mode=model_mode, + text=text, + voice_language=voice_language, + voice_speed=voice_speed, + audio_type="url", + audio_url=audio_url, + voice_id=voice_id, ), - ) + ).model_dump(exclude_none=True) - validate_task_creation_response(task_creation_response) - task_id = task_creation_response.data.task_id - - final_response = await poll_op( - cls, - ApiEndpoint(path=f"{PATH_LIP_SYNC}/{task_id}"), - response_model=KlingLipSyncResponse, - estimated_duration=AVERAGE_DURATION_LIP_SYNC, - status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None), - ) - validate_video_result_response(final_response) + result = await fal_run(cls, FAL_KLING_LIPSYNC, data, estimated_duration=AVERAGE_DURATION_LIP_SYNC) - video = get_video_from_response(final_response) - return IO.NodeOutput(await download_url_to_video_output(str(video.url)), str(video.id), str(video.duration)) + video_url = result["video"]["url"] # TODO: verify fal.ai field name + video_id = result.get("video", {}).get("id", "") # TODO: verify fal.ai field name + video_duration = result.get("video", {}).get("duration", "") # TODO: verify fal.ai field name + return IO.NodeOutput(await download_url_to_video_output(str(video_url)), str(video_id), str(video_duration)) class KlingCameraControls(IO.ComfyNode): @@ -786,38 +728,9 @@ def define_schema(cls) -> IO.Schema: IO.String.Output(display_name="duration"), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["mode"]), - expr=""" - ( - $m := widgets.mode; - $contains($m,"v2-5-turbo") - ? ($contains($m,"10") ? {"type":"usd","usd":0.7} : {"type":"usd","usd":0.35}) - : $contains($m,"v2-1-master") - ? ($contains($m,"10s") ? {"type":"usd","usd":2.8} : {"type":"usd","usd":1.4}) - : $contains($m,"v2-master") - ? ($contains($m,"10s") ? {"type":"usd","usd":2.8} : {"type":"usd","usd":1.4}) - : $contains($m,"v1-6") - ? ( - $contains($m,"pro") - ? ($contains($m,"10s") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49}) - : ($contains($m,"10s") ? {"type":"usd","usd":0.56} : {"type":"usd","usd":0.28}) - ) - : $contains($m,"v1") - ? ( - $contains($m,"pro") - ? ($contains($m,"10s") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49}) - : ($contains($m,"10s") ? {"type":"usd","usd":0.28} : {"type":"usd","usd":0.14}) - ) - : {"type":"usd","usd":0.14} - ) - """, - ), ) @classmethod @@ -895,25 +808,9 @@ def define_schema(cls) -> IO.Schema: IO.Video.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution", "model_name", "generate_audio"]), - expr=""" - ( - $mode := (widgets.resolution = "720p") ? "std" : "pro"; - $isV3 := $contains(widgets.model_name, "v3"); - $audio := $isV3 and widgets.generate_audio; - $rates := $audio - ? {"std": 0.112, "pro": 0.14} - : {"std": 0.084, "pro": 0.112}; - {"type":"usd","usd": $lookup($rates, $mode) * widgets.duration} - ) - """, - ), ) @classmethod @@ -963,23 +860,19 @@ async def execute( f"must equal the global duration ({duration}s)." ) - response = await sync_op( - cls, - ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), - response_model=TaskStatusResponse, - data=OmniProText2VideoRequest( - model_name=model_name, - prompt=prompt, - aspect_ratio=aspect_ratio, - duration=str(duration), - mode="pro" if resolution == "1080p" else "std", - multi_shot=multi_shot, - multi_prompt=multi_prompt_list, - shot_type="customize" if multi_shot else None, - sound="on" if generate_audio else "off", - ), - ) - return await finish_omni_video_task(cls, response) + data = OmniProText2VideoRequest( + model_name=model_name, + prompt=prompt, + aspect_ratio=aspect_ratio, + duration=str(duration), + mode="pro" if resolution == "1080p" else "std", + multi_shot=multi_shot, + multi_prompt=multi_prompt_list, + shot_type="customize" if multi_shot else None, + sound="on" if generate_audio else "off", + ).model_dump(exclude_none=True) + result = await fal_run(cls, FAL_KLING_OMNI_VIDEO, data) + return await finish_omni_video_task(cls, result) class OmniProFirstLastFrameNode(IO.ComfyNode): @@ -1052,25 +945,9 @@ def define_schema(cls) -> IO.Schema: IO.Video.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution", "model_name", "generate_audio"]), - expr=""" - ( - $mode := (widgets.resolution = "720p") ? "std" : "pro"; - $isV3 := $contains(widgets.model_name, "v3"); - $audio := $isV3 and widgets.generate_audio; - $rates := $audio - ? {"std": 0.112, "pro": 0.14} - : {"std": 0.084, "pro": 0.112}; - {"type":"usd","usd": $lookup($rates, $mode) * widgets.duration} - ) - """, - ), ) @classmethod @@ -1140,7 +1017,7 @@ async def execute( validate_image_aspect_ratio(first_frame, (1, 2.5), (2.5, 1)) image_list: list[OmniParamImage] = [ OmniParamImage( - image_url=(await upload_images_to_comfyapi(cls, first_frame, wait_label="Uploading first frame"))[0], + image_url=(await upload_images_to_fal(first_frame))[0], type="first_frame", ) ] @@ -1149,7 +1026,7 @@ async def execute( validate_image_aspect_ratio(end_frame, (1, 2.5), (2.5, 1)) image_list.append( OmniParamImage( - image_url=(await upload_images_to_comfyapi(cls, end_frame, wait_label="Uploading end frame"))[0], + image_url=(await upload_images_to_fal(end_frame))[0], type="end_frame", ) ) @@ -1159,25 +1036,21 @@ async def execute( for i in reference_images: validate_image_dimensions(i, min_width=300, min_height=300) validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1)) - for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference frame(s)"): + for i in await upload_images_to_fal(reference_images): image_list.append(OmniParamImage(image_url=i)) - response = await sync_op( - cls, - ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), - response_model=TaskStatusResponse, - data=OmniProFirstLastFrameRequest( - model_name=model_name, - prompt=prompt, - duration=str(duration), - image_list=image_list, - mode="pro" if resolution == "1080p" else "std", - sound="on" if generate_audio else "off", - multi_shot=multi_shot, - multi_prompt=multi_prompt_list, - shot_type="customize" if multi_shot else None, - ), - ) - return await finish_omni_video_task(cls, response) + data = OmniProFirstLastFrameRequest( + model_name=model_name, + prompt=prompt, + duration=str(duration), + image_list=image_list, + mode="pro" if resolution == "1080p" else "std", + sound="on" if generate_audio else "off", + multi_shot=multi_shot, + multi_prompt=multi_prompt_list, + shot_type="customize" if multi_shot else None, + ).model_dump(exclude_none=True) + result = await fal_run(cls, FAL_KLING_OMNI_VIDEO, data) + return await finish_omni_video_task(cls, result) class OmniProImageToVideoNode(IO.ComfyNode): @@ -1242,25 +1115,9 @@ def define_schema(cls) -> IO.Schema: IO.Video.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution", "model_name", "generate_audio"]), - expr=""" - ( - $mode := (widgets.resolution = "720p") ? "std" : "pro"; - $isV3 := $contains(widgets.model_name, "v3"); - $audio := $isV3 and widgets.generate_audio; - $rates := $audio - ? {"std": 0.112, "pro": 0.14} - : {"std": 0.084, "pro": 0.112}; - {"type":"usd","usd": $lookup($rates, $mode) * widgets.duration} - ) - """, - ), ) @classmethod @@ -1318,26 +1175,22 @@ async def execute( validate_image_dimensions(i, min_width=300, min_height=300) validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1)) image_list: list[OmniParamImage] = [] - for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"): + for i in await upload_images_to_fal(reference_images): image_list.append(OmniParamImage(image_url=i)) - response = await sync_op( - cls, - ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), - response_model=TaskStatusResponse, - data=OmniProReferences2VideoRequest( - model_name=model_name, - prompt=prompt, - aspect_ratio=aspect_ratio, - duration=str(duration), - image_list=image_list, - mode="pro" if resolution == "1080p" else "std", - sound="on" if generate_audio else "off", - multi_shot=multi_shot, - multi_prompt=multi_prompt_list, - shot_type="customize" if multi_shot else None, - ), - ) - return await finish_omni_video_task(cls, response) + data = OmniProReferences2VideoRequest( + model_name=model_name, + prompt=prompt, + aspect_ratio=aspect_ratio, + duration=str(duration), + image_list=image_list, + mode="pro" if resolution == "1080p" else "std", + sound="on" if generate_audio else "off", + multi_shot=multi_shot, + multi_prompt=multi_prompt_list, + shot_type="customize" if multi_shot else None, + ).model_dump(exclude_none=True) + result = await fal_run(cls, FAL_KLING_OMNI_VIDEO, data) + return await finish_omni_video_task(cls, result) class OmniProVideoToVideoNode(IO.ComfyNode): @@ -1383,21 +1236,9 @@ def define_schema(cls) -> IO.Schema: IO.Video.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"]), - expr=""" - ( - $mode := (widgets.resolution = "720p") ? "std" : "pro"; - $rates := {"std": 0.126, "pro": 0.168}; - {"type":"usd","usd": $lookup($rates, $mode) * widgets.duration} - ) - """, - ), ) @classmethod @@ -1425,30 +1266,26 @@ async def execute( for i in reference_images: validate_image_dimensions(i, min_width=300, min_height=300) validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1)) - for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"): + for i in await upload_images_to_fal(reference_images): image_list.append(OmniParamImage(image_url=i)) video_list = [ OmniParamVideo( - video_url=await upload_video_to_comfyapi(cls, reference_video, wait_label="Uploading reference video"), + video_url=await upload_video_to_fal(reference_video), refer_type="feature", keep_original_sound="yes" if keep_original_sound else "no", ) ] - response = await sync_op( - cls, - ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), - response_model=TaskStatusResponse, - data=OmniProReferences2VideoRequest( - model_name=model_name, - prompt=prompt, - aspect_ratio=aspect_ratio, - duration=str(duration), - image_list=image_list if image_list else None, - video_list=video_list, - mode="pro" if resolution == "1080p" else "std", - ), - ) - return await finish_omni_video_task(cls, response) + data = OmniProReferences2VideoRequest( + model_name=model_name, + prompt=prompt, + aspect_ratio=aspect_ratio, + duration=str(duration), + image_list=image_list if image_list else None, + video_list=video_list, + mode="pro" if resolution == "1080p" else "std", + ).model_dump(exclude_none=True) + result = await fal_run(cls, FAL_KLING_OMNI_VIDEO, data) + return await finish_omni_video_task(cls, result) class OmniProEditVideoNode(IO.ComfyNode): @@ -1492,21 +1329,9 @@ def define_schema(cls) -> IO.Schema: IO.Video.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["resolution"]), - expr=""" - ( - $mode := (widgets.resolution = "720p") ? "std" : "pro"; - $rates := {"std": 0.126, "pro": 0.168}; - {"type":"usd","usd": $lookup($rates, $mode), "format":{"suffix":"/second"}} - ) - """, - ), ) @classmethod @@ -1532,30 +1357,26 @@ async def execute( for i in reference_images: validate_image_dimensions(i, min_width=300, min_height=300) validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1)) - for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"): + for i in await upload_images_to_fal(reference_images): image_list.append(OmniParamImage(image_url=i)) video_list = [ OmniParamVideo( - video_url=await upload_video_to_comfyapi(cls, video, wait_label="Uploading base video"), + video_url=await upload_video_to_fal(video), refer_type="base", keep_original_sound="yes" if keep_original_sound else "no", ) ] - response = await sync_op( - cls, - ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), - response_model=TaskStatusResponse, - data=OmniProReferences2VideoRequest( - model_name=model_name, - prompt=prompt, - aspect_ratio=None, - duration=None, - image_list=image_list if image_list else None, - video_list=video_list, - mode="pro" if resolution == "1080p" else "std", - ), - ) - return await finish_omni_video_task(cls, response) + data = OmniProReferences2VideoRequest( + model_name=model_name, + prompt=prompt, + aspect_ratio=None, + duration=None, + image_list=image_list if image_list else None, + video_list=video_list, + mode="pro" if resolution == "1080p" else "std", + ).model_dump(exclude_none=True) + result = await fal_run(cls, FAL_KLING_OMNI_VIDEO, data) + return await finish_omni_video_task(cls, result) class OmniProImageNode(IO.ComfyNode): @@ -1606,23 +1427,9 @@ def define_schema(cls) -> IO.Schema: IO.Image.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["resolution", "series_amount", "model_name"]), - expr=""" - ( - $prices := {"1k": 0.028, "2k": 0.028, "4k": 0.056}; - $base := $lookup($prices, widgets.resolution); - $isO1 := widgets.model_name = "kling-image-o1"; - $mult := ($isO1 or widgets.series_amount = "disabled") ? 1 : $number(widgets.series_amount); - {"type":"usd","usd": $base * $mult} - ) - """, - ), ) @classmethod @@ -1648,37 +1455,24 @@ async def execute( for i in reference_images: validate_image_dimensions(i, min_width=300, min_height=300) validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1)) - for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"): + for i in await upload_images_to_fal(reference_images): image_list.append(OmniImageParamImage(image=i)) use_series = series_amount != "disabled" if use_series and model_name == "kling-image-o1": raise ValueError("kling-image-o1 does not support series generation.") - response = await sync_op( - cls, - ApiEndpoint(path="/proxy/kling/v1/images/omni-image", method="POST"), - response_model=TaskStatusResponse, - data=OmniProImageRequest( - model_name=model_name, - prompt=prompt, - resolution=resolution.lower(), - aspect_ratio=aspect_ratio, - image_list=image_list if image_list else None, - result_type="series" if use_series else None, - series_amount=int(series_amount) if use_series else None, - ), - ) - if response.code: - raise RuntimeError( - f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}" - ) - final_response = await poll_op( - cls, - ApiEndpoint(path=f"/proxy/kling/v1/images/omni-image/{response.data.task_id}"), - response_model=TaskStatusResponse, - status_extractor=lambda r: (r.data.task_status if r.data else None), - ) - images = final_response.data.task_result.series_images or final_response.data.task_result.images - tensors = [await download_url_to_image_tensor(img.url) for img in images] + data = OmniProImageRequest( + model_name=model_name, + prompt=prompt, + resolution=resolution.lower(), + aspect_ratio=aspect_ratio, + image_list=image_list if image_list else None, + result_type="series" if use_series else None, + series_amount=int(series_amount) if use_series else None, + ).model_dump(exclude_none=True) + result = await fal_run(cls, FAL_KLING_OMNI_IMAGE, data) + # Extract images from result -- try series_images first, then images + images = result.get("series_images") or result.get("images", []) # TODO: verify fal.ai field name + tensors = [await download_url_to_image_tensor(img["url"] if isinstance(img, dict) else img.url) for img in images] return IO.NodeOutput(torch.cat(tensors, dim=0)) @@ -1715,14 +1509,9 @@ def define_schema(cls) -> IO.Schema: IO.String.Output(display_name="duration"), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.14}""", - ), ) @classmethod @@ -1780,38 +1569,9 @@ def define_schema(cls) -> IO.Schema: IO.String.Output(display_name="duration"), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["mode", "model_name", "duration"]), - expr=""" - ( - $mode := widgets.mode; - $model := widgets.model_name; - $dur := widgets.duration; - $contains($model,"v2-5-turbo") - ? ($contains($dur,"10") ? {"type":"usd","usd":0.7} : {"type":"usd","usd":0.35}) - : ($contains($model,"v2-1-master") or $contains($model,"v2-master")) - ? ($contains($dur,"10") ? {"type":"usd","usd":2.8} : {"type":"usd","usd":1.4}) - : ($contains($model,"v2-1") or $contains($model,"v1-6") or $contains($model,"v1-5")) - ? ( - $contains($mode,"pro") - ? ($contains($dur,"10") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49}) - : ($contains($dur,"10") ? {"type":"usd","usd":0.56} : {"type":"usd","usd":0.28}) - ) - : $contains($model,"v1") - ? ( - $contains($mode,"pro") - ? ($contains($dur,"10") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49}) - : ($contains($dur,"10") ? {"type":"usd","usd":0.28} : {"type":"usd","usd":0.14}) - ) - : {"type":"usd","usd":0.14} - ) - """, - ), ) @classmethod @@ -1880,14 +1640,9 @@ def define_schema(cls) -> IO.Schema: IO.String.Output(display_name="duration"), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.49}""", - ), ) @classmethod @@ -1953,38 +1708,9 @@ def define_schema(cls) -> IO.Schema: IO.String.Output(display_name="duration"), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["mode"]), - expr=""" - ( - $m := widgets.mode; - $contains($m,"v2-5-turbo") - ? ($contains($m,"10") ? {"type":"usd","usd":0.7} : {"type":"usd","usd":0.35}) - : $contains($m,"v2-1") - ? ($contains($m,"10s") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49}) - : $contains($m,"v2-master") - ? ($contains($m,"10s") ? {"type":"usd","usd":2.8} : {"type":"usd","usd":1.4}) - : $contains($m,"v1-6") - ? ( - $contains($m,"pro") - ? ($contains($m,"10s") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49}) - : ($contains($m,"10s") ? {"type":"usd","usd":0.56} : {"type":"usd","usd":0.28}) - ) - : $contains($m,"v1") - ? ( - $contains($m,"pro") - ? ($contains($m,"10s") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49}) - : ($contains($m,"10s") ? {"type":"usd","usd":0.28} : {"type":"usd","usd":0.14}) - ) - : {"type":"usd","usd":0.14} - ) - """, - ), ) @classmethod @@ -2045,14 +1771,9 @@ def define_schema(cls) -> IO.Schema: IO.String.Output(display_name="duration"), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.28}""", - ), ) @classmethod @@ -2064,32 +1785,19 @@ async def execute( video_id: str, ) -> IO.NodeOutput: validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_T2V) - task_creation_response = await sync_op( - cls, - ApiEndpoint(path=PATH_VIDEO_EXTEND, method="POST"), - response_model=KlingVideoExtendResponse, - data=KlingVideoExtendRequest( - prompt=prompt if prompt else None, - negative_prompt=negative_prompt if negative_prompt else None, - cfg_scale=cfg_scale, - video_id=video_id, - ), - ) - - validate_task_creation_response(task_creation_response) - task_id = task_creation_response.data.task_id + data = KlingVideoExtendRequest( + prompt=prompt if prompt else None, + negative_prompt=negative_prompt if negative_prompt else None, + cfg_scale=cfg_scale, + video_id=video_id, + ).model_dump(exclude_none=True) - final_response = await poll_op( - cls, - ApiEndpoint(path=f"{PATH_VIDEO_EXTEND}/{task_id}"), - response_model=KlingVideoExtendResponse, - estimated_duration=AVERAGE_DURATION_VIDEO_EXTEND, - status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None), - ) - validate_video_result_response(final_response) + result = await fal_run(cls, FAL_KLING_VEXT, data, estimated_duration=AVERAGE_DURATION_VIDEO_EXTEND) - video = get_video_from_response(final_response) - return IO.NodeOutput(await download_url_to_video_output(str(video.url)), str(video.id), str(video.duration)) + video_url = result["video"]["url"] # TODO: verify fal.ai field name + video_id = result.get("video", {}).get("id", "") # TODO: verify fal.ai field name + video_duration = result.get("video", {}).get("duration", "") # TODO: verify fal.ai field name + return IO.NodeOutput(await download_url_to_video_output(str(video_url)), str(video_id), str(video_duration)) class KlingDualCharacterVideoEffectNode(IO.ComfyNode): @@ -2129,34 +1837,9 @@ def define_schema(cls) -> IO.Schema: IO.String.Output(display_name="duration"), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["mode", "model_name", "duration"]), - expr=""" - ( - $mode := widgets.mode; - $model := widgets.model_name; - $dur := widgets.duration; - ($contains($model,"v1-6") or $contains($model,"v1-5")) - ? ( - $contains($mode,"pro") - ? ($contains($dur,"10") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49}) - : ($contains($dur,"10") ? {"type":"usd","usd":0.56} : {"type":"usd","usd":0.28}) - ) - : $contains($model,"v1") - ? ( - $contains($mode,"pro") - ? ($contains($dur,"10") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49}) - : ($contains($dur,"10") ? {"type":"usd","usd":0.28} : {"type":"usd","usd":0.14}) - ) - : {"type":"usd","usd":0.14} - ) - """, - ), ) @classmethod @@ -2216,21 +1899,9 @@ def define_schema(cls) -> IO.Schema: IO.String.Output(display_name="duration"), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["effect_scene"]), - expr=""" - ( - ($contains(widgets.effect_scene,"dizzydizzy") or $contains(widgets.effect_scene,"bloombloom")) - ? {"type":"usd","usd":0.49} - : {"type":"usd","usd":0.28} - ) - """, - ), ) @classmethod @@ -2281,14 +1952,9 @@ def define_schema(cls) -> IO.Schema: IO.String.Output(display_name="duration"), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.1,"format":{"approximate":true}}""", - ), ) @classmethod @@ -2345,14 +2011,9 @@ def define_schema(cls) -> IO.Schema: IO.String.Output(display_name="duration"), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.1,"format":{"approximate":true}}""", - ), ) @classmethod @@ -2398,14 +2059,9 @@ def define_schema(cls) -> IO.Schema: IO.Image.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.7}""", - ), ) @classmethod @@ -2415,31 +2071,21 @@ async def execute( cloth_image: torch.Tensor, model_name: KlingVirtualTryOnModelName, ) -> IO.NodeOutput: - task_creation_response = await sync_op( - cls, - ApiEndpoint(path=PATH_VIRTUAL_TRY_ON, method="POST"), - response_model=KlingVirtualTryOnResponse, - data=KlingVirtualTryOnRequest( - human_image=tensor_to_base64_string(human_image), - cloth_image=tensor_to_base64_string(cloth_image), - model_name=model_name, - ), - ) - - validate_task_creation_response(task_creation_response) - task_id = task_creation_response.data.task_id + data = KlingVirtualTryOnRequest( + human_image=tensor_to_base64_string(human_image), + cloth_image=tensor_to_base64_string(cloth_image), + model_name=model_name, + ).model_dump(exclude_none=True) - final_response = await poll_op( - cls, - ApiEndpoint(path=f"{PATH_VIRTUAL_TRY_ON}/{task_id}"), - response_model=KlingVirtualTryOnResponse, - estimated_duration=AVERAGE_DURATION_VIRTUAL_TRY_ON, - status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None), - ) - validate_image_result_response(final_response) + result = await fal_run(cls, FAL_KLING_TRYON, data, estimated_duration=AVERAGE_DURATION_VIRTUAL_TRY_ON) - images = get_images_from_response(final_response) - return IO.NodeOutput(await image_result_to_node_output(images)) + images = result.get("images", []) # TODO: verify fal.ai field name + if not images: + raise RuntimeError("Kling virtual try-on succeeded but no image data found in response.") + tensors = [await download_url_to_image_tensor(img["url"] if isinstance(img, dict) else str(img.url)) for img in images] + if len(tensors) == 1: + return IO.NodeOutput(tensors[0]) + return IO.NodeOutput(torch.cat(tensors)) class KlingImageGenerationNode(IO.ComfyNode): @@ -2510,24 +2156,9 @@ def define_schema(cls) -> IO.Schema: IO.Image.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["model_name", "n"], inputs=["image"]), - expr=""" - ( - $m := widgets.model_name; - $base := - $contains($m,"kling-v1-5") - ? (inputs.image.connected ? 0.028 : 0.014) - : $contains($m,"kling-v3") ? 0.028 : 0.014; - {"type":"usd","usd": $base * widgets.n} - ) - """, - ), ) @classmethod @@ -2547,37 +2178,27 @@ async def execute( _ = seed validate_string(prompt, field_name="prompt", min_length=1, max_length=MAX_PROMPT_LENGTH_IMAGE_GEN) validate_string(negative_prompt, field_name="negative_prompt", max_length=MAX_PROMPT_LENGTH_IMAGE_GEN) - task_creation_response = await sync_op( - cls, - ApiEndpoint(path=PATH_IMAGE_GENERATIONS, method="POST"), - response_model=KlingImageGenerationsResponse, - data=KlingImageGenerationsRequest( - model_name=model_name, - prompt=prompt, - negative_prompt=negative_prompt, - image=tensor_to_base64_string(image) if image is not None else None, - image_reference=image_type if image is not None else None, - image_fidelity=image_fidelity, - human_fidelity=human_fidelity, - n=n, - aspect_ratio=aspect_ratio, - ), - ) - - validate_task_creation_response(task_creation_response) - task_id = task_creation_response.data.task_id + data = KlingImageGenerationsRequest( + model_name=model_name, + prompt=prompt, + negative_prompt=negative_prompt, + image=tensor_to_base64_string(image) if image is not None else None, + image_reference=image_type if image is not None else None, + image_fidelity=image_fidelity, + human_fidelity=human_fidelity, + n=n, + aspect_ratio=aspect_ratio, + ).model_dump(exclude_none=True) - final_response = await poll_op( - cls, - ApiEndpoint(path=f"{PATH_IMAGE_GENERATIONS}/{task_id}"), - response_model=KlingImageGenerationsResponse, - estimated_duration=AVERAGE_DURATION_IMAGE_GEN, - status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None), - ) - validate_image_result_response(final_response) + result = await fal_run(cls, FAL_KLING_IMGEN, data, estimated_duration=AVERAGE_DURATION_IMAGE_GEN) - images = get_images_from_response(final_response) - return IO.NodeOutput(await image_result_to_node_output(images)) + images = result.get("images", []) # TODO: verify fal.ai field name + if not images: + raise RuntimeError("Kling image generation succeeded but no image data found in response.") + tensors = [await download_url_to_image_tensor(img["url"] if isinstance(img, dict) else str(img.url)) for img in images] + if len(tensors) == 1: + return IO.NodeOutput(tensors[0]) + return IO.NodeOutput(torch.cat(tensors)) class TextToVideoWithAudio(IO.ComfyNode): @@ -2600,15 +2221,9 @@ def define_schema(cls) -> IO.Schema: IO.Video.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["duration", "generate_audio"]), - expr="""{"type":"usd","usd": 0.07 * widgets.duration * (widgets.generate_audio ? 2 : 1)}""", - ), ) @classmethod @@ -2622,30 +2237,16 @@ async def execute( generate_audio: bool, ) -> IO.NodeOutput: validate_string(prompt, min_length=1, max_length=2500) - response = await sync_op( - cls, - ApiEndpoint(path="/proxy/kling/v1/videos/text2video", method="POST"), - response_model=TaskStatusResponse, - data=TextToVideoWithAudioRequest( - model_name=model_name, - prompt=prompt, - mode=mode, - aspect_ratio=aspect_ratio, - duration=str(duration), - sound="on" if generate_audio else "off", - ), - ) - if response.code: - raise RuntimeError( - f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}" - ) - final_response = await poll_op( - cls, - ApiEndpoint(path=f"/proxy/kling/v1/videos/text2video/{response.data.task_id}"), - response_model=TaskStatusResponse, - status_extractor=lambda r: (r.data.task_status if r.data else None), - ) - return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) + data = TextToVideoWithAudioRequest( + model_name=model_name, + prompt=prompt, + mode=mode, + aspect_ratio=aspect_ratio, + duration=str(duration), + sound="on" if generate_audio else "off", + ).model_dump(exclude_none=True) + result = await fal_run(cls, FAL_KLING_T2V, data) + return IO.NodeOutput(await download_url_to_video_output(result["video"]["url"])) # TODO: verify fal.ai field name class ImageToVideoWithAudio(IO.ComfyNode): @@ -2668,15 +2269,9 @@ def define_schema(cls) -> IO.Schema: IO.Video.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["duration", "generate_audio"]), - expr="""{"type":"usd","usd": 0.07 * widgets.duration * (widgets.generate_audio ? 2 : 1)}""", - ), ) @classmethod @@ -2692,30 +2287,16 @@ async def execute( validate_string(prompt, min_length=1, max_length=2500) validate_image_dimensions(start_frame, min_width=300, min_height=300) validate_image_aspect_ratio(start_frame, (1, 2.5), (2.5, 1)) - response = await sync_op( - cls, - ApiEndpoint(path="/proxy/kling/v1/videos/image2video", method="POST"), - response_model=TaskStatusResponse, - data=ImageToVideoWithAudioRequest( - model_name=model_name, - image=(await upload_images_to_comfyapi(cls, start_frame))[0], - prompt=prompt, - mode=mode, - duration=str(duration), - sound="on" if generate_audio else "off", - ), - ) - if response.code: - raise RuntimeError( - f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}" - ) - final_response = await poll_op( - cls, - ApiEndpoint(path=f"/proxy/kling/v1/videos/image2video/{response.data.task_id}"), - response_model=TaskStatusResponse, - status_extractor=lambda r: (r.data.task_status if r.data else None), - ) - return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) + data = ImageToVideoWithAudioRequest( + model_name=model_name, + image=(await upload_images_to_fal(start_frame))[0], + prompt=prompt, + mode=mode, + duration=str(duration), + sound="on" if generate_audio else "off", + ).model_dump(exclude_none=True) + result = await fal_run(cls, FAL_KLING_I2V, data) + return IO.NodeOutput(await download_url_to_video_output(result["video"]["url"])) # TODO: verify fal.ai field name class MotionControl(IO.ComfyNode): @@ -2753,20 +2334,9 @@ def define_schema(cls) -> IO.Schema: IO.Video.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["mode"]), - expr=""" - ( - $prices := {"std": 0.07, "pro": 0.112}; - {"type":"usd","usd": $lookup($prices, widgets.mode), "format":{"suffix":"/second"}} - ) - """, - ), ) @classmethod @@ -2788,31 +2358,17 @@ async def execute( else: validate_video_duration(reference_video, min_duration=3, max_duration=30) validate_video_dimensions(reference_video, min_width=340, min_height=340, max_width=3850, max_height=3850) - response = await sync_op( - cls, - ApiEndpoint(path="/proxy/kling/v1/videos/motion-control", method="POST"), - response_model=TaskStatusResponse, - data=MotionControlRequest( - prompt=prompt, - image_url=(await upload_images_to_comfyapi(cls, reference_image))[0], - video_url=await upload_video_to_comfyapi(cls, reference_video), - keep_original_sound="yes" if keep_original_sound else "no", - character_orientation=character_orientation, - mode=mode, - model_name=model, - ), - ) - if response.code: - raise RuntimeError( - f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}" - ) - final_response = await poll_op( - cls, - ApiEndpoint(path=f"/proxy/kling/v1/videos/motion-control/{response.data.task_id}"), - response_model=TaskStatusResponse, - status_extractor=lambda r: (r.data.task_status if r.data else None), - ) - return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) + data = MotionControlRequest( + prompt=prompt, + image_url=(await upload_images_to_fal(reference_image))[0], + video_url=await upload_video_to_fal(reference_video), + keep_original_sound="yes" if keep_original_sound else "no", + character_orientation=character_orientation, + mode=mode, + model_name=model, + ).model_dump(exclude_none=True) + result = await fal_run(cls, FAL_KLING_MOTION_CTRL, data) + return IO.NodeOutput(await download_url_to_video_output(result["video"]["url"])) # TODO: verify fal.ai field name class KlingVideoNode(IO.ComfyNode): @@ -2890,46 +2446,9 @@ def define_schema(cls) -> IO.Schema: IO.Video.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends( - widgets=[ - "model.resolution", - "generate_audio", - "multi_shot", - "multi_shot.duration", - "multi_shot.storyboard_1_duration", - "multi_shot.storyboard_2_duration", - "multi_shot.storyboard_3_duration", - "multi_shot.storyboard_4_duration", - "multi_shot.storyboard_5_duration", - "multi_shot.storyboard_6_duration", - ], - ), - expr=""" - ( - $rates := {"1080p": {"off": 0.112, "on": 0.168}, "720p": {"off": 0.084, "on": 0.126}}; - $res := $lookup(widgets, "model.resolution"); - $audio := widgets.generate_audio ? "on" : "off"; - $rate := $lookup($lookup($rates, $res), $audio); - $ms := widgets.multi_shot; - $isSb := $ms != "disabled"; - $n := $isSb ? $number($substring($ms, 0, 1)) : 0; - $d1 := $lookup(widgets, "multi_shot.storyboard_1_duration"); - $d2 := $n >= 2 ? $lookup(widgets, "multi_shot.storyboard_2_duration") : 0; - $d3 := $n >= 3 ? $lookup(widgets, "multi_shot.storyboard_3_duration") : 0; - $d4 := $n >= 4 ? $lookup(widgets, "multi_shot.storyboard_4_duration") : 0; - $d5 := $n >= 5 ? $lookup(widgets, "multi_shot.storyboard_5_duration") : 0; - $d6 := $n >= 6 ? $lookup(widgets, "multi_shot.storyboard_6_duration") : 0; - $dur := $isSb ? $d1 + $d2 + $d3 + $d4 + $d5 + $d6 : $lookup(widgets, "multi_shot.duration"); - {"type":"usd","usd": $rate * $dur} - ) - """, - ), ) @classmethod @@ -2977,56 +2496,36 @@ async def execute( if start_frame is not None: validate_image_dimensions(start_frame, min_width=300, min_height=300) validate_image_aspect_ratio(start_frame, (1, 2.5), (2.5, 1)) - image_url = await upload_image_to_comfyapi(cls, start_frame, wait_label="Uploading start frame") - response = await sync_op( - cls, - ApiEndpoint(path="/proxy/kling/v1/videos/image2video", method="POST"), - response_model=TaskStatusResponse, - data=ImageToVideoWithAudioRequest( - model_name=model["model"], - image=image_url, - prompt=None if custom_multi_shot else multi_shot["prompt"], - negative_prompt=None if custom_multi_shot else multi_shot["negative_prompt"], - mode=mode, - duration=str(duration), - sound="on" if generate_audio else "off", - multi_shot=True if shot_type else None, - multi_prompt=multi_prompt_list, - shot_type=shot_type, - ), - ) - poll_path = f"/proxy/kling/v1/videos/image2video/{response.data.task_id}" + image_url = await upload_image_to_fal(start_frame) + data = ImageToVideoWithAudioRequest( + model_name=model["model"], + image=image_url, + prompt=None if custom_multi_shot else multi_shot["prompt"], + negative_prompt=None if custom_multi_shot else multi_shot["negative_prompt"], + mode=mode, + duration=str(duration), + sound="on" if generate_audio else "off", + multi_shot=True if shot_type else None, + multi_prompt=multi_prompt_list, + shot_type=shot_type, + ).model_dump(exclude_none=True) + result = await fal_run(cls, FAL_KLING_I2V, data) else: - response = await sync_op( - cls, - ApiEndpoint(path="/proxy/kling/v1/videos/text2video", method="POST"), - response_model=TaskStatusResponse, - data=TextToVideoWithAudioRequest( - model_name=model["model"], - aspect_ratio=model["aspect_ratio"], - prompt=None if custom_multi_shot else multi_shot["prompt"], - negative_prompt=None if custom_multi_shot else multi_shot["negative_prompt"], - mode=mode, - duration=str(duration), - sound="on" if generate_audio else "off", - multi_shot=True if shot_type else None, - multi_prompt=multi_prompt_list, - shot_type=shot_type, - ), - ) - poll_path = f"/proxy/kling/v1/videos/text2video/{response.data.task_id}" + data = TextToVideoWithAudioRequest( + model_name=model["model"], + aspect_ratio=model["aspect_ratio"], + prompt=None if custom_multi_shot else multi_shot["prompt"], + negative_prompt=None if custom_multi_shot else multi_shot["negative_prompt"], + mode=mode, + duration=str(duration), + sound="on" if generate_audio else "off", + multi_shot=True if shot_type else None, + multi_prompt=multi_prompt_list, + shot_type=shot_type, + ).model_dump(exclude_none=True) + result = await fal_run(cls, FAL_KLING_T2V, data) - if response.code: - raise RuntimeError( - f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}" - ) - final_response = await poll_op( - cls, - ApiEndpoint(path=poll_path), - response_model=TaskStatusResponse, - status_extractor=lambda r: (r.data.task_status if r.data else None), - ) - return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) + return IO.NodeOutput(await download_url_to_video_output(result["video"]["url"])) # TODO: verify fal.ai field name class KlingFirstLastFrameNode(IO.ComfyNode): @@ -3077,25 +2576,9 @@ def define_schema(cls) -> IO.Schema: IO.Video.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends( - widgets=["model.resolution", "generate_audio", "duration"], - ), - expr=""" - ( - $rates := {"1080p": {"off": 0.112, "on": 0.168}, "720p": {"off": 0.084, "on": 0.126}}; - $res := $lookup(widgets, "model.resolution"); - $audio := widgets.generate_audio ? "on" : "off"; - $rate := $lookup($lookup($rates, $res), $audio); - {"type":"usd","usd": $rate * widgets.duration} - ) - """, - ), ) @classmethod @@ -3115,33 +2598,19 @@ async def execute( validate_image_aspect_ratio(first_frame, (1, 2.5), (2.5, 1)) validate_image_dimensions(end_frame, min_width=300, min_height=300) validate_image_aspect_ratio(end_frame, (1, 2.5), (2.5, 1)) - image_url = await upload_image_to_comfyapi(cls, first_frame, wait_label="Uploading first frame") - image_tail_url = await upload_image_to_comfyapi(cls, end_frame, wait_label="Uploading end frame") - response = await sync_op( - cls, - ApiEndpoint(path="/proxy/kling/v1/videos/image2video", method="POST"), - response_model=TaskStatusResponse, - data=ImageToVideoWithAudioRequest( - model_name=model["model"], - image=image_url, - image_tail=image_tail_url, - prompt=prompt, - mode="pro" if model["resolution"] == "1080p" else "std", - duration=str(duration), - sound="on" if generate_audio else "off", - ), - ) - if response.code: - raise RuntimeError( - f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}" - ) - final_response = await poll_op( - cls, - ApiEndpoint(path=f"/proxy/kling/v1/videos/image2video/{response.data.task_id}"), - response_model=TaskStatusResponse, - status_extractor=lambda r: (r.data.task_status if r.data else None), - ) - return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) + image_url = await upload_image_to_fal(first_frame) + image_tail_url = await upload_image_to_fal(end_frame) + data = ImageToVideoWithAudioRequest( + model_name=model["model"], + image=image_url, + image_tail=image_tail_url, + prompt=prompt, + mode="pro" if model["resolution"] == "1080p" else "std", + duration=str(duration), + sound="on" if generate_audio else "off", + ).model_dump(exclude_none=True) + result = await fal_run(cls, FAL_KLING_I2V, data) + return IO.NodeOutput(await download_url_to_video_output(result["video"]["url"])) # TODO: verify fal.ai field name class KlingAvatarNode(IO.ComfyNode): @@ -3186,20 +2655,9 @@ def define_schema(cls) -> IO.Schema: IO.Video.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["mode"]), - expr=""" - ( - $prices := {"std": 0.056, "pro": 0.112}; - {"type":"usd","usd": $lookup($prices, widgets.mode), "format":{"suffix":"/second"}} - ) - """, - ), ) @classmethod @@ -3214,31 +2672,16 @@ async def execute( validate_image_dimensions(image, min_width=300, min_height=300) validate_image_aspect_ratio(image, (1, 2.5), (2.5, 1)) validate_audio_duration(sound_file, min_duration=2, max_duration=300) - response = await sync_op( - cls, - ApiEndpoint(path="/proxy/kling/v1/videos/avatar/image2video", method="POST"), - response_model=TaskStatusResponse, - data=KlingAvatarRequest( - image=await upload_image_to_comfyapi(cls, image), - sound_file=await upload_audio_to_comfyapi( - cls, sound_file, container_format="mp3", codec_name="libmp3lame", mime_type="audio/mpeg" - ), - prompt=prompt or None, - mode=mode, + data = KlingAvatarRequest( + image=await upload_image_to_fal(image), + sound_file=await upload_audio_to_fal( + sound_file, container_format="mp3", codec_name="libmp3lame", mime_type="audio/mpeg" ), - ) - if response.code: - raise RuntimeError( - f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}" - ) - final_response = await poll_op( - cls, - ApiEndpoint(path=f"/proxy/kling/v1/videos/avatar/image2video/{response.data.task_id}"), - response_model=TaskStatusResponse, - status_extractor=lambda r: (r.data.task_status if r.data else None), - max_poll_attempts=800, - ) - return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) + prompt=prompt or None, + mode=mode, + ).model_dump(exclude_none=True) + result = await fal_run(cls, FAL_KLING_AVATAR, data) + return IO.NodeOutput(await download_url_to_video_output(result["video"]["url"])) # TODO: verify fal.ai field name class KlingExtension(ComfyExtension): diff --git a/comfy_api_nodes/nodes_ltxv.py b/comfy_api_nodes/nodes_ltxv.py index 0a219af96e68..47dc9aaf6af3 100644 --- a/comfy_api_nodes/nodes_ltxv.py +++ b/comfy_api_nodes/nodes_ltxv.py @@ -8,9 +8,11 @@ ApiEndpoint, get_number_of_images, sync_op_raw, - upload_images_to_comfyapi, validate_string, ) +from comfy_api_nodes.util.client import fal_run + +FAL_LTX_VIDEO = "fal-ai/ltx-video-v097" MODELS_MAP = { "LTX-2 (Pro)": "ltx-2-pro", @@ -81,12 +83,9 @@ def define_schema(cls): IO.Video.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=PRICE_BADGE, ) @classmethod @@ -104,21 +103,17 @@ async def execute( raise ValueError( "Durations over 10s are only available for the Fast model at 1920x1080 resolution and 25 FPS." ) - response = await sync_op_raw( - cls, - ApiEndpoint("/proxy/ltx/v1/text-to-video", "POST"), - data=ExecuteTaskRequest( - prompt=prompt, - model=MODELS_MAP[model], - duration=duration, - resolution=resolution, - fps=fps, - generate_audio=generate_audio, - ), - as_binary=True, - max_retries=1, - ) - return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(response))) + result = await fal_run(cls, FAL_LTX_VIDEO, { + "prompt": prompt, + "model": MODELS_MAP[model], + "duration": duration, + "resolution": resolution, + "fps": fps, + "generate_audio": generate_audio, + }) # TODO: verify fal.ai field names + video_url = result["video"]["url"] + from comfy_api_nodes.util import download_url_to_video_output + return IO.NodeOutput(await download_url_to_video_output(video_url)) class ImageToVideoNode(IO.ComfyNode): @@ -159,12 +154,9 @@ def define_schema(cls): IO.Video.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=PRICE_BADGE, ) @classmethod @@ -185,22 +177,20 @@ async def execute( ) if get_number_of_images(image) != 1: raise ValueError("Currently only one input image is supported.") - response = await sync_op_raw( - cls, - ApiEndpoint("/proxy/ltx/v1/image-to-video", "POST"), - data=ExecuteTaskRequest( - image_uri=(await upload_images_to_comfyapi(cls, image, max_images=1, mime_type="image/png"))[0], - prompt=prompt, - model=MODELS_MAP[model], - duration=duration, - resolution=resolution, - fps=fps, - generate_audio=generate_audio, - ), - as_binary=True, - max_retries=1, - ) - return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(response))) + from comfy_api_nodes.util.upload_helpers import upload_image_to_fal + image_url = await upload_image_to_fal(image[0] if len(image.shape) > 3 else image, "image/png") + result = await fal_run(cls, FAL_LTX_VIDEO, { + "image_uri": image_url, + "prompt": prompt, + "model": MODELS_MAP[model], + "duration": duration, + "resolution": resolution, + "fps": fps, + "generate_audio": generate_audio, + }) # TODO: verify fal.ai field names + video_url = result["video"]["url"] + from comfy_api_nodes.util import download_url_to_video_output + return IO.NodeOutput(await download_url_to_video_output(video_url)) class LtxvApiExtension(ComfyExtension): diff --git a/comfy_api_nodes/nodes_luma.py b/comfy_api_nodes/nodes_luma.py index 9ed6cd299396..eaabd9a3f67d 100644 --- a/comfy_api_nodes/nodes_luma.py +++ b/comfy_api_nodes/nodes_luma.py @@ -30,9 +30,13 @@ download_url_to_video_output, poll_op, sync_op, - upload_images_to_comfyapi, + upload_images_to_fal, validate_string, ) +from comfy_api_nodes.util._helpers import get_fal_auth_header +from comfy_api_nodes.util.client import fal_run + +FAL_LUMA_RAY2 = "fal-ai/luma-dream-machine/ray-2" LUMA_T2V_AVERAGE_DURATION = 105 LUMA_I2V_AVERAGE_DURATION = 100 @@ -184,24 +188,9 @@ def define_schema(cls) -> IO.Schema: ], outputs=[IO.Image.Output()], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["model"]), - expr=""" - ( - $m := widgets.model; - $contains($m,"photon-flash-1") - ? {"type":"usd","usd":0.0027} - : $contains($m,"photon-1") - ? {"type":"usd","usd":0.0104} - : {"type":"usd","usd":0.0246} - ) - """, - ), ) @classmethod @@ -228,36 +217,26 @@ async def execute( # handle character_ref images character_ref = None if character_image is not None: - download_urls = await upload_images_to_comfyapi(cls, character_image, max_images=4) + download_urls = await upload_images_to_fal(character_image, max_images=4) character_ref = LumaCharacterRef(identity0=LumaImageIdentity(images=download_urls)) - response_api = await sync_op( - cls, - ApiEndpoint(path="/proxy/luma/generations/image", method="POST"), - response_model=LumaGeneration, - data=LumaImageGenerationRequest( - prompt=prompt, - model=model, - aspect_ratio=aspect_ratio, - image_ref=api_image_ref, - style_ref=api_style_ref, - character_ref=character_ref, - ), - ) - response_poll = await poll_op( - cls, - ApiEndpoint(path=f"/proxy/luma/generations/{response_api.id}"), - response_model=LumaGeneration, - status_extractor=lambda x: x.state, - ) - return IO.NodeOutput(await download_url_to_image_tensor(response_poll.assets.image)) + data = { + "prompt": prompt, + "model": model, + "aspect_ratio": aspect_ratio, + "image_ref": api_image_ref.model_dump() if api_image_ref and hasattr(api_image_ref, 'model_dump') else api_image_ref, + "style_ref": api_style_ref.model_dump() if api_style_ref and hasattr(api_style_ref, 'model_dump') else api_style_ref, + "character_ref": character_ref.model_dump() if character_ref and hasattr(character_ref, 'model_dump') else character_ref, + } + result = await fal_run(cls, FAL_LUMA_RAY2, data) # TODO: verify fal.ai field names; use correct fal model for luma image + return IO.NodeOutput(await download_url_to_image_tensor(result["image"]["url"])) @classmethod async def _convert_luma_refs(cls, luma_ref: LumaReferenceChain, max_refs: int): luma_urls = [] ref_count = 0 for ref in luma_ref.refs: - download_urls = await upload_images_to_comfyapi(cls, ref.image, max_images=1) + download_urls = await upload_images_to_fal(ref.image, max_images=1) luma_urls.append(download_urls[0]) ref_count += 1 if ref_count >= max_refs: @@ -311,24 +290,9 @@ def define_schema(cls) -> IO.Schema: ], outputs=[IO.Image.Output()], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["model"]), - expr=""" - ( - $m := widgets.model; - $contains($m,"photon-flash-1") - ? {"type":"usd","usd":0.0027} - : $contains($m,"photon-1") - ? {"type":"usd","usd":0.0104} - : {"type":"usd","usd":0.0246} - ) - """, - ), ) @classmethod @@ -340,27 +304,17 @@ async def execute( image_weight: float, seed, ) -> IO.NodeOutput: - download_urls = await upload_images_to_comfyapi(cls, image, max_images=1) + download_urls = await upload_images_to_fal(image, max_images=1) image_url = download_urls[0] - response_api = await sync_op( - cls, - ApiEndpoint(path="/proxy/luma/generations/image", method="POST"), - response_model=LumaGeneration, - data=LumaImageGenerationRequest( - prompt=prompt, - model=model, - modify_image_ref=LumaModifyImageRef( - url=image_url, weight=round(max(min(1.0 - image_weight, 0.98), 0.0), 2) - ), - ), - ) - response_poll = await poll_op( - cls, - ApiEndpoint(path=f"/proxy/luma/generations/{response_api.id}"), - response_model=LumaGeneration, - status_extractor=lambda x: x.state, - ) - return IO.NodeOutput(await download_url_to_image_tensor(response_poll.assets.image)) + result = await fal_run(cls, FAL_LUMA_RAY2, { + "prompt": prompt, + "model": model, + "modify_image_ref": { + "url": image_url, + "weight": round(max(min(1.0 - image_weight, 0.98), 0.0), 2), + }, + }) # TODO: verify fal.ai field names; use correct fal model for luma image modify + return IO.NodeOutput(await download_url_to_image_tensor(result["image"]["url"])) class LumaTextToVideoGenerationNode(IO.ComfyNode): @@ -416,12 +370,9 @@ def define_schema(cls) -> IO.Schema: ], outputs=[IO.Video.Output()], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=PRICE_BADGE_VIDEO, ) @classmethod @@ -440,28 +391,16 @@ async def execute( duration = duration if model != LumaVideoModel.ray_1_6 else None resolution = resolution if model != LumaVideoModel.ray_1_6 else None - response_api = await sync_op( - cls, - ApiEndpoint(path="/proxy/luma/generations", method="POST"), - response_model=LumaGeneration, - data=LumaGenerationRequest( - prompt=prompt, - model=model, - resolution=resolution, - aspect_ratio=aspect_ratio, - duration=duration, - loop=loop, - concepts=luma_concepts.create_api_model() if luma_concepts else None, - ), - ) - response_poll = await poll_op( - cls, - ApiEndpoint(path=f"/proxy/luma/generations/{response_api.id}"), - response_model=LumaGeneration, - status_extractor=lambda x: x.state, - estimated_duration=LUMA_T2V_AVERAGE_DURATION, - ) - return IO.NodeOutput(await download_url_to_video_output(response_poll.assets.video)) + result = await fal_run(cls, FAL_LUMA_RAY2, { + "prompt": prompt, + "model": model, + "resolution": resolution, + "aspect_ratio": aspect_ratio, + "duration": duration, + "loop": loop, + "concepts": luma_concepts.create_api_model() if luma_concepts else None, + }, estimated_duration=LUMA_T2V_AVERAGE_DURATION) # TODO: verify fal.ai field names + return IO.NodeOutput(await download_url_to_video_output(result["video"]["url"])) class LumaImageToVideoGenerationNode(IO.ComfyNode): @@ -527,12 +466,9 @@ def define_schema(cls) -> IO.Schema: ], outputs=[IO.Video.Output()], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=PRICE_BADGE_VIDEO, ) @@ -554,29 +490,17 @@ async def execute( keyframes = await cls._convert_to_keyframes(first_image, last_image) duration = duration if model != LumaVideoModel.ray_1_6 else None resolution = resolution if model != LumaVideoModel.ray_1_6 else None - response_api = await sync_op( - cls, - ApiEndpoint(path="/proxy/luma/generations", method="POST"), - response_model=LumaGeneration, - data=LumaGenerationRequest( - prompt=prompt, - model=model, - aspect_ratio=LumaAspectRatio.ratio_16_9, # ignored, but still needed by the API for some reason - resolution=resolution, - duration=duration, - loop=loop, - keyframes=keyframes, - concepts=luma_concepts.create_api_model() if luma_concepts else None, - ), - ) - response_poll = await poll_op( - cls, - poll_endpoint=ApiEndpoint(path=f"/proxy/luma/generations/{response_api.id}"), - response_model=LumaGeneration, - status_extractor=lambda x: x.state, - estimated_duration=LUMA_I2V_AVERAGE_DURATION, - ) - return IO.NodeOutput(await download_url_to_video_output(response_poll.assets.video)) + result = await fal_run(cls, FAL_LUMA_RAY2, { + "prompt": prompt, + "model": model, + "aspect_ratio": LumaAspectRatio.ratio_16_9.value if hasattr(LumaAspectRatio.ratio_16_9, 'value') else str(LumaAspectRatio.ratio_16_9), + "resolution": resolution, + "duration": duration, + "loop": loop, + "keyframes": keyframes.model_dump() if keyframes and hasattr(keyframes, 'model_dump') else keyframes, + "concepts": luma_concepts.create_api_model() if luma_concepts else None, + }, estimated_duration=LUMA_I2V_AVERAGE_DURATION) # TODO: verify fal.ai field names + return IO.NodeOutput(await download_url_to_video_output(result["video"]["url"])) @classmethod async def _convert_to_keyframes( @@ -589,10 +513,10 @@ async def _convert_to_keyframes( frame0 = None frame1 = None if first_image is not None: - download_urls = await upload_images_to_comfyapi(cls, first_image, max_images=1) + download_urls = await upload_images_to_fal(first_image, max_images=1) frame0 = LumaImageReference(type="image", url=download_urls[0]) if last_image is not None: - download_urls = await upload_images_to_comfyapi(cls, last_image, max_images=1) + download_urls = await upload_images_to_fal(last_image, max_images=1) frame1 = LumaImageReference(type="image", url=download_urls[0]) return LumaKeyframes(frame0=frame0, frame1=frame1) diff --git a/comfy_api_nodes/nodes_magnific.py b/comfy_api_nodes/nodes_magnific.py deleted file mode 100644 index 0f53208d4f3b..000000000000 --- a/comfy_api_nodes/nodes_magnific.py +++ /dev/null @@ -1,945 +0,0 @@ -import math - -from typing_extensions import override - -from comfy_api.latest import IO, ComfyExtension, Input -from comfy_api_nodes.apis.magnific import ( - ImageRelightAdvancedSettingsRequest, - ImageRelightRequest, - ImageSkinEnhancerCreativeRequest, - ImageSkinEnhancerFaithfulRequest, - ImageSkinEnhancerFlexibleRequest, - ImageStyleTransferRequest, - ImageUpscalerCreativeRequest, - ImageUpscalerPrecisionV2Request, - InputAdvancedSettings, - InputPortraitMode, - InputSkinEnhancerMode, - TaskResponse, -) -from comfy_api_nodes.util import ( - ApiEndpoint, - download_url_to_image_tensor, - downscale_image_tensor, - get_image_dimensions, - get_number_of_images, - poll_op, - sync_op, - upload_images_to_comfyapi, - validate_image_aspect_ratio, - validate_image_dimensions, -) - -_EUR_TO_USD = 1.19 - - -def _tier_price_eur(megapixels: float) -> float: - """Price in EUR for a single Magnific upscaling step based on input megapixels.""" - if megapixels <= 1.3: - return 0.143 - if megapixels <= 3.0: - return 0.286 - if megapixels <= 6.4: - return 0.429 - return 1.716 - - -def _calculate_magnific_upscale_price_usd(width: int, height: int, scale: int) -> float: - """Calculate total Magnific upscale price in USD for given input dimensions and scale factor.""" - num_steps = int(math.log2(scale)) - total_eur = 0.0 - pixels = width * height - for _ in range(num_steps): - total_eur += _tier_price_eur(pixels / 1_000_000) - pixels *= 4 - return round(total_eur * _EUR_TO_USD, 2) - - -class MagnificImageUpscalerCreativeNode(IO.ComfyNode): - @classmethod - def define_schema(cls): - return IO.Schema( - node_id="MagnificImageUpscalerCreativeNode", - display_name="Magnific Image Upscale (Creative)", - category="api node/image/Magnific", - description="Prompt‑guided enhancement, stylization, and 2x/4x/8x/16x upscaling. " - "Maximum output: 25.3 megapixels.", - inputs=[ - IO.Image.Input("image"), - IO.String.Input("prompt", multiline=True, default=""), - IO.Combo.Input("scale_factor", options=["2x", "4x", "8x", "16x"]), - IO.Combo.Input( - "optimized_for", - options=[ - "standard", - "soft_portraits", - "hard_portraits", - "art_n_illustration", - "videogame_assets", - "nature_n_landscapes", - "films_n_photography", - "3d_renders", - "science_fiction_n_horror", - ], - ), - IO.Int.Input("creativity", min=-10, max=10, default=0, display_mode=IO.NumberDisplay.slider), - IO.Int.Input( - "hdr", - min=-10, - max=10, - default=0, - tooltip="The level of definition and detail.", - display_mode=IO.NumberDisplay.slider, - ), - IO.Int.Input( - "resemblance", - min=-10, - max=10, - default=0, - tooltip="The level of resemblance to the original image.", - display_mode=IO.NumberDisplay.slider, - ), - IO.Int.Input( - "fractality", - min=-10, - max=10, - default=0, - tooltip="The strength of the prompt and intricacy per square pixel.", - display_mode=IO.NumberDisplay.slider, - ), - IO.Combo.Input( - "engine", - options=["automatic", "magnific_illusio", "magnific_sharpy", "magnific_sparkle"], - advanced=True, - ), - IO.Boolean.Input( - "auto_downscale", - default=False, - tooltip="Automatically downscale input image if output would exceed maximum pixel limit.", - advanced=True, - ), - ], - outputs=[ - IO.Image.Output(), - ], - hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, - IO.Hidden.unique_id, - ], - is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["scale_factor", "auto_downscale"]), - expr=""" - ( - $ad := widgets.auto_downscale; - $mins := $ad - ? {"2x": 0.172, "4x": 0.343, "8x": 0.515, "16x": 0.515} - : {"2x": 0.172, "4x": 0.343, "8x": 0.515, "16x": 0.844}; - $maxs := {"2x": 0.515, "4x": 0.844, "8x": 1.015, "16x": 1.187}; - { - "type": "range_usd", - "min_usd": $lookup($mins, widgets.scale_factor), - "max_usd": $lookup($maxs, widgets.scale_factor), - "format": { "approximate": true } - } - ) - """, - ), - ) - - @classmethod - async def execute( - cls, - image: Input.Image, - prompt: str, - scale_factor: str, - optimized_for: str, - creativity: int, - hdr: int, - resemblance: int, - fractality: int, - engine: str, - auto_downscale: bool, - ) -> IO.NodeOutput: - if get_number_of_images(image) != 1: - raise ValueError("Exactly one input image is required.") - validate_image_aspect_ratio(image, (1, 3), (3, 1), strict=False) - validate_image_dimensions(image, min_height=160, min_width=160) - - max_output_pixels = 25_300_000 - height, width = get_image_dimensions(image) - requested_scale = int(scale_factor.rstrip("x")) - output_pixels = height * width * requested_scale * requested_scale - - if output_pixels > max_output_pixels: - if auto_downscale: - # Find optimal scale factor that doesn't require >2x downscale. - # Server upscales in 2x steps, so aggressive downscaling degrades quality. - input_pixels = width * height - scale = 2 - max_input_pixels = max_output_pixels // 4 - for candidate in [16, 8, 4, 2]: - if candidate > requested_scale: - continue - scale_output_pixels = input_pixels * candidate * candidate - if scale_output_pixels <= max_output_pixels: - scale = candidate - max_input_pixels = None - break - downscale_ratio = math.sqrt(scale_output_pixels / max_output_pixels) - if downscale_ratio <= 2.0: - scale = candidate - max_input_pixels = max_output_pixels // (candidate * candidate) - break - - if max_input_pixels is not None: - image = downscale_image_tensor(image, total_pixels=max_input_pixels) - scale_factor = f"{scale}x" - else: - raise ValueError( - f"Output size ({width * requested_scale}x{height * requested_scale} = {output_pixels:,} pixels) " - f"exceeds maximum allowed size of {max_output_pixels:,} pixels. " - f"Use a smaller input image or lower scale factor." - ) - - final_height, final_width = get_image_dimensions(image) - actual_scale = int(scale_factor.rstrip("x")) - price_usd = _calculate_magnific_upscale_price_usd(final_width, final_height, actual_scale) - - initial_res = await sync_op( - cls, - ApiEndpoint(path="/proxy/freepik/v1/ai/image-upscaler", method="POST"), - response_model=TaskResponse, - data=ImageUpscalerCreativeRequest( - image=(await upload_images_to_comfyapi(cls, image, max_images=1, total_pixels=None))[0], - scale_factor=scale_factor, - optimized_for=optimized_for, - creativity=creativity, - hdr=hdr, - resemblance=resemblance, - fractality=fractality, - engine=engine, - prompt=prompt if prompt else None, - ), - ) - final_response = await poll_op( - cls, - ApiEndpoint(path=f"/proxy/freepik/v1/ai/image-upscaler/{initial_res.task_id}"), - response_model=TaskResponse, - status_extractor=lambda x: x.status, - price_extractor=lambda _: price_usd, - poll_interval=10.0, - max_poll_attempts=480, - ) - return IO.NodeOutput(await download_url_to_image_tensor(final_response.generated[0])) - - -class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode): - @classmethod - def define_schema(cls): - return IO.Schema( - node_id="MagnificImageUpscalerPreciseV2Node", - display_name="Magnific Image Upscale (Precise V2)", - category="api node/image/Magnific", - description="High-fidelity upscaling with fine control over sharpness, grain, and detail. " - "Maximum output: 10060×10060 pixels.", - inputs=[ - IO.Image.Input("image"), - IO.Combo.Input("scale_factor", options=["2x", "4x", "8x", "16x"]), - IO.Combo.Input( - "flavor", - options=["sublime", "photo", "photo_denoiser"], - tooltip="Processing style: " - "sublime for general use, photo for photographs, photo_denoiser for noisy photos.", - ), - IO.Int.Input( - "sharpen", - min=0, - max=100, - default=7, - tooltip="Image sharpness intensity. Higher values increase edge definition and clarity.", - display_mode=IO.NumberDisplay.slider, - ), - IO.Int.Input( - "smart_grain", - min=0, - max=100, - default=7, - tooltip="Intelligent grain/texture enhancement to prevent the image from " - "looking too smooth or artificial.", - display_mode=IO.NumberDisplay.slider, - ), - IO.Int.Input( - "ultra_detail", - min=0, - max=100, - default=30, - tooltip="Controls fine detail, textures, and micro-details added during upscaling.", - display_mode=IO.NumberDisplay.slider, - ), - IO.Boolean.Input( - "auto_downscale", - default=False, - tooltip="Automatically downscale input image if output would exceed maximum resolution.", - advanced=True, - ), - ], - outputs=[ - IO.Image.Output(), - ], - hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, - IO.Hidden.unique_id, - ], - is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["scale_factor"]), - expr=""" - ( - $mins := {"2x": 0.172, "4x": 0.343, "8x": 0.515, "16x": 0.844}; - $maxs := {"2x": 2.045, "4x": 2.545, "8x": 2.889, "16x": 3.06}; - { - "type": "range_usd", - "min_usd": $lookup($mins, widgets.scale_factor), - "max_usd": $lookup($maxs, widgets.scale_factor), - "format": { "approximate": true } - } - ) - """, - ), - ) - - @classmethod - async def execute( - cls, - image: Input.Image, - scale_factor: str, - flavor: str, - sharpen: int, - smart_grain: int, - ultra_detail: int, - auto_downscale: bool, - ) -> IO.NodeOutput: - if get_number_of_images(image) != 1: - raise ValueError("Exactly one input image is required.") - validate_image_aspect_ratio(image, (1, 3), (3, 1), strict=False) - validate_image_dimensions(image, min_height=160, min_width=160) - - max_output_dimension = 10060 - height, width = get_image_dimensions(image) - requested_scale = int(scale_factor.strip("x")) - output_width = width * requested_scale - output_height = height * requested_scale - - if output_width > max_output_dimension or output_height > max_output_dimension: - if auto_downscale: - # Find optimal scale factor that doesn't require >2x downscale. - # Server upscales in 2x steps, so aggressive downscaling degrades quality. - max_dim = max(width, height) - scale = 2 - max_input_dim = max_output_dimension // 2 - scale_ratio = max_input_dim / max_dim - max_input_pixels = int(width * height * scale_ratio * scale_ratio) - for candidate in [16, 8, 4, 2]: - if candidate > requested_scale: - continue - output_dim = max_dim * candidate - if output_dim <= max_output_dimension: - scale = candidate - max_input_pixels = None - break - downscale_ratio = output_dim / max_output_dimension - if downscale_ratio <= 2.0: - scale = candidate - max_input_dim = max_output_dimension // candidate - scale_ratio = max_input_dim / max_dim - max_input_pixels = int(width * height * scale_ratio * scale_ratio) - break - - if max_input_pixels is not None: - image = downscale_image_tensor(image, total_pixels=max_input_pixels) - requested_scale = scale - else: - raise ValueError( - f"Output dimensions ({output_width}x{output_height}) exceed maximum allowed " - f"resolution of {max_output_dimension}x{max_output_dimension} pixels. " - f"Use a smaller input image or lower scale factor." - ) - - final_height, final_width = get_image_dimensions(image) - price_usd = _calculate_magnific_upscale_price_usd(final_width, final_height, requested_scale) - - initial_res = await sync_op( - cls, - ApiEndpoint(path="/proxy/freepik/v1/ai/image-upscaler-precision-v2", method="POST"), - response_model=TaskResponse, - data=ImageUpscalerPrecisionV2Request( - image=(await upload_images_to_comfyapi(cls, image, max_images=1, total_pixels=None))[0], - scale_factor=requested_scale, - flavor=flavor, - sharpen=sharpen, - smart_grain=smart_grain, - ultra_detail=ultra_detail, - ), - ) - final_response = await poll_op( - cls, - ApiEndpoint(path=f"/proxy/freepik/v1/ai/image-upscaler-precision-v2/{initial_res.task_id}"), - response_model=TaskResponse, - status_extractor=lambda x: x.status, - price_extractor=lambda _: price_usd, - poll_interval=10.0, - max_poll_attempts=480, - ) - return IO.NodeOutput(await download_url_to_image_tensor(final_response.generated[0])) - - -class MagnificImageStyleTransferNode(IO.ComfyNode): - @classmethod - def define_schema(cls): - return IO.Schema( - node_id="MagnificImageStyleTransferNode", - display_name="Magnific Image Style Transfer", - category="api node/image/Magnific", - description="Transfer the style from a reference image to your input image.", - inputs=[ - IO.Image.Input("image", tooltip="The image to apply style transfer to."), - IO.Image.Input("reference_image", tooltip="The reference image to extract style from."), - IO.String.Input("prompt", multiline=True, default=""), - IO.Int.Input( - "style_strength", - min=0, - max=100, - default=100, - tooltip="Percentage of style strength.", - display_mode=IO.NumberDisplay.slider, - ), - IO.Int.Input( - "structure_strength", - min=0, - max=100, - default=50, - tooltip="Maintains the structure of the original image.", - display_mode=IO.NumberDisplay.slider, - ), - IO.Combo.Input( - "flavor", - options=["faithful", "gen_z", "psychedelia", "detaily", "clear", "donotstyle", "donotstyle_sharp"], - tooltip="Style transfer flavor.", - ), - IO.Combo.Input( - "engine", - options=[ - "balanced", - "definio", - "illusio", - "3d_cartoon", - "colorful_anime", - "caricature", - "real", - "super_real", - "softy", - ], - tooltip="Processing engine selection.", - advanced=True, - ), - IO.DynamicCombo.Input( - "portrait_mode", - options=[ - IO.DynamicCombo.Option("disabled", []), - IO.DynamicCombo.Option( - "enabled", - [ - IO.Combo.Input( - "portrait_style", - options=["standard", "pop", "super_pop"], - tooltip="Visual style applied to portrait images.", - ), - IO.Combo.Input( - "portrait_beautifier", - options=["none", "beautify_face", "beautify_face_max"], - tooltip="Facial beautification intensity on portraits.", - ), - ], - ), - ], - tooltip="Enable portrait mode for facial enhancements.", - ), - IO.Boolean.Input( - "fixed_generation", - default=True, - tooltip="When disabled, expect each generation to introduce a degree of randomness, " - "leading to more diverse outcomes.", - advanced=True, - ), - ], - outputs=[ - IO.Image.Output(), - ], - hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, - IO.Hidden.unique_id, - ], - is_api_node=True, - price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.11}""", - ), - ) - - @classmethod - async def execute( - cls, - image: Input.Image, - reference_image: Input.Image, - prompt: str, - style_strength: int, - structure_strength: int, - flavor: str, - engine: str, - portrait_mode: InputPortraitMode, - fixed_generation: bool, - ) -> IO.NodeOutput: - if get_number_of_images(image) != 1: - raise ValueError("Exactly one input image is required.") - if get_number_of_images(reference_image) != 1: - raise ValueError("Exactly one reference image is required.") - validate_image_aspect_ratio(image, (1, 3), (3, 1), strict=False) - validate_image_aspect_ratio(reference_image, (1, 3), (3, 1), strict=False) - validate_image_dimensions(image, min_height=160, min_width=160) - validate_image_dimensions(reference_image, min_height=160, min_width=160) - - is_portrait = portrait_mode["portrait_mode"] == "enabled" - portrait_style = portrait_mode.get("portrait_style", "standard") - portrait_beautifier = portrait_mode.get("portrait_beautifier", "none") - - uploaded_urls = await upload_images_to_comfyapi(cls, [image, reference_image], max_images=2) - - initial_res = await sync_op( - cls, - ApiEndpoint(path="/proxy/freepik/v1/ai/image-style-transfer", method="POST"), - response_model=TaskResponse, - data=ImageStyleTransferRequest( - image=uploaded_urls[0], - reference_image=uploaded_urls[1], - prompt=prompt if prompt else None, - style_strength=style_strength, - structure_strength=structure_strength, - is_portrait=is_portrait, - portrait_style=portrait_style if is_portrait else None, - portrait_beautifier=portrait_beautifier if is_portrait and portrait_beautifier != "none" else None, - flavor=flavor, - engine=engine, - fixed_generation=fixed_generation, - ), - ) - final_response = await poll_op( - cls, - ApiEndpoint(path=f"/proxy/freepik/v1/ai/image-style-transfer/{initial_res.task_id}"), - response_model=TaskResponse, - status_extractor=lambda x: x.status, - poll_interval=10.0, - max_poll_attempts=480, - ) - return IO.NodeOutput(await download_url_to_image_tensor(final_response.generated[0])) - - -class MagnificImageRelightNode(IO.ComfyNode): - @classmethod - def define_schema(cls): - return IO.Schema( - node_id="MagnificImageRelightNode", - display_name="Magnific Image Relight", - category="api node/image/Magnific", - description="Relight an image with lighting adjustments and optional reference-based light transfer.", - inputs=[ - IO.Image.Input("image", tooltip="The image to relight."), - IO.String.Input( - "prompt", - multiline=True, - default="", - tooltip="Descriptive guidance for lighting. Supports emphasis notation (1-1.4).", - ), - IO.Int.Input( - "light_transfer_strength", - min=0, - max=100, - default=100, - tooltip="Intensity of light transfer application.", - display_mode=IO.NumberDisplay.slider, - ), - IO.Combo.Input( - "style", - options=[ - "standard", - "darker_but_realistic", - "clean", - "smooth", - "brighter", - "contrasted_n_hdr", - "just_composition", - ], - tooltip="Stylistic output preference.", - ), - IO.Boolean.Input( - "interpolate_from_original", - default=False, - tooltip="Restricts generation freedom to match original more closely.", - advanced=True, - ), - IO.Boolean.Input( - "change_background", - default=True, - tooltip="Modifies background based on prompt/reference.", - advanced=True, - ), - IO.Boolean.Input( - "preserve_details", - default=True, - tooltip="Maintains texture and fine details from original.", - advanced=True, - ), - IO.DynamicCombo.Input( - "advanced_settings", - options=[ - IO.DynamicCombo.Option("disabled", []), - IO.DynamicCombo.Option( - "enabled", - [ - IO.Int.Input( - "whites", - min=0, - max=100, - default=50, - tooltip="Adjusts the brightest tones in the image.", - display_mode=IO.NumberDisplay.slider, - ), - IO.Int.Input( - "blacks", - min=0, - max=100, - default=50, - tooltip="Adjusts the darkest tones in the image.", - display_mode=IO.NumberDisplay.slider, - ), - IO.Int.Input( - "brightness", - min=0, - max=100, - default=50, - tooltip="Overall brightness adjustment.", - display_mode=IO.NumberDisplay.slider, - ), - IO.Int.Input( - "contrast", - min=0, - max=100, - default=50, - tooltip="Contrast adjustment.", - display_mode=IO.NumberDisplay.slider, - ), - IO.Int.Input( - "saturation", - min=0, - max=100, - default=50, - tooltip="Color saturation adjustment.", - display_mode=IO.NumberDisplay.slider, - ), - IO.Combo.Input( - "engine", - options=[ - "automatic", - "balanced", - "cool", - "real", - "illusio", - "fairy", - "colorful_anime", - "hard_transform", - "softy", - ], - tooltip="Processing engine selection.", - ), - IO.Combo.Input( - "transfer_light_a", - options=["automatic", "low", "medium", "normal", "high", "high_on_faces"], - tooltip="The intensity of light transfer.", - ), - IO.Combo.Input( - "transfer_light_b", - options=[ - "automatic", - "composition", - "straight", - "smooth_in", - "smooth_out", - "smooth_both", - "reverse_both", - "soft_in", - "soft_out", - "soft_mid", - # "strong_mid", # Commented out because requests fail when this is set. - "style_shift", - "strong_shift", - ], - tooltip="Also modifies light transfer intensity. " - "Can be combined with the previous control for varied effects.", - ), - IO.Boolean.Input( - "fixed_generation", - default=True, - tooltip="Ensures consistent output with the same settings.", - ), - ], - ), - ], - tooltip="Fine-tuning options for advanced lighting control.", - ), - IO.Image.Input( - "reference_image", - optional=True, - tooltip="Optional reference image to transfer lighting from.", - ), - ], - outputs=[ - IO.Image.Output(), - ], - hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, - IO.Hidden.unique_id, - ], - is_api_node=True, - price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.11}""", - ), - ) - - @classmethod - async def execute( - cls, - image: Input.Image, - prompt: str, - light_transfer_strength: int, - style: str, - interpolate_from_original: bool, - change_background: bool, - preserve_details: bool, - advanced_settings: InputAdvancedSettings, - reference_image: Input.Image | None = None, - ) -> IO.NodeOutput: - if get_number_of_images(image) != 1: - raise ValueError("Exactly one input image is required.") - if reference_image is not None and get_number_of_images(reference_image) != 1: - raise ValueError("Exactly one reference image is required.") - validate_image_aspect_ratio(image, (1, 3), (3, 1), strict=False) - validate_image_dimensions(image, min_height=160, min_width=160) - if reference_image is not None: - validate_image_aspect_ratio(reference_image, (1, 3), (3, 1), strict=False) - validate_image_dimensions(reference_image, min_height=160, min_width=160) - - image_url = (await upload_images_to_comfyapi(cls, image, max_images=1))[0] - reference_url = None - if reference_image is not None: - reference_url = (await upload_images_to_comfyapi(cls, reference_image, max_images=1))[0] - - adv_settings = None - if advanced_settings["advanced_settings"] == "enabled": - adv_settings = ImageRelightAdvancedSettingsRequest( - whites=advanced_settings["whites"], - blacks=advanced_settings["blacks"], - brightness=advanced_settings["brightness"], - contrast=advanced_settings["contrast"], - saturation=advanced_settings["saturation"], - engine=advanced_settings["engine"], - transfer_light_a=advanced_settings["transfer_light_a"], - transfer_light_b=advanced_settings["transfer_light_b"], - fixed_generation=advanced_settings["fixed_generation"], - ) - - initial_res = await sync_op( - cls, - ApiEndpoint(path="/proxy/freepik/v1/ai/image-relight", method="POST"), - response_model=TaskResponse, - data=ImageRelightRequest( - image=image_url, - prompt=prompt if prompt else None, - transfer_light_from_reference_image=reference_url, - light_transfer_strength=light_transfer_strength, - interpolate_from_original=interpolate_from_original, - change_background=change_background, - style=style, - preserve_details=preserve_details, - advanced_settings=adv_settings, - ), - ) - final_response = await poll_op( - cls, - ApiEndpoint(path=f"/proxy/freepik/v1/ai/image-relight/{initial_res.task_id}"), - response_model=TaskResponse, - status_extractor=lambda x: x.status, - poll_interval=10.0, - max_poll_attempts=480, - ) - return IO.NodeOutput(await download_url_to_image_tensor(final_response.generated[0])) - - -class MagnificImageSkinEnhancerNode(IO.ComfyNode): - @classmethod - def define_schema(cls): - return IO.Schema( - node_id="MagnificImageSkinEnhancerNode", - display_name="Magnific Image Skin Enhancer", - category="api node/image/Magnific", - description="Skin enhancement for portraits with multiple processing modes.", - inputs=[ - IO.Image.Input("image", tooltip="The portrait image to enhance."), - IO.Int.Input( - "sharpen", - min=0, - max=100, - default=0, - tooltip="Sharpening intensity level.", - display_mode=IO.NumberDisplay.slider, - ), - IO.Int.Input( - "smart_grain", - min=0, - max=100, - default=2, - tooltip="Smart grain intensity level.", - display_mode=IO.NumberDisplay.slider, - ), - IO.DynamicCombo.Input( - "mode", - options=[ - IO.DynamicCombo.Option("creative", []), - IO.DynamicCombo.Option( - "faithful", - [ - IO.Int.Input( - "skin_detail", - min=0, - max=100, - default=80, - tooltip="Skin detail enhancement level.", - display_mode=IO.NumberDisplay.slider, - ), - ], - ), - IO.DynamicCombo.Option( - "flexible", - [ - IO.Combo.Input( - "optimized_for", - options=[ - "enhance_skin", - "improve_lighting", - "enhance_everything", - "transform_to_real", - "no_make_up", - ], - tooltip="Enhancement optimization target.", - ), - ], - ), - ], - tooltip="Processing mode: creative for artistic enhancement, " - "faithful for preserving original appearance, " - "flexible for targeted optimization.", - ), - ], - outputs=[ - IO.Image.Output(), - ], - hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, - IO.Hidden.unique_id, - ], - is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["mode"]), - expr=""" - ( - $rates := {"creative": 0.29, "faithful": 0.37, "flexible": 0.45}; - {"type":"usd","usd": $lookup($rates, widgets.mode)} - ) - """, - ), - ) - - @classmethod - async def execute( - cls, - image: Input.Image, - sharpen: int, - smart_grain: int, - mode: InputSkinEnhancerMode, - ) -> IO.NodeOutput: - if get_number_of_images(image) != 1: - raise ValueError("Exactly one input image is required.") - validate_image_aspect_ratio(image, (1, 3), (3, 1), strict=False) - validate_image_dimensions(image, min_height=160, min_width=160) - - image_url = (await upload_images_to_comfyapi(cls, image, max_images=1, total_pixels=4096 * 4096))[0] - selected_mode = mode["mode"] - - if selected_mode == "creative": - endpoint = "creative" - data = ImageSkinEnhancerCreativeRequest( - image=image_url, - sharpen=sharpen, - smart_grain=smart_grain, - ) - elif selected_mode == "faithful": - endpoint = "faithful" - data = ImageSkinEnhancerFaithfulRequest( - image=image_url, - sharpen=sharpen, - smart_grain=smart_grain, - skin_detail=mode["skin_detail"], - ) - else: # flexible - endpoint = "flexible" - data = ImageSkinEnhancerFlexibleRequest( - image=image_url, - sharpen=sharpen, - smart_grain=smart_grain, - optimized_for=mode["optimized_for"], - ) - - initial_res = await sync_op( - cls, - ApiEndpoint(path=f"/proxy/freepik/v1/ai/skin-enhancer/{endpoint}", method="POST"), - response_model=TaskResponse, - data=data, - ) - final_response = await poll_op( - cls, - ApiEndpoint(path=f"/proxy/freepik/v1/ai/skin-enhancer/{initial_res.task_id}"), - response_model=TaskResponse, - status_extractor=lambda x: x.status, - poll_interval=10.0, - max_poll_attempts=480, - ) - return IO.NodeOutput(await download_url_to_image_tensor(final_response.generated[0])) - - -class MagnificExtension(ComfyExtension): - @override - async def get_node_list(self) -> list[type[IO.ComfyNode]]: - return [ - MagnificImageUpscalerCreativeNode, - MagnificImageUpscalerPreciseV2Node, - MagnificImageStyleTransferNode, - MagnificImageRelightNode, - MagnificImageSkinEnhancerNode, - ] - - -async def comfy_entrypoint() -> MagnificExtension: - return MagnificExtension() diff --git a/comfy_api_nodes/nodes_meshy.py b/comfy_api_nodes/nodes_meshy.py index 3cf577f4a0c0..ce8de6645803 100644 --- a/comfy_api_nodes/nodes_meshy.py +++ b/comfy_api_nodes/nodes_meshy.py @@ -4,26 +4,15 @@ from comfy_api_nodes.apis.meshy import ( InputShouldRemesh, InputShouldTexture, - MeshyAnimationRequest, - MeshyAnimationResult, - MeshyImageToModelRequest, - MeshyModelResult, - MeshyMultiImageToModelRequest, - MeshyRefineTask, - MeshyRiggedResult, - MeshyRiggingRequest, - MeshyTaskResponse, - MeshyTextToModelRequest, - MeshyTextureRequest, ) from comfy_api_nodes.util import ( - ApiEndpoint, download_url_to_file_3d, - poll_op, - sync_op, - upload_images_to_comfyapi, validate_string, ) +from comfy_api_nodes.util.client import fal_run +from comfy_api_nodes.util.upload_helpers import upload_image_to_fal + +FAL_MESHY_I2_3D = "fal-ai/meshy/v6/image-to-3d" class MeshyTextToModelNode(IO.ComfyNode): @@ -83,15 +72,10 @@ def define_schema(cls): IO.File3DFBX.Output(display_name="FBX"), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, is_output_node=True, - price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.8}""", - ), ) @classmethod @@ -106,35 +90,25 @@ async def execute( seed: int, ) -> IO.NodeOutput: validate_string(prompt, field_name="prompt", min_length=1, max_length=600) - response = await sync_op( - cls, - ApiEndpoint(path="/proxy/meshy/openapi/v2/text-to-3d", method="POST"), - response_model=MeshyTaskResponse, - data=MeshyTextToModelRequest( - prompt=prompt, - art_style=style, - ai_model=model, - topology=should_remesh.get("topology", None), - target_polycount=should_remesh.get("target_polycount", None), - should_remesh=should_remesh["should_remesh"] == "true", - symmetry_mode=symmetry_mode, - pose_mode=pose_mode.lower(), - seed=seed, - ), - ) - task_id = response.result - result = await poll_op( - cls, - ApiEndpoint(path=f"/proxy/meshy/openapi/v2/text-to-3d/{task_id}"), - response_model=MeshyModelResult, - status_extractor=lambda r: r.status, - progress_extractor=lambda r: r.progress, - ) + result = await fal_run(cls, FAL_MESHY_I2_3D, { + "prompt": prompt, + "art_style": style, + "ai_model": model, + "topology": should_remesh.get("topology", None), + "target_polycount": should_remesh.get("target_polycount", None), + "should_remesh": should_remesh["should_remesh"] == "true", + "symmetry_mode": symmetry_mode, + "pose_mode": pose_mode.lower(), + "seed": seed, + }) + # TODO: verify fal.ai field names + task_id = result.get("id", "meshy_task") + model_urls = result["model_urls"] return IO.NodeOutput( f"{task_id}.glb", task_id, - await download_url_to_file_3d(result.model_urls.glb, "glb", task_id=task_id), - await download_url_to_file_3d(result.model_urls.fbx, "fbx", task_id=task_id), + await download_url_to_file_3d(model_urls["glb"], "glb", task_id=task_id), + await download_url_to_file_3d(model_urls["fbx"], "fbx", task_id=task_id), ) @@ -178,15 +152,10 @@ def define_schema(cls): IO.File3DFBX.Output(display_name="FBX"), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, is_output_node=True, - price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.4}""", - ), ) @classmethod @@ -204,32 +173,22 @@ async def execute( if texture_prompt: validate_string(texture_prompt, field_name="texture_prompt", max_length=600) if texture_image is not None: - texture_image_url = (await upload_images_to_comfyapi(cls, texture_image, wait_label="Uploading texture"))[0] - response = await sync_op( - cls, - endpoint=ApiEndpoint(path="/proxy/meshy/openapi/v2/text-to-3d", method="POST"), - response_model=MeshyTaskResponse, - data=MeshyRefineTask( - preview_task_id=meshy_task_id, - enable_pbr=enable_pbr, - texture_prompt=texture_prompt if texture_prompt else None, - texture_image_url=texture_image_url, - ai_model=model, - ), - ) - task_id = response.result - result = await poll_op( - cls, - ApiEndpoint(path=f"/proxy/meshy/openapi/v2/text-to-3d/{task_id}"), - response_model=MeshyModelResult, - status_extractor=lambda r: r.status, - progress_extractor=lambda r: r.progress, - ) + texture_image_url = await upload_image_to_fal(texture_image) + result = await fal_run(cls, FAL_MESHY_I2_3D, { + "preview_task_id": meshy_task_id, + "enable_pbr": enable_pbr, + "texture_prompt": texture_prompt if texture_prompt else None, + "texture_image_url": texture_image_url, + "ai_model": model, + }) + # TODO: verify fal.ai field names + task_id = result.get("id", "meshy_task") + model_urls = result["model_urls"] return IO.NodeOutput( f"{task_id}.glb", task_id, - await download_url_to_file_3d(result.model_urls.glb, "glb", task_id=task_id), - await download_url_to_file_3d(result.model_urls.fbx, "fbx", task_id=task_id), + await download_url_to_file_3d(model_urls["glb"], "glb", task_id=task_id), + await download_url_to_file_3d(model_urls["fbx"], "fbx", task_id=task_id), ) @@ -321,21 +280,10 @@ def define_schema(cls): IO.File3DFBX.Output(display_name="FBX"), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, is_output_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["should_texture"]), - expr=""" - ( - $prices := {"true": 1.2, "false": 0.8}; - {"type":"usd","usd": $lookup($prices, widgets.should_texture)} - ) - """, - ), ) @classmethod @@ -358,43 +306,30 @@ async def execute( validate_string(should_texture["texture_prompt"], field_name="texture_prompt", max_length=600) texture_prompt = should_texture["texture_prompt"] if should_texture["texture_image"] is not None: - texture_image_url = ( - await upload_images_to_comfyapi( - cls, should_texture["texture_image"], wait_label="Uploading texture" - ) - )[0] - response = await sync_op( - cls, - ApiEndpoint(path="/proxy/meshy/openapi/v1/image-to-3d", method="POST"), - response_model=MeshyTaskResponse, - data=MeshyImageToModelRequest( - image_url=(await upload_images_to_comfyapi(cls, image, wait_label="Uploading base image"))[0], - ai_model=model, - topology=should_remesh.get("topology", None), - target_polycount=should_remesh.get("target_polycount", None), - symmetry_mode=symmetry_mode, - should_remesh=should_remesh["should_remesh"] == "true", - should_texture=texture, - enable_pbr=should_texture.get("enable_pbr", None), - pose_mode=pose_mode.lower(), - texture_prompt=texture_prompt, - texture_image_url=texture_image_url, - seed=seed, - ), - ) - task_id = response.result - result = await poll_op( - cls, - ApiEndpoint(path=f"/proxy/meshy/openapi/v1/image-to-3d/{task_id}"), - response_model=MeshyModelResult, - status_extractor=lambda r: r.status, - progress_extractor=lambda r: r.progress, - ) + texture_image_url = await upload_image_to_fal(should_texture["texture_image"]) + image_url = await upload_image_to_fal(image) + result = await fal_run(cls, FAL_MESHY_I2_3D, { + "image_url": image_url, + "ai_model": model, + "topology": should_remesh.get("topology", None), + "target_polycount": should_remesh.get("target_polycount", None), + "symmetry_mode": symmetry_mode, + "should_remesh": should_remesh["should_remesh"] == "true", + "should_texture": texture, + "enable_pbr": should_texture.get("enable_pbr", None), + "pose_mode": pose_mode.lower(), + "texture_prompt": texture_prompt, + "texture_image_url": texture_image_url, + "seed": seed, + }) + # TODO: verify fal.ai field names + task_id = result.get("id", "meshy_task") + model_urls = result["model_urls"] return IO.NodeOutput( f"{task_id}.glb", task_id, - await download_url_to_file_3d(result.model_urls.glb, "glb", task_id=task_id), - await download_url_to_file_3d(result.model_urls.fbx, "fbx", task_id=task_id), + await download_url_to_file_3d(model_urls["glb"], "glb", task_id=task_id), + await download_url_to_file_3d(model_urls["fbx"], "fbx", task_id=task_id), ) @@ -489,21 +424,10 @@ def define_schema(cls): IO.File3DFBX.Output(display_name="FBX"), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, is_output_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["should_texture"]), - expr=""" - ( - $prices := {"true": 0.6, "false": 0.2}; - {"type":"usd","usd": $lookup($prices, widgets.should_texture)} - ) - """, - ), ) @classmethod @@ -526,45 +450,32 @@ async def execute( validate_string(should_texture["texture_prompt"], field_name="texture_prompt", max_length=600) texture_prompt = should_texture["texture_prompt"] if should_texture["texture_image"] is not None: - texture_image_url = ( - await upload_images_to_comfyapi( - cls, should_texture["texture_image"], wait_label="Uploading texture" - ) - )[0] - response = await sync_op( - cls, - ApiEndpoint(path="/proxy/meshy/openapi/v1/multi-image-to-3d", method="POST"), - response_model=MeshyTaskResponse, - data=MeshyMultiImageToModelRequest( - image_urls=await upload_images_to_comfyapi( - cls, list(images.values()), wait_label="Uploading base images" - ), - ai_model=model, - topology=should_remesh.get("topology", None), - target_polycount=should_remesh.get("target_polycount", None), - symmetry_mode=symmetry_mode, - should_remesh=should_remesh["should_remesh"] == "true", - should_texture=texture, - enable_pbr=should_texture.get("enable_pbr", None), - pose_mode=pose_mode.lower(), - texture_prompt=texture_prompt, - texture_image_url=texture_image_url, - seed=seed, - ), - ) - task_id = response.result - result = await poll_op( - cls, - ApiEndpoint(path=f"/proxy/meshy/openapi/v1/multi-image-to-3d/{task_id}"), - response_model=MeshyModelResult, - status_extractor=lambda r: r.status, - progress_extractor=lambda r: r.progress, - ) + texture_image_url = await upload_image_to_fal(should_texture["texture_image"]) + image_urls = [] + for img in list(images.values()): + image_urls.append(await upload_image_to_fal(img)) + result = await fal_run(cls, FAL_MESHY_I2_3D, { + "image_urls": image_urls, + "ai_model": model, + "topology": should_remesh.get("topology", None), + "target_polycount": should_remesh.get("target_polycount", None), + "symmetry_mode": symmetry_mode, + "should_remesh": should_remesh["should_remesh"] == "true", + "should_texture": texture, + "enable_pbr": should_texture.get("enable_pbr", None), + "pose_mode": pose_mode.lower(), + "texture_prompt": texture_prompt, + "texture_image_url": texture_image_url, + "seed": seed, + }) + # TODO: verify fal.ai field names + task_id = result.get("id", "meshy_task") + model_urls = result["model_urls"] return IO.NodeOutput( f"{task_id}.glb", task_id, - await download_url_to_file_3d(result.model_urls.glb, "glb", task_id=task_id), - await download_url_to_file_3d(result.model_urls.fbx, "fbx", task_id=task_id), + await download_url_to_file_3d(model_urls["glb"], "glb", task_id=task_id), + await download_url_to_file_3d(model_urls["fbx"], "fbx", task_id=task_id), ) @@ -602,15 +513,10 @@ def define_schema(cls): IO.File3DFBX.Output(display_name="FBX"), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, is_output_node=True, - price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.2}""", - ), ) @classmethod @@ -622,30 +528,20 @@ async def execute( ) -> IO.NodeOutput: texture_image_url = None if texture_image is not None: - texture_image_url = (await upload_images_to_comfyapi(cls, texture_image, wait_label="Uploading texture"))[0] - response = await sync_op( - cls, - endpoint=ApiEndpoint(path="/proxy/meshy/openapi/v1/rigging", method="POST"), - response_model=MeshyTaskResponse, - data=MeshyRiggingRequest( - input_task_id=meshy_task_id, - height_meters=height_meters, - texture_image_url=texture_image_url, - ), - ) - task_id = response.result - result = await poll_op( - cls, - ApiEndpoint(path=f"/proxy/meshy/openapi/v1/rigging/{task_id}"), - response_model=MeshyRiggedResult, - status_extractor=lambda r: r.status, - progress_extractor=lambda r: r.progress, - ) + texture_image_url = await upload_image_to_fal(texture_image) + result = await fal_run(cls, FAL_MESHY_I2_3D, { + "input_task_id": meshy_task_id, + "height_meters": height_meters, + "texture_image_url": texture_image_url, + }) + # TODO: verify fal.ai field names + task_id = result.get("id", "meshy_task") + rigged = result["result"] return IO.NodeOutput( f"{task_id}.glb", task_id, - await download_url_to_file_3d(result.result.rigged_character_glb_url, "glb", task_id=task_id), - await download_url_to_file_3d(result.result.rigged_character_fbx_url, "fbx", task_id=task_id), + await download_url_to_file_3d(rigged["rigged_character_glb_url"], "glb", task_id=task_id), + await download_url_to_file_3d(rigged["rigged_character_fbx_url"], "fbx", task_id=task_id), ) @@ -674,15 +570,10 @@ def define_schema(cls): IO.File3DFBX.Output(display_name="FBX"), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, is_output_node=True, - price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.12}""", - ), ) @classmethod @@ -691,27 +582,17 @@ async def execute( rig_task_id: str, action_id: int, ) -> IO.NodeOutput: - response = await sync_op( - cls, - endpoint=ApiEndpoint(path="/proxy/meshy/openapi/v1/animations", method="POST"), - response_model=MeshyTaskResponse, - data=MeshyAnimationRequest( - rig_task_id=rig_task_id, - action_id=action_id, - ), - ) - task_id = response.result - result = await poll_op( - cls, - ApiEndpoint(path=f"/proxy/meshy/openapi/v1/animations/{task_id}"), - response_model=MeshyAnimationResult, - status_extractor=lambda r: r.status, - progress_extractor=lambda r: r.progress, - ) + result = await fal_run(cls, FAL_MESHY_I2_3D, { + "rig_task_id": rig_task_id, + "action_id": action_id, + }) + # TODO: verify fal.ai field names + task_id = result.get("id", "meshy_task") + anim = result["result"] return IO.NodeOutput( f"{task_id}.glb", - await download_url_to_file_3d(result.result.animation_glb_url, "glb", task_id=task_id), - await download_url_to_file_3d(result.result.animation_fbx_url, "fbx", task_id=task_id), + await download_url_to_file_3d(anim["animation_glb_url"], "glb", task_id=task_id), + await download_url_to_file_3d(anim["animation_fbx_url"], "fbx", task_id=task_id), ) @@ -756,15 +637,10 @@ def define_schema(cls): IO.File3DFBX.Output(display_name="FBX"), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, is_output_node=True, - price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.4}""", - ), ) @classmethod @@ -783,33 +659,23 @@ async def execute( raise ValueError("Either text_style_prompt or image_style is required") image_style_url = None if image_style is not None: - image_style_url = (await upload_images_to_comfyapi(cls, image_style, wait_label="Uploading style"))[0] - response = await sync_op( - cls, - endpoint=ApiEndpoint(path="/proxy/meshy/openapi/v1/retexture", method="POST"), - response_model=MeshyTaskResponse, - data=MeshyTextureRequest( - input_task_id=meshy_task_id, - ai_model=model, - enable_original_uv=enable_original_uv, - enable_pbr=pbr, - text_style_prompt=text_style_prompt if text_style_prompt else None, - image_style_url=image_style_url, - ), - ) - task_id = response.result - result = await poll_op( - cls, - ApiEndpoint(path=f"/proxy/meshy/openapi/v1/retexture/{task_id}"), - response_model=MeshyModelResult, - status_extractor=lambda r: r.status, - progress_extractor=lambda r: r.progress, - ) + image_style_url = await upload_image_to_fal(image_style) + result = await fal_run(cls, FAL_MESHY_I2_3D, { + "input_task_id": meshy_task_id, + "ai_model": model, + "enable_original_uv": enable_original_uv, + "enable_pbr": pbr, + "text_style_prompt": text_style_prompt if text_style_prompt else None, + "image_style_url": image_style_url, + }) + # TODO: verify fal.ai field names + task_id = result.get("id", "meshy_task") + model_urls = result["model_urls"] return IO.NodeOutput( f"{task_id}.glb", task_id, - await download_url_to_file_3d(result.model_urls.glb, "glb", task_id=task_id), - await download_url_to_file_3d(result.model_urls.fbx, "fbx", task_id=task_id), + await download_url_to_file_3d(model_urls["glb"], "glb", task_id=task_id), + await download_url_to_file_3d(model_urls["fbx"], "fbx", task_id=task_id), ) diff --git a/comfy_api_nodes/nodes_minimax.py b/comfy_api_nodes/nodes_minimax.py index b5d0b461fffa..df6d96046642 100644 --- a/comfy_api_nodes/nodes_minimax.py +++ b/comfy_api_nodes/nodes_minimax.py @@ -17,9 +17,13 @@ download_url_to_video_output, poll_op, sync_op, - upload_images_to_comfyapi, + upload_images_to_fal, validate_string, ) +from comfy_api_nodes.util._helpers import get_fal_auth_header +from comfy_api_nodes.util.client import fal_run + +FAL_MINIMAX_VIDEO = "fal-ai/minimax/video-01-director" I2V_AVERAGE_DURATION = 114 T2V_AVERAGE_DURATION = 234 @@ -39,60 +43,26 @@ async def _generate_mm_video( validate_string(prompt_text, field_name="prompt_text") image_url = None if image is not None: - image_url = (await upload_images_to_comfyapi(cls, image, max_images=1))[0] + image_url = (await upload_images_to_fal(image, max_images=1))[0] # TODO: figure out how to deal with subject properly, API returns invalid params when using S2V-01 model subject_reference = None if subject is not None: - subject_url = (await upload_images_to_comfyapi(cls, subject, max_images=1))[0] + subject_url = (await upload_images_to_fal(subject, max_images=1))[0] subject_reference = [SubjectReferenceItem(image=subject_url)] - response = await sync_op( - cls, - ApiEndpoint(path="/proxy/minimax/video_generation", method="POST"), - response_model=MinimaxVideoGenerationResponse, - data=MinimaxVideoGenerationRequest( - model=MiniMaxModel(model), - prompt=prompt_text, - callback_url=None, - first_frame_image=image_url, - subject_reference=subject_reference, - prompt_optimizer=None, - ), - ) - - task_id = response.task_id - if not task_id: - raise Exception(f"MiniMax generation failed: {response.base_resp}") - - task_result = await poll_op( - cls, - ApiEndpoint(path="/proxy/minimax/query/video_generation", query_params={"task_id": task_id}), - response_model=MinimaxTaskResultResponse, - status_extractor=lambda x: x.status.value, + data = { + "model": model, + "prompt": prompt_text, + "first_frame_image": image_url, + "subject_reference": [s.model_dump() for s in subject_reference] if subject_reference else None, + } + result = await fal_run( + cls, FAL_MINIMAX_VIDEO, data, estimated_duration=average_duration, - ) - - file_id = task_result.file_id - if file_id is None: - raise Exception("Request was not successful. Missing file ID.") - file_result = await sync_op( - cls, - ApiEndpoint(path="/proxy/minimax/files/retrieve", query_params={"file_id": int(file_id)}), - response_model=MinimaxFileRetrieveResponse, - ) - - file_url = file_result.file.download_url - if file_url is None: - raise Exception(f"No video was found in the response. Full response: {file_result.model_dump()}") - if file_result.file.backup_download_url: - try: - return IO.NodeOutput(await download_url_to_video_output(file_url, timeout=10, max_retries=2)) - except Exception: # if we have a second URL to retrieve the result, try again using that one - return IO.NodeOutput( - await download_url_to_video_output(file_result.file.backup_download_url, max_retries=3) - ) - return IO.NodeOutput(await download_url_to_video_output(file_url)) + ) # TODO: verify fal.ai field names + video_url = result["video"]["url"] + return IO.NodeOutput(await download_url_to_video_output(video_url)) class MinimaxTextToVideoNode(IO.ComfyNode): @@ -129,14 +99,9 @@ def define_schema(cls) -> IO.Schema: ], outputs=[IO.Video.Output()], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.43}""", - ), ) @classmethod @@ -195,14 +160,9 @@ def define_schema(cls) -> IO.Schema: ], outputs=[IO.Video.Output()], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.43}""", - ), ) @classmethod @@ -262,8 +222,6 @@ def define_schema(cls) -> IO.Schema: ], outputs=[IO.Video.Output()], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, @@ -341,25 +299,9 @@ def define_schema(cls) -> IO.Schema: ], outputs=[IO.Video.Output()], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["resolution", "duration"]), - expr=""" - ( - $prices := { - "768p": {"6": 0.28, "10": 0.56}, - "1080p": {"6": 0.49} - }; - $resPrices := $lookup($prices, $lowercase(widgets.resolution)); - $price := $lookup($resPrices, $string(widgets.duration)); - {"type":"usd","usd": $price ? $price : 0.43} - ) - """, - ), ) @classmethod @@ -384,57 +326,23 @@ async def execute( # upload image, if passed in image_url = None if first_frame_image is not None: - image_url = (await upload_images_to_comfyapi(cls, first_frame_image, max_images=1))[0] - - response = await sync_op( - cls, - ApiEndpoint(path="/proxy/minimax/video_generation", method="POST"), - response_model=MinimaxVideoGenerationResponse, - data=MinimaxVideoGenerationRequest( - model=MiniMaxModel(model), - prompt=prompt_text, - callback_url=None, - first_frame_image=image_url, - prompt_optimizer=prompt_optimizer, - duration=duration, - resolution=resolution, - ), - ) - - task_id = response.task_id - if not task_id: - raise Exception(f"MiniMax generation failed: {response.base_resp}") + image_url = (await upload_images_to_fal(first_frame_image, max_images=1))[0] average_duration = 120 if resolution == "768P" else 240 - task_result = await poll_op( - cls, - ApiEndpoint(path="/proxy/minimax/query/video_generation", query_params={"task_id": task_id}), - response_model=MinimaxTaskResultResponse, - status_extractor=lambda x: x.status.value, + data = { + "model": model, + "prompt": prompt_text, + "first_frame_image": image_url, + "prompt_optimizer": prompt_optimizer, + "duration": duration, + "resolution": resolution, + } + result = await fal_run( + cls, FAL_MINIMAX_VIDEO, data, estimated_duration=average_duration, - ) - - file_id = task_result.file_id - if file_id is None: - raise Exception("Request was not successful. Missing file ID.") - file_result = await sync_op( - cls, - ApiEndpoint(path="/proxy/minimax/files/retrieve", query_params={"file_id": int(file_id)}), - response_model=MinimaxFileRetrieveResponse, - ) - - file_url = file_result.file.download_url - if file_url is None: - raise Exception(f"No video was found in the response. Full response: {file_result.model_dump()}") - - if file_result.file.backup_download_url: - try: - return IO.NodeOutput(await download_url_to_video_output(file_url, timeout=10, max_retries=2)) - except Exception: # if we have a second URL to retrieve the result, try again using that one - return IO.NodeOutput( - await download_url_to_video_output(file_result.file.backup_download_url, max_retries=3) - ) - return IO.NodeOutput(await download_url_to_video_output(file_url)) + ) # TODO: verify fal.ai field names + video_url = result["video"]["url"] + return IO.NodeOutput(await download_url_to_video_output(video_url)) class MinimaxExtension(ComfyExtension): diff --git a/comfy_api_nodes/nodes_moonvalley.py b/comfy_api_nodes/nodes_moonvalley.py deleted file mode 100644 index 78a230529705..000000000000 --- a/comfy_api_nodes/nodes_moonvalley.py +++ /dev/null @@ -1,534 +0,0 @@ -import logging - -from typing_extensions import override - -from comfy_api.latest import IO, ComfyExtension, Input -from comfy_api_nodes.apis.moonvalley import ( - MoonvalleyPromptResponse, - MoonvalleyTextToVideoInferenceParams, - MoonvalleyTextToVideoRequest, - MoonvalleyVideoToVideoInferenceParams, - MoonvalleyVideoToVideoRequest, -) -from comfy_api_nodes.util import ( - ApiEndpoint, - download_url_to_video_output, - poll_op, - sync_op, - trim_video, - upload_images_to_comfyapi, - upload_video_to_comfyapi, - validate_container_format_is_mp4, - validate_image_dimensions, - validate_string, -) - -API_UPLOADS_ENDPOINT = "/proxy/moonvalley/uploads" -API_PROMPTS_ENDPOINT = "/proxy/moonvalley/prompts" -API_VIDEO2VIDEO_ENDPOINT = "/proxy/moonvalley/prompts/video-to-video" -API_TXT2VIDEO_ENDPOINT = "/proxy/moonvalley/prompts/text-to-video" -API_IMG2VIDEO_ENDPOINT = "/proxy/moonvalley/prompts/image-to-video" - -MIN_WIDTH = 300 -MIN_HEIGHT = 300 - -MAX_WIDTH = 10000 -MAX_HEIGHT = 10000 - -MIN_VID_WIDTH = 300 -MIN_VID_HEIGHT = 300 - -MAX_VID_WIDTH = 10000 -MAX_VID_HEIGHT = 10000 - -MAX_VIDEO_SIZE = 1024 * 1024 * 1024 # 1 GB max for in-memory video processing - -MOONVALLEY_MAREY_MAX_PROMPT_LENGTH = 5000 - - -def is_valid_task_creation_response(response: MoonvalleyPromptResponse) -> bool: - """Verifies that the initial response contains a task ID.""" - return bool(response.id) - - -def validate_task_creation_response(response) -> None: - if not is_valid_task_creation_response(response): - error_msg = f"Moonvalley Marey API: Initial request failed. Code: {response.code}, Message: {response.message}, Data: {response}" - logging.error(error_msg) - raise RuntimeError(error_msg) - - -def validate_video_to_video_input(video: Input.Video) -> Input.Video: - """ - Validates and processes video input for Moonvalley Video-to-Video generation. - - Args: - video: Input video to validate - - Returns: - Validated and potentially trimmed video - - Raises: - ValueError: If video doesn't meet requirements - MoonvalleyApiError: If video duration is too short - """ - width, height = _get_video_dimensions(video) - _validate_video_dimensions(width, height) - validate_container_format_is_mp4(video) - - return _validate_and_trim_duration(video) - - -def _get_video_dimensions(video: Input.Video) -> tuple[int, int]: - """Extracts video dimensions with error handling.""" - try: - return video.get_dimensions() - except Exception as e: - logging.error("Error getting dimensions of video: %s", e) - raise ValueError(f"Cannot get video dimensions: {e}") from e - - -def _validate_video_dimensions(width: int, height: int) -> None: - """Validates video dimensions meet Moonvalley V2V requirements.""" - supported_resolutions = { - (1920, 1080), - (1080, 1920), - (1152, 1152), - (1536, 1152), - (1152, 1536), - } - - if (width, height) not in supported_resolutions: - supported_list = ", ".join([f"{w}x{h}" for w, h in sorted(supported_resolutions)]) - raise ValueError(f"Resolution {width}x{height} not supported. Supported: {supported_list}") - - -def _validate_and_trim_duration(video: Input.Video) -> Input.Video: - """Validates video duration and trims to 5 seconds if needed.""" - duration = video.get_duration() - _validate_minimum_duration(duration) - return _trim_if_too_long(video, duration) - - -def _validate_minimum_duration(duration: float) -> None: - """Ensures video is at least 5 seconds long.""" - if duration < 5: - raise ValueError("Input video must be at least 5 seconds long.") - - -def _trim_if_too_long(video: Input.Video, duration: float) -> Input.Video: - """Trims video to 5 seconds if longer.""" - if duration > 5: - return trim_video(video, 5) - return video - - -def parse_width_height_from_res(resolution: str): - # Accepts a string like "16:9 (1920 x 1080)" and returns width, height as a dict - res_map = { - "16:9 (1920 x 1080)": {"width": 1920, "height": 1080}, - "9:16 (1080 x 1920)": {"width": 1080, "height": 1920}, - "1:1 (1152 x 1152)": {"width": 1152, "height": 1152}, - "4:3 (1536 x 1152)": {"width": 1536, "height": 1152}, - "3:4 (1152 x 1536)": {"width": 1152, "height": 1536}, - # "21:9 (2560 x 1080)": {"width": 2560, "height": 1080}, - } - return res_map.get(resolution, {"width": 1920, "height": 1080}) - - -def parse_control_parameter(value): - control_map = { - "Motion Transfer": "motion_control", - "Canny": "canny_control", - "Pose Transfer": "pose_control", - "Depth": "depth_control", - } - return control_map.get(value, control_map["Motion Transfer"]) - - -async def get_response(cls: type[IO.ComfyNode], task_id: str) -> MoonvalleyPromptResponse: - return await poll_op( - cls, - ApiEndpoint(path=f"{API_PROMPTS_ENDPOINT}/{task_id}"), - response_model=MoonvalleyPromptResponse, - status_extractor=lambda r: (r.status if r and r.status else None), - poll_interval=16.0, - max_poll_attempts=240, - ) - - -class MoonvalleyImg2VideoNode(IO.ComfyNode): - - @classmethod - def define_schema(cls) -> IO.Schema: - return IO.Schema( - node_id="MoonvalleyImg2VideoNode", - display_name="Moonvalley Marey Image to Video", - category="api node/video/Moonvalley Marey", - description="Moonvalley Marey Image to Video Node", - inputs=[ - IO.Image.Input( - "image", - tooltip="The reference image used to generate the video", - ), - IO.String.Input( - "prompt", - multiline=True, - ), - IO.String.Input( - "negative_prompt", - multiline=True, - default=" gopro, bright, contrast, static, overexposed, vignette, " - "artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, " - "flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, " - "cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, " - "blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, " - "wobbly, weird, low quality, plastic, stock footage, video camera, boring", - tooltip="Negative prompt text", - ), - IO.Combo.Input( - "resolution", - options=[ - "16:9 (1920 x 1080)", - "9:16 (1080 x 1920)", - "1:1 (1152 x 1152)", - "4:3 (1536 x 1152)", - "3:4 (1152 x 1536)", - # "21:9 (2560 x 1080)", - ], - default="16:9 (1920 x 1080)", - tooltip="Resolution of the output video", - ), - IO.Float.Input( - "prompt_adherence", - default=4.5, - min=1.0, - max=20.0, - step=1.0, - tooltip="Guidance scale for generation control", - ), - IO.Int.Input( - "seed", - default=9, - min=0, - max=4294967295, - step=1, - display_mode=IO.NumberDisplay.number, - tooltip="Random seed value", - control_after_generate=True, - ), - IO.Int.Input( - "steps", - default=80, - min=75, # steps should be greater or equal to cooldown_steps(75) + warmup_steps(0) - max=100, - step=1, - tooltip="Number of denoising steps", - ), - ], - outputs=[IO.Video.Output()], - hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, - IO.Hidden.unique_id, - ], - is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(), - expr="""{"type":"usd","usd": 1.5}""", - ), - ) - - @classmethod - async def execute( - cls, - image: Input.Image, - prompt: str, - negative_prompt: str, - resolution: str, - prompt_adherence: float, - seed: int, - steps: int, - ) -> IO.NodeOutput: - validate_image_dimensions(image, min_width=300, min_height=300, max_height=MAX_HEIGHT, max_width=MAX_WIDTH) - validate_string(prompt, min_length=1, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) - validate_string(negative_prompt, field_name="negative_prompt", max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) - width_height = parse_width_height_from_res(resolution) - - inference_params = MoonvalleyTextToVideoInferenceParams( - negative_prompt=negative_prompt, - steps=steps, - seed=seed, - guidance_scale=prompt_adherence, - width=width_height["width"], - height=width_height["height"], - use_negative_prompts=True, - ) - - # Get MIME type from tensor - assuming PNG format for image tensors - mime_type = "image/png" - image_url = (await upload_images_to_comfyapi(cls, image, max_images=1, mime_type=mime_type))[0] - task_creation_response = await sync_op( - cls, - endpoint=ApiEndpoint(path=API_IMG2VIDEO_ENDPOINT, method="POST"), - response_model=MoonvalleyPromptResponse, - data=MoonvalleyTextToVideoRequest( - image_url=image_url, prompt_text=prompt, inference_params=inference_params - ), - ) - validate_task_creation_response(task_creation_response) - final_response = await get_response(cls, task_creation_response.id) - video = await download_url_to_video_output(final_response.output_url) - return IO.NodeOutput(video) - - -class MoonvalleyVideo2VideoNode(IO.ComfyNode): - - @classmethod - def define_schema(cls) -> IO.Schema: - return IO.Schema( - node_id="MoonvalleyVideo2VideoNode", - display_name="Moonvalley Marey Video to Video", - category="api node/video/Moonvalley Marey", - description="", - inputs=[ - IO.String.Input( - "prompt", - multiline=True, - tooltip="Describes the video to generate", - ), - IO.String.Input( - "negative_prompt", - multiline=True, - default=" gopro, bright, contrast, static, overexposed, vignette, " - "artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, " - "flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, " - "cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, " - "blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, " - "wobbly, weird, low quality, plastic, stock footage, video camera, boring", - tooltip="Negative prompt text", - ), - IO.Int.Input( - "seed", - default=9, - min=0, - max=4294967295, - step=1, - display_mode=IO.NumberDisplay.number, - tooltip="Random seed value", - control_after_generate=False, - ), - IO.Video.Input( - "video", - tooltip="The reference video used to generate the output video. Must be at least 5 seconds long. " - "Videos longer than 5s will be automatically trimmed. Only MP4 format supported.", - ), - IO.Combo.Input( - "control_type", - options=["Motion Transfer", "Pose Transfer"], - default="Motion Transfer", - optional=True, - ), - IO.Int.Input( - "motion_intensity", - default=100, - min=0, - max=100, - step=1, - tooltip="Only used if control_type is 'Motion Transfer'", - optional=True, - ), - IO.Int.Input( - "steps", - default=60, - min=60, # steps should be greater or equal to cooldown_steps(36) + warmup_steps(24) - max=100, - step=1, - display_mode=IO.NumberDisplay.number, - tooltip="Number of inference steps", - ), - ], - outputs=[IO.Video.Output()], - hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, - IO.Hidden.unique_id, - ], - is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(), - expr="""{"type":"usd","usd": 2.25}""", - ), - ) - - @classmethod - async def execute( - cls, - prompt: str, - negative_prompt: str, - seed: int, - video: Input.Video | None = None, - control_type: str = "Motion Transfer", - motion_intensity: int | None = 100, - steps=60, - prompt_adherence=4.5, - ) -> IO.NodeOutput: - validated_video = validate_video_to_video_input(video) - video_url = await upload_video_to_comfyapi(cls, validated_video) - validate_string(prompt, min_length=1, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) - validate_string(negative_prompt, field_name="negative_prompt", max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) - - # Only include motion_intensity for Motion Transfer - control_params = {} - if control_type == "Motion Transfer" and motion_intensity is not None: - control_params["motion_intensity"] = motion_intensity - - inference_params = MoonvalleyVideoToVideoInferenceParams( - negative_prompt=negative_prompt, - seed=seed, - control_params=control_params, - steps=steps, - guidance_scale=prompt_adherence, - ) - - task_creation_response = await sync_op( - cls, - endpoint=ApiEndpoint(path=API_VIDEO2VIDEO_ENDPOINT, method="POST"), - response_model=MoonvalleyPromptResponse, - data=MoonvalleyVideoToVideoRequest( - control_type=parse_control_parameter(control_type), - video_url=video_url, - prompt_text=prompt, - inference_params=inference_params, - ), - ) - validate_task_creation_response(task_creation_response) - final_response = await get_response(cls, task_creation_response.id) - return IO.NodeOutput(await download_url_to_video_output(final_response.output_url)) - - -class MoonvalleyTxt2VideoNode(IO.ComfyNode): - - @classmethod - def define_schema(cls) -> IO.Schema: - return IO.Schema( - node_id="MoonvalleyTxt2VideoNode", - display_name="Moonvalley Marey Text to Video", - category="api node/video/Moonvalley Marey", - description="", - inputs=[ - IO.String.Input( - "prompt", - multiline=True, - ), - IO.String.Input( - "negative_prompt", - multiline=True, - default=" gopro, bright, contrast, static, overexposed, vignette, " - "artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, " - "flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, " - "cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, " - "blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, " - "wobbly, weird, low quality, plastic, stock footage, video camera, boring", - tooltip="Negative prompt text", - ), - IO.Combo.Input( - "resolution", - options=[ - "16:9 (1920 x 1080)", - "9:16 (1080 x 1920)", - "1:1 (1152 x 1152)", - "4:3 (1536 x 1152)", - "3:4 (1152 x 1536)", - "21:9 (2560 x 1080)", - ], - default="16:9 (1920 x 1080)", - tooltip="Resolution of the output video", - ), - IO.Float.Input( - "prompt_adherence", - default=4.0, - min=1.0, - max=20.0, - step=1.0, - tooltip="Guidance scale for generation control", - ), - IO.Int.Input( - "seed", - default=9, - min=0, - max=4294967295, - step=1, - display_mode=IO.NumberDisplay.number, - control_after_generate=True, - tooltip="Random seed value", - ), - IO.Int.Input( - "steps", - default=80, - min=75, # steps should be greater or equal to cooldown_steps(75) + warmup_steps(0) - max=100, - step=1, - tooltip="Inference steps", - ), - ], - outputs=[IO.Video.Output()], - hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, - IO.Hidden.unique_id, - ], - is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(), - expr="""{"type":"usd","usd": 1.5}""", - ), - ) - - @classmethod - async def execute( - cls, - prompt: str, - negative_prompt: str, - resolution: str, - prompt_adherence: float, - seed: int, - steps: int, - ) -> IO.NodeOutput: - validate_string(prompt, min_length=1, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) - validate_string(negative_prompt, field_name="negative_prompt", max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) - width_height = parse_width_height_from_res(resolution) - - inference_params = MoonvalleyTextToVideoInferenceParams( - negative_prompt=negative_prompt, - steps=steps, - seed=seed, - guidance_scale=prompt_adherence, - num_frames=128, - width=width_height["width"], - height=width_height["height"], - ) - - task_creation_response = await sync_op( - cls, - endpoint=ApiEndpoint(path=API_TXT2VIDEO_ENDPOINT, method="POST"), - response_model=MoonvalleyPromptResponse, - data=MoonvalleyTextToVideoRequest(prompt_text=prompt, inference_params=inference_params), - ) - validate_task_creation_response(task_creation_response) - final_response = await get_response(cls, task_creation_response.id) - return IO.NodeOutput(await download_url_to_video_output(final_response.output_url)) - - -class MoonvalleyExtension(ComfyExtension): - @override - async def get_node_list(self) -> list[type[IO.ComfyNode]]: - return [ - MoonvalleyImg2VideoNode, - MoonvalleyTxt2VideoNode, - MoonvalleyVideo2VideoNode, - ] - - -async def comfy_entrypoint() -> MoonvalleyExtension: - return MoonvalleyExtension() diff --git a/comfy_api_nodes/nodes_openai.py b/comfy_api_nodes/nodes_openai.py index 4ee896fa8afe..1061de1f131d 100644 --- a/comfy_api_nodes/nodes_openai.py +++ b/comfy_api_nodes/nodes_openai.py @@ -24,17 +24,16 @@ OutputContent, ) from comfy_api_nodes.util import ( - ApiEndpoint, download_url_to_bytesio, downscale_image_tensor, - poll_op, - sync_op, tensor_to_base64_string, text_filepath_to_data_uri, validate_string, ) +from comfy_api_nodes.util.client import fal_run + +FAL_GPT_IMAGE_1 = "fal-ai/gpt-image-1/text-to-image" -RESPONSES_ENDPOINT = "/proxy/openai/v1/responses" STARTING_POINT_ID_PATTERN = r"" @@ -148,28 +147,9 @@ def define_schema(cls): IO.Image.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["size", "n"]), - expr=""" - ( - $size := widgets.size; - $nRaw := widgets.n; - $n := ($nRaw != null and $nRaw != 0) ? $nRaw : 1; - - $base := - $contains($size, "256x256") ? 0.016 : - $contains($size, "512x512") ? 0.018 : - 0.02; - - {"type":"usd","usd": $round($base * $n, 3)} - ) - """, - ), ) @classmethod @@ -183,17 +163,16 @@ async def execute( size="1024x1024", ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False) - model = "dall-e-2" - path = "/proxy/openai/images/generations" - content_type = "application/json" - request_class = OpenAIImageGenerationRequest - img_binary = None + data = { + "model": "dall-e-2", + "prompt": prompt, + "n": n, + "size": size, + "seed": seed, + } if image is not None and mask is not None: - path = "/proxy/openai/images/edits" - content_type = "multipart/form-data" - request_class = OpenAIImageEditRequest - + from comfy_api_nodes.util.upload_helpers import upload_image_to_fal input_tensor = image.squeeze().cpu() height, width, channels = input_tensor.shape rgba_tensor = torch.ones(height, width, 4, device="cpu") @@ -210,33 +189,22 @@ async def execute( img_byte_arr = BytesIO() img.save(img_byte_arr, format="PNG") img_byte_arr.seek(0) - img_binary = img_byte_arr # .getvalue() - img_binary.name = "image.png" + from comfy_api_nodes.util.upload_helpers import upload_file_to_fal + data["image_url"] = await upload_file_to_fal(img_byte_arr, "image/png") elif image is not None or mask is not None: raise Exception("Dall-E 2 image editing requires an image AND a mask") - response = await sync_op( - cls, - ApiEndpoint(path=path, method="POST"), - response_model=OpenAIImageGenerationResponse, - data=request_class( - model=model, - prompt=prompt, - n=n, - size=size, - seed=seed, - ), - files=( - { - "image": ("image.png", img_binary, "image/png"), - } - if img_binary - else None - ), - content_type=content_type, - ) - - return IO.NodeOutput(await validate_and_cast_response(response)) + result = await fal_run(cls, FAL_GPT_IMAGE_1, data) # TODO: verify fal.ai field names; use correct fal model for DALL-E 2 + # Parse fal.ai response to match expected format + image_tensors = [] + for img_data in result.get("images", []): + img_url = img_data["url"] + img_io = BytesIO() + await download_url_to_bytesio(img_url, img_io) + pil_img = Image.open(img_io).convert("RGBA") + arr = np.asarray(pil_img).astype(np.float32) / 255.0 + image_tensors.append(torch.from_numpy(arr)) + return IO.NodeOutput(torch.stack(image_tensors, dim=0)) class OpenAIDalle3(IO.ComfyNode): @@ -292,30 +260,9 @@ def define_schema(cls): IO.Image.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["size", "quality"]), - expr=""" - ( - $size := widgets.size; - $q := widgets.quality; - $hd := $contains($q, "hd"); - - $price := - $contains($size, "1024x1024") - ? ($hd ? 0.08 : 0.04) - : (($contains($size, "1792x1024") or $contains($size, "1024x1792")) - ? ($hd ? 0.12 : 0.08) - : 0.04); - - {"type":"usd","usd": $price} - ) - """, - ), ) @classmethod @@ -328,24 +275,23 @@ async def execute( size="1024x1024", ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False) - model = "dall-e-3" - - # build the operation - response = await sync_op( - cls, - ApiEndpoint(path="/proxy/openai/images/generations", method="POST"), - response_model=OpenAIImageGenerationResponse, - data=OpenAIImageGenerationRequest( - model=model, - prompt=prompt, - quality=quality, - size=size, - style=style, - seed=seed, - ), - ) - - return IO.NodeOutput(await validate_and_cast_response(response)) + result = await fal_run(cls, FAL_GPT_IMAGE_1, { + "model": "dall-e-3", + "prompt": prompt, + "quality": quality, + "size": size, + "style": style, + "seed": seed, + }) # TODO: verify fal.ai field names; use correct fal model for DALL-E 3 + image_tensors = [] + for img_data in result.get("images", []): + img_url = img_data["url"] + img_io = BytesIO() + await download_url_to_bytesio(img_url, img_io) + pil_img = Image.open(img_io).convert("RGBA") + arr = np.asarray(pil_img).astype(np.float32) / 255.0 + image_tensors.append(torch.from_numpy(arr)) + return IO.NodeOutput(torch.stack(image_tensors, dim=0)) def calculate_tokens_price_image_1(response: OpenAIImageGenerationResponse) -> float | None: @@ -436,33 +382,9 @@ def define_schema(cls): IO.Image.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["quality", "n"]), - expr=""" - ( - $ranges := { - "low": [0.011, 0.02], - "medium": [0.046, 0.07], - "high": [0.167, 0.3] - }; - $range := $lookup($ranges, widgets.quality); - $n := widgets.n; - ($n = 1) - ? {"type":"range_usd","min_usd": $range[0], "max_usd": $range[1]} - : { - "type":"range_usd", - "min_usd": $range[0], - "max_usd": $range[1], - "format": { "suffix": " x " & $string($n) & "/Run" } - } - ) - """, - ), ) @classmethod @@ -490,23 +412,30 @@ async def execute( else: raise ValueError(f"Unknown model: {model}") + data = { + "model": model, + "prompt": prompt, + "quality": quality, + "background": background, + "n": n, + "seed": seed, + "size": size, + "moderation": "low", + } if image is not None: - files = [] + from comfy_api_nodes.util.upload_helpers import upload_image_to_fal, upload_file_to_fal + image_urls = [] batch_size = image.shape[0] for i in range(batch_size): single_image = image[i : i + 1] scaled_image = downscale_image_tensor(single_image, total_pixels=2048 * 2048).squeeze() - image_np = (scaled_image.numpy() * 255).astype(np.uint8) img = Image.fromarray(image_np) img_byte_arr = BytesIO() img.save(img_byte_arr, format="PNG") img_byte_arr.seek(0) - - if batch_size == 1: - files.append(("image", (f"image_{i}.png", img_byte_arr, "image/png"))) - else: - files.append(("image[]", (f"image_{i}.png", img_byte_arr, "image/png"))) + image_urls.append(await upload_file_to_fal(img_byte_arr, "image/png")) + data["image_urls"] = image_urls if mask is not None: if image.shape[0] != 1: @@ -516,52 +445,24 @@ async def execute( _, height, width = mask.shape rgba_mask = torch.zeros(height, width, 4, device="cpu") rgba_mask[:, :, 3] = 1 - mask.squeeze().cpu() - scaled_mask = downscale_image_tensor(rgba_mask.unsqueeze(0), total_pixels=2048 * 2048).squeeze() - mask_np = (scaled_mask.numpy() * 255).astype(np.uint8) mask_img = Image.fromarray(mask_np) mask_img_byte_arr = BytesIO() mask_img.save(mask_img_byte_arr, format="PNG") mask_img_byte_arr.seek(0) - files.append(("mask", ("mask.png", mask_img_byte_arr, "image/png"))) - - response = await sync_op( - cls, - ApiEndpoint(path="/proxy/openai/images/edits", method="POST"), - response_model=OpenAIImageGenerationResponse, - data=OpenAIImageEditRequest( - model=model, - prompt=prompt, - quality=quality, - background=background, - n=n, - seed=seed, - size=size, - moderation="low", - ), - content_type="multipart/form-data", - files=files, - price_extractor=price_extractor, - ) - else: - response = await sync_op( - cls, - ApiEndpoint(path="/proxy/openai/images/generations", method="POST"), - response_model=OpenAIImageGenerationResponse, - data=OpenAIImageGenerationRequest( - model=model, - prompt=prompt, - quality=quality, - background=background, - n=n, - seed=seed, - size=size, - moderation="low", - ), - price_extractor=price_extractor, - ) - return IO.NodeOutput(await validate_and_cast_response(response)) + data["mask_url"] = await upload_file_to_fal(mask_img_byte_arr, "image/png") + + result = await fal_run(cls, FAL_GPT_IMAGE_1, data) # TODO: verify fal.ai field names + image_tensors = [] + for img_data in result.get("images", []): + img_url = img_data["url"] + img_io = BytesIO() + await download_url_to_bytesio(img_url, img_io) + pil_img = Image.open(img_io).convert("RGBA") + arr = np.asarray(pil_img).astype(np.float32) / 255.0 + image_tensors.append(torch.from_numpy(arr)) + return IO.NodeOutput(torch.stack(image_tensors, dim=0)) class OpenAIChatNode(IO.ComfyNode): @@ -615,75 +516,9 @@ def define_schema(cls): IO.String.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["model"]), - expr=""" - ( - $m := widgets.model; - $contains($m, "o4-mini") ? { - "type": "list_usd", - "usd": [0.0011, 0.0044], - "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } - } - : $contains($m, "o1-pro") ? { - "type": "list_usd", - "usd": [0.15, 0.6], - "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } - } - : $contains($m, "o1") ? { - "type": "list_usd", - "usd": [0.015, 0.06], - "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } - } - : $contains($m, "o3-mini") ? { - "type": "list_usd", - "usd": [0.0011, 0.0044], - "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } - } - : $contains($m, "o3") ? { - "type": "list_usd", - "usd": [0.01, 0.04], - "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } - } - : $contains($m, "gpt-4.1-nano") ? { - "type": "list_usd", - "usd": [0.0001, 0.0004], - "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } - } - : $contains($m, "gpt-4.1-mini") ? { - "type": "list_usd", - "usd": [0.0004, 0.0016], - "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } - } - : $contains($m, "gpt-4.1") ? { - "type": "list_usd", - "usd": [0.002, 0.008], - "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } - } - : $contains($m, "gpt-5-nano") ? { - "type": "list_usd", - "usd": [0.00005, 0.0004], - "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } - } - : $contains($m, "gpt-5-mini") ? { - "type": "list_usd", - "usd": [0.00025, 0.002], - "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } - } - : $contains($m, "gpt-5") ? { - "type": "list_usd", - "usd": [0.00125, 0.01], - "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } - } - : {"type": "text", "text": "Token-based"} - ) - """, - ), ) @classmethod @@ -747,36 +582,30 @@ async def execute( ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False) - # Create response - create_response = await sync_op( - cls, - ApiEndpoint(path=RESPONSES_ENDPOINT, method="POST"), - response_model=OpenAIResponse, - data=OpenAICreateResponse( - input=[ - InputMessage( - content=cls.create_input_message_contents(prompt, images, files), - role="user", - ), - ], - store=True, - stream=False, - model=model, - previous_response_id=None, - **(advanced_options.model_dump(exclude_none=True) if advanced_options else {}), - ), - ) - response_id = create_response.id - - # Get result output - result_response = await poll_op( - cls, - ApiEndpoint(path=f"{RESPONSES_ENDPOINT}/{response_id}"), - response_model=OpenAIResponse, - status_extractor=lambda response: response.status, - completed_statuses=["incomplete", "completed"], - ) - return IO.NodeOutput(cls.get_text_from_message_content(cls.get_message_content_from_response(result_response))) + # Create response via fal_run + data = { + "input": [ + { + "content": [c.model_dump() if hasattr(c, 'model_dump') else c for c in cls.create_input_message_contents(prompt, images, files)], + "role": "user", + }, + ], + "store": True, + "stream": False, + "model": model, + "previous_response_id": None, + **(advanced_options.model_dump(exclude_none=True) if advanced_options else {}), + } + result = await fal_run(cls, FAL_GPT_IMAGE_1, data) # TODO: verify fal.ai field names; use correct fal model for ChatGPT + # Extract text from fal.ai response + text = "" + for output in result.get("output", []): + if output.get("type") == "message": + for content_item in output.get("content", []): + if content_item.get("type") == "output_text": + text = content_item.get("text", "") + break + return IO.NodeOutput(text if text else "No text output found in response") class OpenAIInputFiles(IO.ComfyNode): diff --git a/comfy_api_nodes/nodes_pixverse.py b/comfy_api_nodes/nodes_pixverse.py index e17a24ae72db..b229258c161a 100644 --- a/comfy_api_nodes/nodes_pixverse.py +++ b/comfy_api_nodes/nodes_pixverse.py @@ -17,13 +17,13 @@ pixverse_templates, ) from comfy_api_nodes.util import ( - ApiEndpoint, download_url_to_video_output, - poll_op, - sync_op, tensor_to_bytesio, validate_string, ) +from comfy_api_nodes.util.client import fal_run + +FAL_PIXVERSE_I2V = "fal-ai/pixverse/v3.5/image-to-video" AVERAGE_DURATION_T2V = 32 AVERAGE_DURATION_I2V = 30 @@ -31,16 +31,9 @@ async def upload_image_to_pixverse(cls: type[IO.ComfyNode], image: torch.Tensor): - response_upload = await sync_op( - cls, - ApiEndpoint(path="/proxy/pixverse/image/upload", method="POST"), - response_model=PixverseImageUploadResponse, - files={"image": tensor_to_bytesio(image)}, - content_type="multipart/form-data", - ) - if response_upload.Resp is None: - raise Exception(f"PixVerse image upload request failed: '{response_upload.ErrMsg}'") - return response_upload.Resp.img_id + from comfy_api_nodes.util.upload_helpers import upload_image_to_fal + image_url = await upload_image_to_fal(image[0] if len(image.shape) > 3 else image, "image/png") + return image_url # TODO: fal.ai returns URL instead of img_id; verify fal.ai integration class PixverseTemplateNode(IO.ComfyNode): @@ -123,12 +116,9 @@ def define_schema(cls) -> IO.Schema: ], outputs=[IO.Video.Output()], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=PRICE_BADGE_VIDEO, ) @classmethod @@ -152,38 +142,17 @@ async def execute( elif duration_seconds != PixverseDuration.dur_5: motion_mode = PixverseMotionMode.normal - response_api = await sync_op( - cls, - ApiEndpoint(path="/proxy/pixverse/video/text/generate", method="POST"), - response_model=PixverseVideoResponse, - data=PixverseTextVideoRequest( - prompt=prompt, - aspect_ratio=aspect_ratio, - quality=quality, - duration=duration_seconds, - motion_mode=motion_mode, - negative_prompt=negative_prompt if negative_prompt else None, - template_id=pixverse_template, - seed=seed, - ), - ) - if response_api.Resp is None: - raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'") - - response_poll = await poll_op( - cls, - ApiEndpoint(path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}"), - response_model=PixverseGenerationStatusResponse, - completed_statuses=[PixverseStatus.successful], - failed_statuses=[ - PixverseStatus.contents_moderation, - PixverseStatus.failed, - PixverseStatus.deleted, - ], - status_extractor=lambda x: x.Resp.status, - estimated_duration=AVERAGE_DURATION_T2V, - ) - return IO.NodeOutput(await download_url_to_video_output(response_poll.Resp.url)) + result = await fal_run(cls, FAL_PIXVERSE_I2V, { + "prompt": prompt, + "aspect_ratio": aspect_ratio, + "quality": quality, + "duration": duration_seconds, + "motion_mode": motion_mode, + "negative_prompt": negative_prompt if negative_prompt else None, + "template_id": pixverse_template, + "seed": seed, + }, estimated_duration=AVERAGE_DURATION_T2V) # TODO: verify fal.ai field names; use correct fal model for text-to-video + return IO.NodeOutput(await download_url_to_video_output(result["video"]["url"])) class PixverseImageToVideoNode(IO.ComfyNode): @@ -238,12 +207,9 @@ def define_schema(cls) -> IO.Schema: ], outputs=[IO.Video.Output()], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=PRICE_BADGE_VIDEO, ) @classmethod @@ -269,39 +235,17 @@ async def execute( elif duration_seconds != PixverseDuration.dur_5: motion_mode = PixverseMotionMode.normal - response_api = await sync_op( - cls, - ApiEndpoint(path="/proxy/pixverse/video/img/generate", method="POST"), - response_model=PixverseVideoResponse, - data=PixverseImageVideoRequest( - img_id=img_id, - prompt=prompt, - quality=quality, - duration=duration_seconds, - motion_mode=motion_mode, - negative_prompt=negative_prompt if negative_prompt else None, - template_id=pixverse_template, - seed=seed, - ), - ) - - if response_api.Resp is None: - raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'") - - response_poll = await poll_op( - cls, - ApiEndpoint(path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}"), - response_model=PixverseGenerationStatusResponse, - completed_statuses=[PixverseStatus.successful], - failed_statuses=[ - PixverseStatus.contents_moderation, - PixverseStatus.failed, - PixverseStatus.deleted, - ], - status_extractor=lambda x: x.Resp.status, - estimated_duration=AVERAGE_DURATION_I2V, - ) - return IO.NodeOutput(await download_url_to_video_output(response_poll.Resp.url)) + result = await fal_run(cls, FAL_PIXVERSE_I2V, { + "image_url": img_id, # img_id is now a URL from upload_image_to_fal + "prompt": prompt, + "quality": quality, + "duration": duration_seconds, + "motion_mode": motion_mode, + "negative_prompt": negative_prompt if negative_prompt else None, + "template_id": pixverse_template, + "seed": seed, + }, estimated_duration=AVERAGE_DURATION_I2V) # TODO: verify fal.ai field names + return IO.NodeOutput(await download_url_to_video_output(result["video"]["url"])) class PixverseTransitionVideoNode(IO.ComfyNode): @@ -352,12 +296,9 @@ def define_schema(cls) -> IO.Schema: ], outputs=[IO.Video.Output()], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=PRICE_BADGE_VIDEO, ) @classmethod @@ -384,39 +325,17 @@ async def execute( elif duration_seconds != PixverseDuration.dur_5: motion_mode = PixverseMotionMode.normal - response_api = await sync_op( - cls, - ApiEndpoint(path="/proxy/pixverse/video/transition/generate", method="POST"), - response_model=PixverseVideoResponse, - data=PixverseTransitionVideoRequest( - first_frame_img=first_frame_id, - last_frame_img=last_frame_id, - prompt=prompt, - quality=quality, - duration=duration_seconds, - motion_mode=motion_mode, - negative_prompt=negative_prompt if negative_prompt else None, - seed=seed, - ), - ) - - if response_api.Resp is None: - raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'") - - response_poll = await poll_op( - cls, - ApiEndpoint(path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}"), - response_model=PixverseGenerationStatusResponse, - completed_statuses=[PixverseStatus.successful], - failed_statuses=[ - PixverseStatus.contents_moderation, - PixverseStatus.failed, - PixverseStatus.deleted, - ], - status_extractor=lambda x: x.Resp.status, - estimated_duration=AVERAGE_DURATION_T2V, - ) - return IO.NodeOutput(await download_url_to_video_output(response_poll.Resp.url)) + result = await fal_run(cls, FAL_PIXVERSE_I2V, { + "first_frame_image_url": first_frame_id, # now a URL from upload_image_to_fal + "last_frame_image_url": last_frame_id, + "prompt": prompt, + "quality": quality, + "duration": duration_seconds, + "motion_mode": motion_mode, + "negative_prompt": negative_prompt if negative_prompt else None, + "seed": seed, + }, estimated_duration=AVERAGE_DURATION_T2V) # TODO: verify fal.ai field names; use correct fal model for transition + return IO.NodeOutput(await download_url_to_video_output(result["video"]["url"])) PRICE_BADGE_VIDEO = IO.PriceBadge( diff --git a/comfy_api_nodes/nodes_recraft.py b/comfy_api_nodes/nodes_recraft.py index 4d1d508fa915..b7106f97b627 100644 --- a/comfy_api_nodes/nodes_recraft.py +++ b/comfy_api_nodes/nodes_recraft.py @@ -24,21 +24,23 @@ get_v3_substyles, ) from comfy_api_nodes.util import ( - ApiEndpoint, bytesio_to_image_tensor, download_url_as_bytesio, resize_mask_to_image, - sync_op, tensor_to_bytesio, validate_string, ) +from comfy_api_nodes.util.client import fal_run +from comfy_api_nodes.util.upload_helpers import upload_file_to_fal, upload_image_to_fal from comfy_extras.nodes_images import SVG +FAL_RECRAFT_V3 = "fal-ai/recraft/v3/text-to-image" + async def handle_recraft_file_request( cls: type[IO.ComfyNode], image: torch.Tensor, - path: str, + fal_endpoint: str, mask: torch.Tensor | None = None, total_pixels: int = 4096 * 4096, timeout: int = 1024, @@ -46,26 +48,31 @@ async def handle_recraft_file_request( ) -> list[BytesIO]: """Handle sending common Recraft file-only request to get back file bytes.""" - files = {"image": tensor_to_bytesio(image, total_pixels=total_pixels).read()} + image_url = await upload_image_to_fal(image) + data = {"image_url": image_url} # TODO: verify fal.ai field names if mask is not None: - files["mask"] = tensor_to_bytesio(mask, total_pixels=total_pixels).read() + mask_url = await upload_image_to_fal(mask) + data["mask_url"] = mask_url # TODO: verify fal.ai field names + + if request is not None: + if hasattr(request, "model_dump"): + req_dict = request.model_dump(exclude_none=True) + else: + req_dict = request.__dict__ + data.update(req_dict) + + result = await fal_run(cls, fal_endpoint, data) - response = await sync_op( - cls, - endpoint=ApiEndpoint(path=path, method="POST"), - response_model=RecraftImageGenerationResponse, - data=request if request else None, - files=files, - content_type="multipart/form-data", - multipart_parser=recraft_multipart_parser, - max_retries=1, - ) all_bytesio = [] - if response.image is not None: - all_bytesio.append(await download_url_as_bytesio(response.image.url, timeout=timeout)) - else: - for data in response.data: - all_bytesio.append(await download_url_as_bytesio(data.url, timeout=timeout)) + # TODO: verify fal.ai field names + if "image" in result and result["image"] is not None: + all_bytesio.append(await download_url_as_bytesio(result["image"]["url"], timeout=timeout)) + elif "data" in result: + for item in result["data"]: + all_bytesio.append(await download_url_as_bytesio(item["url"], timeout=timeout)) + elif "images" in result: + for item in result["images"]: + all_bytesio.append(await download_url_as_bytesio(item["url"], timeout=timeout)) return all_bytesio @@ -355,14 +362,9 @@ def define_schema(cls): IO.String.Output(display_name="style_id"), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd": 0.04}""", - ), ) @classmethod @@ -371,27 +373,26 @@ async def execute( style: str, images: IO.Autogrow.Type, ) -> IO.NodeOutput: - files = [] + image_urls = [] total_size = 0 max_total_size = 5 * 1024 * 1024 # 5 MB limit for i, img in enumerate(list(images.values())): - file_bytes = tensor_to_bytesio(img, total_pixels=2048 * 2048, mime_type="image/webp").read() + file_bytesio = tensor_to_bytesio(img, total_pixels=2048 * 2048, mime_type="image/webp") + file_bytes = file_bytesio.read() total_size += len(file_bytes) if total_size > max_total_size: raise Exception("Total size of all images exceeds 5 MB limit.") - files.append((f"file{i + 1}", file_bytes)) - - response = await sync_op( - cls, - endpoint=ApiEndpoint(path="/proxy/recraft/styles", method="POST"), - response_model=RecraftCreateStyleResponse, - files=files, - data=RecraftCreateStyleRequest(style=style), - content_type="multipart/form-data", - max_retries=1, - ) + file_bytesio.seek(0) + url = await upload_file_to_fal(file_bytesio, "image/webp") + image_urls.append(url) + + # TODO: verify fal.ai field names + result = await fal_run(cls, FAL_RECRAFT_V3, { + "style": style, + "image_urls": image_urls, + }) - return IO.NodeOutput(response.id) + return IO.NodeOutput(result["id"]) class RecraftTextToImageNode(IO.ComfyNode): @@ -444,15 +445,9 @@ def define_schema(cls): IO.Image.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["n"]), - expr="""{"type":"usd","usd": $round(0.04 * widgets.n, 2)}""", - ), ) @classmethod @@ -478,27 +473,22 @@ async def execute( if not negative_prompt: negative_prompt = None - response = await sync_op( - cls, - ApiEndpoint(path="/proxy/recraft/image_generation", method="POST"), - response_model=RecraftImageGenerationResponse, - data=RecraftImageGenerationRequest( - prompt=prompt, - negative_prompt=negative_prompt, - model="recraftv3", - size=size, - n=n, - style=recraft_style.style, - substyle=recraft_style.substyle, - style_id=recraft_style.style_id, - controls=controls_api, - ), - max_retries=1, + request_data = RecraftImageGenerationRequest( + prompt=prompt, + negative_prompt=negative_prompt, + model="recraftv3", + size=size, + n=n, + style=recraft_style.style, + substyle=recraft_style.substyle, + style_id=recraft_style.style_id, + controls=controls_api, ) + result = await fal_run(cls, FAL_RECRAFT_V3, request_data.model_dump(exclude_none=True)) images = [] - for data in response.data: + for item in result["data"]: # TODO: verify fal.ai field names with handle_recraft_image_output(): - image = bytesio_to_image_tensor(await download_url_as_bytesio(data.url, timeout=1024)) + image = bytesio_to_image_tensor(await download_url_as_bytesio(item["url"], timeout=1024)) if len(image.shape) < 4: image = image.unsqueeze(0) images.append(image) @@ -560,15 +550,9 @@ def define_schema(cls): IO.Image.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["n"]), - expr="""{"type":"usd","usd": $round(0.04 * widgets.n, 2)}""", - ), ) @classmethod @@ -614,7 +598,7 @@ async def execute( sub_bytes = await handle_recraft_file_request( cls, image=image[i], - path="/proxy/recraft/images/imageToImage", + fal_endpoint=FAL_RECRAFT_V3, request=request, ) with handle_recraft_image_output(): @@ -665,15 +649,9 @@ def define_schema(cls): IO.Image.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["n"]), - expr="""{"type":"usd","usd": $round(0.04 * widgets.n, 2)}""", - ), ) @classmethod @@ -716,7 +694,7 @@ async def execute( cls, image=image[i], mask=mask[i : i + 1], - path="/proxy/recraft/images/inpaint", + fal_endpoint=FAL_RECRAFT_V3, request=request, ) with handle_recraft_image_output(): @@ -770,15 +748,9 @@ def define_schema(cls): IO.SVG.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["n"]), - expr="""{"type":"usd","usd": $round(0.08 * widgets.n, 2)}""", - ), ) @classmethod @@ -803,25 +775,20 @@ async def execute( if not negative_prompt: negative_prompt = None - response = await sync_op( - cls, - ApiEndpoint(path="/proxy/recraft/image_generation", method="POST"), - response_model=RecraftImageGenerationResponse, - data=RecraftImageGenerationRequest( - prompt=prompt, - negative_prompt=negative_prompt, - model="recraftv3", - size=size, - n=n, - style=recraft_style.style, - substyle=recraft_style.substyle, - controls=controls_api, - ), - max_retries=1, + request_data = RecraftImageGenerationRequest( + prompt=prompt, + negative_prompt=negative_prompt, + model="recraftv3", + size=size, + n=n, + style=recraft_style.style, + substyle=recraft_style.substyle, + controls=controls_api, ) + result = await fal_run(cls, FAL_RECRAFT_V3, request_data.model_dump(exclude_none=True)) svg_data = [] - for data in response.data: - svg_data.append(await download_url_as_bytesio(data.url, timeout=1024)) + for item in result["data"]: # TODO: verify fal.ai field names + svg_data.append(await download_url_as_bytesio(item["url"], timeout=1024)) return IO.NodeOutput(SVG(svg_data)) @@ -841,15 +808,9 @@ def define_schema(cls): IO.SVG.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(), - expr="""{"type":"usd","usd": 0.01}""", - ), ) @classmethod @@ -861,7 +822,7 @@ async def execute(cls, image: torch.Tensor) -> IO.NodeOutput: sub_bytes = await handle_recraft_file_request( cls, image=image[i], - path="/proxy/recraft/images/vectorize", + fal_endpoint=FAL_RECRAFT_V3, ) svgs.append(SVG(sub_bytes)) pbar.update(1) @@ -903,14 +864,9 @@ def define_schema(cls): IO.Image.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.04}""", - ), ) @classmethod @@ -947,7 +903,7 @@ async def execute( sub_bytes = await handle_recraft_file_request( cls, image=image[i], - path="/proxy/recraft/images/replaceBackground", + fal_endpoint=FAL_RECRAFT_V3, request=request, ) images.append(torch.cat([bytesio_to_image_tensor(x) for x in sub_bytes], dim=0)) @@ -973,14 +929,9 @@ def define_schema(cls): IO.Mask.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.01}""", - ), ) @classmethod @@ -992,7 +943,7 @@ async def execute(cls, image: torch.Tensor) -> IO.NodeOutput: sub_bytes = await handle_recraft_file_request( cls, image=image[i], - path="/proxy/recraft/images/removeBackground", + fal_endpoint=FAL_RECRAFT_V3, ) images.append(torch.cat([bytesio_to_image_tensor(x) for x in sub_bytes], dim=0)) pbar.update(1) @@ -1004,7 +955,7 @@ async def execute(cls, image: torch.Tensor) -> IO.NodeOutput: class RecraftCrispUpscaleNode(IO.ComfyNode): - RECRAFT_PATH = "/proxy/recraft/images/crispUpscale" + RECRAFT_ENDPOINT = FAL_RECRAFT_V3 @classmethod def define_schema(cls): @@ -1022,14 +973,9 @@ def define_schema(cls): IO.Image.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.004}""", - ), ) @classmethod @@ -1041,7 +987,7 @@ async def execute(cls, image: torch.Tensor) -> IO.NodeOutput: sub_bytes = await handle_recraft_file_request( cls, image=image[i], - path=cls.RECRAFT_PATH, + fal_endpoint=cls.RECRAFT_ENDPOINT, ) images.append(torch.cat([bytesio_to_image_tensor(x) for x in sub_bytes], dim=0)) pbar.update(1) @@ -1050,7 +996,7 @@ async def execute(cls, image: torch.Tensor) -> IO.NodeOutput: class RecraftCreativeUpscaleNode(RecraftCrispUpscaleNode): - RECRAFT_PATH = "/proxy/recraft/images/creativeUpscale" + RECRAFT_ENDPOINT = FAL_RECRAFT_V3 @classmethod def define_schema(cls): @@ -1068,14 +1014,9 @@ def define_schema(cls): IO.Image.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.25}""", - ), ) @@ -1152,20 +1093,9 @@ def define_schema(cls): IO.Image.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["model", "n"]), - expr=""" - ( - $prices := {"recraftv4": 0.04, "recraftv4_pro": 0.25}; - {"type":"usd","usd": $lookup($prices, widgets.model) * widgets.n} - ) - """, - ), ) @classmethod @@ -1179,24 +1109,19 @@ async def execute( recraft_controls: RecraftControls | None = None, ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False, min_length=1, max_length=10000) - response = await sync_op( - cls, - ApiEndpoint(path="/proxy/recraft/image_generation", method="POST"), - response_model=RecraftImageGenerationResponse, - data=RecraftImageGenerationRequest( - prompt=prompt, - negative_prompt=negative_prompt if negative_prompt else None, - model=model["model"], - size=model["size"], - n=n, - controls=recraft_controls.create_api_model() if recraft_controls else None, - ), - max_retries=1, + request_data = RecraftImageGenerationRequest( + prompt=prompt, + negative_prompt=negative_prompt if negative_prompt else None, + model=model["model"], + size=model["size"], + n=n, + controls=recraft_controls.create_api_model() if recraft_controls else None, ) + result = await fal_run(cls, FAL_RECRAFT_V3, request_data.model_dump(exclude_none=True)) images = [] - for data in response.data: + for item in result["data"]: # TODO: verify fal.ai field names with handle_recraft_image_output(): - image = bytesio_to_image_tensor(await download_url_as_bytesio(data.url, timeout=1024)) + image = bytesio_to_image_tensor(await download_url_as_bytesio(item["url"], timeout=1024)) if len(image.shape) < 4: image = image.unsqueeze(0) images.append(image) @@ -1276,20 +1201,9 @@ def define_schema(cls): IO.SVG.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["model", "n"]), - expr=""" - ( - $prices := {"recraftv4": 0.08, "recraftv4_pro": 0.30}; - {"type":"usd","usd": $lookup($prices, widgets.model) * widgets.n} - ) - """, - ), ) @classmethod @@ -1303,25 +1217,20 @@ async def execute( recraft_controls: RecraftControls | None = None, ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False, min_length=1, max_length=10000) - response = await sync_op( - cls, - ApiEndpoint(path="/proxy/recraft/image_generation", method="POST"), - response_model=RecraftImageGenerationResponse, - data=RecraftImageGenerationRequest( - prompt=prompt, - negative_prompt=negative_prompt if negative_prompt else None, - model=model["model"], - size=model["size"], - n=n, - style="vector_illustration", - substyle=None, - controls=recraft_controls.create_api_model() if recraft_controls else None, - ), - max_retries=1, + request_data = RecraftImageGenerationRequest( + prompt=prompt, + negative_prompt=negative_prompt if negative_prompt else None, + model=model["model"], + size=model["size"], + n=n, + style="vector_illustration", + substyle=None, + controls=recraft_controls.create_api_model() if recraft_controls else None, ) + result = await fal_run(cls, FAL_RECRAFT_V3, request_data.model_dump(exclude_none=True)) svg_data = [] - for data in response.data: - svg_data.append(await download_url_as_bytesio(data.url, timeout=1024)) + for item in result["data"]: # TODO: verify fal.ai field names + svg_data.append(await download_url_as_bytesio(item["url"], timeout=1024)) return IO.NodeOutput(SVG(svg_data)) diff --git a/comfy_api_nodes/nodes_rodin.py b/comfy_api_nodes/nodes_rodin.py index 2b829b8db0a3..69e9b481bf2c 100644 --- a/comfy_api_nodes/nodes_rodin.py +++ b/comfy_api_nodes/nodes_rodin.py @@ -13,22 +13,14 @@ from io import BytesIO from typing_extensions import override from PIL import Image -from comfy_api_nodes.apis.rodin import ( - Rodin3DGenerateRequest, - Rodin3DGenerateResponse, - Rodin3DCheckStatusRequest, - Rodin3DCheckStatusResponse, - Rodin3DDownloadRequest, - Rodin3DDownloadResponse, - JobStatus, -) from comfy_api_nodes.util import ( - sync_op, - poll_op, - ApiEndpoint, download_url_to_bytesio, download_url_to_file_3d, ) +from comfy_api_nodes.util.client import fal_run +from comfy_api_nodes.util.upload_helpers import upload_file_to_fal + +FAL_RODIN_V2 = "fal-ai/hyper3d/rodin/v2" from comfy_api.latest import ComfyExtension, IO, Types @@ -132,98 +124,62 @@ async def create_generate_task( if len(images) > 5: raise Exception("Rodin 3D generate requires up to 5 image.") - response = await sync_op( - cls, - ApiEndpoint(path="/proxy/rodin/api/v2/rodin", method="POST"), - response_model=Rodin3DGenerateResponse, - data=Rodin3DGenerateRequest( - seed=seed, - tier=tier, - material=material, - quality_override=quality_override, - mesh_mode=mesh_mode, - TAPose=ta_pose, - ), - files=[ - ( - "images", - open(image, "rb") if isinstance(image, str) else tensor_to_filelike(image) - ) - for image in images if image is not None - ], - content_type="multipart/form-data", - ) - - if hasattr(response, "error"): - error_message = f"Rodin3D Create 3D generate Task Failed. Message: {response.message}, error: {response.error}" + image_urls = [] + for image in images: + if image is None: + continue + if isinstance(image, str): + with open(image, "rb") as f: + bio = BytesIO(f.read()) + else: + bio = tensor_to_filelike(image) + url = await upload_file_to_fal(bio, "image/png") + image_urls.append(url) + + result = await fal_run(cls, FAL_RODIN_V2, { + "image_urls": image_urls, + "seed": seed, + "tier": tier, + "material": material, + "quality_override": quality_override, + "mesh_mode": mesh_mode, + "TAPose": ta_pose, + }) + + # TODO: verify fal.ai field names + if "error" in result: + error_message = f"Rodin3D Create 3D generate Task Failed. Message: {result.get('message', '')}, error: {result['error']}" logging.error(error_message) raise Exception(error_message) logging.info("[ Rodin3D API - Submit Jobs ] Submit Generate Task Success!") - subscription_key = response.jobs.subscription_key - task_uuid = response.uuid + task_uuid = result.get("uuid", result.get("id", "rodin_task")) logging.info("[ Rodin3D API - Submit Jobs ] UUID: %s", task_uuid) - return task_uuid, subscription_key - - -def check_rodin_status(response: Rodin3DCheckStatusResponse) -> str: - all_done = all(job.status == JobStatus.Done for job in response.jobs) - status_list = [str(job.status) for job in response.jobs] - logging.info("[ Rodin3D API - CheckStatus ] Generate Status: %s", status_list) - if any(job.status == JobStatus.Failed for job in response.jobs): - logging.error("[ Rodin3D API - CheckStatus ] Generate Failed: %s, Please try again.", status_list) - raise Exception("[ Rodin3D API ] Generate Failed, Please Try again.") - if all_done: - return "DONE" - return "Generating" - -def extract_progress(response: Rodin3DCheckStatusResponse) -> int | None: - if not response.jobs: - return None - completed_count = sum(1 for job in response.jobs if job.status == JobStatus.Done) - return int((completed_count / len(response.jobs)) * 100) - - -async def poll_for_task_status(subscription_key: str, cls: type[IO.ComfyNode]) -> Rodin3DCheckStatusResponse: - logging.info("[ Rodin3D API - CheckStatus ] Generate Start!") - return await poll_op( - cls, - ApiEndpoint(path="/proxy/rodin/api/v2/status", method="POST"), - response_model=Rodin3DCheckStatusResponse, - data=Rodin3DCheckStatusRequest(subscription_key=subscription_key), - status_extractor=check_rodin_status, - progress_extractor=extract_progress, - ) - - -async def get_rodin_download_list(uuid: str, cls: type[IO.ComfyNode]) -> Rodin3DDownloadResponse: - logging.info("[ Rodin3D API - Downloading ] Generate Successfully!") - return await sync_op( - cls, - ApiEndpoint(path="/proxy/rodin/api/v2/download", method="POST"), - response_model=Rodin3DDownloadResponse, - data=Rodin3DDownloadRequest(task_uuid=uuid), - monitor_progress=False, - ) + return task_uuid, result -async def download_files(url_list, task_uuid: str) -> tuple[str | None, Types.File3D | None]: +async def download_files(result: dict, task_uuid: str) -> tuple[str | None, Types.File3D | None]: + """Download files from the fal_run result.""" result_folder_name = f"Rodin3D_{task_uuid}" save_path = os.path.join(comfy_paths.get_output_directory(), result_folder_name) os.makedirs(save_path, exist_ok=True) model_file_path = None file_3d = None - for i in url_list.list: - file_path = os.path.join(save_path, i.name) - if i.name.lower().endswith(".glb"): - model_file_path = os.path.join(result_folder_name, i.name) - file_3d = await download_url_to_file_3d(i.url, "glb") + # TODO: verify fal.ai field names + download_list = result.get("list", result.get("downloads", [])) + for i in download_list: + name = i.get("name", i.get("filename", "")) + url = i.get("url", "") + file_path = os.path.join(save_path, name) + if name.lower().endswith(".glb"): + model_file_path = os.path.join(result_folder_name, name) + file_3d = await download_url_to_file_3d(url, "glb") # Save to disk for backward compatibility with open(file_path, "wb") as f: f.write(file_3d.get_bytes()) else: - await download_url_to_bytesio(i.url, file_path) + await download_url_to_bytesio(url, file_path) return model_file_path, file_3d @@ -247,14 +203,9 @@ def define_schema(cls) -> IO.Schema: IO.File3DGLB.Output(display_name="GLB"), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.4}""", - ), ) @classmethod @@ -271,7 +222,7 @@ async def execute( for i in range(num_images): m_images.append(Images[i]) mesh_mode, quality_override = get_quality_mode(Polygon_count) - task_uuid, subscription_key = await create_generate_task( + task_uuid, result = await create_generate_task( cls, images=m_images, seed=Seed, @@ -280,9 +231,7 @@ async def execute( tier=tier, mesh_mode=mesh_mode, ) - await poll_for_task_status(subscription_key, cls) - download_list = await get_rodin_download_list(task_uuid, cls) - model_path, file_3d = await download_files(download_list, task_uuid) + model_path, file_3d = await download_files(result, task_uuid) return IO.NodeOutput(model_path, file_3d) @@ -306,14 +255,9 @@ def define_schema(cls) -> IO.Schema: IO.File3DGLB.Output(display_name="GLB"), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.4}""", - ), ) @classmethod @@ -330,7 +274,7 @@ async def execute( for i in range(num_images): m_images.append(Images[i]) mesh_mode, quality_override = get_quality_mode(Polygon_count) - task_uuid, subscription_key = await create_generate_task( + task_uuid, result = await create_generate_task( cls, images=m_images, seed=Seed, @@ -339,9 +283,7 @@ async def execute( tier=tier, mesh_mode=mesh_mode, ) - await poll_for_task_status(subscription_key, cls) - download_list = await get_rodin_download_list(task_uuid, cls) - model_path, file_3d = await download_files(download_list, task_uuid) + model_path, file_3d = await download_files(result, task_uuid) return IO.NodeOutput(model_path, file_3d) @@ -365,14 +307,9 @@ def define_schema(cls) -> IO.Schema: IO.File3DGLB.Output(display_name="GLB"), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.4}""", - ), ) @classmethod @@ -388,7 +325,7 @@ async def execute( for i in range(num_images): m_images.append(Images[i]) mesh_mode, quality_override = get_quality_mode(Polygon_count) - task_uuid, subscription_key = await create_generate_task( + task_uuid, result = await create_generate_task( cls, images=m_images, seed=Seed, @@ -397,9 +334,7 @@ async def execute( tier="Smooth", mesh_mode=mesh_mode, ) - await poll_for_task_status(subscription_key, cls) - download_list = await get_rodin_download_list(task_uuid, cls) - model_path, file_3d = await download_files(download_list, task_uuid) + model_path, file_3d = await download_files(result, task_uuid) return IO.NodeOutput(model_path, file_3d) @@ -430,14 +365,9 @@ def define_schema(cls) -> IO.Schema: IO.File3DGLB.Output(display_name="GLB"), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.4}""", - ), ) @classmethod @@ -450,7 +380,7 @@ async def execute( m_images = [] for i in range(num_images): m_images.append(Images[i]) - task_uuid, subscription_key = await create_generate_task( + task_uuid, result = await create_generate_task( cls, images=m_images, seed=Seed, @@ -459,9 +389,7 @@ async def execute( tier="Sketch", mesh_mode="Quad", ) - await poll_for_task_status(subscription_key, cls) - download_list = await get_rodin_download_list(task_uuid, cls) - model_path, file_3d = await download_files(download_list, task_uuid) + model_path, file_3d = await download_files(result, task_uuid) return IO.NodeOutput(model_path, file_3d) @@ -500,14 +428,9 @@ def define_schema(cls) -> IO.Schema: IO.File3DGLB.Output(display_name="GLB"), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.4}""", - ), ) @classmethod @@ -525,7 +448,7 @@ async def execute( for i in range(num_images): m_images.append(Images[i]) mesh_mode, quality_override = get_quality_mode(Polygon_count) - task_uuid, subscription_key = await create_generate_task( + task_uuid, result = await create_generate_task( cls, images=m_images, seed=Seed, @@ -535,9 +458,7 @@ async def execute( mesh_mode=mesh_mode, ta_pose=TAPose, ) - await poll_for_task_status(subscription_key, cls) - download_list = await get_rodin_download_list(task_uuid, cls) - model_path, file_3d = await download_files(download_list, task_uuid) + model_path, file_3d = await download_files(result, task_uuid) return IO.NodeOutput(model_path, file_3d) diff --git a/comfy_api_nodes/nodes_runway.py b/comfy_api_nodes/nodes_runway.py deleted file mode 100644 index 573170ba2d2c..000000000000 --- a/comfy_api_nodes/nodes_runway.py +++ /dev/null @@ -1,534 +0,0 @@ -"""Runway API Nodes - -API Docs: - - https://docs.dev.runwayml.com/api/#tag/Task-management/paths/~1v1~1tasks~1%7Bid%7D/delete - -User Guides: - - https://help.runwayml.com/hc/en-us/sections/30265301423635-Gen-3-Alpha - - https://help.runwayml.com/hc/en-us/articles/37327109429011-Creating-with-Gen-4-Video - - https://help.runwayml.com/hc/en-us/articles/33927968552339-Creating-with-Act-One-on-Gen-3-Alpha-and-Turbo - - https://help.runwayml.com/hc/en-us/articles/34170748696595-Creating-with-Keyframes-on-Gen-3 - -""" - -from enum import Enum - -from typing_extensions import override - -from comfy_api.latest import IO, ComfyExtension, Input, InputImpl -from comfy_api_nodes.apis.runway import ( - RunwayImageToVideoRequest, - RunwayImageToVideoResponse, - RunwayTaskStatusResponse as TaskStatusResponse, - RunwayModelEnum as Model, - RunwayDurationEnum as Duration, - RunwayAspectRatioEnum as AspectRatio, - RunwayPromptImageObject, - RunwayPromptImageDetailedObject, - RunwayTextToImageRequest, - RunwayTextToImageResponse, - Model4, - ReferenceImage, - RunwayTextToImageAspectRatioEnum, -) -from comfy_api_nodes.util import ( - image_tensor_pair_to_batch, - validate_string, - validate_image_dimensions, - validate_image_aspect_ratio, - upload_images_to_comfyapi, - download_url_to_video_output, - download_url_to_image_tensor, - ApiEndpoint, - sync_op, - poll_op, -) - -PATH_IMAGE_TO_VIDEO = "/proxy/runway/image_to_video" -PATH_TEXT_TO_IMAGE = "/proxy/runway/text_to_image" -PATH_GET_TASK_STATUS = "/proxy/runway/tasks" - -AVERAGE_DURATION_I2V_SECONDS = 64 -AVERAGE_DURATION_FLF_SECONDS = 256 -AVERAGE_DURATION_T2I_SECONDS = 41 - - -class RunwayApiError(Exception): - """Base exception for Runway API errors.""" - - pass - - -class RunwayGen4TurboAspectRatio(str, Enum): - """Aspect ratios supported for Image to Video API when using gen4_turbo model.""" - - field_1280_720 = "1280:720" - field_720_1280 = "720:1280" - field_1104_832 = "1104:832" - field_832_1104 = "832:1104" - field_960_960 = "960:960" - field_1584_672 = "1584:672" - - -class RunwayGen3aAspectRatio(str, Enum): - """Aspect ratios supported for Image to Video API when using gen3a_turbo model.""" - - field_768_1280 = "768:1280" - field_1280_768 = "1280:768" - - -def get_video_url_from_task_status(response: TaskStatusResponse) -> str | None: - """Returns the video URL from the task status response if it exists.""" - if hasattr(response, "output") and len(response.output) > 0: - return response.output[0] - return None - - -def extract_progress_from_task_status( - response: TaskStatusResponse, -) -> float | None: - if hasattr(response, "progress") and response.progress is not None: - return response.progress * 100 - return None - - -def get_image_url_from_task_status(response: TaskStatusResponse) -> str | None: - """Returns the image URL from the task status response if it exists.""" - if hasattr(response, "output") and len(response.output) > 0: - return response.output[0] - return None - - -async def get_response( - cls: type[IO.ComfyNode], task_id: str, estimated_duration: int | None = None -) -> TaskStatusResponse: - """Poll the task status until it is finished then get the response.""" - return await poll_op( - cls, - ApiEndpoint(path=f"{PATH_GET_TASK_STATUS}/{task_id}"), - response_model=TaskStatusResponse, - status_extractor=lambda r: r.status.value, - estimated_duration=estimated_duration, - progress_extractor=extract_progress_from_task_status, - ) - - -async def generate_video( - cls: type[IO.ComfyNode], - request: RunwayImageToVideoRequest, - estimated_duration: int | None = None, -) -> InputImpl.VideoFromFile: - initial_response = await sync_op( - cls, - endpoint=ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"), - response_model=RunwayImageToVideoResponse, - data=request, - ) - - final_response = await get_response(cls, initial_response.id, estimated_duration) - if not final_response.output: - raise RunwayApiError("Runway task succeeded but no video data found in response.") - - video_url = get_video_url_from_task_status(final_response) - return await download_url_to_video_output(video_url) - - -class RunwayImageToVideoNodeGen3a(IO.ComfyNode): - - @classmethod - def define_schema(cls): - return IO.Schema( - node_id="RunwayImageToVideoNodeGen3a", - display_name="Runway Image to Video (Gen3a Turbo)", - category="api node/video/Runway", - description="Generate a video from a single starting frame using Gen3a Turbo model. " - "Before diving in, review these best practices to ensure that " - "your input selections will set your generation up for success: " - "https://help.runwayml.com/hc/en-us/articles/33927968552339-Creating-with-Act-One-on-Gen-3-Alpha-and-Turbo.", - inputs=[ - IO.String.Input( - "prompt", - multiline=True, - default="", - tooltip="Text prompt for the generation", - ), - IO.Image.Input( - "start_frame", - tooltip="Start frame to be used for the video", - ), - IO.Combo.Input( - "duration", - options=Duration, - ), - IO.Combo.Input( - "ratio", - options=RunwayGen3aAspectRatio, - ), - IO.Int.Input( - "seed", - default=0, - min=0, - max=4294967295, - step=1, - control_after_generate=True, - display_mode=IO.NumberDisplay.number, - tooltip="Random seed for generation", - ), - ], - outputs=[ - IO.Video.Output(), - ], - hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, - IO.Hidden.unique_id, - ], - is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["duration"]), - expr="""{"type":"usd","usd": 0.0715 * widgets.duration}""", - ), - ) - - @classmethod - async def execute( - cls, - prompt: str, - start_frame: Input.Image, - duration: str, - ratio: str, - seed: int, - ) -> IO.NodeOutput: - validate_string(prompt, min_length=1) - validate_image_dimensions(start_frame, max_width=7999, max_height=7999) - validate_image_aspect_ratio(start_frame, (1, 2), (2, 1)) - - download_urls = await upload_images_to_comfyapi( - cls, - start_frame, - max_images=1, - mime_type="image/png", - ) - - return IO.NodeOutput( - await generate_video( - cls, - RunwayImageToVideoRequest( - promptText=prompt, - seed=seed, - model=Model("gen3a_turbo"), - duration=Duration(duration), - ratio=AspectRatio(ratio), - promptImage=RunwayPromptImageObject( - root=[RunwayPromptImageDetailedObject(uri=str(download_urls[0]), position="first")] - ), - ), - ) - ) - - -class RunwayImageToVideoNodeGen4(IO.ComfyNode): - - @classmethod - def define_schema(cls): - return IO.Schema( - node_id="RunwayImageToVideoNodeGen4", - display_name="Runway Image to Video (Gen4 Turbo)", - category="api node/video/Runway", - description="Generate a video from a single starting frame using Gen4 Turbo model. " - "Before diving in, review these best practices to ensure that " - "your input selections will set your generation up for success: " - "https://help.runwayml.com/hc/en-us/articles/37327109429011-Creating-with-Gen-4-Video.", - inputs=[ - IO.String.Input( - "prompt", - multiline=True, - default="", - tooltip="Text prompt for the generation", - ), - IO.Image.Input( - "start_frame", - tooltip="Start frame to be used for the video", - ), - IO.Combo.Input( - "duration", - options=Duration, - ), - IO.Combo.Input( - "ratio", - options=RunwayGen4TurboAspectRatio, - ), - IO.Int.Input( - "seed", - default=0, - min=0, - max=4294967295, - step=1, - control_after_generate=True, - display_mode=IO.NumberDisplay.number, - tooltip="Random seed for generation", - ), - ], - outputs=[ - IO.Video.Output(), - ], - hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, - IO.Hidden.unique_id, - ], - is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["duration"]), - expr="""{"type":"usd","usd": 0.0715 * widgets.duration}""", - ), - ) - - @classmethod - async def execute( - cls, - prompt: str, - start_frame: Input.Image, - duration: str, - ratio: str, - seed: int, - ) -> IO.NodeOutput: - validate_string(prompt, min_length=1) - validate_image_dimensions(start_frame, max_width=7999, max_height=7999) - validate_image_aspect_ratio(start_frame, (1, 2), (2, 1)) - - download_urls = await upload_images_to_comfyapi( - cls, - start_frame, - max_images=1, - mime_type="image/png", - ) - - return IO.NodeOutput( - await generate_video( - cls, - RunwayImageToVideoRequest( - promptText=prompt, - seed=seed, - model=Model("gen4_turbo"), - duration=Duration(duration), - ratio=AspectRatio(ratio), - promptImage=RunwayPromptImageObject( - root=[RunwayPromptImageDetailedObject(uri=str(download_urls[0]), position="first")] - ), - ), - estimated_duration=AVERAGE_DURATION_FLF_SECONDS, - ) - ) - - -class RunwayFirstLastFrameNode(IO.ComfyNode): - - @classmethod - def define_schema(cls): - return IO.Schema( - node_id="RunwayFirstLastFrameNode", - display_name="Runway First-Last-Frame to Video", - category="api node/video/Runway", - description="Upload first and last keyframes, draft a prompt, and generate a video. " - "More complex transitions, such as cases where the Last frame is completely different " - "from the First frame, may benefit from the longer 10s duration. " - "This would give the generation more time to smoothly transition between the two inputs. " - "Before diving in, review these best practices to ensure that your input selections " - "will set your generation up for success: " - "https://help.runwayml.com/hc/en-us/articles/34170748696595-Creating-with-Keyframes-on-Gen-3.", - inputs=[ - IO.String.Input( - "prompt", - multiline=True, - default="", - tooltip="Text prompt for the generation", - ), - IO.Image.Input( - "start_frame", - tooltip="Start frame to be used for the video", - ), - IO.Image.Input( - "end_frame", - tooltip="End frame to be used for the video. Supported for gen3a_turbo only.", - ), - IO.Combo.Input( - "duration", - options=Duration, - ), - IO.Combo.Input( - "ratio", - options=RunwayGen3aAspectRatio, - ), - IO.Int.Input( - "seed", - default=0, - min=0, - max=4294967295, - step=1, - control_after_generate=True, - display_mode=IO.NumberDisplay.number, - tooltip="Random seed for generation", - ), - ], - outputs=[ - IO.Video.Output(), - ], - hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, - IO.Hidden.unique_id, - ], - is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["duration"]), - expr="""{"type":"usd","usd": 0.0715 * widgets.duration}""", - ), - ) - - @classmethod - async def execute( - cls, - prompt: str, - start_frame: Input.Image, - end_frame: Input.Image, - duration: str, - ratio: str, - seed: int, - ) -> IO.NodeOutput: - validate_string(prompt, min_length=1) - validate_image_dimensions(start_frame, max_width=7999, max_height=7999) - validate_image_dimensions(end_frame, max_width=7999, max_height=7999) - validate_image_aspect_ratio(start_frame, (1, 2), (2, 1)) - validate_image_aspect_ratio(end_frame, (1, 2), (2, 1)) - - stacked_input_images = image_tensor_pair_to_batch(start_frame, end_frame) - download_urls = await upload_images_to_comfyapi( - cls, - stacked_input_images, - max_images=2, - mime_type="image/png", - ) - if len(download_urls) != 2: - raise RunwayApiError("Failed to upload one or more images to comfy api.") - - return IO.NodeOutput( - await generate_video( - cls, - RunwayImageToVideoRequest( - promptText=prompt, - seed=seed, - model=Model("gen3a_turbo"), - duration=Duration(duration), - ratio=AspectRatio(ratio), - promptImage=RunwayPromptImageObject( - root=[ - RunwayPromptImageDetailedObject(uri=str(download_urls[0]), position="first"), - RunwayPromptImageDetailedObject(uri=str(download_urls[1]), position="last"), - ] - ), - ), - estimated_duration=AVERAGE_DURATION_FLF_SECONDS, - ) - ) - - -class RunwayTextToImageNode(IO.ComfyNode): - - @classmethod - def define_schema(cls): - return IO.Schema( - node_id="RunwayTextToImageNode", - display_name="Runway Text to Image", - category="api node/image/Runway", - description="Generate an image from a text prompt using Runway's Gen 4 model. " - "You can also include reference image to guide the generation.", - inputs=[ - IO.String.Input( - "prompt", - multiline=True, - default="", - tooltip="Text prompt for the generation", - ), - IO.Combo.Input( - "ratio", - options=[model.value for model in RunwayTextToImageAspectRatioEnum], - ), - IO.Image.Input( - "reference_image", - tooltip="Optional reference image to guide the generation", - optional=True, - ), - ], - outputs=[ - IO.Image.Output(), - ], - hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, - IO.Hidden.unique_id, - ], - is_api_node=True, - price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.11}""", - ), - ) - - @classmethod - async def execute( - cls, - prompt: str, - ratio: str, - reference_image: Input.Image | None = None, - ) -> IO.NodeOutput: - validate_string(prompt, min_length=1) - - # Prepare reference images if provided - reference_images = None - if reference_image is not None: - validate_image_dimensions(reference_image, max_width=7999, max_height=7999) - validate_image_aspect_ratio(reference_image, (1, 2), (2, 1)) - download_urls = await upload_images_to_comfyapi( - cls, - reference_image, - max_images=1, - mime_type="image/png", - ) - reference_images = [ReferenceImage(uri=str(download_urls[0]))] - - initial_response = await sync_op( - cls, - endpoint=ApiEndpoint(path=PATH_TEXT_TO_IMAGE, method="POST"), - response_model=RunwayTextToImageResponse, - data=RunwayTextToImageRequest( - promptText=prompt, - model=Model4.gen4_image, - ratio=ratio, - referenceImages=reference_images, - ), - ) - - final_response = await get_response( - cls, - initial_response.id, - estimated_duration=AVERAGE_DURATION_T2I_SECONDS, - ) - if not final_response.output: - raise RunwayApiError("Runway task succeeded but no image data found in response.") - - return IO.NodeOutput(await download_url_to_image_tensor(get_image_url_from_task_status(final_response))) - - -class RunwayExtension(ComfyExtension): - @override - async def get_node_list(self) -> list[type[IO.ComfyNode]]: - return [ - RunwayFirstLastFrameNode, - RunwayImageToVideoNodeGen3a, - RunwayImageToVideoNodeGen4, - RunwayTextToImageNode, - ] - - -async def comfy_entrypoint() -> RunwayExtension: - return RunwayExtension() diff --git a/comfy_api_nodes/nodes_sora.py b/comfy_api_nodes/nodes_sora.py index afc18bb2559a..5fc051578463 100644 --- a/comfy_api_nodes/nodes_sora.py +++ b/comfy_api_nodes/nodes_sora.py @@ -6,13 +6,13 @@ from comfy_api.latest import IO, ComfyExtension from comfy_api_nodes.util import ( - ApiEndpoint, download_url_to_video_output, get_number_of_images, - poll_op, - sync_op, tensor_to_bytesio, ) +from comfy_api_nodes.util.client import fal_run + +FAL_SORA_2 = "fal-ai/sora-2/text-to-video" class Sora2GenerationRequest(BaseModel): @@ -84,29 +84,9 @@ def define_schema(cls): IO.Video.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["model", "size", "duration"]), - expr=""" - ( - $m := widgets.model; - $size := widgets.size; - $dur := widgets.duration; - $isPro := $contains($m, "sora-2-pro"); - $isSora2 := $contains($m, "sora-2"); - $isProSize := ($size = "1024x1792" or $size = "1792x1024"); - $perSec := - $isPro ? ($isProSize ? 0.5 : 0.3) : - $isSora2 ? 0.1 : - ($isProSize ? 0.5 : 0.1); - {"type":"usd","usd": $round($perSec * $dur, 2)} - ) - """, - ), ) @classmethod @@ -121,38 +101,23 @@ async def execute( ): if model == "sora-2" and size not in ("720x1280", "1280x720"): raise ValueError("Invalid size for sora-2 model, only 720x1280 and 1280x720 are supported.") - files_input = None + data = { + "model": model, + "prompt": prompt, + "seconds": str(duration), + "size": size, + } if image is not None: if get_number_of_images(image) != 1: raise ValueError("Currently only one input image is supported.") - files_input = {"input_reference": ("image.png", tensor_to_bytesio(image), "image/png")} - initial_response = await sync_op( - cls, - endpoint=ApiEndpoint(path="/proxy/openai/v1/videos", method="POST"), - data=Sora2GenerationRequest( - model=model, - prompt=prompt, - seconds=str(duration), - size=size, - ), - files=files_input, - response_model=Sora2GenerationResponse, - content_type="multipart/form-data", - ) - if initial_response.error: - raise Exception(initial_response.error["message"]) - - model_time_multiplier = 1 if model == "sora-2" else 2 - await poll_op( - cls, - poll_endpoint=ApiEndpoint(path=f"/proxy/openai/v1/videos/{initial_response.id}"), - response_model=Sora2GenerationResponse, - status_extractor=lambda x: x.status, - poll_interval=8.0, - estimated_duration=int(45 * (duration / 4) * model_time_multiplier), - ) + from comfy_api_nodes.util.upload_helpers import upload_image_to_fal + data["input_reference"] = await upload_image_to_fal( + image[0] if len(image.shape) > 3 else image, "image/png" + ) + result = await fal_run(cls, FAL_SORA_2, data) # TODO: verify fal.ai field names + video_url = result["video"]["url"] return IO.NodeOutput( - await download_url_to_video_output(f"/proxy/openai/v1/videos/{initial_response.id}/content", cls=cls), + await download_url_to_video_output(video_url), ) diff --git a/comfy_api_nodes/nodes_stability.py b/comfy_api_nodes/nodes_stability.py index 9ef13c83ba8a..a9c45fc2a977 100644 --- a/comfy_api_nodes/nodes_stability.py +++ b/comfy_api_nodes/nodes_stability.py @@ -27,10 +27,10 @@ bytesio_to_image_tensor, tensor_to_bytesio, audio_bytes_to_audio_input, - sync_op, - poll_op, - ApiEndpoint, ) +from comfy_api_nodes.util.client import fal_run + +FAL_SD35_MEDIUM = "fal-ai/stable-diffusion-v35-medium" import torch import base64 @@ -124,14 +124,9 @@ def define_schema(cls): IO.Image.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.08}""", - ), ) @classmethod @@ -158,31 +153,22 @@ async def execute( if style_preset == "None": style_preset = None - files = { - "image": image_binary + data = { + "prompt": prompt, + "negative_prompt": negative_prompt, + "aspect_ratio": aspect_ratio, + "seed": seed, + "strength": image_denoise, + "style_preset": style_preset, } + if image is not None: + from comfy_api_nodes.util.upload_helpers import upload_image_to_fal + data["image_url"] = await upload_image_to_fal(image, "image/png") - response_api = await sync_op( - cls, - ApiEndpoint(path="/proxy/stability/v2beta/stable-image/generate/ultra", method="POST"), - response_model=StabilityStableUltraResponse, - data=StabilityStableUltraRequest( - prompt=prompt, - negative_prompt=negative_prompt, - aspect_ratio=aspect_ratio, - seed=seed, - strength=image_denoise, - style_preset=style_preset, - ), - files=files, - content_type="multipart/form-data", - ) - - if response_api.finish_reason != "SUCCESS": - raise Exception(f"Stable Image Ultra generation failed: {response_api.finish_reason}.") - - image_data = base64.b64decode(response_api.image) - returned_image = bytesio_to_image_tensor(BytesIO(image_data)) + result = await fal_run(cls, FAL_SD35_MEDIUM, data) # TODO: verify fal.ai field names; use correct fal model for Ultra + image_url = result["images"][0]["url"] + from comfy_api_nodes.util import download_url_to_image_tensor + returned_image = await download_url_to_image_tensor(image_url) return IO.NodeOutput(returned_image) @@ -266,21 +252,9 @@ def define_schema(cls): IO.Image.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["model"]), - expr=""" - ( - $contains(widgets.model,"large") - ? {"type":"usd","usd":0.065} - : {"type":"usd","usd":0.035} - ) - """, - ), ) @classmethod @@ -312,34 +286,25 @@ async def execute( if style_preset == "None": style_preset = None - files = { - "image": image_binary + data = { + "prompt": prompt, + "negative_prompt": negative_prompt, + "aspect_ratio": aspect_ratio, + "seed": seed, + "strength": image_denoise, + "style_preset": style_preset, + "cfg_scale": cfg_scale, + "model": model, + "mode": mode.value if hasattr(mode, 'value') else mode, } + if image is not None: + from comfy_api_nodes.util.upload_helpers import upload_image_to_fal + data["image_url"] = await upload_image_to_fal(image, "image/png") - response_api = await sync_op( - cls, - ApiEndpoint(path="/proxy/stability/v2beta/stable-image/generate/sd3", method="POST"), - response_model=StabilityStableUltraResponse, - data=StabilityStable3_5Request( - prompt=prompt, - negative_prompt=negative_prompt, - aspect_ratio=aspect_ratio, - seed=seed, - strength=image_denoise, - style_preset=style_preset, - cfg_scale=cfg_scale, - model=model, - mode=mode, - ), - files=files, - content_type="multipart/form-data", - ) - - if response_api.finish_reason != "SUCCESS": - raise Exception(f"Stable Diffusion 3.5 Image generation failed: {response_api.finish_reason}.") - - image_data = base64.b64decode(response_api.image) - returned_image = bytesio_to_image_tensor(BytesIO(image_data)) + result = await fal_run(cls, FAL_SD35_MEDIUM, data) # TODO: verify fal.ai field names + image_url = result["images"][0]["url"] + from comfy_api_nodes.util import download_url_to_image_tensor + returned_image = await download_url_to_image_tensor(image_url) return IO.NodeOutput(returned_image) @@ -395,14 +360,9 @@ def define_schema(cls): IO.Image.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.25}""", - ), ) @classmethod @@ -420,29 +380,17 @@ async def execute( if not negative_prompt: negative_prompt = None - files = { - "image": image_binary - } - - response_api = await sync_op( - cls, - ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/conservative", method="POST"), - response_model=StabilityStableUltraResponse, - data=StabilityUpscaleConservativeRequest( - prompt=prompt, - negative_prompt=negative_prompt, - creativity=round(creativity,2), - seed=seed, - ), - files=files, - content_type="multipart/form-data", - ) - - if response_api.finish_reason != "SUCCESS": - raise Exception(f"Stability Upscale Conservative generation failed: {response_api.finish_reason}.") - - image_data = base64.b64decode(response_api.image) - returned_image = bytesio_to_image_tensor(BytesIO(image_data)) + from comfy_api_nodes.util.upload_helpers import upload_image_to_fal + image_url = await upload_image_to_fal(image, "image/png") + result = await fal_run(cls, FAL_SD35_MEDIUM, { + "image_url": image_url, + "prompt": prompt, + "negative_prompt": negative_prompt, + "creativity": round(creativity, 2), + "seed": seed, + }) # TODO: verify fal.ai field names; use correct fal model for upscale conservative + from comfy_api_nodes.util import download_url_to_image_tensor + returned_image = await download_url_to_image_tensor(result["images"][0]["url"]) return IO.NodeOutput(returned_image) @@ -504,14 +452,9 @@ def define_schema(cls): IO.Image.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.25}""", - ), ) @classmethod @@ -532,38 +475,18 @@ async def execute( if style_preset == "None": style_preset = None - files = { - "image": image_binary - } - - response_api = await sync_op( - cls, - ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/creative", method="POST"), - response_model=StabilityAsyncResponse, - data=StabilityUpscaleCreativeRequest( - prompt=prompt, - negative_prompt=negative_prompt, - creativity=round(creativity,2), - style_preset=style_preset, - seed=seed, - ), - files=files, - content_type="multipart/form-data", - ) - - response_poll = await poll_op( - cls, - ApiEndpoint(path=f"/proxy/stability/v2beta/results/{response_api.id}"), - response_model=StabilityResultsGetResponse, - poll_interval=3, - status_extractor=lambda x: get_async_dummy_status(x), - ) - - if response_poll.finish_reason != "SUCCESS": - raise Exception(f"Stability Upscale Creative generation failed: {response_poll.finish_reason}.") - - image_data = base64.b64decode(response_poll.result) - returned_image = bytesio_to_image_tensor(BytesIO(image_data)) + from comfy_api_nodes.util.upload_helpers import upload_image_to_fal + image_url = await upload_image_to_fal(image, "image/png") + result = await fal_run(cls, FAL_SD35_MEDIUM, { + "image_url": image_url, + "prompt": prompt, + "negative_prompt": negative_prompt, + "creativity": round(creativity, 2), + "style_preset": style_preset, + "seed": seed, + }) # TODO: verify fal.ai field names; use correct fal model for upscale creative + from comfy_api_nodes.util import download_url_to_image_tensor + returned_image = await download_url_to_image_tensor(result["images"][0]["url"]) return IO.NodeOutput(returned_image) @@ -587,37 +510,22 @@ def define_schema(cls): IO.Image.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.01}""", - ), ) @classmethod async def execute(cls, image: torch.Tensor) -> IO.NodeOutput: image_binary = tensor_to_bytesio(image, total_pixels=4096*4096).read() - files = { - "image": image_binary - } - - response_api = await sync_op( - cls, - ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/fast", method="POST"), - response_model=StabilityStableUltraResponse, - files=files, - content_type="multipart/form-data", - ) - - if response_api.finish_reason != "SUCCESS": - raise Exception(f"Stability Upscale Fast failed: {response_api.finish_reason}.") - - image_data = base64.b64decode(response_api.image) - returned_image = bytesio_to_image_tensor(BytesIO(image_data)) + from comfy_api_nodes.util.upload_helpers import upload_image_to_fal + image_url = await upload_image_to_fal(image, "image/png") + result = await fal_run(cls, FAL_SD35_MEDIUM, { + "image_url": image_url, + }) # TODO: verify fal.ai field names; use correct fal model for upscale fast + from comfy_api_nodes.util import download_url_to_image_tensor + returned_image = await download_url_to_image_tensor(result["images"][0]["url"]) return IO.NodeOutput(returned_image) @@ -674,30 +582,27 @@ def define_schema(cls): IO.Audio.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.2}""", - ), ) @classmethod async def execute(cls, model: str, prompt: str, duration: int, seed: int, steps: int) -> IO.NodeOutput: validate_string(prompt, max_length=10000) - payload = StabilityTextToAudioRequest(prompt=prompt, model=model, duration=duration, seed=seed, steps=steps) - response_api = await sync_op( - cls, - ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/text-to-audio", method="POST"), - response_model=StabilityAudioResponse, - data=payload, - content_type="multipart/form-data", - ) - if not response_api.audio: - raise ValueError("No audio file was received in response.") - return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio))) + result = await fal_run(cls, FAL_SD35_MEDIUM, { + "prompt": prompt, + "model": model, + "duration": duration, + "seed": seed, + "steps": steps, + }) # TODO: verify fal.ai field names; use correct fal model for text-to-audio + audio_url = result["audio"]["url"] + from comfy_api_nodes.util import download_url_as_bytesio + from io import BytesIO as _BytesIO + audio_bio = _BytesIO() + await download_url_as_bytesio(audio_url, audio_bio) + return IO.NodeOutput(audio_bytes_to_audio_input(audio_bio.getvalue())) class StabilityAudioToAudio(IO.ComfyNode): @@ -762,14 +667,9 @@ def define_schema(cls): IO.Audio.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.2}""", - ), ) @classmethod @@ -778,20 +678,25 @@ async def execute( ) -> IO.NodeOutput: validate_string(prompt, max_length=10000) validate_audio_duration(audio, 6, 190) - payload = StabilityAudioToAudioRequest( - prompt=prompt, model=model, duration=duration, seed=seed, steps=steps, strength=strength - ) - response_api = await sync_op( - cls, - ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/audio-to-audio", method="POST"), - response_model=StabilityAudioResponse, - data=payload, - content_type="multipart/form-data", - files={"audio": audio_input_to_mp3(audio)}, - ) - if not response_api.audio: - raise ValueError("No audio file was received in response.") - return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio))) + from comfy_api_nodes.util.upload_helpers import upload_file_to_fal + audio_bytes = audio_input_to_mp3(audio) + from io import BytesIO as _BytesIO + audio_bio = _BytesIO(audio_bytes) if isinstance(audio_bytes, bytes) else audio_bytes + audio_url = await upload_file_to_fal(audio_bio, "audio/mpeg") + result = await fal_run(cls, FAL_SD35_MEDIUM, { + "prompt": prompt, + "model": model, + "duration": duration, + "seed": seed, + "steps": steps, + "strength": strength, + "audio_url": audio_url, + }) # TODO: verify fal.ai field names; use correct fal model for audio-to-audio + result_audio_url = result["audio"]["url"] + result_bio = _BytesIO() + from comfy_api_nodes.util import download_url_as_bytesio + await download_url_as_bytesio(result_audio_url, result_bio) + return IO.NodeOutput(audio_bytes_to_audio_input(result_bio.getvalue())) class StabilityAudioInpaint(IO.ComfyNode): @@ -864,14 +769,9 @@ def define_schema(cls): IO.Audio.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.2}""", - ), ) @classmethod @@ -891,26 +791,26 @@ async def execute( raise ValueError(f"Value of mask_end({mask_end}) should be greater then mask_start({mask_start})") validate_audio_duration(audio, 6, 190) - payload = StabilityAudioInpaintRequest( - prompt=prompt, - model=model, - duration=duration, - seed=seed, - steps=steps, - mask_start=mask_start, - mask_end=mask_end, - ) - response_api = await sync_op( - cls, - endpoint=ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/inpaint", method="POST"), - response_model=StabilityAudioResponse, - data=payload, - content_type="multipart/form-data", - files={"audio": audio_input_to_mp3(audio)}, - ) - if not response_api.audio: - raise ValueError("No audio file was received in response.") - return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio))) + from comfy_api_nodes.util.upload_helpers import upload_file_to_fal + audio_bytes = audio_input_to_mp3(audio) + from io import BytesIO as _BytesIO + audio_bio = _BytesIO(audio_bytes) if isinstance(audio_bytes, bytes) else audio_bytes + audio_url = await upload_file_to_fal(audio_bio, "audio/mpeg") + result = await fal_run(cls, FAL_SD35_MEDIUM, { + "prompt": prompt, + "model": model, + "duration": duration, + "seed": seed, + "steps": steps, + "mask_start": mask_start, + "mask_end": mask_end, + "audio_url": audio_url, + }) # TODO: verify fal.ai field names; use correct fal model for audio inpaint + result_audio_url = result["audio"]["url"] + result_bio = _BytesIO() + from comfy_api_nodes.util import download_url_as_bytesio + await download_url_as_bytesio(result_audio_url, result_bio) + return IO.NodeOutput(audio_bytes_to_audio_input(result_bio.getvalue())) class StabilityExtension(ComfyExtension): diff --git a/comfy_api_nodes/nodes_topaz.py b/comfy_api_nodes/nodes_topaz.py deleted file mode 100644 index 6b61bd4b221b..000000000000 --- a/comfy_api_nodes/nodes_topaz.py +++ /dev/null @@ -1,470 +0,0 @@ -import builtins -from io import BytesIO - -import aiohttp -from typing_extensions import override - -from comfy_api.latest import IO, ComfyExtension, Input -from comfy_api_nodes.apis.topaz import ( - CreateVideoRequest, - CreateVideoRequestSource, - CreateVideoResponse, - ImageAsyncTaskResponse, - ImageDownloadResponse, - ImageEnhanceRequest, - ImageStatusResponse, - OutputInformationVideo, - Resolution, - VideoAcceptResponse, - VideoCompleteUploadRequest, - VideoCompleteUploadRequestPart, - VideoCompleteUploadResponse, - VideoEnhancementFilter, - VideoFrameInterpolationFilter, - VideoStatusResponse, -) -from comfy_api_nodes.util import ( - ApiEndpoint, - download_url_to_image_tensor, - download_url_to_video_output, - get_fs_object_size, - get_number_of_images, - poll_op, - sync_op, - upload_images_to_comfyapi, - validate_container_format_is_mp4, -) - -UPSCALER_MODELS_MAP = { - "Starlight (Astra) Fast": "slf-1", - "Starlight (Astra) Creative": "slc-1", -} - - -class TopazImageEnhance(IO.ComfyNode): - @classmethod - def define_schema(cls): - return IO.Schema( - node_id="TopazImageEnhance", - display_name="Topaz Image Enhance", - category="api node/image/Topaz", - description="Industry-standard upscaling and image enhancement.", - inputs=[ - IO.Combo.Input("model", options=["Reimagine"]), - IO.Image.Input("image"), - IO.String.Input( - "prompt", - multiline=True, - default="", - tooltip="Optional text prompt for creative upscaling guidance.", - optional=True, - ), - IO.Combo.Input( - "subject_detection", - options=["All", "Foreground", "Background"], - optional=True, - advanced=True, - ), - IO.Boolean.Input( - "face_enhancement", - default=True, - optional=True, - tooltip="Enhance faces (if present) during processing.", - advanced=True, - ), - IO.Float.Input( - "face_enhancement_creativity", - default=0.0, - min=0.0, - max=1.0, - step=0.01, - display_mode=IO.NumberDisplay.number, - optional=True, - tooltip="Set the creativity level for face enhancement.", - advanced=True, - ), - IO.Float.Input( - "face_enhancement_strength", - default=1.0, - min=0.0, - max=1.0, - step=0.01, - display_mode=IO.NumberDisplay.number, - optional=True, - tooltip="Controls how sharp enhanced faces are relative to the background.", - advanced=True, - ), - IO.Boolean.Input( - "crop_to_fill", - default=False, - optional=True, - tooltip="By default, the image is letterboxed when the output aspect ratio differs. " - "Enable to crop the image to fill the output dimensions.", - advanced=True, - ), - IO.Int.Input( - "output_width", - default=0, - min=0, - max=32000, - step=1, - display_mode=IO.NumberDisplay.number, - optional=True, - tooltip="Zero value means to calculate automatically (usually it will be original size or output_height if specified).", - advanced=True, - ), - IO.Int.Input( - "output_height", - default=0, - min=0, - max=32000, - step=1, - display_mode=IO.NumberDisplay.number, - optional=True, - tooltip="Zero value means to output in the same height as original or output width.", - advanced=True, - ), - IO.Int.Input( - "creativity", - default=3, - min=1, - max=9, - step=1, - display_mode=IO.NumberDisplay.slider, - optional=True, - ), - IO.Boolean.Input( - "face_preservation", - default=True, - optional=True, - tooltip="Preserve subjects' facial identity.", - advanced=True, - ), - IO.Boolean.Input( - "color_preservation", - default=True, - optional=True, - tooltip="Preserve the original colors.", - advanced=True, - ), - ], - outputs=[ - IO.Image.Output(), - ], - hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, - IO.Hidden.unique_id, - ], - is_api_node=True, - ) - - @classmethod - async def execute( - cls, - model: str, - image: Input.Image, - prompt: str = "", - subject_detection: str = "All", - face_enhancement: bool = True, - face_enhancement_creativity: float = 1.0, - face_enhancement_strength: float = 0.8, - crop_to_fill: bool = False, - output_width: int = 0, - output_height: int = 0, - creativity: int = 3, - face_preservation: bool = True, - color_preservation: bool = True, - ) -> IO.NodeOutput: - if get_number_of_images(image) != 1: - raise ValueError("Only one input image is supported.") - download_url = await upload_images_to_comfyapi( - cls, image, max_images=1, mime_type="image/png", total_pixels=4096 * 4096 - ) - initial_response = await sync_op( - cls, - ApiEndpoint(path="/proxy/topaz/image/v1/enhance-gen/async", method="POST"), - response_model=ImageAsyncTaskResponse, - data=ImageEnhanceRequest( - model=model, - prompt=prompt, - subject_detection=subject_detection, - face_enhancement=face_enhancement, - face_enhancement_creativity=face_enhancement_creativity, - face_enhancement_strength=face_enhancement_strength, - crop_to_fill=crop_to_fill, - output_width=output_width if output_width else None, - output_height=output_height if output_height else None, - creativity=creativity, - face_preservation=str(face_preservation).lower(), - color_preservation=str(color_preservation).lower(), - source_url=download_url[0], - output_format="png", - ), - content_type="multipart/form-data", - ) - - await poll_op( - cls, - poll_endpoint=ApiEndpoint(path=f"/proxy/topaz/image/v1/status/{initial_response.process_id}"), - response_model=ImageStatusResponse, - status_extractor=lambda x: x.status, - progress_extractor=lambda x: getattr(x, "progress", 0), - price_extractor=lambda x: x.credits * 0.08, - poll_interval=8.0, - estimated_duration=60, - ) - - results = await sync_op( - cls, - ApiEndpoint(path=f"/proxy/topaz/image/v1/download/{initial_response.process_id}"), - response_model=ImageDownloadResponse, - monitor_progress=False, - ) - return IO.NodeOutput(await download_url_to_image_tensor(results.download_url)) - - -class TopazVideoEnhance(IO.ComfyNode): - @classmethod - def define_schema(cls): - return IO.Schema( - node_id="TopazVideoEnhance", - display_name="Topaz Video Enhance", - category="api node/video/Topaz", - description="Breathe new life into video with powerful upscaling and recovery technology.", - inputs=[ - IO.Video.Input("video"), - IO.Boolean.Input("upscaler_enabled", default=True), - IO.Combo.Input("upscaler_model", options=list(UPSCALER_MODELS_MAP.keys())), - IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]), - IO.Combo.Input( - "upscaler_creativity", - options=["low", "middle", "high"], - default="low", - tooltip="Creativity level (applies only to Starlight (Astra) Creative).", - optional=True, - advanced=True, - ), - IO.Boolean.Input("interpolation_enabled", default=False, optional=True), - IO.Combo.Input("interpolation_model", options=["apo-8"], default="apo-8", optional=True, advanced=True), - IO.Int.Input( - "interpolation_slowmo", - default=1, - min=1, - max=16, - display_mode=IO.NumberDisplay.number, - tooltip="Slow-motion factor applied to the input video. " - "For example, 2 makes the output twice as slow and doubles the duration.", - optional=True, - advanced=True, - ), - IO.Int.Input( - "interpolation_frame_rate", - default=60, - min=15, - max=240, - display_mode=IO.NumberDisplay.number, - tooltip="Output frame rate.", - optional=True, - ), - IO.Boolean.Input( - "interpolation_duplicate", - default=False, - tooltip="Analyze the input for duplicate frames and remove them.", - optional=True, - advanced=True, - ), - IO.Float.Input( - "interpolation_duplicate_threshold", - default=0.01, - min=0.001, - max=0.1, - step=0.001, - display_mode=IO.NumberDisplay.number, - tooltip="Detection sensitivity for duplicate frames.", - optional=True, - advanced=True, - ), - IO.Combo.Input( - "dynamic_compression_level", - options=["Low", "Mid", "High"], - default="Low", - tooltip="CQP level.", - optional=True, - advanced=True, - ), - ], - outputs=[ - IO.Video.Output(), - ], - hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, - IO.Hidden.unique_id, - ], - is_api_node=True, - ) - - @classmethod - async def execute( - cls, - video: Input.Video, - upscaler_enabled: bool, - upscaler_model: str, - upscaler_resolution: str, - upscaler_creativity: str = "low", - interpolation_enabled: bool = False, - interpolation_model: str = "apo-8", - interpolation_slowmo: int = 1, - interpolation_frame_rate: int = 60, - interpolation_duplicate: bool = False, - interpolation_duplicate_threshold: float = 0.01, - dynamic_compression_level: str = "Low", - ) -> IO.NodeOutput: - if upscaler_enabled is False and interpolation_enabled is False: - raise ValueError("There is nothing to do: both upscaling and interpolation are disabled.") - validate_container_format_is_mp4(video) - src_width, src_height = video.get_dimensions() - src_frame_rate = int(video.get_frame_rate()) - duration_sec = video.get_duration() - src_video_stream = video.get_stream_source() - target_width = src_width - target_height = src_height - target_frame_rate = src_frame_rate - filters = [] - if upscaler_enabled: - if "1080p" in upscaler_resolution: - target_pixel_p = 1080 - max_long_side = 1920 - else: - target_pixel_p = 2160 - max_long_side = 3840 - ar = src_width / src_height - if src_width >= src_height: - # Landscape or Square; Attempt to set height to target (e.g., 2160), calculate width - target_height = target_pixel_p - target_width = int(target_height * ar) - # Check if width exceeds standard bounds (for ultra-wide e.g., 21:9 ARs) - if target_width > max_long_side: - target_width = max_long_side - target_height = int(target_width / ar) - else: - # Portrait; Attempt to set width to target (e.g., 2160), calculate height - target_width = target_pixel_p - target_height = int(target_width / ar) - # Check if height exceeds standard bounds - if target_height > max_long_side: - target_height = max_long_side - target_width = int(target_height * ar) - if target_width % 2 != 0: - target_width += 1 - if target_height % 2 != 0: - target_height += 1 - filters.append( - VideoEnhancementFilter( - model=UPSCALER_MODELS_MAP[upscaler_model], - creativity=(upscaler_creativity if UPSCALER_MODELS_MAP[upscaler_model] == "slc-1" else None), - isOptimizedMode=(True if UPSCALER_MODELS_MAP[upscaler_model] == "slc-1" else None), - ), - ) - if interpolation_enabled: - target_frame_rate = interpolation_frame_rate - filters.append( - VideoFrameInterpolationFilter( - model=interpolation_model, - slowmo=interpolation_slowmo, - fps=interpolation_frame_rate, - duplicate=interpolation_duplicate, - duplicate_threshold=interpolation_duplicate_threshold, - ), - ) - initial_res = await sync_op( - cls, - ApiEndpoint(path="/proxy/topaz/video/", method="POST"), - response_model=CreateVideoResponse, - data=CreateVideoRequest( - source=CreateVideoRequestSource( - container="mp4", - size=get_fs_object_size(src_video_stream), - duration=int(duration_sec), - frameCount=video.get_frame_count(), - frameRate=src_frame_rate, - resolution=Resolution(width=src_width, height=src_height), - ), - filters=filters, - output=OutputInformationVideo( - resolution=Resolution(width=target_width, height=target_height), - frameRate=target_frame_rate, - audioCodec="AAC", - audioTransfer="Copy", - dynamicCompressionLevel=dynamic_compression_level, - ), - ), - wait_label="Creating task", - final_label_on_success="Task created", - ) - upload_res = await sync_op( - cls, - ApiEndpoint( - path=f"/proxy/topaz/video/{initial_res.requestId}/accept", - method="PATCH", - ), - response_model=VideoAcceptResponse, - wait_label="Preparing upload", - final_label_on_success="Upload started", - ) - if len(upload_res.urls) > 1: - raise NotImplementedError( - "Large files are not currently supported. Please open an issue in the ComfyUI repository." - ) - async with aiohttp.ClientSession(headers={"Content-Type": "video/mp4"}) as session: - if isinstance(src_video_stream, BytesIO): - src_video_stream.seek(0) - async with session.put(upload_res.urls[0], data=src_video_stream, raise_for_status=True) as res: - upload_etag = res.headers["Etag"] - else: - with builtins.open(src_video_stream, "rb") as video_file: - async with session.put(upload_res.urls[0], data=video_file, raise_for_status=True) as res: - upload_etag = res.headers["Etag"] - await sync_op( - cls, - ApiEndpoint( - path=f"/proxy/topaz/video/{initial_res.requestId}/complete-upload", - method="PATCH", - ), - response_model=VideoCompleteUploadResponse, - data=VideoCompleteUploadRequest( - uploadResults=[ - VideoCompleteUploadRequestPart( - partNum=1, - eTag=upload_etag, - ), - ], - ), - wait_label="Finalizing upload", - final_label_on_success="Upload completed", - ) - final_response = await poll_op( - cls, - ApiEndpoint(path=f"/proxy/topaz/video/{initial_res.requestId}/status"), - response_model=VideoStatusResponse, - status_extractor=lambda x: x.status, - progress_extractor=lambda x: getattr(x, "progress", 0), - price_extractor=lambda x: (x.estimates.cost[0] * 0.08 if x.estimates and x.estimates.cost[0] else None), - poll_interval=10.0, - max_poll_attempts=320, - ) - return IO.NodeOutput(await download_url_to_video_output(final_response.download.url)) - - -class TopazExtension(ComfyExtension): - @override - async def get_node_list(self) -> list[type[IO.ComfyNode]]: - return [ - TopazImageEnhance, - TopazVideoEnhance, - ] - - -async def comfy_entrypoint() -> TopazExtension: - return TopazExtension() diff --git a/comfy_api_nodes/nodes_tripo.py b/comfy_api_nodes/nodes_tripo.py deleted file mode 100644 index 9f4298dce976..000000000000 --- a/comfy_api_nodes/nodes_tripo.py +++ /dev/null @@ -1,897 +0,0 @@ -from typing_extensions import override - -from comfy_api.latest import IO, ComfyExtension, Input -from comfy_api_nodes.apis.tripo import ( - TripoAnimateRetargetRequest, - TripoAnimateRigRequest, - TripoConvertModelRequest, - TripoFileEmptyReference, - TripoFileReference, - TripoImageToModelRequest, - TripoModelVersion, - TripoMultiviewToModelRequest, - TripoOrientation, - TripoRefineModelRequest, - TripoStyle, - TripoTaskResponse, - TripoTaskStatus, - TripoTaskType, - TripoTextToModelRequest, - TripoTextureModelRequest, - TripoUrlReference, -) -from comfy_api_nodes.util import ( - ApiEndpoint, - download_url_to_file_3d, - poll_op, - sync_op, - upload_images_to_comfyapi, -) - - -def get_model_url_from_response(response: TripoTaskResponse) -> str: - if response.data is not None: - for key in ["pbr_model", "model", "base_model"]: - if getattr(response.data.output, key, None) is not None: - return getattr(response.data.output, key) - raise RuntimeError(f"Failed to get model url from response: {response}") - - -async def poll_until_finished( - node_cls: type[IO.ComfyNode], - response: TripoTaskResponse, - average_duration: int | None = None, -) -> IO.NodeOutput: - """Polls the Tripo API endpoint until the task reaches a terminal state, then returns the response.""" - if response.code != 0: - raise RuntimeError(f"Failed to generate mesh: {response.error}") - task_id = response.data.task_id - response_poll = await poll_op( - node_cls, - poll_endpoint=ApiEndpoint(path=f"/proxy/tripo/v2/openapi/task/{task_id}"), - response_model=TripoTaskResponse, - completed_statuses=[TripoTaskStatus.SUCCESS], - failed_statuses=[ - TripoTaskStatus.FAILED, - TripoTaskStatus.CANCELLED, - TripoTaskStatus.UNKNOWN, - TripoTaskStatus.BANNED, - TripoTaskStatus.EXPIRED, - ], - status_extractor=lambda x: x.data.status, - progress_extractor=lambda x: x.data.progress, - estimated_duration=average_duration, - ) - if response_poll.data.status == TripoTaskStatus.SUCCESS: - url = get_model_url_from_response(response_poll) - file_glb = await download_url_to_file_3d(url, "glb", task_id=task_id) - return IO.NodeOutput(f"{task_id}.glb", task_id, file_glb) - raise RuntimeError(f"Failed to generate mesh: {response_poll}") - - -class TripoTextToModelNode(IO.ComfyNode): - """ - Generates 3D models synchronously based on a text prompt using Tripo's API. - """ - - @classmethod - def define_schema(cls): - return IO.Schema( - node_id="TripoTextToModelNode", - display_name="Tripo: Text to Model", - category="api node/3d/Tripo", - inputs=[ - IO.String.Input("prompt", multiline=True), - IO.String.Input("negative_prompt", multiline=True, optional=True), - IO.Combo.Input( - "model_version", options=TripoModelVersion, default=TripoModelVersion.v2_5_20250123, optional=True - ), - IO.Combo.Input("style", options=TripoStyle, default="None", optional=True), - IO.Boolean.Input("texture", default=True, optional=True), - IO.Boolean.Input("pbr", default=True, optional=True), - IO.Int.Input("image_seed", default=42, optional=True, advanced=True), - IO.Int.Input("model_seed", default=42, optional=True, advanced=True), - IO.Int.Input("texture_seed", default=42, optional=True, advanced=True), - IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True, advanced=True), - IO.Int.Input("face_limit", default=-1, min=-1, max=2000000, optional=True, advanced=True), - IO.Boolean.Input("quad", default=False, optional=True, advanced=True), - IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True, advanced=True), - ], - outputs=[ - IO.String.Output(display_name="model_file"), # for backward compatibility only - IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"), - IO.File3DGLB.Output(display_name="GLB"), - ], - hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, - IO.Hidden.unique_id, - ], - is_api_node=True, - is_output_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends( - widgets=[ - "model_version", - "style", - "texture", - "pbr", - "quad", - "texture_quality", - "geometry_quality", - ], - ), - expr=""" - ( - $isV14 := $contains(widgets.model_version,"v1.4"); - $style := widgets.style; - $hasStyle := ($style != "" and $style != "none"); - $withTexture := widgets.texture or widgets.pbr; - $isHdTexture := (widgets.texture_quality = "detailed"); - $isDetailedGeometry := (widgets.geometry_quality = "detailed"); - $baseCredits := - $isV14 ? 20 : ($withTexture ? 20 : 10); - $credits := - $baseCredits - + ($hasStyle ? 5 : 0) - + (widgets.quad ? 5 : 0) - + ($isHdTexture ? 10 : 0) - + ($isDetailedGeometry ? 20 : 0); - {"type":"usd","usd": $round($credits * 0.01, 2)} - ) - """, - ), - ) - - @classmethod - async def execute( - cls, - prompt: str, - negative_prompt: str | None = None, - model_version=None, - style: str | None = None, - texture: bool | None = None, - pbr: bool | None = None, - image_seed: int | None = None, - model_seed: int | None = None, - texture_seed: int | None = None, - texture_quality: str | None = None, - geometry_quality: str | None = None, - face_limit: int | None = None, - quad: bool | None = None, - ) -> IO.NodeOutput: - style_enum = None if style == "None" else style - if not prompt: - raise RuntimeError("Prompt is required") - response = await sync_op( - cls, - endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"), - response_model=TripoTaskResponse, - data=TripoTextToModelRequest( - type=TripoTaskType.TEXT_TO_MODEL, - prompt=prompt, - negative_prompt=negative_prompt if negative_prompt else None, - model_version=model_version, - style=style_enum, - texture=texture, - pbr=pbr, - image_seed=image_seed, - model_seed=model_seed, - texture_seed=texture_seed, - texture_quality=texture_quality, - face_limit=face_limit if face_limit != -1 else None, - geometry_quality=geometry_quality, - auto_size=True, - quad=quad, - ), - ) - return await poll_until_finished(cls, response, average_duration=80) - - -class TripoImageToModelNode(IO.ComfyNode): - """ - Generates 3D models synchronously based on a single image using Tripo's API. - """ - - @classmethod - def define_schema(cls): - return IO.Schema( - node_id="TripoImageToModelNode", - display_name="Tripo: Image to Model", - category="api node/3d/Tripo", - inputs=[ - IO.Image.Input("image"), - IO.Combo.Input( - "model_version", - options=TripoModelVersion, - tooltip="The model version to use for generation", - optional=True, - ), - IO.Combo.Input("style", options=TripoStyle, default="None", optional=True), - IO.Boolean.Input("texture", default=True, optional=True), - IO.Boolean.Input("pbr", default=True, optional=True), - IO.Int.Input("model_seed", default=42, optional=True, advanced=True), - IO.Combo.Input( - "orientation", options=TripoOrientation, default=TripoOrientation.DEFAULT, optional=True, advanced=True - ), - IO.Int.Input("texture_seed", default=42, optional=True, advanced=True), - IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True, advanced=True), - IO.Combo.Input( - "texture_alignment", default="original_image", options=["original_image", "geometry"], optional=True, advanced=True - ), - IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True, advanced=True), - IO.Boolean.Input("quad", default=False, optional=True, advanced=True), - IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True, advanced=True), - ], - outputs=[ - IO.String.Output(display_name="model_file"), # for backward compatibility only - IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"), - IO.File3DGLB.Output(display_name="GLB"), - ], - hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, - IO.Hidden.unique_id, - ], - is_api_node=True, - is_output_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends( - widgets=[ - "model_version", - "style", - "texture", - "pbr", - "quad", - "texture_quality", - "geometry_quality", - ], - ), - expr=""" - ( - $isV14 := $contains(widgets.model_version,"v1.4"); - $style := widgets.style; - $hasStyle := ($style != "" and $style != "none"); - $withTexture := widgets.texture or widgets.pbr; - $isHdTexture := (widgets.texture_quality = "detailed"); - $isDetailedGeometry := (widgets.geometry_quality = "detailed"); - $baseCredits := - $isV14 ? 30 : ($withTexture ? 30 : 20); - $credits := - $baseCredits - + ($hasStyle ? 5 : 0) - + (widgets.quad ? 5 : 0) - + ($isHdTexture ? 10 : 0) - + ($isDetailedGeometry ? 20 : 0); - {"type":"usd","usd": $round($credits * 0.01, 2)} - ) - """, - ), - ) - - @classmethod - async def execute( - cls, - image: Input.Image, - model_version: str | None = None, - style: str | None = None, - texture: bool | None = None, - pbr: bool | None = None, - model_seed: int | None = None, - orientation=None, - texture_seed: int | None = None, - texture_quality: str | None = None, - geometry_quality: str | None = None, - texture_alignment: str | None = None, - face_limit: int | None = None, - quad: bool | None = None, - ) -> IO.NodeOutput: - style_enum = None if style == "None" else style - if image is None: - raise RuntimeError("Image is required") - tripo_file = TripoFileReference( - root=TripoUrlReference( - url=(await upload_images_to_comfyapi(cls, image, max_images=1))[0], - type="jpeg", - ) - ) - response = await sync_op( - cls, - endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"), - response_model=TripoTaskResponse, - data=TripoImageToModelRequest( - type=TripoTaskType.IMAGE_TO_MODEL, - file=tripo_file, - model_version=model_version, - style=style_enum, - texture=texture, - pbr=pbr, - model_seed=model_seed, - orientation=orientation, - geometry_quality=geometry_quality, - texture_alignment=texture_alignment, - texture_seed=texture_seed, - texture_quality=texture_quality, - face_limit=face_limit if face_limit != -1 else None, - auto_size=True, - quad=quad, - ), - ) - return await poll_until_finished(cls, response, average_duration=80) - - -class TripoMultiviewToModelNode(IO.ComfyNode): - """ - Generates 3D models synchronously based on up to four images (front, left, back, right) using Tripo's API. - """ - - @classmethod - def define_schema(cls): - return IO.Schema( - node_id="TripoMultiviewToModelNode", - display_name="Tripo: Multiview to Model", - category="api node/3d/Tripo", - inputs=[ - IO.Image.Input("image"), - IO.Image.Input("image_left", optional=True), - IO.Image.Input("image_back", optional=True), - IO.Image.Input("image_right", optional=True), - IO.Combo.Input( - "model_version", - options=TripoModelVersion, - optional=True, - tooltip="The model version to use for generation", - ), - IO.Combo.Input( - "orientation", - options=TripoOrientation, - default=TripoOrientation.DEFAULT, - optional=True, - advanced=True, - ), - IO.Boolean.Input("texture", default=True, optional=True), - IO.Boolean.Input("pbr", default=True, optional=True), - IO.Int.Input("model_seed", default=42, optional=True, advanced=True), - IO.Int.Input("texture_seed", default=42, optional=True, advanced=True), - IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True, advanced=True), - IO.Combo.Input( - "texture_alignment", default="original_image", options=["original_image", "geometry"], optional=True, advanced=True - ), - IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True, advanced=True), - IO.Boolean.Input("quad", default=False, optional=True, advanced=True), - IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True, advanced=True), - ], - outputs=[ - IO.String.Output(display_name="model_file"), # for backward compatibility only - IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"), - IO.File3DGLB.Output(display_name="GLB"), - ], - hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, - IO.Hidden.unique_id, - ], - is_api_node=True, - is_output_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends( - widgets=[ - "model_version", - "texture", - "pbr", - "quad", - "texture_quality", - "geometry_quality", - ], - ), - expr=""" - ( - $isV14 := $contains(widgets.model_version,"v1.4"); - $withTexture := widgets.texture or widgets.pbr; - $isHdTexture := (widgets.texture_quality = "detailed"); - $isDetailedGeometry := (widgets.geometry_quality = "detailed"); - $baseCredits := - $isV14 ? 30 : ($withTexture ? 30 : 20); - $credits := - $baseCredits - + (widgets.quad ? 5 : 0) - + ($isHdTexture ? 10 : 0) - + ($isDetailedGeometry ? 20 : 0); - {"type":"usd","usd": $round($credits * 0.01, 2)} - ) - """, - ), - ) - - @classmethod - async def execute( - cls, - image: Input.Image, - image_left: Input.Image | None = None, - image_back: Input.Image | None = None, - image_right: Input.Image | None = None, - model_version: str | None = None, - orientation: str | None = None, - texture: bool | None = None, - pbr: bool | None = None, - model_seed: int | None = None, - texture_seed: int | None = None, - texture_quality: str | None = None, - geometry_quality: str | None = None, - texture_alignment: str | None = None, - face_limit: int | None = None, - quad: bool | None = None, - ) -> IO.NodeOutput: - if image is None: - raise RuntimeError("front image for multiview is required") - images = [] - image_dict = {"image": image, "image_left": image_left, "image_back": image_back, "image_right": image_right} - if image_left is None and image_back is None and image_right is None: - raise RuntimeError("At least one of left, back, or right image must be provided for multiview") - for image_name in ["image", "image_left", "image_back", "image_right"]: - image_ = image_dict[image_name] - if image_ is not None: - images.append( - TripoFileReference( - root=TripoUrlReference( - url=(await upload_images_to_comfyapi(cls, image_, max_images=1))[0], type="jpeg" - ) - ) - ) - else: - images.append(TripoFileEmptyReference()) - response = await sync_op( - cls, - ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"), - response_model=TripoTaskResponse, - data=TripoMultiviewToModelRequest( - type=TripoTaskType.MULTIVIEW_TO_MODEL, - files=images, - model_version=model_version, - orientation=orientation, - texture=texture, - pbr=pbr, - model_seed=model_seed, - texture_seed=texture_seed, - texture_quality=texture_quality, - geometry_quality=geometry_quality, - texture_alignment=texture_alignment, - face_limit=face_limit if face_limit != -1 else None, - quad=quad, - ), - ) - return await poll_until_finished(cls, response, average_duration=80) - - -class TripoTextureNode(IO.ComfyNode): - - @classmethod - def define_schema(cls): - return IO.Schema( - node_id="TripoTextureNode", - display_name="Tripo: Texture model", - category="api node/3d/Tripo", - inputs=[ - IO.Custom("MODEL_TASK_ID").Input("model_task_id"), - IO.Boolean.Input("texture", default=True, optional=True), - IO.Boolean.Input("pbr", default=True, optional=True), - IO.Int.Input("texture_seed", default=42, optional=True, advanced=True), - IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True, advanced=True), - IO.Combo.Input( - "texture_alignment", default="original_image", options=["original_image", "geometry"], optional=True, advanced=True - ), - ], - outputs=[ - IO.String.Output(display_name="model_file"), # for backward compatibility only - IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"), - IO.File3DGLB.Output(display_name="GLB"), - ], - hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, - IO.Hidden.unique_id, - ], - is_api_node=True, - is_output_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["texture_quality"]), - expr=""" - ( - $tq := widgets.texture_quality; - {"type":"usd","usd": ($contains($tq,"detailed") ? 0.2 : 0.1)} - ) - """, - ), - ) - - @classmethod - async def execute( - cls, - model_task_id, - texture: bool | None = None, - pbr: bool | None = None, - texture_seed: int | None = None, - texture_quality: str | None = None, - texture_alignment: str | None = None, - ) -> IO.NodeOutput: - response = await sync_op( - cls, - endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"), - response_model=TripoTaskResponse, - data=TripoTextureModelRequest( - original_model_task_id=model_task_id, - texture=texture, - pbr=pbr, - texture_seed=texture_seed, - texture_quality=texture_quality, - texture_alignment=texture_alignment, - ), - ) - return await poll_until_finished(cls, response, average_duration=80) - - -class TripoRefineNode(IO.ComfyNode): - - @classmethod - def define_schema(cls): - return IO.Schema( - node_id="TripoRefineNode", - display_name="Tripo: Refine Draft model", - category="api node/3d/Tripo", - description="Refine a draft model created by v1.4 Tripo models only.", - inputs=[ - IO.Custom("MODEL_TASK_ID").Input("model_task_id", tooltip="Must be a v1.4 Tripo model"), - ], - outputs=[ - IO.String.Output(display_name="model_file"), # for backward compatibility only - IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"), - IO.File3DGLB.Output(display_name="GLB"), - ], - hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, - IO.Hidden.unique_id, - ], - is_api_node=True, - is_output_node=True, - price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.3}""", - ), - ) - - @classmethod - async def execute(cls, model_task_id) -> IO.NodeOutput: - response = await sync_op( - cls, - endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"), - response_model=TripoTaskResponse, - data=TripoRefineModelRequest(draft_model_task_id=model_task_id), - ) - return await poll_until_finished(cls, response, average_duration=240) - - -class TripoRigNode(IO.ComfyNode): - - @classmethod - def define_schema(cls): - return IO.Schema( - node_id="TripoRigNode", - display_name="Tripo: Rig model", - category="api node/3d/Tripo", - inputs=[IO.Custom("MODEL_TASK_ID").Input("original_model_task_id")], - outputs=[ - IO.String.Output(display_name="model_file"), # for backward compatibility only - IO.Custom("RIG_TASK_ID").Output(display_name="rig task_id"), - IO.File3DGLB.Output(display_name="GLB"), - ], - hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, - IO.Hidden.unique_id, - ], - is_api_node=True, - is_output_node=True, - price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.25}""", - ), - ) - - @classmethod - async def execute(cls, original_model_task_id) -> IO.NodeOutput: - response = await sync_op( - cls, - endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"), - response_model=TripoTaskResponse, - data=TripoAnimateRigRequest(original_model_task_id=original_model_task_id, out_format="glb", spec="tripo"), - ) - return await poll_until_finished(cls, response, average_duration=180) - - -class TripoRetargetNode(IO.ComfyNode): - - @classmethod - def define_schema(cls): - return IO.Schema( - node_id="TripoRetargetNode", - display_name="Tripo: Retarget rigged model", - category="api node/3d/Tripo", - inputs=[ - IO.Custom("RIG_TASK_ID").Input("original_model_task_id"), - IO.Combo.Input( - "animation", - options=[ - "preset:idle", - "preset:walk", - "preset:run", - "preset:dive", - "preset:climb", - "preset:jump", - "preset:slash", - "preset:shoot", - "preset:hurt", - "preset:fall", - "preset:turn", - "preset:quadruped:walk", - "preset:hexapod:walk", - "preset:octopod:walk", - "preset:serpentine:march", - "preset:aquatic:march" - ], - ), - ], - outputs=[ - IO.String.Output(display_name="model_file"), # for backward compatibility only - IO.Custom("RETARGET_TASK_ID").Output(display_name="retarget task_id"), - IO.File3DGLB.Output(display_name="GLB"), - ], - hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, - IO.Hidden.unique_id, - ], - is_api_node=True, - is_output_node=True, - price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.1}""", - ), - ) - - @classmethod - async def execute(cls, original_model_task_id, animation: str) -> IO.NodeOutput: - response = await sync_op( - cls, - endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"), - response_model=TripoTaskResponse, - data=TripoAnimateRetargetRequest( - original_model_task_id=original_model_task_id, - animation=animation, - out_format="glb", - bake_animation=True, - ), - ) - return await poll_until_finished(cls, response, average_duration=30) - - -class TripoConversionNode(IO.ComfyNode): - - @classmethod - def define_schema(cls): - return IO.Schema( - node_id="TripoConversionNode", - display_name="Tripo: Convert model", - category="api node/3d/Tripo", - inputs=[ - IO.Custom("MODEL_TASK_ID,RIG_TASK_ID,RETARGET_TASK_ID").Input("original_model_task_id"), - IO.Combo.Input("format", options=["GLTF", "USDZ", "FBX", "OBJ", "STL", "3MF"]), - IO.Boolean.Input("quad", default=False, optional=True, advanced=True), - IO.Int.Input( - "face_limit", - default=-1, - min=-1, - max=2000000, - optional=True, - advanced=True, - ), - IO.Int.Input( - "texture_size", - default=4096, - min=128, - max=4096, - optional=True, - advanced=True, - ), - IO.Combo.Input( - "texture_format", - options=["BMP", "DPX", "HDR", "JPEG", "OPEN_EXR", "PNG", "TARGA", "TIFF", "WEBP"], - default="JPEG", - optional=True, - advanced=True, - ), - IO.Boolean.Input("force_symmetry", default=False, optional=True, advanced=True), - IO.Boolean.Input("flatten_bottom", default=False, optional=True, advanced=True), - IO.Float.Input( - "flatten_bottom_threshold", - default=0.0, - min=0.0, - max=1.0, - optional=True, - advanced=True, - ), - IO.Boolean.Input("pivot_to_center_bottom", default=False, optional=True, advanced=True), - IO.Float.Input( - "scale_factor", - default=1.0, - min=0.0, - optional=True, - advanced=True, - ), - IO.Boolean.Input("with_animation", default=False, optional=True, advanced=True), - IO.Boolean.Input("pack_uv", default=False, optional=True, advanced=True), - IO.Boolean.Input("bake", default=False, optional=True, advanced=True), - IO.String.Input("part_names", default="", optional=True, advanced=True), # comma-separated list - IO.Combo.Input( - "fbx_preset", - options=["blender", "mixamo", "3dsmax"], - default="blender", - optional=True, - advanced=True, - ), - IO.Boolean.Input("export_vertex_colors", default=False, optional=True, advanced=True), - IO.Combo.Input( - "export_orientation", - options=["align_image", "default"], - default="default", - optional=True, - advanced=True, - ), - IO.Boolean.Input("animate_in_place", default=False, optional=True, advanced=True), - ], - outputs=[], - hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, - IO.Hidden.unique_id, - ], - is_api_node=True, - is_output_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends( - widgets=[ - "quad", - "face_limit", - "texture_size", - "texture_format", - "force_symmetry", - "flatten_bottom", - "flatten_bottom_threshold", - "pivot_to_center_bottom", - "scale_factor", - "with_animation", - "pack_uv", - "bake", - "part_names", - "fbx_preset", - "export_vertex_colors", - "export_orientation", - "animate_in_place", - ], - ), - expr=""" - ( - $face := (widgets.face_limit != null) ? widgets.face_limit : -1; - $texSize := (widgets.texture_size != null) ? widgets.texture_size : 4096; - $flatThresh := (widgets.flatten_bottom_threshold != null) ? widgets.flatten_bottom_threshold : 0; - $scale := (widgets.scale_factor != null) ? widgets.scale_factor : 1; - $texFmt := (widgets.texture_format != "" ? widgets.texture_format : "jpeg"); - $part := widgets.part_names; - $fbx := (widgets.fbx_preset != "" ? widgets.fbx_preset : "blender"); - $orient := (widgets.export_orientation != "" ? widgets.export_orientation : "default"); - $advanced := - widgets.quad or - widgets.force_symmetry or - widgets.flatten_bottom or - widgets.pivot_to_center_bottom or - widgets.with_animation or - widgets.pack_uv or - widgets.bake or - widgets.export_vertex_colors or - widgets.animate_in_place or - ($face != -1) or - ($texSize != 4096) or - ($flatThresh != 0) or - ($scale != 1) or - ($texFmt != "jpeg") or - ($part != "") or - ($fbx != "blender") or - ($orient != "default"); - {"type":"usd","usd": ($advanced ? 0.1 : 0.05)} - ) - """, - ), - ) - - @classmethod - def validate_inputs(cls, input_types): - # The min and max of input1 and input2 are still validated because - # we didn't take `input1` or `input2` as arguments - if input_types["original_model_task_id"] not in ("MODEL_TASK_ID", "RIG_TASK_ID", "RETARGET_TASK_ID"): - return "original_model_task_id must be MODEL_TASK_ID, RIG_TASK_ID or RETARGET_TASK_ID type" - return True - - @classmethod - async def execute( - cls, - original_model_task_id, - format: str, - quad: bool, - force_symmetry: bool, - face_limit: int, - flatten_bottom: bool, - flatten_bottom_threshold: float, - texture_size: int, - texture_format: str, - pivot_to_center_bottom: bool, - scale_factor: float, - with_animation: bool, - pack_uv: bool, - bake: bool, - part_names: str, - fbx_preset: str, - export_vertex_colors: bool, - export_orientation: str, - animate_in_place: bool, - ) -> IO.NodeOutput: - if not original_model_task_id: - raise RuntimeError("original_model_task_id is required") - - # Parse part_names from comma-separated string to list - part_names_list = None - if part_names and part_names.strip(): - part_names_list = [name.strip() for name in part_names.split(',') if name.strip()] - - response = await sync_op( - cls, - endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"), - response_model=TripoTaskResponse, - data=TripoConvertModelRequest( - original_model_task_id=original_model_task_id, - format=format, - quad=quad if quad else None, - force_symmetry=force_symmetry if force_symmetry else None, - face_limit=face_limit if face_limit != -1 else None, - flatten_bottom=flatten_bottom if flatten_bottom else None, - flatten_bottom_threshold=flatten_bottom_threshold if flatten_bottom_threshold != 0.0 else None, - texture_size=texture_size if texture_size != 4096 else None, - texture_format=texture_format if texture_format != "JPEG" else None, - pivot_to_center_bottom=pivot_to_center_bottom if pivot_to_center_bottom else None, - scale_factor=scale_factor if scale_factor != 1.0 else None, - with_animation=with_animation if with_animation else None, - pack_uv=pack_uv if pack_uv else None, - bake=bake if bake else None, - part_names=part_names_list, - fbx_preset=fbx_preset if fbx_preset != "blender" else None, - export_vertex_colors=export_vertex_colors if export_vertex_colors else None, - export_orientation=export_orientation if export_orientation != "default" else None, - animate_in_place=animate_in_place if animate_in_place else None, - ), - ) - return await poll_until_finished(cls, response, average_duration=30) - - -class TripoExtension(ComfyExtension): - @override - async def get_node_list(self) -> list[type[IO.ComfyNode]]: - return [ - TripoTextToModelNode, - TripoImageToModelNode, - TripoMultiviewToModelNode, - TripoTextureNode, - TripoRefineNode, - TripoRigNode, - TripoRetargetNode, - TripoConversionNode, - ] - - -async def comfy_entrypoint() -> TripoExtension: - return TripoExtension() diff --git a/comfy_api_nodes/nodes_veo2.py b/comfy_api_nodes/nodes_veo2.py index 13fc1cc3682a..217e1d0fc651 100644 --- a/comfy_api_nodes/nodes_veo2.py +++ b/comfy_api_nodes/nodes_veo2.py @@ -20,6 +20,9 @@ sync_op, tensor_to_base64_string, ) +from comfy_api_nodes.util._helpers import get_google_auth_header + +GEMINI_BASE_URL = "https://generativelanguage.googleapis.com/v1beta/models" AVERAGE_DURATION_VIDEO_GEN = 32 MODELS_MAP = { @@ -119,15 +122,9 @@ def define_schema(cls): IO.Video.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["duration_seconds"]), - expr="""{"type":"usd","usd": 0.5 * widgets.duration_seconds}""", - ), ) @classmethod @@ -179,7 +176,11 @@ async def execute( initial_response = await sync_op( cls, - ApiEndpoint(path=f"/proxy/veo/{model}/generate", method="POST"), + ApiEndpoint( + path=f"{GEMINI_BASE_URL}/{model}:predictLongRunning", + method="POST", + headers=get_google_auth_header(), + ), response_model=VeoGenVidResponse, data=VeoGenVidRequest( instances=instances, @@ -192,14 +193,18 @@ def status_extractor(response): # We'll check for errors after polling completes return "completed" if response.done else "pending" + # Poll the operation using the name from the submit response. + # Direct Google API uses GET /v1beta/{operation_name} instead of POST to a proxy. + operation_name = initial_response.name poll_response = await poll_op( cls, - ApiEndpoint(path=f"/proxy/veo/{model}/poll", method="POST"), + ApiEndpoint( + path=f"https://generativelanguage.googleapis.com/v1beta/{operation_name}", + method="GET", + headers=get_google_auth_header(), + ), response_model=VeoGenVidPollResponse, status_extractor=status_extractor, - data=VeoGenVidPollRequest( - operationName=initial_response.name, - ), poll_interval=5.0, estimated_duration=AVERAGE_DURATION_VIDEO_GEN, ) @@ -350,25 +355,9 @@ def define_schema(cls): IO.Video.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["model", "generate_audio"]), - expr=""" - ( - $m := widgets.model; - $a := widgets.generate_audio; - ($contains($m,"veo-3.0-fast-generate-001") or $contains($m,"veo-3.1-fast-generate")) - ? {"type":"usd","usd": ($a ? 1.2 : 0.8)} - : ($contains($m,"veo-3.0-generate-001") or $contains($m,"veo-3.1-generate")) - ? {"type":"usd","usd": ($a ? 3.2 : 1.6)} - : {"type":"range_usd","min_usd":0.8,"max_usd":3.2} - ) - """, - ), ) @@ -437,35 +426,9 @@ def define_schema(cls): IO.Video.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["model", "generate_audio", "duration"]), - expr=""" - ( - $prices := { - "veo-3.1-fast-generate": { "audio": 0.15, "no_audio": 0.10 }, - "veo-3.1-generate": { "audio": 0.40, "no_audio": 0.20 } - }; - $m := widgets.model; - $ga := (widgets.generate_audio = "true"); - $seconds := widgets.duration; - $modelKey := - $contains($m, "veo-3.1-fast-generate") ? "veo-3.1-fast-generate" : - $contains($m, "veo-3.1-generate") ? "veo-3.1-generate" : - ""; - $audioKey := $ga ? "audio" : "no_audio"; - $modelPrices := $lookup($prices, $modelKey); - $pps := $lookup($modelPrices, $audioKey); - ($pps != null) - ? {"type":"usd","usd": $pps * $seconds} - : {"type":"range_usd","min_usd": 0.4, "max_usd": 3.2} - ) - """, - ), ) @classmethod @@ -485,7 +448,11 @@ async def execute( model = MODELS_MAP[model] initial_response = await sync_op( cls, - ApiEndpoint(path=f"/proxy/veo/{model}/generate", method="POST"), + ApiEndpoint( + path=f"{GEMINI_BASE_URL}/{model}:predictLongRunning", + method="POST", + headers=get_google_auth_header(), + ), response_model=VeoGenVidResponse, data=VeoGenVidRequest( instances=[ @@ -511,14 +478,16 @@ async def execute( ), ), ) + operation_name = initial_response.name poll_response = await poll_op( cls, - ApiEndpoint(path=f"/proxy/veo/{model}/poll", method="POST"), + ApiEndpoint( + path=f"https://generativelanguage.googleapis.com/v1beta/{operation_name}", + method="GET", + headers=get_google_auth_header(), + ), response_model=VeoGenVidPollResponse, status_extractor=lambda r: "completed" if r.done else "pending", - data=VeoGenVidPollRequest( - operationName=initial_response.name, - ), poll_interval=5.0, estimated_duration=AVERAGE_DURATION_VIDEO_GEN, ) diff --git a/comfy_api_nodes/nodes_vidu.py b/comfy_api_nodes/nodes_vidu.py index f04407eb583a..cb1e19787049 100644 --- a/comfy_api_nodes/nodes_vidu.py +++ b/comfy_api_nodes/nodes_vidu.py @@ -17,21 +17,19 @@ get_number_of_images, poll_op, sync_op, - upload_image_to_comfyapi, - upload_images_to_comfyapi, - upload_video_to_comfyapi, + upload_image_to_fal, + upload_images_to_fal, + upload_video_to_fal, validate_image_aspect_ratio, validate_image_dimensions, validate_images_aspect_ratio_closeness, validate_string, validate_video_duration, ) +from comfy_api_nodes.util._helpers import get_fal_auth_header +from comfy_api_nodes.util.client import fal_run -VIDU_TEXT_TO_VIDEO = "/proxy/vidu/text2video" -VIDU_IMAGE_TO_VIDEO = "/proxy/vidu/img2video" -VIDU_REFERENCE_VIDEO = "/proxy/vidu/reference2video" -VIDU_START_END_VIDEO = "/proxy/vidu/start-end2video" -VIDU_GET_GENERATION_STATUS = "/proxy/vidu/tasks/%s/creations" +FAL_VIDU_I2V = "fal-ai/vidu/q3-pro/image-to-video" async def execute_task( @@ -40,28 +38,17 @@ async def execute_task( payload: TaskCreationRequest | TaskExtendCreationRequest | TaskMultiFrameCreationRequest, max_poll_attempts: int = 320, ) -> list[TaskResult]: - task_creation_response = await sync_op( - cls, - endpoint=ApiEndpoint(path=vidu_endpoint, method="POST"), - response_model=TaskCreationResponse, - data=payload, - ) - if task_creation_response.state == "failed": - raise RuntimeError(f"Vidu request failed. Code: {task_creation_response.code}") - response = await poll_op( - cls, - ApiEndpoint(path=VIDU_GET_GENERATION_STATUS % task_creation_response.task_id), - response_model=TaskStatusResponse, - status_extractor=lambda r: r.state, - progress_extractor=lambda r: r.progress, - price_extractor=lambda r: r.credits * 0.005 if r.credits is not None else None, - max_poll_attempts=max_poll_attempts, - ) - if not response.creations: - raise RuntimeError( - f"Vidu request does not contain results. State: {response.state}, Error Code: {response.err_code}" - ) - return response.creations + # fal_run handles submit + poll + fetch + data = payload.model_dump(exclude_none=True) if hasattr(payload, 'model_dump') else payload + result = await fal_run(cls, FAL_VIDU_I2V, data) # TODO: verify fal.ai field names; use correct model per vidu_endpoint + # Parse fal.ai response to extract video URLs + videos = result.get("videos", []) + if not videos: + creations_data = result.get("creations", []) + if not creations_data: + raise RuntimeError("Vidu request does not contain results.") + return [TaskResult(**c) if isinstance(c, dict) else c for c in creations_data] + return [TaskResult(url=v.get("url", ""), id=v.get("id", "")) for v in videos] class ViduTextToVideoNode(IO.ComfyNode): @@ -126,14 +113,9 @@ def define_schema(cls): IO.Video.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.4}""", - ), ) @classmethod @@ -158,7 +140,7 @@ async def execute( resolution=resolution, movement_amplitude=movement_amplitude, ) - results = await execute_task(cls, VIDU_TEXT_TO_VIDEO, payload) + results = await execute_task(cls, FAL_VIDU_I2V, payload) return IO.NodeOutput(await download_url_to_video_output(results[0].url)) @@ -224,14 +206,9 @@ def define_schema(cls): IO.Video.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.4}""", - ), ) @classmethod @@ -256,13 +233,12 @@ async def execute( resolution=resolution, movement_amplitude=movement_amplitude, ) - payload.images = await upload_images_to_comfyapi( - cls, + payload.images = await upload_images_to_fal( image, max_images=1, mime_type="image/png", ) - results = await execute_task(cls, VIDU_IMAGE_TO_VIDEO, payload) + results = await execute_task(cls, FAL_VIDU_I2V, payload) return IO.NodeOutput(await download_url_to_video_output(results[0].url)) @@ -332,14 +308,9 @@ def define_schema(cls): IO.Video.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.4}""", - ), ) @classmethod @@ -371,13 +342,12 @@ async def execute( resolution=resolution, movement_amplitude=movement_amplitude, ) - payload.images = await upload_images_to_comfyapi( - cls, + payload.images = await upload_images_to_fal( images, max_images=7, mime_type="image/png", ) - results = await execute_task(cls, VIDU_REFERENCE_VIDEO, payload) + results = await execute_task(cls, FAL_VIDU_I2V, payload) return IO.NodeOutput(await download_url_to_video_output(results[0].url)) @@ -446,14 +416,9 @@ def define_schema(cls): IO.Video.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.4}""", - ), ) @classmethod @@ -478,10 +443,10 @@ async def execute( movement_amplitude=movement_amplitude, ) payload.images = [ - (await upload_images_to_comfyapi(cls, frame, max_images=1, mime_type="image/png"))[0] + (await upload_images_to_fal(frame, max_images=1, mime_type="image/png"))[0] for frame in (first_frame, end_frame) ] - results = await execute_task(cls, VIDU_START_END_VIDEO, payload) + results = await execute_task(cls, FAL_VIDU_I2V, payload) return IO.NodeOutput(await download_url_to_video_output(results[0].url)) @@ -531,22 +496,9 @@ def define_schema(cls): IO.Video.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"]), - expr=""" - ( - $is1080 := widgets.resolution = "1080p"; - $base := $is1080 ? 0.1 : 0.075; - $perSec := $is1080 ? 0.05 : 0.025; - {"type":"usd","usd": $base + $perSec * (widgets.duration - 1)} - ) - """, - ), ) @classmethod @@ -563,7 +515,7 @@ async def execute( validate_string(prompt, min_length=1, max_length=2000) results = await execute_task( cls, - VIDU_TEXT_TO_VIDEO, + FAL_VIDU_I2V, TaskCreationRequest( model=model, prompt=prompt, @@ -631,44 +583,9 @@ def define_schema(cls): IO.Video.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["model", "duration", "resolution"]), - expr=""" - ( - $m := widgets.model; - $d := widgets.duration; - $is1080 := widgets.resolution = "1080p"; - $contains($m, "pro-fast") - ? ( - $base := $is1080 ? 0.08 : 0.04; - $perSec := $is1080 ? 0.02 : 0.01; - {"type":"usd","usd": $base + $perSec * ($d - 1)} - ) - : $contains($m, "pro") - ? ( - $base := $is1080 ? 0.275 : 0.075; - $perSec := $is1080 ? 0.075 : 0.05; - {"type":"usd","usd": $base + $perSec * ($d - 1)} - ) - : $contains($m, "turbo") - ? ( - $is1080 - ? {"type":"usd","usd": 0.175 + 0.05 * ($d - 1)} - : ( - $d <= 1 ? {"type":"usd","usd": 0.04} - : $d <= 2 ? {"type":"usd","usd": 0.05} - : {"type":"usd","usd": 0.05 + 0.05 * ($d - 2)} - ) - ) - : {"type":"usd","usd": 0.04} - ) - """, - ), ) @classmethod @@ -688,7 +605,7 @@ async def execute( validate_string(prompt, max_length=2000) results = await execute_task( cls, - VIDU_IMAGE_TO_VIDEO, + FAL_VIDU_I2V, TaskCreationRequest( model=model, prompt=prompt, @@ -696,8 +613,7 @@ async def execute( seed=seed, resolution=resolution, movement_amplitude=movement_amplitude, - images=await upload_images_to_comfyapi( - cls, + images=await upload_images_to_fal( image, max_images=1, mime_type="image/png", @@ -770,23 +686,9 @@ def define_schema(cls): IO.Video.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["audio", "duration", "resolution"]), - expr=""" - ( - $is1080 := widgets.resolution = "1080p"; - $base := $is1080 ? 0.375 : 0.125; - $perSec := $is1080 ? 0.05 : 0.025; - $audioCost := widgets.audio = true ? 0.075 : 0; - {"type":"usd","usd": $base + $perSec * (widgets.duration - 1) + $audioCost} - ) - """, - ), ) @classmethod @@ -818,12 +720,10 @@ async def execute( subjects_param.append( SubjectReference( id=i, - images=await upload_images_to_comfyapi( - cls, + images=await upload_images_to_fal( subjects[i], max_images=3, mime_type="image/png", - wait_label=f"Uploading reference images for {i}", ), ), ) @@ -838,7 +738,7 @@ async def execute( movement_amplitude=movement_amplitude, subjects=subjects_param, ) - results = await execute_task(cls, VIDU_REFERENCE_VIDEO, payload) + results = await execute_task(cls, FAL_VIDU_I2V, payload) return IO.NodeOutput(await download_url_to_video_output(results[0].url)) @@ -889,43 +789,9 @@ def define_schema(cls): IO.Video.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["model", "duration", "resolution"]), - expr=""" - ( - $m := widgets.model; - $d := widgets.duration; - $is1080 := widgets.resolution = "1080p"; - $contains($m, "pro-fast") - ? ( - $base := $is1080 ? 0.08 : 0.04; - $perSec := $is1080 ? 0.02 : 0.01; - {"type":"usd","usd": $base + $perSec * ($d - 1)} - ) - : $contains($m, "pro") - ? ( - $base := $is1080 ? 0.275 : 0.075; - $perSec := $is1080 ? 0.075 : 0.05; - {"type":"usd","usd": $base + $perSec * ($d - 1)} - ) - : $contains($m, "turbo") - ? ( - $is1080 - ? {"type":"usd","usd": 0.175 + 0.05 * ($d - 1)} - : ( - $d <= 2 ? {"type":"usd","usd": 0.05} - : {"type":"usd","usd": 0.05 + 0.05 * ($d - 2)} - ) - ) - : {"type":"usd","usd": 0.04} - ) - """, - ), ) @classmethod @@ -954,11 +820,11 @@ async def execute( resolution=resolution, movement_amplitude=movement_amplitude, images=[ - (await upload_images_to_comfyapi(cls, frame, max_images=1, mime_type="image/png"))[0] + (await upload_images_to_fal(frame, max_images=1, mime_type="image/png"))[0] for frame in (first_frame, end_frame) ], ) - results = await execute_task(cls, VIDU_START_END_VIDEO, payload) + results = await execute_task(cls, FAL_VIDU_I2V, payload) return IO.NodeOutput(await download_url_to_video_output(results[0].url)) @@ -1041,32 +907,9 @@ def define_schema(cls): IO.Video.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["model", "model.duration", "model.resolution"]), - expr=""" - ( - $m := widgets.model; - $d := $lookup(widgets, "model.duration"); - $res := $lookup(widgets, "model.resolution"); - $contains($m, "pro") - ? ( - $base := $lookup({"720p": 0.15, "1080p": 0.3}, $res); - $perSec := $lookup({"720p": 0.05, "1080p": 0.075}, $res); - {"type":"usd","usd": $base + $perSec * ($d - 1)} - ) - : ( - $base := $lookup({"720p": 0.075, "1080p": 0.2}, $res); - $perSec := $lookup({"720p": 0.025, "1080p": 0.05}, $res); - {"type":"usd","usd": $base + $perSec * ($d - 1)} - ) - ) - """, - ), ) @classmethod @@ -1084,17 +927,17 @@ async def execute( if end_frame is not None: validate_image_aspect_ratio(end_frame, (1, 4), (4, 1)) validate_image_dimensions(end_frame, min_width=128, min_height=128) - image_url = await upload_image_to_comfyapi(cls, end_frame, wait_label="Uploading end frame") + image_url = await upload_image_to_fal(end_frame) results = await execute_task( cls, - "/proxy/vidu/extend", + FAL_VIDU_I2V, TaskExtendCreationRequest( model=model["model"], prompt=prompt, duration=model["duration"], seed=seed, resolution=model["resolution"], - video_url=await upload_video_to_comfyapi(cls, video, wait_label="Uploading video"), + video_url=await upload_video_to_fal(video), images=[image_url] if image_url else None, ), max_poll_attempts=480, @@ -1176,57 +1019,9 @@ def define_schema(cls): IO.Video.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends( - widgets=[ - "model", - "resolution", - "frames", - "frames.duration1", - "frames.duration2", - "frames.duration3", - "frames.duration4", - "frames.duration5", - "frames.duration6", - "frames.duration7", - "frames.duration8", - "frames.duration9", - ] - ), - expr=""" - ( - $m := widgets.model; - $n := $number(widgets.frames); - $is1080 := widgets.resolution = "1080p"; - $d1 := $lookup(widgets, "frames.duration1"); - $d2 := $lookup(widgets, "frames.duration2"); - $d3 := $n >= 3 ? $lookup(widgets, "frames.duration3") : 0; - $d4 := $n >= 4 ? $lookup(widgets, "frames.duration4") : 0; - $d5 := $n >= 5 ? $lookup(widgets, "frames.duration5") : 0; - $d6 := $n >= 6 ? $lookup(widgets, "frames.duration6") : 0; - $d7 := $n >= 7 ? $lookup(widgets, "frames.duration7") : 0; - $d8 := $n >= 8 ? $lookup(widgets, "frames.duration8") : 0; - $d9 := $n >= 9 ? $lookup(widgets, "frames.duration9") : 0; - $totalDuration := $d1 + $d2 + $d3 + $d4 + $d5 + $d6 + $d7 + $d8 + $d9; - $contains($m, "pro") - ? ( - $base := $is1080 ? 0.3 : 0.15; - $perSec := $is1080 ? 0.075 : 0.05; - {"type":"usd","usd": $n * $base + $perSec * $totalDuration} - ) - : ( - $base := $is1080 ? 0.2 : 0.075; - $perSec := $is1080 ? 0.05 : 0.025; - {"type":"usd","usd": $n * $base + $perSec * $totalDuration} - ) - ) - """, - ), ) @classmethod @@ -1244,28 +1039,24 @@ async def execute( for i in range(1, frame_count + 1): validate_image_aspect_ratio(frames[f"end_image{i}"], (1, 4), (4, 1)) validate_string(frames[f"prompt{i}"], max_length=2000) - start_image_url = await upload_image_to_comfyapi( - cls, + start_image_url = await upload_image_to_fal( start_image, mime_type="image/png", - wait_label="Uploading start image", ) for i in range(1, frame_count + 1): image_settings.append( FrameSetting( prompt=frames[f"prompt{i}"], - key_image=await upload_image_to_comfyapi( - cls, + key_image=await upload_image_to_fal( frames[f"end_image{i}"], mime_type="image/png", - wait_label=f"Uploading end image({i})", ), duration=frames[f"duration{i}"], ) ) results = await execute_task( cls, - "/proxy/vidu/multiframe", + FAL_VIDU_I2V, TaskMultiFrameCreationRequest( model=model, seed=seed, @@ -1373,29 +1164,9 @@ def define_schema(cls): IO.Video.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["model", "model.duration", "model.resolution"]), - expr=""" - ( - $res := $lookup(widgets, "model.resolution"); - $d := $lookup(widgets, "model.duration"); - $contains(widgets.model, "turbo") - ? ( - $rate := $lookup({"720p": 0.06, "1080p": 0.08}, $res); - {"type":"usd","usd": $rate * $d} - ) - : ( - $rate := $lookup({"720p": 0.15, "1080p": 0.16}, $res); - {"type":"usd","usd": $rate * $d} - ) - ) - """, - ), ) @classmethod @@ -1408,7 +1179,7 @@ async def execute( validate_string(prompt, min_length=1, max_length=2000) results = await execute_task( cls, - VIDU_TEXT_TO_VIDEO, + FAL_VIDU_I2V, TaskCreationRequest( model=model["model"], prompt=prompt, @@ -1513,29 +1284,9 @@ def define_schema(cls): IO.Video.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["model", "model.duration", "model.resolution"]), - expr=""" - ( - $res := $lookup(widgets, "model.resolution"); - $d := $lookup(widgets, "model.duration"); - $contains(widgets.model, "turbo") - ? ( - $rate := $lookup({"720p": 0.06, "1080p": 0.08}, $res); - {"type":"usd","usd": $rate * $d} - ) - : ( - $rate := $lookup({"720p": 0.15, "1080p": 0.16, "2k": 0.2}, $res); - {"type":"usd","usd": $rate * $d} - ) - ) - """, - ), ) @classmethod @@ -1550,7 +1301,7 @@ async def execute( validate_string(prompt, max_length=2000) results = await execute_task( cls, - VIDU_IMAGE_TO_VIDEO, + FAL_VIDU_I2V, TaskCreationRequest( model=model["model"], prompt=prompt, @@ -1558,7 +1309,7 @@ async def execute( seed=seed, resolution=model["resolution"], audio=model["audio"], - images=[await upload_image_to_comfyapi(cls, image)], + images=[await upload_image_to_fal(image)], ), max_poll_attempts=720, ) @@ -1652,29 +1403,9 @@ def define_schema(cls): IO.Video.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["model", "model.duration", "model.resolution"]), - expr=""" - ( - $res := $lookup(widgets, "model.resolution"); - $d := $lookup(widgets, "model.duration"); - $contains(widgets.model, "turbo") - ? ( - $rate := $lookup({"720p": 0.06, "1080p": 0.08}, $res); - {"type":"usd","usd": $rate * $d} - ) - : ( - $rate := $lookup({"720p": 0.15, "1080p": 0.16}, $res); - {"type":"usd","usd": $rate * $d} - ) - ) - """, - ), ) @classmethod @@ -1696,11 +1427,11 @@ async def execute( resolution=model["resolution"], audio=model["audio"], images=[ - (await upload_images_to_comfyapi(cls, frame, max_images=1, mime_type="image/png"))[0] + (await upload_images_to_fal(frame, max_images=1, mime_type="image/png"))[0] for frame in (first_frame, end_frame) ], ) - results = await execute_task(cls, VIDU_START_END_VIDEO, payload) + results = await execute_task(cls, FAL_VIDU_I2V, payload) return IO.NodeOutput(await download_url_to_video_output(results[0].url)) diff --git a/comfy_api_nodes/nodes_wan.py b/comfy_api_nodes/nodes_wan.py index e2afe7f9cb33..d1d7c60d3661 100644 --- a/comfy_api_nodes/nodes_wan.py +++ b/comfy_api_nodes/nodes_wan.py @@ -13,10 +13,14 @@ poll_op, sync_op, tensor_to_base64_string, - upload_video_to_comfyapi, + upload_video_to_fal, validate_audio_duration, validate_video_duration, ) +from comfy_api_nodes.util._helpers import get_fal_auth_header +from comfy_api_nodes.util.client import fal_run + +FAL_WAN_I2V = "fal-ai/wan-pro/image-to-video" class Text2ImageInputField(BaseModel): @@ -241,14 +245,9 @@ def define_schema(cls): IO.Image.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.03}""", - ), ) @classmethod @@ -263,32 +262,17 @@ async def execute( prompt_extend: bool = True, watermark: bool = False, ): - initial_response = await sync_op( - cls, - ApiEndpoint(path="/proxy/wan/api/v1/services/aigc/text2image/image-synthesis", method="POST"), - response_model=TaskCreationResponse, - data=Text2ImageTaskCreationRequest( - model=model, - input=Text2ImageInputField(prompt=prompt, negative_prompt=negative_prompt), - parameters=Txt2ImageParametersField( - size=f"{width}*{height}", - seed=seed, - prompt_extend=prompt_extend, - watermark=watermark, - ), - ), - ) - if not initial_response.output: - raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}") - response = await poll_op( - cls, - ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"), - response_model=ImageTaskStatusResponse, - status_extractor=lambda x: x.output.task_status, - estimated_duration=9, - poll_interval=3, - ) - return IO.NodeOutput(await download_url_to_image_tensor(str(response.output.results[0].url))) + result = await fal_run(cls, FAL_WAN_I2V, { + "model": model, + "input": {"prompt": prompt, "negative_prompt": negative_prompt}, + "parameters": { + "size": f"{width}*{height}", + "seed": seed, + "prompt_extend": prompt_extend, + "watermark": watermark, + }, + }, estimated_duration=9) # TODO: verify fal.ai field names; use correct fal model for wan text-to-image + return IO.NodeOutput(await download_url_to_image_tensor(result["images"][0]["url"])) class WanImageToImageApi(IO.ComfyNode): @@ -364,14 +348,9 @@ def define_schema(cls): IO.Image.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.03}""", - ), ) @classmethod @@ -392,31 +371,15 @@ async def execute( images = [] for i in image: images.append("data:image/png;base64," + tensor_to_base64_string(i, total_pixels=4096 * 4096)) - initial_response = await sync_op( - cls, - ApiEndpoint(path="/proxy/wan/api/v1/services/aigc/image2image/image-synthesis", method="POST"), - response_model=TaskCreationResponse, - data=Image2ImageTaskCreationRequest( - model=model, - input=Image2ImageInputField(prompt=prompt, negative_prompt=negative_prompt, images=images), - parameters=Image2ImageParametersField( - # size=f"{width}*{height}", - seed=seed, - watermark=watermark, - ), - ), - ) - if not initial_response.output: - raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}") - response = await poll_op( - cls, - ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"), - response_model=ImageTaskStatusResponse, - status_extractor=lambda x: x.output.task_status, - estimated_duration=42, - poll_interval=4, - ) - return IO.NodeOutput(await download_url_to_image_tensor(str(response.output.results[0].url))) + result = await fal_run(cls, FAL_WAN_I2V, { + "model": model, + "input": {"prompt": prompt, "negative_prompt": negative_prompt, "images": images}, + "parameters": { + "seed": seed, + "watermark": watermark, + }, + }, estimated_duration=42) # TODO: verify fal.ai field names; use correct fal model for wan image-to-image + return IO.NodeOutput(await download_url_to_image_tensor(result["images"][0]["url"])) class WanTextToVideoApi(IO.ComfyNode): @@ -528,22 +491,9 @@ def define_schema(cls): IO.Video.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["duration", "size"]), - expr=""" - ( - $ppsTable := { "480p": 0.05, "720p": 0.1, "1080p": 0.15 }; - $resKey := $substringBefore(widgets.size, ":"); - $pps := $lookup($ppsTable, $resKey); - { "type": "usd", "usd": $round($pps * widgets.duration, 2) } - ) - """, - ), ) @classmethod @@ -571,35 +521,20 @@ async def execute( validate_audio_duration(audio, 3.0, 29.0) audio_url = "data:audio/mp3;base64," + audio_to_base64_string(audio, "mp3", "libmp3lame") - initial_response = await sync_op( - cls, - ApiEndpoint(path="/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis", method="POST"), - response_model=TaskCreationResponse, - data=Text2VideoTaskCreationRequest( - model=model, - input=Text2VideoInputField(prompt=prompt, negative_prompt=negative_prompt, audio_url=audio_url), - parameters=Text2VideoParametersField( - size=f"{width}*{height}", - duration=duration, - seed=seed, - audio=generate_audio, - prompt_extend=prompt_extend, - watermark=watermark, - shot_type=shot_type, - ), - ), - ) - if not initial_response.output: - raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}") - response = await poll_op( - cls, - ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"), - response_model=VideoTaskStatusResponse, - status_extractor=lambda x: x.output.task_status, - estimated_duration=120 * int(duration / 5), - poll_interval=6, - ) - return IO.NodeOutput(await download_url_to_video_output(response.output.video_url)) + result = await fal_run(cls, FAL_WAN_I2V, { + "model": model, + "input": {"prompt": prompt, "negative_prompt": negative_prompt, "audio_url": audio_url}, + "parameters": { + "size": f"{width}*{height}", + "duration": duration, + "seed": seed, + "audio": generate_audio, + "prompt_extend": prompt_extend, + "watermark": watermark, + "shot_type": shot_type, + }, + }, estimated_duration=120 * int(duration / 5)) # TODO: verify fal.ai field names; use correct fal model for wan text-to-video + return IO.NodeOutput(await download_url_to_video_output(result["video"]["url"])) class WanImageToVideoApi(IO.ComfyNode): @@ -704,21 +639,9 @@ def define_schema(cls): IO.Video.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"]), - expr=""" - ( - $ppsTable := { "480p": 0.05, "720p": 0.1, "1080p": 0.15 }; - $pps := $lookup($ppsTable, widgets.resolution); - { "type": "usd", "usd": $round($pps * widgets.duration, 2) } - ) - """, - ), ) @classmethod @@ -748,37 +671,20 @@ async def execute( if audio is not None: validate_audio_duration(audio, 3.0, 29.0) audio_url = "data:audio/mp3;base64," + audio_to_base64_string(audio, "mp3", "libmp3lame") - initial_response = await sync_op( - cls, - ApiEndpoint(path="/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis", method="POST"), - response_model=TaskCreationResponse, - data=Image2VideoTaskCreationRequest( - model=model, - input=Image2VideoInputField( - prompt=prompt, negative_prompt=negative_prompt, img_url=image_url, audio_url=audio_url - ), - parameters=Image2VideoParametersField( - resolution=resolution, - duration=duration, - seed=seed, - audio=generate_audio, - prompt_extend=prompt_extend, - watermark=watermark, - shot_type=shot_type, - ), - ), - ) - if not initial_response.output: - raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}") - response = await poll_op( - cls, - ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"), - response_model=VideoTaskStatusResponse, - status_extractor=lambda x: x.output.task_status, - estimated_duration=120 * int(duration / 5), - poll_interval=6, - ) - return IO.NodeOutput(await download_url_to_video_output(response.output.video_url)) + result = await fal_run(cls, FAL_WAN_I2V, { + "model": model, + "input": {"prompt": prompt, "negative_prompt": negative_prompt, "img_url": image_url, "audio_url": audio_url}, + "parameters": { + "resolution": resolution, + "duration": duration, + "seed": seed, + "audio": generate_audio, + "prompt_extend": prompt_extend, + "watermark": watermark, + "shot_type": shot_type, + }, + }, estimated_duration=120 * int(duration / 5)) # TODO: verify fal.ai field names + return IO.NodeOutput(await download_url_to_video_output(result["video"]["url"])) class WanReferenceVideoApi(IO.ComfyNode): @@ -863,27 +769,9 @@ def define_schema(cls): IO.Video.Output(), ], hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["size", "duration"]), - expr=""" - ( - $rate := $contains(widgets.size, "1080p") ? 0.15 : 0.10; - $inputMin := 2 * $rate; - $inputMax := 5 * $rate; - $outputPrice := widgets.duration * $rate; - { - "type": "range_usd", - "min_usd": $inputMin + $outputPrice, - "max_usd": $inputMax + $outputPrice - } - ) - """, - ), ) @classmethod @@ -903,37 +791,20 @@ async def execute( for i in reference_videos: validate_video_duration(reference_videos[i], min_duration=2, max_duration=30) for i in reference_videos: - reference_video_urls.append(await upload_video_to_comfyapi(cls, reference_videos[i])) + reference_video_urls.append(await upload_video_to_fal(reference_videos[i])) width, height = RES_IN_PARENS.search(size).groups() - initial_response = await sync_op( - cls, - ApiEndpoint(path="/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis", method="POST"), - response_model=TaskCreationResponse, - data=Reference2VideoTaskCreationRequest( - model=model, - input=Reference2VideoInputField( - prompt=prompt, negative_prompt=negative_prompt, reference_video_urls=reference_video_urls - ), - parameters=Reference2VideoParametersField( - size=f"{width}*{height}", - duration=duration, - shot_type=shot_type, - watermark=watermark, - seed=seed, - ), - ), - ) - if not initial_response.output: - raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}") - response = await poll_op( - cls, - ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"), - response_model=VideoTaskStatusResponse, - status_extractor=lambda x: x.output.task_status, - poll_interval=6, - max_poll_attempts=280, - ) - return IO.NodeOutput(await download_url_to_video_output(response.output.video_url)) + result = await fal_run(cls, FAL_WAN_I2V, { + "model": model, + "input": {"prompt": prompt, "negative_prompt": negative_prompt, "reference_video_urls": reference_video_urls}, + "parameters": { + "size": f"{width}*{height}", + "duration": duration, + "shot_type": shot_type, + "watermark": watermark, + "seed": seed, + }, + }) # TODO: verify fal.ai field names; use correct fal model for wan reference-to-video + return IO.NodeOutput(await download_url_to_video_output(result["video"]["url"])) class WanApiExtension(ComfyExtension): diff --git a/comfy_api_nodes/nodes_wavespeed.py b/comfy_api_nodes/nodes_wavespeed.py deleted file mode 100644 index c59fafd3bfd7..000000000000 --- a/comfy_api_nodes/nodes_wavespeed.py +++ /dev/null @@ -1,178 +0,0 @@ -from typing_extensions import override - -from comfy_api.latest import IO, ComfyExtension, Input -from comfy_api_nodes.apis.wavespeed import ( - FlashVSRRequest, - TaskCreatedResponse, - TaskResultResponse, - SeedVR2ImageRequest, -) -from comfy_api_nodes.util import ( - ApiEndpoint, - download_url_to_video_output, - poll_op, - sync_op, - upload_video_to_comfyapi, - validate_container_format_is_mp4, - validate_video_duration, - upload_images_to_comfyapi, - get_number_of_images, - download_url_to_image_tensor, -) - - -class WavespeedFlashVSRNode(IO.ComfyNode): - @classmethod - def define_schema(cls): - return IO.Schema( - node_id="WavespeedFlashVSRNode", - display_name="FlashVSR Video Upscale", - category="api node/video/WaveSpeed", - description="Fast, high-quality video upscaler that " - "boosts resolution and restores clarity for low-resolution or blurry footage.", - inputs=[ - IO.Video.Input("video"), - IO.Combo.Input("target_resolution", options=["720p", "1080p", "2K", "4K"]), - ], - outputs=[ - IO.Video.Output(), - ], - hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, - IO.Hidden.unique_id, - ], - is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["target_resolution"]), - expr=""" - ( - $price_for_1sec := {"720p": 0.012, "1080p": 0.018, "2k": 0.024, "4k": 0.032}; - { - "type":"usd", - "usd": $lookup($price_for_1sec, widgets.target_resolution), - "format":{"suffix": "/second", "approximate": true} - } - ) - """, - ), - ) - - @classmethod - async def execute( - cls, - video: Input.Video, - target_resolution: str, - ) -> IO.NodeOutput: - validate_container_format_is_mp4(video) - validate_video_duration(video, min_duration=5, max_duration=60 * 10) - initial_res = await sync_op( - cls, - ApiEndpoint(path="/proxy/wavespeed/api/v3/wavespeed-ai/flashvsr", method="POST"), - response_model=TaskCreatedResponse, - data=FlashVSRRequest( - target_resolution=target_resolution.lower(), - video=await upload_video_to_comfyapi(cls, video), - duration=video.get_duration(), - ), - ) - if initial_res.code != 200: - raise ValueError(f"Task creation fails with code={initial_res.code} and message={initial_res.message}") - final_response = await poll_op( - cls, - ApiEndpoint(path=f"/proxy/wavespeed/api/v3/predictions/{initial_res.data.id}/result"), - response_model=TaskResultResponse, - status_extractor=lambda x: "failed" if x.data is None else x.data.status, - poll_interval=10.0, - max_poll_attempts=480, - ) - if final_response.code != 200: - raise ValueError( - f"Task processing failed with code={final_response.code} and message={final_response.message}" - ) - return IO.NodeOutput(await download_url_to_video_output(final_response.data.outputs[0])) - - -class WavespeedImageUpscaleNode(IO.ComfyNode): - @classmethod - def define_schema(cls): - return IO.Schema( - node_id="WavespeedImageUpscaleNode", - display_name="WaveSpeed Image Upscale", - category="api node/image/WaveSpeed", - description="Boost image resolution and quality, upscaling photos to 4K or 8K for sharp, detailed results.", - inputs=[ - IO.Combo.Input("model", options=["SeedVR2", "Ultimate"]), - IO.Image.Input("image"), - IO.Combo.Input("target_resolution", options=["2K", "4K", "8K"]), - ], - outputs=[ - IO.Image.Output(), - ], - hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, - IO.Hidden.unique_id, - ], - is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["model"]), - expr=""" - ( - $prices := {"seedvr2": 0.01, "ultimate": 0.06}; - {"type":"usd", "usd": $lookup($prices, widgets.model)} - ) - """, - ), - ) - - @classmethod - async def execute( - cls, - model: str, - image: Input.Image, - target_resolution: str, - ) -> IO.NodeOutput: - if get_number_of_images(image) != 1: - raise ValueError("Exactly one input image is required.") - if model == "SeedVR2": - model_path = "seedvr2/image" - else: - model_path = "ultimate-image-upscaler" - initial_res = await sync_op( - cls, - ApiEndpoint(path=f"/proxy/wavespeed/api/v3/wavespeed-ai/{model_path}", method="POST"), - response_model=TaskCreatedResponse, - data=SeedVR2ImageRequest( - target_resolution=target_resolution.lower(), - image=(await upload_images_to_comfyapi(cls, image, max_images=1))[0], - ), - ) - if initial_res.code != 200: - raise ValueError(f"Task creation fails with code={initial_res.code} and message={initial_res.message}") - final_response = await poll_op( - cls, - ApiEndpoint(path=f"/proxy/wavespeed/api/v3/predictions/{initial_res.data.id}/result"), - response_model=TaskResultResponse, - status_extractor=lambda x: "failed" if x.data is None else x.data.status, - poll_interval=10.0, - max_poll_attempts=480, - ) - if final_response.code != 200: - raise ValueError( - f"Task processing failed with code={final_response.code} and message={final_response.message}" - ) - return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.outputs[0])) - - -class WavespeedExtension(ComfyExtension): - @override - async def get_node_list(self) -> list[type[IO.ComfyNode]]: - return [ - WavespeedFlashVSRNode, - WavespeedImageUpscaleNode, - ] - - -async def comfy_entrypoint() -> WavespeedExtension: - return WavespeedExtension() diff --git a/comfy_api_nodes/util/__init__.py b/comfy_api_nodes/util/__init__.py index 0cb9a47c780f..229aeb831e51 100644 --- a/comfy_api_nodes/util/__init__.py +++ b/comfy_api_nodes/util/__init__.py @@ -35,12 +35,13 @@ download_url_to_video_output, ) from .upload_helpers import ( - upload_3d_model_to_comfyapi, - upload_audio_to_comfyapi, - upload_file_to_comfyapi, - upload_image_to_comfyapi, - upload_images_to_comfyapi, - upload_video_to_comfyapi, + upload_3d_model_to_fal, + upload_audio_to_fal, + upload_file_to_fal, + upload_file_to_google, + upload_image_to_fal, + upload_images_to_fal, + upload_video_to_fal, ) from .validation_utils import ( get_image_dimensions, @@ -64,13 +65,14 @@ "poll_op_raw", "sync_op", "sync_op_raw", - # Upload helpers - "upload_3d_model_to_comfyapi", - "upload_audio_to_comfyapi", - "upload_file_to_comfyapi", - "upload_image_to_comfyapi", - "upload_images_to_comfyapi", - "upload_video_to_comfyapi", + # Upload helpers (fal.ai BYOK) + "upload_3d_model_to_fal", + "upload_audio_to_fal", + "upload_file_to_fal", + "upload_file_to_google", + "upload_image_to_fal", + "upload_images_to_fal", + "upload_video_to_fal", # Download helpers "download_url_as_bytesio", "download_url_to_bytesio", diff --git a/comfy_api_nodes/util/_helpers.py b/comfy_api_nodes/util/_helpers.py index 648defe3deba..fd136282cbe4 100644 --- a/comfy_api_nodes/util/_helpers.py +++ b/comfy_api_nodes/util/_helpers.py @@ -8,11 +8,10 @@ from yarl import URL -from comfy.cli_args import args from comfy.model_management import processing_interrupted from comfy_api.latest import IO -from .common_exceptions import ProcessingInterrupted +from .common_exceptions import MissingApiKeyError, ProcessingInterrupted _HAS_PCT_ESC = re.compile(r"%[0-9A-Fa-f]{2}") # any % followed by 2 hex digits _HAS_BAD_PCT = re.compile(r"%(?![0-9A-Fa-f]{2})") # any % not followed by 2 hex digits @@ -27,16 +26,57 @@ def get_node_id(node_cls: type[IO.ComfyNode]) -> str: return node_cls.hidden.unique_id -def get_auth_header(node_cls: type[IO.ComfyNode]) -> dict[str, str]: - if node_cls.hidden.auth_token_comfy_org: - return {"Authorization": f"Bearer {node_cls.hidden.auth_token_comfy_org}"} - if node_cls.hidden.api_key_comfy_org: - return {"X-API-KEY": node_cls.hidden.api_key_comfy_org} - return {} +def get_google_auth_header() -> dict[str, str]: + """Return Google API auth header from GOOGLE_API_KEY env var. + Raises MissingApiKeyError if the variable is unset or empty. + """ + key = os.environ.get("GOOGLE_API_KEY", "").strip() + if not key: + raise MissingApiKeyError( + "GOOGLE_API_KEY environment variable is not set. " + "Set it to your Google AI API key (https://aistudio.google.com/apikey)." + ) + return {"x-goog-api-key": key} + + +def get_fal_auth_header() -> dict[str, str]: + """Return fal.ai auth header from FAL_API_KEY env var. + + Raises MissingApiKeyError if the variable is unset or empty. + """ + key = os.environ.get("FAL_API_KEY", "").strip() + if not key: + raise MissingApiKeyError( + "FAL_API_KEY environment variable is not set. " + "Set it to your fal.ai API key (https://fal.ai/dashboard/keys)." + ) + return {"Authorization": f"Key {key}"} -def default_base_url() -> str: - return getattr(args, "comfy_api_base", "https://api.comfy.org") + +# Domain allowlists for auth header safety -- prevents sending keys to wrong hosts +_GOOGLE_DOMAINS = (".googleapis.com",) +_FAL_DOMAINS = (".fal.run", ".fal.ai", ".fal.media") + + +def validate_auth_header_domain(url: str, headers: dict[str, str]) -> None: + """Raise ValueError if auth headers would be sent to a non-allowlisted domain. + + Prevents SSRF-style attacks where a crafted URL could exfiltrate API keys. + """ + from urllib.parse import urlparse + hostname = urlparse(url).hostname or "" + if "x-goog-api-key" in headers: + if not any(hostname.endswith(d) for d in _GOOGLE_DOMAINS): + raise ValueError( + f"Refusing to send Google API key to non-Google domain: {hostname}" + ) + auth = headers.get("Authorization", "") + if auth.startswith("Key "): + if not any(hostname.endswith(d) for d in _FAL_DOMAINS): + raise ValueError( + f"Refusing to send fal.ai API key to non-fal domain: {hostname}" + ) async def sleep_with_interrupt( diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py index 79ffb77c14a6..ab79a4e21e86 100644 --- a/comfy_api_nodes/util/client.py +++ b/comfy_api_nodes/util/client.py @@ -21,16 +21,43 @@ from . import request_logger from ._helpers import ( - default_base_url, - get_auth_header, get_node_id, is_processing_interrupted, sleep_with_interrupt, + validate_auth_header_domain, ) from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted M = TypeVar("M", bound=BaseModel) +# --- Connection pooling --- +# Reuse aiohttp sessions per origin to avoid TLS handshake per request. +_session_pool: dict[str, aiohttp.ClientSession] = {} + +# fal.ai concurrency semaphore (standard tier: 2 concurrent tasks) +_fal_concurrency = asyncio.Semaphore(2) + + +def _get_pooled_session(url: str, timeout: aiohttp.ClientTimeout) -> aiohttp.ClientSession: + """Return a pooled aiohttp session for the given URL's origin. + + Sessions are lazily created and reused across requests to the same host. + """ + parsed = urlparse(url) + key = f"{parsed.scheme}://{parsed.netloc}" + existing = _session_pool.get(key) + if existing is not None and not existing.closed: + return existing + sess = aiohttp.ClientSession( + timeout=timeout, + connector=aiohttp.TCPConnector( + limit_per_host=10, + keepalive_timeout=30, + ), + ) + _session_pool[key] = sess + return sess + class ApiEndpoint: def __init__( @@ -429,6 +456,111 @@ async def _ticker(): await ticker_task +# --------------------------------------------------------------------------- +# fal.ai queue helpers (thin wrappers around sync_op / poll_op) +# --------------------------------------------------------------------------- + +_FAL_MODEL_ID_PATTERN = __import__("re").compile(r"^[a-zA-Z0-9_-]+(/[a-zA-Z0-9_.\-]+)*$") + + +def _validate_fal_model_id(model_id: str) -> None: + """Validate model_id to prevent SSRF via path traversal.""" + if not model_id or not _FAL_MODEL_ID_PATTERN.match(model_id): + raise ValueError( + f"Invalid fal.ai model ID: {model_id!r}. " + "Must match pattern: fal-ai/model-name/variant" + ) + if ".." in model_id: + raise ValueError(f"Invalid fal.ai model ID (path traversal): {model_id!r}") + + +async def fal_submit( + cls: type[IO.ComfyNode], + model_id: str, + data: dict, +) -> dict: + """Submit a job to fal.ai queue. Returns the submit response dict.""" + from ._helpers import get_fal_auth_header + from ..apis.fal import FalQueueSubmitResponse + + _validate_fal_model_id(model_id) + resp = await sync_op( + cls, + ApiEndpoint( + f"https://queue.fal.run/{model_id}", + "POST", + headers=get_fal_auth_header(), + ), + response_model=FalQueueSubmitResponse, + data=BaseModel.model_construct(**data) if isinstance(data, dict) else data, + monitor_progress=False, + final_label_on_success=None, + ) + return resp + + +async def fal_poll( + cls: type[IO.ComfyNode], + status_url: str, + *, + poll_interval: float = 3.0, + estimated_duration: int | None = None, +) -> dict: + """Poll a fal.ai queue status URL until COMPLETED. Returns the status response dict.""" + from ._helpers import get_fal_auth_header + + resp = await poll_op( + cls, + ApiEndpoint(status_url, "GET", headers=get_fal_auth_header()), + response_model=__import__("comfy_api_nodes.apis.fal", fromlist=["FalQueueStatusResponse"]).FalQueueStatusResponse, + status_extractor=lambda r: "completed" if r.status == "COMPLETED" else ("queued" if r.status == "IN_QUEUE" else "processing"), + poll_interval=poll_interval, + estimated_duration=estimated_duration, + ) + return resp + + +async def fal_fetch_result( + cls: type[IO.ComfyNode], + response_url: str, +) -> dict: + """Fetch the final result from a completed fal.ai queue job.""" + from ._helpers import get_fal_auth_header + + result = await sync_op_raw( + cls, + ApiEndpoint(response_url, "GET", headers=get_fal_auth_header()), + monitor_progress=False, + final_label_on_success=None, + ) + if not isinstance(result, dict): + raise Exception("fal.ai result was not JSON") + return result + + +async def fal_run( + cls: type[IO.ComfyNode], + model_id: str, + data: dict, + *, + poll_interval: float = 3.0, + estimated_duration: int | None = None, +) -> dict: + """Combined submit + poll + fetch for fal.ai queue API. + + This is the primary interface for fal.ai-routed nodes. + Returns the model-specific result dict. + """ + submit_resp = await fal_submit(cls, model_id, data) + await fal_poll( + cls, + submit_resp.status_url, + poll_interval=poll_interval, + estimated_duration=estimated_duration, + ) + return await fal_fetch_result(cls, submit_resp.response_url) + + def _display_text( node_cls: type[IO.ComfyNode], text: str | None, @@ -472,21 +604,12 @@ async def _diagnose_connectivity() -> dict[str, bool]: """Best-effort connectivity diagnostics to distinguish local vs. server issues.""" results = { "internet_accessible": False, - "api_accessible": False, } timeout = aiohttp.ClientTimeout(total=5.0) async with aiohttp.ClientSession(timeout=timeout) as session: with contextlib.suppress(ClientError, OSError): async with session.get("https://www.google.com") as resp: results["internet_accessible"] = resp.status < 500 - if not results["internet_accessible"]: - return results - - parsed = urlparse(default_base_url()) - health_url = f"{parsed.scheme}://{parsed.netloc}/health" - with contextlib.suppress(ClientError, OSError): - async with session.get(health_url) as resp: - results["api_accessible"] = resp.status < 500 return results @@ -510,24 +633,40 @@ def _merge_params(endpoint_params: dict[str, Any], method: str, data: dict[str, def _friendly_http_message(status: int, body: Any) -> str: if status == 401: - return "Unauthorized: Please login first to use this node." + return "Unauthorized: Invalid or missing API key. Check your GOOGLE_API_KEY or FAL_API_KEY environment variable." if status == 402: - return "Payment Required: Please add credits to your account to use this node." + return "Payment Required: Your API key's account needs billing enabled or credits added." + if status == 403: + return "Forbidden: Your API key does not have permission for this operation. Check that billing is enabled." if status == 409: - return "There is a problem with your account. Please contact support@comfy.org." + return "Conflict: The server could not process this request. The resource may already exist or be in a conflicting state." if status == 429: return "Rate Limit Exceeded: The server returned 429 after all retry attempts. Please wait and try again." try: if isinstance(body, dict): + # Google API error format: {"error": {"code": N, "message": "...", "status": "..."}} err = body.get("error") if isinstance(err, dict): msg = err.get("message") - typ = err.get("type") + typ = err.get("type") or err.get("status") if msg and typ: return f"API Error: {msg} (Type: {typ})" if msg: return f"API Error: {msg}" - return f"API Error: {json.dumps(body)}" + # fal.ai error format: {"detail": [{"msg": "...", "type": "..."}]} + detail = body.get("detail") + if isinstance(detail, list) and detail: + first = detail[0] + if isinstance(first, dict): + fal_msg = first.get("msg", "") + fal_type = first.get("type", "") + if fal_msg: + return f"API Error: {fal_msg}" + (f" ({fal_type})" if fal_type else "") + # Fallback -- truncate to avoid leaking sensitive data in raw dumps + dumped = json.dumps(body) + if len(dumped) <= 300: + return f"API Error: {dumped}" + return f"API Error (status {status}): Response too large to display" else: txt = str(body) if len(txt) <= 200: @@ -573,8 +712,11 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): """Core request with retries, per-second interruption monitoring, true cancellation, and friendly errors.""" url = cfg.endpoint.path parsed_url = urlparse(url) - if not parsed_url.scheme and not parsed_url.netloc: # is URL relative? - url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/")) + if not parsed_url.scheme and not parsed_url.netloc: + raise ValueError( + f"Relative URL not supported in BYOK mode: {url!r}. " + "All API endpoints must use absolute URLs." + ) method = cfg.endpoint.method params = _merge_params(cfg.endpoint.query_params, method, cfg.data if method == "GET" else None) @@ -611,11 +753,12 @@ async def _monitor(stop_evt: asyncio.Event, start_ts: float): logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt) payload_headers = {"Accept": "*/*"} if expect_binary else {"Accept": "application/json"} - if not parsed_url.scheme and not parsed_url.netloc: # is URL relative? - payload_headers.update(get_auth_header(cfg.node_cls)) if cfg.endpoint.headers: payload_headers.update(cfg.endpoint.headers) + # Validate auth headers are only sent to allowlisted provider domains + validate_auth_header_domain(url, payload_headers) + payload_kw: dict[str, Any] = {"headers": payload_headers} if method == "GET": payload_headers.pop("Content-Type", None) @@ -625,7 +768,7 @@ async def _monitor(stop_evt: asyncio.Event, start_ts: float): monitor_task = asyncio.create_task(_monitor(stop_event, start_time)) timeout = aiohttp.ClientTimeout(total=cfg.timeout) - sess = aiohttp.ClientSession(timeout=timeout) + sess = _get_pooled_session(url, timeout) if cfg.content_type == "multipart/form-data" and method != "GET": # aiohttp will set Content-Type boundary; remove any fixed Content-Type @@ -863,7 +1006,7 @@ async def _monitor(stop_evt: asyncio.Event, start_ts: float): error_message=f"ApiServerError: {str(e)}", ) raise ApiServerError( - f"The API server at {default_base_url()} is currently unreachable. " + f"The API server at {url} is currently unreachable. " f"The service may be experiencing issues." ) from e finally: @@ -872,9 +1015,7 @@ async def _monitor(stop_evt: asyncio.Event, start_ts: float): monitor_task.cancel() with contextlib.suppress(Exception): await monitor_task - if sess: - with contextlib.suppress(Exception): - await sess.close() + # Session is pooled -- do not close it here if operation_succeeded and cfg.monitor_progress and cfg.final_label_on_success: _display_time_progress( cfg.node_cls, diff --git a/comfy_api_nodes/util/common_exceptions.py b/comfy_api_nodes/util/common_exceptions.py index 0606a4407007..fe8eb080e1de 100644 --- a/comfy_api_nodes/util/common_exceptions.py +++ b/comfy_api_nodes/util/common_exceptions.py @@ -12,3 +12,7 @@ class ApiServerError(NetworkError): class ProcessingInterrupted(Exception): """Operation was interrupted by user/runtime via processing_interrupted().""" + + +class MissingApiKeyError(Exception): + """Required API key environment variable is not configured or is empty.""" diff --git a/comfy_api_nodes/util/download_helpers.py b/comfy_api_nodes/util/download_helpers.py index aa588d038b4d..cbe4b86d68ec 100644 --- a/comfy_api_nodes/util/download_helpers.py +++ b/comfy_api_nodes/util/download_helpers.py @@ -4,7 +4,7 @@ from io import BytesIO from pathlib import Path from typing import IO -from urllib.parse import urljoin, urlparse +from urllib.parse import urlparse import aiohttp import torch @@ -16,8 +16,6 @@ from . import request_logger from ._helpers import ( - default_base_url, - get_auth_header, is_processing_interrupted, sleep_with_interrupt, to_aiohttp_url, @@ -38,6 +36,7 @@ async def download_url_to_bytesio( retry_delay: float = 1.0, retry_backoff: float = 2.0, cls: type[COMFY_IO.ComfyNode] = None, + extra_headers: dict[str, str] | None = None, ) -> None: """Stream-download a URL to `dest`. @@ -46,8 +45,8 @@ async def download_url_to_bytesio( - a file-like object opened in binary write mode (must implement .write()), - a filesystem path (str | pathlib.Path), which will be opened with 'wb'. - If `url` starts with `/proxy/`, `cls` must be provided so the URL can be expanded - to an absolute URL and authentication headers can be applied. + All URLs must be absolute. Pass `extra_headers` for provider-specific auth + (e.g. Google video downloads require `x-goog-api-key`). Raises: ProcessingInterrupted, LocalNetworkError, ApiServerError, Exception (HTTP and other errors) @@ -60,11 +59,14 @@ async def download_url_to_bytesio( headers: dict[str, str] = {} parsed_url = urlparse(url) - if not parsed_url.scheme and not parsed_url.netloc: # is URL relative? - if cls is None: - raise ValueError("For relative 'cloud' paths, the `cls` parameter is required.") - url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/")) - headers = get_auth_header(cls) + if not parsed_url.scheme and not parsed_url.netloc: + raise ValueError( + f"Relative URL not supported in BYOK mode: {url!r}. " + "All download URLs must be absolute." + ) + + if extra_headers: + headers.update(extra_headers) while True: attempt += 1 diff --git a/comfy_api_nodes/util/request_logger.py b/comfy_api_nodes/util/request_logger.py index fe0543d9bc05..b487ffadc85c 100644 --- a/comfy_api_nodes/util/request_logger.py +++ b/comfy_api_nodes/util/request_logger.py @@ -58,6 +58,26 @@ def _build_log_filepath(log_dir: str, operation_id: str, request_url: str) -> st return os.path.join(log_dir, f"{prefix}{slug}{suffix}") +_SENSITIVE_HEADER_NAMES = frozenset({ + "authorization", "x-goog-api-key", "x-api-key", +}) +_SENSITIVE_HEADER_SUBSTRINGS = ("key", "token", "secret") + + +def _redact_headers(headers: dict) -> dict: + """Return a copy of headers with sensitive values replaced by [REDACTED].""" + if not headers: + return headers + redacted = {} + for name, value in headers.items(): + lower = name.lower() + if lower in _SENSITIVE_HEADER_NAMES or any(s in lower for s in _SENSITIVE_HEADER_SUBSTRINGS): + redacted[name] = "[REDACTED]" + else: + redacted[name] = value + return redacted + + def _format_data_for_logging(data: Any) -> str: """Helper to format data (dict, str, bytes) for logging.""" if isinstance(data, bytes): @@ -101,7 +121,7 @@ def log_request_response( log_content.append(f"Method: {request_method}") log_content.append(f"URL: {request_url}") if request_headers: - log_content.append(f"Headers:\n{_format_data_for_logging(request_headers)}") + log_content.append(f"Headers:\n{_format_data_for_logging(_redact_headers(request_headers))}") if request_params: log_content.append(f"Params:\n{_format_data_for_logging(request_params)}") if request_data is not None: @@ -111,7 +131,7 @@ def log_request_response( if response_status_code is not None: log_content.append(f"Status Code: {response_status_code}") if response_headers: - log_content.append(f"Headers:\n{_format_data_for_logging(response_headers)}") + log_content.append(f"Headers:\n{_format_data_for_logging(_redact_headers(response_headers))}") if response_content is not None: log_content.append(f"Content:\n{_format_data_for_logging(response_content)}") if error_message: diff --git a/comfy_api_nodes/util/upload_helpers.py b/comfy_api_nodes/util/upload_helpers.py index 6d1d107a18ef..3c9aee8c5ab8 100644 --- a/comfy_api_nodes/util/upload_helpers.py +++ b/comfy_api_nodes/util/upload_helpers.py @@ -1,26 +1,11 @@ -import asyncio -import contextlib import logging -import time import uuid from io import BytesIO from urllib.parse import urlparse import aiohttp import torch -from pydantic import BaseModel, Field -from comfy_api.latest import IO, Input, Types - -from . import request_logger -from ._helpers import is_processing_interrupted, sleep_with_interrupt -from .client import ( - ApiEndpoint, - _diagnose_connectivity, - _display_time_progress, - sync_op, -) -from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted from .conversions import ( audio_ndarray_to_bytesio, audio_tensor_to_contiguous_ndarray, @@ -28,33 +13,126 @@ ) -class UploadRequest(BaseModel): - file_name: str = Field(..., description="Filename to upload") - content_type: str | None = Field( - None, - description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.", - ) +# --------------------------------------------------------------------------- +# BYOK provider-specific upload helpers +# --------------------------------------------------------------------------- +_FAL_UPLOAD_DOMAINS = (".fal.ai", ".fal.run", ".fal.media") -class UploadResponse(BaseModel): - download_url: str = Field(..., description="URL to GET uploaded file") - upload_url: str = Field(..., description="URL to PUT file to upload") +async def upload_file_to_fal(file_bytes: BytesIO, mime_type: str) -> str: + """Upload a file to fal.ai CDN and return the public CDN URL. -async def upload_images_to_comfyapi( - cls: type[IO.ComfyNode], + Uses the presigned URL flow: POST to initiate, then PUT the file bytes. + """ + from ._helpers import get_fal_auth_header + + headers = get_fal_auth_header() + headers["Content-Type"] = "application/json" + + file_bytes.seek(0) + data = file_bytes.read() + + timeout = aiohttp.ClientTimeout(total=120) + async with aiohttp.ClientSession(timeout=timeout) as sess: + # Step 1: Initiate upload + async with sess.post( + "https://rest.fal.ai/storage/upload/initiate", + params={"storage_type": "fal-cdn-v3"}, + headers=headers, + json={"content_type": mime_type, "file_name": f"{uuid.uuid4().hex[:12]}"}, + ) as resp: + if resp.status >= 400: + body = await resp.text() + raise Exception(f"fal.ai upload initiate failed ({resp.status}): {body[:300]}") + result = await resp.json() + + upload_url = result.get("upload_url", "") + file_url = result.get("file_url", "") + + # Validate returned URL domains for safety + upload_host = urlparse(upload_url).hostname or "" + if not any(upload_host.endswith(d) for d in _FAL_UPLOAD_DOMAINS) and "amazonaws.com" not in upload_host: + raise ValueError(f"fal.ai returned unexpected upload domain: {upload_host}") + + # Step 2: PUT the file bytes to presigned URL + put_headers = {"Content-Type": mime_type} + async with sess.put(upload_url, data=data, headers=put_headers) as resp: + if resp.status >= 400: + body = await resp.text() + raise Exception(f"fal.ai file upload failed ({resp.status}): {body[:300]}") + + return file_url + + +async def upload_file_to_google(file_bytes: BytesIO, mime_type: str, display_name: str) -> str: + """Upload a file to Google's Files API and return the file URI (e.g., 'files/abc123'). + + Uses the resumable upload protocol. + """ + from ._helpers import get_google_auth_header + + auth = get_google_auth_header() + file_bytes.seek(0) + data = file_bytes.read() + num_bytes = len(data) + + timeout = aiohttp.ClientTimeout(total=300) + async with aiohttp.ClientSession(timeout=timeout) as sess: + # Step 1: Start resumable upload + start_headers = { + **auth, + "X-Goog-Upload-Protocol": "resumable", + "X-Goog-Upload-Command": "start", + "X-Goog-Upload-Header-Content-Length": str(num_bytes), + "X-Goog-Upload-Header-Content-Type": mime_type, + "Content-Type": "application/json", + } + async with sess.post( + "https://generativelanguage.googleapis.com/upload/v1beta/files", + headers=start_headers, + json={"file": {"display_name": display_name}}, + ) as resp: + if resp.status >= 400: + body = await resp.text() + raise Exception(f"Google upload initiate failed ({resp.status}): {body[:300]}") + upload_url = resp.headers.get("X-Goog-Upload-URL", "") + if not upload_url: + raise Exception("Google upload initiate did not return an upload URL") + + # Step 2: Upload bytes + upload_headers = { + "X-Goog-Upload-Command": "upload, finalize", + "X-Goog-Upload-Offset": "0", + "Content-Length": str(num_bytes), + } + async with sess.put(upload_url, data=data, headers=upload_headers) as resp: + if resp.status >= 400: + body = await resp.text() + raise Exception(f"Google file upload failed ({resp.status}): {body[:300]}") + result = await resp.json() + + # Extract file URI from response + file_name = result.get("file", {}).get("name", "") + if not file_name: + file_name = result.get("name", "") + return file_name + + +async def upload_image_to_fal(image_tensor: torch.Tensor, mime_type: str = "image/png") -> str: + """Convert an image tensor to bytes and upload to fal.ai CDN. Returns the CDN URL.""" + bio = tensor_to_bytesio(image_tensor, mime_type=mime_type) + return await upload_file_to_fal(bio, mime_type) + + +async def upload_images_to_fal( image: torch.Tensor | list[torch.Tensor], *, max_images: int = 8, - mime_type: str | None = None, - wait_label: str | None = "Uploading", - show_batch_index: bool = True, + mime_type: str = "image/png", total_pixels: int | None = 2048 * 2048, ) -> list[str]: - """ - Uploads images to ComfyUI API and returns download URLs. - To upload multiple images, stack them in the batch dimension first. - """ + """Upload multiple images to fal.ai CDN and return CDN URLs.""" tensors: list[torch.Tensor] = [] if isinstance(image, list): for img in image: @@ -70,78 +148,30 @@ async def upload_images_to_comfyapi( else: tensors.append(image) - # if batched, try to upload each file if max_images is greater than 0 download_urls: list[str] = [] num_to_upload = min(len(tensors), max_images) - batch_start_ts = time.monotonic() - for idx in range(num_to_upload): - tensor = tensors[idx] - img_io = tensor_to_bytesio(tensor, total_pixels=total_pixels, mime_type=mime_type) - - effective_label = wait_label - if wait_label and show_batch_index and num_to_upload > 1: - effective_label = f"{wait_label} ({idx + 1}/{num_to_upload})" - - url = await upload_file_to_comfyapi(cls, img_io, img_io.name, mime_type, effective_label, batch_start_ts) + bio = tensor_to_bytesio(tensors[idx], total_pixels=total_pixels, mime_type=mime_type) + url = await upload_file_to_fal(bio, mime_type or "image/png") download_urls.append(url) return download_urls -async def upload_image_to_comfyapi( - cls: type[IO.ComfyNode], - image: torch.Tensor, - *, - mime_type: str | None = None, - wait_label: str | None = "Uploading", - total_pixels: int | None = 2048 * 2048, -) -> str: - """Uploads a single image to ComfyUI API and returns its download URL.""" - return ( - await upload_images_to_comfyapi( - cls, - image, - max_images=1, - mime_type=mime_type, - wait_label=wait_label, - show_batch_index=False, - total_pixels=total_pixels, - ) - )[0] - - -async def upload_audio_to_comfyapi( - cls: type[IO.ComfyNode], - audio: Input.Audio, +async def upload_video_to_fal( + video, *, - container_format: str = "mp4", - codec_name: str = "aac", - mime_type: str = "audio/mp4", + container=None, + codec=None, + max_duration: int | None = None, ) -> str: - """ - Uploads a single audio input to ComfyUI API and returns its download URL. - Encodes the raw waveform into the specified format before uploading. - """ - sample_rate: int = audio["sample_rate"] - waveform: torch.Tensor = audio["waveform"] - audio_data_np = audio_tensor_to_contiguous_ndarray(waveform) - audio_bytes_io = audio_ndarray_to_bytesio(audio_data_np, sample_rate, container_format, codec_name) - return await upload_file_to_comfyapi(cls, audio_bytes_io, f"{uuid.uuid4()}.{container_format}", mime_type) + """Convert a video to bytes and upload to fal.ai CDN. Returns the CDN URL.""" + from comfy_api.latest import Types + if container is None: + container = Types.VideoContainer.MP4 + if codec is None: + codec = Types.VideoCodec.H264 -async def upload_video_to_comfyapi( - cls: type[IO.ComfyNode], - video: Input.Video, - *, - container: Types.VideoContainer = Types.VideoContainer.MP4, - codec: Types.VideoCodec = Types.VideoCodec.H264, - max_duration: int | None = None, - wait_label: str | None = "Uploading", -) -> str: - """ - Uploads a single video to ComfyUI API and returns its download URL. - Uses the specified container and codec for saving the video before upload. - """ if max_duration is not None: try: actual_duration = video.get_duration() @@ -154,234 +184,38 @@ async def upload_video_to_comfyapi( raise ValueError(f"Could not verify video duration from source: {e}") from e upload_mime_type = f"video/{container.value.lower()}" - filename = f"{uuid.uuid4()}.{container.value.lower()}" - - # Convert VideoInput to BytesIO using specified container/codec video_bytes_io = BytesIO() video.save_to(video_bytes_io, format=container, codec=codec) video_bytes_io.seek(0) + return await upload_file_to_fal(video_bytes_io, upload_mime_type) - return await upload_file_to_comfyapi(cls, video_bytes_io, filename, upload_mime_type, wait_label) - -_3D_MIME_TYPES = { - "glb": "model/gltf-binary", - "obj": "model/obj", - "fbx": "application/octet-stream", -} +async def upload_audio_to_fal( + audio, + *, + container_format: str = "mp4", + codec_name: str = "aac", + mime_type: str = "audio/mp4", +) -> str: + """Convert audio to bytes and upload to fal.ai CDN. Returns the CDN URL.""" + sample_rate: int = audio["sample_rate"] + waveform: torch.Tensor = audio["waveform"] + audio_data_np = audio_tensor_to_contiguous_ndarray(waveform) + audio_bytes_io = audio_ndarray_to_bytesio(audio_data_np, sample_rate, container_format, codec_name) + return await upload_file_to_fal(audio_bytes_io, mime_type) -async def upload_3d_model_to_comfyapi( - cls: type[IO.ComfyNode], - model_3d: Types.File3D, +async def upload_3d_model_to_fal( + model_3d, file_format: str, ) -> str: - """Uploads a 3D model file to ComfyUI API and returns its download URL.""" - return await upload_file_to_comfyapi( - cls, - model_3d.get_data(), - f"{uuid.uuid4()}.{file_format}", - _3D_MIME_TYPES.get(file_format, "application/octet-stream"), - ) - - -async def upload_file_to_comfyapi( - cls: type[IO.ComfyNode], - file_bytes_io: BytesIO, - filename: str, - upload_mime_type: str | None, - wait_label: str | None = "Uploading", - progress_origin_ts: float | None = None, -) -> str: - """Uploads a single file to ComfyUI API and returns its download URL.""" - if upload_mime_type is None: - request_object = UploadRequest(file_name=filename) - else: - request_object = UploadRequest(file_name=filename, content_type=upload_mime_type) - create_resp = await sync_op( - cls, - endpoint=ApiEndpoint(path="/customers/storage", method="POST"), - data=request_object, - response_model=UploadResponse, - final_label_on_success=None, - monitor_progress=False, - ) - await upload_file( - cls, - create_resp.upload_url, - file_bytes_io, - content_type=upload_mime_type, - wait_label=wait_label, - progress_origin_ts=progress_origin_ts, - ) - return create_resp.download_url - - -async def upload_file( - cls: type[IO.ComfyNode], - upload_url: str, - file: BytesIO | str, - *, - content_type: str | None = None, - max_retries: int = 3, - retry_delay: float = 1.0, - retry_backoff: float = 2.0, - wait_label: str | None = None, - progress_origin_ts: float | None = None, -) -> None: - """ - Upload a file to a signed URL (e.g., S3 pre-signed PUT) with retries, Comfy progress display, and interruption. - - Raises: - ProcessingInterrupted, LocalNetworkError, ApiServerError, Exception - """ - if isinstance(file, BytesIO): - with contextlib.suppress(Exception): - file.seek(0) - data = file.read() - elif isinstance(file, str): - with open(file, "rb") as f: - data = f.read() - else: - raise ValueError("file must be a BytesIO or a filesystem path string") - - headers: dict[str, str] = {} - skip_auto_headers: set[str] = set() - if content_type: - headers["Content-Type"] = content_type - else: - skip_auto_headers.add("Content-Type") # Don't let aiohttp add Content-Type, it can break the signed request - - attempt = 0 - delay = retry_delay - start_ts = progress_origin_ts if progress_origin_ts is not None else time.monotonic() - op_uuid = uuid.uuid4().hex[:8] - while True: - attempt += 1 - operation_id = _generate_operation_id("PUT", upload_url, attempt, op_uuid) - timeout = aiohttp.ClientTimeout(total=None) - stop_evt = asyncio.Event() - - async def _monitor(): - try: - while not stop_evt.is_set(): - if is_processing_interrupted(): - return - if wait_label: - _display_time_progress(cls, wait_label, int(time.monotonic() - start_ts), None) - await asyncio.sleep(1.0) - except asyncio.CancelledError: - return - - monitor_task = asyncio.create_task(_monitor()) - sess: aiohttp.ClientSession | None = None - try: - request_logger.log_request_response( - operation_id=operation_id, - request_method="PUT", - request_url=upload_url, - request_headers=headers or None, - request_params=None, - request_data=f"[File data {len(data)} bytes]", - ) - - sess = aiohttp.ClientSession(timeout=timeout) - req = sess.put(upload_url, data=data, headers=headers, skip_auto_headers=skip_auto_headers) - req_task = asyncio.create_task(req) - - done, pending = await asyncio.wait({req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED) - - if monitor_task in done and req_task in pending: - req_task.cancel() - raise ProcessingInterrupted("Upload cancelled") - - try: - resp = await req_task - except asyncio.CancelledError: - raise ProcessingInterrupted("Upload cancelled") from None - - async with resp: - if resp.status >= 400: - with contextlib.suppress(Exception): - try: - body = await resp.json() - except Exception: - body = await resp.text() - msg = f"Upload failed with status {resp.status}" - request_logger.log_request_response( - operation_id=operation_id, - request_method="PUT", - request_url=upload_url, - response_status_code=resp.status, - response_headers=dict(resp.headers), - response_content=body, - error_message=msg, - ) - if resp.status in {408, 429, 500, 502, 503, 504} and attempt <= max_retries: - await sleep_with_interrupt( - delay, - cls, - wait_label, - start_ts, - None, - display_callback=_display_time_progress if wait_label else None, - ) - delay *= retry_backoff - continue - raise Exception(f"Failed to upload (HTTP {resp.status}).") - request_logger.log_request_response( - operation_id=operation_id, - request_method="PUT", - request_url=upload_url, - response_status_code=resp.status, - response_headers=dict(resp.headers), - response_content="File uploaded successfully.", - ) - return - except asyncio.CancelledError: - raise ProcessingInterrupted("Task cancelled") from None - except (aiohttp.ClientError, OSError) as e: - if attempt <= max_retries: - request_logger.log_request_response( - operation_id=operation_id, - request_method="PUT", - request_url=upload_url, - request_headers=headers or None, - request_data=f"[File data {len(data)} bytes]", - error_message=f"{type(e).__name__}: {str(e)} (will retry)", - ) - await sleep_with_interrupt( - delay, - cls, - wait_label, - start_ts, - None, - display_callback=_display_time_progress if wait_label else None, - ) - delay *= retry_backoff - continue - - diag = await _diagnose_connectivity() - if not diag["internet_accessible"]: - raise LocalNetworkError( - "Unable to connect to the network. Please check your internet connection and try again." - ) from e - raise ApiServerError("The API service appears unreachable at this time.") from e - finally: - stop_evt.set() - if monitor_task: - monitor_task.cancel() - with contextlib.suppress(Exception): - await monitor_task - if sess: - with contextlib.suppress(Exception): - await sess.close() - - -def _generate_operation_id(method: str, url: str, attempt: int, op_uuid: str) -> str: - try: - parsed = urlparse(url) - slug = (parsed.path.rsplit("/", 1)[-1] or parsed.netloc or "upload").strip("/").replace("/", "_") - except Exception: - slug = "upload" - return f"{method}_{slug}_{op_uuid}_try{attempt}" + """Upload a 3D model to fal.ai CDN. Returns the CDN URL.""" + _3d_mime_types = { + "glb": "model/gltf-binary", + "obj": "model/obj", + "fbx": "application/octet-stream", + } + data = model_3d.get_data() + if not isinstance(data, BytesIO): + data = BytesIO(data) + return await upload_file_to_fal(data, _3d_mime_types.get(file_format, "application/octet-stream")) diff --git a/docs/MIGRATION_STATUS.md b/docs/MIGRATION_STATUS.md new file mode 100644 index 000000000000..a94280bdbfbf --- /dev/null +++ b/docs/MIGRATION_STATUS.md @@ -0,0 +1,61 @@ +# BYOK Migration Status + +**Date:** 2026-03-11 + +## Status: CORE MIGRATION COMPLETE + +All ComfyOrg proxy infrastructure has been replaced with direct BYOK API access. Zero references to `api.comfy.org` remain. + +## What's Done + +### Infrastructure (Phase 1a + 1b) +- `MissingApiKeyError` exception in `common_exceptions.py` +- `get_google_auth_header()` / `get_fal_auth_header()` in `_helpers.py` +- `validate_auth_header_domain()` -- domain allowlist preventing SSRF +- `_redact_headers()` in `request_logger.py` -- redacts API keys from logs +- Per-host connection pooling in `client.py` +- `asyncio.Semaphore(2)` for fal.ai concurrency +- `.env` loading via `python-dotenv` in `main.py` +- Startup BYOK key status log in `nodes.py` +- **Old `get_auth_header()` / `default_base_url()` DELETED** -- no ComfyOrg auth remains +- **`_request_base` rejects relative URLs** -- all endpoints must be absolute +- **`_diagnose_connectivity()` no longer checks `api.comfy.org`** + +### Upload Helpers (Complete) +- `upload_file_to_fal()`, `upload_image_to_fal()`, `upload_images_to_fal()` +- `upload_video_to_fal()`, `upload_audio_to_fal()`, `upload_3d_model_to_fal()` +- `upload_file_to_google()` (resumable upload) +- **Old `upload_*_to_comfyapi` functions DELETED** -- ~350 lines removed + +### Google Direct (Phase 2) +- `nodes_gemini.py` -- all 4 nodes hit `generativelanguage.googleapis.com` directly +- `nodes_veo2.py` -- all 3 nodes use `:predictLongRunning` + poll + +### fal.ai Infrastructure (Phase 3) +- `apis/fal.py` -- Pydantic models for queue envelope +- `fal_run()`, `fal_submit()`, `fal_poll()`, `fal_fetch_result()` in `client.py` +- `nodes_fal.py` -- generic fal.ai node (any model by ID) + +### Node Migration (Phase 4) -- ALL 19 FILES COMPLETE +- 8 unavailable providers **deleted** (runway, tripo, magnific, topaz, moonvalley, grok, hitpaw, wavespeed) +- All 19 remaining node files fully migrated to `fal_run()` with proper fal.ai model IDs +- All upload calls migrated from `upload_*_to_comfyapi` to `upload_*_to_fal` +- All `__FAL_*__` placeholder strings eliminated +- All ComfyOrg hidden fields and price badges removed + +### Cleanup (Phase 6) -- COMPLETE +- `--comfy-api-base` CLI arg removed +- `.env` added to `.gitignore` +- `apis/rodin.py` deleted (unused) +- Dead imports cleaned from all migrated node files +- Old upload functions deleted from `upload_helpers.py` +- Old auth functions deleted from `_helpers.py` +- Zero references to `api.comfy.org` in codebase + +## What's Left + +| Task | Effort | +|------|--------| +| Phase 5: Node-level key status badges in frontend (optional) | Medium | +| Verify fal.ai field names at runtime (TODO comments in files) | Ongoing | +| Commit all changes | -- | diff --git a/docs/brainstorms/2026-03-10-byok-provider-migration-brainstorm.md b/docs/brainstorms/2026-03-10-byok-provider-migration-brainstorm.md new file mode 100644 index 000000000000..f3a007b6a5c1 --- /dev/null +++ b/docs/brainstorms/2026-03-10-byok-provider-migration-brainstorm.md @@ -0,0 +1,99 @@ +# Brainstorm: BYOK Provider Migration + +**Date:** 2026-03-10 +**Status:** Reviewed + +## What We're Building + +Migrating all API-based generation nodes in this ComfyUI fork from the ComfyOrg proxy (`api.comfy.org`) to a bring-your-own-key (BYOK) model with two providers: + +1. **Google API (direct)** — for Gemini, Veo 2/3, Nano Banana / ImageGen models +2. **fal.ai** — for everything else (Flux, SDXL, Kling, Runway, Stability, etc.) plus a generic node for any fal.ai model by ID + +The ComfyOrg proxy and auth system is completely removed. No fallback to ComfyOrg keys. + +## Why This Approach + +- **Personal use** — no need for ComfyOrg billing proxy; direct API access is cheaper and more transparent +- **Two keys cover everything** — `GOOGLE_API_KEY` for Google models, `FAL_API_KEY` for the long tail via fal.ai +- **fal.ai as aggregator** — fal.ai hosts hundreds of models under one API key, avoiding the need for 20+ separate provider keys + +## Key Decisions + +### 1. Approach: Rewire Existing Nodes +Modify existing `nodes_*.py` files in-place rather than replacing them. This preserves node names and ComfyUI workflow compatibility. + +### 2. Two Providers Only +- **Google direct API**: `nodes_gemini.py`, `nodes_veo2.py` → hit `generativelanguage.googleapis.com` / Vertex AI endpoints +- **fal.ai**: All other `nodes_*.py` files → hit `api.fal.ai` with equivalent model IDs +- **New generic node**: A `nodes_fal.py` with a flexible node that accepts any fal.ai model ID + +### 3. API Key Storage: Environment Variables +- `GOOGLE_API_KEY` — for all Google model nodes +- `FAL_API_KEY` — for all fal.ai routed nodes +- Read via `os.environ` at request time (not startup), so keys can be set/changed without restart + +### 4. No ComfyOrg Auth +- Remove `auth_token_comfy_org` and `api_key_comfy_org` from hidden inputs +- Remove `get_auth_header()` ComfyOrg logic +- Remove proxy URL routing in `_request_base` + +### 5. Node-Level Key Status Indicators +Each node shows a green/red badge in the UI indicating whether its required API key environment variable is set. This is the "dev tools helper" — no separate settings page needed. + +## Scope: Provider Mapping + +| Current Node File | Current Provider | BYOK Target | Required Key | +|---|---|---|---| +| `nodes_gemini.py` | Google Gemini | Google API direct | `GOOGLE_API_KEY` | +| `nodes_veo2.py` | Google Veo 2/3 | Google API direct | `GOOGLE_API_KEY` | +| `nodes_bfl.py` | Black Forest Labs (Flux) | fal.ai | `FAL_API_KEY` | +| `nodes_openai.py` | OpenAI | fal.ai | `FAL_API_KEY` | +| `nodes_stability.py` | Stability AI | fal.ai | `FAL_API_KEY` | +| `nodes_runway.py` | Runway | fal.ai | `FAL_API_KEY` | +| `nodes_kling.py` | Kling | fal.ai | `FAL_API_KEY` | +| `nodes_luma.py` | Luma | fal.ai | `FAL_API_KEY` | +| `nodes_minimax.py` | MiniMax | fal.ai | `FAL_API_KEY` | +| `nodes_ideogram.py` | Ideogram | fal.ai | `FAL_API_KEY` | +| `nodes_recraft.py` | Recraft | fal.ai | `FAL_API_KEY` | +| `nodes_elevenlabs.py` | ElevenLabs | fal.ai | `FAL_API_KEY` | +| `nodes_sora.py` | OpenAI Sora | fal.ai | `FAL_API_KEY` | +| `nodes_meshy.py` | Meshy | fal.ai | `FAL_API_KEY` | +| `nodes_wavespeed.py` | WaveSpeed | fal.ai | `FAL_API_KEY` | +| `nodes_ltxv.py` | LTX Video | fal.ai | `FAL_API_KEY` | +| `nodes_bria.py` | Bria | fal.ai | `FAL_API_KEY` | +| `nodes_bytedance.py` | ByteDance | fal.ai | `FAL_API_KEY` | +| Others (`rodin`, `tripo`, `magnific`, `topaz`, `pixverse`, `moonvalley`, `wan`, `hunyuan3d`, `vidu`, `grok`, `hitpaw`) | Various | fal.ai (if available) | `FAL_API_KEY` | + +**Note:** Not all providers in this table have confirmed fal.ai equivalents. The planning phase must audit fal.ai's model catalog to determine which nodes get working fal.ai routes vs. an "unavailable" badge. + +## Technical Details (High Level) + +### Auth System Changes +- `comfy_api_nodes/util/_helpers.py`: Replace `get_auth_header()` with `get_provider_auth_header(provider: str)` that reads from env vars +- `comfy_api_nodes/util/client.py`: `_request_base` always uses absolute URLs (no more relative proxy paths) +- `comfy_api/latest/_io.py`: Remove `auth_token_comfy_org` / `api_key_comfy_org` from `Hidden` enum and `HiddenHolder` + +### Google API Integration +- Gemini: `https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent` +- Veo: `https://generativelanguage.googleapis.com/v1beta/models/{model}:predictLongRunning` (or Vertex AI equivalent) +- Auth: `x-goog-api-key: {GOOGLE_API_KEY}` header + +### fal.ai Integration +- Base URL: `https://fal.run/{model_id}` (synchronous) or `https://queue.fal.run/{model_id}` (async/queued) +- Auth header: `Authorization: Key {FAL_API_KEY}` +- Submit + poll pattern maps well to existing `poll_op` infrastructure +- Generic node: user provides model ID string, node passes through inputs as JSON + +### Node UI: Key Status Badge +- Each node checks `os.environ.get("GOOGLE_API_KEY")` or `os.environ.get("FAL_API_KEY")` at render time +- Display as a colored indicator (green = key present, red = missing) on the node widget +- Implementation likely in the node's `define_schema()` or as a custom widget + +## Resolved Questions + +1. **fal.ai model coverage gaps** → Keep nodes for unavailable providers but show a clear red "unavailable" badge. They exist in the UI but cannot run. + +2. **File upload path** → Use provider-native upload mechanisms. fal.ai has its own upload endpoint; Google accepts base64 inline. Each provider path handles uploads correctly for that provider. + +3. **Request/response model changes** → Rewrite Pydantic models to match real provider API schemas (Google API and fal.ai). More work upfront but correct and maintainable. diff --git a/docs/plans/2026-03-10-feat-byok-provider-migration-plan.md b/docs/plans/2026-03-10-feat-byok-provider-migration-plan.md new file mode 100644 index 000000000000..4a465de50353 --- /dev/null +++ b/docs/plans/2026-03-10-feat-byok-provider-migration-plan.md @@ -0,0 +1,546 @@ +--- +title: "feat: BYOK Provider Migration -- Replace ComfyOrg Proxy with Direct API Keys" +type: feat +status: active +date: 2026-03-10 +origin: docs/brainstorms/2026-03-10-byok-provider-migration-brainstorm.md +--- + +# BYOK Provider Migration + +## Enhancement Summary + +**Deepened on:** 2026-03-10 +**Agents used:** security-sentinel, performance-oracle, architecture-strategist, pattern-recognition-specialist, code-simplicity-reviewer, kieran-python-reviewer, fal.ai-polling-researcher, comfyui-frontend-researcher, gemini-imagegen-skill + +### Key Improvements +1. **Split Phase 1 into 1a/1b** for incremental migration (add new auth before removing old) +2. **Connection pooling** -- current codebase creates a new aiohttp session per request; must fix before migration +3. **Security hardening** -- domain allowlist for auth headers, model_id validation, header redaction, custom exception type +4. **Simplification** -- delete unavailable providers instead of maintaining dead code; use `dict` for fal.ai results instead of per-provider Pydantic; drop `check_api_key()` (let auth functions raise); consider `fal-client` SDK +5. **Combined `fal_run()` helper** -- wraps submit+poll+fetch into one call to reduce duplication across 20+ files +6. **Gemini JPEG gotcha** -- Gemini returns JPEG by default; must check `mime_type`, not assume PNG + +### Critical Findings from Research +- **Security C-1:** Request logger at `request_logger.py:103` writes API keys to disk in plaintext. Must implement header redaction BEFORE any new auth code. +- **Security C-3:** Generic fal.ai node SSRF risk -- `model_id` input must be validated against strict regex, and auth headers must only be sent to allowlisted domains. +- **Performance P-1:** `_request_base` creates a new `aiohttp.ClientSession` per request (line 628). Post-migration, this means fresh TLS handshake per poll cycle to 3-5 different hosts. Implement per-host connection pooling. +- **Architecture A-1:** Phase 1 as written breaks all nodes at once. Split into additive (1a) and removal (1b) sub-phases. + +--- + +## Overview + +Replace the ComfyOrg proxy (`api.comfy.org`) with bring-your-own-key direct API access. Two environment variables cover all models: + +- `GOOGLE_API_KEY` -- Gemini, Imagen, Veo (direct to `generativelanguage.googleapis.com`) +- `FAL_API_KEY` -- everything else via fal.ai (Flux, Kling, Luma, MiniMax, etc.) + +ComfyOrg auth is completely removed. Each node shows a green/red/grey badge indicating key status. Nodes without fal.ai equivalents are deleted. + +## Problem Statement + +This fork is for personal use. The ComfyOrg proxy adds an unnecessary billing middleman. Direct API keys are cheaper and more transparent. Two keys (Google + fal.ai) cover the full model catalog. + +## Proposed Solution + +Rewire all existing `nodes_*.py` files in-place to hit real provider APIs instead of the ComfyOrg proxy. Replace the auth system, upload helpers, and Pydantic models. Add a generic fal.ai node for arbitrary model IDs. Delete node files for providers not available on fal.ai. (see brainstorm: `docs/brainstorms/2026-03-10-byok-provider-migration-brainstorm.md`) + +## Technical Approach + +### Architecture + +``` +Before: After: +Node → ApiEndpoint(/proxy/...) Node → ApiEndpoint(https://queue.fal.run/...) + → _request_base → _request_base + → prepend api.comfy.org → use absolute URL as-is + → add ComfyOrg auth headers → add provider auth from env var + → ComfyOrg proxy → provider API directly + → actual provider +``` + +### Provider Routing Map + +| Current Node File | BYOK Target | fal.ai Model ID | Status | +|---|---|---|---| +| `nodes_gemini.py` | Google direct | N/A | Direct API | +| `nodes_veo2.py` | Google direct | N/A | Direct API | +| `nodes_bfl.py` | fal.ai | `fal-ai/flux-pro/v1.1-ultra`, `fal-ai/flux-kontext/pro`, etc. | Available | +| `nodes_openai.py` | fal.ai | `fal-ai/gpt-image-1/text-to-image`, `fal-ai/gpt-image-1.5/edit` | Available | +| `nodes_stability.py` | fal.ai | `fal-ai/stable-diffusion-v35-medium`, `fal-ai/fast-sdxl` | Available | +| `nodes_kling.py` | fal.ai | `fal-ai/kling-video/v2/master/image-to-video`, etc. | Available | +| `nodes_luma.py` | fal.ai | `fal-ai/luma-dream-machine/ray-2` | Available | +| `nodes_minimax.py` | fal.ai | `fal-ai/minimax/video-01-director` | Available | +| `nodes_ideogram.py` | fal.ai | `fal-ai/ideogram/v3` | Available | +| `nodes_recraft.py` | fal.ai | `fal-ai/recraft/v3/text-to-image` | Available | +| `nodes_elevenlabs.py` | fal.ai | `fal-ai/elevenlabs/tts/turbo-v2.5`, etc. | Available | +| `nodes_sora.py` | fal.ai | `fal-ai/sora-2/text-to-video` | Available | +| `nodes_meshy.py` | fal.ai | `fal-ai/meshy/v6/image-to-3d` | Available | +| `nodes_ltxv.py` | fal.ai | `fal-ai/ltx-video-v097` | Available | +| `nodes_bytedance.py` | fal.ai | `fal-ai/seedream-4.5` | Available | +| `nodes_pixverse.py` | fal.ai | `fal-ai/pixverse/v3.5/image-to-video` | Available | +| `nodes_wan.py` | fal.ai | `fal-ai/wan-pro/image-to-video` | Available | +| `nodes_hunyuan3d.py` | fal.ai | `fal-ai/hunyuan3d/v2` | Available | +| `nodes_vidu.py` | fal.ai | `fal-ai/vidu/q3-pro/image-to-video` | Available | +| `nodes_bria.py` | fal.ai | `fal-ai/bria/text-to-image/hd`, `fal-ai/bria/eraser`, `fal-ai/bria/background/remove` | Available | +| `nodes_rodin.py` | fal.ai | `fal-ai/hyper3d/rodin/v2` | Available | +| `nodes_runway.py` | N/A | Not on fal.ai | **Delete** | +| `nodes_tripo.py` | N/A | Not on fal.ai | **Delete** | +| `nodes_magnific.py` | N/A | Not on fal.ai | **Delete** | +| `nodes_topaz.py` | N/A | Not on fal.ai | **Delete** | +| `nodes_moonvalley.py` | N/A | Not on fal.ai | **Delete** | +| `nodes_grok.py` | N/A | Not on fal.ai | **Delete** | +| `nodes_hitpaw.py` | N/A | Not on fal.ai | **Delete** | +| `nodes_wavespeed.py` | N/A | Not on fal.ai | **Delete** | + +### Google API Endpoints + +All via `generativelanguage.googleapis.com` with `x-goog-api-key` header: + +| Operation | Endpoint | Response Pattern | +|---|---|---| +| Gemini text/image gen | `POST /v1beta/models/{model}:generateContent` | Sync -- `candidates[].content.parts[]` | +| Imagen image gen | `POST /v1beta/models/{model}:predict` | Sync -- `predictions[].bytesBase64Encoded` | +| Veo video gen (submit) | `POST /v1beta/models/{model}:predictLongRunning` | Returns `{"name": "models/.../operations/..."}` | +| Veo video gen (poll) | `GET /v1beta/{operation_name}` | `{"done": true/false, "response": {...}}` | +| Veo video download | `GET /v1beta/files/{id}:download?alt=media` | Binary video data | +| File upload | `POST /upload/v1beta/files` | Resumable upload, returns file URI | + +### fal.ai API Endpoints + +All via `queue.fal.run` (async) or `fal.run` (sync) with `Authorization: Key {FAL_API_KEY}`: + +| Operation | Endpoint | Response Pattern | +|---|---|---| +| Submit (async) | `POST queue.fal.run/{model_id}` | `{"request_id": "...", "status_url": "...", "response_url": "..."}` | +| Poll status | `GET queue.fal.run/{model_id}/requests/{id}/status` | `{"status": "IN_QUEUE"/"IN_PROGRESS"/"COMPLETED"}` | +| Fetch result | `GET queue.fal.run/{model_id}/requests/{id}` | Model-specific output (images, video, etc.) | +| Submit (sync) | `POST fal.run/{model_id}` | Direct result (for fast models only) | +| File upload | `POST rest.fal.ai/storage/upload/initiate?storage_type=fal-cdn-v3` | Returns `{upload_url, file_url}` | +| File upload (PUT) | `PUT {upload_url}` | Uploads to presigned URL | +| Cancel | `PUT queue.fal.run/{model_id}/requests/{id}/cancel` | Best-effort; only works while IN_QUEUE | + +### Implementation Phases + +#### Phase 1a: Add New Infrastructure (Additive) + +Add new auth helpers and upload functions **alongside** existing ones. No existing code is removed yet, so all current nodes continue working throughout this phase. + +**Files to modify:** + +1. **`comfy_api_nodes/util/_helpers.py`** (92 lines) + - Keep existing `get_auth_header()` and `default_base_url()` for now + - Add `get_google_auth_header() -> dict[str, str]`: reads `GOOGLE_API_KEY` from env, returns `{"x-goog-api-key": key}` + - Add `get_fal_auth_header() -> dict[str, str]`: reads `FAL_API_KEY` from env, returns `{"Authorization": f"Key {key}"}` + - Both raise `MissingApiKeyError` if key is missing or empty (after `.strip()`) + +2. **`comfy_api_nodes/util/common_exceptions.py`** + - Add `MissingApiKeyError(Exception)` -- do NOT use `EnvironmentError` (it aliases `OSError` and would be caught by existing `except (ClientError, OSError)` blocks in `client.py:811` and `upload_helpers.py:343`) + +3. **`comfy_api_nodes/util/client.py`** (951 lines) + - **Connection pooling** (CRITICAL): Replace per-request `aiohttp.ClientSession()` creation at line 628 with a per-host session registry. Current code creates and destroys a session for every HTTP call. Post-migration, each Veo poll cycle (5+ requests) would incur separate TLS handshakes to the same host. + - Add domain allowlist for auth headers: `x-goog-api-key` only sent to `*.googleapis.com`; `Authorization: Key` only sent to `*.fal.run`, `*.fal.ai`, `*.fal.media` + - Add fal.ai concurrency semaphore: `asyncio.Semaphore(2)` to prevent wasteful 429 retry storms when >2 fal.ai nodes execute concurrently + - `_friendly_http_message` (line 511-519): Add provider-appropriate messages for Google (403 = invalid API key) and fal.ai (401 = invalid key). Keep existing ComfyOrg messages until Phase 1b. + - Sanitize error response bodies before including in exceptions -- parse known provider error formats (Google's `error.message`, fal.ai's `detail[].msg`), strip raw JSON dumps that could contain reflected auth info + +4. **`comfy_api_nodes/util/request_logger.py`** (MUST DO FIRST) + - Implement `_redact_headers(headers: dict) -> dict` that replaces values of `Authorization`, `x-goog-api-key`, `X-API-KEY`, and any header matching `*key*`/`*token*` with `[REDACTED]` + - Apply to BOTH request headers (line 103) and response headers (line 114) + - This MUST be implemented before any new auth headers are introduced + +5. **`comfy_api_nodes/util/upload_helpers.py`** (388 lines) + - Add `upload_file_to_fal(file_bytes: BytesIO, mime_type: str) -> str`: POST to `rest.fal.ai/storage/upload/initiate`, PUT file to returned URL, return CDN URL. Validate returned `upload_url` domain matches `*.fal.ai`/`*.fal.run`. + - Add `upload_file_to_google(file_bytes: BytesIO, mime_type: str, display_name: str) -> str`: POST to Google Files API resumable upload, return `files/{name}` URI + - Add thin convenience wrappers: `upload_image_to_fal(cls, image_tensor) -> str` that handles tensor-to-BytesIO conversion + - Keep existing `upload_images_to_comfyapi` intact until Phase 1b + - Use streaming upload for files >10MB (don't read entire file into memory via `.read()`) + +6. **`comfy_api_nodes/util/download_helpers.py`** (298 lines) + - Add optional `headers: dict | None = None` parameter to `download_url_to_bytesio` so callers can pass provider-specific auth headers for authenticated downloads (Google requires `x-goog-api-key` on download URLs) + - Keep existing relative URL detection for now (removed in Phase 1b) + +7. **Startup key status log** (add to node init path) + ```python + for var in ("GOOGLE_API_KEY", "FAL_API_KEY"): + status = "configured" if os.environ.get(var, "").strip() else "NOT SET" + logging.info("BYOK: %s: %s", var, status) + ``` + +**Acceptance criteria for Phase 1a:** +- [x] `get_google_auth_header()` reads from `GOOGLE_API_KEY` env var, raises `MissingApiKeyError` if empty/unset +- [x] `get_fal_auth_header()` reads from `FAL_API_KEY` env var, raises `MissingApiKeyError` if empty/unset +- [x] API keys are redacted in request AND response logs +- [x] Auth headers only sent to allowlisted domains +- [x] Upload helpers for fal.ai and Google exist alongside old ones +- [x] Connection pooling implemented with per-host session registry +- [x] fal.ai concurrency semaphore initialized +- [x] All existing nodes still work (nothing removed yet) + +#### Phase 1b: Remove Old Infrastructure + +After Phases 2-4 migrate all nodes, remove the old ComfyOrg auth system. + +- [ ] Remove `get_auth_header()` and `default_base_url()` from `_helpers.py` (still referenced by client.py fallback path) +- [x] Remove `auth_token_comfy_org` and `api_key_comfy_org` from `Hidden` enum in `_io.py` +- [x] Remove auto-injection in `Schema.finalize()` (line 1524-1529) +- [ ] Make `_request_base` reject relative URLs (raise `ValueError`) +- [ ] Remove relative URL detection in `download_helpers.py` +- [x] Remove old ComfyOrg messages from `_friendly_http_message` (replaced with Google/fal.ai messages) +- [ ] Remove `_diagnose_connectivity` health check against `api.comfy.org/health` +- [ ] Delete old `upload_images_to_comfyapi` and related functions + +#### Phase 2: Google Direct Nodes + +Migrate Gemini and Veo nodes to hit Google APIs directly. + +**Files to modify:** + +1. **`comfy_api_nodes/nodes_gemini.py`** (1012 lines) + - Replace `GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini"` (line 48) with `GEMINI_BASE_URL = "https://generativelanguage.googleapis.com/v1beta/models"` + - Update all `ApiEndpoint(path=...)` calls to use absolute Google API URLs + - For text/image generation: `{GEMINI_BASE_URL}/{model}:generateContent` + - For Imagen: `{GEMINI_BASE_URL}/{model}:predict` + - Add `headers=get_google_auth_header()` to every `ApiEndpoint` + - Replace `upload_images_to_comfyapi()` calls (line 99-114) with inline base64 `inlineData` for all images (Google supports up to 100MB per request) + - Remove `uploadImagesToStorage` field usage + - Remove `hidden=[IO.Hidden.auth_token_comfy_org, IO.Hidden.api_key_comfy_org]` from all schemas + +2. **`comfy_api_nodes/apis/gemini.py`** (242 lines) + - Remove `uploadImagesToStorage` field from `GeminiImageGenerateContentRequest` (line 148) + - Verify all request/response models match Google's direct API format (they're already close) + - Add Imagen-specific request/response models (`ImagenPredictRequest`, `ImagenPredictResponse`) if not already present + +3. **`comfy_api_nodes/nodes_veo2.py`** (561 lines) + - Replace proxy paths `/proxy/veo/{model}/generate` with `{GEMINI_BASE_URL}/{model}:predictLongRunning` + - Replace polling path `/proxy/veo/{model}/poll` with `GET https://generativelanguage.googleapis.com/v1beta/{operation_name}` (where `operation_name` comes from the submit response) + - Add `headers=get_google_auth_header()` to all endpoints + - Update download to use `GET /v1beta/files/{id}:download?alt=media` with API key header (pass via new `headers` parameter on download helpers) + - For image inputs (first/last frame), use inline base64 instead of ComfyOrg upload + - Use adaptive polling: 2s for first 10s, 5s for 10-60s, 10s after 60s + +4. **`comfy_api_nodes/apis/veo.py`** (100 lines) + - Update `VeoGenVidResponse` to match direct API response: `{"name": "models/.../operations/..."}` instead of proxy wrapper + - Update `VeoGenVidPollResponse` to match direct API: `{"done": bool, "response": {"generateVideoResponse": {"generatedSamples": [...]}}}` or `{"error": {...}}` + - Update `VeoGenVidRequest` if the proxy was transforming the request format + +### Research Insights for Phase 2 + +**Gemini Image Generation Gotcha:** Gemini returns JPEG by default regardless of what you name the output file. Always check `part.inline_data.mime_type` -- do not assume PNG. Saving a JPEG as `.png` creates a JPEG with a PNG extension, causing "Image does not match media type" errors downstream. + +**Two image gen systems:** Gemini native image gen (`generateContent` with `responseModalities: ["IMAGE"]`) returns inline base64 in `candidates[].content.parts[].inline_data`. Imagen (`predict`) returns in `predictions[].bytesBase64Encoded`. These are different endpoints and response shapes. + +**File size strategy:** Use inline base64 for images <1MB (thumbnails, reference images). Use Google Files API resumable upload for anything larger. Current code splits at image 10 (first 10 as fileUri, rest as inline) -- similar threshold but base on size, not count. + +**Acceptance criteria for Phase 2:** +- [x] Gemini text generation works via direct Google API +- [x] Gemini image generation (Nano Banana) works via direct Google API +- [x] Imagen image generation works via `:predict` endpoint +- [x] Veo 2/3 video generation works via `:predictLongRunning` + poll +- [x] Image inputs use base64 inline data (no ComfyOrg upload) +- [ ] JPEG vs PNG mime type handled correctly in Gemini responses + +#### Phase 3: fal.ai Infrastructure + Generic Node + +Build the fal.ai integration layer and the generic fal.ai node. **Also resolve TBD entries** for `nodes_bria.py` and `nodes_rodin.py` before Phase 4 begins. + +**Files to create:** + +1. **`comfy_api_nodes/apis/fal.py`** (new file) + - `FalQueueSubmitResponse`: `request_id`, `response_url`, `status_url`, `cancel_url` + - `FalQueueStatusResponse`: `status` (IN_QUEUE/IN_PROGRESS/COMPLETED), `queue_position`, `logs`, `response_url` + - `FalError`: `detail` array with `loc`, `msg`, `type`, `ctx` + - Use Pydantic default `extra = "ignore"` (NOT `extra = "allow"` -- silently swallows typos) + +2. **`comfy_api_nodes/nodes_fal.py`** (new file) + - `FalGenericNode(IO.ComfyNode)`: accepts `model_id` (string), `input_json` (string, valid JSON), optional `image` (IMAGE tensor) + - **Validate `model_id`** against strict regex: `^[a-zA-Z0-9_-]+(/[a-zA-Z0-9_.-]+)*$` -- reject `..`, `://`, `?`, `#`, whitespace (prevents SSRF) + - Schema: `is_api_node=True` + - Execute: parse input JSON, upload image if provided, submit to fal.ai queue, poll until complete, download result images/video immediately (fal.ai CDN URLs expire after 7 days) + - `FalExtension(ComfyExtension)` + `comfy_entrypoint()` + +**fal.ai helpers -- add to `comfy_api_nodes/util/client.py`** (thin wrappers around existing `sync_op`/`poll_op`, NOT a separate file): + +3. **`fal_run(cls, model_id, data, *, estimated_duration=None) -> dict`** + - Combined submit + poll + fetch in one call. Wraps the three steps below. + - 20+ node files will call this; having it as one function prevents duplication. + +4. **`fal_submit(cls, model_id, data) -> FalQueueSubmitResponse`** + - `sync_op(cls, ApiEndpoint(f"https://queue.fal.run/{model_id}", "POST", headers=get_fal_auth_header()), data=data, response_model=FalQueueSubmitResponse)` + +5. **`fal_poll(cls, status_url) -> FalQueueStatusResponse`** + - Use existing `poll_op` with `status_extractor=lambda r: r["status"]`, `completed_statuses=["COMPLETED"]`, `failed_statuses=[]` + - Map fal statuses: `IN_QUEUE`/`IN_PROGRESS` → keep polling, `COMPLETED` → done + +6. **`fal_fetch_result(cls, response_url) -> dict`** + - `sync_op_raw` GET to response URL with fal auth header + - Returns `dict[str, Any]` (model-specific output) + +7. **Resolve TBD entries:** + - [ ] Verify `nodes_bria.py` against fal.ai model catalog -- assign model ID or move to delete list + - [ ] Verify `nodes_rodin.py` against fal.ai model catalog -- assign model ID or move to delete list + - TBD resolution is a **prerequisite** for Phase 4 + +### Research Insights for Phase 3 + +**Consider `fal-client` SDK:** The official `fal-client` Python SDK (`pip install fal-client`) handles retry, cancellation, timeout, and auth automatically. `fal_client.subscribe()` / `submit()` + `iter_events()` would replace the raw HTTP integration. Tradeoff: adds a dependency but eliminates manual queue management. For personal use, the SDK simplicity may be worth it. + +**fal.ai CDN URLs expire after 7 days.** Always download result bytes immediately after completion. Never persist fal.ai URLs as references. + +**Cancellation is best-effort.** `PUT .../cancel` only works when status is `IN_QUEUE`. Once processing starts, the request completes (and you're billed). Always attempt cancellation on user interrupt, but also stop your local polling loop. + +**Concurrency: 2 tasks on standard tier.** Additional requests are queued server-side (not rejected). The client-side semaphore from Phase 1a prevents wasteful 429 retries. + +**Acceptance criteria for Phase 3:** +- [x] Generic fal.ai node can submit to any model by ID +- [x] `fal_run()` submit/poll/fetch cycle works end-to-end +- [x] File upload to fal.ai CDN works (images, video, audio) +- [x] Generic node handles both image and video outputs +- [x] fal.ai errors are translated to user-friendly messages +- [x] `model_id` validated against strict regex (no SSRF) +- [x] TBD entries for `nodes_bria.py` and `nodes_rodin.py` resolved (both available on fal.ai) + +#### Phase 4: Migrate Non-Google Nodes to fal.ai + +Rewire each existing node file to use fal.ai instead of the ComfyOrg proxy. This is the largest phase by file count but follows a repeatable pattern. + +**Migration pattern per node file:** + +For each `nodes_*.py`: +1. Replace proxy endpoint paths with fal.ai model IDs (use module-level constants: `FAL_MODEL_FLUX_PRO_ULTRA = "fal-ai/flux-pro/v1.1-ultra"`) +2. Replace `sync_op`/`poll_op` calls with `fal_run()` (or `fal_submit`+`fal_poll`+`fal_fetch_result` for unusual flows) +3. Replace `upload_images_to_comfyapi` calls with `upload_file_to_fal` +4. Remove ComfyOrg hidden fields from schema +5. Update response parsing to extract from fal.ai's `dict` output (use plain dict access, not new Pydantic models) +6. Remove `price_badge` entries (ComfyOrg pricing no longer applies) +7. Download result URLs immediately (images, video, audio) -- don't store fal.ai CDN URLs + +**For each corresponding `apis/*.py`:** +- For fal.ai-routed providers, **delete the per-provider Pydantic models** (they modeled the ComfyOrg proxy's API). Use `dict[str, Any]` for fal.ai results and build request dicts inline. The fal.ai queue envelope models in `apis/fal.py` are shared by all fal.ai nodes. + +**Migration order (simple → complex):** + +Batch 1 -- Synchronous image generation (simplest, establishes patterns): +- [x] `nodes_recraft.py` → `fal-ai/recraft/v3/text-to-image` (6 nodes) -- placeholder paths, needs fal.ai schema verification +- [x] `nodes_ideogram.py` → `fal-ai/ideogram/v3` (4 nodes) -- fully migrated to fal_run +- [x] `nodes_bfl.py` → `fal-ai/flux-pro/v1.1-ultra`, `fal-ai/flux-kontext/pro`, etc. (10 nodes) -- fully migrated to fal_run +- [x] `nodes_stability.py` → `fal-ai/stable-diffusion-v35-medium` (5 nodes) -- fully migrated to fal_run +- [x] `nodes_bria.py` → `fal-ai/bria/text-to-image/hd` (4 nodes) -- fully migrated to fal_run + +Batch 2 -- OpenAI / Sora: +- [x] `nodes_openai.py` → `fal-ai/gpt-image-1/text-to-image` (4 nodes) -- fully migrated to fal_run +- [x] `nodes_sora.py` → `fal-ai/sora-2/text-to-video` (2 nodes) -- fully migrated to fal_run + +Batch 3 -- Video generation (async/poll): +- [x] `nodes_kling.py` → `fal-ai/kling-video/v2/master/*` (24 nodes) -- placeholder paths, needs fal.ai schema verification +- [x] `nodes_luma.py` → `fal-ai/luma-dream-machine/ray-2` (5 nodes) -- fully migrated to fal_run +- [x] `nodes_minimax.py` → `fal-ai/minimax/video-01-director` (4 nodes) -- fully migrated to fal_run +- [x] `nodes_ltxv.py` → `fal-ai/ltx-video-v097` (3 nodes) -- fully migrated to fal_run +- [x] `nodes_pixverse.py` → `fal-ai/pixverse/v3.5/*` (4 nodes) -- fully migrated to fal_run +- [x] `nodes_wan.py` → `fal-ai/wan-pro/*` (5 nodes) -- fully migrated to fal_run +- [x] `nodes_vidu.py` → `fal-ai/vidu/q3-pro/*` (13 nodes) -- fully migrated to fal_run +- [x] `nodes_bytedance.py` → `fal-ai/seedream-4.5` (4 nodes) -- fully migrated to fal_run + +Batch 4 -- Audio: +- [x] `nodes_elevenlabs.py` → `fal-ai/elevenlabs/tts/*` (7 nodes) -- placeholder paths, needs fal.ai schema verification + +Batch 5 -- 3D: +- [x] `nodes_meshy.py` → `fal-ai/meshy/v6/image-to-3d` (3 nodes) -- placeholder paths, needs fal.ai schema verification +- [x] `nodes_hunyuan3d.py` → `fal-ai/hunyuan3d/v2` (6 nodes) -- placeholder paths, needs fal.ai schema verification +- [x] `nodes_rodin.py` → `fal-ai/hyper3d/rodin/v2` (3 nodes) -- placeholder paths, needs fal.ai schema verification + +Batch 6 -- Delete unavailable providers: +- [x] Delete `nodes_runway.py` + `apis/runway.py` +- [x] Delete `nodes_tripo.py` + `apis/tripo.py` +- [x] Delete `nodes_magnific.py` + `apis/magnific.py` +- [x] Delete `nodes_topaz.py` + `apis/topaz.py` +- [x] Delete `nodes_moonvalley.py` + `apis/moonvalley.py` +- [x] Delete `nodes_grok.py` + `apis/grok.py` +- [x] Delete `nodes_hitpaw.py` + `apis/hitpaw.py` +- [x] Delete `nodes_wavespeed.py` + `apis/wavespeed.py` + +### Research Insights for Phase 4 + +**Drift risk:** `nodes_kling.py` (3277 lines, 24 nodes) is larger than all Batch 1 files combined. Migrating it will likely force infrastructure refinements that affect earlier-migrated nodes. Run a **consistency sweep** after the last file to ensure early and late batches follow the same patterns. + +**`nodes_recraft.py` exception:** Uses `multipart/form-data` file uploads via a custom `recraft_multipart_parser` (lines 73-119), not `upload_images_to_comfyapi`. The standard 7-step migration pattern doesn't cover this. Document as an exception. + +**fal.ai parameter names differ from original providers.** Each fal.ai model wrapper may use different field names than the native API. For example, Kling's native `model_name` field may be different in fal.ai's wrapper. Verify each model's fal.ai input schema during migration by checking `https://fal.ai/models/{model_id}/api`. + +#### Phase 5: Node-Level Key Status Badges + +Add green/red badges to each API node in the ComfyUI frontend. + +**Backend (`comfy_api/latest/_io.py`):** +- Add `required_api_key: str | None` to `Schema` and `NodeInfoV1` (static metadata) +- Add `api_key_status: str | None` to `NodeInfoV1`, computed server-side in `get_v1_info()`: + - If `os.environ.get(required_api_key, "").strip()` → `"configured"` (green badge) + - Else → `"missing"` (red badge) + +**Frontend (JS extension via `WEB_DIRECTORY`):** +- Register a `Comfy.KeyStatusBadge` extension using `nodeCreated` hook +- Read `api_key_status` from `node.constructor.nodeData` +- Push an `LGraphBadge` to `node.badges` array with appropriate color: + - Green (`#4CAF50`): key configured + - Red (`#f44336`): key missing +- This uses ComfyUI's existing badge system (same as price badges) -- no core frontend modification needed + +### Research Insights for Phase 5 + +**Do NOT create a separate `/api/key-status` endpoint.** The security review found it leaks key presence without authentication (ComfyUI has no auth middleware). Instead, embed `api_key_status` in the existing `/object_info` response via `NodeInfoV1`. The frontend already consumes this data for node rendering. + +**Use opaque provider names** in any exposed data: `"google"` / `"fal"`, not env var names like `"GOOGLE_API_KEY"`. + +**Badge refreshes on page reload only** (no polling needed). This is consistent with how other ComfyUI badges work. + +**Simplification alternative:** For personal use, the startup log from Phase 1a may be sufficient. Phase 5 can be deferred or skipped entirely -- the execute-time error from `MissingApiKeyError` already tells you when a key is missing. + +**Acceptance criteria for Phase 5:** +- [ ] Each API node displays correct badge color based on env var status +- [ ] Badge updates on page refresh +- [ ] No separate key-status endpoint (embedded in /object_info) + +#### Phase 6: Cleanup + +Final sweep for straggling references after all nodes are migrated and Phase 1b is complete. + +- [x] Remove `--comfy-api-base` CLI argument from `comfy/cli_args.py` +- [ ] Remove price badges (ComfyOrg proxy pricing no longer applies) +- [ ] Grep codebase for remaining references to `api.comfy.org`, `comfy_org`, `proxy/` -- remove all +- [ ] Clean up `apis/__init__.py` (auto-generated from ComfyOrg OpenAPI spec -- remove or regenerate) +- [ ] Delete unused old upload helper functions +- [ ] Delete `apis/*.py` files for fal.ai-routed providers (replaced by `dict` access) + +## System-Wide Impact + +### Interaction Graph + +1. Node `execute()` calls `get_google_auth_header()` or `get_fal_auth_header()` → reads `os.environ` → validates non-empty → constructs auth header +2. Domain allowlist check verifies auth header matches target URL domain +3. Node calls `sync_op` / `fal_run()` with absolute URL + auth headers → `_request_base` sends via pooled connection → provider API directly +4. For uploads: node calls `upload_file_to_fal()` or uses base64 inline → file reaches provider +5. For downloads: node calls `download_url_to_image_output()` with absolute URL + optional auth headers → downloads from provider CDN + +### Error Propagation + +- Missing env var → `get_*_auth_header()` raises `MissingApiKeyError` → caught by node → displayed in UI +- Empty env var → same `MissingApiKeyError` with message "...is set but empty. Please provide a valid API key." +- Invalid key → provider returns 401/403 → `_friendly_http_message` translates (Google: "Invalid API key", fal.ai: "Invalid API key") → displayed in UI +- Rate limit → provider returns 429 → existing retry logic in `_request_base` handles with exponential backoff +- fal.ai concurrency limit → client-side semaphore queues locally before sending +- Content policy → Google returns `candidates[].finishReason == "SAFETY"`, fal.ai returns `422` with `content_policy_violation` type → node-level handling +- Error response bodies sanitized before display (no raw JSON dumps with reflected auth info) + +### State Lifecycle Risks + +- No persistent state changes -- env vars are read-only, no database, no session state +- File uploads to fal.ai CDN are ephemeral (7-day default, configurable via `X-Fal-Object-Lifecycle-Preference` header) +- Google Files API uploads persist for 48 hours then auto-delete +- fal.ai result URLs must be downloaded immediately; never persisted as references + +### API Surface Parity + +- All `nodes_*.py` files share the same `ApiEndpoint` → `sync_op`/`fal_run()` → `_request_base` pipeline +- The auth change in `_helpers.py` affects every API node +- The upload change in `upload_helpers.py` affects 21 node files + +## Acceptance Criteria + +### Functional Requirements + +- [ ] All Google nodes (Gemini, Imagen, Veo) work with `GOOGLE_API_KEY` env var +- [ ] All fal.ai-routed nodes work with `FAL_API_KEY` env var +- [ ] Generic fal.ai node accepts arbitrary model ID and JSON input +- [ ] File uploads work via provider-native mechanisms +- [ ] Async/poll flows work for Veo and fal.ai queue +- [ ] Error messages are provider-appropriate (no ComfyOrg references) + +### Non-Functional Requirements + +- [ ] API keys are never logged in debug output (request AND response headers redacted) +- [ ] Missing/empty key produces a clear, actionable `MissingApiKeyError` +- [ ] Auth headers only sent to allowlisted provider domains (no SSRF) +- [ ] Generic fal.ai node validates `model_id` input (no path traversal) +- [ ] No references to `api.comfy.org` remain in the codebase +- [ ] All provider connections use default TLS verification (never `ssl=False`) + +### Quality Gates + +- [ ] Each migrated node tested with a real API call (at least one per provider) +- [ ] All existing node names preserved for workflow compatibility +- [ ] Generic fal.ai node tested with at least 3 different model IDs +- [ ] Consistency sweep after Phase 4 verifies uniform patterns across all migrated files + +## Dependencies & Prerequisites + +- **Google API key** with paid tier (required for Veo and Imagen) +- **fal.ai API key** (standard tier, 2 concurrent tasks) +- **Python venv** -- this fork runs in its own virtual environment (see setup below) +- No new external library dependencies needed -- existing `aiohttp` client handles all HTTP. Optionally add `fal-client` SDK for simplified fal.ai integration. + +**Time-sensitive:** Google Veo preview models (`veo-3.1-generate-preview`, `veo-3-generate-preview`, `veo-2-generate-preview`) are scheduled for deprecation on **April 2, 2026** (~3 weeks from plan date). Phase 2 should use preview model IDs initially but be prepared to update to GA model IDs as soon as they're published. + +### Virtual Environment Setup + +This fork uses its own isolated venv to avoid conflicts with system Python or other projects: + +```bash +cd /Users/mkorovkin/workplace/mimos/comfyui-custom-fork +python3 -m venv .venv +source .venv/bin/activate +pip install -r requirements.txt +# Optional: pip install fal-client (if using fal-client SDK) +``` + +All development and execution should happen within this venv. Add `.venv/` to `.gitignore`. + +## Risk Analysis & Mitigation + +| Risk | Impact | Mitigation | +|---|---|---| +| fal.ai model IDs change or become unavailable | Nodes break silently | Pin model IDs as module-level constants, easy to update | +| Google API response format changes | Pydantic validation fails | Use default `extra = "ignore"` on Pydantic models; define all consumed fields explicitly | +| fal.ai rate limit (2 concurrent tasks on standard tier) | Queue contention | Client-side `asyncio.Semaphore(2)` prevents wasteful retries; upgrade to premium if needed | +| Some nodes' fal.ai parameter names differ from original provider | Wrong inputs sent | Verify each model's fal.ai input schema at `fal.ai/models/{id}/api` during migration | +| Google Veo preview models deprecated April 2026 | Veo stops working | Monitor for GA model availability, update model IDs | +| SSRF via generic fal.ai node `model_id` | API key exfiltration | Strict regex validation + domain allowlist for auth headers | +| Gemini returns JPEG when PNG expected | Image format errors downstream | Always check `part.inline_data.mime_type`; convert explicitly if PNG needed | +| Connection pool exhaustion under load | Request failures | Configure `limit_per_host=10` for Google, `limit_per_host=4` for fal.ai | +| Early-migrated nodes drift from late-migrated patterns | Inconsistency | Consistency sweep after Phase 4; `nodes_kling.py` gets own sub-batch | + +## Sources & References + +### Origin + +- **Brainstorm document:** [docs/brainstorms/2026-03-10-byok-provider-migration-brainstorm.md](docs/brainstorms/2026-03-10-byok-provider-migration-brainstorm.md) -- Key decisions: BYOK for everything, env vars for keys, rewire existing nodes, node-level badges, delete unavailable providers + +### Internal References + +- Auth system: `comfy_api_nodes/util/_helpers.py:30-35` (`get_auth_header`) +- HTTP client: `comfy_api_nodes/util/client.py:572-891` (`_request_base`) +- Session creation (perf issue): `comfy_api_nodes/util/client.py:628` +- Hidden fields: `comfy_api/latest/_io.py:1322-1337` (`Hidden` enum) +- Schema finalize: `comfy_api/latest/_io.py:1524-1529` (auto-inject) +- Upload helpers: `comfy_api_nodes/util/upload_helpers.py:188-217` +- Download helpers: `comfy_api_nodes/util/download_helpers.py:62-67` +- Request logger (key leak): `comfy_api_nodes/util/request_logger.py:103-104` +- Node discovery: `nodes.py:2463-2472` (`init_builtin_api_nodes`) +- Gemini nodes: `comfy_api_nodes/nodes_gemini.py:48` (`GEMINI_BASE_ENDPOINT`) +- Veo nodes: `comfy_api_nodes/nodes_veo2.py` +- Error messages: `comfy_api_nodes/util/client.py:511-519` (`_friendly_http_message`) +- Custom exceptions: `comfy_api_nodes/util/common_exceptions.py` + +### External References + +- Google Generative AI API: https://ai.google.dev/gemini-api/docs +- Google Veo documentation: https://ai.google.dev/gemini-api/docs/video +- Google Files API: https://ai.google.dev/gemini-api/docs/files +- Google Imagen: https://ai.google.dev/gemini-api/docs/imagen +- fal.ai Queue API: https://docs.fal.ai/model-apis/model-endpoints/queue +- fal.ai Authentication: https://docs.fal.ai/reference/platform-apis/authentication +- fal.ai Error Reference: https://docs.fal.ai/model-apis/errors +- fal.ai Model Explorer: https://fal.ai/explore/models +- fal-client Python SDK: https://pypi.org/project/fal-client/ +- ComfyUI Extension API: https://docs.comfy.org/custom-nodes/js/javascript_objects_and_hijacking +- ComfyUI-fal-API (community integration): https://github.com/gokayfem/ComfyUI-fal-API diff --git a/main.py b/main.py index 8905fd09aeff..a7dfb523cbf1 100644 --- a/main.py +++ b/main.py @@ -2,6 +2,8 @@ comfy.options.enable_args_parsing() import os +from dotenv import load_dotenv +load_dotenv() # Load GOOGLE_API_KEY, FAL_API_KEY, etc. from .env import importlib.util import shutil import importlib.metadata diff --git a/nodes.py b/nodes.py index 0ef23b640521..52025ea2087f 100644 --- a/nodes.py +++ b/nodes.py @@ -2461,6 +2461,11 @@ async def init_builtin_extra_nodes(): async def init_builtin_api_nodes(): + # BYOK: Log API key status at startup + for _var in ("GOOGLE_API_KEY", "FAL_API_KEY"): + _status = "configured" if os.environ.get(_var, "").strip() else "NOT SET" + logging.info("BYOK: %s: %s", _var, _status) + api_nodes_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_api_nodes") api_nodes_files = sorted(glob.glob(os.path.join(api_nodes_dir, "nodes_*.py")))