Skip to content

Commit e56fe94

Browse files
author
Piotr Stachaczynski
committed
feat: add custom km support
1 parent 537b68b commit e56fe94

7 files changed

Lines changed: 40 additions & 39 deletions

File tree

src/MaIN.Services/Services/LLMService/DeepSeekService.cs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,7 @@ protected override void ValidateApiKey()
6161

6262
chat.Messages.Last().Content = message.Content;
6363
chat.Messages.Last().Files = [];
64-
var result = await Send(chat, new ChatRequestOptions()
65-
{
66-
InteractiveUpdates = true
67-
}, cancellationToken);
64+
var result = await Send(chat, new ChatRequestOptions(), cancellationToken);
6865
chat.Messages.Last().Content = lastMsg.Content;
6966
return result;
7067
}

src/MaIN.Services/Services/LLMService/Factory/IImageGenServiceFactory.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@ namespace MaIN.Services.Services.LLMService.Factory;
55

66
public interface IImageGenServiceFactory
77
{
8-
IImageGenService CreateService(BackendType backendType);
8+
IImageGenService? CreateService(BackendType backendType);
99
}

src/MaIN.Services/Services/LLMService/Factory/ImageGenServiceFactory.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,22 @@ namespace MaIN.Services.Services.LLMService.Factory;
77

88
public class ImageGenServiceFactory(IServiceProvider serviceProvider) : IImageGenServiceFactory
99
{
10-
public IImageGenService CreateService(BackendType backendType)
10+
public IImageGenService? CreateService(BackendType backendType)
1111
{
1212
return backendType switch
1313
{
1414
BackendType.OpenAi => new OpenAiImageGenService(serviceProvider.GetRequiredService<IHttpClientFactory>(),
1515
serviceProvider.GetRequiredService<MaINSettings>()),
1616
BackendType.Gemini => new GeminiImageGenService(serviceProvider.GetRequiredService<IHttpClientFactory>(),
1717
serviceProvider.GetRequiredService<MaINSettings>()),
18-
BackendType.DeepSeek => throw new NotSupportedException("DeepSeek does not support image generation."),
19-
BackendType.GroqCloud => throw new NotSupportedException("Groq Cloud does not support image generation."),
20-
BackendType.Anthropic => throw new NotSupportedException("Anthropic does not support image generation."),
18+
BackendType.DeepSeek => null,
19+
BackendType.GroqCloud => null,
20+
BackendType.Anthropic => null,
2121
BackendType.Self => new ImageGenService(serviceProvider.GetRequiredService<IHttpClientFactory>(),
2222
serviceProvider.GetRequiredService<MaINSettings>()),
2323

2424
// Add other backends as needed
25-
_ => throw new ArgumentOutOfRangeException(nameof(backendType))
25+
_ => throw new NotSupportedException("Not support image generation."),
2626
};
2727
}
2828
}

