Skip to content

Commit b198073

Browse files
authored
feat: Add EnsureDownloaded fluent method (#121)
* feat: Add EnsureDownloaded fluent method Fixes #115 - add EnsureDownloaded method * cleanup and use EnsureDownloaded in the example * error is throw when invalid model type is provided
1 parent b55e64d commit b198073

File tree

9 files changed

+153
-209
lines changed

9 files changed

+153
-209
lines changed

Examples/Examples/Chat/ChatExample.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ public async Task Start()
1212
// Using strongly-typed model
1313
await AIHub.Chat()
1414
.WithModel<Gemma2_2b>()
15+
.EnsureModelDownloaded()
1516
.WithMessage("Where do hedgehogs goes at night?")
1617
.CompleteAsync(interactive: true);
1718
}

src/MaIN.Core/Hub/Contexts/AgentContext.cs

Lines changed: 38 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
using MaIN.Core.Hub.Contexts.Interfaces.AgentContext;
2+
using MaIN.Core.Hub.Utils;
23
using MaIN.Domain.Configuration;
34
using MaIN.Domain.Entities;
45
using MaIN.Domain.Entities.Agents;
56
using MaIN.Domain.Entities.Agents.AgentSource;
6-
using MaIN.Domain.Models;
7-
using MaIN.Services.Services.Abstract;
8-
using MaIN.Services.Services.Models;
9-
using MaIN.Core.Hub.Utils;
107
using MaIN.Domain.Entities.Agents.Knowledge;
118
using MaIN.Domain.Entities.Tools;
129
using MaIN.Domain.Exceptions.Agents;
10+
using MaIN.Domain.Models;
11+
using MaIN.Domain.Models.Abstract;
1312
using MaIN.Services.Constants;
13+
using MaIN.Services.Services.Abstract;
14+
using MaIN.Services.Services.Models;
1415

1516
namespace MaIN.Core.Hub.Contexts;
1617

@@ -20,6 +21,7 @@ public sealed class AgentContext : IAgentBuilderEntryPoint, IAgentConfigurationB
2021
private InferenceParams? _inferenceParams;
2122
private MemoryParams? _memoryParams;
2223
private bool _disableCache;
24+
private bool _ensureModelDownloaded;
2325
private readonly Agent _agent;
2426
internal Knowledge? _knowledge;
2527

@@ -60,8 +62,8 @@ internal AgentContext(IAgentService agentService, Agent existingAgent)
6062
public async Task<Agent?> GetAgentById(string id) => await _agentService.GetAgentById(id);
6163
public async Task Delete() => await _agentService.DeleteAgent(_agent.Id);
6264
public async Task<bool> Exists() => await _agentService.AgentExists(_agent.Id);
63-
64-
65+
66+
6567
public IAgentConfigurationBuilder WithModel(string model)
6668
{
6769
_agent.Model = model;
@@ -82,18 +84,18 @@ public async Task<IAgentContextExecutor> FromExisting(string agentId)
8284
{
8385
throw new AgentNotFoundException(agentId);
8486
}
85-
87+
8688
var context = new AgentContext(_agentService, existingAgent);
8789
context.LoadExistingKnowledgeIfExists();
8890
return context;
8991
}
90-
92+
9193
public IAgentConfigurationBuilder WithInitialPrompt(string prompt)
9294
{
9395
_agent.Context.Instruction = prompt;
9496
return this;
9597
}
96-
98+
9799
public IAgentConfigurationBuilder WithId(string id)
98100
{
99101
_agent.Id = id;
@@ -112,6 +114,12 @@ public IAgentConfigurationBuilder DisableCache()
112114
return this;
113115
}
114116

