Skip to content

Commit 217732e

Browse files
committed
Refactor: replace hardcoded image gen model constants with central ModelRegistry
Introduce new image-generation models and centralize image-gen logic: added Models constants (Imagen4_0_FastGenerate, Flux1Shnell) and new Cloud/Local model records for those models. Replace ad-hoc FLUX checks with ModelRegistry.TryGetById(...).HasImageGeneration across ChatMapper, AgentService, ChatService, and AgentStateManager. Update image generation services (Gemini, OpenAI, Vertex, Xai) to resolve default model IDs from Models, pass the resolved model through to ChatResult, and remove per-service hardcoded model constants. Also add necessary using/import adjustments and improve chat model availability handling in ChatService.
1 parent 8c26bf9 commit 217732e

12 files changed

Lines changed: 55 additions & 57 deletions

File tree

src/MaIN.Domain/Models/Concrete/CloudModels.cs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,13 @@ public sealed record Gemini2_5Pro() : CloudModel(
103103
public string? MMProjectName => null;
104104
}
105105

106+
public sealed record GeminiImagen4_0FastGenerate() : CloudModel(
107+
Models.Gemini.Imagen4_0_FastGenerate,
108+
BackendType.Gemini,
109+
"Imagen 4.0 Fast (Gemini)",
110+
4000,
111+
"Google's fast image generation model via Gemini API"), IImageGenerationModel;
112+
106113
// ===== Vertex AI Models =====
107114

108115
public sealed record VertexGemini2_5Pro() : CloudModel(

src/MaIN.Domain/Models/Concrete/LocalModels.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,16 @@ public sealed record Olmo2_7b() : LocalModel(
293293
8192,
294294
"Open-source 7B model for research, benchmarking, and academic studies");
295295

296+
// ===== Image Generation =====
297+
298+
public sealed record Flux1Shnell() : LocalModel(
299+
Models.Local.Flux1Shnell,
300+
"FLUX.1_Shnell",
301+
null,
302+
"FLUX.1 Schnell",
303+
4096,
304+
"Fast local image generation model"), IImageGenerationModel;
305+
296306
// ===== Embedding Model =====
297307

298308
public sealed record Mxbai_Embedding() : LocalModel(

src/MaIN.Domain/Models/Models.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ public static class Gemini
2626
public const string Gemini2_5Pro = "gemini-2.5-pro";
2727
public const string Gemini2_5Flash = "gemini-2.5-flash";
2828
public const string Gemini2_0Flash = "gemini-2.0-flash";
29+
public const string Imagen4_0_FastGenerate = "imagen-4.0-fast-generate-001";
2930
}
3031

3132
public static class Xai

src/MaIN.Services/Mappers/ChatMapper.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using MaIN.Domain.Entities;
22
using MaIN.Domain.Models;
3+
using MaIN.Domain.Models.Abstract;
34
using MaIN.Services.Dtos;
4-
using MaIN.Services.Services.ImageGenServices;
55
using FileInfo = MaIN.Domain.Entities.FileInfo;
66

77
namespace MaIN.Services.Mappers;
@@ -44,7 +44,7 @@ public static Chat ToDomain(this ChatDto chat)
4444
Name = chat.Name!,
4545
ModelId = chat.Model!,
4646
Messages = chat.Messages?.Select(m => m.ToDomain()).ToList()!,
47-
ImageGen = chat.Model == ImageGenService.LocalImageModels.FLUX,
47+
ImageGen = ModelRegistry.TryGetById(chat.Model!, out var m) && m!.HasImageGeneration,
4848
Type = Enum.Parse<ChatType>(chat.Type.ToString()),
4949
Properties = chat.Properties
5050
};

