Skip to content

Commit 7e6d116

Browse files
Merge pull request #61 from wisedev-code/feat/model-context-to-handle-model-related-ops
Feat/model context to handle model related ops
2 parents 6ed0e1b + 0a80473 commit 7e6d116

9 files changed

Lines changed: 341 additions & 9 deletions

File tree

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
using MaIN.Core;
22
using MaIN.Core.Hub;
33
using MaIN.Domain.Entities;
4+
using OpenAI.Models;
45

56
MaINBootstrapper.Initialize();
67

7-
await AIHub.Chat()
8-
.WithModel("gemma2:2b")
9-
.WithMessage("Hello, World!")
10-
.CompleteAsync(interactive: true);
8+
var model = AIHub.Model();
119

10+
var m = model.GetModel("gemma3:4b");
11+
12+
var x = model.GetModel("llama3.2:3b");
13+
await model.DownloadAsync(x.Name);
1214

1315

Releases/0.2.6.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# 0.2.6 release
2+
3+
- Implement model context to manage LLMs from code,
4+
- Filled missing gaps in docs

src/MaIN.Core/.nuspec

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
<package>
33
<metadata>
44
<id>MaIN.NET</id>
5-
<version>0.2.5</version>
5+
<version>0.2.6</version>
66
<authors>Wisedev</authors>
77
<owners>Wisedev</owners>
88
<icon>favicon.png</icon>

src/MaIN.Core/Bootstrapper.cs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,11 @@ public static IServiceCollection AddAIHub(this IServiceCollection services)
4545
sp.GetRequiredService<IAgentFlowService>(),
4646
sp.GetRequiredService<IMcpService>()
4747
);
48-
48+
49+
var settings = sp.GetRequiredService<MaINSettings>();
50+
var httpClientFactory = sp.GetRequiredService<IHttpClientFactory>();
4951
// Initialize AIHub with the services
50-
AIHub.Initialize(aiServices);
52+
AIHub.Initialize(aiServices, settings, httpClientFactory );
5153
return aiServices;
5254
}
5355
);

src/MaIN.Core/Hub/AiHub.cs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,32 @@
11
using LLama.Native;
22
using MaIN.Core.Hub.Contexts;
33
using MaIN.Core.Interfaces;
4+
using MaIN.Domain.Configuration;
45
using MaIN.Services.Services.Abstract;
56

67
namespace MaIN.Core.Hub;
78

