forked from LykosAI/StabilityMatrix
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathTiledVAEModule.cs
More file actions
55 lines (47 loc) · 1.86 KB
/
TiledVAEModule.cs
File metadata and controls
55 lines (47 loc) · 1.86 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
using Injectio.Attributes;
using StabilityMatrix.Avalonia.Models.Inference;
using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels.Base;
using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Models.Api.Comfy.Nodes;
namespace StabilityMatrix.Avalonia.ViewModels.Inference.Modules;
[ManagedService]
[RegisterTransient<TiledVAEModule>]
public class TiledVAEModule : ModuleBase
{
public TiledVAEModule(IServiceManager<ViewModelBase> vmFactory)
: base(vmFactory)
{
Title = "Tiled VAE Decode";
AddCards(vmFactory.Get<TiledVAECardViewModel>());
}
protected override void OnApplyStep(ModuleApplyStepEventArgs e)
{
var card = GetCard<TiledVAECardViewModel>();
// Register a pre-output action that replaces standard VAE decode with tiled decode
e.PreOutputActions.Add(args =>
{
var builder = args.Builder;
// Only apply if primary is in latent space
if (builder.Connections.Primary?.IsT0 != true)
return;
var latent = builder.Connections.Primary.AsT0;
var vae = builder.Connections.GetDefaultVAE();
// Use tiled VAE decode instead of standard decode
var tiledDecode = builder.Nodes.AddTypedNode(
new ComfyNodeBuilder.TiledVAEDecode
{
Name = builder.Nodes.GetUniqueName("TiledVAEDecode"),
Samples = latent,
Vae = vae,
TileSize = card.TileSize,
Overlap = card.Overlap,
TemporalSize = card.TemporalSize,
TemporalOverlap = card.TemporalOverlap
}
);
// Update primary connection to the decoded image
builder.Connections.Primary = tiledDecode.Output;
});
}
}