Skip to content

Commit 357b185

Browse files
author
Jicheng Lu
committed
add image variation
1 parent b3fd13a commit 357b185

11 files changed

Lines changed: 145 additions & 109 deletions

File tree

src/Infrastructure/BotSharp.Abstraction/Files/IBotSharpFileService.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ Task<IEnumerable<MessageFileModel>> GetChatFiles(string conversationId, string s
4444
#endregion
4545

4646
#region Image
47-
47+
Task<RoleDialogModel> GenerateImage(string? provider, string? model, string text);
48+
Task<RoleDialogModel> VarifyImage(string? provider, string? model, BotSharpFile file);
4849
#endregion
4950

5051
#region Pdf
@@ -54,7 +55,7 @@ Task<IEnumerable<MessageFileModel>> GetChatFiles(string conversationId, string s
5455
/// <param name="prompt"></param>
5556
/// <param name="files">Pdf files</param>
5657
/// <returns></returns>
57-
Task<string> InstructPdf(string? provider, string? model, string? modelId, string prompt, List<BotSharpFile> files);
58+
Task<string> ReadPdf(string? provider, string? model, string? modelId, string prompt, List<BotSharpFile> files);
5859
#endregion
5960

6061
#region User

src/Infrastructure/BotSharp.Abstraction/MLTasks/IImageVariation.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,5 @@ public interface IImageVariation
1515
/// <param name="model">deployment name</param>
1616
void SetModelName(string model);
1717

18-
RoleDialogModel GetImageVariation(Agent agent, RoleDialogModel message, Stream image, string imageFileName);
18+
Task<RoleDialogModel> GetImageVariation(Agent agent, RoleDialogModel message, Stream image, string imageFileName);
1919
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,46 @@
1+
using Microsoft.AspNetCore.StaticFiles;
2+
using System.IO;
3+
14
namespace BotSharp.Core.Files.Services;
25

36
public partial class BotSharpFileService
47
{
8+
public string GetDirectory(string conversationId)
9+
{
10+
var dir = Path.Combine(_dbSettings.FileRepository, CONVERSATION_FOLDER, conversationId, "attachments");
11+
if (!Directory.Exists(dir))
12+
{
13+
Directory.CreateDirectory(dir);
14+
}
15+
return dir;
16+
}
17+
18+
public (string, byte[]) GetFileInfoFromData(string data)
19+
{
20+
if (string.IsNullOrEmpty(data))
21+
{
22+
return (string.Empty, new byte[0]);
23+
}
24+
25+
var typeStartIdx = data.IndexOf(':');
26+
var typeEndIdx = data.IndexOf(';');
27+
var contentType = data.Substring(typeStartIdx + 1, typeEndIdx - typeStartIdx - 1);
28+
29+
var base64startIdx = data.IndexOf(',');
30+
var base64Str = data.Substring(base64startIdx + 1);
31+
32+
return (contentType, Convert.FromBase64String(base64Str));
33+
}
34+
35+
public string GetFileContentType(string filePath)
36+
{
37+
string contentType;
38+
var provider = new FileExtensionContentTypeProvider();
39+
if (!provider.TryGetContentType(filePath, out contentType))
40+
{
41+
contentType = string.Empty;
42+
}
43+
44+
return contentType;
45+
}
546
}
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
using System.IO;
2+
3+
namespace BotSharp.Core.Files.Services;
4+
5+
public partial class BotSharpFileService
6+
{
7+
public async Task<RoleDialogModel> GenerateImage(string? provider, string? model, string text)
8+
{
9+
var completion = CompletionProvider.GetImageGeneration(_services, provider: provider ?? "openai", model: model ?? "dall-e-3");
10+
var message = await completion.GetImageGeneration(new Agent()
11+
{
12+
Id = Guid.Empty.ToString(),
13+
}, new RoleDialogModel(AgentRole.User, text));
14+
return message;
15+
}
16+
17+
public async Task<RoleDialogModel> VarifyImage(string? provider, string? model, BotSharpFile file)
18+
{
19+
if (string.IsNullOrWhiteSpace(file?.FileUrl) && string.IsNullOrWhiteSpace(file?.FileData))
20+
{
21+
throw new ArgumentException($"Please fill in at least file url or file data!");
22+
}
23+
24+
var completion = CompletionProvider.GetImageVariation(_services, provider: provider ?? "openai", model: model ?? "dall-e-2");
25+
var bytes = await DownloadFile(file);
26+
using var stream = new MemoryStream();
27+
stream.Write(bytes, 0, bytes.Length);
28+
stream.Position = 0;
29+
30+
var message = await completion.GetImageVariation(new Agent()
31+
{
32+
Id = Guid.Empty.ToString()
33+
}, new RoleDialogModel(AgentRole.User, string.Empty), stream, file.FileName ?? string.Empty);
34+
stream.Close();
35+
36+
return message;
37+
}
38+
39+
#region Private methods
40+
private async Task<byte[]> DownloadFile(BotSharpFile file)
41+
{
42+
var bytes = new byte[0];
43+
if (!string.IsNullOrEmpty(file.FileUrl))
44+
{
45+
var http = _services.GetRequiredService<IHttpClientFactory>();
46+
using var client = http.CreateClient();
47+
bytes = await client.GetByteArrayAsync(file.FileUrl);
48+
}
49+
else if (!string.IsNullOrEmpty(file.FileData))
50+
{
51+
(_, bytes) = GetFileInfoFromData(file.FileData);
52+
}
53+
54+
return bytes;
55+
}
56+
#endregion
57+
}

src/Infrastructure/BotSharp.Core/Files/Services/BotSharpFileService.ImageGeneration.cs

Lines changed: 0 additions & 5 deletions
This file was deleted.

src/Infrastructure/BotSharp.Core/Files/Services/BotSharpFileService.ImageVariation.cs

Lines changed: 0 additions & 5 deletions
This file was deleted.

src/Infrastructure/BotSharp.Core/Files/Services/BotSharpFileService.Pdf.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ namespace BotSharp.Core.Files.Services;
44

55
public partial class BotSharpFileService
66
{
7-
public async Task<string> InstructPdf(string? provider, string? model, string? modelId, string prompt, List<BotSharpFile> files)
7+
public async Task<string> ReadPdf(string? provider, string? model, string? modelId, string prompt, List<BotSharpFile> files)
88
{
99
var content = string.Empty;
1010

@@ -22,7 +22,7 @@ public async Task<string> InstructPdf(string? provider, string? model, string? m
2222

2323
try
2424
{
25-
var pdfFiles = await SaveFiles(sessionDir, files);
25+
var pdfFiles = await DownloadFiles(sessionDir, files);
2626
var images = await ConvertPdfToImages(pdfFiles);
2727
if (images.IsNullOrEmpty()) return content;
2828

@@ -60,7 +60,7 @@ private string GetSessionDirectory(string id)
6060
return dir;
6161
}
6262

63-
private async Task<IEnumerable<string>> SaveFiles(string dir, List<BotSharpFile> files, string extension = "pdf")
63+
private async Task<IEnumerable<string>> DownloadFiles(string dir, List<BotSharpFile> files, string extension = "pdf")
6464
{
6565
if (string.IsNullOrWhiteSpace(dir) || files.IsNullOrEmpty())
6666
{

src/Infrastructure/BotSharp.Core/Files/Services/BotSharpFileService.cs

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -41,45 +41,6 @@ public BotSharpFileService(
4141
_baseDir = Path.Combine(AppDomain.CurrentDomain.BaseDirectory, dbSettings.FileRepository);
4242
}
4343

44-
public string GetDirectory(string conversationId)
45-
{
46-
var dir = Path.Combine(_dbSettings.FileRepository, CONVERSATION_FOLDER, conversationId, "attachments");
47-
if (!Directory.Exists(dir))
48-
{
49-
Directory.CreateDirectory(dir);
50-
}
51-
return dir;
52-
}
53-
54-
public (string, byte[]) GetFileInfoFromData(string data)
55-
{
56-
if (string.IsNullOrEmpty(data))
57-
{
58-
return (string.Empty, new byte[0]);
59-
}
60-
61-
var typeStartIdx = data.IndexOf(':');
62-
var typeEndIdx = data.IndexOf(';');
63-
var contentType = data.Substring(typeStartIdx + 1, typeEndIdx - typeStartIdx - 1);
64-
65-
var base64startIdx = data.IndexOf(',');
66-
var base64Str = data.Substring(base64startIdx + 1);
67-
68-
return (contentType, Convert.FromBase64String(base64Str));
69-
}
70-
71-
public string GetFileContentType(string filePath)
72-
{
73-
string contentType;
74-
var provider = new FileExtensionContentTypeProvider();
75-
if (!provider.TryGetContentType(filePath, out contentType))
76-
{
77-
contentType = string.Empty;
78-
}
79-
80-
return contentType;
81-
}
82-
8344
#region Private methods
8445
private bool ExistDirectory(string? dir)
8546
{

src/Infrastructure/BotSharp.OpenAPI/Controllers/InstructModeController.cs

Lines changed: 27 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -109,18 +109,14 @@ public async Task<string> MultiModalCompletion([FromBody] IncomingMessageModel i
109109
[HttpPost("/instruct/image-generation")]
110110
public async Task<ImageGenerationViewModel> ImageGeneration([FromBody] IncomingMessageModel input)
111111
{
112+
var fileService = _services.GetRequiredService<IBotSharpFileService>();
112113
var state = _services.GetRequiredService<IConversationStateService>();
113114
input.States.ForEach(x => state.SetState(x.Key, x.Value, activeRounds: x.ActiveRounds, source: StateSource.External));
114115
var imageViewModel = new ImageGenerationViewModel();
115116

116117
try
117118
{
118-
var completion = CompletionProvider.GetImageGeneration(_services, provider: input.Provider ?? "openai", model: input.Model ?? "dall-e-3");
119-
var message = await completion.GetImageGeneration(new Agent()
120-
{
121-
Id = Guid.Empty.ToString(),
122-
}, new RoleDialogModel(AgentRole.User, input.Text));
123-
119+
var message = await fileService.GenerateImage(input.Provider, input.Model, input.Text);
124120
imageViewModel.Content = message.Content;
125121
imageViewModel.Images = message.GeneratedImages.Select(x => ImageViewModel.ToViewModel(x)).ToList();
126122
return imageViewModel;
@@ -134,33 +130,30 @@ public async Task<ImageGenerationViewModel> ImageGeneration([FromBody] IncomingM
134130
}
135131
}
136132

137-
//[HttpPost("/instruct/image-variation")]
138-
//public ImageGenerationViewModel ImageVariation([FromBody] IncomingMessageModel input)
139-
//{
140-
// var state = _services.GetRequiredService<IConversationStateService>();
141-
// input.States.ForEach(x => state.SetState(x.Key, x.Value, activeRounds: x.ActiveRounds, source: StateSource.External));
142-
// var imageViewModel = new ImageGenerationViewModel();
143-
144-
// try
145-
// {
146-
// var completion = CompletionProvider.GetImageVariation(_services, provider: input.Provider ?? "openai", model: input.Model ?? "dall-e-2");
147-
// var message = completion.GetImageVariation(new Agent()
148-
// {
149-
// Id = Guid.Empty.ToString(),
150-
// }, new RoleDialogModel(AgentRole.User, input.Text));
151-
152-
// imageViewModel.Content = message.Content;
153-
// imageViewModel.Images = message.GeneratedImages.Select(x => ImageViewModel.ToViewModel(x)).ToList();
154-
// return imageViewModel;
155-
// }
156-
// catch (Exception ex)
157-
// {
158-
// var error = $"Error in image generation. {ex.Message}";
159-
// _logger.LogError(error);
160-
// imageViewModel.Message = error;
161-
// return imageViewModel;
162-
// }
163-
//}
133+
[HttpPost("/instruct/image-variation")]
134+
public async Task<ImageGenerationViewModel> ImageVariation([FromBody] IncomingMessageModel input)
135+
{
136+
var fileService = _services.GetRequiredService<IBotSharpFileService>();
137+
var state = _services.GetRequiredService<IConversationStateService>();
138+
input.States.ForEach(x => state.SetState(x.Key, x.Value, activeRounds: x.ActiveRounds, source: StateSource.External));
139+
var imageViewModel = new ImageGenerationViewModel();
140+
141+
try
142+
{
143+
var file = input.Files.FirstOrDefault(x => !string.IsNullOrWhiteSpace(x.FileUrl) || !string.IsNullOrWhiteSpace(x.FileData));
144+
var message = await fileService.VarifyImage(input.Provider, input.Model, file);
145+
imageViewModel.Content = message.Content;
146+
imageViewModel.Images = message.GeneratedImages.Select(x => ImageViewModel.ToViewModel(x)).ToList();
147+
return imageViewModel;
148+
}
149+
catch (Exception ex)
150+
{
151+
var error = $"Error in image variation. {ex.Message}";
152+
_logger.LogError(error);
153+
imageViewModel.Message = error;
154+
return imageViewModel;
155+
}
156+
}
164157

165158
[HttpPost("/instruct/pdf-completion")]
166159
public async Task<PdfCompletionViewModel> PdfCompletion([FromBody] IncomingMessageModel input)
@@ -172,7 +165,7 @@ public async Task<PdfCompletionViewModel> PdfCompletion([FromBody] IncomingMessa
172165
try
173166
{
174167
var fileService = _services.GetRequiredService<IBotSharpFileService>();
175-
var content = await fileService.InstructPdf(input.Provider, input.Model, input.ModelId, input.Text, input.Files);
168+
var content = await fileService.ReadPdf(input.Provider, input.Model, input.ModelId, input.Text, input.Files);
176169
viewModel.Content = content;
177170
return viewModel;
178171
}

src/Plugins/BotSharp.Plugin.AzureOpenAI/Providers/Image/ImageGenerationProvider.cs

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -26,18 +26,10 @@ public ImageGenerationProvider(
2626
}
2727

2828

29-
public async Task<RoleDialogModel> GetImageGeneration(Agent agent, List<RoleDialogModel> conversations)
29+
public async Task<RoleDialogModel> GetImageGeneration(Agent agent, RoleDialogModel message)
3030
{
31-
var contentHooks = _services.GetServices<IContentGeneratingHook>().ToList();
32-
33-
// Before
34-
foreach (var hook in contentHooks)
35-
{
36-
await hook.BeforeGenerating(agent, conversations);
37-
}
38-
3931
var client = ProviderHelper.GetClient(Provider, _model, _services);
40-
var (prompt, imageCount, options) = PrepareOptions(conversations);
32+
var (prompt, imageCount, options) = PrepareOptions(message);
4133
var imageClient = client.GetImageClient(_model);
4234

4335
var response = imageClient.GenerateImages(prompt, imageCount, options);
@@ -71,11 +63,12 @@ public async Task<RoleDialogModel> GetImageGeneration(Agent agent, List<RoleDial
7163
var responseMessage = new RoleDialogModel(AgentRole.Assistant, content)
7264
{
7365
CurrentAgentId = agent.Id,
74-
MessageId = conversations.LastOrDefault()?.MessageId ?? string.Empty,
66+
MessageId = message?.MessageId ?? string.Empty,
7567
GeneratedImages = images
7668
};
7769

7870
// After
71+
var contentHooks = _services.GetServices<IContentGeneratingHook>().ToList();
7972
foreach (var hook in contentHooks)
8073
{
8174
await hook.AfterGenerated(responseMessage, new TokenStatsModel
@@ -91,9 +84,14 @@ public async Task<RoleDialogModel> GetImageGeneration(Agent agent, List<RoleDial
9184
return responseMessage;
9285
}
9386

94-
private (string, int, ImageGenerationOptions) PrepareOptions(List<RoleDialogModel> conversations)
87+
public void SetModelName(string model)
9588
{
96-
var prompt = conversations.LastOrDefault()?.Payload ?? conversations.LastOrDefault()?.Content ?? string.Empty;
89+
_model = model;
90+
}
91+
92+
private (string, int, ImageGenerationOptions) PrepareOptions(RoleDialogModel message)
93+
{
94+
var prompt = message?.Payload ?? message?.Content ?? string.Empty;
9795

9896
var state = _services.GetRequiredService<IConversationStateService>();
9997
var size = state.GetState("image_size");
@@ -112,11 +110,6 @@ public async Task<RoleDialogModel> GetImageGeneration(Agent agent, List<RoleDial
112110
return (prompt, count, options);
113111
}
114112

115-
public void SetModelName(string model)
116-
{
117-
_model = model;
118-
}
119-
120113
private GeneratedImageSize GetImageSize(string size)
121114
{
122115
var value = !string.IsNullOrEmpty(size) ? size : "1024x1024";

0 commit comments

Comments
 (0)