Skip to content

Commit a9920ee

Browse files
author
Piotr Stachaczynski
committed
feat: minor fixes
1 parent 4334e7f commit a9920ee

1 file changed

Lines changed: 115 additions & 119 deletions

File tree

src/MaIN.Services/Services/LLMService/LLMService.cs

Lines changed: 115 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,15 @@ public class LLMService(IOptions<MaINSettings> options, INotificationService not
5959
GpuLayerCount = 30,
6060
};
6161

62-
var session = newSession ? GetOrCreateSession(chat.Id, () =>
63-
{
64-
var context = llmModel.CreateContext(parameters);
65-
var history = new ChatHistory();
66-
var executor = new InteractiveExecutor(context);
67-
return new ChatSession(executor, history);
68-
}) : new ChatSession(new InteractiveExecutor(llmModel.CreateContext(parameters)));
62+
var session = newSession
63+
? GetOrCreateSession(chat.Id, () =>
64+
{
65+
var context = llmModel.CreateContext(parameters);
66+
var history = new ChatHistory();
67+
var executor = new InteractiveExecutor(context);
68+
return new ChatSession(executor, history);
69+
})
70+
: new ChatSession(new InteractiveExecutor(llmModel.CreateContext(parameters)));
6971

7072
// Add all messages to the session history.
7173
AddMessagesToHistory(session, chat.Messages);
@@ -83,8 +85,11 @@ public class LLMService(IOptions<MaINSettings> options, INotificationService not
8385
if (lastMessage.Files?.Any() ?? false)
8486
{
8587
#pragma warning disable SKEXP0001
86-
var textData = lastMessage.Files.Where(x => x.Content is not null).ToDictionary(x => x.Name, x => x.Content);
87-
var fileData = lastMessage.Files.Where(x => x.Path is not null).ToDictionary(x => x.Name, x => x.Path); //shity coode TODO
88+
var textData = lastMessage.Files.Where(x => x.Content is not null)
89+
.ToDictionary(x => x.Name, x => x.Content);
90+
var fileData =
91+
lastMessage.Files.Where(x => x.Path is not null)
92+
.ToDictionary(x => x.Name, x => x.Path); //shity coode TODO
8893
var result = await AskMemory(chat, textData!, fileData!);
8994
resultBuilder.Append(result!.Message.Content);
9095
#pragma warning restore SKEXP0001
@@ -104,18 +109,19 @@ await notificationService.DispatchNotification(
104109
false),
105110
"ReceiveMessageUpdate");
106111
}
112+
107113
resultBuilder.Append(text);
108114
}
109115
}
110116