src/MaIN.Services/Services/LLMService/GeminiService.cs

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@ public sealed class GeminiService(
2222
: OpenAiCompatibleService(notificationService, httpClientFactory, memoryFactory, memoryService, logger)
2323
{
2424
private readonly MaINSettings _settings = settings ?? throw new ArgumentNullException(nameof(settings));
25-
private readonly IHttpClientFactory _httpClientFactory = httpClientFactory ?? throw new ArgumentNullException(nameof(httpClientFactory));
25+
26+
private readonly IHttpClientFactory _httpClientFactory =
27+
httpClientFactory ?? throw new ArgumentNullException(nameof(httpClientFactory));
28+
2629
private readonly IMemoryService _memoryService = memoryService;
2730
private readonly IMemoryFactory _memoryFactory = memoryFactory;
2831

@@ -43,10 +46,10 @@ public override async Task<string[]> GetCurrentModels()
4346
var modelsResponse = JsonSerializer.Deserialize<GeminiModelsResponse>(responseJson);
4447

4548
return modelsResponse?.Models?
46-
.Where(m => m.Name!.StartsWith("models/gemini", StringComparison.InvariantCultureIgnoreCase))
47-
.Select(m => m.Name![7..]) // remove "models/" part => get baseModelId
48-
.ToArray()
49-
?? [];
49+
.Where(m => m.Name!.StartsWith("models/gemini", StringComparison.InvariantCultureIgnoreCase))
50+
.Select(m => m.Name![7..]) // remove "models/" part => get baseModelId
51+
.ToArray()
52+
?? [];
5053
}
5154

5255
protected override string GetApiKey()
@@ -57,7 +60,8 @@ protected override string GetApiKey()
5760

5861
protected override void ValidateApiKey()
5962
{
60-
if (string.IsNullOrEmpty(_settings.GeminiKey) && string.IsNullOrEmpty(Environment.GetEnvironmentVariable("GEMINI_API_KEY")))
63+
if (string.IsNullOrEmpty(_settings.GeminiKey) &&
64+
string.IsNullOrEmpty(Environment.GetEnvironmentVariable("GEMINI_API_KEY")))
6165
{
6266
throw new InvalidOperationException("Gemini Key not configured");
6367
}
@@ -80,9 +84,10 @@ protected override void ValidateApiKey()
8084
{
8185
var jsonGrammarConverter = new GBNFToJsonConverter();
8286
var jsonGrammar = jsonGrammarConverter.ConvertToJson(chat.MemoryParams.Grammar);
83-
userQuery = $"{userQuery} | Respond only using the following JSON format: \n{jsonGrammar}\n. Do not add explanations, code tags, or any extra content.";
87+
userQuery =
88+
$"{userQuery} | For your next response only, please respond using exactly the following JSON format: \n{jsonGrammar}\n. Do not include any explanations, code blocks, or additional content. After this single JSON response, resume your normal conversational style.";
8489
}
85-
90+
8691
var retrievedContext = await kernel.AskAsync(userQuery, cancellationToken: cancellationToken);
8792
chat.Messages.Last().MarkProcessed();
8893
await kernel.DeleteIndexAsync(cancellationToken: cancellationToken);
@@ -92,12 +97,10 @@ protected override void ValidateApiKey()
9297

9398
file class GeminiModelsResponse
9499
{
95-
[JsonPropertyName("models")]
96-
public List<GeminiModel>? Models { get; set; }
100+
[JsonPropertyName("models")] public List<GeminiModel>? Models { get; set; }
97101
}
98102

99103
file class GeminiModel
100104
{
101-
[JsonPropertyName("name")]
102-
public string? Name { get; set; }
105+
[JsonPropertyName("name")] public string? Name { get; set; }
103106
}

src/MaIN.Services/Services/LLMService/GroqCloudService.cs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,7 @@ protected override void ValidateApiKey()
5555

5656
chat.Messages.Last().Content = message.Content;
5757
chat.Messages.Last().Files = [];
58-
var result = await Send(chat, new ChatRequestOptions()
59-
{
60-
InteractiveUpdates = true
61-
}, cancellationToken);
58+
var result = await Send(chat, new ChatRequestOptions(), cancellationToken);
6259
chat.Messages.Last().Content = lastMsg.Content;
6360
return result;
6461
}

src/MaIN.Services/Services/LLMService/OpenAiCompatibleService.cs

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,10 @@ await _notificationService.DispatchNotification(
121121
{
122122
var jsonGrammarConverter = new GBNFToJsonConverter();
123123
var jsonGrammar = jsonGrammarConverter.ConvertToJson(chat.MemoryParams.Grammar);
124-
userQuery = $"{userQuery} | Respond only using the following JSON format: \n{jsonGrammar}\n. Do not add explanations, code tags, or any extra content.";
124+
userQuery =
125+
$"{userQuery} | For your next response only, please respond using exactly the following JSON format: \n{jsonGrammar}\n. Do not include any explanations, code blocks, or additional content. After this single JSON response, resume your normal conversational style.";
125126
}
126-
127+
127128
var retrievedContext = await kernel.AskAsync(userQuery, cancellationToken: cancellationToken);
128129

129130
await kernel.DeleteIndexAsync(cancellationToken: cancellationToken);
@@ -143,13 +144,13 @@ public virtual async Task<string[]> GetCurrentModels()
143144
response.EnsureSuccessStatusCode();
144145

145146
var responseJson = await response.Content.ReadAsStringAsync();
146-
var modelsResponse = JsonSerializer.Deserialize<OpenAiModelsResponse>(responseJson,
147+
var modelsResponse = JsonSerializer.Deserialize<OpenAiModelsResponse>(responseJson,
147148
new JsonSerializerOptions { PropertyNameCaseInsensitive = true });
148149

149150
return (modelsResponse?.Data?
150151
.Select(m => m.Id)
151152
.Where(id => id != null)
152-
.ToArray()
153+
.ToArray()
153154
?? [])!;
154155
}
155156

@@ -208,7 +209,7 @@ private async Task ProcessStreamingChatAsync(
208209
{
209210
role = m.Role,
210211
content = chat.InterferenceParams.Grammar != null
211-
//I know that this is a bit ugly, but hey, it works
212+
//I know that this is a bit ugly, but hey, it works
212213
? $"{m.Content} | Respond only using the following JSON format: \n{new GBNFToJsonConverter().ConvertToJson(chat.InterferenceParams.Grammar)}\n. Do not add explanations, code tags, or any extra content."
213214
: m.Content
214215
}).ToArray(),
@@ -300,10 +301,13 @@ private async Task ProcessNonStreamingChatAsync(
300301
var requestBody = new
301302
{
302303
model = chat.Model,
303-
messages = conversation.Select(m => new { role = m.Role, content = chat.InterferenceParams.Grammar != null
304-
//I know that this is a bit ugly, but hey, it works
305-
? $"{m.Content} | Respond only using the following JSON format: \n{new GBNFToJsonConverter().ConvertToJson(chat.InterferenceParams.Grammar)}\n. Do not add explanations, code tags, or any extra content."
306-
: m.Content }).ToArray(),
304+
messages = conversation.Select(m => new
305+
{
306+
role = m.Role, content = chat.InterferenceParams.Grammar != null
307+
//I know that this is a bit ugly, but hey, it works
308+
? $"{m.Content} | Respond only using the following JSON format: \n{new GBNFToJsonConverter().ConvertToJson(chat.InterferenceParams.Grammar)}\n. Do not add explanations, code tags, or any extra content."
309+
: m.Content
310+
}).ToArray(),
307311
stream = false
308312
};
309313

@@ -368,7 +372,7 @@ public class ChatRequestOptions
368372
{
369373
public bool InteractiveUpdates { get; set; }
370374
public bool CreateSession { get; set; }
371-
public bool SaveConv {get; set; } = true;
375+
public bool SaveConv { get; set; } = true;
372376
public Func<LLMTokenValue, Task>? TokenCallback { get; set; }
373377
}
374378

@@ -409,7 +413,7 @@ file class Delta
409413
}
410414

411415
file class OpenAiModelsResponse
412-
{
416+
{
413417
public List<OpenAiModel>? Data { get; set; }
414418
}
415419

src/MaIN.Services/Services/Steps/Commands/AnswerCommandHandler.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ public class AnswerCommandHandler(
5353
return await ProcessKnowledgeQuery(command.Knowledge, command.Chat, command.AgentId);
5454
}
5555

56-
result = command.Chat!.Visual
57-
? await imageGenService.Send(command.Chat)
56+
result = command.Chat.Visual
57+
? await imageGenService!.Send(command.Chat)
5858
: await llmService.Send(command.Chat,
5959
new ChatRequestOptions { InteractiveUpdates = command.Chat.Interactive });
6060

0 commit comments

Comments
 (0)