Skip to content

Commit 21b6732

Browse files
committed
refactor: remove static ROCm profile forwarders, add TorchIndex-aware helper methods
- Remove static forwarder methods (GetCompatibility, HasSupport, ShouldApplyLaunchEnvironment/ShouldUseWindowsNativeInstall, BuildLaunchEnvironment, GetLaunchNoticeLines) from InvokeAiWindowsRocmProfile and ReforgeWindowsRocmProfile - Add ShouldApplyWindowsLaunchEnvironment(TorchIndex) and GetWindowsLaunchNoticeLines(TorchIndex) to IRocmPackageHelper and RocmPackageHelper; make no-arg overload private - Update InvokeAI and Reforge call sites to use IRocmPackageHelper directly - Migrate ComfyUI.EmitWindowsRocmLaunchNotice to use GetWindowsLaunchNoticeLines(TorchIndex), removing ShouldShowWindowsRocmLaunchNotice
1 parent 6ae3937 commit 21b6732

7 files changed

Lines changed: 38 additions & 108 deletions

File tree

StabilityMatrix.Core/Models/Packages/ComfyUI.cs

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -654,24 +654,13 @@ private void EmitWindowsRocmLaunchNotice(
654654
Action<ProcessOutput>? onConsoleOutput
655655
)
656656
{
657-
if (!ShouldShowWindowsRocmLaunchNotice(installedPackage))
658-
return;
659-
660-
foreach (var line in rocmPackageHelper.GetWindowsLaunchNoticeLines())
657+
var torchIndex = installedPackage.PreferredTorchIndex ?? GetRecommendedTorchVersion();
658+
foreach (var line in rocmPackageHelper.GetWindowsLaunchNoticeLines(torchIndex))
661659
{
662660
onConsoleOutput?.Invoke(ProcessOutput.FromStdOutLine($"{line}{Environment.NewLine}"));
663661
}
664662
}
665663

666-
private bool ShouldShowWindowsRocmLaunchNotice(InstalledPackage installedPackage)
667-
{
668-
if (!Compat.IsWindows || !HasWindowsRocmSupport())
669-
return false;
670-
671-
var torchIndex = installedPackage.PreferredTorchIndex ?? GetRecommendedTorchVersion();
672-
return torchIndex == TorchIndex.Rocm;
673-
}
674-
675664
protected ProcessArgs NormalizeLaunchArguments(
676665
InstalledPackage installedPackage,
677666
ProcessArgs fallbackArguments

StabilityMatrix.Core/Models/Packages/InvokeAI.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ public override TorchIndex GetRecommendedTorchVersion()
120120
return TorchIndex.Mps;
121121
}
122122

123-
if (Compat.IsWindows && InvokeAiWindowsRocmProfile.HasSupport(rocmPackageHelper))
123+
if (Compat.IsWindows && rocmPackageHelper.GetCompatibility().IsCompatible)
124124
{
125125
return TorchIndex.Rocm;
126126
}
@@ -235,7 +235,7 @@ public override async Task InstallPackage(
235235
var isSupportedWindowsRocmInstall =
236236
Compat.IsWindows
237237
&& torchVersion == TorchIndex.Rocm
238-
&& InvokeAiWindowsRocmProfile.HasSupport(rocmPackageHelper);
238+
&& rocmPackageHelper.GetCompatibility().IsCompatible;
239239
var isLegacyNvidiaGpu =
240240
SettingsManager.Settings.PreferredGpu?.IsLegacyNvidiaGpu() ?? HardwareHelper.HasLegacyNvidiaGpu();
241241
var fallbackIndex = torchVersion switch
@@ -662,17 +662,17 @@ InstalledPackage installedPackage
662662
env = GetEnvVars(env, installPath);
663663

664664
var selectedTorchIndex = installedPackage.PreferredTorchIndex ?? GetRecommendedTorchVersion();
665-
if (!InvokeAiWindowsRocmProfile.ShouldApplyLaunchEnvironment(rocmPackageHelper, selectedTorchIndex))
665+
if (!rocmPackageHelper.ShouldApplyWindowsLaunchEnvironment(selectedTorchIndex))
666666
{
667667
return env;
668668
}
669669

670-
return env.SetItems(InvokeAiWindowsRocmProfile.BuildLaunchEnvironment(rocmPackageHelper));
670+
return env.SetItems(rocmPackageHelper.BuildLaunchEnvironment(InvokeAiWindowsRocmProfile.Profile));
671671
}
672672

673673
private IReadOnlyList<string> GetLaunchNoticeLines(InstalledPackage installedPackage)
674674
{
675675
var selectedTorchIndex = installedPackage.PreferredTorchIndex ?? GetRecommendedTorchVersion();
676-
return InvokeAiWindowsRocmProfile.GetLaunchNoticeLines(rocmPackageHelper, selectedTorchIndex);
676+
return rocmPackageHelper.GetWindowsLaunchNoticeLines(selectedTorchIndex);
677677
}
678678
}