89
public static class AIHub
910
{
1011
private static IAIHubServices? _services;
12+
private static MaINSettings _settings = null!;
13+
private static IHttpClientFactory _httpClientFactory = null!;
1114

12-
internal static void Initialize(IAIHubServices services)
15+
internal static void Initialize(IAIHubServices services,
16+
MaINSettings settings,
17+
IHttpClientFactory httpClientFactory)
1318
{
1419
_services = services;
20+
_settings = settings;
21+
_httpClientFactory = httpClientFactory;
1522
}
1623

1724
private static IAIHubServices Services =>
1825
_services ??
1926
throw new InvalidOperationException(
2027
"AIHub has not been initialized. Make sure to call AddAIHub() in your service configuration.");
2128

29+
public static ModelContext Model() => new ModelContext(_settings, _httpClientFactory);
2230
public static ChatContext Chat() => new(Services.ChatService);
2331
public static AgentContext Agent() => new(Services.AgentService);
2432
public static FlowContext Flow() => new(Services.FlowService, Services.AgentService);
Lines changed: 309 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,309 @@
1+
using System.Diagnostics;
2+
using System.Net;
3+
using MaIN.Domain.Configuration;
4+
using MaIN.Domain.Models;
5+
using MaIN.Services.Constants;
6+
using MaIN.Services.Services.LLMService.Utils;
7+
8+
namespace MaIN.Core.Hub.Contexts;
9+
10+
public class ModelContext
11+
{
12+
private readonly MaINSettings _settings;
13+
private readonly IHttpClientFactory _httpClientFactory;
14+
15+
private const int DefaultBufferSize = 8192;
16+
private const int FileStreamBufferSize = 65536;
17+
private const int ProgressUpdateIntervalMilliseconds = 1000;
18+
private const string MissingModelName = "Model name cannot be null or empty";
19+
private static readonly TimeSpan DefaultHttpTimeout = TimeSpan.FromMinutes(30);
20+
21+
internal ModelContext(MaINSettings settings, IHttpClientFactory httpClientFactory)
22+
{
23+
_settings = settings ?? throw new ArgumentNullException(nameof(settings));
24+
_httpClientFactory = httpClientFactory ?? throw new ArgumentNullException(nameof(httpClientFactory));
25+
}
26+
27+
public List<Model> GetAll() => KnownModels.All();
28+
29+
public Model GetModel(string model) => KnownModels.GetModel(model);
30+
31+
public Model GetEmbeddingModel() => KnownModels.GetEmbeddingModel();
32+
33+
public bool Exists(string modelName)
34+
{
35+
if (string.IsNullOrWhiteSpace(modelName))
36+
{
37+
throw new ArgumentException(nameof(modelName));
38+
}
39+
40+
var model = KnownModels.GetModel(modelName);
41+
var modelPath = GetModelFilePath(model.FileName);
42+
return File.Exists(modelPath);
43+
}
44+
45+
public async Task<ModelContext> DownloadAsync(string modelName, CancellationToken cancellationToken = default)
46+
{
47+
if (string.IsNullOrWhiteSpace(modelName))
48+
{
49+
throw new ArgumentException(MissingModelName, nameof(modelName));
50+
}
51+
52+
var model = KnownModels.GetModel(modelName);
53+
await DownloadModelAsync(model.DownloadUrl!, model.FileName, cancellationToken);
54+
return this;
55+
}
56+
57+
public async Task<ModelContext> DownloadAsync(string model, string url)
58+
{
59+
if (string.IsNullOrWhiteSpace(model))
60+
{
61+
throw new ArgumentException(MissingModelName, nameof(model));
62+
}
63+
64+
if (string.IsNullOrWhiteSpace(url))
65+
{
66+
throw new ArgumentException("URL cannot be null or empty", nameof(url));
67+
}
68+
69+
var fileName = $"{model}.gguf";
70+
await DownloadModelAsync(url, fileName, CancellationToken.None);
71+
72+
var filePath = GetModelFilePath(fileName);
73+
KnownModels.AddModel(model, filePath);
74+
return this;
75+
}
76+
77+
public ModelContext Download(string modelName)
78+
{
79+
if (string.IsNullOrWhiteSpace(modelName))
80+
{
81+
throw new ArgumentException(MissingModelName, nameof(modelName));
82+
}
83+
84+
var model = KnownModels.GetModel(modelName);
85+
DownloadModelSync(model.DownloadUrl!, model.FileName);
86+
return this;
87+
}
88+
89+
public ModelContext Download(string model, string url)
90+
{
91+
if (string.IsNullOrWhiteSpace(model))
92+
{
93+
throw new ArgumentException(MissingModelName, nameof(model));
94+
}
95+
96+
if (string.IsNullOrWhiteSpace(url))
97+
{
98+
throw new ArgumentException("URL cannot be null or empty", nameof(url));
99+
}
100+
101+
var fileName = $"{model}.gguf";
102+
DownloadModelSync(url, fileName);
103+
104+
var filePath = GetModelFilePath(fileName);
105+
KnownModels.AddModel(model, filePath);
106+
return this;
107+
}
108+
109+
public ModelContext LoadToCache(Model model)
110+
{
111+
ArgumentNullException.ThrowIfNull(model);
112+
113+
var modelsPath = ResolvePath(_settings.ModelsPath);
114+
ModelLoader.GetOrLoadModel(modelsPath, model.FileName);
115+
return this;
116+
}
117+
118+
public async Task<ModelContext> LoadToCacheAsync(Model model)
119+
{
120+
ArgumentNullException.ThrowIfNull(model);
121+
122+
var modelsPath = ResolvePath(_settings.ModelsPath);
123+
await ModelLoader.GetOrLoadModelAsync(modelsPath, model.FileName);
124+
return this;
125+
}
126+
127+
private async Task DownloadModelAsync(string url, string fileName, CancellationToken cancellationToken)
128+
{
129+
using var httpClient = CreateConfiguredHttpClient();
130+
var filePath = GetModelFilePath(fileName);
131+
132+
Console.WriteLine($"Starting download of {fileName}...");
133+
134+
try
135+
{
136+
using var response = await httpClient.GetAsync(url, HttpCompletionOption.ResponseHeadersRead, cancellationToken);
137+
response.EnsureSuccessStatusCode();
138+
139+
await DownloadWithProgressAsync(response, filePath, fileName, cancellationToken);
140+
}
141+
catch (Exception ex)
142+
{
143+
Console.WriteLine($"Download failed: {ex.Message}");
144+
145+
if (File.Exists(filePath))
146+
{
147+
File.Delete(filePath);
148+
}
149+
throw;
150+
}
151+
}
152+
153+
private async Task DownloadWithProgressAsync(HttpResponseMessage response, string filePath, string fileName, CancellationToken cancellationToken)
154+
{
155+
var totalBytes = response.Content.Headers.ContentLength;
156+
var totalBytesRead = 0L;
157+
var buffer = new byte[DefaultBufferSize];
158+
var progressStopwatch = Stopwatch.StartNew();
159+
var totalStopwatch = Stopwatch.StartNew();
160+
161+
if (totalBytes.HasValue)
162+
{
163+
Console.WriteLine($"File size: {FormatBytes(totalBytes.Value)}");
164+
}
165+
166+
await using var fileStream = new FileStream(filePath, FileMode.Create, FileAccess.Write, FileShare.None, FileStreamBufferSize);
167+
await using var contentStream = await response.Content.ReadAsStreamAsync(cancellationToken);
168+
169+
while (true)
170+
{
171+
var bytesRead = await contentStream.ReadAsync(buffer, 0, buffer.Length, cancellationToken);
172+
if (bytesRead == 0) break;
173+
174+
await fileStream.WriteAsync(buffer, 0, bytesRead, cancellationToken);
175+
totalBytesRead += bytesRead;
176+
177+
if (ShouldUpdateProgress(progressStopwatch))
178+
{
179+
ShowProgress(totalBytesRead, totalBytes, totalStopwatch);
180+
progressStopwatch.Restart();
181+
}
182+
}
183+
184+
ShowFinalProgress(totalBytesRead, totalStopwatch, fileName);
185+
}
186+
187+
private void DownloadModelSync(string url, string fileName)
188+
{
189+
var filePath = GetModelFilePath(fileName);
190+
191+
Console.WriteLine($"Starting download of {fileName}...");
192+
193+
using var webClient = CreateConfiguredWebClient();
194+
var totalStopwatch = Stopwatch.StartNew();
195+
var progressStopwatch = Stopwatch.StartNew();
196+
197+
webClient.DownloadProgressChanged += (sender, e) =>
198+
{
199+
if (ShouldUpdateProgress(progressStopwatch))
200+
{
201+
ShowProgress(e.BytesReceived, e.TotalBytesToReceive > 0 ? e.TotalBytesToReceive : null, totalStopwatch);
202+
progressStopwatch.Restart();
203+
}
204+
};
205+
206+
webClient.DownloadFileCompleted += (sender, e) =>
207+
{
208+
totalStopwatch.Stop();
209+
if (e.Error != null)
210+
{
211+
Console.WriteLine($"\nDownload failed: {e.Error.Message}");
212+
}
213+
else
214+
{
215+
var totalTime = totalStopwatch.Elapsed;
216+
Console.WriteLine($"\nDownload completed: {fileName}. Time: {totalTime:hh\\:mm\\:ss}");
217+
}
218+
};
219+
220+
try
221+
{
222+
webClient.DownloadFile(url, filePath);
223+
}
224+
catch (Exception ex)
225+
{
226+
Console.WriteLine($"Download failed: {ex.Message}");
227+
228+
if (File.Exists(filePath))
229+
{
230+
File.Delete(filePath);
231+
}
232+
233+
throw;
234+
}
235+
}
236+
237+
private HttpClient CreateConfiguredHttpClient()
238+
{
239+
var httpClient = _httpClientFactory.CreateClient(ServiceConstants.HttpClients.ModelContextDownloadClient);
240+
httpClient.Timeout = DefaultHttpTimeout;
241+
return httpClient;
242+
}
243+
244+
private static WebClient CreateConfiguredWebClient()
245+
{
246+
var webClient = new WebClient();
247+
webClient.Headers.Add("User-Agent", "YourApp/1.0");
248+
return webClient;
249+
}
250+
251+
private string GetModelFilePath(string fileName) => Path.Combine(ResolvePath(_settings.ModelsPath), fileName);
252+
253+
private static bool ShouldUpdateProgress(Stopwatch progressStopwatch) =>
254+
progressStopwatch.ElapsedMilliseconds >= ProgressUpdateIntervalMilliseconds;
255+
256+
private static void ShowProgress(long totalBytesRead, long? totalBytes, Stopwatch totalStopwatch)
257+
{
258+
var elapsedSeconds = totalStopwatch.Elapsed.TotalSeconds;
259+
var speed = elapsedSeconds > 0 ? totalBytesRead / elapsedSeconds : 0;
260+
261+
if (totalBytes.HasValue)
262+
{
263+
var progressPercentage = (double)totalBytesRead / totalBytes.Value * 100;
264+
var eta = speed > 0 ? TimeSpan.FromSeconds((totalBytes.Value - totalBytesRead) / speed) : TimeSpan.Zero;
265+
266+
Console.Write($"\rProgress: {progressPercentage:F1}% ({FormatBytes(totalBytesRead)}/{FormatBytes(totalBytes.Value)}) " +
267+
$"Speed: {FormatBytes((long)speed)}/s ETA: {eta:hh\\:mm\\:ss}");
268+
}
269+
else
270+
{
271+
Console.Write($"\rDownloaded: {FormatBytes(totalBytesRead)} Speed: {FormatBytes((long)speed)}/s");
272+
}
273+
}
274+
275+
private static void ShowFinalProgress(long totalBytesRead, Stopwatch totalStopwatch, string fileName)
276+
{
277+
totalStopwatch.Stop();
278+
var totalTime = totalStopwatch.Elapsed;
279+
var avgSpeed = totalTime.TotalSeconds > 0 ? totalBytesRead / totalTime.TotalSeconds : 0;
280+
281+
Console.WriteLine($"\nDownload completed: {fileName}. " +
282+
$"Total size: {FormatBytes(totalBytesRead)}, " +
283+
$"Time: {totalTime:hh\\:mm\\:ss}, " +
284+
$"Average speed: {FormatBytes((long)avgSpeed)}/s");
285+
}
286+
287+
private static string FormatBytes(long bytes)
288+
{
289+
if (bytes == 0) return "0 Bytes";
290+
291+
const int scale = 1024;
292+
string[] orders = ["GB", "MB", "KB", "Bytes"];
293+
var max = (long)Math.Pow(scale, orders.Length - 1);
294+
295+
foreach (var order in orders)
296+
{
297+
if (bytes >= max)
298+
return $"{decimal.Divide(bytes, max):##.##} {order}";
299+
max /= scale;
300+
}
301+
302+
return "0 Bytes";
303+
}
304+
305+
private string ResolvePath(string? settingsModelsPath) =>
306+
settingsModelsPath
307+
?? Environment.GetEnvironmentVariable("MaIN_ModelsPath")
308+
?? throw new InvalidOperationException("Models path not found in settings or environment variables");
309+
}

0 commit comments

Comments
 (0)