Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ai_diffusion/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ class CustomStyleInput:
class CustomWorkflowInput:
workflow: dict
params: dict[str, Any]
selection_bounds: Bounds
positive_evaluated: str = ""
negative_evaluated: str = ""
models: CheckpointInput | None = None
Expand Down
30 changes: 19 additions & 11 deletions ai_diffusion/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,12 @@ def filename(self) -> str:
def check_color_mode(self) -> tuple[Literal[True], None] | tuple[Literal[False], str]:
return True, None

def user_selection_bounds(self) -> Bounds | None:
raise NotImplementedError

def create_mask_from_selection(
self, mod: SelectionModifiers
) -> tuple[Mask, Bounds] | tuple[None, None]:
self, selection_bounds: Bounds | None, mod: SelectionModifiers
) -> Mask | None:
raise NotImplementedError

def get_image(
Expand Down Expand Up @@ -196,20 +199,25 @@ def check_color_mode(self):
return False, msg_fmt.format("depth", "8-bit integer", depth)
return True, None

def create_mask_from_selection(self, mod: SelectionModifiers):
def user_selection_bounds(self):
user_selection = self._doc.selection()
if not user_selection:
return None, None
return None

if _selection_is_entire_document(user_selection, self.extent):
return None, None
return None

selection = user_selection.duplicate()
original_bounds = Bounds(
selection.x(), selection.y(), selection.width(), selection.height()
bounds = Bounds(
user_selection.x(), user_selection.y(), user_selection.width(), user_selection.height()
)
original_bounds = Bounds.clamp(original_bounds, self.extent)
size_factor = original_bounds.extent.diagonal
return Bounds.clamp(bounds, self.extent)

def create_mask_from_selection(self, selection_bounds: Bounds | None, mod: SelectionModifiers):
if selection_bounds is None:
return None

selection = self._doc.selection().duplicate()
size_factor = selection_bounds.extent.diagonal
pad_px = max(int(mod.feather_rel * size_factor), mod.feather_min_px)
pad_px += mod.pad_offset_px
pad_px += int(mod.pad_rel * size_factor)
Expand All @@ -223,7 +231,7 @@ def create_mask_from_selection(self, mod: SelectionModifiers):
)
bounds = Bounds.clamp(bounds, self.extent)
data = selection.pixelData(*bounds)
return Mask(bounds, data), original_bounds
return Mask(bounds, data)

def get_image(self, bounds: Bounds | None = None, exclude_layers: list[Layer] | None = None):
excluded: list[Layer] = []
Expand Down
19 changes: 14 additions & 5 deletions ai_diffusion/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,8 @@ def _prepare_workflow(self, dryrun=False):
region_layer = None

smod = get_selection_modifiers(arch, self.inpaint.mode, strength)
mask, selection_bounds = self._doc.create_mask_from_selection(smod)
selection_bounds = self._doc.user_selection_bounds()
mask = self._doc.create_mask_from_selection(selection_bounds, smod)
bounds = Bounds(0, 0, *extent)
if mask is None: # Check for region inpaint
region_layer = regions.get_active_region_layer(use_parent=not self.region_only)
Expand Down Expand Up @@ -456,7 +457,8 @@ def _prepare_live_workflow(self):

image = None
smod = get_selection_modifiers(self.arch, inpaint.mode, strength, min_mask_size)
mask, selection_bounds = self._doc.create_mask_from_selection(smod)
selection_bounds = self._doc.user_selection_bounds()
mask = self._doc.create_mask_from_selection(selection_bounds, smod)
inpaint = calc_selection_pre_process(inpaint, selection_bounds, smod)

bounds = Bounds(0, 0, *self._doc.extent)
Expand Down Expand Up @@ -517,11 +519,16 @@ async def _generate_custom(self, previous_input: WorkflowInput | None):
bounds = canvas_bounds
mask = None

select_bounds = self._doc.user_selection_bounds()

if selection_node := next(wf.find(type="ETN_KritaSelection"), None):
mods = get_selection_modifiers(Arch.sdxl, InpaintMode.fill, self.strength)
mask, select_bounds = self._doc.create_mask_from_selection(mods)
mask = self._doc.create_mask_from_selection(select_bounds, mods)
mask, bounds = self.custom.prepare_mask(selection_node, mask, select_bounds, bounds)

if select_bounds is None:
select_bounds = canvas_bounds

img_input = ImageInput.from_extent(bounds.extent)
img_input.initial_image = self._get_current_image(bounds, exclude_internal=not is_live)
img_input.hires_mask = mask.to_image(bounds.extent) if mask else None
Expand All @@ -530,7 +537,8 @@ async def _generate_custom(self, previous_input: WorkflowInput | None):
self.layers, canvas_bounds, client.models, is_live, is_anim
)

