diff --git a/common/arg.cpp b/common/arg.cpp index c21598e7687..10119ca12e5 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -600,9 +600,9 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context common_params_handle_model(params.vocoder.model, params.hf_token, params.offline); } - // model is required (except for server) + // model is required (except for server, or when using --endpoint in CLI) // TODO @ngxson : maybe show a list of available models in CLI in this case - if (params.model.path.empty() && ctx_arg.ex != LLAMA_EXAMPLE_SERVER && !skip_model_download && !params.usage && !params.completion) { + if (params.model.path.empty() && ctx_arg.ex != LLAMA_EXAMPLE_SERVER && !skip_model_download && !params.usage && !params.completion && params.endpoint.empty()) { throw std::invalid_argument("error: --model is required\n"); } @@ -1398,6 +1398,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.show_timings = value; } ).set_examples({LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SHOW_TIMINGS")); + add_opt(common_arg( + {"--endpoint"}, "URL", + string_format("connect to a running llama-server at URL instead of loading a model locally (e.g. http://localhost:8080)"), + [](common_params & params, const std::string & value) { + params.endpoint = value; + } + ).set_examples({LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_ENDPOINT")); add_opt(common_arg( {"-f", "--file"}, "FNAME", "a file containing the prompt (default: none)", diff --git a/common/common.h b/common/common.h index a564b3b8c2b..9bb086315a7 100644 --- a/common/common.h +++ b/common/common.h @@ -555,6 +555,10 @@ struct common_params { bool single_turn = false; // single turn chat conversation + // remote server endpoint for CLI (e.g. "http://localhost:8080") + // when set, CLI connects to a running server instead of loading a model + std::string endpoint = ""; // NOLINT + ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V diff --git a/tools/cli/CMakeLists.txt b/tools/cli/CMakeLists.txt index 7e01abb81b9..885892ea093 100644 --- a/tools/cli/CMakeLists.txt +++ b/tools/cli/CMakeLists.txt @@ -1,9 +1,11 @@ set(TARGET llama-cli) -add_executable(${TARGET} cli.cpp) -target_link_libraries(${TARGET} PRIVATE server-context PUBLIC llama-common ${CMAKE_THREAD_LIBS_INIT}) +add_executable(${TARGET} cli.cpp cli-backend.cpp) +target_link_libraries(${TARGET} PRIVATE cpp-httplib PUBLIC llama-common ${CMAKE_THREAD_LIBS_INIT}) target_compile_features(${TARGET} PRIVATE cxx_std_17) include_directories(../server) +include_directories(../mtmd) +include_directories(../../vendor) if(LLAMA_TOOLS_INSTALL) install(TARGETS ${TARGET} RUNTIME) diff --git a/tools/cli/cli-backend.cpp b/tools/cli/cli-backend.cpp new file mode 100644 index 00000000000..3ffd4a7eb92 --- /dev/null +++ b/tools/cli/cli-backend.cpp @@ -0,0 +1,805 @@ +#include "cli-backend.h" + +#include "common.h" +#include "console.h" +#include "http.h" +#include "log.h" +#include "server-common.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef _WIN32 +#include +#include +#else +#include +#include +#include +#include +#include +#endif + +#if defined(__APPLE__) && defined(__MACH__) +#include +#include +#endif + +// shared with cli.cpp +extern std::atomic g_is_interrupted; + +// base64 encoding for multimodal content over HTTP +static std::string base64_encode(const std::string & in) { + static const char base64_chars[] = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + std::string out; + out.reserve(4 * ((in.size() + 2) / 3)); + int i = 0; + int j = 0; + unsigned char arr3[3]; + unsigned char arr4[4]; + size_t in_len = in.size(); + const unsigned char * bytes = reinterpret_cast(in.data()); + while (in_len--) { + arr3[i++] = *(bytes++); + if (i == 3) { + arr4[0] = (arr3[0] & 0xfc) >> 2; + arr4[1] = ((arr3[0] & 0x03) << 4) + ((arr3[1] & 0xf0) >> 4); + arr4[2] = ((arr3[1] & 0x0f) << 2) + ((arr3[2] & 0xc0) >> 6); + arr4[3] = arr3[2] & 0x3f; + for (i = 0; i < 4; i++) { + out += base64_chars[arr4[i]]; + } + i = 0; + } + } + if (i) { + for (j = i; j < 3; j++) { + arr3[j] = '\0'; + } + arr4[0] = (arr3[0] & 0xfc) >> 2; + arr4[1] = ((arr3[0] & 0x03) << 4) + ((arr3[1] & 0xf0) >> 4); + arr4[2] = ((arr3[1] & 0x0f) << 2) + ((arr3[2] & 0xc0) >> 6); + for (j = 0; j < i + 1; j++) { + out += base64_chars[arr4[j]]; + } + while (i++ < 3) { + out += '='; + } + } + return out; +} + +// get path to current executable +static std::filesystem::path get_current_exec_path() { +#if defined(_WIN32) + wchar_t buf[32768] = { 0 }; + DWORD len = GetModuleFileNameW(nullptr, buf, _countof(buf)); + if (len == 0 || len >= _countof(buf)) { + throw std::runtime_error("GetModuleFileNameW failed or path too long"); + } + return std::filesystem::path(buf); +#elif defined(__APPLE__) && defined(__MACH__) + char small_path[PATH_MAX]; + uint32_t size = sizeof(small_path); + + if (_NSGetExecutablePath(small_path, &size) == 0) { + try { + return std::filesystem::canonical(std::filesystem::path(small_path)); + } catch (...) { + return std::filesystem::path(small_path); + } + } else { + std::vector buf(size); + if (_NSGetExecutablePath(buf.data(), &size) == 0) { + try { + return std::filesystem::canonical(std::filesystem::path(buf.data())); + } catch (...) { + return std::filesystem::path(buf.data()); + } + } + throw std::runtime_error("_NSGetExecutablePath failed after buffer resize"); + } +#else + char path[FILENAME_MAX]; + ssize_t count = readlink("/proc/self/exe", path, FILENAME_MAX); + if (count <= 0) { + throw std::runtime_error("failed to resolve /proc/self/exe"); + } + return std::filesystem::path(std::string(path, count)); +#endif +} + +// get path to llama-server executable +static std::filesystem::path get_server_exec_path() { + std::filesystem::path exec_path = get_current_exec_path(); + std::filesystem::path exec_dir = exec_path.parent_path(); + +#if defined(_WIN32) + return exec_dir / "llama-server.exe"; +#else + return exec_dir / "llama-server"; +#endif +} + +// check if a port is available +static bool is_port_available(int port) { +#ifdef _WIN32 + typedef SOCKET native_socket_t; +#define INVALID_SOCKET_VAL INVALID_SOCKET +#define CLOSE_SOCKET(s) closesocket(s) + WSADATA wsaData; + if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) { + return false; + } +#else + typedef int native_socket_t; +#define INVALID_SOCKET_VAL (-1) +#define CLOSE_SOCKET(s) close(s) +#endif + + native_socket_t sock = socket(AF_INET, SOCK_STREAM, 0); + if (sock == INVALID_SOCKET_VAL) { +#ifdef _WIN32 + WSACleanup(); +#endif + return false; + } + + struct sockaddr_in serv_addr; + std::memset(&serv_addr, 0, sizeof(serv_addr)); + serv_addr.sin_family = AF_INET; + serv_addr.sin_addr.s_addr = htonl(INADDR_ANY); + serv_addr.sin_port = htons(port); + + bool available = bind(sock, (struct sockaddr*)&serv_addr, sizeof(serv_addr)) == 0; + + CLOSE_SOCKET(sock); +#ifdef _WIN32 + WSACleanup(); +#endif + + return available; +} + +// helper to convert vector to char ** +static std::vector to_char_ptr_array(const std::vector & vec) { + std::vector result; + result.reserve(vec.size() + 1); + for (const auto & s : vec) { + result.push_back(const_cast(s.c_str())); + } + result.push_back(nullptr); + return result; +} + +// +// cli_backend implementation +// + +cli_backend::cli_backend() = default; + +// List of CLI-specific arguments that should NOT be passed to llama-server +// These are arguments that the CLI consumes directly +static const std::vector CLI_SPECIFIC_ARGS = { + // Connection + "--endpoint", + // Input (prompt related) + "-p", "--prompt", + "-f", "--file", + "-sys", "--system-prompt", + "-sysf", "--system-prompt-file", + "--image", + // CLI behavior/display + "--show-timings", "--no-show-timings", + "--simple-io", + "--color", "--no-color", + "--multiline-input", + "-cnv", "--conversation", "--no-conversation", + "-i", "--interactive", + "--verbose-prompt", +}; + +// Check if an argument is CLI-specific +static bool is_cli_specific_arg(const std::string & arg) { + return std::any_of(CLI_SPECIFIC_ARGS.begin(), CLI_SPECIFIC_ARGS.end(), [&] (auto & x) { + return (arg == x || arg.find(x + "=") == 0); + + }); +} + +bool cli_backend::spawn_local_server(int argc, char ** argv) { + // 1. Find available port starting from 8080 + int port = 8080; + const int max_port = 65535; + while (port < max_port && !is_port_available(port)) { + port++; + } + if (port >= max_port) { + console::error("Failed to find an available port\n"); + return false; + } + + // 2. Get path to llama-server executable + std::filesystem::path server_path; + try { + server_path = get_server_exec_path(); + } catch (const std::exception & e) { + console::error("Failed to find llama-server executable: %s\n", e.what()); + return false; + } + + if (!std::filesystem::exists(server_path)) { + console::error("llama-server executable not found at: %s\n", server_path.string().c_str()); + return false; + } + + // 3. Build command line arguments + // Start with the server executable + std::vector args; + args.push_back(server_path.string()); + + // Add --host 127.0.0.1 to bind only to localhost + args.push_back("--host"); + args.push_back("127.0.0.1"); + + // Add the available port + args.push_back("--port"); + args.push_back(std::to_string(port)); + + // Filter and copy arguments from original argv + // Skip argv[0] (program name) + for (int i = 1; i < argc; i++) { + std::string arg = argv[i]; + + // Skip CLI-specific arguments + if (is_cli_specific_arg(arg)) { + // Skip the value for arguments that take one + // Check if this arg doesn't contain '=' and next arg doesn't start with '-' + if (arg.find('=') == std::string::npos && + (arg == "-p" || arg == "--prompt" || + arg == "-f" || arg == "--file" || + arg == "-sys" || arg == "--system-prompt" || + arg == "-sysf" || arg == "--system-prompt-file" || + arg == "--image" || arg == "--endpoint")) { + if (i + 1 < argc && argv[i + 1][0] != '-') { + i++; // Skip the value too + } + } + continue; + } + + args.push_back(arg); + } + + // 4. Spawn subprocess in its own session/process group + // so that Ctrl+C (SIGINT) from the terminal doesn't reach the server. + server_proc = std::make_shared(); + std::vector server_argv = to_char_ptr_array(args); + + int options = subprocess_option_no_window + | subprocess_option_combined_stdout_stderr + | subprocess_option_new_session; + + int spawn_result = subprocess_create_ex(server_argv.data(), options, nullptr, server_proc.get()); + + if (spawn_result != 0) { + console::error("Failed to spawn llama-server subprocess\n"); + server_proc.reset(); + return false; + } + + is_local_server = true; + + // 5. Wait for server to be ready by polling /props endpoint + endpoint_url = "http://127.0.0.1:" + std::to_string(port); + + const int max_retries = 60; // 30 seconds total (500ms per retry) + const int retry_delay_ms = 500; + + for (int i = 0; i < max_retries && !g_is_interrupted.load(); i++) { + std::this_thread::sleep_for(std::chrono::milliseconds(retry_delay_ms)); + + try { + auto http_res = common_http_client(endpoint_url); + httplib::Client & cli = http_res.first; + cli.set_connection_timeout(1, 0); + cli.set_read_timeout(1, 0); + + auto res = cli.Get("/props"); + if (res && (res->status == 200 || res->status == 404)) { + // Server is ready, parse metadata + auto data = json::parse(res->body); + model_name = json_value(data, "model_alias", std::string("unknown")); + build_info_ = json_value(data, "build_info", std::string("")); + + if (data.contains("modalities")) { + auto mods = data.at("modalities"); + has_vision_ = json_value(mods, "vision", false); + has_audio_ = json_value(mods, "audio", false); + } + + return true; + } + } catch (...) { + LOG_DBG("Server not yet ready, still polling..."); + // Server not ready yet, continue polling + } + } + + // Timeout - terminate subprocess + console::error("Timeout waiting for local server to start\n"); + subprocess_terminate(server_proc.get()); + subprocess_destroy(server_proc.get()); + server_proc.reset(); + is_local_server = false; + return false; +} + +bool cli_backend::connect() { + // strip trailing slash + while (!endpoint_url.empty() && endpoint_url.back() == '/') { + endpoint_url.pop_back(); + } + + // query /props to get server metadata + try { + auto http_res = common_http_client(endpoint_url); + httplib::Client & cli = http_res.first; + cli.set_connection_timeout(5, 0); + cli.set_read_timeout(10, 0); + + auto res = cli.Get("/props"); + if (!res || res->status != 200) { + console::error("Failed to connect to server at %s\n", endpoint_url.c_str()); + if (res) { + console::error("HTTP status: %d\n", res->status); + } else { + console::error("Connection error: %s\n", httplib::to_string(res.error()).c_str()); + } + return false; + } + + auto data = json::parse(res->body); + model_name = json_value(data, "model_alias", std::string("unknown")); + build_info_ = json_value(data, "build_info", std::string("")); + + if (data.contains("modalities")) { + auto mods = data.at("modalities"); + has_vision_ = json_value(mods, "vision", false); + has_audio_ = json_value(mods, "audio", false); + } + + return true; + } catch (const std::exception & e) { + console::error("Failed to connect to server at %s: %s\n", endpoint_url.c_str(), e.what()); + return false; + } +} + +std::string cli_backend::get_model_name() const { return model_name; } +bool cli_backend::has_vision() const { return has_vision_; } +bool cli_backend::has_audio() const { return has_audio_; } +std::string cli_backend::get_build_info() const { return build_info_; } + +std::string cli_backend::generate_completion( + const json & messages, + const common_params & params, + bool verbose_prompt, + result_timings & out_timings) { + // build the OAI chat completion request + json request_body = { + {"messages", messages}, + {"stream", true}, + }; + + // sampling parameters + { + const auto & s = params.sampling; + if (s.temp != 0.8f) { request_body["temperature"] = s.temp; } + if (s.top_k != 40) { request_body["top_k"] = s.top_k; } + if (s.top_p != 0.95f) { request_body["top_p"] = s.top_p; } + if (s.min_p != 0.05f) { request_body["min_p"] = s.min_p; } + if (s.penalty_repeat != 1.0f) { request_body["repeat_penalty"] = s.penalty_repeat; } + if (s.penalty_present != 0.0f) { request_body["presence_penalty"] = s.penalty_present; } + if (s.penalty_freq != 0.0f) { request_body["frequency_penalty"] = s.penalty_freq; } + if (s.seed != LLAMA_DEFAULT_SEED) { request_body["seed"] = s.seed; } + } + + if (params.n_predict >= 0) { + request_body["max_tokens"] = params.n_predict; + } + + if (!params.antiprompt.empty()) { + request_body["stop"] = params.antiprompt; + } + + // reasoning budget + if (params.sampling.reasoning_budget_tokens >= 0) { + request_body["thinking_budget_tokens"] = params.sampling.reasoning_budget_tokens; + } + + // reasoning/thinking control via chat_template_kwargs + if (params.enable_reasoning == 0) { + request_body["chat_template_kwargs"] = json{{"enable_thinking", false}}; + } + + if (verbose_prompt) { + console::set_display(DISPLAY_TYPE_PROMPT); + console::log("POST /v1/chat/completions %s\n\n", request_body.dump().c_str()); + console::set_display(DISPLAY_TYPE_RESET); + } + + // do the HTTP request with SSE streaming + std::string curr_content; + bool is_thinking = false; + std::thread req_thread; + + try { + // Reset interrupt flag at start of new request + if (g_is_interrupted.load()) { + LOG_DBG("Resetting interrupt flag at start of new request"); + g_is_interrupted.store(false); + this->was_interrupted = true; + // Longer delay after interruption to let server clean up slot + std::this_thread::sleep_for(std::chrono::milliseconds(500)); + } + + // After interruption, poll server until ready + if (this->was_interrupted) { + LOG_DBG("Polling server slots after interruption"); + auto test_res = common_http_client(endpoint_url); + httplib::Client & test_cli = test_res.first; + test_cli.set_connection_timeout(2, 0); + int retries = 0; + bool connected = false; + while (retries < 100) { + auto res = test_cli.Get("/slots?fail_on_no_slot=1"); + if (res && (res->status == 200 || res->status == 404)) { + LOG_DBG("Server ready after %d retries (status=%d)", retries, res ? res->status : 0); + connected = true; + break; + } + // If cannot connect at all, server may have crashed + if (!res) { + LOG_DBG("Cannot connect to server (retry %d)", retries); + } + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + retries++; + } + this->was_interrupted = false; + if (!connected) { + LOG_DBG("Failed to reconnect to server after interruption"); + console::error("Server unavailable after interruption\n"); + return ""; + } + } + LOG_DBG("Connecting to %s", endpoint_url.c_str()); + auto http_res = common_http_client(endpoint_url); + LOG_DBG("Connected successfully"); + httplib::Client cli = std::move(http_res.first); + cli.set_connection_timeout(10, 0); + cli.set_read_timeout(600, 0); // 10 min read timeout for long generations + cli.set_write_timeout(10, 0); + + std::string body_str = request_body.dump(); + + auto done = std::make_shared>(false); + auto error_msg = std::make_shared(); + + httplib::Request req; + req.method = "POST"; + req.path = "/v1/chat/completions"; + req.set_header("Content-Type", "application/json"); + req.body = body_str; + + // collect chunks into a thread-safe queue + auto chunks = std::make_shared>(); + auto chunks_mutex = std::make_shared(); + auto chunks_cv = std::make_shared(); + + req.content_receiver = [chunks, chunks_mutex, chunks_cv](const char * data, size_t data_length, size_t, size_t) -> bool { + std::string chunk(data, data_length); + { + std::lock_guard lock(*chunks_mutex); + chunks->push_back(std::move(chunk)); + } + chunks_cv->notify_one(); + return true; + }; + + // send request in a separate thread + req_thread = std::thread([&cli, req = std::move(req), done, error_msg]() mutable { + LOG_DBG("HTTP request thread started"); + auto result = cli.send(req); + LOG_DBG("HTTP request done, error=%d", (int)result.error()); + // Don't set error if interrupted (Error::Read or Error::Connection expected) + if (result.error() != httplib::Error::Success) { + auto err_str = httplib::to_string(result.error()); + // Only set error if not a connection error (likely from interruption) + if (err_str.find("read") == std::string::npos && + err_str.find("connection") == std::string::npos) { + *error_msg = err_str; + LOG_DBG("HTTP error set: %s", error_msg->c_str()); + } else { + LOG_DBG("HTTP connection error (likely from interrupt): %s", err_str.c_str()); + } + } else if (result && result->status != 200) { + try { + auto err_json = json::parse(result->body); + if (err_json.contains("error") && err_json["error"].contains("message")) { + *error_msg = err_json["error"]["message"].get(); + } else { + *error_msg = "HTTP " + std::to_string(result->status) + ": " + result->body.substr(0, 500); + } + } catch (...) { + *error_msg = "HTTP " + std::to_string(result->status) + ": " + result->body.substr(0, 500); + } + } + done->store(true); + }); + + // process SSE stream + console::spinner::start(); + + std::string sse_buffer; + bool first_chunk = true; + bool stream_done = false; + bool interrupted = false; + + while (!stream_done && !interrupted) { + // Check for interrupt at start of each iteration + if (g_is_interrupted.load()) { + LOG_DBG("Interrupt detected, stopping client"); + interrupted = true; + cli.stop(); + break; + } + + // wait for data + std::string new_data; + { + std::unique_lock lock(*chunks_mutex); + if (chunks->empty()) { + chunks_cv->wait_for(lock, std::chrono::milliseconds(100), [&] { + return !chunks->empty() || done->load() || g_is_interrupted.load(); + }); + if (chunks->empty()) { + if (done->load()) { + break; + } + if (g_is_interrupted.load()) { + interrupted = true; + cli.stop(); + break; + } + continue; + } + } + for (auto & chunk : *chunks) { + new_data += chunk; + } + chunks->clear(); + } + + sse_buffer += new_data; + + // process SSE lines + size_t pos; + while ((pos = sse_buffer.find("\n\n")) != std::string::npos) { + std::string event = sse_buffer.substr(0, pos); + sse_buffer.erase(0, pos + 2); + + // look for data: line + size_t data_pos = event.find("data: "); + if (data_pos == std::string::npos) { continue; } + + std::string data_line = event.substr(data_pos + 6); + + if (data_line == "[DONE]") { + stream_done = true; + break; + } + + try { + auto chunk_json = json::parse(data_line); + if (!chunk_json.contains("choices")) { continue; } + + auto & choices = chunk_json["choices"]; + if (choices.empty()) { continue; } + + // Extract timings if present (usually in the final chunk) + if (chunk_json.contains("timings")) { + auto & t = chunk_json["timings"]; + out_timings.prompt_n = json_value(t, "prompt_n", out_timings.prompt_n); + out_timings.prompt_ms = json_value(t, "prompt_ms", out_timings.prompt_ms); + out_timings.prompt_per_second = json_value(t, "prompt_per_second", out_timings.prompt_per_second); + out_timings.predicted_n = json_value(t, "predicted_n", out_timings.predicted_n); + out_timings.predicted_ms = json_value(t, "predicted_ms", out_timings.predicted_ms); + out_timings.predicted_per_second = json_value(t, "predicted_per_second", out_timings.predicted_per_second); + } + + // Check for completion + if (choices[0].contains("finish_reason") && !choices[0]["finish_reason"].is_null()) { + stream_done = true; + break; + } + + if (!choices[0].contains("delta")) { continue; } + + auto & delta = choices[0]["delta"]; + + if (first_chunk) { + first_chunk = false; + console::spinner::stop(); + } + + // Handle content (only display if not interrupted) + if (delta.contains("content") && !delta["content"].is_null()) { + std::string content_delta = delta["content"].get(); + curr_content += content_delta; + if (!interrupted) { + if (is_thinking) { + console::log("\n[End thinking]\n\n"); + console::set_display(DISPLAY_TYPE_RESET); + is_thinking = false; + } + console::log("%s", content_delta.c_str()); + console::flush(); + } + } + + // Handle reasoning_content (only display if not interrupted) + if (delta.contains("reasoning_content") && !delta["reasoning_content"].is_null()) { + std::string reasoning_delta = delta["reasoning_content"].get(); + if (!interrupted) { + console::set_display(DISPLAY_TYPE_REASONING); + if (!is_thinking) { + console::log("[Start thinking]\n"); + } + is_thinking = true; + console::log("%s", reasoning_delta.c_str()); + console::flush(); + } + } + } catch (const json::parse_error & e) { + // Ignore parsing errors for malformed chunks + LOG_DBG("Ignoring malformed chunk due to JSON parse error: %s\n", e.what()); + } + } + } + + console::spinner::stop(); + + // Wait for request thread to finish (if not already done) + if (req_thread.joinable()) { + req_thread.join(); + } + + // Reset interrupt flag if we were interrupted + if (interrupted) { + g_is_interrupted.store(false); + this->was_interrupted = true; + // Don't show error messages when user intentionally interrupted + // Give server a moment to clean up the aborted request + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + return curr_content; + } + + // Check for errors (only if not interrupted) + // When interrupted, we expect connection errors - don't show them + if (!interrupted && !error_msg->empty()) { + LOG_DBG("Error during request: %s", error_msg->c_str()); + console::error("Error: %s\n", error_msg->c_str()); + return curr_content; + } + + LOG_DBG("Request completed successfully, content length: %zu", curr_content.length()); + return curr_content; + } catch (const std::exception & e) { + console::spinner::stop(); + // Make sure to join the thread on exception + if (req_thread.joinable()) { + req_thread.join(); + } + // Don't show error if we were interrupted + if (!g_is_interrupted.load()) { + LOG_DBG("Exception during request: %s", e.what()); + console::error("Error in generate_completion: %s\n", e.what()); + } else { + LOG_DBG("Exception during request but interrupted, suppressing: %s", e.what()); + } + return curr_content; + } +} + +std::string cli_backend::load_text_file(const std::string & fname) { + std::ifstream file(fname, std::ios::binary); + if (!file) { + return ""; + } + return std::string((std::istreambuf_iterator(file)), std::istreambuf_iterator()); +} + +json cli_backend::load_media_file(const std::string & fname) { + std::ifstream file(fname, std::ios::binary); + if (!file) { + return json::object(); + } + + std::string data((std::istreambuf_iterator(file)), std::istreambuf_iterator()); + if (data.empty()) { + return json::object(); + } + + // detect file type from extension + std::string ext; + size_t dot_pos = fname.find_last_of('.'); + if (dot_pos != std::string::npos) { + ext = fname.substr(dot_pos); + std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower); + } + + // base64 encode + std::string b64 = base64_encode(data); + + // Media type info struct + struct media_type_info { + std::string media_type; + std::string content_type; + }; + + // Map of extension to media type info + static const std::unordered_map MEDIA_TYPE_MAP = { + {".png", {"image/png", "image_url"}}, + {".jpg", {"image/jpeg", "image_url"}}, + {".jpeg", {"image/jpeg", "image_url"}}, + {".gif", {"image/gif", "image_url"}}, + {".webp", {"image/webp", "image_url"}}, + {".wav", {"audio/wav", "input_audio"}}, + {".mp3", {"audio/mpeg", "input_audio"}}, + }; + + const auto it = MEDIA_TYPE_MAP.find(ext); + if (it == MEDIA_TYPE_MAP.end()) { + // unknown type, treat as binary + return json::object(); + } + const auto & info = it->second; + + // build the OAI content part + if (info.content_type == "image_url") { + return json{ + {"type", "image_url"}, + {"image_url", { + {"url", "data:" + info.media_type + ";base64," + b64} + }} + }; + } + // audio + return json{ + {"type", "input_audio"}, + {"input_audio", { + {"data", b64}, + {"format", ext == ".mp3" ? "mp3" : "wav"} + }} + }; +} + +void cli_backend::terminate() { + // terminate subprocess if running local server + if (is_local_server && server_proc) { + subprocess_terminate(server_proc.get()); + subprocess_destroy(server_proc.get()); + server_proc.reset(); + is_local_server = false; + } +} diff --git a/tools/cli/cli-backend.h b/tools/cli/cli-backend.h new file mode 100644 index 00000000000..779abd20f70 --- /dev/null +++ b/tools/cli/cli-backend.h @@ -0,0 +1,82 @@ +#pragma once + +#include "common.h" +#include "console.h" + +#include + +// forward declaration for subprocess +struct subprocess_s; + +using json = nlohmann::ordered_json; + +// result_timings struct - copied from server-task.h for CLI use +struct result_timings { + int32_t cache_n = -1; + + int32_t prompt_n = -1; + double prompt_ms = 0.0; + double prompt_per_token_ms = 0.0; + double prompt_per_second = 0.0; + + int32_t predicted_n = -1; + double predicted_ms = 0.0; + double predicted_per_token_ms = 0.0; + double predicted_per_second = 0.0; + + // Optional speculative metrics + int32_t predicted_draft_n = -1; + double predicted_draft_ms = 0.0; + double draft_per_token_ms = 0.0; +}; + +// +// Backend interface — connects to llama-server via HTTP(S) +// Can connect to external server or spawn local server subprocess +// + +struct cli_backend { + std::string endpoint_url; + std::string model_name; + bool has_vision_ = false; + bool has_audio_ = false; + std::string build_info_; + + // subprocess management for local server + std::shared_ptr server_proc; + bool is_local_server = false; + bool was_interrupted = false; // track if last request was interrupted + + cli_backend(); + + // connect to a running server at endpoint_url + bool connect(); + + // spawn a local server subprocess + // argc/argv are the original command-line arguments + // CLI-specific arguments are filtered out, rest passed to llama-server + bool spawn_local_server(int argc, char ** argv); + + // model / server info + std::string get_model_name() const; + bool has_vision() const; + bool has_audio() const; + std::string get_build_info() const; + + // chat completion (streaming), returns assistant content text + std::string generate_completion( + const json & messages, + const common_params & params, + bool verbose_prompt, + result_timings & out_timings); + + // load a local text file, return its contents (empty string on failure) + static std::string load_text_file(const std::string & fname); + + // load a local media file, return the OAI content part JSON for it + // returns empty JSON object on failure + static json load_media_file(const std::string & fname); + + // cleanup - terminates subprocess if running local server + void terminate(); +}; diff --git a/tools/cli/cli.cpp b/tools/cli/cli.cpp index 369c24216b7..e1e68493064 100644 --- a/tools/cli/cli.cpp +++ b/tools/cli/cli.cpp @@ -2,12 +2,11 @@ #include "common.h" #include "arg.h" #include "console.h" -#include "fit.h" -// #include "log.h" +#include "log.h" -#include "server-common.h" -#include "server-context.h" -#include "server-task.h" +#include + +#include "cli-backend.h" #include #include @@ -35,7 +34,7 @@ const char * LLAMA_ASCII_LOGO = R"( ▀▀ ▀▀ )"; -static std::atomic g_is_interrupted = false; +std::atomic g_is_interrupted = false; static bool should_stop() { return g_is_interrupted.load(); } @@ -53,176 +52,6 @@ static void signal_handler(int) { } #endif -struct cli_context { - server_context ctx_server; - json messages = json::array(); - std::vector input_files; - task_params defaults; - bool verbose_prompt; - - // thread for showing "loading" animation - std::atomic loading_show; - - cli_context(const common_params & params) { - defaults.sampling = params.sampling; - defaults.speculative = params.speculative; - defaults.n_keep = params.n_keep; - defaults.n_predict = params.n_predict; - defaults.antiprompt = params.antiprompt; - - defaults.stream = true; // make sure we always use streaming mode - defaults.timings_per_token = true; // in order to get timings even when we cancel mid-way - // defaults.return_progress = true; // TODO: show progress - - verbose_prompt = params.verbose_prompt; - } - - std::string generate_completion(result_timings & out_timings) { - server_response_reader rd = ctx_server.get_response_reader(); - auto chat_params = format_chat(); - { - // TODO: reduce some copies here in the future - server_task task = server_task(SERVER_TASK_TYPE_COMPLETION); - task.id = rd.get_new_id(); - task.index = 0; - task.params = defaults; // copy - task.cli_prompt = chat_params.prompt; // copy - task.cli_files = input_files; // copy - task.cli = true; - - // chat template settings - task.params.chat_parser_params = common_chat_parser_params(chat_params); - task.params.chat_parser_params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; - if (!chat_params.parser.empty()) { - task.params.chat_parser_params.parser.load(chat_params.parser); - } - - // reasoning budget sampler - if (!chat_params.thinking_end_tag.empty()) { - const llama_vocab * vocab = llama_model_get_vocab( - llama_get_model(ctx_server.get_llama_context())); - - task.params.sampling.reasoning_budget_tokens = defaults.sampling.reasoning_budget_tokens; - task.params.sampling.generation_prompt = chat_params.generation_prompt; - - if (!chat_params.thinking_start_tag.empty()) { - task.params.sampling.reasoning_budget_start = - common_tokenize(vocab, chat_params.thinking_start_tag, false, true); - } - task.params.sampling.reasoning_budget_end = - common_tokenize(vocab, chat_params.thinking_end_tag, false, true); - task.params.sampling.reasoning_budget_forced = - common_tokenize(vocab, defaults.sampling.reasoning_budget_message + chat_params.thinking_end_tag, false, true); - } - - rd.post_task({std::move(task)}); - } - - if (verbose_prompt) { - console::set_display(DISPLAY_TYPE_PROMPT); - console::log("%s\n\n", chat_params.prompt.c_str()); - console::set_display(DISPLAY_TYPE_RESET); - } - - // wait for first result - console::spinner::start(); - server_task_result_ptr result = rd.next(should_stop); - - console::spinner::stop(); - std::string curr_content; - bool is_thinking = false; - - while (result) { - if (should_stop()) { - break; - } - if (result->is_error()) { - json err_data = result->to_json(); - if (err_data.contains("message")) { - console::error("Error: %s\n", err_data["message"].get().c_str()); - } else { - console::error("Error: %s\n", err_data.dump().c_str()); - } - return curr_content; - } - auto res_partial = dynamic_cast(result.get()); - if (res_partial) { - out_timings = std::move(res_partial->timings); - for (const auto & diff : res_partial->oaicompat_msg_diffs) { - if (!diff.content_delta.empty()) { - if (is_thinking) { - console::log("\n[End thinking]\n\n"); - console::set_display(DISPLAY_TYPE_RESET); - is_thinking = false; - } - curr_content += diff.content_delta; - console::log("%s", diff.content_delta.c_str()); - console::flush(); - } - if (!diff.reasoning_content_delta.empty()) { - console::set_display(DISPLAY_TYPE_REASONING); - if (!is_thinking) { - console::log("[Start thinking]\n"); - } - is_thinking = true; - console::log("%s", diff.reasoning_content_delta.c_str()); - console::flush(); - } - } - } - auto res_final = dynamic_cast(result.get()); - if (res_final) { - out_timings = std::move(res_final->timings); - break; - } - result = rd.next(should_stop); - } - g_is_interrupted.store(false); - // server_response_reader automatically cancels pending tasks upon destruction - return curr_content; - } - - // TODO: support remote files in the future (http, https, etc) - std::string load_input_file(const std::string & fname, bool is_media) { - std::ifstream file(fname, std::ios::binary); - if (!file) { - return ""; - } - if (is_media) { - raw_buffer buf; - buf.assign((std::istreambuf_iterator(file)), std::istreambuf_iterator()); - input_files.push_back(std::move(buf)); - return get_media_marker(); - } else { - std::string content((std::istreambuf_iterator(file)), std::istreambuf_iterator()); - return content; - } - } - - common_chat_params format_chat() { - auto meta = ctx_server.get_meta(); - auto & chat_params = meta.chat_params; - - auto caps = common_chat_templates_get_caps(chat_params.tmpls.get()); - - common_chat_templates_inputs inputs; - inputs.messages = common_chat_msgs_parse_oaicompat(messages); - inputs.tools = {}; // TODO - inputs.tool_choice = COMMON_CHAT_TOOL_CHOICE_NONE; - inputs.json_schema = ""; // TODO - inputs.grammar = ""; // TODO - inputs.use_jinja = chat_params.use_jinja; - inputs.parallel_tool_calls = caps["supports_parallel_tool_calls"]; - inputs.add_generation_prompt = true; - inputs.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; - inputs.force_pure_content = chat_params.force_pure_content; - inputs.enable_thinking = chat_params.enable_thinking ? common_chat_templates_support_enable_thinking(chat_params.tmpls.get()) : false; - - // Apply chat template to the list of messages - return common_chat_templates_apply(chat_params.tmpls.get(), inputs); - } -}; - // TODO?: Make this reusable, enums, docs static const std::array cmds = { "/audio ", @@ -359,12 +188,6 @@ int main(int argc, char ** argv) { console::error("please use llama-completion instead\n"); } - // struct that contains llama context and inference - cli_context ctx_cli(params); - - llama_backend_init(); - llama_numa_init(params.numa); - // TODO: avoid using atexit() here by making `console` a singleton console::init(params.simple_io, params.use_color); atexit([]() { console::cleanup(); }); @@ -386,33 +209,48 @@ int main(int argc, char ** argv) { SetConsoleCtrlHandler(reinterpret_cast(console_ctrl_handler), true); #endif - console::log("\nLoading model... "); // followed by loading animation - console::spinner::start(); - if (!ctx_cli.ctx_server.load_model(params)) { - console::spinner::stop(); - console::error("\nFailed to load the model\n"); - return 1; - } + // shared state between backend and the interactive loop + json messages = json::array(); + std::vector pending_media_parts; + bool was_generating = false; - console::spinner::stop(); - console::log("\n"); + std::unique_ptr backend = std::make_unique(); - std::thread inference_thread([&ctx_cli]() { - ctx_cli.ctx_server.start_loop(); - }); + // Connect to server - either external or spawn local + if (!params.endpoint.empty()) { + backend->endpoint_url = params.endpoint; + console::log("\nConnecting to %s ... ", backend->endpoint_url.c_str()); + console::spinner::start(); + if (!backend->connect()) { + console::spinner::stop(); + console::error("\nFailed to connect to server\n"); + return 1; + } + console::spinner::stop(); + console::log("connected\n"); + } else { + console::log("\nStarting local server... "); + console::spinner::start(); + if (!backend->spawn_local_server(argc, argv)) { + console::spinner::stop(); + console::error("\nFailed to start local server\n"); + return 1; + } + console::spinner::stop(); + console::log("started on %s\n", backend->endpoint_url.c_str()); + } - auto inf = ctx_cli.ctx_server.get_meta(); std::string modalities = "text"; - if (inf.has_inp_image) { + if (backend->has_vision()) { modalities += ", vision"; } - if (inf.has_inp_audio) { + if (backend->has_audio()) { modalities += ", audio"; } auto add_system_prompt = [&]() { if (!params.system_prompt.empty()) { - ctx_cli.messages.push_back({ + messages.push_back({ {"role", "system"}, {"content", params.system_prompt} }); @@ -422,8 +260,9 @@ int main(int argc, char ** argv) { console::log("\n"); console::log("%s\n", LLAMA_ASCII_LOGO); - console::log("build : %s\n", inf.build_info.c_str()); - console::log("model : %s\n", inf.model_name.c_str()); + console::log("server : %s\n", backend->endpoint_url.c_str()); + console::log("build : %s\n", backend->get_build_info().c_str()); + console::log("model : %s\n", backend->get_model_name().c_str()); console::log("modalities : %s\n", modalities.c_str()); if (!params.system_prompt.empty()) { console::log("using custom system prompt\n"); @@ -435,33 +274,43 @@ int main(int argc, char ** argv) { console::log(" /clear clear the chat history\n"); console::log(" /read add a text file\n"); console::log(" /glob add text files using globbing pattern\n"); - if (inf.has_inp_image) { + if (backend->has_vision()) { console::log(" /image add an image file\n"); } - if (inf.has_inp_audio) { + if (backend->has_audio()) { console::log(" /audio add an audio file\n"); } console::log("\n"); - // interactive loop std::string cur_msg; + // helper: build a user message JSON with optional media parts + auto build_user_content = [&](const std::string & text, const std::vector & media_parts) -> json { + if (media_parts.empty()) { + return json{{"type", "text"}, {"text", text}}; + } + // OAI multipart content + json content = json::array(); + // add media first, then text + for (const auto & part : media_parts) { + content.push_back(part); + } + if (!text.empty()) { + content.push_back(json{{"type", "text"}, {"text", text}}); + } + return content; + }; + auto add_text_file = [&](const std::string & fname) -> bool { - std::string marker = ctx_cli.load_input_file(fname, false); - if (marker.empty()) { + std::string content = backend->load_text_file(fname); + if (content.empty()) { console::error("file does not exist or cannot be opened: '%s'\n", fname.c_str()); return false; } - if (inf.fim_sep_token != LLAMA_TOKEN_NULL) { - cur_msg += common_token_to_piece(ctx_cli.ctx_server.get_llama_context(), inf.fim_sep_token, true); - cur_msg += fname; - cur_msg.push_back('\n'); - } else { - cur_msg += "--- File: "; - cur_msg += fname; - cur_msg += " ---\n"; - } - cur_msg += marker; + cur_msg += "--- File: "; + cur_msg += fname; + cur_msg += " ---\n"; + cur_msg += content; console::log("Loaded text from '%s'\n", fname.c_str()); return true; }; @@ -478,15 +327,15 @@ int main(int argc, char ** argv) { buffer += line; } while (another_line); } else { - // process input prompt from args + // process input prompt from args — load any media files for (auto & fname : params.image) { - std::string marker = ctx_cli.load_input_file(fname, true); - if (marker.empty()) { + json media_part = backend->load_media_file(fname); + if (media_part.empty()) { console::error("file does not exist or cannot be opened: '%s'\n", fname.c_str()); break; } + pending_media_parts.push_back(media_part); console::log("Loaded media from '%s'\n", fname.c_str()); - cur_msg += marker; } buffer = params.prompt; if (buffer.size() > 500) { @@ -499,13 +348,20 @@ int main(int argc, char ** argv) { console::set_display(DISPLAY_TYPE_RESET); console::log("\n"); + // Check for interrupt if (should_stop()) { g_is_interrupted.store(false); + if (was_generating) { + // Interrupt during generation: cancel it, continue chatting + was_generating = false; + continue; + } + // Interrupt at prompt: exit break; } // remove trailing newline - if (!buffer.empty() &&buffer.back() == '\n') { + if (!buffer.empty() && buffer.back() == '\n') { buffer.pop_back(); } @@ -520,32 +376,30 @@ int main(int argc, char ** argv) { if (string_starts_with(buffer, "/exit")) { break; } else if (string_starts_with(buffer, "/regen")) { - if (ctx_cli.messages.size() >= 2) { - size_t last_idx = ctx_cli.messages.size() - 1; - ctx_cli.messages.erase(last_idx); + if (messages.size() >= 2) { + size_t last_idx = messages.size() - 1; + messages.erase(last_idx); add_user_msg = false; } else { console::error("No message to regenerate.\n"); continue; } } else if (string_starts_with(buffer, "/clear")) { - ctx_cli.messages.clear(); + messages.clear(); add_system_prompt(); - - ctx_cli.input_files.clear(); + pending_media_parts.clear(); console::log("Chat history cleared.\n"); continue; } else if ( - (string_starts_with(buffer, "/image ") && inf.has_inp_image) || - (string_starts_with(buffer, "/audio ") && inf.has_inp_audio)) { - // just in case (bad copy-paste for example), we strip all trailing/leading spaces + (string_starts_with(buffer, "/image ") && backend->has_vision()) || + (string_starts_with(buffer, "/audio ") && backend->has_audio())) { std::string fname = string_strip(buffer.substr(7)); - std::string marker = ctx_cli.load_input_file(fname, true); - if (marker.empty()) { + json media_part = backend->load_media_file(fname); + if (media_part.empty()) { console::error("file does not exist or cannot be opened: '%s'\n", fname.c_str()); continue; } - cur_msg += marker; + pending_media_parts.push_back(media_part); console::log("Loaded media from '%s'\n", fname.c_str()); continue; } else if (string_starts_with(buffer, "/read ")) { @@ -612,25 +466,44 @@ int main(int argc, char ** argv) { // generate response if (add_user_msg) { - ctx_cli.messages.push_back({ - {"role", "user"}, - {"content", cur_msg} - }); + // always use multipart content for consistency + json user_content = build_user_content(cur_msg, pending_media_parts); + if (pending_media_parts.empty()) { + // simple text message + messages.push_back({ + {"role", "user"}, + {"content", cur_msg} + }); + } else { + // multipart message + messages.push_back({ + {"role", "user"}, + {"content", user_content} + }); + } + pending_media_parts.clear(); cur_msg.clear(); } result_timings timings; - std::string assistant_content = ctx_cli.generate_completion(timings); - ctx_cli.messages.push_back({ - {"role", "assistant"}, - {"content", assistant_content} - }); + was_generating = true; + std::string assistant_content = backend->generate_completion(messages, params, params.verbose_prompt, timings); + was_generating = false; + // Only add assistant message if we got content (not interrupted/error) + if (!assistant_content.empty() || !g_is_interrupted.load()) { + messages.push_back({ + {"role", "assistant"}, + {"content", assistant_content} + }); + } console::log("\n"); if (params.show_timings) { - console::set_display(DISPLAY_TYPE_INFO); - console::log("\n"); - console::log("[ Prompt: %.1f t/s | Generation: %.1f t/s ]\n", timings.prompt_per_second, timings.predicted_per_second); - console::set_display(DISPLAY_TYPE_RESET); + if (timings.prompt_n >= 0 && timings.predicted_n >= 0) { + console::set_display(DISPLAY_TYPE_INFO); + console::log("\n"); + console::log("[ Prompt: %.1f t/s | Generation: %.1f t/s ]\n", timings.prompt_per_second, timings.predicted_per_second); + console::set_display(DISPLAY_TYPE_RESET); + } } if (params.single_turn) { @@ -641,12 +514,7 @@ int main(int argc, char ** argv) { console::set_display(DISPLAY_TYPE_RESET); console::log("\nExiting...\n"); - ctx_cli.ctx_server.terminate(); - inference_thread.join(); - - // bump the log level to display timings - common_log_set_verbosity_thold(LOG_LEVEL_INFO); - common_memory_breakdown_print(ctx_cli.ctx_server.get_llama_context()); + backend->terminate(); return 0; } diff --git a/vendor/sheredom/subprocess.h b/vendor/sheredom/subprocess.h index 3e40bae046a..b95a9e70c80 100644 --- a/vendor/sheredom/subprocess.h +++ b/vendor/sheredom/subprocess.h @@ -89,7 +89,12 @@ enum subprocess_option_e { // Search for program names in the PATH variable. Always enabled on Windows. // Note: this will **not** search for paths in any provided custom environment // and instead uses the PATH of the spawning process. - subprocess_option_search_user_path = 0x10 + subprocess_option_search_user_path = 0x10, + + // Spawn the child in a new session/process group so that terminal signals + // (e.g. SIGINT from Ctrl+C) are not forwarded to the child process. + // On Unix this uses setsid(); on Windows it uses CREATE_NEW_PROCESS_GROUP. + subprocess_option_new_session = 0x20 }; #if defined(__cplusplus) @@ -499,6 +504,7 @@ int subprocess_create_ex(const char *const commandLine[], int options, const unsigned long startFUseStdHandles = 0x00000100; const unsigned long handleFlagInherit = 0x00000001; const unsigned long createNoWindow = 0x08000000; + const unsigned long createNewProcessGroup = 0x00000200; struct subprocess_subprocess_information_s processInfo; struct subprocess_security_attributes_s saAttr = {sizeof(saAttr), SUBPROCESS_NULL, 1}; @@ -529,6 +535,10 @@ int subprocess_create_ex(const char *const commandLine[], int options, flags |= createNoWindow; } + if (subprocess_option_new_session == (options & subprocess_option_new_session)) { + flags |= createNewProcessGroup; + } + if (subprocess_option_inherit_environment != (options & subprocess_option_inherit_environment)) { if (SUBPROCESS_NULL == environment) { @@ -873,22 +883,36 @@ int subprocess_create_ex(const char *const commandLine[], int options, #pragma clang diagnostic ignored "-Wcast-qual" #pragma clang diagnostic ignored "-Wold-style-cast" #endif + // Prepare spawn attributes if new session is requested + posix_spawnattr_t *attrp = SUBPROCESS_NULL; + posix_spawnattr_t attr; + if (subprocess_option_new_session == (options & subprocess_option_new_session)) { + posix_spawnattr_init(&attr); + short pflags = POSIX_SPAWN_SETSID; + posix_spawnattr_setflags(&attr, pflags); + attrp = &attr; + } + if (subprocess_option_search_user_path == (options & subprocess_option_search_user_path)) { - if (0 != posix_spawnp(&child, commandLine[0], &actions, SUBPROCESS_NULL, + if (0 != posix_spawnp(&child, commandLine[0], &actions, attrp, SUBPROCESS_CONST_CAST(char *const *, commandLine), used_environment)) { + if (attrp) posix_spawnattr_destroy(attrp); posix_spawn_file_actions_destroy(&actions); return -1; } } else { - if (0 != posix_spawn(&child, commandLine[0], &actions, SUBPROCESS_NULL, + if (0 != posix_spawn(&child, commandLine[0], &actions, attrp, SUBPROCESS_CONST_CAST(char *const *, commandLine), used_environment)) { + if (attrp) posix_spawnattr_destroy(attrp); posix_spawn_file_actions_destroy(&actions); return -1; } } + + if (attrp) posix_spawnattr_destroy(attrp); #ifdef __clang__ #pragma clang diagnostic pop #endif