src/MaIN.Services/Services/AgentService.cs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
using MaIN.Domain.Repositories;
1010
using MaIN.Services.Constants;
1111
using MaIN.Services.Services.Abstract;
12-
using MaIN.Services.Services.ImageGenServices;
1312
using MaIN.Services.Services.LLMService.Factory;
1413
using MaIN.Services.Services.Models.Commands;
1514
using MaIN.Services.Services.Steps.Commands.Abstract;
@@ -101,7 +100,7 @@ public async Task<Agent> CreateAgent(Agent agent, bool flow = false, bool intera
101100
Id = Guid.NewGuid().ToString(),
102101
ModelId = agent.Model,
103102
Name = agent.Name,
104-
ImageGen = agent.Model == ImageGenService.LocalImageModels.FLUX,
103+
ImageGen = ModelRegistry.TryGetById(agent.Model, out var agentModel) && agentModel!.HasImageGeneration,
105104
ToolsConfiguration = agent.ToolsConfiguration,
106105
BackendParams = inferenceParams ?? new LocalInferenceParams(),
107106
MemoryParams = memoryParams ?? new MemoryParams(),

src/MaIN.Services/Services/ChatService.cs

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
using MaIN.Domain.Models.Abstract;
77
using MaIN.Domain.Repositories;
88
using MaIN.Services.Services.Abstract;
9-
using MaIN.Services.Services.ImageGenServices;
109
using MaIN.Services.Services.LLMService;
1110
using MaIN.Services.Services.LLMService.Factory;
1211
using MaIN.Services.Services.Models;
@@ -34,14 +33,14 @@ public async Task<ChatResult> Completions(
3433
Func<LLMTokenValue?, Task>? changeOfValue = null,
3534
CancellationToken cancellationToken = default)
3635
{
37-
if (chat.ModelId == ImageGenService.LocalImageModels.FLUX)
36+
if (!ModelRegistry.TryGetById(chat.ModelId, out var model))
3837
{
39-
chat.ImageGen = true;
38+
throw new ChatModelNotAvailableException(chat.Id, chat.ModelId);
4039
}
4140

42-
if (!ModelRegistry.TryGetById(chat.ModelId, out var model))
41+
if (model!.HasImageGeneration)
4342
{
44-
throw new ChatModelNotAvailableException(chat.Id, chat.ModelId);
43+
chat.ImageGen = true;
4544
}
4645

4746
var backend = model!.Backend;

src/MaIN.Services/Services/ImageGenServices/GeminiImageGenService.cs

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
using MaIN.Services.Constants;
44
using MaIN.Services.Services.Abstract;
55
using MaIN.Services.Services.Models;
6+
using ModelIds = MaIN.Domain.Models.Models;
67
using System.Net.Http.Headers;
78
using System.Net.Http.Json;
89
using System.Text.Json.Serialization;
@@ -22,22 +23,19 @@ internal class GeminiImageGenService(IHttpClientFactory httpClientFactory, MaINS
2223
string apiKey = _settings.GeminiKey ?? Environment.GetEnvironmentVariable(LLMApiRegistry.Gemini.ApiKeyEnvName)
2324
?? throw new APIKeyNotConfiguredException(LLMApiRegistry.Gemini.ApiName);
2425

25-
if (string.IsNullOrEmpty(chat.ModelId))
26-
{
27-
chat.ModelId = Models.IMAGEN_GENERATE;
28-
}
26+
var model = string.IsNullOrEmpty(chat.ModelId) ? ModelIds.Gemini.Imagen4_0_FastGenerate : chat.ModelId;
2927
client.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", apiKey);
3028
var requestBody = new
3129
{
32-
model = chat.ModelId,
30+
model,
3331
prompt = BuildPromptFromChat(chat),
3432
response_format = "b64_json", // necessary for gemini api
3533
size = ServiceConstants.Defaults.ImageSize,
3634
};
3735

3836
using var response = await client.PostAsJsonAsync(ServiceConstants.ApiUrls.GeminiImageGenerations, requestBody);
3937
var imageBytes = await ProcessGeminiResponse(response);
40-
return CreateChatResult(imageBytes);
38+
return CreateChatResult(imageBytes, model);
4139
}
4240

4341
private static string BuildPromptFromChat(Chat chat)
@@ -61,7 +59,7 @@ private async Task<byte[]> ProcessGeminiResponse(HttpResponseMessage response)
6159
return Convert.FromBase64String(base64Image);
6260
}
6361

64-
private static ChatResult CreateChatResult(byte[] imageBytes)
62+
private static ChatResult CreateChatResult(byte[] imageBytes, string model)
6563
{
6664
return new ChatResult
6765
{
@@ -73,15 +71,10 @@ private static ChatResult CreateChatResult(byte[] imageBytes)
7371
Image = imageBytes,
7472
Type = MessageType.Image
7573
},
76-
Model = Models.IMAGEN_GENERATE,
74+
Model = model,
7775
CreatedAt = DateTime.UtcNow
7876
};
7977
}
80-
81-
private struct Models
82-
{
83-
public const string IMAGEN_GENERATE = "imagen-4.0-fast-generate-001";
84-
}
8578
}
8679

8780
file class GeminiImageResponse

src/MaIN.Services/Services/ImageGenServices/ImageGenService.cs

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
using MaIN.Services.Constants;
44
using MaIN.Services.Services.Abstract;
55
using MaIN.Services.Services.Models;
6+
using ModelIds = MaIN.Domain.Models.Models;
67

78
namespace MaIN.Services.Services.ImageGenServices;
89

@@ -48,13 +49,8 @@ private static ChatResult CreateChatResult(byte[] imageBytes)
4849
Image = imageBytes,
4950
Type = MessageType.Image
5051
},
51-
Model = LocalImageModels.FLUX,
52+
Model = ModelIds.Local.Flux1Shnell,
5253
CreatedAt = DateTime.UtcNow
5354
};
5455
}
55-
56-
internal struct LocalImageModels
57-
{
58-
public const string FLUX = "FLUX.1_Shnell";
59-
}
6056
}

