@@ -59,13 +59,15 @@ public class LLMService(IOptions<MaINSettings> options, INotificationService not
5959 GpuLayerCount = 30 ,
6060 } ;
6161
62- var session = newSession ? GetOrCreateSession ( chat . Id , ( ) =>
63- {
64- var context = llmModel . CreateContext ( parameters ) ;
65- var history = new ChatHistory ( ) ;
66- var executor = new InteractiveExecutor ( context ) ;
67- return new ChatSession ( executor , history ) ;
68- } ) : new ChatSession ( new InteractiveExecutor ( llmModel . CreateContext ( parameters ) ) ) ;
62+ var session = newSession
63+ ? GetOrCreateSession ( chat . Id , ( ) =>
64+ {
65+ var context = llmModel . CreateContext ( parameters ) ;
66+ var history = new ChatHistory ( ) ;
67+ var executor = new InteractiveExecutor ( context ) ;
68+ return new ChatSession ( executor , history ) ;
69+ } )
70+ : new ChatSession ( new InteractiveExecutor ( llmModel . CreateContext ( parameters ) ) ) ;
6971
7072 // Add all messages to the session history.
7173 AddMessagesToHistory ( session , chat . Messages ) ;
@@ -83,8 +85,11 @@ public class LLMService(IOptions<MaINSettings> options, INotificationService not
8385 if ( lastMessage . Files ? . Any ( ) ?? false )
8486 {
8587#pragma warning disable SKEXP0001
86- var textData = lastMessage . Files . Where ( x => x . Content is not null ) . ToDictionary ( x => x . Name , x => x . Content ) ;
87- var fileData = lastMessage . Files . Where ( x => x . Path is not null ) . ToDictionary ( x => x . Name , x => x . Path ) ; //shity coode TODO
88+ var textData = lastMessage . Files . Where ( x => x . Content is not null )
89+ . ToDictionary ( x => x . Name , x => x . Content ) ;
90+ var fileData =
91+ lastMessage . Files . Where ( x => x . Path is not null )
92+ . ToDictionary ( x => x . Name , x => x . Path ) ; //shity coode TODO
8893 var result = await AskMemory ( chat , textData ! , fileData ! ) ;
8994 resultBuilder . Append ( result ! . Message . Content ) ;
9095#pragma warning restore SKEXP0001
@@ -104,18 +109,19 @@ await notificationService.DispatchNotification(
104109 false ) ,
105110 "ReceiveMessageUpdate" ) ;
106111 }
112+
107113 resultBuilder . Append ( text ) ;
108114 }
109115 }
110116
111117 if ( interactiveUpdates )
112118 {
113- await notificationService . DispatchNotification ( NotificationMessageBuilder . CreateChatCompletion (
119+ await notificationService . DispatchNotification ( NotificationMessageBuilder . CreateChatCompletion (
114120 chat . Id ,
115121 resultBuilder . ToString ( ) ,
116122 true ) , "ReceiveMessageUpdate" ) ;
117123 }
118-
124+
119125 var chatResult = new ChatResult
120126 {
121127 Done = true ,
@@ -143,17 +149,17 @@ await notificationService.DispatchNotification( NotificationMessageBuilder.Creat
143149 using var context = model . CreateContext ( parameters ) ;
144150
145151 // Llava Init
146- var inferenceParams = new InferenceParams ( ) { AntiPrompts = new [ ] { model . Vocab . EOT . ToString ( ) ?? "User:" } } ;
152+ var inferenceParams = new InferenceParams ( ) { AntiPrompts = new [ ] { model . Vocab . EOT . ToString ( ) ?? "User:" } } ;
147153 var ex = new InteractiveExecutor ( context ) ;
148- ex . Context . NativeHandle . KvCacheRemove ( LLamaSeqId . Zero , - 1 , - 1 ) ;
154+ ex . Context . NativeHandle . KvCacheRemove ( LLamaSeqId . Zero , - 1 , - 1 ) ;
149155 ex . Images . Add ( chat . Messages ! . Last ( ) . Images ) ;
150156 var result = new StringBuilder ( ) ;
151157 await foreach ( var text in ex . InferAsync ( chat . Messages ! . Last ( ) . Content , inferenceParams ) )
152158 {
153159 Console . Write ( text ) ;
154160 result . Append ( text ) ;
155161 }
156-
162+
157163 var chatResult = new ChatResult
158164 {
159165 Done = true ,
@@ -200,7 +206,7 @@ private void AddMessagesToHistory(ChatSession session, List<Message> messages)
200206 }
201207
202208 [ Experimental ( "SKEXP0001" ) ]
203- public async Task < ChatResult ? > AskMemory ( Chat chat ,
209+ public async Task < ChatResult ? > AskMemory ( Chat chat ,
204210 Dictionary < string , string > ? textData = null ,
205211 Dictionary < string , string > ? fileData = null ,
206212 List < string > ? memory = null )
@@ -239,7 +245,7 @@ private void AddMessagesToHistory(ChatSession session, List<Message> messages)
239245 var result = await kernelMemory . AskAsync ( userMsg . Content ) ;
240246
241247 await kernelMemory . DeleteIndexAsync ( ) ;
242-
248+
243249 var chatResult = new ChatResult ( )
244250 {
245251 Done = true ,
@@ -251,15 +257,16 @@ private void AddMessagesToHistory(ChatSession session, List<Message> messages)
251257 Role = AuthorRole . Assistant . ToString ( )
252258 }
253259 } ;
254-
260+
255261 generator . Dispose ( ) ;
256262
257263 return chatResult ;
258264 }
259265
260266
261267 [ Experimental ( "KMEXP01" ) ]
262- private static IKernelMemory CreateMemory ( string modelName , string path , out KernelMemFix . LlamaSharpTextGenerator generator )
268+ private static IKernelMemory CreateMemory ( string modelName , string path ,
269+ out KernelMemFix . LlamaSharpTextGenerator generator )
263270 {
264271 InferenceParams infParams = new ( ) { AntiPrompts = [ "INFO" , "<|im_end|>" , "Question:" ] } ;
265272
@@ -282,14 +289,14 @@ private static IKernelMemory CreateMemory(string modelName, string path, out Ker
282289
283290 return new KernelMemoryBuilder ( )
284291 //.WithLLamaSharpDefaults2(lsConfig)
285- . WithLLamaSharpMaINTemp ( lsConfig , Path . Combine ( path , modelName ) , out generator )
292+ . WithLLamaSharpMaINTemp ( lsConfig , path , modelName , out generator )
286293 . WithSearchClientConfig ( searchClientConfig )
287294 . WithCustomImageOcr ( new OcrWrapper ( ) )
288295 . With ( parseOptions )
289296 . Build ( ) ;
290297 }
291298
292- private async Task < LLamaWeights > GetOrLoadModelAsync ( string path , string modelKey )
299+ internal static async Task < LLamaWeights > GetOrLoadModelAsync ( string path , string modelKey )
293300 {
294301 if ( modelCache . TryGetValue ( modelKey , out var cachedModel ) )
295302 {
@@ -328,120 +335,116 @@ public Task CleanSessionCache(string id)
328335}
329336
330337internal static class KernelMemFix
331- {
338+ {
332339 [ Experimental ( "KMEXP00" ) ]
333340 public sealed class LlamaSharpTextGenerator : ITextGenerator , ITextTokenizer , IDisposable
334- {
335- private readonly StatelessExecutor _executor ;
336- private readonly LLamaWeights _weights ;
337- private readonly bool _ownsWeights ;
338- private readonly LLamaContext _context ;
339- private readonly bool _ownsContext ;
340- private readonly InferenceParams ? _defaultInferenceParams ;
341-
342- public int MaxTokenTotal { get ; }
343-
344-
345- public LlamaSharpTextGenerator (
346- LLamaWeights weights ,
347- LLamaContext context ,
348- StatelessExecutor ? executor = null ,
349- InferenceParams ? inferenceParams = null )
350341 {
351- this . _weights = weights ;
352- this . _context = context ;
353- this . _executor = executor ?? new StatelessExecutor ( this . _weights , this . _context . Params ) ;
354- this . _defaultInferenceParams = inferenceParams ;
355- this . MaxTokenTotal = ( int ) this . _context . ContextSize ;
356- }
342+ private readonly StatelessExecutor _executor ;
343+ private readonly LLamaWeights _weights ;
344+ private readonly bool _ownsWeights ;
345+ private readonly LLamaContext _context ;
346+ private readonly bool _ownsContext ;
347+ private readonly InferenceParams ? _defaultInferenceParams ;
357348
358- public void Dispose ( )
359- {
360- if ( this . _ownsWeights )
361- this . _weights . Dispose ( ) ;
362- if ( ! this . _ownsContext )
363- return ;
364- this . _context . Dispose ( ) ;
365- }
349+ public int MaxTokenTotal { get ; }
366350
367- public IAsyncEnumerable < GeneratedTextContent > GenerateTextAsync ( string prompt , TextGenerationOptions options ,
368- CancellationToken cancellationToken = default )
369- {
370- return _executor
371- . InferAsync ( prompt , OptionsToParams ( options , _defaultInferenceParams ) , cancellationToken : cancellationToken )
372- . Select ( a => new GeneratedTextContent ( a ) ) ;
373- }
374351
375- private static InferenceParams OptionsToParams (
376- TextGenerationOptions options ,
377- InferenceParams ? defaultParams )
378- {
379- if ( defaultParams != ( InferenceParams ) null )
380- return defaultParams with
352+ public LlamaSharpTextGenerator (
353+ LLamaWeights weights ,
354+ LLamaContext context ,
355+ StatelessExecutor ? executor = null ,
356+ InferenceParams ? inferenceParams = null )
381357 {
382- AntiPrompts = ( IReadOnlyList < string > ) defaultParams . AntiPrompts . Concat < string > ( ( IEnumerable < string > ) options . StopSequences ) . ToList < string > ( ) . AsReadOnly ( ) ,
383- MaxTokens = options . MaxTokens ?? defaultParams . MaxTokens ,
384- SamplingPipeline = ( ISamplingPipeline ) new DefaultSamplingPipeline ( )
385- {
386- Temperature = ( float ) options . Temperature ,
387- FrequencyPenalty = ( float ) options . FrequencyPenalty ,
388- PresencePenalty = ( float ) options . PresencePenalty ,
389- TopP = ( float ) options . NucleusSampling
390- }
391- } ;
392- return new InferenceParams ( )
393- {
394- AntiPrompts = ( IReadOnlyList < string > ) options . StopSequences . ToList < string > ( ) . AsReadOnly ( ) ,
395- MaxTokens = options . MaxTokens . GetValueOrDefault ( 1024 ) ,
396- SamplingPipeline = ( ISamplingPipeline ) new DefaultSamplingPipeline ( )
358+ this . _weights = weights ;
359+ this . _context = context ;
360+ this . _executor = executor ?? new StatelessExecutor ( this . _weights , this . _context . Params ) ;
361+ this . _defaultInferenceParams = inferenceParams ;
362+ this . MaxTokenTotal = ( int ) this . _context . ContextSize ;
363+ }
364+
365+ public void Dispose ( )
397366 {
398- Temperature = ( float ) options . Temperature ,
399- FrequencyPenalty = ( float ) options . FrequencyPenalty ,
400- PresencePenalty = ( float ) options . PresencePenalty ,
401- TopP = ( float ) options . NucleusSampling
367+ if ( this . _ownsWeights )
368+ this . _weights . Dispose ( ) ;
369+ if ( ! this . _ownsContext )
370+ return ;
371+ this . _context . Dispose ( ) ;
402372 }
403- } ;
404- }
405373
406- public int CountTokens ( string text ) => this . _context . Tokenize ( text , special : true ) . Length ;
374+ public IAsyncEnumerable < GeneratedTextContent > GenerateTextAsync ( string prompt , TextGenerationOptions options ,
375+ CancellationToken cancellationToken = default )
376+ {
377+ return _executor
378+ . InferAsync ( prompt , OptionsToParams ( options , _defaultInferenceParams ) ,
379+ cancellationToken : cancellationToken )
380+ . Select ( a => new GeneratedTextContent ( a ) ) ;
381+ }
407382
408- public IReadOnlyList < string > GetTokens ( string text )
409- {
410- LLamaToken [ ] source = this . _context . Tokenize ( text , special : true ) ;
411- StreamingTokenDecoder decoder = new StreamingTokenDecoder ( this . _context ) ;
412- Func < LLamaToken , string > selector = ( Func < LLamaToken , string > ) ( x =>
413- {
414- decoder . Add ( x ) ;
415- return decoder . Read ( ) ;
416- } ) ;
417- return ( IReadOnlyList < string > ) ( ( IEnumerable < LLamaToken > ) source ) . Select < LLamaToken , string > ( selector ) . ToList < string > ( ) ;
383+ private static InferenceParams OptionsToParams (
384+ TextGenerationOptions options ,
385+ InferenceParams ? defaultParams )
386+ {
387+ if ( defaultParams != ( InferenceParams ) null )
388+ return defaultParams with
389+ {
390+ AntiPrompts = ( IReadOnlyList < string > ) defaultParams . AntiPrompts
391+ . Concat < string > ( ( IEnumerable < string > ) options . StopSequences ) . ToList < string > ( ) . AsReadOnly ( ) ,
392+ MaxTokens = options . MaxTokens ?? defaultParams . MaxTokens ,
393+ SamplingPipeline = ( ISamplingPipeline ) new DefaultSamplingPipeline ( )
394+ {
395+ Temperature = ( float ) options . Temperature ,
396+ FrequencyPenalty = ( float ) options . FrequencyPenalty ,
397+ PresencePenalty = ( float ) options . PresencePenalty ,
398+ TopP = ( float ) options . NucleusSampling
399+ }
400+ } ;
401+ return new InferenceParams ( )
402+ {
403+ AntiPrompts = ( IReadOnlyList < string > ) options . StopSequences . ToList < string > ( ) . AsReadOnly ( ) ,
404+ MaxTokens = options . MaxTokens . GetValueOrDefault ( 1024 ) ,
405+ SamplingPipeline = ( ISamplingPipeline ) new DefaultSamplingPipeline ( )
406+ {
407+ Temperature = ( float ) options . Temperature ,
408+ FrequencyPenalty = ( float ) options . FrequencyPenalty ,
409+ PresencePenalty = ( float ) options . PresencePenalty ,
410+ TopP = ( float ) options . NucleusSampling
411+ }
412+ } ;
413+ }
414+
415+ public int CountTokens ( string text ) => this . _context . Tokenize ( text , special : true ) . Length ;
416+
417+ public IReadOnlyList < string > GetTokens ( string text )
418+ {
419+ LLamaToken [ ] source = this . _context . Tokenize ( text , special : true ) ;
420+ StreamingTokenDecoder decoder = new StreamingTokenDecoder ( this . _context ) ;
421+ Func < LLamaToken , string > selector = ( Func < LLamaToken , string > ) ( x =>
422+ {
423+ decoder . Add ( x ) ;
424+ return decoder . Read ( ) ;
425+ } ) ;
426+ return ( IReadOnlyList < string > ) ( ( IEnumerable < LLamaToken > ) source ) . Select < LLamaToken , string > ( selector )
427+ . ToList < string > ( ) ;
428+ }
418429 }
419- }
420430
421431 [ Experimental ( "KMEXP00" ) ]
422432 public static IKernelMemoryBuilder WithLLamaSharpTextGeneration (
423433 this IKernelMemoryBuilder builder ,
424434 LlamaSharpTextGenerator textGenerator )
425435 {
426- builder . AddSingleton ( ( ITextGenerator ) textGenerator ) ;
436+ builder . AddSingleton ( ( ITextGenerator ) textGenerator ) ;
427437 return builder ;
428438 }
429-
430- private static readonly ConcurrentDictionary < string , LLamaWeights > ModelCache = new ( ) ;
439+
440+ public static LLamaWeights ? Weights = null ;
431441
432442 [ Experimental ( "KMEXP01" ) ]
433443 public static IKernelMemoryBuilder WithLLamaSharpMaINTemp ( this IKernelMemoryBuilder builder ,
434- LLamaSharpConfig config , string modelPath , out LlamaSharpTextGenerator generator )
444+ LLamaSharpConfig config , string path , string modelName , out LlamaSharpTextGenerator generator )
435445 {
436- // Create ModelParams for the first model.
437- var parameters1 = new ModelParams ( modelPath )
438- {
439- ContextSize = 1024 ,
440- GpuLayerCount = 55 ,
441- } ;
442-
443446 // Load the first model with caching.
444- var model = GetOrLoadModel ( parameters1 ) ;
447+ var model = LLMService . GetOrLoadModelAsync ( path , modelName ) . Result ;
445448
446449 // Create ModelParams for the second model.
447450 ModelParams parameters2 = new ModelParams ( config . ModelPath )
@@ -453,23 +456,16 @@ public static IKernelMemoryBuilder WithLLamaSharpMaINTemp(this IKernelMemoryBuil
453456 //SplitMode = new GPUSplitMode?(config.SplitMode)
454457 } ;
455458
456- // Load the second model with caching.
457- var weights = GetOrLoadModel ( parameters2 ) ;
459+ Weights ??= LLamaWeights . LoadFromFile ( parameters2 ) ;
458460
459461 var context = model . CreateContext ( parameters2 ) ;
460462 StatelessExecutor executor = new StatelessExecutor ( model , parameters2 ) ;
461463
462464 generator = new LlamaSharpTextGenerator ( model , context , executor ,
463465 config . DefaultInferenceParams ) ;
464-
465- builder . WithLLamaSharpTextEmbeddingGeneration ( new LLamaSharpTextEmbeddingGenerator ( config , weights ) ) ;
466+
467+ builder . WithLLamaSharpTextEmbeddingGeneration ( new LLamaSharpTextEmbeddingGenerator ( config , Weights ) ) ;
466468 builder . WithLLamaSharpTextGeneration ( generator ) ;
467469 return builder ;
468470 }
469-
470- private static LLamaWeights GetOrLoadModel ( ModelParams modelParams )
471- {
472- return LLamaWeights . LoadFromFile ( modelParams ) ;
473- }
474-
475471}
0 commit comments