forked from LykosAI/StabilityMatrix
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathPipInstallArgs.cs
More file actions
150 lines (121 loc) · 5.21 KB
/
PipInstallArgs.cs
File metadata and controls
150 lines (121 loc) · 5.21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
using System.Collections.Immutable;
using System.Diagnostics.CodeAnalysis;
using System.Diagnostics.Contracts;
using System.Text.RegularExpressions;
using StabilityMatrix.Core.Extensions;
using StabilityMatrix.Core.Processes;
namespace StabilityMatrix.Core.Python;
[SuppressMessage("ReSharper", "StringLiteralTypo")]
public partial record PipInstallArgs : ProcessArgsBuilder
{
public PipInstallArgs(params Argument[] arguments)
: base(arguments) { }
public PipInstallArgs WithTorch(string version = "") =>
this.AddArg(new Argument("torch", $"torch{version}"));
public PipInstallArgs WithTorchDirectML(string version = "") =>
this.AddArg(new Argument("torch-directml", $"torch-directml{version}"));
public PipInstallArgs WithTorchVision(string version = "") =>
this.AddArg(new Argument("torchvision", $"torchvision{version}"));
public PipInstallArgs WithTorchAudio(string version = "") =>
this.AddArg(new Argument("torchaudio", $"torchaudio{version}"));
public PipInstallArgs WithXFormers(string version = "") =>
this.AddArg(new Argument("xformers", $"xformers{version}"));
public PipInstallArgs WithExtraIndex(string indexUrl) =>
this.AddKeyedArgs("--extra-index-url", ["--extra-index-url", indexUrl]);
public PipInstallArgs WithTorchExtraIndex(string index) =>
WithExtraIndex($"https://download.pytorch.org/whl/{index}");
public PipInstallArgs WithParsedFromRequirementsTxt(
string requirements,
[StringSyntax(StringSyntaxAttribute.Regex)] string? excludePattern = null
)
{
var requirementsEntries = requirements
.SplitLines(StringSplitOptions.TrimEntries | StringSplitOptions.RemoveEmptyEntries)
.Where(s => !s.StartsWith('#'))
.Select(s => s.Contains('#') ? s.Substring(0, s.IndexOf('#')) : s)
.Select(s => s.Trim())
.Where(s => !string.IsNullOrWhiteSpace(s))
.Select(NormalizePackageSpecifier);
if (excludePattern is not null)
{
var excludeRegex = new Regex($"^{excludePattern}$");
requirementsEntries = requirementsEntries.Where(s => !excludeRegex.IsMatch(s));
}
return this.AddArgs(requirementsEntries.Select(ToRequirementArgument).ToArray());
}
private static Argument ToRequirementArgument(string requirementEntry)
{
if (requirementEntry.StartsWith('-'))
return Argument.Quoted(requirementEntry);
return new Argument(requirementEntry);
}
/// <summary>
/// Normalizes a package specifier by removing spaces around version constraint operators.
/// </summary>
/// <param name="specifier">The package specifier to normalize.</param>
/// <returns>The normalized package specifier.</returns>
private static string NormalizePackageSpecifier(string specifier)
{
// Skip normalization for special pip commands that start with a hyphen
if (specifier.StartsWith('-'))
return specifier;
// Regex to match common version constraint patterns with spaces
// Matches: package >= 1.0.0, package <= 1.0.0, package == 1.0.0, etc.
var versionConstraintPattern = PackageSpecifierRegex();
var match = versionConstraintPattern.Match(specifier);
if (match.Success)
{
var packageName = match.Groups[1].Value;
var versionOperator = match.Groups[2].Value;
var version = match.Groups[3].Value;
return $"{packageName}{versionOperator}{version}";
}
return specifier;
}
public PipInstallArgs WithUserOverrides(List<PipPackageSpecifierOverride> overrides)
{
var newArgs = this;
foreach (var pipOverride in overrides)
{
if (string.IsNullOrWhiteSpace(pipOverride.Name))
continue;
if (pipOverride.Name is "--extra-index-url" or "--index-url")
{
pipOverride.Constraint = "=";
}
var pipOverrideArg = pipOverride.ToArgument();
if (pipOverride.Action is PipPackageSpecifierOverrideAction.Update)
{
newArgs = newArgs.RemovePipArgKey(pipOverrideArg.Key ?? pipOverrideArg.Value);
newArgs = newArgs.AddArg(pipOverrideArg);
}
else if (pipOverride.Action is PipPackageSpecifierOverrideAction.Remove)
{
newArgs = newArgs.RemovePipArgKey(pipOverrideArg.Key ?? pipOverrideArg.Value);
}
}
return newArgs;
}
[Pure]
public PipInstallArgs RemovePipArgKey(string argumentKey)
{
return this with
{
Arguments = Arguments
.Where(
arg =>
arg.HasKey
? (arg.Key != argumentKey)
: (arg.Value != argumentKey && !arg.Value.Contains($"{argumentKey}=="))
)
.ToImmutableList()
};
}
/// <inheritdoc />
public override string ToString()
{
return base.ToString();
}
[GeneratedRegex(@"^([a-zA-Z0-9\-_.]+)\s*(>=|<=|==|>|<|!=|~=)\s*(.+)$")]
private static partial Regex PackageSpecifierRegex();
}