Skip to content

Commit cbd2297

Browse files
committed
feat: Add EnsureDownloaded fluent method
Fixes #115 - add EnsureDownloaded method
1 parent 0c6dde0 commit cbd2297

File tree

8 files changed

+128
-182
lines changed

8 files changed

+128
-182
lines changed

src/MaIN.Core/Hub/Contexts/AgentContext.cs

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@
44
using MaIN.Domain.Entities.Agents;
55
using MaIN.Domain.Entities.Agents.AgentSource;
66
using MaIN.Domain.Models;
7+
using MaIN.Domain.Models.Abstract;
78
using MaIN.Services.Services.Abstract;
89
using MaIN.Services.Services.Models;
910
using MaIN.Core.Hub.Utils;
1011
using MaIN.Domain.Entities.Agents.Knowledge;
1112
using MaIN.Domain.Entities.Tools;
1213
using MaIN.Domain.Exceptions.Agents;
14+
using MaIN.Domain.Exceptions.Models;
1315
using MaIN.Services.Constants;
1416

1517
namespace MaIN.Core.Hub.Contexts;
@@ -20,6 +22,7 @@ public sealed class AgentContext : IAgentBuilderEntryPoint, IAgentConfigurationB
2022
private InferenceParams? _inferenceParams;
2123
private MemoryParams? _memoryParams;
2224
private bool _disableCache;
25+
private bool _ensureModelDownloaded;
2326
private readonly Agent _agent;
2427
internal Knowledge? _knowledge;
2528

@@ -112,6 +115,12 @@ public IAgentConfigurationBuilder DisableCache()
112115
return this;
113116
}
114117

