|
| 1 | +using Microsoft.AI.Foundry.Local; |
| 2 | +using Betalgo.Ranul.OpenAI.ObjectModels.RequestModels; |
| 3 | +using System.Security.Cryptography; |
| 4 | +using System.Text; |
| 5 | +using System.Text.Json; |
| 6 | + |
| 7 | +// --------------------------------------------------------------------------- |
| 8 | +// Private Catalog sample — registers a customer MDS catalog with a self-signed |
| 9 | +// JWT, lists models (public + private), lets you pick one, and runs a streaming |
| 10 | +// chat completion. |
| 11 | +// |
| 12 | +// Usage: |
| 13 | +// PrivateCatalog (interactive — pick from list) |
| 14 | +// PrivateCatalog --model phi-4 (pick by alias) |
| 15 | +// PrivateCatalog --model Phi-4-generic-cpu:1 (pick by exact variant id) |
| 16 | +// PrivateCatalog --list (list models and exit) |
| 17 | +// PrivateCatalog --customer cust2 (override MdsCustomer) |
| 18 | +// PrivateCatalog --prompt "Hello!" (custom prompt) |
| 19 | +// --------------------------------------------------------------------------- |
| 20 | +string? cliModel = null; |
| 21 | +string cliPrompt = "Why is the sky blue?"; |
| 22 | +bool listOnly = false; |
| 23 | +string? cliCustomer = null; |
| 24 | + |
| 25 | +for (int i = 0; i < args.Length; i++) |
| 26 | +{ |
| 27 | + switch (args[i]) |
| 28 | + { |
| 29 | + case "-m": |
| 30 | + case "--model": |
| 31 | + if (i + 1 < args.Length) cliModel = args[++i]; |
| 32 | + else { Console.WriteLine("Error: --model requires a value."); return; } |
| 33 | + break; |
| 34 | + case "-p": |
| 35 | + case "--prompt": |
| 36 | + if (i + 1 < args.Length) cliPrompt = args[++i]; |
| 37 | + else { Console.WriteLine("Error: --prompt requires a value."); return; } |
| 38 | + break; |
| 39 | + case "-c": |
| 40 | + case "--customer": |
| 41 | + if (i + 1 < args.Length) cliCustomer = args[++i]; |
| 42 | + else { Console.WriteLine("Error: --customer requires a value."); return; } |
| 43 | + break; |
| 44 | + case "-l": |
| 45 | + case "--list": |
| 46 | + listOnly = true; |
| 47 | + break; |
| 48 | + case "-h": |
| 49 | + case "--help": |
| 50 | + Console.WriteLine("Usage: PrivateCatalog [options]"); |
| 51 | + Console.WriteLine(" -m, --model <name> Model alias or variant id"); |
| 52 | + Console.WriteLine(" -c, --customer <name> Customer name (default: from appsettings)"); |
| 53 | + Console.WriteLine(" -p, --prompt <text> Prompt (default: \"Why is the sky blue?\")"); |
| 54 | + Console.WriteLine(" -l, --list List models and exit"); |
| 55 | + return; |
| 56 | + } |
| 57 | +} |
| 58 | + |
| 59 | +CancellationToken ct = default; |
| 60 | + |
| 61 | +// --- Load config --- |
| 62 | +var settings = JsonDocument.Parse( |
| 63 | + File.ReadAllText(Path.Combine(AppContext.BaseDirectory, "appsettings.json"))).RootElement; |
| 64 | +var mdsHost = settings.GetProperty("MdsHost").GetString()!; |
| 65 | +var mdsCustomer = cliCustomer ?? settings.GetProperty("MdsCustomer").GetString()!; |
| 66 | +var mdsKeyDir = settings.GetProperty("MdsKeyDir").GetString()!; |
| 67 | + |
| 68 | +// --- Derive customer resources (same convention as mds/scripts/download_model.py) --- |
| 69 | +var safeName = mdsCustomer.ToLower().Replace(" ", "").Replace("-", ""); |
| 70 | +var registryName = $"mds-{mdsCustomer.ToLower()}-registry"; |
| 71 | +var issuer = $"https://mds{safeName}jwks.blob.core.windows.net/jwks"; |
| 72 | +var kid = $"mds-{mdsCustomer.ToLower()}-key-1"; |
| 73 | +var keyPath = Path.Combine(mdsKeyDir, $"{mdsCustomer.ToLower()}-key.pem"); |
| 74 | + |
| 75 | +if (!File.Exists(keyPath)) |
| 76 | +{ |
| 77 | + Console.WriteLine($"Error: Private key not found at {keyPath}"); |
| 78 | + Console.WriteLine("Run mds/scripts/create_jwks_storage.py --customer <name> first."); |
| 79 | + return; |
| 80 | +} |
| 81 | + |
| 82 | +var jwt = SignJwt(keyPath, kid, issuer, registryName); |
| 83 | +Console.WriteLine($"Signed JWT for '{mdsCustomer}' (registry={registryName})"); |
| 84 | + |
| 85 | +// --- Init Foundry Local --- |
| 86 | +await FoundryLocalManager.CreateAsync( |
| 87 | + new Configuration { AppName = "private_catalog_sample", LogLevel = Microsoft.AI.Foundry.Local.LogLevel.Information }, |
| 88 | + Utils.GetAppLogger()); |
| 89 | +var mgr = FoundryLocalManager.Instance; |
| 90 | +Console.WriteLine("Registering execution providers..."); |
| 91 | +await mgr.DownloadAndRegisterEpsAsync(); |
| 92 | +Console.WriteLine("Done."); |
| 93 | + |
| 94 | +// --- Register private catalog (falls back to public-only if it fails) --- |
| 95 | +var catalog = await mgr.GetCatalogAsync(); |
| 96 | + |
| 97 | +Console.WriteLine($"\nRegistering private catalog at {mdsHost}..."); |
| 98 | +bool privateRegistered = false; |
| 99 | +try |
| 100 | +{ |
| 101 | + await catalog.AddCatalogAsync("private", new Uri(mdsHost), |
| 102 | + options: new Dictionary<string, string> |
| 103 | + { |
| 104 | + ["BearerToken"] = jwt, |
| 105 | + ["Audience"] = "model-distribution-service", |
| 106 | + }); |
| 107 | + privateRegistered = true; |
| 108 | + Console.WriteLine("Private catalog registered."); |
| 109 | +} |
| 110 | +catch (Exception ex) |
| 111 | +{ |
| 112 | + Console.WriteLine($"Warning: could not register private catalog ({ex.Message})."); |
| 113 | + Console.WriteLine("Continuing with the public catalog only."); |
| 114 | +} |
| 115 | + |
| 116 | +// --- List models (grouped by origin) --- |
| 117 | +// Classify by the model's Uri: private MDS models have an |
| 118 | +// `azureml://registries/<mds-registry>/...` Uri, public ones point to the |
| 119 | +// built-in Azure ML registry. This is robust to neutron persisting |
| 120 | +// registered catalogs across runs (which would break a pre-snapshot approach). |
| 121 | +var allModels = await catalog.ListModelsAsync(); |
| 122 | +var allVariants = allModels.SelectMany(m => m.Variants).ToList(); |
| 123 | + |
| 124 | +bool IsPrivate(IModel v) => |
| 125 | + v.Info.Uri?.Contains(registryName, StringComparison.OrdinalIgnoreCase) == true; |
| 126 | + |
| 127 | +var publicVariants = allVariants.Where(v => !IsPrivate(v)).ToList(); |
| 128 | +var privateVariants = allVariants.Where(IsPrivate).ToList(); |
| 129 | + |
| 130 | +// Rebuild in display order (public first, then private) so numbered selection |
| 131 | +// in the interactive picker maps 1:1 to what's printed. |
| 132 | +allVariants = publicVariants.Concat(privateVariants).ToList(); |
| 133 | + |
| 134 | +int idx = 0; |
| 135 | +Console.WriteLine($"\n=== Public Models ({publicVariants.Count}) ==="); |
| 136 | +foreach (var v in publicVariants) |
| 137 | + Console.WriteLine($" [{++idx}] {v.Alias} ({v.Id})"); |
| 138 | + |
| 139 | +if (privateRegistered) |
| 140 | +{ |
| 141 | + Console.WriteLine($"\n=== Private Models ({privateVariants.Count}) ==="); |
| 142 | + if (privateVariants.Count == 0) |
| 143 | + Console.WriteLine(" (none)"); |
| 144 | + foreach (var v in privateVariants) |
| 145 | + Console.WriteLine($" [{++idx}] {v.Alias} ({v.Id})"); |
| 146 | +} |
| 147 | + |
| 148 | +if (listOnly) return; |
| 149 | + |
| 150 | +// --- Resolve a model (from --model or interactive prompt) --- |
| 151 | +IModel? model = null; |
| 152 | +string? input = cliModel; |
| 153 | + |
| 154 | +if (string.IsNullOrWhiteSpace(input)) |
| 155 | +{ |
| 156 | + Console.Write("\nEnter model number, alias, or variant id (q to quit): "); |
| 157 | + input = Console.ReadLine()?.Trim(); |
| 158 | + if (string.IsNullOrEmpty(input) || input.Equals("q", StringComparison.OrdinalIgnoreCase)) return; |
| 159 | + |
| 160 | + if (int.TryParse(input, out int n) && n >= 1 && n <= allVariants.Count) |
| 161 | + input = allVariants[n - 1].Id; |
| 162 | +} |
| 163 | + |
| 164 | +model = await ResolveModel(catalog, allVariants, input!); |
| 165 | +if (model == null) |
| 166 | +{ |
| 167 | + Console.WriteLine($"\nModel '{input}' not found."); |
| 168 | + return; |
| 169 | +} |
| 170 | +Console.WriteLine($"\nSelected: {model.Id}"); |
| 171 | + |
| 172 | +// --- Download / load / chat --- |
| 173 | +await model.DownloadAsync(p => |
| 174 | +{ |
| 175 | + Console.Write($"\rDownloading: {p:F1}%"); |
| 176 | + if (p >= 100f) Console.WriteLine(); |
| 177 | +}); |
| 178 | + |
| 179 | +Console.Write($"Loading {model.Id}..."); |
| 180 | +await model.LoadAsync(); |
| 181 | +Console.WriteLine(" done."); |
| 182 | + |
| 183 | +var chat = await model.GetChatClientAsync(); |
| 184 | +var messages = new List<ChatMessage> { new() { Role = "user", Content = cliPrompt } }; |
| 185 | + |
| 186 | +Console.WriteLine("Chat completion:"); |
| 187 | +await foreach (var chunk in chat.CompleteChatStreamingAsync(messages, ct)) |
| 188 | +{ |
| 189 | + Console.Write(chunk.Choices[0].Message.Content); |
| 190 | + Console.Out.Flush(); |
| 191 | +} |
| 192 | +Console.WriteLine(); |
| 193 | + |
| 194 | +await model.UnloadAsync(); |
| 195 | + |
| 196 | +// --------------------------------------------------------------------------- |
| 197 | +// Helpers |
| 198 | +// --------------------------------------------------------------------------- |
| 199 | + |
| 200 | +static async Task<IModel?> ResolveModel( |
| 201 | + ICatalog catalog, List<IModel> allVariants, string input) |
| 202 | +{ |
| 203 | + // Exact variant id |
| 204 | + var model = await catalog.GetModelVariantAsync(input); |
| 205 | + if (model != null) return model; |
| 206 | + |
| 207 | + // Alias (prefer generic-cpu variant) |
| 208 | + var resolved = await catalog.GetModelAsync(input); |
| 209 | + if (resolved != null) |
| 210 | + { |
| 211 | + var pick = resolved.Variants.FirstOrDefault(v => |
| 212 | + v.Id.Contains("generic-cpu", StringComparison.OrdinalIgnoreCase)) |
| 213 | + ?? resolved.Variants[0]; |
| 214 | + return await catalog.GetModelVariantAsync(pick.Id); |
| 215 | + } |
| 216 | + |
| 217 | + // Substring match against the combined list |
| 218 | + var match = allVariants.FirstOrDefault(v => |
| 219 | + v.Id.Contains(input, StringComparison.OrdinalIgnoreCase) || |
| 220 | + v.Alias.Contains(input, StringComparison.OrdinalIgnoreCase)); |
| 221 | + return match != null ? await catalog.GetModelVariantAsync(match.Id) : null; |
| 222 | +} |
| 223 | + |
| 224 | +static string SignJwt(string pemPath, string kid, string issuer, string registryName) |
| 225 | +{ |
| 226 | + using var rsa = RSA.Create(); |
| 227 | + rsa.ImportFromPem(File.ReadAllText(pemPath)); |
| 228 | + |
| 229 | + var now = DateTimeOffset.UtcNow; |
| 230 | + var header = JsonSerializer.Serialize(new { alg = "RS256", typ = "JWT", kid }); |
| 231 | + var payload = JsonSerializer.Serialize(new Dictionary<string, object> |
| 232 | + { |
| 233 | + ["iss"] = issuer, |
| 234 | + ["sub"] = "foundry-local-sample", |
| 235 | + ["aud"] = "model-distribution-service", |
| 236 | + ["iat"] = now.ToUnixTimeSeconds(), |
| 237 | + ["exp"] = now.AddHours(1).ToUnixTimeSeconds(), |
| 238 | + ["registry_name"] = registryName, |
| 239 | + ["entitlements"] = new Dictionary<string, object> |
| 240 | + { |
| 241 | + ["models"] = new[] { "*" }, |
| 242 | + ["versions"] = new[] { "*" }, |
| 243 | + }, |
| 244 | + }); |
| 245 | + |
| 246 | + var h = B64Url(Encoding.UTF8.GetBytes(header)); |
| 247 | + var p = B64Url(Encoding.UTF8.GetBytes(payload)); |
| 248 | + var sig = rsa.SignData(Encoding.UTF8.GetBytes($"{h}.{p}"), |
| 249 | + HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1); |
| 250 | + return $"{h}.{p}.{B64Url(sig)}"; |
| 251 | +} |
| 252 | + |
| 253 | +static string B64Url(byte[] data) => |
| 254 | + Convert.ToBase64String(data).TrimEnd('=').Replace('+', '-').Replace('/', '_'); |
0 commit comments