diff --git a/docusaurus/docs/llm-providers/exo.md b/docusaurus/docs/llm-providers/exo.md new file mode 100644 index 00000000..ab73b73e --- /dev/null +++ b/docusaurus/docs/llm-providers/exo.md @@ -0,0 +1,228 @@ +--- +sidebar_position: 3 +title: Exo - Distributed AI Cluster +description: Run large AI models across multiple Apple Silicon devices using Exo with DevoxxGenie. +keywords: [devoxxgenie, exo, distributed, cluster, apple silicon, thunderbolt, mlx, local llm] +image: /img/devoxxgenie-social-card.jpg +--- + +# Exo - Distributed AI Cluster + +![Exo dashboard showing a multi-node cluster with Tensor sharding](/img/exo-dashboard-cluster-view.png) + +[Exo](https://github.com/exo-explore/exo) lets you run frontier AI models by clustering multiple devices together. Models that are too large for a single machine get split across your cluster, enabling you to run models like MiniMax M2.5 (173GB) or Llama 3.3 70B locally. + +DevoxxGenie integrates directly with Exo, automatically managing model instances so you can focus on coding. + +## Why Exo? + +| Feature | Benefit | +|---------|---------| +| **Distributed inference** | Run models too large for any single device | +| **Automatic device discovery** | Devices find each other without manual setup | +| **Thunderbolt / RDMA support** | Near-native speed between devices | +| **Pipeline & Tensor parallelism** | Up to 3.2x speedup on 4 devices | +| **MLX backend** | Optimized for Apple Silicon | +| **No API costs** | Run everything locally, privately | + +## Prerequisites + +- **macOS Tahoe 26.2+** on all devices +- **Apple Silicon Macs** (M1/M2/M3/M4 series) +- **Thunderbolt cable** connecting devices (for best performance) +- **Exo installed** on all devices in the cluster + +## Installing Exo + +### Option 1: Native App (Recommended) + +Download the latest `EXO-latest.dmg` from the [Exo releases page](https://github.com/exo-explore/exo/releases). The app runs in the background and includes a web dashboard for cluster management. + +### Option 2: From Source + +```bash +git clone https://github.com/exo-explore/exo +cd exo/dashboard && npm install && npm run build && cd .. +uv run exo +``` + +:::tip +Install Exo on **every device** in your cluster. They will discover each other automatically over the network. +::: + +## Setting Up Your Cluster + +### 1. Connect Devices + +Connect your Macs via **Thunderbolt** for best performance. Exo also works over regular networking but Thunderbolt provides significantly lower latency. + +**Example cluster:** +- MacBook Pro M4 Max (128GB RAM) + Mac Studio M1 Ultra (128GB RAM) +- Combined: 256GB RAM for model inference + +### 2. Enable RDMA (Optional, for Thunderbolt 5) + +For Thunderbolt 5 connections, enable RDMA for maximum performance: + +1. Boot into Recovery Mode +2. Run: `rdma_ctl enable` +3. Restart + +### 3. Start Exo + +Launch Exo on each device. Open the Exo dashboard at `http://localhost:52415` to verify your cluster: + +- All devices should appear in the cluster view +- Thunderbolt connections should show as active +- Device memory and GPU stats should be visible + +![Exo Dashboard showing a two-node cluster with a MiniMax M2.5 instance running](/img/exo-instanceready.png) + +*The Exo dashboard showing a MacBook Pro M4 Max and Mac Studio M1 Ultra cluster, with a MiniMax M2.5 instance ready for chat.* + +### 4. Download Models + +Use the Exo dashboard to download models to your cluster. Models are stored in `~/.exo/models/` on each device. + +:::info +DevoxxGenie only shows models that are **fully downloaded** on your cluster. If you don't see a model in the dropdown, check the Exo dashboard to verify the download completed. +::: + +## Configuring DevoxxGenie + +### 1. Enable Exo Provider + +1. Open IntelliJ IDEA **Settings** > **Tools** > **DevoxxGenie** > **Large Language Models** +2. Find **Exo URL** in the Local LLM Providers section +3. **Enable** the checkbox +4. Set the URL (default: `http://localhost:52415/v1/`) + +### 2. Select a Model + +1. In the DevoxxGenie panel, select **Exo** from the provider dropdown +2. Choose a model from the model dropdown (only downloaded models are shown) +3. A **background task** will start to prepare the model instance across your cluster + +:::note +When you select a model, DevoxxGenie automatically: +1. Previews placements across your cluster +2. Creates an optimal instance (Pipeline or Tensor sharding) +3. Waits for all runners to warm up +4. Notifies you when the instance is ready +::: + +![DevoxxGenie warming up Exo model runners with progress bar](/img/exo-warmingup.png) + +*DevoxxGenie shows a progress bar while the Exo model instance is loading across your cluster. The notification confirms the downloaded models are available.* + +### 3. Start Chatting + +Once the instance is ready (you'll see a notification), you can start chatting with the model just like any other provider. + +![DevoxxGenie with Exo cluster panel showing two nodes and active instance](/img/exo-view.png) + +*The Exo cluster panel appears above the chat when Exo is selected, showing connected nodes with memory usage, GPU stats, and the active model instance status. Click the header to collapse it.* + +## How It Works + +### Instance Management + +When you select an Exo model in DevoxxGenie: + +1. **Placement preview** - DevoxxGenie queries Exo to find how the model can be distributed across your cluster +2. **Instance creation** - The optimal placement is selected and an instance is created +3. **Model loading** - Each device loads its portion of the model into memory +4. **Runner warmup** - The inference pipeline is initialized +5. **Ready** - The model is ready for chat via the OpenAI-compatible API + +### Automatic Recovery + +If the Exo instance disconnects or gets recycled: +- DevoxxGenie detects the disconnection automatically +- A new instance is prepared in the background +- You'll be notified when the model is ready again + +### Sharding Strategies + +Exo supports two sharding strategies: + +- **Pipeline parallelism** - Model layers are split across devices. Device A processes layers 0-30, Device B processes layers 31-62. +- **Tensor parallelism** - Each layer is split across devices. Both devices process every layer together, achieving higher throughput. + +DevoxxGenie automatically selects the best available strategy. + +## Supported Models + +Exo supports any model from the [MLX Community on HuggingFace](https://huggingface.co/mlx-community). Popular models include: + +| Model | Size | Min Cluster RAM | +|-------|------|----------------| +| Llama 3.2 1B Instruct 4bit | ~1 GB | 8 GB | +| Llama 3.2 3B Instruct 4bit | ~2 GB | 8 GB | +| Llama 3.3 70B Instruct 4bit | ~39 GB | 48 GB | +| MiniMax M2.5 6bit | ~173 GB | 192 GB | +| Qwen3 Coder 480B 4bit | ~276 GB | 320 GB | + +:::tip +Start with smaller models to verify your cluster works, then move to larger ones. The Exo dashboard shows real-time memory usage across your cluster. +::: + +## API Compatibility + +Exo exposes an **OpenAI-compatible API** at `http://localhost:52415/v1/`. You can also use it directly with curl: + +```bash +curl http://localhost:52415/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "mlx-community/MiniMax-M2.5-6bit", + "messages": [{"role": "user", "content": "Hello!"}], + "max_tokens": 100 + }' +``` + +Exo also supports: +- `/v1/chat/completions` (OpenAI format) +- `/v1/messages` (Claude format) +- `/ollama/api/chat` (Ollama format) +- Streaming responses (SSE) + +## Troubleshooting + +### No Models in Dropdown + +- **Exo not running**: Verify Exo is running on at least one device (`http://localhost:52415` should load the dashboard) +- **No downloads**: Models must be downloaded via the Exo dashboard first +- **Wrong URL**: Check the Exo URL in settings is `http://localhost:52415/v1/` + +### Instance Creation Fails + +- **"No valid placement found"**: Not enough combined RAM across your cluster for the selected model +- **"No cycles found"**: Devices are not connected. Check Thunderbolt cables and that Exo is running on all devices +- **Timeout**: Large models take time to load. The default timeout is 120 seconds + +### "Method Not Allowed" Error + +The Exo URL must end with `/v1/`. Update it in Settings > DevoxxGenie > Large Language Models > Exo URL. + +### Chat Works Once Then Fails + +Exo may recycle instances after inactivity. DevoxxGenie will automatically detect this and re-prepare the instance. Wait for the "Instance ready" notification before chatting again. + +### Slow Response Times + +- Use **Thunderbolt** connections instead of network for lower latency +- Enable **RDMA** for Thunderbolt 5 devices +- Try **Pipeline** sharding (default) for single-request latency +- Try **Tensor** sharding for throughput with concurrent requests + +## Hardware Recommendations + +### Minimum Cluster +- 2x Mac Mini M4 (16GB each) - Run models up to ~28GB + +### Recommended Cluster +- 2x Mac Studio M2 Ultra (192GB each) - Run most frontier models + +### High-End Cluster +- 4x Mac Studio M4 Ultra (256GB each) - Run the largest available models with tensor parallelism diff --git a/docusaurus/docs/llm-providers/local-models.md b/docusaurus/docs/llm-providers/local-models.md index b39204ab..01af4ec5 100644 --- a/docusaurus/docs/llm-providers/local-models.md +++ b/docusaurus/docs/llm-providers/local-models.md @@ -19,7 +19,8 @@ DevoxxGenie integrates with these local LLM providers: 3. [GPT4All](#gpt4all) 4. [Llama.cpp](#llamacpp) 5. [Jan](#jan) -6. [Custom OpenAI-compatible Providers](#custom-openai-compatible-providers) +6. [Exo](/docs/llm-providers/exo) (Distributed AI cluster) +7. [Custom OpenAI-compatible Providers](#custom-openai-compatible-providers) ## Ollama @@ -264,6 +265,7 @@ Choose the provider that best matches your needs: - **Performance**: Llama.cpp offers the most control over optimization - **Customization**: LM Studio and Llama.cpp provide the most options - **All-in-one**: Jan provides model management + chat in a single app +- **Large models**: [Exo](/docs/llm-providers/exo) lets you run frontier models across multiple devices ## Troubleshooting diff --git a/docusaurus/docs/llm-providers/overview.md b/docusaurus/docs/llm-providers/overview.md index 6fa4e351..581eaa55 100644 --- a/docusaurus/docs/llm-providers/overview.md +++ b/docusaurus/docs/llm-providers/overview.md @@ -29,6 +29,7 @@ Supported local providers: - [GPT4All](local-models.md#gpt4all) - [Llama.cpp](local-models.md#llamacpp) - [Jan](local-models.md#jan) +- [Exo](exo.md) (Distributed AI cluster) - [Custom OpenAI-compatible](local-models.md#custom-openai-compatible-providers) ### Cloud LLM Providers diff --git a/docusaurus/sidebars.js b/docusaurus/sidebars.js index 047e0655..135d220e 100644 --- a/docusaurus/sidebars.js +++ b/docusaurus/sidebars.js @@ -55,6 +55,7 @@ const sidebars = { items: [ 'llm-providers/overview', 'llm-providers/local-models', + 'llm-providers/exo', 'llm-providers/cloud-models', 'llm-providers/custom-providers', 'features/acp-runners', diff --git a/docusaurus/static/img/exo-dashboard-cluster-view.png b/docusaurus/static/img/exo-dashboard-cluster-view.png new file mode 100644 index 00000000..9754042d Binary files /dev/null and b/docusaurus/static/img/exo-dashboard-cluster-view.png differ diff --git a/docusaurus/static/img/exo-instanceready.png b/docusaurus/static/img/exo-instanceready.png new file mode 100644 index 00000000..7e5322aa Binary files /dev/null and b/docusaurus/static/img/exo-instanceready.png differ diff --git a/docusaurus/static/img/exo-view.png b/docusaurus/static/img/exo-view.png new file mode 100644 index 00000000..68db9aae Binary files /dev/null and b/docusaurus/static/img/exo-view.png differ diff --git a/docusaurus/static/img/exo-warmingup.png b/docusaurus/static/img/exo-warmingup.png new file mode 100644 index 00000000..9f33c145 Binary files /dev/null and b/docusaurus/static/img/exo-warmingup.png differ diff --git a/src/main/java/com/devoxx/genie/chatmodel/ChatModelFactoryProvider.java b/src/main/java/com/devoxx/genie/chatmodel/ChatModelFactoryProvider.java index 1416fe3d..31f13759 100644 --- a/src/main/java/com/devoxx/genie/chatmodel/ChatModelFactoryProvider.java +++ b/src/main/java/com/devoxx/genie/chatmodel/ChatModelFactoryProvider.java @@ -14,6 +14,7 @@ import com.devoxx.genie.chatmodel.cloud.openai.OpenAIChatModelFactory; import com.devoxx.genie.chatmodel.cloud.openrouter.OpenRouterChatModelFactory; import com.devoxx.genie.chatmodel.local.customopenai.CustomOpenAIChatModelFactory; +import com.devoxx.genie.chatmodel.local.exo.ExoChatModelFactory; import com.devoxx.genie.chatmodel.local.gpt4all.GPT4AllChatModelFactory; import com.devoxx.genie.chatmodel.local.jan.JanChatModelFactory; import com.devoxx.genie.chatmodel.local.acprunners.AcpRunnersChatModelFactory; @@ -51,6 +52,7 @@ private ChatModelFactoryProvider() { case "AzureOpenAI" -> new AzureOpenAIChatModelFactory(); case "Bedrock" -> new BedrockModelFactory(); case "CustomOpenAI" -> new CustomOpenAIChatModelFactory(); + case "Exo" -> new ExoChatModelFactory(); case "DeepInfra" -> new DeepInfraChatModelFactory(); case "DeepSeek" -> new DeepSeekChatModelFactory(); case "Google" -> new GoogleChatModelFactory(); diff --git a/src/main/java/com/devoxx/genie/chatmodel/ChatModelProvider.java b/src/main/java/com/devoxx/genie/chatmodel/ChatModelProvider.java index da1f60a6..82fdb6cc 100644 --- a/src/main/java/com/devoxx/genie/chatmodel/ChatModelProvider.java +++ b/src/main/java/com/devoxx/genie/chatmodel/ChatModelProvider.java @@ -87,6 +87,9 @@ private void setLocalBaseUrl(@NotNull LanguageModel languageModel, case LLaMA: customChatModel.setBaseUrl(stateService.getLlamaCPPUrl()); break; + case Exo: + customChatModel.setBaseUrl(stateService.getExoModelUrl()); + break; case CustomOpenAI: customChatModel.setBaseUrl(stateService.getCustomOpenAIUrl()); break; diff --git a/src/main/java/com/devoxx/genie/chatmodel/local/exo/ExoChatModelFactory.java b/src/main/java/com/devoxx/genie/chatmodel/local/exo/ExoChatModelFactory.java new file mode 100644 index 00000000..8525d727 --- /dev/null +++ b/src/main/java/com/devoxx/genie/chatmodel/local/exo/ExoChatModelFactory.java @@ -0,0 +1,130 @@ +package com.devoxx.genie.chatmodel.local.exo; + +import com.devoxx.genie.chatmodel.local.LocalChatModelFactory; +import com.devoxx.genie.model.CustomChatModel; +import com.devoxx.genie.model.LanguageModel; +import com.devoxx.genie.model.enumarations.ModelProvider; +import com.devoxx.genie.model.exo.ExoModelEntryDTO; +import com.devoxx.genie.ui.settings.DevoxxGenieStateService; +import com.devoxx.genie.ui.util.NotificationUtil; +import com.intellij.openapi.application.ApplicationManager; +import com.intellij.openapi.progress.ProgressIndicator; +import com.intellij.openapi.progress.ProgressManager; +import com.intellij.openapi.progress.Task; +import com.intellij.openapi.project.Project; +import com.intellij.openapi.project.ProjectManager; +import dev.langchain4j.model.chat.ChatModel; +import dev.langchain4j.model.chat.StreamingChatModel; +import org.jetbrains.annotations.NotNull; + +import java.io.IOException; +import java.util.concurrent.atomic.AtomicBoolean; + +public class ExoChatModelFactory extends LocalChatModelFactory { + + private static final AtomicBoolean instanceReady = new AtomicBoolean(false); + private static volatile String preparedModelId = null; + private static final AtomicBoolean preparing = new AtomicBoolean(false); + + public ExoChatModelFactory() { + super(ModelProvider.Exo); + } + + @Override + public ChatModel createChatModel(@NotNull CustomChatModel customChatModel) { + verifyAndReprepareIfNeeded(customChatModel.getModelName()); + return createOpenAiChatModel(customChatModel); + } + + @Override + public StreamingChatModel createStreamingChatModel(@NotNull CustomChatModel customChatModel) { + verifyAndReprepareIfNeeded(customChatModel.getModelName()); + return createOpenAiStreamingChatModel(customChatModel); + } + + @Override + protected String getModelUrl() { + return DevoxxGenieStateService.getInstance().getExoModelUrl(); + } + + @Override + protected ExoModelEntryDTO[] fetchModels() throws IOException { + return ExoModelService.getInstance().getModels(); + } + + @Override + protected LanguageModel buildLanguageModel(Object model) throws IOException { + ExoModelEntryDTO exoModel = (ExoModelEntryDTO) model; + int contextWindow = exoModel.getContextLength(); + if (contextWindow <= 0) { + contextWindow = 4096; + } + return LanguageModel.builder() + .provider(modelProvider) + .modelName(exoModel.getId()) + .displayName(exoModel.getName() != null ? exoModel.getName() : exoModel.getId()) + .inputCost(0) + .outputCost(0) + .inputMaxTokens(contextWindow) + .apiKeyUsed(false) + .build(); + } + + /** + * Quick check if the instance is still alive. If not, triggers re-preparation. + * This is called from createChatModel on the EDT, so it must NOT block. + */ + private void verifyAndReprepareIfNeeded(String modelId) { + if (modelId == null) return; + + // Quick async check — don't block the EDT + ApplicationManager.getApplication().executeOnPooledThread(() -> { + if (!ExoModelService.getInstance().isInstanceRunning(modelId)) { + instanceReady.set(false); + prepareInstanceAsync(modelId, ProjectManager.getInstance().getDefaultProject()); + } + }); + } + + /** + * Prepares an Exo instance for the given model in the background. + * Shows a progress bar and notifies when ready. + */ + public static void prepareInstanceAsync(String modelId, Project project) { + if (modelId == null || modelId.isBlank()) return; + + // Don't start multiple preparations + if (!preparing.compareAndSet(false, true)) return; + + instanceReady.set(false); + preparedModelId = modelId; + + ProgressManager.getInstance().run(new Task.Backgroundable( + project, "Exo: Starting instance for " + modelId, true + ) { + @Override + public void run(@NotNull ProgressIndicator indicator) { + try { + ExoModelService.getInstance().ensureInstanceWithProgress(modelId, indicator); + instanceReady.set(true); + NotificationUtil.sendNotification(project, + "Exo instance ready for " + modelId); + } catch (IOException e) { + instanceReady.set(false); + NotificationUtil.sendNotification(project, + "Failed to start Exo instance: " + e.getMessage()); + } finally { + preparing.set(false); + } + } + }); + } + + public static boolean isInstanceReady() { + return instanceReady.get(); + } + + public static String getPreparedModelId() { + return preparedModelId; + } +} diff --git a/src/main/java/com/devoxx/genie/chatmodel/local/exo/ExoModelService.java b/src/main/java/com/devoxx/genie/chatmodel/local/exo/ExoModelService.java new file mode 100644 index 00000000..a2115afb --- /dev/null +++ b/src/main/java/com/devoxx/genie/chatmodel/local/exo/ExoModelService.java @@ -0,0 +1,475 @@ +package com.devoxx.genie.chatmodel.local.exo; + +import com.devoxx.genie.chatmodel.local.LocalLLMProvider; +import com.devoxx.genie.model.exo.ExoModelDTO; +import com.devoxx.genie.model.exo.ExoModelEntryDTO; +import com.devoxx.genie.service.exception.UnsuccessfulRequestException; +import com.devoxx.genie.ui.settings.DevoxxGenieStateService; +import com.devoxx.genie.util.HttpClientProvider; +import com.google.gson.Gson; +import com.google.gson.JsonArray; +import com.google.gson.JsonElement; +import com.google.gson.JsonObject; +import okhttp3.*; +import org.jetbrains.annotations.NotNull; + +import com.intellij.openapi.application.ApplicationManager; +import com.intellij.openapi.progress.ProgressIndicator; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import static com.devoxx.genie.util.HttpUtil.ensureEndsWithSlash; + +public class ExoModelService implements LocalLLMProvider { + + private static final Gson gson = new Gson(); + private static final MediaType JSON_TYPE = MediaType.parse("application/json"); + + /** + * Returns the Exo API base URL (without /v1/ suffix). + * The settings URL includes /v1/ for Langchain4j OpenAI compatibility, + * but the Exo management API endpoints (/state, /models, /instance) are at the root. + */ + private static String getExoApiBaseUrl() { + String url = DevoxxGenieStateService.getInstance().getExoModelUrl(); + if (url.endsWith("/v1/")) { + url = url.substring(0, url.length() - 3); // Remove "v1/" + } else if (url.endsWith("/v1")) { + url = url.substring(0, url.length() - 2); // Remove "v1" + } + return ensureEndsWithSlash(url); + } + + @NotNull + public static ExoModelService getInstance() { + return ApplicationManager.getApplication().getService(ExoModelService.class); + } + + /** + * Returns only models that are downloaded (on disk) across all Exo cluster nodes. + * Uses the /state endpoint's downloads section to find DownloadCompleted models, + * then enriches with metadata from /models. + */ + @Override + public ExoModelEntryDTO[] getModels() throws IOException { + String baseUrl = getExoApiBaseUrl(); + + // Get downloaded model IDs from /state (DownloadCompleted entries) + Set downloadedModelIds = getDownloadedModelIds(baseUrl); + + if (downloadedModelIds.isEmpty()) { + throw new IOException("No downloaded models found. Use the Exo dashboard to download models first."); + } + + // Fetch full model metadata from /models and filter to only downloaded ones + List result = new ArrayList<>(); + String url = baseUrl + "models"; + Request request = new Request.Builder().url(url).build(); + try (Response response = HttpClientProvider.getClient().newCall(request).execute()) { + if (response.isSuccessful() && response.body() != null) { + String json = response.body().string(); + ExoModelDTO dto = gson.fromJson(json, ExoModelDTO.class); + if (dto.getData() != null) { + for (ExoModelEntryDTO model : dto.getData()) { + if (downloadedModelIds.contains(model.getId())) { + result.add(model); + } + } + } + } + } + + // If /models didn't match, create entries from state info + if (result.isEmpty()) { + for (String modelId : downloadedModelIds) { + ExoModelEntryDTO dto = new ExoModelEntryDTO(); + dto.setId(modelId); + dto.setName(modelId.contains("/") ? modelId.substring(modelId.indexOf('/') + 1) : modelId); + dto.setContextLength(0); + result.add(dto); + } + } + + return result.toArray(new ExoModelEntryDTO[0]); + } + + /** + * Extracts model IDs that have DownloadCompleted status from /state. + */ + private Set getDownloadedModelIds(String baseUrl) throws IOException { + Set modelIds = new HashSet<>(); + JsonObject state = fetchState(baseUrl); + JsonObject downloads = state.getAsJsonObject("downloads"); + if (downloads == null) return modelIds; + + for (var nodeEntry : downloads.entrySet()) { + JsonArray nodeDownloads = nodeEntry.getValue().getAsJsonArray(); + for (JsonElement dl : nodeDownloads) { + JsonObject dlObj = dl.getAsJsonObject(); + if (dlObj.has("DownloadCompleted")) { + JsonObject completed = dlObj.getAsJsonObject("DownloadCompleted"); + String modelId = extractModelIdFromDownload(completed); + if (modelId != null) { + modelIds.add(modelId); + } + } + } + } + return modelIds; + } + + private String extractModelIdFromDownload(JsonObject downloadInfo) { + try { + JsonObject shardMetadata = downloadInfo.getAsJsonObject("shardMetadata"); + if (shardMetadata == null) return null; + for (var entry : shardMetadata.entrySet()) { + JsonObject shard = entry.getValue().getAsJsonObject(); + if (shard.has("modelCard")) { + return shard.getAsJsonObject("modelCard").get("modelId").getAsString(); + } + } + } catch (Exception ignored) { + } + return null; + } + + /** + * Quick check if an instance for the given model is currently running and has ready runners. + * Used to detect when exo has recycled/disconnected an instance. + */ + public boolean isInstanceRunning(String modelId) { + try { + String baseUrl = getExoApiBaseUrl(); + JsonObject state = fetchState(baseUrl); + JsonObject instances = state.getAsJsonObject("instances"); + if (instances == null || instances.entrySet().isEmpty()) return false; + + JsonObject runners = state.getAsJsonObject("runners"); + for (var entry : instances.entrySet()) { + String instModelId = extractModelIdFromInstance(entry.getValue()); + if (modelId.equals(instModelId)) { + String instanceId = extractInstanceId(entry.getValue()); + return instanceId != null && areInstanceRunnersReady(instances, runners, instanceId); + } + } + } catch (Exception ignored) { + } + return false; + } + + /** + * Ensures a model instance is running on Exo. + * Previews placements and creates an instance if none exists. + */ + public void ensureInstance(String modelId) throws IOException { + String baseUrl = getExoApiBaseUrl(); + + // Check if an instance already exists for this model + JsonObject state = fetchState(baseUrl); + JsonObject instances = state.getAsJsonObject("instances"); + if (instances != null) { + for (var entry : instances.entrySet()) { + String instanceModelId = extractModelIdFromInstance(entry.getValue()); + if (modelId.equals(instanceModelId)) { + // Instance exists — check if its runners are ready + String instanceId = extractInstanceId(entry.getValue()); + JsonObject runners = state.getAsJsonObject("runners"); + if (instanceId != null && areInstanceRunnersReady(instances, runners, instanceId)) { + return; // Instance exists and is ready + } + // Instance exists but runners not ready — wait for them + waitForInstanceReady(baseUrl, modelId, 120); + return; + } + } + } + + // Preview placements + String previewUrl = baseUrl + "instance/previews?model_id=" + modelId; + Request previewRequest = new Request.Builder().url(previewUrl).build(); + String previewBody; + try (Response previewResponse = HttpClientProvider.getClient().newCall(previewRequest).execute()) { + if (!previewResponse.isSuccessful() || previewResponse.body() == null) { + throw new IOException("Failed to preview placements for model: " + modelId); + } + previewBody = previewResponse.body().string(); + } + + // Find first valid placement + JsonObject previewObj = gson.fromJson(previewBody, JsonObject.class); + JsonArray previews = previewObj.getAsJsonArray("previews"); + JsonElement chosenPreview = null; + + for (JsonElement p : previews) { + JsonObject preview = p.getAsJsonObject(); + if (preview.get("error").isJsonNull() && preview.get("instance") != null && !preview.get("instance").isJsonNull()) { + chosenPreview = p; + break; + } + } + + if (chosenPreview == null) { + throw new IOException("No valid placement found for model: " + modelId + + ". Ensure enough devices are connected in the Exo cluster."); + } + + // Create instance + String instanceUrl = baseUrl + "instance"; + RequestBody body = RequestBody.create(chosenPreview.toString(), JSON_TYPE); + Request instanceRequest = new Request.Builder().url(instanceUrl).post(body).build(); + + try (Response instanceResponse = HttpClientProvider.getClient().newCall(instanceRequest).execute()) { + if (!instanceResponse.isSuccessful()) { + String errorBody = instanceResponse.body() != null ? instanceResponse.body().string() : "unknown"; + throw new IOException("Failed to create Exo instance: " + errorBody); + } + } + + // Wait for the instance runners to be ready + waitForInstanceReady(baseUrl, modelId, 120); + } + + /** + * Same as ensureInstance but reports progress to an IntelliJ ProgressIndicator. + */ + public void ensureInstanceWithProgress(String modelId, ProgressIndicator indicator) throws IOException { + String baseUrl = getExoApiBaseUrl(); + + indicator.setText("Checking for existing instance..."); + indicator.setFraction(0.1); + + // Check if an instance already exists for this model + JsonObject state = fetchState(baseUrl); + JsonObject instances = state.getAsJsonObject("instances"); + if (instances != null) { + for (var entry : instances.entrySet()) { + String instanceModelId = extractModelIdFromInstance(entry.getValue()); + if (modelId.equals(instanceModelId)) { + String instanceId = extractInstanceId(entry.getValue()); + JsonObject runners = state.getAsJsonObject("runners"); + if (instanceId != null && areInstanceRunnersReady(instances, runners, instanceId)) { + indicator.setText("Instance already running"); + indicator.setFraction(1.0); + return; + } + indicator.setText("Waiting for runners to warm up..."); + indicator.setFraction(0.5); + waitForInstanceReadyWithProgress(baseUrl, modelId, 120, indicator); + return; + } + } + } + + indicator.setText("Previewing placements across cluster..."); + indicator.setFraction(0.2); + if (indicator.isCanceled()) return; + + // Preview placements + String previewUrl = baseUrl + "instance/previews?model_id=" + modelId; + Request previewRequest = new Request.Builder().url(previewUrl).build(); + String previewBody; + try (Response previewResponse = HttpClientProvider.getClient().newCall(previewRequest).execute()) { + if (!previewResponse.isSuccessful() || previewResponse.body() == null) { + throw new IOException("Failed to preview placements for model: " + modelId); + } + previewBody = previewResponse.body().string(); + } + + JsonObject previewObj = gson.fromJson(previewBody, JsonObject.class); + JsonArray previews = previewObj.getAsJsonArray("previews"); + JsonElement chosenPreview = null; + for (JsonElement p : previews) { + JsonObject preview = p.getAsJsonObject(); + if (preview.get("error").isJsonNull() && preview.get("instance") != null && !preview.get("instance").isJsonNull()) { + chosenPreview = p; + break; + } + } + + if (chosenPreview == null) { + throw new IOException("No valid placement found for model: " + modelId + + ". Ensure enough devices are connected in the Exo cluster."); + } + + indicator.setText("Creating instance..."); + indicator.setFraction(0.3); + if (indicator.isCanceled()) return; + + // Create instance + String instanceUrl = baseUrl + "instance"; + RequestBody body = RequestBody.create(chosenPreview.toString(), JSON_TYPE); + Request instanceRequest = new Request.Builder().url(instanceUrl).post(body).build(); + + try (Response instanceResponse = HttpClientProvider.getClient().newCall(instanceRequest).execute()) { + if (!instanceResponse.isSuccessful()) { + String errorBody = instanceResponse.body() != null ? instanceResponse.body().string() : "unknown"; + throw new IOException("Failed to create Exo instance: " + errorBody); + } + } + + indicator.setText("Loading model across cluster..."); + indicator.setFraction(0.4); + + waitForInstanceReadyWithProgress(baseUrl, modelId, 120, indicator); + } + + private void waitForInstanceReadyWithProgress(String baseUrl, String modelId, + int timeoutSeconds, ProgressIndicator indicator) throws IOException { + long deadline = System.currentTimeMillis() + (timeoutSeconds * 1000L); + long startTime = System.currentTimeMillis(); + + while (System.currentTimeMillis() < deadline) { + if (indicator.isCanceled()) return; + + JsonObject state = fetchState(baseUrl); + JsonObject instances = state.getAsJsonObject("instances"); + JsonObject runners = state.getAsJsonObject("runners"); + + if (instances != null) { + for (var entry : instances.entrySet()) { + String instModelId = extractModelIdFromInstance(entry.getValue()); + if (modelId.equals(instModelId)) { + String instanceId = extractInstanceId(entry.getValue()); + if (instanceId != null && areInstanceRunnersReady(instances, runners, instanceId)) { + indicator.setText("Exo instance ready"); + indicator.setFraction(1.0); + return; + } + } + } + } + + // Update progress (0.4 to 0.95 over the timeout period) + double elapsed = (System.currentTimeMillis() - startTime) / (double) (timeoutSeconds * 1000L); + indicator.setFraction(0.4 + elapsed * 0.55); + indicator.setText("Warming up model runners..."); + + try { + Thread.sleep(2000); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new IOException("Interrupted while waiting for Exo runners", e); + } + } + throw new IOException("Timed out waiting for Exo runners to become ready"); + } + + private JsonObject fetchState(String baseUrl) throws IOException { + String stateUrl = baseUrl + "state"; + Request request = new Request.Builder().url(stateUrl).build(); + try (Response response = HttpClientProvider.getClient().newCall(request).execute()) { + if (!response.isSuccessful() || response.body() == null) { + throw new IOException("Failed to fetch Exo state"); + } + return gson.fromJson(response.body().string(), JsonObject.class); + } + } + + private String extractModelIdFromInstance(JsonElement instanceValue) { + try { + JsonObject inst = instanceValue.getAsJsonObject(); + for (var entry : inst.entrySet()) { + JsonObject inner = entry.getValue().getAsJsonObject(); + JsonObject shardAssignments = inner.getAsJsonObject("shardAssignments"); + if (shardAssignments != null) { + return shardAssignments.get("modelId").getAsString(); + } + } + } catch (Exception ignored) { + } + return null; + } + + private String extractInstanceId(JsonElement instanceValue) { + try { + JsonObject inst = instanceValue.getAsJsonObject(); + for (var entry : inst.entrySet()) { + JsonObject inner = entry.getValue().getAsJsonObject(); + if (inner.has("instanceId")) { + return inner.get("instanceId").getAsString(); + } + } + } catch (Exception ignored) { + } + return null; + } + + /** + * Checks if the runners belonging to a specific instance are all ready. + * Only looks at runners mapped to this instance via nodeToRunner, ignoring stale runners. + */ + private boolean areInstanceRunnersReady(JsonObject instances, JsonObject runners, String instanceId) { + if (runners == null || instances == null) return false; + + // Find runner IDs belonging to this instance + Set instanceRunnerIds = new HashSet<>(); + for (var entry : instances.entrySet()) { + String iid = extractInstanceId(entry.getValue()); + if (instanceId.equals(iid)) { + try { + JsonObject inst = entry.getValue().getAsJsonObject(); + for (var inner : inst.entrySet()) { + JsonObject shardAssignments = inner.getValue().getAsJsonObject().getAsJsonObject("shardAssignments"); + if (shardAssignments != null) { + JsonObject nodeToRunner = shardAssignments.getAsJsonObject("nodeToRunner"); + if (nodeToRunner != null) { + for (var nr : nodeToRunner.entrySet()) { + instanceRunnerIds.add(nr.getValue().getAsString()); + } + } + } + } + } catch (Exception ignored) { + } + break; + } + } + + if (instanceRunnerIds.isEmpty()) return false; + + // Check only those runners + return instanceRunnerIds.stream().allMatch(runnerId -> { + JsonElement runner = runners.get(runnerId); + if (runner == null) return false; + JsonObject r = runner.getAsJsonObject(); + return r.has("RunnerReady") || r.has("RunnerRunning"); + }); + } + + /** + * Waits for the instance running the given model to have all its runners ready. + */ + private void waitForInstanceReady(String baseUrl, String modelId, int timeoutSeconds) throws IOException { + long deadline = System.currentTimeMillis() + (timeoutSeconds * 1000L); + + while (System.currentTimeMillis() < deadline) { + JsonObject state = fetchState(baseUrl); + JsonObject instances = state.getAsJsonObject("instances"); + JsonObject runners = state.getAsJsonObject("runners"); + + if (instances != null) { + for (var entry : instances.entrySet()) { + String instModelId = extractModelIdFromInstance(entry.getValue()); + if (modelId.equals(instModelId)) { + String instanceId = extractInstanceId(entry.getValue()); + if (instanceId != null && areInstanceRunnersReady(instances, runners, instanceId)) { + return; + } + } + } + } + + try { + Thread.sleep(2000); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new IOException("Interrupted while waiting for Exo runners", e); + } + } + throw new IOException("Timed out waiting for Exo runners to become ready"); + } +} diff --git a/src/main/java/com/devoxx/genie/model/Constant.java b/src/main/java/com/devoxx/genie/model/Constant.java index 3f83cd9a..718d3271 100644 --- a/src/main/java/com/devoxx/genie/model/Constant.java +++ b/src/main/java/com/devoxx/genie/model/Constant.java @@ -40,6 +40,7 @@ private Constant() { public static final String GPT4ALL_MODEL_URL = "http://localhost:4891/v1/"; public static final String JAN_MODEL_URL = "http://localhost:1337/v1/"; public static final String LLAMA_CPP_MODEL_URL = "http://localhost:8080"; + public static final String EXO_MODEL_URL = "http://localhost:52415/v1/"; // ActionCommands public static final String SUBMIT_ACTION = "submit"; diff --git a/src/main/java/com/devoxx/genie/model/enumarations/ModelProvider.java b/src/main/java/com/devoxx/genie/model/enumarations/ModelProvider.java index 20ccb6ea..5aa2631b 100644 --- a/src/main/java/com/devoxx/genie/model/enumarations/ModelProvider.java +++ b/src/main/java/com/devoxx/genie/model/enumarations/ModelProvider.java @@ -13,6 +13,7 @@ public enum ModelProvider { Jan("Jan", Type.LOCAL), LLaMA("LLaMA.c++", Type.LOCAL), LMStudio("LMStudio", Type.LOCAL), + Exo("Exo", Type.LOCAL), Ollama("Ollama", Type.LOCAL), CLIRunners("CLI Runners", Type.LOCAL), ACPRunners("ACP Runners", Type.LOCAL), diff --git a/src/main/java/com/devoxx/genie/model/exo/ExoModelDTO.java b/src/main/java/com/devoxx/genie/model/exo/ExoModelDTO.java new file mode 100644 index 00000000..296152fc --- /dev/null +++ b/src/main/java/com/devoxx/genie/model/exo/ExoModelDTO.java @@ -0,0 +1,11 @@ +package com.devoxx.genie.model.exo; + +import lombok.Getter; +import lombok.Setter; + +@Getter +@Setter +public class ExoModelDTO { + private String object; + private ExoModelEntryDTO[] data; +} diff --git a/src/main/java/com/devoxx/genie/model/exo/ExoModelEntryDTO.java b/src/main/java/com/devoxx/genie/model/exo/ExoModelEntryDTO.java new file mode 100644 index 00000000..9e47f1c2 --- /dev/null +++ b/src/main/java/com/devoxx/genie/model/exo/ExoModelEntryDTO.java @@ -0,0 +1,23 @@ +package com.devoxx.genie.model.exo; + +import com.google.gson.annotations.SerializedName; +import lombok.Getter; +import lombok.Setter; + +@Getter +@Setter +public class ExoModelEntryDTO { + private String id; + private String name; + @SerializedName("context_length") + private int contextLength; + @SerializedName("storage_size_megabytes") + private int storageSizeMegabytes; + @SerializedName("supports_tensor") + private boolean supportsTensor; + private String family; + private String quantization; + @SerializedName("base_model") + private String baseModel; + private String[] capabilities; +} diff --git a/src/main/java/com/devoxx/genie/service/DevoxxGenieSettingsService.java b/src/main/java/com/devoxx/genie/service/DevoxxGenieSettingsService.java index 562fda88..699eb7fd 100644 --- a/src/main/java/com/devoxx/genie/service/DevoxxGenieSettingsService.java +++ b/src/main/java/com/devoxx/genie/service/DevoxxGenieSettingsService.java @@ -203,6 +203,10 @@ public interface DevoxxGenieSettingsService { void setLlamaCPPUrl(String text); + String getExoModelUrl(); + + void setExoModelUrl(String url); + Boolean getUseFileInEditor(); void setUseFileInEditor(Boolean useFileInEditor); diff --git a/src/main/java/com/devoxx/genie/ui/component/JEditorPaneUtils.java b/src/main/java/com/devoxx/genie/ui/component/JEditorPaneUtils.java index 7780ca5d..a97bf2a5 100644 --- a/src/main/java/com/devoxx/genie/ui/component/JEditorPaneUtils.java +++ b/src/main/java/com/devoxx/genie/ui/component/JEditorPaneUtils.java @@ -14,6 +14,7 @@ import javax.swing.event.HyperlinkListener; import javax.swing.text.html.HTMLEditorKit; import javax.swing.text.html.StyleSheet; +import java.awt.*; import java.util.List; import static com.devoxx.genie.ui.util.DevoxxGenieColorsUtil.PROMPT_BG_COLOR; @@ -56,6 +57,39 @@ private JEditorPaneUtils() { return editorPane; } + /** + * Creates a JEditorPane for conversation content with a custom background color. + */ + public static @NotNull JEditorPane createConversationJEditorPane(@NotNull CharSequence content, + HyperlinkListener hyperlinkListener, + StyleSheet styleSheet, + Color backgroundColor) { + JEditorPane editorPane = new JEditorPane(); + editorPane.addHyperlinkListener(hyperlinkListener != null ? hyperlinkListener : BrowserHyperlinkListener.INSTANCE); + editorPane.setContentType("text/html"); + + HTMLEditorKitBuilder htmlEditorKitBuilder = + new HTMLEditorKitBuilder() + .withWordWrapViewFactory() + .withFontResolver(EditorCssFontResolver.getGlobalInstance()); + + HTMLEditorKit editorKit = htmlEditorKitBuilder.build(); + editorKit.getStyleSheet().addStyleSheet(styleSheet); + editorPane.setEditorKit(editorKit); + + editorPane.setEditable(false); + editorPane.setForeground(JBColor.foreground()); + editorPane.setBackground(backgroundColor != null ? backgroundColor : PROMPT_BG_COLOR); + editorPane.setText(colorizeSeparators(content.toString())); + + UIUtil.doNotScrollToCaret(editorPane); + UIUtil.invokeLaterIfNeeded(() -> { + editorPane.revalidate(); + editorPane.setCaretPosition(editorPane.getDocument().getLength()); + }); + return editorPane; + } + @Contract(pure = true) private static String colorizeSeparators(String html) { String body = UIUtil.getHtmlBody(html); diff --git a/src/main/java/com/devoxx/genie/ui/panel/ExoClusterPanel.java b/src/main/java/com/devoxx/genie/ui/panel/ExoClusterPanel.java new file mode 100644 index 00000000..7ccdce98 --- /dev/null +++ b/src/main/java/com/devoxx/genie/ui/panel/ExoClusterPanel.java @@ -0,0 +1,380 @@ +package com.devoxx.genie.ui.panel; + +import com.devoxx.genie.ui.settings.DevoxxGenieStateService; +import com.devoxx.genie.util.HttpClientProvider; +import com.google.gson.Gson; +import com.google.gson.JsonObject; +import com.intellij.openapi.Disposable; +import com.intellij.openapi.application.ApplicationManager; +import com.intellij.ui.JBColor; +import com.intellij.util.concurrency.AppExecutorUtil; +import com.intellij.util.ui.JBUI; +import okhttp3.Request; +import okhttp3.Response; +import org.jetbrains.annotations.NotNull; + +import javax.swing.*; +import java.awt.*; +import java.awt.event.MouseAdapter; +import java.awt.event.MouseEvent; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; + +import static com.devoxx.genie.util.HttpUtil.ensureEndsWithSlash; + +/** + * Collapsible panel that displays Exo cluster node status above the chat window. + * Header bar always shows instance status; node cards can be toggled. + */ +public class ExoClusterPanel extends JPanel implements Disposable { + + private static final Gson gson = new Gson(); + private static final JBColor NODE_BG = new JBColor(new Color(45, 45, 48), new Color(45, 45, 48)); + private static final JBColor NODE_BORDER = new JBColor(new Color(80, 80, 85), new Color(80, 80, 85)); + private static final JBColor MEMORY_BAR_BG = new JBColor(new Color(60, 60, 65), new Color(60, 60, 65)); + private static final JBColor MEMORY_BAR_FG = new JBColor(new Color(78, 154, 241), new Color(78, 154, 241)); + private static final JBColor ACTIVE_DOT = new JBColor(new Color(80, 200, 80), new Color(80, 200, 80)); + private static final JBColor INACTIVE_DOT = new JBColor(new Color(120, 120, 120), new Color(120, 120, 120)); + private static final JBColor CONNECTION_LINE = new JBColor(new Color(200, 180, 50), new Color(200, 180, 50)); + private static final JBColor INSTANCE_COLOR = new JBColor(new Color(200, 180, 50), new Color(200, 180, 50)); + private static final int HEADER_HEIGHT = 22; + + private final List nodes = new ArrayList<>(); + private String activeModelId = null; + private String instanceStatus = null; + private ScheduledFuture refreshTask; + private boolean expanded = true; + + public ExoClusterPanel() { + setLayout(new BorderLayout()); + setOpaque(true); + setBackground(JBColor.background()); + setBorder(JBUI.Borders.compound( + JBUI.Borders.customLineBottom(NODE_BORDER), + JBUI.Borders.empty(0, 8, 2, 8) + )); + setVisible(false); + + addMouseListener(new MouseAdapter() { + @Override + public void mouseClicked(MouseEvent e) { + // Click on the header area toggles expand/collapse + if (e.getY() <= HEADER_HEIGHT) { + expanded = !expanded; + revalidate(); + repaint(); + // Propagate size change to parent + Container parent = getParent(); + if (parent != null) { + parent.revalidate(); + parent.repaint(); + } + } + } + }); + + setCursor(Cursor.getPredefinedCursor(Cursor.HAND_CURSOR)); + } + + public void startPolling() { + stopPolling(); + refreshTask = AppExecutorUtil.getAppScheduledExecutorService() + .scheduleWithFixedDelay(this::refreshState, 0, 5, TimeUnit.SECONDS); + setVisible(true); + } + + public void stopPolling() { + if (refreshTask != null) { + refreshTask.cancel(false); + refreshTask = null; + } + setVisible(false); + } + + private void refreshState() { + try { + String url = getExoApiBaseUrl(); + Request request = new Request.Builder().url(url + "state").build(); + try (Response response = HttpClientProvider.getClient().newCall(request).execute()) { + if (response.isSuccessful() && response.body() != null) { + JsonObject state = gson.fromJson(response.body().string(), JsonObject.class); + parseState(state); + ApplicationManager.getApplication().invokeLater(() -> { + revalidate(); + repaint(); + }); + } + } + } catch (Exception ignored) { + } + } + + private void parseState(@NotNull JsonObject state) { + nodes.clear(); + activeModelId = null; + instanceStatus = null; + + JsonObject identities = state.getAsJsonObject("nodeIdentities"); + JsonObject nodeMemory = state.getAsJsonObject("nodeMemory"); + JsonObject topology = state.getAsJsonObject("topology"); + var topoNodes = topology != null ? topology.getAsJsonArray("nodes") : null; + + if (identities != null) { + for (var entry : identities.entrySet()) { + String nodeId = entry.getKey(); + JsonObject info = entry.getValue().getAsJsonObject(); + + boolean inTopology = false; + if (topoNodes != null) { + for (var tn : topoNodes) { + if (tn.getAsString().equals(nodeId)) { + inTopology = true; + break; + } + } + } + if (!inTopology) continue; + + NodeInfo node = new NodeInfo(); + node.name = info.get("friendlyName").getAsString(); + node.chip = info.get("chipId").getAsString(); + + if (nodeMemory != null && nodeMemory.has(nodeId)) { + JsonObject mem = nodeMemory.getAsJsonObject(nodeId); + node.ramTotal = mem.getAsJsonObject("ramTotal").get("inBytes").getAsLong(); + node.ramAvailable = mem.getAsJsonObject("ramAvailable").get("inBytes").getAsLong(); + } + + JsonObject nodeSystem = state.getAsJsonObject("nodeSystem"); + if (nodeSystem != null && nodeSystem.has(nodeId)) { + JsonObject sys = nodeSystem.getAsJsonObject(nodeId); + node.gpuUsage = sys.has("gpuUsage") ? sys.get("gpuUsage").getAsDouble() : 0; + node.temperature = sys.has("temp") ? sys.get("temp").getAsDouble() : 0; + } + + nodes.add(node); + } + } + + JsonObject instances = state.getAsJsonObject("instances"); + if (instances != null && !instances.entrySet().isEmpty()) { + for (var entry : instances.entrySet()) { + JsonObject inst = entry.getValue().getAsJsonObject(); + for (var inner : inst.entrySet()) { + JsonObject details = inner.getValue().getAsJsonObject(); + JsonObject shards = details.getAsJsonObject("shardAssignments"); + if (shards != null && shards.has("modelId")) { + activeModelId = shards.get("modelId").getAsString(); + } + } + } + + JsonObject runners = state.getAsJsonObject("runners"); + if (runners != null && !runners.entrySet().isEmpty()) { + boolean allReady = runners.entrySet().stream() + .allMatch(e -> { + JsonObject r = e.getValue().getAsJsonObject(); + return r.has("RunnerReady") || r.has("RunnerRunning"); + }); + boolean anyWarming = runners.entrySet().stream() + .anyMatch(e -> { + JsonObject r = e.getValue().getAsJsonObject(); + return r.has("RunnerWarmingUp") || r.has("RunnerLoading"); + }); + if (allReady) { + instanceStatus = "Ready"; + } else if (anyWarming) { + instanceStatus = "Warming up"; + } else { + instanceStatus = "Loading"; + } + } + } + } + + @Override + protected void paintComponent(Graphics g) { + super.paintComponent(g); + if (nodes.isEmpty()) return; + + Graphics2D g2 = (Graphics2D) g.create(); + g2.setRenderingHint(RenderingHints.KEY_ANTIALIASING, RenderingHints.VALUE_ANTIALIAS_ON); + g2.setRenderingHint(RenderingHints.KEY_TEXT_ANTIALIASING, RenderingHints.VALUE_TEXT_ANTIALIAS_ON); + + // --- Header bar (always visible) --- + drawHeader(g2); + + // --- Node cards (only when expanded) --- + if (expanded) { + int panelWidth = getWidth() - 16; + int nodeWidth = Math.min(180, (panelWidth - (nodes.size() - 1) * 30) / Math.max(nodes.size(), 1)); + int nodeHeight = 64; + int startX = (getWidth() - (nodes.size() * nodeWidth + (nodes.size() - 1) * 30)) / 2; + int startY = HEADER_HEIGHT + 4; + + // Connection lines + if (nodes.size() > 1) { + g2.setColor(CONNECTION_LINE); + g2.setStroke(new BasicStroke(1.5f, BasicStroke.CAP_ROUND, BasicStroke.JOIN_ROUND, + 0, new float[]{4, 4}, 0)); + for (int i = 0; i < nodes.size() - 1; i++) { + int x1 = startX + i * (nodeWidth + 30) + nodeWidth; + int x2 = startX + (i + 1) * (nodeWidth + 30); + int y = startY + nodeHeight / 2; + g2.drawLine(x1, y, x2, y); + } + g2.setStroke(new BasicStroke(1)); + } + + // Node cards + for (int i = 0; i < nodes.size(); i++) { + NodeInfo node = nodes.get(i); + int x = startX + i * (nodeWidth + 30); + drawNodeCard(g2, node, x, startY, nodeWidth, nodeHeight); + } + } + + g2.dispose(); + } + + private void drawHeader(Graphics2D g2) { + int y = 4; + + // Toggle arrow + g2.setFont(g2.getFont().deriveFont(Font.PLAIN, 10f)); + g2.setColor(JBColor.GRAY); + String arrow = expanded ? "\u25BC" : "\u25B6"; // ▼ or ▶ + g2.drawString(arrow, 4, y + 12); + + // "Exo Cluster" label + node count + g2.setFont(g2.getFont().deriveFont(Font.BOLD, 11f)); + g2.setColor(JBColor.foreground()); + g2.drawString("Exo Cluster", 18, y + 12); + + g2.setFont(g2.getFont().deriveFont(Font.PLAIN, 10f)); + g2.setColor(JBColor.GRAY); + g2.drawString(nodes.size() + " node" + (nodes.size() != 1 ? "s" : ""), 90, y + 12); + + // Instance status on the right + if (activeModelId != null) { + String modelName = activeModelId.contains("/") + ? activeModelId.substring(activeModelId.indexOf('/') + 1) + : activeModelId; + String label = modelName + (instanceStatus != null ? " — " + instanceStatus : ""); + + g2.setFont(g2.getFont().deriveFont(Font.PLAIN, 10f)); + FontMetrics fm = g2.getFontMetrics(); + int labelWidth = fm.stringWidth(label) + 16; + int labelX = getWidth() - labelWidth - 12; + + // Status dot + g2.setColor("Ready".equals(instanceStatus) ? ACTIVE_DOT : INSTANCE_COLOR); + g2.fillOval(labelX, y + 5, 7, 7); + + // Label + g2.setColor(INSTANCE_COLOR); + g2.drawString(label, labelX + 11, y + 12); + } + } + + private void drawNodeCard(Graphics2D g2, NodeInfo node, int x, int y, int w, int h) { + g2.setColor(NODE_BG); + g2.fillRoundRect(x, y, w, h, 8, 8); + + g2.setColor(NODE_BORDER); + g2.drawRoundRect(x, y, w, h, 8, 8); + + int padding = 6; + int textX = x + padding; + + g2.setFont(g2.getFont().deriveFont(Font.BOLD, 11f)); + g2.setColor(JBColor.foreground()); + String name = truncate(node.name, g2.getFontMetrics(), w - padding * 2); + g2.drawString(name, textX, y + 15); + + g2.setFont(g2.getFont().deriveFont(Font.PLAIN, 9f)); + g2.setColor(JBColor.GRAY); + String chip = truncate(node.chip, g2.getFontMetrics(), w - padding * 2); + g2.drawString(chip, textX, y + 27); + + if (node.ramTotal > 0) { + int barX = textX; + int barY = y + 33; + int barW = w - padding * 2; + int barH = 8; + double usedRatio = 1.0 - (double) node.ramAvailable / node.ramTotal; + + g2.setColor(MEMORY_BAR_BG); + g2.fillRoundRect(barX, barY, barW, barH, 4, 4); + + g2.setColor(MEMORY_BAR_FG); + g2.fillRoundRect(barX, barY, (int) (barW * usedRatio), barH, 4, 4); + + long usedGB = (node.ramTotal - node.ramAvailable) / (1024L * 1024 * 1024); + long totalGB = node.ramTotal / (1024L * 1024 * 1024); + String memText = usedGB + "/" + totalGB + "GB"; + g2.setFont(g2.getFont().deriveFont(Font.PLAIN, 9f)); + g2.setColor(JBColor.GRAY); + g2.drawString(memText, textX, y + 54); + + if (node.gpuUsage > 0 || node.temperature > 0) { + String stats = String.format("%.0f%% %.0f\u00B0C", node.gpuUsage * 100, node.temperature); + int statsWidth = g2.getFontMetrics().stringWidth(stats); + g2.drawString(stats, x + w - padding - statsWidth, y + 54); + } + } + + g2.setColor(node.ramTotal > 0 ? ACTIVE_DOT : INACTIVE_DOT); + g2.fillOval(x + w - 12, y + 4, 6, 6); + } + + private String truncate(String text, FontMetrics fm, int maxWidth) { + if (fm.stringWidth(text) <= maxWidth) return text; + for (int i = text.length() - 1; i > 0; i--) { + String truncated = text.substring(0, i) + ".."; + if (fm.stringWidth(truncated) <= maxWidth) return truncated; + } + return ".."; + } + + @Override + public Dimension getPreferredSize() { + if (nodes.isEmpty()) return new Dimension(0, 0); + int height = HEADER_HEIGHT; + if (expanded) { + height += 72; // node cards + } + return new Dimension(super.getPreferredSize().width, height); + } + + @Override + public Dimension getMinimumSize() { + return getPreferredSize(); + } + + private static String getExoApiBaseUrl() { + String url = DevoxxGenieStateService.getInstance().getExoModelUrl(); + if (url.endsWith("/v1/")) { + url = url.substring(0, url.length() - 3); + } else if (url.endsWith("/v1")) { + url = url.substring(0, url.length() - 2); + } + return ensureEndsWithSlash(url); + } + + @Override + public void dispose() { + stopPolling(); + } + + private static class NodeInfo { + String name = ""; + String chip = ""; + long ramTotal = 0; + long ramAvailable = 0; + double gpuUsage = 0; + double temperature = 0; + } +} diff --git a/src/main/java/com/devoxx/genie/ui/panel/LlmProviderPanel.java b/src/main/java/com/devoxx/genie/ui/panel/LlmProviderPanel.java index 8a91c858..ea5474f6 100644 --- a/src/main/java/com/devoxx/genie/ui/panel/LlmProviderPanel.java +++ b/src/main/java/com/devoxx/genie/ui/panel/LlmProviderPanel.java @@ -2,6 +2,7 @@ import com.devoxx.genie.chatmodel.ChatModelFactory; import com.devoxx.genie.chatmodel.ChatModelFactoryProvider; +import com.devoxx.genie.chatmodel.local.exo.ExoChatModelFactory; import com.devoxx.genie.model.Constant; import com.devoxx.genie.model.LanguageModel; import com.devoxx.genie.model.enumarations.ModelProvider; @@ -108,6 +109,7 @@ public LlmProviderPanel(@NotNull Project project, String tabId) { restoreLastSelectedLanguageModel(); modelProviderComboBox.addActionListener(this::handleModelProviderSelectionChange); + modelNameComboBox.addActionListener(this::handleModelNameSelectionChange); add(toolPanel, BorderLayout.CENTER); isInitializationComplete = true; @@ -141,6 +143,7 @@ public void addModelProvidersToComboBox() { case GPT4All -> stateService.isGpt4AllEnabled(); case Jan -> stateService.isJanEnabled(); case LLaMA -> stateService.isLlamaCPPEnabled(); + case Exo -> stateService.isExoEnabled(); case CustomOpenAI -> stateService.isCustomOpenAIUrlEnabled(); case OpenAI -> stateService.isOpenAIEnabled(); case Mistral -> stateService.isMistralEnabled(); @@ -356,6 +359,21 @@ public void llmSettingsChanged() { ); } + /** + * When a model is selected for the Exo provider, start preparing the instance in the background. + */ + private void handleModelNameSelectionChange(@NotNull ActionEvent e) { + if (!e.getActionCommand().equals(Constant.COMBO_BOX_CHANGED) || !isInitializationComplete || isUpdatingModelNames) return; + + ModelProvider provider = (ModelProvider) modelProviderComboBox.getSelectedItem(); + if (provider != ModelProvider.Exo) return; + + LanguageModel selectedModel = (LanguageModel) modelNameComboBox.getSelectedItem(); + if (selectedModel != null) { + ExoChatModelFactory.prepareInstanceAsync(selectedModel.getModelName(), project); + } + } + /** * Process the model provider selection change. * Set the model provider and update the model names. diff --git a/src/main/java/com/devoxx/genie/ui/settings/DevoxxGenieStateService.java b/src/main/java/com/devoxx/genie/ui/settings/DevoxxGenieStateService.java index 4d5a7b64..dcbc568c 100644 --- a/src/main/java/com/devoxx/genie/ui/settings/DevoxxGenieStateService.java +++ b/src/main/java/com/devoxx/genie/ui/settings/DevoxxGenieStateService.java @@ -87,6 +87,7 @@ public static DevoxxGenieStateService getInstance() { private String gpt4allModelUrl = GPT4ALL_MODEL_URL; private String janModelUrl = JAN_MODEL_URL; private String llamaCPPUrl = LLAMA_CPP_MODEL_URL; + private String exoModelUrl = EXO_MODEL_URL; // Local custom OpenAI-compliant LLM fields private String customOpenAIUrl = ""; @@ -99,6 +100,7 @@ public static DevoxxGenieStateService getInstance() { private boolean isGpt4AllEnabled = true; private boolean isJanEnabled = true; private boolean isLlamaCPPEnabled = true; + private boolean isExoEnabled = false; // Local custom OpenAI-compliant LLM fields private boolean isCustomOpenAIUrlEnabled = false; @@ -537,6 +539,7 @@ public Integer getSubAgentParallelism() { case "gpt4allModelUrl" -> getGpt4allModelUrl(); case "lmStudioModelUrl" -> getLmstudioModelUrl(); case "ollamaModelUrl" -> getOllamaModelUrl(); + case "exoModelUrl" -> getExoModelUrl(); default -> null; }; } diff --git a/src/main/java/com/devoxx/genie/ui/settings/agent/AgentSettingsComponent.java b/src/main/java/com/devoxx/genie/ui/settings/agent/AgentSettingsComponent.java index a3bf505b..27df3f89 100644 --- a/src/main/java/com/devoxx/genie/ui/settings/agent/AgentSettingsComponent.java +++ b/src/main/java/com/devoxx/genie/ui/settings/agent/AgentSettingsComponent.java @@ -284,6 +284,7 @@ private void populateProviderComboBox() { case GPT4All -> state.isGpt4AllEnabled(); case Jan -> state.isJanEnabled(); case LLaMA -> state.isLlamaCPPEnabled(); + case Exo -> state.isExoEnabled(); case CustomOpenAI -> state.isCustomOpenAIUrlEnabled(); case OpenAI -> state.isOpenAIEnabled(); case Mistral -> state.isMistralEnabled(); @@ -512,6 +513,7 @@ private void populateRowProviderComboBox() { case GPT4All -> state.isGpt4AllEnabled(); case Jan -> state.isJanEnabled(); case LLaMA -> state.isLlamaCPPEnabled(); + case Exo -> state.isExoEnabled(); case CustomOpenAI -> state.isCustomOpenAIUrlEnabled(); case OpenAI -> state.isOpenAIEnabled(); case Mistral -> state.isMistralEnabled(); diff --git a/src/main/java/com/devoxx/genie/ui/settings/llm/LLMProvidersComponent.java b/src/main/java/com/devoxx/genie/ui/settings/llm/LLMProvidersComponent.java index 3b975ca0..fccdd732 100644 --- a/src/main/java/com/devoxx/genie/ui/settings/llm/LLMProvidersComponent.java +++ b/src/main/java/com/devoxx/genie/ui/settings/llm/LLMProvidersComponent.java @@ -99,6 +99,10 @@ public class LLMProvidersComponent extends AbstractSettingsComponent { @Getter private final JCheckBox llamaCPPEnabledCheckBox = new JCheckBox("", stateService.isLlamaCPPEnabled()); @Getter + private final JTextField exoModelUrlField = new JTextField(stateService.getExoModelUrl()); + @Getter + private final JCheckBox exoEnabledCheckBox = new JCheckBox("", stateService.isExoEnabled()); + @Getter private final JCheckBox customOpenAIUrlEnabledCheckBox = new JCheckBox("", stateService.isCustomOpenAIUrlEnabled()); @Getter private final JCheckBox customOpenAIForceHttp11CheckBox = new JCheckBox("", stateService.isCustomOpenAIForceHttp11()); @@ -186,6 +190,9 @@ public JPanel createPanel() { createTextWithLinkButton(janModelUrlField, "https://jan.ai/download")); addProviderSettingRow(panel, gbc, "LLaMA.c++ URL", llamaCPPEnabledCheckBox, createTextWithLinkButton(llamaCPPModelUrlField, "https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md")); + addProviderSettingRow(panel, gbc, "Exo URL", exoEnabledCheckBox, + createTextWithLinkButton(exoModelUrlField, "https://genie.devoxx.com/docs/llm-providers/exo")); + addHintText(panel, gbc, "Distributed AI cluster — auto-creates model instances across connected devices"); addProviderSettingRow(panel, gbc, "Custom OpenAI URL", customOpenAIUrlEnabledCheckBox, customOpenAIUrlField); addProviderSettingRow(panel, gbc, "Custom OpenAI Model", customOpenAIModelNameEnabledCheckBox, customOpenAIModelNameField); addProviderSettingRow(panel, gbc, "Custom OpenAI API Key", enableCustomOpenAIApiKeyCheckBox, customOpenAIApiKeyField); @@ -251,6 +258,7 @@ public void addListeners() { gpt4AllEnabledCheckBox.addItemListener(e -> updateUrlFieldState(gpt4AllEnabledCheckBox, gpt4AllModelUrlField)); janEnabledCheckBox.addItemListener(e -> updateUrlFieldState(janEnabledCheckBox, janModelUrlField)); llamaCPPEnabledCheckBox.addItemListener(e -> updateUrlFieldState(llamaCPPEnabledCheckBox, llamaCPPModelUrlField)); + exoEnabledCheckBox.addItemListener(e -> updateUrlFieldState(exoEnabledCheckBox, exoModelUrlField)); customOpenAIUrlEnabledCheckBox.addItemListener(e -> updateUrlFieldState(customOpenAIUrlEnabledCheckBox, customOpenAIUrlField)); customOpenAIModelNameEnabledCheckBox.addItemListener(e -> updateUrlFieldState(customOpenAIModelNameEnabledCheckBox, customOpenAIModelNameField)); diff --git a/src/main/java/com/devoxx/genie/ui/settings/llm/LLMProvidersConfigurable.java b/src/main/java/com/devoxx/genie/ui/settings/llm/LLMProvidersConfigurable.java index 3da57389..fb2ba7ed 100644 --- a/src/main/java/com/devoxx/genie/ui/settings/llm/LLMProvidersConfigurable.java +++ b/src/main/java/com/devoxx/genie/ui/settings/llm/LLMProvidersConfigurable.java @@ -105,6 +105,8 @@ public boolean isModified() { isModified |= stateService.isGpt4AllEnabled() != llmSettingsComponent.getGpt4AllEnabledCheckBox().isSelected(); isModified |= stateService.isJanEnabled() != llmSettingsComponent.getJanEnabledCheckBox().isSelected(); isModified |= stateService.isLlamaCPPEnabled() != llmSettingsComponent.getLlamaCPPEnabledCheckBox().isSelected(); + isModified |= stateService.isExoEnabled() != llmSettingsComponent.getExoEnabledCheckBox().isSelected(); + isModified |= isFieldModified(llmSettingsComponent.getExoModelUrlField(), stateService.getExoModelUrl()); isModified |= stateService.isCustomOpenAIUrlEnabled() != llmSettingsComponent.getCustomOpenAIUrlEnabledCheckBox().isSelected(); isModified |= stateService.isCustomOpenAIModelNameEnabled() != llmSettingsComponent.getCustomOpenAIModelNameEnabledCheckBox().isSelected(); @@ -186,6 +188,8 @@ public void apply() { settings.setGpt4AllEnabled(llmSettingsComponent.getGpt4AllEnabledCheckBox().isSelected()); settings.setJanEnabled(llmSettingsComponent.getJanEnabledCheckBox().isSelected()); settings.setLlamaCPPEnabled(llmSettingsComponent.getLlamaCPPEnabledCheckBox().isSelected()); + settings.setExoEnabled(llmSettingsComponent.getExoEnabledCheckBox().isSelected()); + settings.setExoModelUrl(llmSettingsComponent.getExoModelUrlField().getText()); settings.setCustomOpenAIUrlEnabled(llmSettingsComponent.getCustomOpenAIUrlEnabledCheckBox().isSelected()); settings.setCustomOpenAIModelNameEnabled(llmSettingsComponent.getCustomOpenAIModelNameEnabledCheckBox().isSelected()); @@ -294,6 +298,8 @@ public void reset() { llmSettingsComponent.getGpt4AllEnabledCheckBox().setSelected(settings.isGpt4AllEnabled()); llmSettingsComponent.getJanEnabledCheckBox().setSelected(settings.isJanEnabled()); llmSettingsComponent.getLlamaCPPEnabledCheckBox().setSelected(settings.isLlamaCPPEnabled()); + llmSettingsComponent.getExoEnabledCheckBox().setSelected(settings.isExoEnabled()); + llmSettingsComponent.getExoModelUrlField().setText(settings.getExoModelUrl()); llmSettingsComponent.getCustomOpenAIUrlEnabledCheckBox().setSelected(settings.isCustomOpenAIUrlEnabled()); llmSettingsComponent.getCustomOpenAIModelNameEnabledCheckBox().setSelected(settings.isCustomOpenAIModelNameEnabled()); diff --git a/src/main/java/com/devoxx/genie/ui/window/DevoxxGenieToolWindowContent.java b/src/main/java/com/devoxx/genie/ui/window/DevoxxGenieToolWindowContent.java index 4f91aa68..1cc5ac40 100644 --- a/src/main/java/com/devoxx/genie/ui/window/DevoxxGenieToolWindowContent.java +++ b/src/main/java/com/devoxx/genie/ui/window/DevoxxGenieToolWindowContent.java @@ -9,6 +9,7 @@ import com.devoxx.genie.ui.component.border.AnimatedGlowingBorder; import com.devoxx.genie.ui.listener.GlowingListener; import com.devoxx.genie.ui.listener.SettingsChangeListener; +import com.devoxx.genie.ui.panel.ExoClusterPanel; import com.devoxx.genie.ui.panel.LlmProviderPanel; import com.devoxx.genie.ui.panel.PromptOutputPanel; import com.devoxx.genie.ui.panel.SubmitPanel; @@ -72,6 +73,7 @@ public class DevoxxGenieToolWindowContent implements SettingsChangeListener, Glo private SubmitPanel submitPanel; @Getter private PromptOutputPanel promptOutputPanel; + private ExoClusterPanel exoClusterPanel; private boolean isInitializationComplete = false; /** @@ -123,6 +125,7 @@ private void initializeComponents() { llmProviderPanel = new LlmProviderPanel(project, tabId); promptOutputPanel = new PromptOutputPanel(project, resourceBundle, tabId); submitPanel = new SubmitPanel(this); + exoClusterPanel = new ExoClusterPanel(); ExternalPromptService.getInstance(project).setPromptInputArea(submitPanel.getPromptInputArea()); } @@ -135,6 +138,21 @@ private void setupLayout() { private void setupListeners() { llmProviderPanel.getModelNameComboBox().addActionListener(this::processModelNameSelection); + llmProviderPanel.getModelProviderComboBox().addActionListener(this::updateExoClusterVisibility); + + // Show cluster panel if Exo is already selected on startup + updateExoClusterVisibility(null); + } + + private void updateExoClusterVisibility(ActionEvent e) { + ModelProvider provider = (ModelProvider) llmProviderPanel.getModelProviderComboBox().getSelectedItem(); + if (provider == ModelProvider.Exo) { + exoClusterPanel.startPolling(); + } else { + exoClusterPanel.stopPolling(); + } + contentPanel.revalidate(); + contentPanel.repaint(); } @Override @@ -155,6 +173,7 @@ public void stopGlowing() { private @NotNull JPanel createTopPanel() { JPanel topPanel = new JPanel(new BorderLayout()); topPanel.add(llmProviderPanel, BorderLayout.NORTH); + topPanel.add(exoClusterPanel, BorderLayout.CENTER); return topPanel; } diff --git a/src/main/resources/META-INF/plugin.xml b/src/main/resources/META-INF/plugin.xml index db1805af..2928da13 100644 --- a/src/main/resources/META-INF/plugin.xml +++ b/src/main/resources/META-INF/plugin.xml @@ -543,6 +543,7 @@ + diff --git a/src/test/java/com/devoxx/genie/chatmodel/local/exo/ExoChatModelFactoryTest.java b/src/test/java/com/devoxx/genie/chatmodel/local/exo/ExoChatModelFactoryTest.java new file mode 100644 index 00000000..370aae9c --- /dev/null +++ b/src/test/java/com/devoxx/genie/chatmodel/local/exo/ExoChatModelFactoryTest.java @@ -0,0 +1,185 @@ +package com.devoxx.genie.chatmodel.local.exo; + +import com.devoxx.genie.model.CustomChatModel; +import com.devoxx.genie.model.LanguageModel; +import com.devoxx.genie.model.enumarations.ModelProvider; +import com.devoxx.genie.model.exo.ExoModelEntryDTO; +import com.devoxx.genie.service.mcp.MCPService; +import com.devoxx.genie.ui.settings.DevoxxGenieStateService; +import dev.langchain4j.model.chat.ChatModel; +import dev.langchain4j.model.chat.StreamingChatModel; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.MockedStatic; +import org.mockito.Mockito; +import org.mockito.junit.jupiter.MockitoExtension; +import org.mockito.junit.jupiter.MockitoSettings; +import org.mockito.quality.Strictness; + +import java.io.IOException; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.*; + +@ExtendWith(MockitoExtension.class) +@MockitoSettings(strictness = Strictness.LENIENT) +class ExoChatModelFactoryTest { + + private MockedStatic mockedStateService; + private MockedStatic mockedMCPService; + private MockedStatic mockedExoModelService; + private DevoxxGenieStateService mockState; + private ExoModelService mockExoService; + + @BeforeEach + void setUp() { + mockState = mock(DevoxxGenieStateService.class); + when(mockState.getExoModelUrl()).thenReturn("http://localhost:52415/"); + when(mockState.getAgentModeEnabled()).thenReturn(false); + + mockedStateService = Mockito.mockStatic(DevoxxGenieStateService.class); + mockedStateService.when(DevoxxGenieStateService::getInstance).thenReturn(mockState); + + mockedMCPService = Mockito.mockStatic(MCPService.class); + mockedMCPService.when(MCPService::isMCPEnabled).thenReturn(false); + + mockExoService = mock(ExoModelService.class); + mockedExoModelService = Mockito.mockStatic(ExoModelService.class); + mockedExoModelService.when(ExoModelService::getInstance).thenReturn(mockExoService); + } + + @AfterEach + void tearDown() { + if (mockedStateService != null) mockedStateService.close(); + if (mockedMCPService != null) mockedMCPService.close(); + if (mockedExoModelService != null) mockedExoModelService.close(); + } + + @Test + void createChatModelShouldReturnNonNull() throws IOException { + // ensureInstance should not throw + doNothing().when(mockExoService).ensureInstance(anyString()); + + ExoChatModelFactory factory = new ExoChatModelFactory(); + + CustomChatModel customChatModel = new CustomChatModel(); + customChatModel.setModelName("mlx-community/MiniMax-M2.5-6bit"); + customChatModel.setTemperature(0.7); + customChatModel.setTopP(0.9); + customChatModel.setMaxTokens(256); + customChatModel.setMaxRetries(3); + customChatModel.setTimeout(30); + + ChatModel result = factory.createChatModel(customChatModel); + assertThat(result).isNotNull(); + } + + @Test + void createStreamingChatModelShouldReturnNonNull() throws IOException { + doNothing().when(mockExoService).ensureInstance(anyString()); + + ExoChatModelFactory factory = new ExoChatModelFactory(); + + CustomChatModel customChatModel = new CustomChatModel(); + customChatModel.setModelName("mlx-community/MiniMax-M2.5-6bit"); + customChatModel.setTemperature(0.7); + customChatModel.setTopP(0.9); + customChatModel.setTimeout(30); + + StreamingChatModel result = factory.createStreamingChatModel(customChatModel); + assertThat(result).isNotNull(); + } + + @Test + void createChatModelShouldCallEnsureInstance() throws IOException { + doNothing().when(mockExoService).ensureInstance(anyString()); + + ExoChatModelFactory factory = new ExoChatModelFactory(); + + CustomChatModel customChatModel = new CustomChatModel(); + customChatModel.setModelName("mlx-community/MiniMax-M2.5-6bit"); + + factory.createChatModel(customChatModel); + + verify(mockExoService).ensureInstance("mlx-community/MiniMax-M2.5-6bit"); + } + + @Test + void createChatModelShouldNotThrowWhenEnsureInstanceFails() throws IOException { + doThrow(new IOException("No valid placement found")).when(mockExoService).ensureInstance(anyString()); + + ExoChatModelFactory factory = new ExoChatModelFactory(); + + CustomChatModel customChatModel = new CustomChatModel(); + customChatModel.setModelName("mlx-community/MiniMax-M2.5-6bit"); + customChatModel.setTemperature(0.7); + customChatModel.setTopP(0.9); + customChatModel.setMaxTokens(256); + customChatModel.setMaxRetries(3); + customChatModel.setTimeout(30); + + // Should not throw — ensureInstance failure is handled gracefully + ChatModel result = factory.createChatModel(customChatModel); + assertThat(result).isNotNull(); + } + + @Test + void buildLanguageModelShouldMapFieldsCorrectly() throws IOException { + ExoChatModelFactory factory = new ExoChatModelFactory(); + + ExoModelEntryDTO dto = new ExoModelEntryDTO(); + dto.setId("mlx-community/Llama-3.2-1B-Instruct-4bit"); + dto.setName("Llama-3.2-1B-Instruct-4bit"); + dto.setContextLength(131072); + dto.setStorageSizeMegabytes(696); + + LanguageModel model = factory.buildLanguageModel(dto); + + assertThat(model.getProvider()).isEqualTo(ModelProvider.Exo); + assertThat(model.getModelName()).isEqualTo("mlx-community/Llama-3.2-1B-Instruct-4bit"); + assertThat(model.getDisplayName()).isEqualTo("Llama-3.2-1B-Instruct-4bit"); + assertThat(model.getInputMaxTokens()).isEqualTo(131072); + assertThat(model.getInputCost()).isEqualTo(0); + assertThat(model.getOutputCost()).isEqualTo(0); + assertThat(model.isApiKeyUsed()).isFalse(); + } + + @Test + void buildLanguageModelShouldUseDefaultContextWhenZero() throws IOException { + ExoChatModelFactory factory = new ExoChatModelFactory(); + + ExoModelEntryDTO dto = new ExoModelEntryDTO(); + dto.setId("mlx-community/SomeModel-4bit"); + dto.setName("SomeModel-4bit"); + dto.setContextLength(0); // Exo reports 0 for some models + + LanguageModel model = factory.buildLanguageModel(dto); + + assertThat(model.getInputMaxTokens()).isEqualTo(4096); + } + + @Test + void buildLanguageModelShouldFallbackToIdWhenNameIsNull() throws IOException { + ExoChatModelFactory factory = new ExoChatModelFactory(); + + ExoModelEntryDTO dto = new ExoModelEntryDTO(); + dto.setId("mlx-community/SomeModel-4bit"); + dto.setName(null); + dto.setContextLength(8192); + + LanguageModel model = factory.buildLanguageModel(dto); + + assertThat(model.getDisplayName()).isEqualTo("mlx-community/SomeModel-4bit"); + } + + @Test + void getModelUrlShouldReturnConfiguredUrl() { + ExoChatModelFactory factory = new ExoChatModelFactory(); + + // getModelUrl() is protected, but we test indirectly through createChatModel + // which uses it via createOpenAiChatModel. The URL is configured in mockState. + assertThat(mockState.getExoModelUrl()).isEqualTo("http://localhost:52415/"); + } +} diff --git a/src/test/java/com/devoxx/genie/chatmodel/local/exo/ExoModelServiceTest.java b/src/test/java/com/devoxx/genie/chatmodel/local/exo/ExoModelServiceTest.java new file mode 100644 index 00000000..24f62a3e --- /dev/null +++ b/src/test/java/com/devoxx/genie/chatmodel/local/exo/ExoModelServiceTest.java @@ -0,0 +1,211 @@ +package com.devoxx.genie.chatmodel.local.exo; + +import com.devoxx.genie.model.exo.ExoModelEntryDTO; +import com.devoxx.genie.ui.settings.DevoxxGenieStateService; +import com.devoxx.genie.util.HttpClientProvider; +import com.google.gson.Gson; +import okhttp3.*; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.MockedStatic; +import org.mockito.Mockito; +import org.mockito.junit.jupiter.MockitoExtension; +import org.mockito.junit.jupiter.MockitoSettings; +import org.mockito.quality.Strictness; + +import java.io.IOException; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +@ExtendWith(MockitoExtension.class) +@MockitoSettings(strictness = Strictness.LENIENT) +class ExoModelServiceTest { + + private MockedStatic mockedStateService; + private MockedStatic mockedHttpClient; + private DevoxxGenieStateService mockState; + private OkHttpClient mockClient; + + private static final String MODELS_RESPONSE = """ + { + "object": "list", + "data": [ + { + "id": "mlx-community/Llama-3.2-1B-Instruct-4bit", + "name": "Llama-3.2-1B-Instruct-4bit", + "context_length": 131072, + "storage_size_megabytes": 696, + "supports_tensor": true, + "family": "llama", + "quantization": "4bit", + "base_model": "Llama 3.2 1B Instruct", + "capabilities": ["text"] + }, + { + "id": "mlx-community/MiniMax-M2.5-6bit", + "name": "MiniMax-M2.5-6bit", + "context_length": 196608, + "storage_size_megabytes": 173000, + "supports_tensor": true, + "family": "minimax", + "quantization": "6bit", + "base_model": "MiniMax M2.5", + "capabilities": ["text", "thinking"] + } + ] + } + """; + + private static final String STATE_NO_INSTANCES = """ + { + "instances": {}, + "runners": {} + } + """; + + private static final String STATE_WITH_INSTANCE = """ + { + "instances": { + "abc-123": { + "MlxRingInstance": { + "instanceId": "abc-123", + "shardAssignments": { + "modelId": "mlx-community/MiniMax-M2.5-6bit" + } + } + } + }, + "runners": { + "runner-1": {"RunnerReady": {}} + } + } + """; + + private static final String PREVIEW_RESPONSE = """ + { + "previews": [ + { + "model_id": "mlx-community/MiniMax-M2.5-6bit", + "sharding": "Pipeline", + "instance_meta": "MlxRing", + "instance": {"MlxRingInstance": {"instanceId": "new-123"}}, + "memory_delta_by_node": {}, + "error": null + } + ] + } + """; + + private static final String PREVIEW_NO_VALID = """ + { + "previews": [ + { + "model_id": "mlx-community/MiniMax-M2.5-6bit", + "sharding": "Pipeline", + "instance_meta": "MlxRing", + "instance": null, + "memory_delta_by_node": null, + "error": "No cycles found with sufficient memory" + } + ] + } + """; + + @BeforeEach + void setUp() { + mockState = mock(DevoxxGenieStateService.class); + when(mockState.getExoModelUrl()).thenReturn("http://localhost:52415/"); + + mockedStateService = Mockito.mockStatic(DevoxxGenieStateService.class); + mockedStateService.when(DevoxxGenieStateService::getInstance).thenReturn(mockState); + + mockClient = mock(OkHttpClient.class); + mockedHttpClient = Mockito.mockStatic(HttpClientProvider.class); + mockedHttpClient.when(HttpClientProvider::getClient).thenReturn(mockClient); + } + + @AfterEach + void tearDown() { + if (mockedStateService != null) mockedStateService.close(); + if (mockedHttpClient != null) mockedHttpClient.close(); + } + + private Call mockCall(String responseBody, int code) throws IOException { + Call mockCall = mock(Call.class); + Response response = new Response.Builder() + .request(new Request.Builder().url("http://localhost:52415/models").build()) + .protocol(Protocol.HTTP_1_1) + .code(code) + .message("OK") + .body(ResponseBody.create(responseBody, MediaType.parse("application/json"))) + .build(); + when(mockCall.execute()).thenReturn(response); + return mockCall; + } + + @Test + void getModelsShouldParseResponseCorrectly() throws IOException { + Call mockCall = mockCall(MODELS_RESPONSE, 200); + when(mockClient.newCall(any(Request.class))).thenReturn(mockCall); + + ExoModelService service = new ExoModelService(); + ExoModelEntryDTO[] models = service.getModels(); + + assertThat(models).hasSize(2); + assertThat(models[0].getId()).isEqualTo("mlx-community/Llama-3.2-1B-Instruct-4bit"); + assertThat(models[0].getName()).isEqualTo("Llama-3.2-1B-Instruct-4bit"); + assertThat(models[0].getContextLength()).isEqualTo(131072); + assertThat(models[0].getFamily()).isEqualTo("llama"); + assertThat(models[1].getId()).isEqualTo("mlx-community/MiniMax-M2.5-6bit"); + assertThat(models[1].getContextLength()).isEqualTo(196608); + } + + @Test + void getModelsShouldThrowOnHttpError() throws IOException { + Call mockCall = mockCall("error", 500); + when(mockClient.newCall(any(Request.class))).thenReturn(mockCall); + + ExoModelService service = new ExoModelService(); + + assertThatThrownBy(service::getModels) + .isInstanceOf(IOException.class); + } + + @Test + void ensureInstanceShouldSkipWhenInstanceAlreadyExists() throws IOException { + // State check returns existing instance for this model + Call stateCall = mockCall(STATE_WITH_INSTANCE, 200); + when(mockClient.newCall(any(Request.class))).thenReturn(stateCall); + + ExoModelService service = new ExoModelService(); + // Should return without calling preview/create + service.ensureInstance("mlx-community/MiniMax-M2.5-6bit"); + + // Only 1 call should be made (state check), not 3 (state + preview + create) + Mockito.verify(mockClient, Mockito.times(1)).newCall(any(Request.class)); + } + + @Test + void ensureInstanceShouldThrowWhenNoValidPlacement() throws IOException { + // First call: state check returns no instances + Call stateCall = mockCall(STATE_NO_INSTANCES, 200); + // Second call: preview returns no valid placements + Call previewCall = mockCall(PREVIEW_NO_VALID, 200); + + when(mockClient.newCall(any(Request.class))) + .thenReturn(stateCall) + .thenReturn(previewCall); + + ExoModelService service = new ExoModelService(); + + assertThatThrownBy(() -> service.ensureInstance("mlx-community/MiniMax-M2.5-6bit")) + .isInstanceOf(IOException.class) + .hasMessageContaining("No valid placement found"); + } +} diff --git a/src/test/java/com/devoxx/genie/model/exo/ExoModelDTOTest.java b/src/test/java/com/devoxx/genie/model/exo/ExoModelDTOTest.java new file mode 100644 index 00000000..69b4a7cb --- /dev/null +++ b/src/test/java/com/devoxx/genie/model/exo/ExoModelDTOTest.java @@ -0,0 +1,89 @@ +package com.devoxx.genie.model.exo; + +import com.google.gson.Gson; +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +class ExoModelDTOTest { + + private static final Gson gson = new Gson(); + + private static final String FULL_RESPONSE = """ + { + "object": "list", + "data": [ + { + "id": "mlx-community/Llama-3.2-1B-Instruct-4bit", + "name": "Llama-3.2-1B-Instruct-4bit", + "context_length": 131072, + "storage_size_megabytes": 696, + "supports_tensor": true, + "family": "llama", + "quantization": "4bit", + "base_model": "Llama 3.2 1B Instruct", + "capabilities": ["text"] + }, + { + "id": "mlx-community/MiniMax-M2.5-6bit", + "name": "MiniMax-M2.5-6bit", + "context_length": 196608, + "storage_size_megabytes": 173000, + "supports_tensor": true, + "family": "minimax", + "quantization": "6bit", + "base_model": "MiniMax M2.5", + "capabilities": ["text", "thinking"] + } + ] + } + """; + + @Test + void shouldDeserializeFullResponse() { + ExoModelDTO dto = gson.fromJson(FULL_RESPONSE, ExoModelDTO.class); + + assertThat(dto.getObject()).isEqualTo("list"); + assertThat(dto.getData()).hasSize(2); + } + + @Test + void shouldDeserializeModelEntry() { + ExoModelDTO dto = gson.fromJson(FULL_RESPONSE, ExoModelDTO.class); + ExoModelEntryDTO first = dto.getData()[0]; + + assertThat(first.getId()).isEqualTo("mlx-community/Llama-3.2-1B-Instruct-4bit"); + assertThat(first.getName()).isEqualTo("Llama-3.2-1B-Instruct-4bit"); + assertThat(first.getContextLength()).isEqualTo(131072); + assertThat(first.getStorageSizeMegabytes()).isEqualTo(696); + assertThat(first.isSupportsTensor()).isTrue(); + assertThat(first.getFamily()).isEqualTo("llama"); + assertThat(first.getQuantization()).isEqualTo("4bit"); + assertThat(first.getBaseModel()).isEqualTo("Llama 3.2 1B Instruct"); + assertThat(first.getCapabilities()).containsExactly("text"); + } + + @Test + void shouldDeserializeModelWithMultipleCapabilities() { + ExoModelDTO dto = gson.fromJson(FULL_RESPONSE, ExoModelDTO.class); + ExoModelEntryDTO second = dto.getData()[1]; + + assertThat(second.getId()).isEqualTo("mlx-community/MiniMax-M2.5-6bit"); + assertThat(second.getContextLength()).isEqualTo(196608); + assertThat(second.getCapabilities()).containsExactly("text", "thinking"); + } + + @Test + void shouldHandleEmptyDataArray() { + String json = """ + { + "object": "list", + "data": [] + } + """; + + ExoModelDTO dto = gson.fromJson(json, ExoModelDTO.class); + + assertThat(dto.getData()).isEmpty(); + } +} diff --git a/src/test/java/com/devoxx/genie/ui/panel/LlmProviderPanelTest.java b/src/test/java/com/devoxx/genie/ui/panel/LlmProviderPanelTest.java index bef96f9d..1143ff58 100644 --- a/src/test/java/com/devoxx/genie/ui/panel/LlmProviderPanelTest.java +++ b/src/test/java/com/devoxx/genie/ui/panel/LlmProviderPanelTest.java @@ -84,6 +84,7 @@ void setUp() { lenient().when(stateService.isGpt4AllEnabled()).thenReturn(false); lenient().when(stateService.isJanEnabled()).thenReturn(false); lenient().when(stateService.isLlamaCPPEnabled()).thenReturn(false); + lenient().when(stateService.isExoEnabled()).thenReturn(false); lenient().when(stateService.isCustomOpenAIUrlEnabled()).thenReturn(false); lenient().when(stateService.isOpenAIEnabled()).thenReturn(true); lenient().when(stateService.isMistralEnabled()).thenReturn(false);