111117
if (interactiveUpdates)
112118
{
113-
await notificationService.DispatchNotification( NotificationMessageBuilder.CreateChatCompletion(
119+
await notificationService.DispatchNotification(NotificationMessageBuilder.CreateChatCompletion(
114120
chat.Id,
115121
resultBuilder.ToString(),
116122
true), "ReceiveMessageUpdate");
117123
}
118-
124+
119125
var chatResult = new ChatResult
120126
{
121127
Done = true,
@@ -143,17 +149,17 @@ await notificationService.DispatchNotification( NotificationMessageBuilder.Creat
143149
using var context = model.CreateContext(parameters);
144150

145151
// Llava Init
146-
var inferenceParams = new InferenceParams() { AntiPrompts = new[] { model.Vocab.EOT.ToString() ?? "User:" }};
152+
var inferenceParams = new InferenceParams() { AntiPrompts = new[] { model.Vocab.EOT.ToString() ?? "User:" } };
147153
var ex = new InteractiveExecutor(context);
148-
ex.Context.NativeHandle.KvCacheRemove( LLamaSeqId.Zero, -1, -1 );
154+
ex.Context.NativeHandle.KvCacheRemove(LLamaSeqId.Zero, -1, -1);
149155
ex.Images.Add(chat.Messages!.Last().Images);
150156
var result = new StringBuilder();
151157
await foreach (var text in ex.InferAsync(chat.Messages!.Last().Content, inferenceParams))
152158
{
153159
Console.Write(text);
154160
result.Append(text);
155161
}
156-
162+
157163
var chatResult = new ChatResult
158164
{
159165
Done = true,
@@ -200,7 +206,7 @@ private void AddMessagesToHistory(ChatSession session, List<Message> messages)
200206
}
201207

202208
[Experimental("SKEXP0001")]
203-
public async Task<ChatResult?> AskMemory(Chat chat,
209+
public async Task<ChatResult?> AskMemory(Chat chat,
204210
Dictionary<string, string>? textData = null,
205211
Dictionary<string, string>? fileData = null,
206212
List<string>? memory = null)
@@ -239,7 +245,7 @@ private void AddMessagesToHistory(ChatSession session, List<Message> messages)
239245
var result = await kernelMemory.AskAsync(userMsg.Content);
240246

241247
await kernelMemory.DeleteIndexAsync();
242-
248+
243249
var chatResult = new ChatResult()
244250
{
245251
Done = true,
@@ -251,15 +257,16 @@ private void AddMessagesToHistory(ChatSession session, List<Message> messages)
251257
Role = AuthorRole.Assistant.ToString()
252258
}
253259
};
254-
260+
255261
generator.Dispose();
256262

257263
return chatResult;
258264
}
259265

260266

261267
[Experimental("KMEXP01")]
262-
private static IKernelMemory CreateMemory(string modelName, string path, out KernelMemFix.LlamaSharpTextGenerator generator)
268+
private static IKernelMemory CreateMemory(string modelName, string path,
269+
out KernelMemFix.LlamaSharpTextGenerator generator)
263270
{
264271
InferenceParams infParams = new() { AntiPrompts = ["INFO", "<|im_end|>", "Question:"] };
265272

@@ -282,14 +289,14 @@ private static IKernelMemory CreateMemory(string modelName, string path, out Ker
282289

283290
return new KernelMemoryBuilder()
284291
//.WithLLamaSharpDefaults2(lsConfig)
285-
.WithLLamaSharpMaINTemp(lsConfig, Path.Combine(path, modelName), out generator)
292+
.WithLLamaSharpMaINTemp(lsConfig, path, modelName, out generator)
286293
.WithSearchClientConfig(searchClientConfig)
287294
.WithCustomImageOcr(new OcrWrapper())
288295
.With(parseOptions)
289296
.Build();
290297
}
291298

292-
private async Task<LLamaWeights> GetOrLoadModelAsync(string path, string modelKey)
299+
internal static async Task<LLamaWeights> GetOrLoadModelAsync(string path, string modelKey)
293300
{
294301
if (modelCache.TryGetValue(modelKey, out var cachedModel))
295302
{
@@ -328,120 +335,116 @@ public Task CleanSessionCache(string id)
328335
}
329336

330337
internal static class KernelMemFix
331-
{
338+
{
332339
[Experimental("KMEXP00")]
333340
public sealed class LlamaSharpTextGenerator : ITextGenerator, ITextTokenizer, IDisposable
334-
{
335-
private readonly StatelessExecutor _executor;
336-
private readonly LLamaWeights _weights;
337-
private readonly bool _ownsWeights;
338-
private readonly LLamaContext _context;
339-
private readonly bool _ownsContext;
340-
private readonly InferenceParams? _defaultInferenceParams;
341-
342-
public int MaxTokenTotal { get; }
343-
344-
345-
public LlamaSharpTextGenerator(
346-
LLamaWeights weights,
347-
LLamaContext context,
348-
StatelessExecutor? executor = null,
349-
InferenceParams? inferenceParams = null)
350341
{
351-
this._weights = weights;
352-
this._context = context;
353-
this._executor = executor ?? new StatelessExecutor(this._weights, this._context.Params);
354-
this._defaultInferenceParams = inferenceParams;
355-
this.MaxTokenTotal = (int) this._context.ContextSize;
356-
}
342+
private readonly StatelessExecutor _executor;
343+
private readonly LLamaWeights _weights;
344+
private readonly bool _ownsWeights;
345+
private readonly LLamaContext _context;
346+
private readonly bool _ownsContext;
347+
private readonly InferenceParams? _defaultInferenceParams;
357348

358-
public void Dispose()
359-
{
360-
if (this._ownsWeights)
361-
this._weights.Dispose();
362-
if (!this._ownsContext)
363-
return;
364-
this._context.Dispose();
365-
}
349+
public int MaxTokenTotal { get; }
366350

367-
public IAsyncEnumerable<GeneratedTextContent> GenerateTextAsync(string prompt, TextGenerationOptions options,
368-
CancellationToken cancellationToken = default)
369-
{
370-
return _executor
371-
.InferAsync(prompt, OptionsToParams(options, _defaultInferenceParams), cancellationToken: cancellationToken)
372-
.Select(a => new GeneratedTextContent(a));
373-
}
374351

375-
private static InferenceParams OptionsToParams(
376-
TextGenerationOptions options,
377-
InferenceParams? defaultParams)
378-
{
379-
if (defaultParams != (InferenceParams) null)
380-
return defaultParams with
352+
public LlamaSharpTextGenerator(
353+
LLamaWeights weights,
354+
LLamaContext context,
355+
StatelessExecutor? executor = null,
356+
InferenceParams? inferenceParams = null)
381357
{
382-
AntiPrompts = (IReadOnlyList<string>) defaultParams.AntiPrompts.Concat<string>((IEnumerable<string>) options.StopSequences).ToList<string>().AsReadOnly(),
383-
MaxTokens = options.MaxTokens ?? defaultParams.MaxTokens,
384-
SamplingPipeline = (ISamplingPipeline) new DefaultSamplingPipeline()
385-
{
386-
Temperature = (float) options.Temperature,
387-
FrequencyPenalty = (float) options.FrequencyPenalty,
388-
PresencePenalty = (float) options.PresencePenalty,
389-
TopP = (float) options.NucleusSampling
390-
}
391-
};
392-
return new InferenceParams()
393-
{
394-
AntiPrompts = (IReadOnlyList<string>) options.StopSequences.ToList<string>().AsReadOnly(),
395-
MaxTokens = options.MaxTokens.GetValueOrDefault(1024),
396-
SamplingPipeline = (ISamplingPipeline) new DefaultSamplingPipeline()
358+
this._weights = weights;
359+
this._context = context;
360+
this._executor = executor ?? new StatelessExecutor(this._weights, this._context.Params);
361+
this._defaultInferenceParams = inferenceParams;
362+
this.MaxTokenTotal = (int)this._context.ContextSize;
363+
}
364+
365+
public void Dispose()
397366
{
398-
Temperature = (float) options.Temperature,
399-
FrequencyPenalty = (float) options.FrequencyPenalty,
400-
PresencePenalty = (float) options.PresencePenalty,
401-
TopP = (float) options.NucleusSampling
367+
if (this._ownsWeights)
368+
this._weights.Dispose();
369+
if (!this._ownsContext)
370+
return;
371+
this._context.Dispose();
402372
}
403-
};
404-
}
405373

406-
public int CountTokens(string text) => this._context.Tokenize(text, special: true).Length;
374+
public IAsyncEnumerable<GeneratedTextContent> GenerateTextAsync(string prompt, TextGenerationOptions options,
375+
CancellationToken cancellationToken = default)
376+
{
377+
return _executor
378+
.InferAsync(prompt, OptionsToParams(options, _defaultInferenceParams),
379+
cancellationToken: cancellationToken)
380+
.Select(a => new GeneratedTextContent(a));
381+
}
407382

408-
public IReadOnlyList<string> GetTokens(string text)
409-
{
410-
LLamaToken[] source = this._context.Tokenize(text, special: true);
411-
StreamingTokenDecoder decoder = new StreamingTokenDecoder(this._context);
412-
Func<LLamaToken, string> selector = (Func<LLamaToken, string>) (x =>
413-
{
414-
decoder.Add(x);
415-
return decoder.Read();
416-
});
417-
return (IReadOnlyList<string>) ((IEnumerable<LLamaToken>) source).Select<LLamaToken, string>(selector).ToList<string>();
383+
private static InferenceParams OptionsToParams(
384+
TextGenerationOptions options,
385+
InferenceParams? defaultParams)
386+
{
387+
if (defaultParams != (InferenceParams)null)
388+
return defaultParams with
389+
{
390+
AntiPrompts = (IReadOnlyList<string>)defaultParams.AntiPrompts
391+
.Concat<string>((IEnumerable<string>)options.StopSequences).ToList<string>().AsReadOnly(),
392+
MaxTokens = options.MaxTokens ?? defaultParams.MaxTokens,
393+
SamplingPipeline = (ISamplingPipeline)new DefaultSamplingPipeline()
394+
{
395+
Temperature = (float)options.Temperature,
396+
FrequencyPenalty = (float)options.FrequencyPenalty,
397+
PresencePenalty = (float)options.PresencePenalty,
398+
TopP = (float)options.NucleusSampling
399+
}
400+
};
401+
return new InferenceParams()
402+
{
403+
AntiPrompts = (IReadOnlyList<string>)options.StopSequences.ToList<string>().AsReadOnly(),
404+
MaxTokens = options.MaxTokens.GetValueOrDefault(1024),
405+
SamplingPipeline = (ISamplingPipeline)new DefaultSamplingPipeline()
406+
{
407+
Temperature = (float)options.Temperature,
408+
FrequencyPenalty = (float)options.FrequencyPenalty,
409+
PresencePenalty = (float)options.PresencePenalty,
410+
TopP = (float)options.NucleusSampling
411+
}
412+
};
413+
}
414+
415+
public int CountTokens(string text) => this._context.Tokenize(text, special: true).Length;
416+
417+
public IReadOnlyList<string> GetTokens(string text)
418+
{
419+
LLamaToken[] source = this._context.Tokenize(text, special: true);
420+
StreamingTokenDecoder decoder = new StreamingTokenDecoder(this._context);
421+
Func<LLamaToken, string> selector = (Func<LLamaToken, string>)(x =>
422+
{
423+
decoder.Add(x);
424+
return decoder.Read();
425+
});
426+
return (IReadOnlyList<string>)((IEnumerable<LLamaToken>)source).Select<LLamaToken, string>(selector)
427+
.ToList<string>();
428+
}
418429
}
419-
}
420430

421431
[Experimental("KMEXP00")]
422432
public static IKernelMemoryBuilder WithLLamaSharpTextGeneration(
423433
this IKernelMemoryBuilder builder,
424434
LlamaSharpTextGenerator textGenerator)
425435
{
426-
builder.AddSingleton((ITextGenerator) textGenerator);
436+
builder.AddSingleton((ITextGenerator)textGenerator);
427437
return builder;
428438
}
429-
430-
private static readonly ConcurrentDictionary<string, LLamaWeights> ModelCache = new();
439+
440+
public static LLamaWeights? Weights = null;
431441

432442
[Experimental("KMEXP01")]
433443
public static IKernelMemoryBuilder WithLLamaSharpMaINTemp(this IKernelMemoryBuilder builder,
434-
LLamaSharpConfig config, string modelPath, out LlamaSharpTextGenerator generator)
444+
LLamaSharpConfig config, string path, string modelName, out LlamaSharpTextGenerator generator)
435445
{
436-
// Create ModelParams for the first model.
437-
var parameters1 = new ModelParams(modelPath)
438-
{
439-
ContextSize = 1024,
440-
GpuLayerCount = 55,
441-
};
442-
443446
// Load the first model with caching.
444-
var model = GetOrLoadModel(parameters1);
447+
var model = LLMService.GetOrLoadModelAsync(path, modelName).Result;
445448

446449
// Create ModelParams for the second model.
447450
ModelParams parameters2 = new ModelParams(config.ModelPath)
@@ -453,23 +456,16 @@ public static IKernelMemoryBuilder WithLLamaSharpMaINTemp(this IKernelMemoryBuil
453456
//SplitMode = new GPUSplitMode?(config.SplitMode)
454457
};
455458

456-
// Load the second model with caching.
457-
var weights = GetOrLoadModel(parameters2);
459+
Weights ??= LLamaWeights.LoadFromFile(parameters2);
458460

459461
var context = model.CreateContext(parameters2);
460462
StatelessExecutor executor = new StatelessExecutor(model, parameters2);
461463

462464
generator = new LlamaSharpTextGenerator(model, context, executor,
463465
config.DefaultInferenceParams);
464-
465-
builder.WithLLamaSharpTextEmbeddingGeneration(new LLamaSharpTextEmbeddingGenerator(config, weights));
466+
467+
builder.WithLLamaSharpTextEmbeddingGeneration(new LLamaSharpTextEmbeddingGenerator(config, Weights));
466468
builder.WithLLamaSharpTextGeneration(generator);
467469
return builder;
468470
}
469-
470-
private static LLamaWeights GetOrLoadModel(ModelParams modelParams)
471-
{
472-
return LLamaWeights.LoadFromFile(modelParams);
473-
}
474-
475471
}

0 commit comments

Comments
 (0)