118+
public IAgentConfigurationBuilder EnsureModelDownloaded()
119+
{
120+
_ensureModelDownloaded = true;
121+
return this;
122+
}
123+
115124
public IAgentConfigurationBuilder WithSource(IAgentSource source, AgentSourceType type)
116125
{
117126
_agent.Context.Source = new AgentSource()
@@ -200,11 +209,20 @@ public IAgentConfigurationBuilder WithBehaviour(string name, string instruction)
200209

201210
public async Task<IAgentContextExecutor> CreateAsync(bool flow = false, bool interactiveResponse = false)
202211
{
212+
if (_ensureModelDownloaded && !string.IsNullOrWhiteSpace(_agent.Model))
213+
{
214+
var model = ModelRegistry.GetById(_agent.Model);
215+
if (model is LocalModel)
216+
{
217+
await AIHub.Model().EnsureDownloadedAsync(_agent.Model);
218+
}
219+
}
220+
203221
await _agentService.CreateAgent(_agent, flow, interactiveResponse, _inferenceParams, _memoryParams, _disableCache);
204222
return this;
205223
}
206-
207-
public IAgentContextExecutor Create(bool flow = false, bool interactiveResponse = false)
224+
225+
public IAgentContextExecutor Create(bool flow = false, bool interactiveResponse = false) // I think it should be removed as there is a deadlock risk.
208226
{
209227
_ = _agentService.CreateAgent(_agent, flow, interactiveResponse, _inferenceParams, _memoryParams, _disableCache).Result;
210228
return this;

src/MaIN.Core/Hub/Contexts/ChatContext.cs

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using MaIN.Core.Hub.Contexts.Interfaces.ChatContext;
2+
using MaIN.Core.Hub.Contexts.Interfaces.ModelContext;
23
using MaIN.Domain.Configuration;
34
using MaIN.Domain.Entities;
45
using MaIN.Domain.Entities.Tools;
@@ -10,6 +11,7 @@
1011
using MaIN.Services.Constants;
1112
using MaIN.Services.Services.Abstract;
1213
using MaIN.Services.Services.Models;
14+
using MaIN.Core.Hub;
1315
using FileInfo = MaIN.Domain.Entities.FileInfo;
1416

1517
namespace MaIN.Core.Hub.Contexts;
@@ -18,6 +20,7 @@ public sealed class ChatContext : IChatBuilderEntryPoint, IChatMessageBuilder, I
1820
{
1921
private readonly IChatService _chatService;
2022
private bool _preProcess;
23+
private bool _ensureModelDownloaded;
2124
private readonly Chat _chat;
2225
private List<FileInfo> _files = [];
2326

@@ -88,7 +91,13 @@ public IChatMessageBuilder EnableVisual()
8891
_chat.Visual = true;
8992
return this;
9093
}
91-
94+
95+
public IChatMessageBuilder EnsureModelDownloaded()
96+
{
97+
_ensureModelDownloaded = true;
98+
return this;
99+
}
100+
92101
public IChatConfigurationBuilder WithInferenceParams(InferenceParams inferenceParams)
93102
{
94103
_chat.InterferenceParams = inferenceParams;
@@ -208,7 +217,12 @@ public async Task<ChatResult> CompleteAsync(
208217
{
209218
throw new EmptyChatException(_chat.Id);
210219
}
211-
220+
221+
if (_ensureModelDownloaded && _chat.ModelInstance is LocalModel)
222+
{
223+
await AIHub.Model().EnsureDownloadedAsync(_chat.ModelId);
224+
}
225+
212226
_chat.Messages.Last().Files = _files;
213227
if(_preProcess)
214228
{
@@ -244,10 +258,7 @@ private async Task<bool> ChatExists(string id)
244258
return false;
245259
}
246260
}
247-
248-
IChatMessageBuilder IChatMessageBuilder.EnableVisual() => EnableVisual();
249-
250-
261+
251262
public string GetChatId() => _chat.Id;
252263

253264
public async Task<Chat> GetCurrentChat()

src/MaIN.Core/Hub/Contexts/Interfaces/AgentContext/IAgentConfigurationBuilder.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,14 @@ namespace MaIN.Core.Hub.Contexts.Interfaces.AgentContext;
99

1010
public interface IAgentConfigurationBuilder : IAgentActions
1111
{
12+
/// <summary>
13+
/// Flags the agent to automatically ensure the selected local model is downloaded before creation.
14+
/// If the model is already present the download is skipped; cloud models are silently ignored.
15+
/// The actual download is deferred until <see cref="CreateAsync"/> is called.
16+
/// </summary>
17+
/// <returns>The context instance implementing <see cref="IAgentConfigurationBuilder"/> for method chaining.</returns>
18+
IAgentConfigurationBuilder EnsureModelDownloaded();
19+
1220
/// <summary>
1321
/// Sets the initial prompt for the agent. This prompt serves as an instruction or context that guides the agent's behavior during its execution.
1422
/// </summary>

src/MaIN.Core/Hub/Contexts/Interfaces/ChatContext/IChatMessageBuilder.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,14 @@ public interface IChatMessageBuilder : IChatActions
1010
/// </summary>
1111
/// <returns>The context instance implementing <see cref="IChatMessageBuilder"/> for method chaining.</returns>
1212
IChatMessageBuilder EnableVisual();
13+
14+
/// <summary>
15+
/// Flags the chat to automatically ensure the selected local model is downloaded before completing.
16+
/// If the model is already present the download is skipped; cloud models are silently ignored.
17+
/// The actual download is deferred until <see cref="IChatConfigurationBuilder.CompleteAsync"/> is called.
18+
/// </summary>
19+
/// <returns>The context instance implementing <see cref="IChatMessageBuilder"/> for method chaining.</returns>
20+
IChatMessageBuilder EnsureModelDownloaded();
1321

1422
/// <summary>
1523
/// Adds a user message to the chat. This method captures the message content and assigns the "User" role to it.

src/MaIN.Core/Hub/Contexts/Interfaces/ModelContext/IModelContext.cs

Lines changed: 18 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ public interface IModelContext
3939
/// </summary>
4040
/// <param name="modelId">The id of the model to check for existence.</param>
4141
/// <returns>A boolean value indicating whether the model file exists locally.</returns>
42-
bool Exists(string modelId);
42+
bool IsDownloaded(string modelId);
4343

4444
/// <summary>
4545
/// Asynchronously downloads a known model from its configured download URL. This method handles the complete download process
@@ -52,33 +52,27 @@ public interface IModelContext
5252
Task<IModelContext> DownloadAsync(string modelId, CancellationToken cancellationToken = default);
5353

5454
/// <summary>
55-
/// Asynchronously downloads a custom model from a specified URL. This method allows downloading models that are not part
56-
/// of the known models collection, adding them to the system after download.
55+
/// Ensures a known local model is downloaded before use. If the model is already present on disk the call
56+
/// returns immediately; if not, the model is downloaded. Cloud models are silently skipped.
57+
/// Thread-safe: concurrent calls for the same model will not trigger duplicate downloads.
5758
/// </summary>
58-
/// <param name="model">The name to assign to the downloaded model.</param>
59-
/// <param name="url">The URL from which to download the model.</param>
60-
/// <returns>A task that represents the asynchronous download operation that completes when the download finishes,
61-
/// returning the context instance implementing <see cref="IModelContext"/> for method chaining.</returns>
62-
Task<IModelContext> DownloadAsync(string model, string url, CancellationToken cancellationToken = default);
63-
64-
/// <summary>
65-
/// Synchronously downloads a known model from its configured download URL. This is the blocking version of the download operation
66-
/// with progress tracking.
67-
/// </summary>
68-
/// <param name="modelName">The name of the model to download.</param>
69-
/// <returns>The context instance implementing <see cref="IModelContext"/> for method chaining.</returns>
70-
[Obsolete("Use DownloadAsync instead")]
71-
IModelContext Download(string modelName);
59+
/// <param name="modelId">The id of the model to ensure is downloaded.</param>
60+
/// <param name="cancellationToken">Optional cancellation token to abort the download operation.</param>
61+
/// <returns>A task that represents the asynchronous operation, returning the context instance implementing
62+
/// <see cref="IModelContext"/> for method chaining.</returns>
63+
Task<IModelContext> EnsureDownloadedAsync(string modelId, CancellationToken cancellationToken = default);
7264

7365
/// <summary>
74-
/// Synchronously downloads a custom model from a specified URL. This method provides blocking download functionality
75-
/// for custom models not in the known models collection.
66+
/// Ensures a known local model is downloaded before use using a strongly-typed model reference.
67+
/// If the model is already present on disk the call returns immediately; if not, the model is downloaded.
68+
/// Cloud models are silently skipped.
69+
/// Thread-safe: concurrent calls for the same model will not trigger duplicate downloads.
7670
/// </summary>
77-
/// <param name="model">The name to assign to the downloaded model.</param>
78-
/// <param name="url">The URL from which to download the model.</param>
79-
/// <returns>The context instance implementing <see cref="IModelContext"/> for method chaining.</returns>
80-
[Obsolete("Use DownloadAsync instead")]
81-
IModelContext Download(string model, string url);
71+
/// <typeparam name="TModel">A <see cref="LocalModel"/> type with a parameterless constructor.</typeparam>
72+
/// <param name="cancellationToken">Optional cancellation token to abort the download operation.</param>
73+
/// <returns>A task that represents the asynchronous operation, returning the context instance implementing
74+
/// <see cref="IModelContext"/> for method chaining.</returns>
75+
Task<IModelContext> EnsureDownloadedAsync<TModel>(CancellationToken cancellationToken = default) where TModel : LocalModel, new();
8276

8377
/// <summary>
8478
/// Loads a model into the memory cache for faster access during inference operations. This method preloads the model to avoid loading

0 commit comments

Comments
 (0)