StabilityMatrix.Core/Models/Packages/Reforge.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ public override TorchIndex GetRecommendedTorchVersion()
8282
{
8383
var preferRocm =
8484
(Compat.IsLinux && (SettingsManager.Settings.PreferredGpu?.IsAmd ?? HardwareHelper.PreferRocm()))
85-
|| ReforgeWindowsRocmProfile.HasSupport(rocmPackageHelper);
85+
|| rocmPackageHelper.GetCompatibility().IsCompatible;
8686

8787
if (AvailableTorchIndices.Contains(TorchIndex.Rocm) && preferRocm)
8888
{
@@ -103,7 +103,7 @@ public override async Task InstallPackage(
103103
{
104104
var torchIndex = options.PythonOptions.TorchIndex ?? GetRecommendedTorchVersion();
105105

106-
if (!ReforgeWindowsRocmProfile.ShouldUseWindowsNativeInstall(rocmPackageHelper, torchIndex))
106+
if (!rocmPackageHelper.ShouldApplyWindowsLaunchEnvironment(torchIndex))
107107
{
108108
await base.InstallPackage(
109109
installLocation,
@@ -162,17 +162,17 @@ InstalledPackage installedPackage
162162
env = base.GetEnvVars(env, installedPackage);
163163
env = env.SetItem("STABLE_DIFFUSION_REPO", StableDiffusionRepoOverride);
164164

165-
if (!ReforgeWindowsRocmProfile.ShouldUseWindowsNativeInstall(rocmPackageHelper, selectedTorchIndex))
165+
if (!rocmPackageHelper.ShouldApplyWindowsLaunchEnvironment(selectedTorchIndex))
166166
{
167167
return env;
168168
}
169169

170-
return env.SetItems(ReforgeWindowsRocmProfile.BuildLaunchEnvironment(rocmPackageHelper));
170+
return env.SetItems(rocmPackageHelper.BuildLaunchEnvironment(ReforgeWindowsRocmProfile.Profile));
171171
}
172172

173173
protected override IReadOnlyList<string> GetLaunchNoticeLines(InstalledPackage installedPackage)
174174
{
175175
var selectedTorchIndex = installedPackage.PreferredTorchIndex ?? GetRecommendedTorchVersion();
176-
return ReforgeWindowsRocmProfile.GetLaunchNoticeLines(rocmPackageHelper, selectedTorchIndex);
176+
return rocmPackageHelper.GetWindowsLaunchNoticeLines(selectedTorchIndex);
177177
}
178178
}
Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
1-
using StabilityMatrix.Core.Helper;
21
using StabilityMatrix.Core.Models.Packages;
32
using StabilityMatrix.Core.Python;
4-
using StabilityMatrix.Core.Services.Rocm;
53

64
namespace StabilityMatrix.Core.Models.Rocm;
75

@@ -37,44 +35,4 @@ public static RocmPackageProfile CreateInstallProfile(PyVersion pyVersion)
3735
InstallConfig = new PipInstallConfig { PostTorchInstallPipArgs = [TritonWindowsPackage] },
3836
};
3937
}
40-
41-
public static RocmCompatibilityResult GetCompatibility(IRocmPackageHelper rocmPackageHelper)
42-
{
43-
return rocmPackageHelper.GetCompatibility();
44-
}
45-
46-
public static bool HasSupport(IRocmPackageHelper rocmPackageHelper)
47-
{
48-
return GetCompatibility(rocmPackageHelper).IsCompatible;
49-
}
50-
51-
public static bool ShouldApplyLaunchEnvironment(
52-
IRocmPackageHelper rocmPackageHelper,
53-
TorchIndex selectedTorchIndex
54-
)
55-
{
56-
if (!Compat.IsWindows || selectedTorchIndex != TorchIndex.Rocm)
57-
{
58-
return false;
59-
}
60-
61-
return GetCompatibility(rocmPackageHelper).IsCompatible;
62-
}
63-
64-
public static IReadOnlyDictionary<string, string> BuildLaunchEnvironment(
65-
IRocmPackageHelper rocmPackageHelper
66-
)
67-
{
68-
return rocmPackageHelper.BuildLaunchEnvironment(Profile);
69-
}
70-
71-
public static IReadOnlyList<string> GetLaunchNoticeLines(
72-
IRocmPackageHelper rocmPackageHelper,
73-
TorchIndex selectedTorchIndex
74-
)
75-
{
76-
return ShouldApplyLaunchEnvironment(rocmPackageHelper, selectedTorchIndex)
77-
? rocmPackageHelper.GetWindowsLaunchNoticeLines()
78-
: [];
79-
}
8038
}

StabilityMatrix.Core/Models/Rocm/ReforgeWindowsRocmProfile.cs

Lines changed: 2 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -47,30 +47,12 @@ public static RocmPackageProfile CreateProfile(IEnumerable<string> requirementsF
4747
};
4848
}
4949