src/MaIN.Services/Services/ImageGenServices/OpenAiImageGenService.cs

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using MaIN.Services.Constants;
66
using MaIN.Services.Services.Abstract;
77
using MaIN.Services.Services.Models;
8+
using ModelIds = MaIN.Domain.Models.Models;
89
using System.Net.Http.Headers;
910
using System.Net.Http.Json;
1011
using System.Text.Json.Serialization;
@@ -22,21 +23,22 @@ public class OpenAiImageGenService(
2223
public async Task<ChatResult?> Send(Chat chat)
2324
{
2425
var client = _httpClientFactory.CreateClient(ServiceConstants.HttpClients.OpenAiClient);
25-
string apiKey = _settings.OpenAiKey ?? Environment.GetEnvironmentVariable(LLMApiRegistry.OpenAi.ApiKeyEnvName)
26+
string apiKey = _settings.OpenAiKey ?? Environment.GetEnvironmentVariable(LLMApiRegistry.OpenAi.ApiKeyEnvName)
2627
?? throw new APIKeyNotConfiguredException(LLMApiRegistry.OpenAi.ApiName);
27-
28+
29+
var model = string.IsNullOrEmpty(chat.ModelId) ? ModelIds.OpenAi.DallE3 : chat.ModelId;
2830
client.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", apiKey);
2931
var requestBody = new
3032
{
31-
model = chat.ModelId,
33+
model,
3234
prompt = BuildPromptFromChat(chat),
3335
size = ServiceConstants.Defaults.ImageSize
3436
};
3537

3638
using var response = await client.PostAsJsonAsync(ServiceConstants.ApiUrls.OpenAiImageGenerations, requestBody);
3739

3840
byte[] imageBytes = await ProcessOpenAiResponse(response);
39-
return CreateChatResult(imageBytes);
41+
return CreateChatResult(imageBytes, model);
4042
}
4143

4244
private static string BuildPromptFromChat(Chat chat)
@@ -73,7 +75,7 @@ private async Task<byte[]> ProcessOpenAiResponse(HttpResponseMessage response)
7375
throw new InvalidOperationException("No image URL or base64 data returned from OpenAI");
7476
}
7577

76-
private static ChatResult CreateChatResult(byte[] imageBytes)
78+
private static ChatResult CreateChatResult(byte[] imageBytes, string model)
7779
{
7880
return new ChatResult
7981
{
@@ -85,15 +87,10 @@ private static ChatResult CreateChatResult(byte[] imageBytes)
8587
Image = imageBytes,
8688
Type = MessageType.Image
8789
},
88-
Model = Models.DALLE,
90+
Model = model,
8991
CreatedAt = DateTime.UtcNow
9092
};
9193
}
92-
93-
private struct Models
94-
{
95-
public const string DALLE = "dall-e-3";
96-
}
9794
}
9895

9996
file class OpenAiImageResponse

src/MaIN.Services/Services/ImageGenServices/VertexImageGenService.cs

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using MaIN.Services.Services.Abstract;
66
using MaIN.Services.Services.LLMService.Auth;
77
using MaIN.Services.Services.Models;
8+
using ModelIds = MaIN.Domain.Models.Models;
89
using System.Net.Http.Headers;
910
using System.Net.Http.Json;
1011
using System.Text.Json.Serialization;
@@ -13,7 +14,6 @@ namespace MaIN.Services.Services.ImageGenServices;
1314

1415
internal class VertexImageGenService(IHttpClientFactory httpClientFactory, MaINSettings settings) : IImageGenService
1516
{
16-
private const string DefaultModel = "imagen-4.0-generate-001";
1717
private const string DefaultLocation = "us-central1";
1818

1919
public async Task<ChatResult?> Send(Chat chat)
@@ -76,7 +76,7 @@ internal class VertexImageGenService(IHttpClientFactory httpClientFactory, MaINS
7676
Image = imageBytes,
7777
Type = MessageType.Image
7878
},
79-
Model = chat.ModelId ?? $"google/{DefaultModel}",
79+
Model = string.IsNullOrEmpty(chat.ModelId) ? ModelIds.Vertex.Imagen4_0_Generate : chat.ModelId,
8080
CreatedAt = DateTime.UtcNow
8181
};
8282
}
@@ -93,12 +93,11 @@ private static string BuildPromptFromChat(Chat chat)
9393
/// </summary>
9494
private static string ExtractModelName(string? modelId)
9595
{
96-
if (string.IsNullOrEmpty(modelId))
97-
return DefaultModel;
96+
var resolved = string.IsNullOrEmpty(modelId) ? ModelIds.Vertex.Imagen4_0_Generate : modelId;
9897

99-
return modelId.StartsWith("google/", StringComparison.OrdinalIgnoreCase)
100-
? modelId["google/".Length..]
101-
: modelId;
98+
return resolved.StartsWith("google/", StringComparison.OrdinalIgnoreCase)
99+
? resolved["google/".Length..]
100+
: resolved;
102101
}
103102
}
104103

0 commit comments

Comments
 (0)