custom_input = CustomWorkflowInput(wf.root, params)
custom_input = CustomWorkflowInput(wf.root, params, select_bounds)

metadata: dict[str, Any] = dict(self.custom.params)
job_params = JobParams(bounds, self.custom.job_name, metadata=metadata)

Expand Down Expand Up @@ -607,7 +615,8 @@ def generate_control_layer(self, control: ControlLayer):

try:
image = doc.get_image(Bounds(0, 0, *self._doc.extent))
mask, _ = doc.create_mask_from_selection(SelectionModifiers(pad_rel=0.25, multiple=64))
selection_bounds = doc.user_selection_bounds()
mask = doc.create_mask_from_selection(selection_bounds, SelectionModifiers(pad_rel=0.25, multiple=64))
bounds = mask.bounds if mask else None
perf = self._performance_settings(self._connection.client)
input = workflow.prepare_create_control_image(image, control.mode, perf, bounds)
Expand Down
5 changes: 5 additions & 0 deletions ai_diffusion/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1455,6 +1455,11 @@ def get_param(node: ComfyNode, expected_type: type | tuple[type, type] | None =
outputs[node.output(1)] = images.hires_mask is not None
outputs[node.output(2)] = bounds.x
outputs[node.output(3)] = bounds.y
case "ETN_KritaSelectionBounds":
outputs[node.output(0)] = input.selection_bounds.x
outputs[node.output(1)] = input.selection_bounds.y
outputs[node.output(2)] = input.selection_bounds.width
outputs[node.output(3)] = input.selection_bounds.height
case "ETN_Parameter":
outputs[node.output(0)] = get_param(node)
case "ETN_KritaImageLayer":
Expand Down
8 changes: 4 additions & 4 deletions tests/test_custom_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,7 @@ def test_expand():
"style": style_input,
}

input = CustomWorkflowInput(workflow=ext.root, params=params)
input = CustomWorkflowInput(workflow=ext.root, params=params, selection_bounds=Bounds(0, 0, width, height))
images = ImageInput.from_extent(Extent(4, 4))
images.initial_image = Image.create(Extent(4, 4), Qt.GlobalColor.red)

Expand Down Expand Up @@ -628,7 +628,7 @@ def test_expand_animation():
"mask": in_masks,
}

input = CustomWorkflowInput(workflow=ext.root, params=params)
input = CustomWorkflowInput(workflow=ext.root, params=params, selection_bounds=Bounds(0, 0, 4, 4))
images = ImageInput.from_extent(Extent(4, 4))
models = ClientModels()

Expand Down Expand Up @@ -682,11 +682,11 @@ def test_expand_selection():
)

params = {}
input = CustomWorkflowInput(workflow=ext.root, params=params)
bounds = Bounds(2, 3, 8, 16) # selection from (2,2) to (6,6)
input = CustomWorkflowInput(workflow=ext.root, params=params, selection_bounds=bounds)
images = ImageInput.from_extent(Extent(8, 16))
images.initial_image = Image.create(Extent(8, 16), Qt.GlobalColor.red)
images.hires_mask = Image.create(Extent(8, 16), Qt.GlobalColor.green)
bounds = Bounds(2, 3, 8, 16) # selection from (2,2) to (6,6)
models = ClientModels()

w = ComfyWorkflow()
Expand Down
5 changes: 3 additions & 2 deletions tests/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1091,12 +1091,13 @@ def test_custom_workflow(qtapp, local_client: Client):
"2. Detail/2. Steps": 14,
"2. Detail/4. CFG": 3.5,
}
bounds = Bounds(0, 0, 512, 512)
job = WorkflowInput(
WorkflowKind.custom,
images=ImageInput.from_extent(Extent(512, 512)),
sampling=SamplingInput("custom", "custom", 1, 1000, seed=1234),
inpaint=InpaintParams(InpaintMode.fill, Bounds(0, 0, 512, 512)),
custom_workflow=CustomWorkflowInput(workflow_graph.root, params),
inpaint=InpaintParams(InpaintMode.fill, bounds),
custom_workflow=CustomWorkflowInput(workflow_graph.root, params, bounds),
)
assert job.images is not None
job.images.initial_image = Image.create(Extent(512, 512))
Expand Down