50-
public static RocmCompatibilityResult GetCompatibility(IRocmPackageHelper rocmPackageHelper)
51-
{
52-
return rocmPackageHelper.GetCompatibility();
53-
}
54-
55-
public static bool HasSupport(IRocmPackageHelper rocmPackageHelper)
56-
{
57-
return GetCompatibility(rocmPackageHelper).IsCompatible;
58-
}
59-
60-
public static bool ShouldUseWindowsNativeInstall(
61-
IRocmPackageHelper rocmPackageHelper,
62-
TorchIndex selectedTorchIndex
63-
)
64-
{
65-
return Compat.IsWindows && selectedTorchIndex == TorchIndex.Rocm && HasSupport(rocmPackageHelper);
66-
}
67-
6850
public static void ApplyWindowsRocmLaunchDefaults(
6951
List<LaunchOptionDefinition> launchOptions,
7052
IRocmPackageHelper rocmPackageHelper
7153
)
7254
{
73-
if (!(Compat.IsWindows && HasSupport(rocmPackageHelper)))
55+
if (!(Compat.IsWindows && rocmPackageHelper.GetCompatibility().IsCompatible))
7456
{
7557
return;
7658
}
@@ -89,7 +71,7 @@ IRocmPackageHelper rocmPackageHelper
8971

9072
public static string? GetPreferredCrossAttentionArgument(IRocmPackageHelper rocmPackageHelper)
9173
{
92-
var compatibility = GetCompatibility(rocmPackageHelper);
74+
var compatibility = rocmPackageHelper.GetCompatibility();
9375
if (!compatibility.IsCompatible)
9476
{
9577
return null;
@@ -99,21 +81,4 @@ IRocmPackageHelper rocmPackageHelper
9981
? "--attention-quad"
10082
: "--attention-pytorch";
10183
}
102-
103-
public static IReadOnlyDictionary<string, string> BuildLaunchEnvironment(
104-
IRocmPackageHelper rocmPackageHelper
105-
)
106-
{
107-
return rocmPackageHelper.BuildLaunchEnvironment(Profile);
108-
}
109-
110-
public static IReadOnlyList<string> GetLaunchNoticeLines(
111-
IRocmPackageHelper rocmPackageHelper,
112-
TorchIndex selectedTorchIndex
113-
)
114-
{
115-
return ShouldUseWindowsNativeInstall(rocmPackageHelper, selectedTorchIndex)
116-
? rocmPackageHelper.GetWindowsLaunchNoticeLines()
117-
: [];
118-
}
11984
}

StabilityMatrix.Core/Services/Rocm/IRocmPackageHelper.cs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,16 @@ public interface IRocmPackageHelper
2323
IReadOnlyDictionary<string, string> BuildLaunchEnvironment(RocmPackageProfile profile);
2424

2525
/// <summary>
26-
/// Returns shared Windows ROCm launch notice lines for helper-managed packages.
26+
/// Returns shared Windows ROCm launch notice lines if the current machine and selected torch index
27+
/// qualify for the Windows native ROCm launch environment; otherwise returns an empty list.
2728
/// </summary>
28-
IReadOnlyList<string> GetWindowsLaunchNoticeLines();
29+
IReadOnlyList<string> GetWindowsLaunchNoticeLines(TorchIndex selectedTorchIndex);
30+
31+
/// <summary>
32+
/// Returns true when the current machine is Windows, the selected torch index is ROCm,
33+
/// and the machine is compatible with Windows native ROCm.
34+
/// </summary>
35+
bool ShouldApplyWindowsLaunchEnvironment(TorchIndex selectedTorchIndex);
2936

3037
/// <summary>
3138
/// Ensures a usable Windows ROCm SDK devel package is installed from the ROCm multi-arch index,

StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,14 +101,25 @@ public IReadOnlyDictionary<string, string> BuildLaunchEnvironment(RocmPackagePro
101101
return mergedEnvironment;
102102
}
103103

104-
/// <summary>
105-
/// Returns the shared informational notice lines shown when launching Windows ROCm packages.
106-
/// </summary>
107-
public IReadOnlyList<string> GetWindowsLaunchNoticeLines()
104+
private IReadOnlyList<string> GetWindowsLaunchNoticeLines()
108105
{
109106
return WindowsLaunchNoticeLines;
110107
}
111108

109+
/// <inheritdoc />
110+
public bool ShouldApplyWindowsLaunchEnvironment(TorchIndex selectedTorchIndex)
111+
{
112+
if (!Compat.IsWindows || selectedTorchIndex != TorchIndex.Rocm)
113+
return false;
114+
return GetCompatibility().IsCompatible;
115+
}
116+
117+
/// <inheritdoc />
118+
public IReadOnlyList<string> GetWindowsLaunchNoticeLines(TorchIndex selectedTorchIndex)
119+
{
120+
return ShouldApplyWindowsLaunchEnvironment(selectedTorchIndex) ? GetWindowsLaunchNoticeLines() : [];
121+
}
122+
112123
/// <summary>
113124
/// Ensures <c>rocm-sdk-devel</c> is installed from the ROCm multi-arch index.
114125
/// It prefers a build whose date token matches the installed ROCm torch build and falls back to the latest available build when no exact match is available.

0 commit comments

Comments
 (0)