forked from dotnet/dev-proxy
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathBaseLanguageModelClient.cs
More file actions
138 lines (108 loc) · 5.4 KB
/
BaseLanguageModelClient.cs
File metadata and controls
138 lines (108 loc) · 5.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using DevProxy.Abstractions.Prompty;
using DevProxy.Abstractions.Utils;
using Microsoft.Extensions.Logging;
using System.Collections.Concurrent;
namespace DevProxy.Abstractions.LanguageModel;
public abstract class BaseLanguageModelClient(LanguageModelConfiguration configuration, ILogger logger) : ILanguageModelClient
{
protected LanguageModelConfiguration Configuration { get; } = configuration;
protected ILogger Logger { get; } = logger;
private bool? _lmAvailable;
private readonly ConcurrentDictionary<string, (IEnumerable<ILanguageModelChatCompletionMessage>?, CompletionOptions?)> _promptCache = new();
public virtual async Task<ILanguageModelCompletionResponse?> GenerateChatCompletionAsync(string promptFileName, Dictionary<string, object> parameters, CancellationToken cancellationToken)
{
ArgumentNullException.ThrowIfNull(promptFileName, nameof(promptFileName));
if (!promptFileName.EndsWith(".prompty", StringComparison.OrdinalIgnoreCase))
{
Logger.LogDebug("Prompt file name '{PromptFileName}' does not end with '.prompty'. Appending the extension.", promptFileName);
promptFileName += ".prompty";
}
var cacheKey = GetPromptCacheKey(promptFileName, parameters);
var (messages, options) = _promptCache.GetOrAdd(cacheKey, _ =>
LoadPrompt(promptFileName, parameters));
if (messages is null || !messages.Any())
{
return null;
}
return await GenerateChatCompletionAsync(messages, options, cancellationToken);
}
public async Task<ILanguageModelCompletionResponse?> GenerateChatCompletionAsync(IEnumerable<ILanguageModelChatCompletionMessage> messages, CompletionOptions? options, CancellationToken cancellationToken)
{
if (Configuration is null)
{
return null;
}
if (!await IsEnabledAsync(cancellationToken))
{
Logger.LogDebug("Language model is not available.");
return null;
}
return await GenerateChatCompletionCoreAsync(messages, options, cancellationToken);
}
public async Task<ILanguageModelCompletionResponse?> GenerateCompletionAsync(string prompt, CompletionOptions? options, CancellationToken cancellationToken)
{
if (Configuration is null)
{
return null;
}
if (!await IsEnabledAsync(cancellationToken))
{
Logger.LogDebug("Language model is not available.");
return null;
}
return await GenerateCompletionCoreAsync(prompt, options, cancellationToken);
}
public async Task<bool> IsEnabledAsync(CancellationToken cancellationToken)
{
if (_lmAvailable.HasValue)
{
return _lmAvailable.Value;
}
_lmAvailable = await IsEnabledCoreAsync(cancellationToken);
return _lmAvailable.Value;
}
protected abstract IEnumerable<ILanguageModelChatCompletionMessage> ConvertMessages(IEnumerable<ChatMessage> messages);
protected abstract Task<ILanguageModelCompletionResponse?> GenerateChatCompletionCoreAsync(IEnumerable<ILanguageModelChatCompletionMessage> messages, CompletionOptions? options, CancellationToken cancellationToken);
protected abstract Task<ILanguageModelCompletionResponse?> GenerateCompletionCoreAsync(string prompt, CompletionOptions? options, CancellationToken cancellationToken);
protected abstract Task<bool> IsEnabledCoreAsync(CancellationToken cancellationToken);
private (IEnumerable<ILanguageModelChatCompletionMessage>?, CompletionOptions?) LoadPrompt(string promptFileName, Dictionary<string, object> parameters)
{
Logger.LogDebug("Prompt file {PromptFileName} not in the cache. Loading...", promptFileName);
var filePath = Path.Combine(ProxyUtils.AppFolder!, "prompts", promptFileName);
if (!File.Exists(filePath))
{
throw new FileNotFoundException($"Prompt file '{filePath}' not found.");
}
Logger.LogDebug("Loading prompt file: {FilePath}", filePath);
var promptContents = File.ReadAllText(filePath);
var prompty = Prompt.FromMarkdown(promptContents);
if (prompty.Prepare(parameters) is not IEnumerable<ChatMessage> promptyMessages ||
!promptyMessages.Any())
{
Logger.LogError("No messages found in the prompt file: {FilePath}", filePath);
return (null, null);
}
var messages = ConvertMessages(promptyMessages);
var options = new CompletionOptions();
if (prompty.Model?.Options is not null)
{
if (prompty.Model.Options.TryGetValue("temperature", out var temperature))
{
options.Temperature = temperature as double?;
}
if (prompty.Model.Options.TryGetValue("top_p", out var topP))
{
options.TopP = topP as double?;
}
}
return (messages, options);
}
private static string GetPromptCacheKey(string promptFileName, Dictionary<string, object> parameters)
{
var parametersString = string.Join(";", parameters.Select(kvp => $"{kvp.Key}={kvp.Value}"));
return $"{promptFileName}:{parametersString}";
}
}