117+
public IAgentConfigurationBuilder EnsureModelDownloaded()
118+
{
119+
_ensureModelDownloaded = true;
120+
return this;
121+
}
122+
115123
public IAgentConfigurationBuilder WithSource(IAgentSource source, AgentSourceType type)
116124
{
117125
_agent.Context.Source = new AgentSource()
@@ -121,7 +129,7 @@ public IAgentConfigurationBuilder WithSource(IAgentSource source, AgentSourceTyp
121129
};
122130
return this;
123131
}
124-
132+
125133
public IAgentConfigurationBuilder WithName(string name)
126134
{
127135
_agent.Name = name;
@@ -143,7 +151,7 @@ public IAgentConfigurationBuilder WithMcpConfig(Mcp mcpConfig)
143151
_agent.Context.McpConfig = mcpConfig;
144152
return this;
145153
}
146-
154+
147155
public IAgentConfigurationBuilder WithInferenceParams(InferenceParams inferenceParams)
148156
{
149157
_inferenceParams = inferenceParams;
@@ -174,7 +182,7 @@ public IAgentConfigurationBuilder WithKnowledge(KnowledgeBuilder knowledge)
174182
_knowledge = knowledge.ForAgent(_agent).Build();
175183
return this;
176184
}
177-
185+
178186
public IAgentConfigurationBuilder WithKnowledge(Knowledge knowledge)
179187
{
180188
_knowledge = knowledge;
@@ -189,7 +197,7 @@ public IAgentConfigurationBuilder WithInMemoryKnowledge(Func<KnowledgeBuilder, K
189197
_knowledge = knowledgeConfig(builder).Build();
190198
return this;
191199
}
192-
200+
193201
public IAgentConfigurationBuilder WithBehaviour(string name, string instruction)
194202
{
195203
_agent.Behaviours ??= new Dictionary<string, string>();
@@ -200,10 +208,15 @@ public IAgentConfigurationBuilder WithBehaviour(string name, string instruction)
200208

201209
public async Task<IAgentContextExecutor> CreateAsync(bool flow = false, bool interactiveResponse = false)
202210
{
211+
if (_ensureModelDownloaded && !string.IsNullOrWhiteSpace(_agent.Model))
212+
{
213+
await AIHub.Model().EnsureDownloadedAsync(_agent.Model);
214+
}
215+
203216
await _agentService.CreateAgent(_agent, flow, interactiveResponse, _inferenceParams, _memoryParams, _disableCache);
204217
return this;
205218
}
206-
219+
207220
public IAgentContextExecutor Create(bool flow = false, bool interactiveResponse = false)
208221
{
209222
_ = _agentService.CreateAgent(_agent, flow, interactiveResponse, _inferenceParams, _memoryParams, _disableCache).Result;
@@ -215,7 +228,7 @@ public IAgentConfigurationBuilder WithTools(ToolsConfiguration toolsConfiguratio
215228
_agent.ToolsConfiguration = toolsConfiguration;
216229
return this;
217230
}
218-
231+
219232
internal void LoadExistingKnowledgeIfExists()
220233
{
221234
_knowledge ??= new Knowledge(_agent);
@@ -229,7 +242,7 @@ internal void LoadExistingKnowledgeIfExists()
229242
Console.WriteLine("Knowledge cannot be loaded - new one will be created");
230243
}
231244
}
232-
245+
233246
public async Task<ChatResult> ProcessAsync(Chat chat, bool translate = false)
234247
{
235248
if (_knowledge == null)
@@ -247,7 +260,7 @@ public async Task<ChatResult> ProcessAsync(Chat chat, bool translate = false)
247260
CreatedAt = DateTime.Now
248261
};
249262
}
250-
263+
251264
public async Task<ChatResult> ProcessAsync(
252265
string message,
253266
bool translate = false,
@@ -276,8 +289,8 @@ public async Task<ChatResult> ProcessAsync(
276289
CreatedAt = DateTime.Now
277290
};
278291
}
279-
280-
public async Task<ChatResult> ProcessAsync(Message message,
292+
293+
public async Task<ChatResult> ProcessAsync(Message message,
281294
bool translate = false,
282295
Func<LLMTokenValue, Task>? tokenCallback = null,
283296
Func<ToolInvocation, Task>? toolCallback = null)
@@ -288,7 +301,7 @@ public async Task<ChatResult> ProcessAsync(Message message,
288301
}
289302
var chat = await _agentService.GetChatByAgent(_agent.Id);
290303
chat.Messages.Add(message);
291-
var result = await _agentService.Process(chat, _agent.Id, _knowledge, translate, tokenCallback, toolCallback);;
304+
var result = await _agentService.Process(chat, _agent.Id, _knowledge, translate, tokenCallback, toolCallback);
292305
var messageResult = result.Messages.LastOrDefault()!;
293306
return new ChatResult()
294307
{
@@ -298,7 +311,7 @@ public async Task<ChatResult> ProcessAsync(Message message,
298311
CreatedAt = DateTime.Now
299312
};
300313
}
301-
314+
302315
public async Task<ChatResult> ProcessAsync(
303316
IEnumerable<Message> messages,
304317
bool translate = false,
@@ -317,7 +330,7 @@ public async Task<ChatResult> ProcessAsync(
317330
chat.Messages.Add(systemMsg);
318331
}
319332
chat.Messages.AddRange(messages);
320-
var result = await _agentService.Process(chat, _agent.Id, _knowledge, translate, tokenCallback, toolCallback);;
333+
var result = await _agentService.Process(chat, _agent.Id, _knowledge, translate, tokenCallback, toolCallback);
321334
var messageResult = result.Messages.LastOrDefault()!;
322335
return new ChatResult()
323336
{
@@ -335,7 +348,7 @@ public static async Task<AgentContext> FromExisting(IAgentService agentService,
335348
{
336349
throw new AgentNotFoundException(agentId);
337350
}
338-
351+
339352
var context = new AgentContext(agentService, existingAgent);
340353
context.LoadExistingKnowledgeIfExists();
341354
return context;
@@ -345,8 +358,8 @@ public static async Task<AgentContext> FromExisting(IAgentService agentService,
345358
public static class AgentExtensions
346359
{
347360
public static async Task<ChatResult> ProcessAsync(
348-
this Task<AgentContext> agentTask,
349-
string message,
361+
this Task<AgentContext> agentTask,
362+
string message,
350363
bool translate = false)
351364
{
352365
var agent = await agentTask;

src/MaIN.Core/Hub/Contexts/ChatContext.cs

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ public sealed class ChatContext : IChatBuilderEntryPoint, IChatMessageBuilder, I
1818
{
1919
private readonly IChatService _chatService;
2020
private bool _preProcess;
21+
private bool _ensureModelDownloaded;
2122
private readonly Chat _chat;
2223
private List<FileInfo> _files = [];
2324

@@ -88,7 +89,13 @@ public IChatMessageBuilder EnableVisual()
8889
_chat.Visual = true;
8990
return this;
9091
}
91-
92+
93+
public IChatMessageBuilder EnsureModelDownloaded()
94+
{
95+
_ensureModelDownloaded = true;
96+
return this;
97+
}
98+
9299
public IChatConfigurationBuilder WithInferenceParams(InferenceParams inferenceParams)
93100
{
94101
_chat.InterferenceParams = inferenceParams;
@@ -194,7 +201,7 @@ public IChatConfigurationBuilder DisableCache()
194201
_chat.Properties.AddProperty(ServiceConstants.Properties.DisableCacheProperty);
195202
return this;
196203
}
197-
204+
198205
public async Task<ChatResult> CompleteAsync(
199206
bool translate = false, // Move to WithTranslate
200207
bool interactive = false, // Move to WithInteractive
@@ -208,13 +215,18 @@ public async Task<ChatResult> CompleteAsync(
208215
{
209216
throw new EmptyChatException(_chat.Id);
210217
}
211-
218+
219+
if (_ensureModelDownloaded)
220+
{
221+
await AIHub.Model().EnsureDownloadedAsync(_chat.ModelId);
222+
}
223+
212224
_chat.Messages.Last().Files = _files;
213-
if(_preProcess)
225+
if (_preProcess)
214226
{
215227
_chat.Messages.Last().Properties.AddProperty(ServiceConstants.Properties.PreProcessProperty);
216228
}
217-
229+
218230
if (!await ChatExists(_chat.Id))
219231
{
220232
await _chatService.Create(_chat);
@@ -227,8 +239,8 @@ public async Task<ChatResult> CompleteAsync(
227239
public async Task<IChatConfigurationBuilder> FromExisting(string chatId)
228240
{
229241
var existing = await _chatService.GetById(chatId);
230-
return existing == null
231-
? throw new ChatNotFoundException(chatId)
242+
return existing == null
243+
? throw new ChatNotFoundException(chatId)
232244
: new ChatContext(_chatService, existing);
233245
}
234246

@@ -244,12 +256,9 @@ private async Task<bool> ChatExists(string id)
244256
return false;
245257
}
246258
}
247-
248-
IChatMessageBuilder IChatMessageBuilder.EnableVisual() => EnableVisual();
249259

250-
251260
public string GetChatId() => _chat.Id;
252-
261+
253262
public async Task<Chat> GetCurrentChat()
254263
{
255264
if (_chat.Id == null)
@@ -271,7 +280,7 @@ public async Task DeleteChat()
271280

272281
await _chatService.Delete(_chat.Id);
273282
}
274-
283+
275284
public List<MessageShort> GetChatHistory()
276285
{
277286
return [.. _chat.Messages.Select(x => new MessageShort()

src/MaIN.Core/Hub/Contexts/Interfaces/AgentContext/IAgentConfigurationBuilder.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,14 @@ namespace MaIN.Core.Hub.Contexts.Interfaces.AgentContext;
99

1010
public interface IAgentConfigurationBuilder : IAgentActions
1111
{
12+
/// <summary>
13+
/// Flags the agent to automatically ensure the selected local model is downloaded before creation.
14+
/// If the model is already present the download is skipped; cloud models are silently ignored.
15+
/// The actual download is deferred until <see cref="CreateAsync"/> is called.
16+
/// </summary>
17+
/// <returns>The context instance implementing <see cref="IAgentConfigurationBuilder"/> for method chaining.</returns>
18+
IAgentConfigurationBuilder EnsureModelDownloaded();
19+
1220
/// <summary>
1321
/// Sets the initial prompt for the agent. This prompt serves as an instruction or context that guides the agent's behavior during its execution.
1422
/// </summary>

src/MaIN.Core/Hub/Contexts/Interfaces/ChatContext/IChatMessageBuilder.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,14 @@ public interface IChatMessageBuilder : IChatActions
1010
/// </summary>
1111
/// <returns>The context instance implementing <see cref="IChatMessageBuilder"/> for method chaining.</returns>
1212
IChatMessageBuilder EnableVisual();
13+
14+
/// <summary>
15+
/// Flags the chat to automatically ensure the selected local model is downloaded before completing.
16+
/// If the model is already present the download is skipped; cloud models are silently ignored.
17+
/// The actual download is deferred until <see cref="IChatConfigurationBuilder.CompleteAsync"/> is called.
18+
/// </summary>
19+
/// <returns>The context instance implementing <see cref="IChatMessageBuilder"/> for method chaining.</returns>
20+
IChatMessageBuilder EnsureModelDownloaded();
1321

1422
/// <summary>
1523
/// Adds a user message to the chat. This method captures the message content and assigns the "User" role to it.

src/MaIN.Core/Hub/Contexts/Interfaces/ModelContext/IModelContext.cs

Lines changed: 17 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -52,33 +52,27 @@ public interface IModelContext
5252
Task<IModelContext> DownloadAsync(string modelId, CancellationToken cancellationToken = default);
5353

5454
/// <summary>
55-
/// Asynchronously downloads a custom model from a specified URL. This method allows downloading models that are not part
56-
/// of the known models collection, adding them to the system after download.
55+
/// Ensures a known local model is downloaded before use. If the model is already present on disk the call
56+
/// returns immediately; if not, the model is downloaded. Cloud models are silently skipped.
57+
/// Thread-safe: concurrent calls for the same model will not trigger duplicate downloads.
5758
/// </summary>
58-
/// <param name="model">The name to assign to the downloaded model.</param>
59-
/// <param name="url">The URL from which to download the model.</param>
60-
/// <returns>A task that represents the asynchronous download operation that completes when the download finishes,
61-
/// returning the context instance implementing <see cref="IModelContext"/> for method chaining.</returns>
62-
Task<IModelContext> DownloadAsync(string model, string url, CancellationToken cancellationToken = default);
63-
64-
/// <summary>
65-
/// Synchronously downloads a known model from its configured download URL. This is the blocking version of the download operation
66-
/// with progress tracking.
67-
/// </summary>
68-
/// <param name="modelName">The name of the model to download.</param>
69-
/// <returns>The context instance implementing <see cref="IModelContext"/> for method chaining.</returns>
70-
[Obsolete("Use DownloadAsync instead")]
71-
IModelContext Download(string modelName);
59+
/// <param name="modelId">The id of the model to ensure is downloaded.</param>
60+
/// <param name="cancellationToken">Optional cancellation token to abort the download operation.</param>
61+
/// <returns>A task that represents the asynchronous operation, returning the context instance implementing
62+
/// <see cref="IModelContext"/> for method chaining.</returns>
63+
Task<IModelContext> EnsureDownloadedAsync(string modelId, CancellationToken cancellationToken = default);
7264

7365
/// <summary>
74-
/// Synchronously downloads a custom model from a specified URL. This method provides blocking download functionality
75-
/// for custom models not in the known models collection.
66+
/// Ensures a known local model is downloaded before use using a strongly-typed model reference.
67+
/// If the model is already present on disk the call returns immediately; if not, the model is downloaded.
68+
/// Cloud models are silently skipped.
69+
/// Thread-safe: concurrent calls for the same model will not trigger duplicate downloads.
7670
/// </summary>
77-
/// <param name="model">The name to assign to the downloaded model.</param>
78-
/// <param name="url">The URL from which to download the model.</param>
79-
/// <returns>The context instance implementing <see cref="IModelContext"/> for method chaining.</returns>
80-
[Obsolete("Use DownloadAsync instead")]
81-
IModelContext Download(string model, string url);
71+
/// <typeparam name="TModel">A <see cref="LocalModel"/> type with a parameterless constructor.</typeparam>
72+
/// <param name="cancellationToken">Optional cancellation token to abort the download operation.</param>
73+
/// <returns>A task that represents the asynchronous operation, returning the context instance implementing
74+
/// <see cref="IModelContext"/> for method chaining.</returns>
75+
Task<IModelContext> EnsureDownloadedAsync<TModel>(CancellationToken cancellationToken = default) where TModel : LocalModel, new();
8276

8377
/// <summary>
8478
/// Loads a model into the memory cache for faster access during inference operations. This method preloads the model to avoid loading

0 commit comments

Comments
 (0)