mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-25 13:50:20 +00:00
60bc8866b1
* common: refactor models handling * remote preset * cont * rm skip_download option * missing header * fix plan.model_files * fix --offline case * move hf_plan to download * refactor * rm redundant curr_ex, add comments * adapt
2134 lines
78 KiB
C++
2134 lines
78 KiB
C++
#include "server-common.h"
|
|
#include "server-models.h"
|
|
#include "server-context.h"
|
|
|
|
#include "build-info.h"
|
|
#include "preset.h"
|
|
#include "download.h"
|
|
|
|
#include <cpp-httplib/httplib.h> // TODO: remove this once we use HTTP client from download.h
|
|
#include <sheredom/subprocess.h>
|
|
|
|
#include <functional>
|
|
#include <optional>
|
|
#include <algorithm>
|
|
#include <thread>
|
|
#include <mutex>
|
|
#include <condition_variable>
|
|
#include <cstring>
|
|
#include <cstdlib>
|
|
#include <atomic>
|
|
#include <chrono>
|
|
#include <queue>
|
|
#include <filesystem>
|
|
#include <random>
|
|
#include <sstream>
|
|
#include <cstring>
|
|
|
|
#ifdef _WIN32
|
|
#include <winsock2.h>
|
|
#include <windows.h>
|
|
#else
|
|
#include <sys/socket.h>
|
|
#include <netinet/in.h>
|
|
#include <arpa/inet.h>
|
|
#include <unistd.h>
|
|
extern char **environ;
|
|
#endif
|
|
|
|
#if defined(__APPLE__) && defined(__MACH__)
|
|
// macOS: use _NSGetExecutablePath to get the executable path
|
|
#include <mach-o/dyld.h>
|
|
#include <limits.h>
|
|
#endif
|
|
|
|
#define DEFAULT_STOP_TIMEOUT 10 // seconds
|
|
|
|
#define CMD_ROUTER_TO_CHILD_EXIT "cmd_router_to_child:exit"
|
|
#define CMD_CHILD_TO_ROUTER_STATE "cmd_child_to_router:state:" // followed by json string
|
|
|
|
// address for child process, this is needed because router may run on 0.0.0.0
|
|
// ref: https://github.com/ggml-org/llama.cpp/issues/17862
|
|
#define CHILD_ADDR "127.0.0.1"
|
|
|
|
struct server_subproc {
|
|
std::optional<subprocess_s> sproc; // empty while in DOWNLOADING state
|
|
std::atomic<bool> stopped{false}; // set to cancel a download or signal child process exit
|
|
|
|
subprocess_s & get() {
|
|
GGML_ASSERT(sproc.has_value() && "subprocess not initialized");
|
|
return sproc.value();
|
|
}
|
|
|
|
bool is_alive() {
|
|
return sproc.has_value() && subprocess_alive(&sproc.value());
|
|
}
|
|
|
|
void request_exit() {
|
|
if (sproc.has_value()) {
|
|
FILE * stdin_file = subprocess_stdin(&sproc.value());
|
|
if (stdin_file) {
|
|
fprintf(stdin_file, "%s\n", CMD_ROUTER_TO_CHILD_EXIT);
|
|
fflush(stdin_file);
|
|
}
|
|
}
|
|
stopped.store(true, std::memory_order_relaxed);
|
|
}
|
|
|
|
void terminate() {
|
|
if (!sproc.has_value()) {
|
|
return;
|
|
}
|
|
#if defined(_WIN32)
|
|
if (sproc->hProcess == NULL) {
|
|
return;
|
|
}
|
|
#else
|
|
if (sproc->child <= 0) {
|
|
return;
|
|
}
|
|
#endif
|
|
subprocess_terminate(&sproc.value());
|
|
}
|
|
};
|
|
|
|
|
|
static std::filesystem::path get_server_exec_path() {
|
|
#if defined(_WIN32)
|
|
wchar_t buf[32768] = { 0 }; // Large buffer to handle long paths
|
|
DWORD len = GetModuleFileNameW(nullptr, buf, _countof(buf));
|
|
if (len == 0 || len >= _countof(buf)) {
|
|
throw std::runtime_error("GetModuleFileNameW failed or path too long");
|
|
}
|
|
return std::filesystem::path(buf);
|
|
#elif defined(__APPLE__) && defined(__MACH__)
|
|
char small_path[PATH_MAX];
|
|
uint32_t size = sizeof(small_path);
|
|
|
|
if (_NSGetExecutablePath(small_path, &size) == 0) {
|
|
// resolve any symlinks to get absolute path
|
|
try {
|
|
return std::filesystem::canonical(std::filesystem::path(small_path));
|
|
} catch (...) {
|
|
return std::filesystem::path(small_path);
|
|
}
|
|
} else {
|
|
// buffer was too small, allocate required size and call again
|
|
std::vector<char> buf(size);
|
|
if (_NSGetExecutablePath(buf.data(), &size) == 0) {
|
|
try {
|
|
return std::filesystem::canonical(std::filesystem::path(buf.data()));
|
|
} catch (...) {
|
|
return std::filesystem::path(buf.data());
|
|
}
|
|
}
|
|
throw std::runtime_error("_NSGetExecutablePath failed after buffer resize");
|
|
}
|
|
#else
|
|
char path[FILENAME_MAX];
|
|
ssize_t count = readlink("/proc/self/exe", path, FILENAME_MAX);
|
|
if (count <= 0) {
|
|
throw std::runtime_error("failed to resolve /proc/self/exe");
|
|
}
|
|
return std::filesystem::path(std::string(path, count));
|
|
#endif
|
|
}
|
|
|
|
static void unset_reserved_args(common_preset & preset, bool unset_model_args) {
|
|
preset.unset_option("LLAMA_ARG_SSL_KEY_FILE");
|
|
preset.unset_option("LLAMA_ARG_SSL_CERT_FILE");
|
|
preset.unset_option("LLAMA_API_KEY");
|
|
preset.unset_option("LLAMA_ARG_MODELS_DIR");
|
|
preset.unset_option("LLAMA_ARG_MODELS_MAX");
|
|
preset.unset_option("LLAMA_ARG_MODELS_PRESET");
|
|
preset.unset_option("LLAMA_ARG_MODELS_AUTOLOAD");
|
|
if (unset_model_args) {
|
|
preset.unset_option("LLAMA_ARG_MODEL");
|
|
preset.unset_option("LLAMA_ARG_MMPROJ");
|
|
preset.unset_option("LLAMA_ARG_ALIAS");
|
|
preset.unset_option("LLAMA_ARG_HF_REPO");
|
|
}
|
|
}
|
|
|
|
#ifdef _WIN32
|
|
static std::string wide_to_utf8(const wchar_t * ws) {
|
|
if (!ws || !*ws) {
|
|
return {};
|
|
}
|
|
|
|
const int len = static_cast<int>(std::wcslen(ws));
|
|
const int bytes = WideCharToMultiByte(CP_UTF8, 0, ws, len, nullptr, 0, nullptr, nullptr);
|
|
if (bytes == 0) {
|
|
return {};
|
|
}
|
|
|
|
std::string utf8(bytes, '\0');
|
|
WideCharToMultiByte(CP_UTF8, 0, ws, len, utf8.data(), bytes, nullptr, nullptr);
|
|
|
|
return utf8;
|
|
}
|
|
#endif
|
|
|
|
static std::vector<std::string> get_environment() {
|
|
std::vector<std::string> env;
|
|
|
|
#ifdef _WIN32
|
|
LPWCH env_block = GetEnvironmentStringsW();
|
|
if (!env_block) {
|
|
return env;
|
|
}
|
|
for (LPWCH e = env_block; *e; e += wcslen(e) + 1) {
|
|
env.emplace_back(wide_to_utf8(e));
|
|
}
|
|
FreeEnvironmentStringsW(env_block);
|
|
#else
|
|
if (environ == nullptr) {
|
|
return env;
|
|
}
|
|
for (char ** e = environ; *e != nullptr; e++) {
|
|
env.emplace_back(*e);
|
|
}
|
|
#endif
|
|
|
|
return env;
|
|
}
|
|
|
|
void server_model_meta::update_args(common_preset_context & ctx_preset, std::string bin_path) {
|
|
// update params
|
|
unset_reserved_args(preset, false);
|
|
preset.set_option(ctx_preset, "LLAMA_ARG_HOST", CHILD_ADDR);
|
|
preset.set_option(ctx_preset, "LLAMA_ARG_PORT", std::to_string(port));
|
|
preset.set_option(ctx_preset, "LLAMA_ARG_ALIAS", name);
|
|
// TODO: maybe validate preset before rendering ?
|
|
// render args
|
|
args = preset.to_args(bin_path);
|
|
|
|
// unified binary dispatches by subcommand, re-inject it right after the
|
|
// binary path so the child starts as 'llama serve ...' not 'llama ...'
|
|
const char * app_cmd = std::getenv("LLAMA_APP_CMD");
|
|
if (app_cmd != nullptr && app_cmd[0] != '\0' && !bin_path.empty()) {
|
|
args.insert(args.begin() + 1, app_cmd);
|
|
}
|
|
}
|
|
|
|
void server_model_meta::update_caps() {
|
|
try {
|
|
common_params params;
|
|
preset.apply_to_params(params, {
|
|
"LLAMA_ARG_MODEL",
|
|
"LLAMA_ARG_MODEL_URL",
|
|
"LLAMA_ARG_MMPROJ",
|
|
"LLAMA_ARG_MMPROJ_URL",
|
|
"LLAMA_ARG_HF_REPO",
|
|
"LLAMA_ARG_HF_REPO_FILE",
|
|
});
|
|
params.offline = true;
|
|
common_models_handler handler = common_models_handler_init(params, LLAMA_EXAMPLE_SERVER);
|
|
common_models_handler_apply(handler, params); // note: this won't download the model because offline=true
|
|
if (params.mmproj.path.empty()) {
|
|
multimodal = { false, false };
|
|
} else {
|
|
multimodal = mtmd_get_cap_from_file(params.mmproj.path.c_str());
|
|
}
|
|
} catch (const std::exception & e) {
|
|
LOG_WRN("failed to initialize common_params for multimodal capability detection: %s\n", e.what());
|
|
multimodal = { false, false };
|
|
}
|
|
}
|
|
|
|
//
|
|
// server_models
|
|
//
|
|
|
|
server_models::server_models(
|
|
const common_params & params,
|
|
int argc,
|
|
char ** argv)
|
|
: ctx_preset(LLAMA_EXAMPLE_SERVER),
|
|
base_params(params),
|
|
base_env(get_environment()),
|
|
base_preset(ctx_preset.load_from_args(argc, argv)) {
|
|
// clean up base preset
|
|
unset_reserved_args(base_preset, true);
|
|
// set binary path
|
|
try {
|
|
bin_path = get_server_exec_path().string();
|
|
} catch (const std::exception & e) {
|
|
bin_path = argv[0];
|
|
LOG_WRN("failed to get server executable path: %s\n", e.what());
|
|
LOG_WRN("using original argv[0] as fallback: %s\n", argv[0]);
|
|
}
|
|
load_models();
|
|
}
|
|
|
|
void server_models::add_model(server_model_meta && meta) {
|
|
if (mapping.find(meta.name) != mapping.end()) {
|
|
throw std::runtime_error(string_format("model '%s' appears multiple times", meta.name.c_str()));
|
|
}
|
|
|
|
// check model name does not conflict with existing aliases
|
|
for (const auto & [key, inst] : mapping) {
|
|
if (inst.meta.aliases.count(meta.name)) {
|
|
throw std::runtime_error(string_format("model name '%s' conflicts with alias of model '%s'",
|
|
meta.name.c_str(), key.c_str()));
|
|
}
|
|
}
|
|
|
|
// parse aliases from preset's --alias option (comma-separated)
|
|
std::string alias_str;
|
|
if (meta.preset.get_option("LLAMA_ARG_ALIAS", alias_str) && !alias_str.empty()) {
|
|
for (auto & alias : string_split<std::string>(alias_str, ',')) {
|
|
alias = string_strip(alias);
|
|
if (!alias.empty()) {
|
|
meta.aliases.insert(alias);
|
|
}
|
|
}
|
|
}
|
|
|
|
// parse tags from preset's --tags option (comma-separated)
|
|
std::string tags_str;
|
|
if (meta.preset.get_option("LLAMA_ARG_TAGS", tags_str) && !tags_str.empty()) {
|
|
for (auto & tag : string_split<std::string>(tags_str, ',')) {
|
|
tag = string_strip(tag);
|
|
if (!tag.empty()) {
|
|
meta.tags.insert(tag);
|
|
}
|
|
}
|
|
}
|
|
|
|
// validate aliases do not conflict with existing names or aliases
|
|
for (const auto & alias : meta.aliases) {
|
|
if (mapping.find(alias) != mapping.end()) {
|
|
throw std::runtime_error(string_format("alias '%s' for model '%s' conflicts with existing model name",
|
|
alias.c_str(), meta.name.c_str()));
|
|
}
|
|
for (const auto & [key, inst] : mapping) {
|
|
if (inst.meta.aliases.count(alias)) {
|
|
throw std::runtime_error(string_format("alias '%s' for model '%s' conflicts with alias of model '%s'",
|
|
alias.c_str(), meta.name.c_str(), key.c_str()));
|
|
}
|
|
}
|
|
}
|
|
|
|
meta.update_args(ctx_preset, bin_path); // render args
|
|
meta.update_caps();
|
|
std::string name = meta.name;
|
|
mapping[name] = instance_t{
|
|
/* subproc */ std::make_shared<server_subproc>(),
|
|
/* th */ std::thread(),
|
|
/* meta */ std::move(meta)
|
|
};
|
|
}
|
|
|
|
void server_models::notify_sse(const std::string & event, const std::string & model_id, const json & data) {
|
|
std::unique_ptr<server_task_result_router> result = std::make_unique<server_task_result_router>();
|
|
result->data = {
|
|
{"model", model_id},
|
|
{"event", event},
|
|
};
|
|
if (!data.is_null()) {
|
|
result->data["data"] = data;
|
|
}
|
|
SRV_DBG("notifying SSE clients about event '%s' for model '%s': %s\n", event.c_str(), model_id.c_str(), safe_json_to_str(result->data).c_str());
|
|
sse.broadcast(std::move(result));
|
|
}
|
|
|
|
void server_models::load_models() {
|
|
// Phase 1: load presets from all sources - pure I/O, no lock needed
|
|
// 1. cached models
|
|
common_presets cached_models = ctx_preset.load_from_cache();
|
|
SRV_INF("Loaded %zu cached model presets\n", cached_models.size());
|
|
// 2. local models from --models-dir
|
|
common_presets local_models;
|
|
if (!base_params.models_dir.empty()) {
|
|
local_models = ctx_preset.load_from_models_dir(base_params.models_dir);
|
|
SRV_INF("Loaded %zu local model presets from %s\n", local_models.size(), base_params.models_dir.c_str());
|
|
}
|
|
// 3. custom-path models from presets
|
|
common_preset global = {};
|
|
common_presets custom_presets = {};
|
|
if (!base_params.models_preset.empty()) {
|
|
custom_presets = ctx_preset.load_from_ini(base_params.models_preset, global);
|
|
SRV_INF("Loaded %zu custom model presets from %s\n", custom_presets.size(), base_params.models_preset.c_str());
|
|
}
|
|
|
|
// cascade, apply global preset first
|
|
cached_models = ctx_preset.cascade(global, cached_models);
|
|
local_models = ctx_preset.cascade(global, local_models);
|
|
custom_presets = ctx_preset.cascade(global, custom_presets);
|
|
|
|
// note: if a model exists in both cached and local, local takes precedence
|
|
common_presets final_presets;
|
|
std::unordered_map<std::string, server_model_source> source_map;
|
|
for (const auto & [name, preset] : cached_models) {
|
|
final_presets[name] = preset;
|
|
source_map[name] = SERVER_MODEL_SOURCE_CACHE;
|
|
}
|
|
for (const auto & [name, preset] : local_models) {
|
|
final_presets[name] = preset;
|
|
source_map[name] = SERVER_MODEL_SOURCE_MODELS_DIR;
|
|
}
|
|
for (const auto & [name, custom] : custom_presets) {
|
|
if (final_presets.find(name) != final_presets.end()) {
|
|
final_presets[name].merge(custom);
|
|
} else {
|
|
final_presets[name] = custom;
|
|
}
|
|
source_map[name] = SERVER_MODEL_SOURCE_PRESET;
|
|
}
|
|
|
|
// overlay router's own CLI args on top of every model preset so that
|
|
// e.g. `llama-server --temp 0` is honoured by all child processes
|
|
for (auto & [name, preset] : final_presets) {
|
|
preset.merge(base_preset);
|
|
}
|
|
|
|
auto get_source = [&](const std::string & name) {
|
|
return source_map.count(name) ? source_map.at(name) : SERVER_MODEL_SOURCE_PRESET;
|
|
};
|
|
|
|
// Helpers that read `mapping` - must be called while holding the lock.
|
|
std::unordered_set<std::string> custom_names;
|
|
for (const auto & [name, preset] : custom_presets) custom_names.insert(name);
|
|
auto join_set = [](const std::set<std::string> & s) {
|
|
std::string result;
|
|
for (const auto & v : s) {
|
|
if (!result.empty()) result += ", ";
|
|
result += v;
|
|
}
|
|
return result;
|
|
};
|
|
auto log_available_models = [&]() {
|
|
SRV_INF("Available models (%zu) (*: custom preset)\n", mapping.size());
|
|
for (const auto & [name, inst] : mapping) {
|
|
bool has_custom = custom_names.find(name) != custom_names.end();
|
|
std::string info;
|
|
if (!inst.meta.aliases.empty()) info += " (aliases: " + join_set(inst.meta.aliases) + ")";
|
|
if (!inst.meta.tags.empty()) info += " [tags: " + join_set(inst.meta.tags) + "]";
|
|
SRV_INF(" %c %s%s\n", has_custom ? '*' : ' ', name.c_str(), info.c_str());
|
|
}
|
|
};
|
|
auto apply_stop_timeout = [&]() {
|
|
for (auto & [name, inst] : mapping) {
|
|
std::string val;
|
|
if (inst.meta.preset.get_option(COMMON_ARG_PRESET_STOP_TIMEOUT, val)) {
|
|
try {
|
|
inst.meta.stop_timeout = std::stoi(val);
|
|
} catch (...) {
|
|
SRV_WRN("invalid stop-timeout value '%s' for model '%s', using default %d seconds\n",
|
|
val.c_str(), name.c_str(), DEFAULT_STOP_TIMEOUT);
|
|
inst.meta.stop_timeout = DEFAULT_STOP_TIMEOUT;
|
|
}
|
|
}
|
|
}
|
|
};
|
|
// update_args() injects HOST/PORT/ALIAS, so strip them before comparing presets
|
|
auto preset_options_for_compare = [](common_preset p) {
|
|
p.unset_option("LLAMA_ARG_HOST");
|
|
p.unset_option("LLAMA_ARG_PORT");
|
|
p.unset_option("LLAMA_ARG_ALIAS");
|
|
return p.options;
|
|
};
|
|
|
|
// Phase 2: acquire the lock once for all mapping mutations.
|
|
// We temporarily release it only when calling functions that acquire it internally
|
|
// (unload, load) or when joining threads (the monitoring thread calls update_status
|
|
// which locks the mutex, so joining while holding it would deadlock).
|
|
std::unique_lock<std::mutex> lk(mutex);
|
|
|
|
need_reload = false;
|
|
bool is_first_load = mapping.empty();
|
|
|
|
if (is_first_load) {
|
|
// FIRST LOAD: add all models, then unlock for autoloading
|
|
for (const auto & [name, preset] : final_presets) {
|
|
server_model_meta meta{
|
|
/* source */ get_source(name),
|
|
/* preset */ preset,
|
|
/* name */ name,
|
|
/* aliases */ {},
|
|
/* tags */ {},
|
|
/* port */ 0,
|
|
/* status */ SERVER_MODEL_STATUS_UNLOADED,
|
|
/* last_used */ 0,
|
|
/* args */ std::vector<std::string>(),
|
|
/* loaded_info */ {},
|
|
/* progress */ {},
|
|
/* exit_code */ 0,
|
|
/* stop_timeout */ DEFAULT_STOP_TIMEOUT,
|
|
/* multimodal */ mtmd_caps{false, false},
|
|
// /* need_download */ false,
|
|
};
|
|
add_model(std::move(meta));
|
|
}
|
|
apply_stop_timeout();
|
|
log_available_models();
|
|
|
|
std::vector<std::string> models_to_load;
|
|
for (const auto & [name, inst] : mapping) {
|
|
std::string val;
|
|
if (inst.meta.preset.get_option(COMMON_ARG_PRESET_LOAD_ON_STARTUP, val) && common_arg_utils::is_truthy(val)) {
|
|
models_to_load.push_back(name);
|
|
}
|
|
}
|
|
if ((int)models_to_load.size() > base_params.models_max) {
|
|
throw std::runtime_error(string_format(
|
|
"number of models to load on startup (%zu) exceeds models_max (%d)",
|
|
models_to_load.size(), base_params.models_max));
|
|
}
|
|
|
|
lk.unlock();
|
|
for (const auto & name : models_to_load) {
|
|
SRV_INF("(startup) loading model %s\n", name.c_str());
|
|
load(name);
|
|
}
|
|
} else {
|
|
// RELOAD: diff the new preset list against the current mapping and reconcile
|
|
is_reloading = true;
|
|
|
|
// find running models whose source was removed or whose preset changed
|
|
std::vector<std::string> to_unload;
|
|
for (const auto & [name, inst] : mapping) {
|
|
if (!inst.meta.is_running()) continue;
|
|
auto it = final_presets.find(name);
|
|
if (it == final_presets.end()) {
|
|
to_unload.push_back(name); // removed from source
|
|
} else if (preset_options_for_compare(inst.meta.preset) != preset_options_for_compare(it->second)) {
|
|
to_unload.push_back(name); // preset changed
|
|
}
|
|
}
|
|
|
|
// unload() acquires the lock internally, so release before each call
|
|
for (const auto & name : to_unload) {
|
|
SRV_INF("(reload) unloading model name=%s (source updated or removed)\n", name.c_str());
|
|
lk.unlock();
|
|
unload(name);
|
|
lk.lock();
|
|
}
|
|
|
|
// wait for all targeted models to reach UNLOADED; cv.wait handles unlock/relock
|
|
cv.wait(lk, [&]() {
|
|
for (const auto & name : to_unload) {
|
|
auto it = mapping.find(name);
|
|
if (it != mapping.end() && it->second.meta.is_running()) return false;
|
|
}
|
|
return true;
|
|
});
|
|
|
|
// collect all threads to join in one pass while the lock is held:
|
|
// - monitoring threads from just-unloaded models (to_unload)
|
|
// - threads of already-UNLOADED models that are being removed from source
|
|
std::vector<std::thread> threads_to_join;
|
|
for (const auto & name : to_unload) {
|
|
auto it = mapping.find(name);
|
|
if (it != mapping.end() && it->second.th.joinable()) {
|
|
threads_to_join.push_back(std::move(it->second.th));
|
|
}
|
|
}
|
|
for (auto & [name, inst] : mapping) {
|
|
if (inst.meta.status == SERVER_MODEL_STATUS_DOWNLOADING) {
|
|
continue; // downloading models are not from config sources, leave them alone
|
|
}
|
|
if (final_presets.find(name) == final_presets.end() && !inst.meta.is_running() && inst.th.joinable()) {
|
|
threads_to_join.push_back(std::move(inst.th));
|
|
}
|
|
}
|
|
|
|
// join outside the lock - monitoring thread calls update_status (needs lock)
|
|
lk.unlock();
|
|
for (auto & th : threads_to_join) th.join();
|
|
lk.lock();
|
|
|
|
// erase models no longer in any source
|
|
for (auto it = mapping.begin(); it != mapping.end(); ) {
|
|
if (it->second.meta.status == SERVER_MODEL_STATUS_DOWNLOADING) {
|
|
++it; // download thread is still busy, skip
|
|
} else if (it->second.meta.status == SERVER_MODEL_STATUS_DOWNLOADED) {
|
|
// download finished, safe to erase
|
|
if (it->second.th.joinable()) {
|
|
it->second.th.join();
|
|
}
|
|
it = mapping.erase(it);
|
|
} else if (final_presets.find(it->first) == final_presets.end()) {
|
|
SRV_INF("(reload) removing model name=%s (no longer in source)\n", it->first.c_str());
|
|
GGML_ASSERT(!it->second.th.joinable()); // must have been joined above
|
|
it = mapping.erase(it);
|
|
} else {
|
|
++it;
|
|
}
|
|
}
|
|
|
|
// update presets for non-running models still in source
|
|
for (auto & [name, inst] : mapping) {
|
|
if (inst.meta.is_running()) continue;
|
|
auto it = final_presets.find(name);
|
|
if (it == final_presets.end()) continue; // erased above
|
|
|
|
inst.meta.preset = it->second;
|
|
|
|
// re-parse aliases, then validate against other models
|
|
std::set<std::string> new_aliases;
|
|
std::string alias_str;
|
|
if (inst.meta.preset.get_option("LLAMA_ARG_ALIAS", alias_str) && !alias_str.empty()) {
|
|
for (auto & alias : string_split<std::string>(alias_str, ',')) {
|
|
alias = string_strip(alias);
|
|
if (!alias.empty()) new_aliases.insert(alias);
|
|
}
|
|
}
|
|
inst.meta.aliases.clear();
|
|
for (const auto & alias : new_aliases) {
|
|
bool conflict = false;
|
|
for (const auto & [other_name, other_inst] : mapping) {
|
|
if (other_name == name) continue;
|
|
if (other_name == alias || other_inst.meta.aliases.count(alias)) {
|
|
SRV_WRN("(reload) alias '%s' for model '%s' conflicts with model '%s', skipping\n",
|
|
alias.c_str(), name.c_str(), other_name.c_str());
|
|
conflict = true;
|
|
break;
|
|
}
|
|
}
|
|
if (!conflict) inst.meta.aliases.insert(alias);
|
|
}
|
|
|
|
// re-parse tags
|
|
inst.meta.tags.clear();
|
|
std::string tags_str;
|
|
if (inst.meta.preset.get_option("LLAMA_ARG_TAGS", tags_str) && !tags_str.empty()) {
|
|
for (auto & tag : string_split<std::string>(tags_str, ',')) {
|
|
tag = string_strip(tag);
|
|
if (!tag.empty()) inst.meta.tags.insert(tag);
|
|
}
|
|
}
|
|
|
|
inst.meta.exit_code = 0; // clear failed state so the model can be reloaded
|
|
inst.meta.update_args(ctx_preset, bin_path);
|
|
inst.meta.update_caps();
|
|
}
|
|
|
|
// add models that are new in this reload
|
|
std::vector<std::string> newly_added;
|
|
for (const auto & [name, preset] : final_presets) {
|
|
if (mapping.find(name) == mapping.end()) {
|
|
server_model_meta meta{
|
|
/* source */ get_source(name),
|
|
/* preset */ preset,
|
|
/* name */ name,
|
|
/* aliases */ {},
|
|
/* tags */ {},
|
|
/* port */ 0,
|
|
/* status */ SERVER_MODEL_STATUS_UNLOADED,
|
|
/* last_used */ 0,
|
|
/* args */ std::vector<std::string>(),
|
|
/* loaded_info */ {},
|
|
/* progress */ {},
|
|
/* exit_code */ 0,
|
|
/* stop_timeout */ DEFAULT_STOP_TIMEOUT,
|
|
/* multimodal */ mtmd_caps{false, false},
|
|
// /* need_download */ false,
|
|
};
|
|
add_model(std::move(meta));
|
|
newly_added.push_back(name);
|
|
}
|
|
}
|
|
|
|
apply_stop_timeout();
|
|
|
|
// clear reload flag before unlocking for autoload - load() blocks on !is_reloading,
|
|
// so clearing it here (while still locked) prevents a deadlock in the autoload calls below
|
|
is_reloading = false;
|
|
cv.notify_all();
|
|
|
|
log_available_models();
|
|
|
|
// collect autoload candidates while still under the lock
|
|
std::vector<std::string> to_autoload;
|
|
for (const auto & name : newly_added) {
|
|
auto it = mapping.find(name);
|
|
if (it != mapping.end()) {
|
|
std::string val;
|
|
if (it->second.meta.preset.get_option(COMMON_ARG_PRESET_LOAD_ON_STARTUP, val) && common_arg_utils::is_truthy(val)) {
|
|
to_autoload.push_back(name);
|
|
}
|
|
}
|
|
}
|
|
|
|
lk.unlock();
|
|
for (const auto & name : to_autoload) {
|
|
SRV_INF("(reload) loading new model %s\n", name.c_str());
|
|
load(name);
|
|
}
|
|
|
|
notify_sse("models_reload", "*");
|
|
}
|
|
}
|
|
|
|
void server_models::update_meta(const std::string & name, const server_model_meta & meta) {
|
|
std::lock_guard<std::mutex> lk(mutex);
|
|
auto it = mapping.find(name);
|
|
if (it != mapping.end()) {
|
|
it->second.meta = meta;
|
|
}
|
|
cv.notify_all(); // notify wait_until_loading_finished
|
|
}
|
|
|
|
bool server_models::has_model(const std::string & name) {
|
|
std::lock_guard<std::mutex> lk(mutex);
|
|
if (mapping.find(name) != mapping.end()) {
|
|
return true;
|
|
}
|
|
for (const auto & [key, inst] : mapping) {
|
|
if (inst.meta.aliases.count(name)) {
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
std::optional<server_model_meta> server_models::get_meta(const std::string & name) {
|
|
std::unique_lock<std::mutex> lk(mutex);
|
|
if (need_reload) {
|
|
lk.unlock();
|
|
load_models();
|
|
lk.lock();
|
|
}
|
|
|
|
auto it = mapping.find(name);
|
|
if (it != mapping.end()) {
|
|
return it->second.meta;
|
|
}
|
|
for (const auto & [key, inst] : mapping) {
|
|
if (inst.meta.aliases.count(name)) {
|
|
return inst.meta;
|
|
}
|
|
}
|
|
return std::nullopt;
|
|
}
|
|
|
|
static int get_free_port() {
|
|
#ifdef _WIN32
|
|
WSADATA wsaData;
|
|
if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) {
|
|
return -1;
|
|
}
|
|
typedef SOCKET native_socket_t;
|
|
#define INVALID_SOCKET_VAL INVALID_SOCKET
|
|
#define CLOSE_SOCKET(s) closesocket(s)
|
|
#else
|
|
typedef int native_socket_t;
|
|
#define INVALID_SOCKET_VAL -1
|
|
#define CLOSE_SOCKET(s) close(s)
|
|
#endif
|
|
|
|
native_socket_t sock = socket(AF_INET, SOCK_STREAM, 0);
|
|
if (sock == INVALID_SOCKET_VAL) {
|
|
#ifdef _WIN32
|
|
WSACleanup();
|
|
#endif
|
|
return -1;
|
|
}
|
|
|
|
struct sockaddr_in serv_addr;
|
|
std::memset(&serv_addr, 0, sizeof(serv_addr));
|
|
serv_addr.sin_family = AF_INET;
|
|
serv_addr.sin_addr.s_addr = htonl(INADDR_ANY);
|
|
serv_addr.sin_port = htons(0);
|
|
|
|
if (bind(sock, (struct sockaddr*)&serv_addr, sizeof(serv_addr)) != 0) {
|
|
CLOSE_SOCKET(sock);
|
|
#ifdef _WIN32
|
|
WSACleanup();
|
|
#endif
|
|
return -1;
|
|
}
|
|
|
|
#ifdef _WIN32
|
|
int namelen = sizeof(serv_addr);
|
|
#else
|
|
socklen_t namelen = sizeof(serv_addr);
|
|
#endif
|
|
if (getsockname(sock, (struct sockaddr*)&serv_addr, &namelen) != 0) {
|
|
CLOSE_SOCKET(sock);
|
|
#ifdef _WIN32
|
|
WSACleanup();
|
|
#endif
|
|
return -1;
|
|
}
|
|
|
|
int port = ntohs(serv_addr.sin_port);
|
|
|
|
CLOSE_SOCKET(sock);
|
|
#ifdef _WIN32
|
|
WSACleanup();
|
|
#endif
|
|
|
|
return port;
|
|
}
|
|
|
|
// helper to convert vector<string> to char **
|
|
// pointers are only valid as long as the original vector is valid
|
|
static std::vector<char *> to_char_ptr_array(const std::vector<std::string> & vec) {
|
|
std::vector<char *> result;
|
|
result.reserve(vec.size() + 1);
|
|
for (const auto & s : vec) {
|
|
result.push_back(const_cast<char*>(s.c_str()));
|
|
}
|
|
result.push_back(nullptr);
|
|
return result;
|
|
}
|
|
|
|
std::vector<server_model_meta> server_models::get_all_meta() {
|
|
std::unique_lock<std::mutex> lk(mutex);
|
|
if (need_reload) {
|
|
lk.unlock();
|
|
load_models();
|
|
lk.lock();
|
|
}
|
|
|
|
std::vector<server_model_meta> result;
|
|
result.reserve(mapping.size());
|
|
for (const auto & [name, inst] : mapping) {
|
|
result.push_back(inst.meta);
|
|
}
|
|
return result;
|
|
}
|
|
|
|
void server_models::unload_lru() {
|
|
if (base_params.models_max <= 0) {
|
|
return; // no limit
|
|
}
|
|
// remove one of the servers if we passed the models_max (least recently used - LRU)
|
|
std::string lru_model_name = "";
|
|
int64_t lru_last_used = ggml_time_ms();
|
|
size_t count_active = 0;
|
|
{
|
|
std::unique_lock<std::mutex> lk(mutex);
|
|
for (const auto & m : mapping) {
|
|
if (m.second.meta.is_running()) {
|
|
count_active++;
|
|
if (m.second.meta.last_used < lru_last_used) {
|
|
lru_model_name = m.first;
|
|
lru_last_used = m.second.meta.last_used;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
if (!lru_model_name.empty() && count_active >= (size_t)base_params.models_max) {
|
|
SRV_INF("models_max limit reached, removing LRU name=%s\n", lru_model_name.c_str());
|
|
unload(lru_model_name);
|
|
// wait for unload to complete
|
|
{
|
|
std::unique_lock<std::mutex> lk(mutex);
|
|
cv.wait(lk, [this, &lru_model_name]() {
|
|
return mapping[lru_model_name].meta.status == SERVER_MODEL_STATUS_UNLOADED;
|
|
});
|
|
}
|
|
}
|
|
}
|
|
|
|
void server_models::load(const std::string & name) {
|
|
load(name, load_options{});
|
|
}
|
|
|
|
void server_models::load(const std::string & name, const load_options & opts) {
|
|
if (!opts.custom_meta.has_value()) {
|
|
if (!has_model(name)) {
|
|
throw std::runtime_error("model name=" + name + " is not found");
|
|
}
|
|
unload_lru();
|
|
}
|
|
|
|
std::unique_lock<std::mutex> lk(mutex);
|
|
// edge case: block until any in-progress reload has finished so we always load
|
|
// against the freshest preset and a consistent mapping state
|
|
cv.wait(lk, [this]() { return !is_reloading; });
|
|
|
|
auto meta = opts.custom_meta.has_value() ? *opts.custom_meta : mapping[name].meta;
|
|
if (meta.status != SERVER_MODEL_STATUS_UNLOADED) {
|
|
SRV_INF("model %s is not ready\n", name.c_str());
|
|
return;
|
|
}
|
|
|
|
// Re-check capacity under the lock to prevent concurrent loads from
|
|
// exceeding models_max. Without this, the window between unload_lru()
|
|
// releasing its lock and this lock_guard acquiring allows multiple
|
|
// threads to each observe capacity and all proceed to load.
|
|
if (base_params.models_max > 0) {
|
|
size_t count_active = 0;
|
|
for (const auto & m : mapping) {
|
|
if (m.second.meta.is_running()) {
|
|
count_active++;
|
|
}
|
|
}
|
|
if (count_active >= (size_t)base_params.models_max) {
|
|
throw std::runtime_error("model limit reached, try again later");
|
|
}
|
|
}
|
|
|
|
// prepare new instance info
|
|
instance_t inst;
|
|
inst.meta = meta;
|
|
inst.meta.port = get_free_port();
|
|
inst.meta.status = SERVER_MODEL_STATUS_LOADING;
|
|
inst.meta.loaded_info = json{};
|
|
inst.meta.last_used = ggml_time_ms();
|
|
|
|
if (inst.meta.port <= 0) {
|
|
throw std::runtime_error("failed to get a port number");
|
|
}
|
|
|
|
inst.subproc = std::make_shared<server_subproc>();
|
|
{
|
|
SRV_INF("spawning server instance with name=%s on port %d\n", inst.meta.name.c_str(), inst.meta.port);
|
|
|
|
inst.meta.update_args(ctx_preset, bin_path); // render args
|
|
|
|
std::vector<std::string> child_args = inst.meta.args; // copy
|
|
std::vector<std::string> child_env = base_env; // copy
|
|
child_env.push_back("LLAMA_SERVER_ROUTER_PORT=" + std::to_string(base_params.port));
|
|
|
|
if (opts.mode == SERVER_CHILD_MODE_DOWNLOAD) {
|
|
inst.meta.status = SERVER_MODEL_STATUS_DOWNLOADING;
|
|
child_env.push_back("LLAMA_SERVER_CHILD_MODE=download");
|
|
child_env.push_back("LLAMA_ARG_HF_REPO=" + name);
|
|
}
|
|
|
|
SRV_INF("%s", "spawning server instance with args:\n");
|
|
for (const auto & arg : child_args) {
|
|
SRV_INF(" %s\n", arg.c_str());
|
|
}
|
|
inst.meta.args = child_args; // save for debugging
|
|
|
|
std::vector<char *> argv = to_char_ptr_array(child_args);
|
|
std::vector<char *> envp = to_char_ptr_array(child_env);
|
|
|
|
// TODO @ngxson : maybe separate stdout and stderr in the future
|
|
// so that we can use stdout for commands and stderr for logging
|
|
int options = subprocess_option_no_window | subprocess_option_combined_stdout_stderr;
|
|
inst.subproc->sproc.emplace();
|
|
int result = subprocess_create_ex(argv.data(), options, envp.data(), &inst.subproc->get());
|
|
if (result != 0) {
|
|
throw std::runtime_error("failed to spawn server instance");
|
|
}
|
|
}
|
|
|
|
// start a thread to manage the child process
|
|
// captured variables are guaranteed to be destroyed only after the thread is joined
|
|
inst.th = std::thread([
|
|
this, name,
|
|
child_proc = inst.subproc,
|
|
port = inst.meta.port,
|
|
stop_timeout = inst.meta.stop_timeout,
|
|
child_mode = opts.mode
|
|
]() {
|
|
FILE * stdin_file = subprocess_stdin(&child_proc->get());
|
|
FILE * stdout_file = subprocess_stdout(&child_proc->get()); // combined stdout/stderr
|
|
|
|
std::thread log_thread([&]() {
|
|
// read stdout/stderr and forward to main server log
|
|
// also handle status report from child process
|
|
std::vector<char> vec_buf(128 * 1024); // large buffer for storing info
|
|
char * buffer = vec_buf.data();
|
|
if (stdout_file) {
|
|
while (fgets(buffer, vec_buf.size(), stdout_file) != nullptr) {
|
|
LOG("[%5d] %s", port, buffer);
|
|
std::string str(buffer);
|
|
if (string_starts_with(buffer, CMD_CHILD_TO_ROUTER_STATE)) {
|
|
this->handle_child_state(name, str);
|
|
}
|
|
}
|
|
} else {
|
|
SRV_ERR("failed to get stdout/stderr of child process for name=%s\n", name.c_str());
|
|
}
|
|
});
|
|
|
|
std::thread stopping_thread([&]() {
|
|
// thread to monitor explicit stop requests; child crash is signalled via child_proc->stopped
|
|
auto is_stopping = [this, &name]() {
|
|
return this->stopping_models.find(name) != this->stopping_models.end();
|
|
};
|
|
{
|
|
std::unique_lock<std::mutex> lk(this->mutex);
|
|
this->cv_stop.wait(lk, [&]() {
|
|
return is_stopping() || child_proc->stopped.load(std::memory_order_acquire);
|
|
});
|
|
}
|
|
// child crashed or finished on its own, skip graceful shutdown sequence
|
|
if (child_proc->stopped.load(std::memory_order_acquire)) {
|
|
return;
|
|
}
|
|
SRV_INF("stopping model instance name=%s\n", name.c_str());
|
|
fprintf(stdin_file, "%s\n", CMD_ROUTER_TO_CHILD_EXIT);
|
|
fflush(stdin_file);
|
|
int64_t start_time = ggml_time_ms();
|
|
while (true) {
|
|
std::unique_lock<std::mutex> lk(this->mutex);
|
|
if (!is_stopping() || child_proc->stopped.load(std::memory_order_acquire)) {
|
|
return;
|
|
}
|
|
int64_t elapsed = ggml_time_ms() - start_time;
|
|
if (elapsed >= stop_timeout * 1000) {
|
|
lk.unlock();
|
|
SRV_WRN("force-killing model instance name=%s after %d seconds timeout\n", name.c_str(), stop_timeout);
|
|
child_proc->terminate();
|
|
return;
|
|
}
|
|
this->cv_stop.wait_for(lk, std::chrono::seconds(1), [&]() {
|
|
return !is_stopping() || child_proc->stopped.load(std::memory_order_acquire);
|
|
});
|
|
}
|
|
});
|
|
|
|
// we reach here when the child process exits (stdout EOF)
|
|
// note: we cannot join() prior to this point because it will close stdin_file
|
|
if (log_thread.joinable()) {
|
|
log_thread.join();
|
|
}
|
|
|
|
child_proc->stopped.store(true, std::memory_order_release);
|
|
{
|
|
std::lock_guard<std::mutex> lk(this->mutex);
|
|
stopping_models.erase(name);
|
|
cv_stop.notify_all();
|
|
}
|
|
if (stopping_thread.joinable()) {
|
|
stopping_thread.join();
|
|
}
|
|
|
|
// get the exit code
|
|
int exit_code = 0;
|
|
subprocess_join(&child_proc->get(), &exit_code);
|
|
subprocess_destroy(&child_proc->get());
|
|
|
|
// update status and exit code
|
|
if (child_mode == SERVER_CHILD_MODE_DOWNLOAD) {
|
|
// instance will be cleaned up on next load_models() call
|
|
} else {
|
|
this->update_status(name, {
|
|
SERVER_MODEL_STATUS_UNLOADED,
|
|
exit_code
|
|
});
|
|
}
|
|
SRV_INF("instance name=%s exited with status %d\n", name.c_str(), exit_code);
|
|
});
|
|
|
|
// clean up old process/thread if exists
|
|
{
|
|
auto & old_instance = mapping[name];
|
|
// old process should have exited already, but just in case, we clean it up here
|
|
if (old_instance.subproc && old_instance.subproc->is_alive()) {
|
|
SRV_WRN("old process for model name=%s is still alive, this is unexpected\n", name.c_str());
|
|
old_instance.subproc->terminate(); // force kill
|
|
}
|
|
if (old_instance.th.joinable()) {
|
|
old_instance.th.join();
|
|
}
|
|
}
|
|
|
|
notify_sse("model_status", name, {
|
|
{"status", server_model_status_to_string(inst.meta.status)},
|
|
});
|
|
|
|
mapping[name] = std::move(inst);
|
|
cv.notify_all();
|
|
}
|
|
|
|
void server_models::unload(const std::string & name) {
|
|
std::unique_lock<std::mutex> lk(mutex);
|
|
auto it = mapping.find(name);
|
|
if (it != mapping.end()) {
|
|
if (it->second.meta.status == SERVER_MODEL_STATUS_DOWNLOADING) {
|
|
SRV_INF("cancelling download for model name=%s\n", name.c_str());
|
|
it->second.subproc->request_exit();
|
|
// for convenience, we wait the status change here
|
|
wait(lk, name, [](const server_model_meta & new_meta) {
|
|
return new_meta.status != SERVER_MODEL_STATUS_DOWNLOADING;
|
|
});
|
|
} else if (it->second.meta.is_running()) {
|
|
SRV_INF("stopping model instance name=%s\n", name.c_str());
|
|
stopping_models.insert(name);
|
|
if (it->second.meta.status == SERVER_MODEL_STATUS_LOADING) {
|
|
// special case: if model is in loading state, unloading means force-killing it
|
|
SRV_WRN("model name=%s is still loading, force-killing\n", name.c_str());
|
|
it->second.subproc->terminate();
|
|
}
|
|
cv_stop.notify_all();
|
|
// status change will be handled by the managing thread
|
|
} else {
|
|
SRV_WRN("model instance name=%s is not running\n", name.c_str());
|
|
}
|
|
}
|
|
}
|
|
|
|
void server_models::unload_all() {
|
|
std::vector<std::thread> to_join;
|
|
{
|
|
std::lock_guard<std::mutex> lk(mutex);
|
|
for (auto & [name, inst] : mapping) {
|
|
if (inst.meta.status == SERVER_MODEL_STATUS_DOWNLOADING) {
|
|
SRV_INF("cancelling download for model name=%s\n", name.c_str());
|
|
inst.subproc->stopped.store(true, std::memory_order_relaxed);
|
|
} else if (inst.meta.is_running()) {
|
|
SRV_INF("stopping model instance name=%s\n", name.c_str());
|
|
stopping_models.insert(name);
|
|
cv_stop.notify_all();
|
|
// status change will be handled by the managing thread
|
|
}
|
|
// moving the thread to join list to avoid deadlock
|
|
to_join.push_back(std::move(inst.th));
|
|
}
|
|
}
|
|
for (auto & th : to_join) {
|
|
if (th.joinable()) {
|
|
th.join();
|
|
}
|
|
}
|
|
}
|
|
|
|
void server_models::update_status(const std::string & name, const update_status_args & args) {
|
|
std::unique_lock<std::mutex> lk(mutex);
|
|
auto it = mapping.find(name);
|
|
if (it != mapping.end()) {
|
|
auto & meta = it->second.meta;
|
|
meta.status = args.status;
|
|
meta.exit_code = args.exit_code;
|
|
if (!args.loaded_info.is_null()) {
|
|
meta.loaded_info = args.loaded_info;
|
|
}
|
|
if (!args.progress.is_null()) {
|
|
meta.progress = args.progress;
|
|
}
|
|
}
|
|
// broadcast status change to SSE
|
|
{
|
|
json data = {
|
|
{"status", server_model_status_to_string(args.status)},
|
|
};
|
|
if (args.status == SERVER_MODEL_STATUS_UNLOADED) {
|
|
data["exit_code"] = args.exit_code;
|
|
}
|
|
if (!args.loaded_info.is_null()) {
|
|
data["info"] = args.loaded_info;
|
|
}
|
|
if (!args.progress.is_null()) {
|
|
data["progress"] = args.progress;
|
|
}
|
|
// note: notify_sse doesn't acquire the lock, so no deadlock here
|
|
notify_sse("status_change", name, data);
|
|
}
|
|
cv.notify_all();
|
|
}
|
|
|
|
void server_models::update_download_progress(const std::string & name, const common_download_progress & progress, bool done, bool ok) {
|
|
json curr;
|
|
{
|
|
std::lock_guard<std::mutex> lk(mutex);
|
|
auto it = mapping.find(name);
|
|
if (it != mapping.end()) {
|
|
if (done) {
|
|
// mark the instance to be erased on next load_models() call
|
|
it->second.meta.status = SERVER_MODEL_STATUS_DOWNLOADED;
|
|
need_reload = true;
|
|
} else {
|
|
json & info = it->second.meta.loaded_info;
|
|
if (!info.contains("progress")) {
|
|
info["progress"] = json{};
|
|
}
|
|
info["progress"][progress.url] = {
|
|
{"done", progress.downloaded},
|
|
{"total", progress.total},
|
|
};
|
|
curr = it->second.meta.loaded_info; // copy
|
|
}
|
|
}
|
|
}
|
|
if (done) {
|
|
cv.notify_all(); // notify in case unload() is waiting for download to be cancelled
|
|
notify_sse(ok ? "download_finished" : "download_failed", name, {});
|
|
} else {
|
|
notify_sse("download_progress", name, curr);
|
|
}
|
|
}
|
|
|
|
bool server_models::remove(const std::string & name) {
|
|
// do everything under one lock acquisition; avoid get_meta() /
|
|
// unload() because they can trigger load_models() which erases
|
|
// transient DOWNLOADING / DOWNLOADED entries as a side-effect
|
|
std::unique_lock<std::mutex> lk(mutex);
|
|
|
|
auto it = mapping.find(name);
|
|
if (it == mapping.end()) {
|
|
throw std::runtime_error("model name=" + name + " is not found");
|
|
}
|
|
if (it->second.meta.source != SERVER_MODEL_SOURCE_CACHE) {
|
|
throw std::runtime_error("model name=" + name + " is not removable (not from cache)");
|
|
}
|
|
|
|
if (it->second.meta.status == SERVER_MODEL_STATUS_DOWNLOADING) {
|
|
// cancel in-flight download
|
|
SRV_INF("cancelling download for model name=%s\n", name.c_str());
|
|
it->second.subproc->request_exit();
|
|
} else if (it->second.meta.is_running()) {
|
|
// stop running instance
|
|
SRV_INF("stopping model instance name=%s\n", name.c_str());
|
|
stopping_models.insert(name);
|
|
if (it->second.meta.status == SERVER_MODEL_STATUS_LOADING) {
|
|
it->second.subproc->terminate();
|
|
}
|
|
cv_stop.notify_all();
|
|
}
|
|
|
|
// wait until the monitoring thread finishes
|
|
wait(lk, name, [](const server_model_meta & meta) {
|
|
return meta.status == SERVER_MODEL_STATUS_UNLOADED
|
|
|| meta.status == SERVER_MODEL_STATUS_DOWNLOADED;
|
|
});
|
|
|
|
// re-find after wait - load_models() may have erased the entry during the wait
|
|
it = mapping.find(name);
|
|
if (it == mapping.end()) {
|
|
// load_models() already joined the thread and erased the entry;
|
|
// we just need to clean up the cached files on disk
|
|
lk.unlock();
|
|
bool ok = common_download_remove(name);
|
|
SRV_INF("removing model name=%s from cache (%s)\n", name.c_str(), ok ? "succeeded" : "partial");
|
|
notify_sse("model_remove", name, {});
|
|
return true;
|
|
}
|
|
|
|
// join before erasing - thread no longer acquires this mutex
|
|
if (it->second.th.joinable()) {
|
|
it->second.th.join();
|
|
}
|
|
|
|
// remove from disk (best-effort: cancelled downloads may have no cached files)
|
|
bool ok = common_download_remove(name);
|
|
mapping.erase(name);
|
|
if (!ok) {
|
|
SRV_WRN("removing model name=%s from disk returned false (no cached files?)\n", name.c_str());
|
|
}
|
|
SRV_INF("removing model name=%s from cache (%s)\n", name.c_str(), ok ? "succeeded" : "partial");
|
|
notify_sse("model_remove", name, {});
|
|
return true;
|
|
}
|
|
|
|
void server_models::wait(const std::string & name, std::function<bool(const server_model_meta &)> predicate) {
|
|
std::unique_lock<std::mutex> lk(mutex);
|
|
wait(lk, name, predicate);
|
|
}
|
|
|
|
void server_models::wait(std::unique_lock<std::mutex> & lk, const std::string & name, std::function<bool(const server_model_meta &)> predicate) {
|
|
cv.wait(lk, [this, &name, &predicate]() {
|
|
auto it = mapping.find(name);
|
|
if (it != mapping.end()) {
|
|
return predicate(it->second.meta);
|
|
|
|
}
|
|
// model was removed from mapping by another code path (e.g. load_models()).
|
|
// nothing left to wait for - tell the caller to proceed.
|
|
return true;
|
|
});
|
|
}
|
|
|
|
bool server_models::ensure_model_ready(const std::string & name) {
|
|
auto meta = get_meta(name);
|
|
if (!meta.has_value()) {
|
|
throw std::runtime_error("model name=" + name + " is not found");
|
|
}
|
|
if (meta->is_ready()) {
|
|
return false; // ready for taking requests
|
|
}
|
|
if (meta->status == SERVER_MODEL_STATUS_SLEEPING) {
|
|
return false; // child is sleeping but still running; new request will wake it up
|
|
}
|
|
if (meta->status == SERVER_MODEL_STATUS_UNLOADED) {
|
|
SRV_INF("model name=%s is not loaded, loading...\n", name.c_str());
|
|
load(name);
|
|
}
|
|
|
|
// wait for loading to complete
|
|
SRV_INF("waiting until model name=%s is fully loaded...\n", name.c_str());
|
|
wait(name, [&meta](const server_model_meta & new_meta) {
|
|
if (new_meta.status != SERVER_MODEL_STATUS_LOADING) {
|
|
meta = new_meta; // update meta for final check after wait
|
|
return true;
|
|
}
|
|
return false;
|
|
});
|
|
|
|
// check final status
|
|
if (!meta.has_value() || meta->is_failed()) {
|
|
throw std::runtime_error("model name=" + name + " failed to load");
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
server_http_res_ptr server_models::proxy_request(const server_http_req & req, const std::string & method, const std::string & name, bool update_last_used) {
|
|
auto meta = get_meta(name);
|
|
if (!meta.has_value()) {
|
|
throw std::runtime_error("model name=" + name + " is not found");
|
|
}
|
|
if (!meta->is_running()) {
|
|
throw std::invalid_argument("model name=" + name + " is not running");
|
|
}
|
|
if (update_last_used) {
|
|
std::unique_lock<std::mutex> lk(mutex);
|
|
mapping[name].meta.last_used = ggml_time_ms();
|
|
}
|
|
SRV_INF("proxying request to model %s on port %d\n", name.c_str(), meta->port);
|
|
std::string proxy_path = req.path;
|
|
if (!req.query_string.empty()) {
|
|
proxy_path += '?' + req.query_string;
|
|
}
|
|
auto proxy = std::make_unique<server_http_proxy>(
|
|
method,
|
|
"http",
|
|
CHILD_ADDR,
|
|
meta->port,
|
|
proxy_path,
|
|
req.headers,
|
|
req.body,
|
|
req.files,
|
|
req.should_stop,
|
|
base_params.timeout_read,
|
|
base_params.timeout_write
|
|
);
|
|
return proxy;
|
|
}
|
|
|
|
void server_models::handle_child_state(const std::string & name, const std::string & raw_input) {
|
|
server_state state;
|
|
json payload;
|
|
|
|
try {
|
|
json data = json::parse(raw_input.substr(strlen(CMD_CHILD_TO_ROUTER_STATE)));
|
|
state = server_state_from_str(json_value(data, "state", std::string()));
|
|
payload = json_value(data, "payload", json{});
|
|
} catch (const std::exception & e) {
|
|
SRV_ERR("failed to parse child state update for name=%s: %s\n", name.c_str(), e.what());
|
|
return;
|
|
}
|
|
|
|
switch (state) {
|
|
case SERVER_STATE_DOWNLOADING:
|
|
{
|
|
std::string result = json_value(payload, "result", std::string());
|
|
std::string url = json_value(payload, "url", std::string());
|
|
auto request_exit = [&]() {
|
|
std::lock_guard<std::mutex> lk(mutex);
|
|
auto it = mapping.find(name);
|
|
if (it != mapping.end()) {
|
|
return it->second.subproc->request_exit();
|
|
}
|
|
};
|
|
if (result == "download_finished") {
|
|
update_download_progress(name, {}, true, true);
|
|
request_exit();
|
|
} else if (result == "download_failed") {
|
|
update_download_progress(name, {}, true, false);
|
|
request_exit();
|
|
} else if (!url.empty()) {
|
|
common_download_progress p;
|
|
p.url = url;
|
|
p.downloaded = json_value(payload, "downloaded", (size_t)0);
|
|
p.total = json_value(payload, "total", (size_t)0);
|
|
update_download_progress(name, p, false);
|
|
}
|
|
} break;
|
|
case SERVER_STATE_LOADING:
|
|
{
|
|
update_status(name, {
|
|
SERVER_MODEL_STATUS_LOADING,
|
|
0,
|
|
nullptr, // no loaded_info yet
|
|
payload,
|
|
});
|
|
} break;
|
|
case SERVER_STATE_READY:
|
|
{
|
|
update_status(name, {
|
|
SERVER_MODEL_STATUS_LOADED,
|
|
0,
|
|
// note: payload can be empty if this is a wakeup from sleep
|
|
payload.size() > 0 ? payload : nullptr,
|
|
{}, // reset progress info
|
|
});
|
|
} break;
|
|
case SERVER_STATE_SLEEPING:
|
|
{
|
|
update_status(name, { SERVER_MODEL_STATUS_SLEEPING });
|
|
} break;
|
|
default:
|
|
// should never happen, but just in case
|
|
GGML_ASSERT(false && "unexpected state from child server");
|
|
}
|
|
}
|
|
|
|
//
|
|
// server_child
|
|
//
|
|
|
|
bool server_child::is_child() {
|
|
const char * router_port = std::getenv("LLAMA_SERVER_ROUTER_PORT");
|
|
return router_port != nullptr;
|
|
}
|
|
|
|
server_child_mode server_child::get_mode() {
|
|
const char * mode = std::getenv("LLAMA_SERVER_CHILD_MODE");
|
|
std::string mode_str(mode ? mode : "");
|
|
if (mode_str == "download") {
|
|
return SERVER_CHILD_MODE_DOWNLOAD;
|
|
} else {
|
|
return SERVER_CHILD_MODE_NORMAL;
|
|
}
|
|
}
|
|
|
|
struct server_download_state : public common_download_callback {
|
|
server_child * self;
|
|
std::function<bool()> should_stop;
|
|
std::atomic<int64_t> last_progress_time{0}; // multiple files downloading in different threads
|
|
bool is_ok = false;
|
|
|
|
server_download_state(server_child * s) : self(s) {}
|
|
|
|
bool run(common_params & params) {
|
|
try {
|
|
common_models_handler handler = common_models_handler_init(params, LLAMA_EXAMPLE_SERVER);
|
|
common_models_handler_apply(handler, params, this);
|
|
is_ok = true;
|
|
} catch (const std::exception & e) {
|
|
auto model_name = params.model.get_name();
|
|
SRV_ERR("download failed for model name=%s: %s\n", model_name.c_str(), e.what());
|
|
is_ok = false;
|
|
}
|
|
return is_ok;
|
|
}
|
|
void on_progress(const common_download_progress & p) {
|
|
json data = {
|
|
{"url", p.url},
|
|
{"downloaded", p.downloaded},
|
|
{"total", p.total},
|
|
};
|
|
self->notify_to_router(server_state_to_str(SERVER_STATE_DOWNLOADING), data);
|
|
}
|
|
void on_start(const common_download_progress & p) override {
|
|
on_progress(p);
|
|
}
|
|
void on_update(const common_download_progress & p) override {
|
|
int64_t now = ggml_time_ms();
|
|
// throttle progress updates to avoid flooding logs
|
|
if (now - last_progress_time.load(std::memory_order_relaxed) >= 100) {
|
|
on_progress(p);
|
|
last_progress_time.store(now, std::memory_order_relaxed);
|
|
}
|
|
}
|
|
void on_done(const common_download_progress & p, bool) override {
|
|
on_progress(p);
|
|
}
|
|
bool is_cancelled() const override {
|
|
return should_stop ? should_stop() : false;
|
|
}
|
|
};
|
|
|
|
int server_child::run_download(common_params & params) {
|
|
auto cancelled = std::make_shared<std::atomic<bool>>(false);
|
|
|
|
// monitor stdin for cancellation command from the router
|
|
std::thread signal_thread = setup([cancelled](int) {
|
|
cancelled->store(true, std::memory_order_relaxed);
|
|
});
|
|
|
|
server_download_state dl(this);
|
|
dl.should_stop = [cancelled]() {
|
|
return cancelled->load(std::memory_order_relaxed);
|
|
};
|
|
|
|
bool ok = dl.run(params);
|
|
|
|
notify_to_router(server_state_to_str(SERVER_STATE_DOWNLOADING), {
|
|
{"result", ok ? "download_finished" : "download_failed"},
|
|
});
|
|
|
|
// router should send CMD_ROUTER_TO_CHILD_EXIT after receiving the result
|
|
if (signal_thread.joinable()) {
|
|
signal_thread.join();
|
|
}
|
|
|
|
SRV_INF("download completed %s\n", ok ? "successfully" : "with errors");
|
|
return 0;
|
|
}
|
|
|
|
std::thread server_child::setup(const std::function<void(int)> & shutdown_handler) {
|
|
// setup thread for monitoring stdin
|
|
return std::thread([shutdown_handler]() {
|
|
// wait for EOF on stdin
|
|
SRV_INF("%s", "child server monitoring thread started, waiting for EOF on stdin...\n");
|
|
bool eof = false;
|
|
while (true) {
|
|
std::string line;
|
|
if (!std::getline(std::cin, line)) {
|
|
// EOF detected, that means the router server is unexpectedly exit or killed
|
|
eof = true;
|
|
break;
|
|
}
|
|
if (line.find(CMD_ROUTER_TO_CHILD_EXIT) != std::string::npos) {
|
|
SRV_INF("%s", "exit command received, exiting...\n");
|
|
shutdown_handler(0);
|
|
break;
|
|
}
|
|
}
|
|
if (eof) {
|
|
SRV_INF("%s", "EOF on stdin detected, forcing shutdown...\n");
|
|
exit(1);
|
|
}
|
|
});
|
|
}
|
|
|
|
void server_child::notify_to_router(const std::string & state, const json & payload) {
|
|
json data = {
|
|
{"state", state},
|
|
{"payload", payload},
|
|
};
|
|
std::lock_guard<std::mutex> lk(mtx_stdout);
|
|
common_log_pause(common_log_main());
|
|
fflush(stdout);
|
|
fprintf(stdout, "%s%s\n", CMD_CHILD_TO_ROUTER_STATE, safe_json_to_str(data).c_str());
|
|
fflush(stdout);
|
|
common_log_resume(common_log_main());
|
|
}
|
|
|
|
|
|
//
|
|
// server_models_routes
|
|
//
|
|
|
|
// RAII wrapper similar to server_response_reader, but doesn't use server_queue
|
|
static std::atomic<int> sse_client_id_counter = 0;
|
|
struct server_models_sse_client {
|
|
server_response & queue_results;
|
|
int client_id;
|
|
server_models_sse_client(server_response & q)
|
|
: queue_results(q), client_id(sse_client_id_counter.fetch_add(1, std::memory_order_relaxed)) {
|
|
SRV_DBG("new SSE client connected, assigned client_id=%d\n", client_id);
|
|
queue_results.add_waiting_task_id(client_id);
|
|
}
|
|
~server_models_sse_client() {
|
|
SRV_DBG("SSE client disconnected, removing client_id=%d\n", client_id);
|
|
queue_results.remove_waiting_task_id(client_id);
|
|
}
|
|
|
|
// return nullptr if should_stop() is true before receiving a result
|
|
// note: if one error is received, it will stop further processing and return error result
|
|
server_task_result_ptr next(const std::function<bool()> & should_stop) {
|
|
while (true) {
|
|
static const int http_polling_seconds = 1; // check should_stop every 1 second
|
|
server_task_result_ptr result = queue_results.recv_with_timeout({client_id}, http_polling_seconds);
|
|
if (result == nullptr) {
|
|
// timeout, check stop condition
|
|
if (should_stop()) {
|
|
return nullptr;
|
|
}
|
|
// continue waiting otherwise
|
|
} else {
|
|
SRV_DBG("recv result for client_id=%d: %s\n", client_id, safe_json_to_str(result->to_json()).c_str());
|
|
return result;
|
|
}
|
|
}
|
|
// should not reach here
|
|
}
|
|
};
|
|
|
|
static void res_ok(std::unique_ptr<server_http_res> & res, const json & response_data) {
|
|
res->status = 200;
|
|
res->data = safe_json_to_str(response_data);
|
|
}
|
|
|
|
static void res_err(std::unique_ptr<server_http_res> & res, const json & error_data) {
|
|
res->status = json_value(error_data, "code", 500);
|
|
res->data = safe_json_to_str({{ "error", error_data }});
|
|
}
|
|
|
|
static bool router_validate_model(std::string & name, server_models & models, bool models_autoload, std::unique_ptr<server_http_res> & res) {
|
|
if (name.empty()) {
|
|
res_err(res, format_error_response("model name is missing from the request", ERROR_TYPE_INVALID_REQUEST));
|
|
return false;
|
|
}
|
|
auto meta = models.get_meta(name);
|
|
if (!meta.has_value()) {
|
|
res_err(res, format_error_response(string_format("model '%s' not found", name.c_str()), ERROR_TYPE_INVALID_REQUEST));
|
|
return false;
|
|
}
|
|
// resolve alias to canonical model name
|
|
name = meta->name;
|
|
if (models_autoload) {
|
|
models.ensure_model_ready(name);
|
|
} else {
|
|
if (!meta->is_running()) {
|
|
res_err(res, format_error_response("model is not loaded", ERROR_TYPE_INVALID_REQUEST));
|
|
return false;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
static bool is_autoload(const common_params & params, const server_http_req & req) {
|
|
std::string autoload = req.get_param("autoload");
|
|
if (autoload.empty()) {
|
|
return params.models_autoload;
|
|
} else {
|
|
return autoload == "true" || autoload == "1";
|
|
}
|
|
}
|
|
|
|
void server_models_routes::init_routes() {
|
|
this->get_router_props = [this](const server_http_req & req) {
|
|
std::string name = req.get_param("model");
|
|
if (name.empty()) {
|
|
// main instance
|
|
auto res = std::make_unique<server_http_res>();
|
|
res_ok(res, {
|
|
// TODO: add support for this on web UI
|
|
{"role", "router"},
|
|
{"max_instances", params.models_max},
|
|
{"models_autoload", params.models_autoload},
|
|
// this is a dummy response to make sure the UI doesn't break
|
|
{"model_alias", "llama-server"},
|
|
{"model_path", "none"},
|
|
{"default_generation_settings", {
|
|
{"params", json{}},
|
|
{"n_ctx", 0},
|
|
}},
|
|
// New key
|
|
{"ui_settings", ui_settings},
|
|
{"build_info", std::string(llama_build_info())},
|
|
{"cors_proxy_enabled", params.ui_mcp_proxy},
|
|
});
|
|
return res;
|
|
}
|
|
return proxy_get(req);
|
|
};
|
|
|
|
this->proxy_get = [this](const server_http_req & req) {
|
|
std::string method = "GET";
|
|
std::string name = req.get_param("model");
|
|
bool autoload = is_autoload(params, req);
|
|
auto error_res = std::make_unique<server_http_res>();
|
|
if (!router_validate_model(name, models, autoload, error_res)) {
|
|
return error_res;
|
|
}
|
|
return models.proxy_request(req, method, name, false);
|
|
};
|
|
|
|
this->proxy_post = [this](const server_http_req & req) {
|
|
std::string method = "POST";
|
|
json body = json::parse(req.body);
|
|
std::string name = json_value(body, "model", std::string());
|
|
bool autoload = is_autoload(params, req);
|
|
auto error_res = std::make_unique<server_http_res>();
|
|
if (!router_validate_model(name, models, autoload, error_res)) {
|
|
return error_res;
|
|
}
|
|
return models.proxy_request(req, method, name, true); // update last usage for POST request only
|
|
};
|
|
|
|
this->post_router_models_load = [this](const server_http_req & req) {
|
|
auto res = std::make_unique<server_http_res>();
|
|
json body = json::parse(req.body);
|
|
std::string name = json_value(body, "model", std::string());
|
|
auto meta = models.get_meta(name);
|
|
if (!meta.has_value()) {
|
|
res_err(res, format_error_response("model is not found", ERROR_TYPE_NOT_FOUND));
|
|
return res;
|
|
}
|
|
if (meta->is_running()) {
|
|
res_err(res, format_error_response("model is already running", ERROR_TYPE_INVALID_REQUEST));
|
|
return res;
|
|
}
|
|
models.load(meta->name);
|
|
res_ok(res, {{"success", true}});
|
|
return res;
|
|
};
|
|
|
|
this->get_router_models = [this](const server_http_req & req) {
|
|
bool reload = !req.get_param("reload", "").empty();
|
|
if (reload) {
|
|
models.load_models();
|
|
}
|
|
auto res = std::make_unique<server_http_res>();
|
|
json models_json = json::array();
|
|
auto all_models = models.get_all_meta();
|
|
std::time_t t = std::time(0);
|
|
for (const auto & meta : all_models) {
|
|
json status {
|
|
{"value", server_model_status_to_string(meta.status)},
|
|
{"args", meta.args},
|
|
};
|
|
if (!meta.preset.name.empty()) {
|
|
common_preset preset_copy = meta.preset;
|
|
unset_reserved_args(preset_copy, false);
|
|
preset_copy.unset_option("LLAMA_ARG_HOST");
|
|
preset_copy.unset_option("LLAMA_ARG_PORT");
|
|
preset_copy.unset_option("LLAMA_ARG_ALIAS");
|
|
preset_copy.unset_option("LLAMA_ARG_TAGS");
|
|
status["preset"] = preset_copy.to_ini();
|
|
}
|
|
if (meta.is_failed()) {
|
|
status["exit_code"] = meta.exit_code;
|
|
status["failed"] = true;
|
|
}
|
|
|
|
// pi coding agent multimodal compatibility
|
|
json input_modalities = json::array({"text"});
|
|
if (meta.multimodal.inp_vision) {
|
|
input_modalities.push_back("image");
|
|
}
|
|
if (meta.multimodal.inp_audio) {
|
|
input_modalities.push_back("audio");
|
|
}
|
|
json architecture {
|
|
{"input_modalities", input_modalities},
|
|
{"output_modalities", json::array({"text"})},
|
|
};
|
|
|
|
json model_info = json {
|
|
{"id", meta.name},
|
|
{"aliases", meta.aliases},
|
|
{"tags", meta.tags},
|
|
{"object", "model"}, // for OAI-compat
|
|
{"owned_by", "llamacpp"}, // for OAI-compat
|
|
{"created", t}, // for OAI-compat
|
|
{"status", status},
|
|
{"architecture", architecture},
|
|
{"source", server_model_source_to_string(meta.source)},
|
|
{"can_remove", meta.source == SERVER_MODEL_SOURCE_CACHE},
|
|
// {"need_download", meta.need_download},
|
|
// TODO: add other fields, may require reading GGUF metadata
|
|
};
|
|
|
|
// merge with loaded_info from the child process if available
|
|
if (meta.is_running()) {
|
|
for (auto it = meta.loaded_info.begin(); it != meta.loaded_info.end(); ++it) {
|
|
if (!model_info.contains(it.key())) {
|
|
model_info[it.key()] = it.value();
|
|
}
|
|
}
|
|
}
|
|
models_json.push_back(model_info);
|
|
}
|
|
res_ok(res, {
|
|
{"data", models_json},
|
|
{"object", "list"},
|
|
});
|
|
return res;
|
|
};
|
|
|
|
this->post_router_models_unload = [this](const server_http_req & req) {
|
|
auto res = std::make_unique<server_http_res>();
|
|
json body = json::parse(req.body);
|
|
std::string name = json_value(body, "model", std::string());
|
|
auto model = models.get_meta(name);
|
|
if (!model.has_value()) {
|
|
res_err(res, format_error_response("model is not found", ERROR_TYPE_INVALID_REQUEST));
|
|
return res;
|
|
}
|
|
if (!model->is_running() && model->status != SERVER_MODEL_STATUS_DOWNLOADING) {
|
|
res_err(res, format_error_response("model is not running", ERROR_TYPE_INVALID_REQUEST));
|
|
return res;
|
|
}
|
|
models.unload(model->name);
|
|
res_ok(res, {{"success", true}});
|
|
return res;
|
|
};
|
|
|
|
this->get_router_models_sse = [this](const server_http_req & req) {
|
|
auto res = std::make_unique<server_http_res>();
|
|
res->status = 200;
|
|
res->content_type = "text/event-stream";
|
|
auto sse_client = std::make_shared<server_models_sse_client>(models.sse);
|
|
res->next = [this, sse_client, &req](std::string & output) -> bool {
|
|
auto result = sse_client->next([&]() {
|
|
return stopping.load(std::memory_order_relaxed) || req.should_stop();
|
|
});
|
|
if (result == nullptr) {
|
|
return false; // client disconnected or should_stop
|
|
}
|
|
output = "data: " + safe_json_to_str(result->to_json()) + "\n\n";
|
|
return true; // listen for the next event
|
|
};
|
|
return res;
|
|
};
|
|
|
|
this->post_router_models = [this](const server_http_req & req) {
|
|
auto res = std::make_unique<server_http_res>();
|
|
|
|
json body = json::parse(req.body);
|
|
std::string name = json_value(body, "model", std::string());
|
|
if (name.empty()) {
|
|
throw std::invalid_argument("model must be a non-empty string");
|
|
}
|
|
|
|
common_params p;
|
|
p.model.hf_repo = name;
|
|
p.hf_token = params.hf_token;
|
|
|
|
// validate by fetching metadata
|
|
bool ok = false;
|
|
try {
|
|
common_models_handler_init(p, LLAMA_EXAMPLE_SERVER);
|
|
ok = true;
|
|
} catch (...) {
|
|
SRV_ERR("unknown error while validating model '%s'\n", name.c_str());
|
|
// other exceptions will be handled by the outer ex_wrapper()
|
|
throw;
|
|
}
|
|
|
|
if (!ok) {
|
|
throw std::invalid_argument("model validation failed, unable to download");
|
|
}
|
|
|
|
// reject if model already exists
|
|
if (models.has_model(name)) {
|
|
throw std::invalid_argument("model '" + name + "' already exists");
|
|
}
|
|
|
|
// then, proceed with the actual download
|
|
SRV_INF("starting download for model '%s'\n", name.c_str());
|
|
{
|
|
server_models::load_options load_opts;
|
|
load_opts.mode = SERVER_CHILD_MODE_DOWNLOAD;
|
|
load_opts.custom_meta = server_model_meta{};
|
|
load_opts.custom_meta->source = SERVER_MODEL_SOURCE_CACHE;
|
|
load_opts.custom_meta->name = name;
|
|
models.load(name, load_opts);
|
|
}
|
|
|
|
res_ok(res, {{"success", true}});
|
|
return res;
|
|
};
|
|
|
|
this->del_router_models = [this](const server_http_req & req) {
|
|
auto res = std::make_unique<server_http_res>();
|
|
|
|
std::string name = req.get_param("model");
|
|
if (name.empty()) {
|
|
throw std::invalid_argument("model must be a non-empty string");
|
|
}
|
|
|
|
models.remove(name); // throws on error
|
|
|
|
res_ok(res, {{"success", true}});
|
|
return res;
|
|
};
|
|
}
|
|
|
|
|
|
|
|
//
|
|
// server_http_proxy
|
|
//
|
|
|
|
// simple implementation of a pipe
|
|
// used for streaming data between threads
|
|
template<typename T>
|
|
struct pipe_t {
|
|
std::mutex mutex;
|
|
std::condition_variable cv;
|
|
std::queue<T> queue;
|
|
std::atomic<bool> writer_closed{false};
|
|
std::atomic<bool> reader_closed{false};
|
|
void close_write() {
|
|
writer_closed.store(true, std::memory_order_relaxed);
|
|
cv.notify_all();
|
|
}
|
|
void close_read() {
|
|
reader_closed.store(true, std::memory_order_relaxed);
|
|
cv.notify_all();
|
|
}
|
|
bool read(T & output, const std::function<bool()> & should_stop) {
|
|
std::unique_lock<std::mutex> lk(mutex);
|
|
constexpr auto poll_interval = std::chrono::milliseconds(500);
|
|
while (true) {
|
|
if (!queue.empty()) {
|
|
output = std::move(queue.front());
|
|
queue.pop();
|
|
return true;
|
|
}
|
|
if (writer_closed.load()) {
|
|
return false; // clean EOF
|
|
}
|
|
if (should_stop()) {
|
|
close_read(); // signal broken pipe to writer
|
|
return false; // cancelled / reader no longer alive
|
|
}
|
|
cv.wait_for(lk, poll_interval);
|
|
}
|
|
}
|
|
bool write(T && data) {
|
|
std::lock_guard<std::mutex> lk(mutex);
|
|
if (reader_closed.load()) {
|
|
return false; // broken pipe
|
|
}
|
|
queue.push(std::move(data));
|
|
cv.notify_one();
|
|
return true;
|
|
}
|
|
};
|
|
|
|
static std::string to_lower_copy(const std::string & value) {
|
|
std::string lowered(value.size(), '\0');
|
|
std::transform(value.begin(), value.end(), lowered.begin(), [](unsigned char c) { return std::tolower(c); });
|
|
return lowered;
|
|
}
|
|
|
|
static bool should_strip_proxy_header(const std::string & header_name) {
|
|
// Headers that get duplicated when router forwards child responses
|
|
if (header_name == "server" ||
|
|
header_name == "transfer-encoding" ||
|
|
header_name == "content-length" || // quick fix for https://github.com/ggml-org/llama.cpp/issues/17710
|
|
header_name == "keep-alive") {
|
|
return true;
|
|
}
|
|
|
|
// Router injects CORS, child also sends them: duplicate
|
|
if (header_name.rfind("access-control-", 0) == 0) {
|
|
return true;
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
static std::string generate_multipart_boundary() {
|
|
thread_local std::mt19937 gen(std::random_device{}());
|
|
static const char chars[] = "0123456789abcdefghijklmnopqrstuvwxyz";
|
|
std::uniform_int_distribution<> dis(0, sizeof(chars) - 2);
|
|
std::string boundary = "----llama-cpp-proxy-";
|
|
for (int i = 0; i < 16; i++) {
|
|
boundary += chars[dis(gen)];
|
|
}
|
|
return boundary;
|
|
}
|
|
|
|
static std::string build_multipart_body(
|
|
const json & form_fields,
|
|
const std::map<std::string, uploaded_file> & files,
|
|
const std::string & boundary) {
|
|
static auto sanitize_field = [](const std::string & text) {
|
|
std::string result;
|
|
result.reserve(text.size());
|
|
for (char c : text) {
|
|
if (c != '\n' && c != '\r' && c != '"') {
|
|
result += c;
|
|
}
|
|
}
|
|
return result;
|
|
};
|
|
|
|
std::ostringstream body;
|
|
|
|
for (const auto & [key, value] : form_fields.items()) {
|
|
if (value.is_array()) {
|
|
for (const auto & item : value) {
|
|
body << "--" << boundary << "\r\n";
|
|
body << "Content-Disposition: form-data; name=\"" << sanitize_field(key) << "\"\r\n";
|
|
body << "\r\n";
|
|
if (!item.is_string()) {
|
|
throw std::invalid_argument("expected string");
|
|
}
|
|
body << item.get<std::string>() << "\r\n";
|
|
}
|
|
} else {
|
|
body << "--" << boundary << "\r\n";
|
|
body << "Content-Disposition: form-data; name=\"" << sanitize_field(key) << "\"\r\n";
|
|
body << "\r\n";
|
|
if (!value.is_string()) {
|
|
throw std::invalid_argument("expected string");
|
|
}
|
|
body << value.get<std::string>() << "\r\n";
|
|
}
|
|
}
|
|
|
|
for (const auto & [key, file] : files) {
|
|
body << "--" << boundary << "\r\n";
|
|
body << "Content-Disposition: form-data; name=\"" << sanitize_field(key) << "\"";
|
|
if (!file.filename.empty()) {
|
|
body << "; filename=\"" << sanitize_field(file.filename) << "\"";
|
|
}
|
|
body << "\r\n";
|
|
if (!file.content_type.empty()) {
|
|
body << "Content-Type: " << sanitize_field(file.content_type) << "\r\n";
|
|
} else {
|
|
body << "Content-Type: application/octet-stream\r\n";
|
|
}
|
|
body << "\r\n";
|
|
body.write(reinterpret_cast<const char*>(file.data.data()), file.data.size());
|
|
body << "\r\n";
|
|
}
|
|
|
|
body << "--" << boundary << "--\r\n";
|
|
return body.str();
|
|
}
|
|
|
|
server_http_proxy::server_http_proxy(
|
|
const std::string & method,
|
|
const std::string & scheme,
|
|
const std::string & host,
|
|
int port,
|
|
const std::string & path,
|
|
const std::map<std::string, std::string> & headers,
|
|
const std::string & body,
|
|
const std::map<std::string, uploaded_file> & files,
|
|
const std::function<bool()> should_stop,
|
|
int32_t timeout_read,
|
|
int32_t timeout_write
|
|
) {
|
|
// shared between reader and writer threads
|
|
auto cli = std::make_shared<httplib::ClientImpl>(host, port);
|
|
auto pipe = std::make_shared<pipe_t<msg_t>>();
|
|
|
|
if (scheme == "https") {
|
|
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
|
cli.reset(new httplib::SSLClient(host, port));
|
|
#else
|
|
throw std::runtime_error("HTTPS requested but CPPHTTPLIB_OPENSSL_SUPPORT is not defined");
|
|
#endif
|
|
}
|
|
|
|
// setup Client
|
|
cli->set_follow_location(true);
|
|
cli->set_connection_timeout(timeout_read, 0); // use --timeout value instead of hardcoded 5 s
|
|
cli->set_write_timeout(timeout_read, 0); // reversed for cli (client) vs srv (server)
|
|
cli->set_read_timeout(timeout_write, 0);
|
|
this->status = 500; // to be overwritten upon response
|
|
this->cleanup = [pipe]() {
|
|
pipe->close_read();
|
|
pipe->close_write();
|
|
};
|
|
|
|
// wire up the receive end of the pipe
|
|
this->next = [pipe, should_stop](std::string & out) -> bool {
|
|
msg_t msg;
|
|
bool has_next = pipe->read(msg, should_stop);
|
|
if (!msg.data.empty()) {
|
|
out = std::move(msg.data);
|
|
}
|
|
return has_next; // false if EOF or pipe broken
|
|
};
|
|
|
|
// wire up the HTTP client
|
|
// note: do NOT capture `this` pointer, as it may be destroyed before the thread ends
|
|
httplib::ResponseHandler response_handler = [pipe, cli](const httplib::Response & response) {
|
|
msg_t msg;
|
|
msg.status = response.status;
|
|
for (const auto & [key, value] : response.headers) {
|
|
const auto lowered = to_lower_copy(key);
|
|
if (should_strip_proxy_header(lowered)) {
|
|
continue;
|
|
}
|
|
if (lowered == "content-type") {
|
|
msg.content_type = value;
|
|
continue;
|
|
}
|
|
msg.headers[key] = value;
|
|
}
|
|
return pipe->write(std::move(msg)); // send headers first
|
|
};
|
|
httplib::ContentReceiverWithProgress content_receiver = [pipe](const char * data, size_t data_length, size_t, size_t) {
|
|
// send data chunks
|
|
// returns false if pipe is closed / broken (signal to stop receiving)
|
|
return pipe->write({{}, 0, std::string(data, data_length), ""});
|
|
};
|
|
|
|
// when files are present, the body was converted from multipart form data to JSON
|
|
// we need to reconstruct the multipart body for the downstream server
|
|
std::string effective_body = body;
|
|
std::string override_content_type;
|
|
bool has_files = !files.empty();
|
|
|
|
if (has_files) {
|
|
json form_fields = json::parse(body, nullptr, false);
|
|
if (!form_fields.is_discarded()) {
|
|
auto boundary = generate_multipart_boundary();
|
|
effective_body = build_multipart_body(form_fields, files, boundary);
|
|
override_content_type = "multipart/form-data; boundary=" + boundary;
|
|
} else {
|
|
throw std::runtime_error("failed to parse multipart form fields JSON");
|
|
}
|
|
}
|
|
|
|
// prepare the request to destination server
|
|
httplib::Request req;
|
|
{
|
|
req.method = method;
|
|
req.path = path;
|
|
for (const auto & [key, value] : headers) {
|
|
const auto lowered = to_lower_copy(key);
|
|
if (lowered == "accept-encoding") {
|
|
// disable Accept-Encoding to avoid compressed responses
|
|
continue;
|
|
}
|
|
if (lowered == "transfer-encoding") {
|
|
// the body is already decoded
|
|
continue;
|
|
}
|
|
if (lowered == "content-length") {
|
|
// let httplib calculate Content-Length from the actual body
|
|
continue;
|
|
}
|
|
if (lowered == "content-type") {
|
|
if (has_files) {
|
|
// we set our own Content-Type with the new boundary
|
|
continue;
|
|
}
|
|
// when no files but the original request was multipart,
|
|
// the body is now JSON, so correct the Content-Type
|
|
if (value.find("multipart/form-data") != std::string::npos) {
|
|
override_content_type = "application/json; charset=utf-8";
|
|
continue;
|
|
}
|
|
}
|
|
if (lowered == "host") {
|
|
bool is_default_port = (scheme == "https" && port == 443) || (scheme == "http" && port == 80);
|
|
req.set_header(key, is_default_port ? host : host + ":" + std::to_string(port));
|
|
} else {
|
|
req.set_header(key, value);
|
|
}
|
|
}
|
|
req.body = effective_body;
|
|
if (!override_content_type.empty()) {
|
|
req.set_header("Content-Type", override_content_type);
|
|
}
|
|
req.response_handler = response_handler;
|
|
req.content_receiver = content_receiver;
|
|
}
|
|
|
|
// start the proxy thread
|
|
SRV_DBG("start proxy thread %s %s\n", req.method.c_str(), req.path.c_str());
|
|
this->thread = std::thread([cli, pipe, req]() {
|
|
auto result = cli->send(std::move(req));
|
|
if (result.error() != httplib::Error::Success) {
|
|
auto err_str = httplib::to_string(result.error());
|
|
SRV_ERR("http client error: %s\n", err_str.c_str());
|
|
pipe->write({{}, 500, "", ""}); // header
|
|
pipe->write({{}, 0, "proxy error: " + err_str, ""}); // body
|
|
}
|
|
pipe->close_write(); // signal EOF to reader
|
|
SRV_DBG("%s", "client request thread ended\n");
|
|
});
|
|
this->thread.detach();
|
|
|
|
// wait for the first chunk (headers)
|
|
{
|
|
msg_t header;
|
|
if (pipe->read(header, should_stop)) {
|
|
SRV_DBG("%s", "received response headers\n");
|
|
this->status = header.status;
|
|
this->headers = std::move(header.headers);
|
|
if (!header.content_type.empty()) {
|
|
this->content_type = std::move(header.content_type);
|
|
}
|
|
} else {
|
|
SRV_DBG("%s", "no response headers received (request cancelled?)\n");
|
|
}
|
|
}
|
|
}
|