diff --git a/common/arg.cpp b/common/arg.cpp index 0fc94e5532..81752657ca 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -3236,6 +3236,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.models_max = value; } ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_MODELS_MAX")); + add_opt(common_arg( + {"--models-memory-margin"}, "N", + string_format("for router server, MiB of memory to leave free, per device (default: %d, 0 = unlimited)", params.models_memory_margin), + [](common_params & params, int value) { + params.models_memory_margin = value; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_MODELS_MEMORY_MARGIN")); add_opt(common_arg( {"--models-autoload"}, {"--no-models-autoload"}, @@ -3486,6 +3493,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.offline = true; } ).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_DOWNLOAD}).set_env("LLAMA_ARG_OFFLINE")); + add_opt(common_arg( + {"--measure-only"}, + "Load the model to measure memory requirements, print to stdout, then exit", + [](common_params & params) { + params.measure_only = true; + } + )); add_opt(common_arg( {"-lv", "--verbosity", "--log-verbosity"}, "N", string_format("Set the verbosity threshold. Messages with a higher verbosity will be ignored. Values:\n" diff --git a/common/common.h b/common/common.h index 2adb310b83..ec6add8012 100644 --- a/common/common.h +++ b/common/common.h @@ -532,6 +532,8 @@ struct common_params { int32_t control_vector_layer_start = -1; // layer range for control vector int32_t control_vector_layer_end = -1; // layer range for control vector bool offline = false; + bool skip_download = false; // skip model file downloading + bool measure_only = false; // load model with no_alloc to measure memory, print to stdout, then exit int32_t ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used. int32_t ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line @@ -660,6 +662,7 @@ struct common_params { std::string models_dir = ""; // directory containing models for the router server std::string models_preset = ""; // directory containing model presets for the router server int models_max = 4; // maximum number of models to load simultaneously + int models_memory_margin = 1024; // MiB of free memory to preserve per device (0 = disabled) bool models_autoload = true; // automatically load models when requested via the router server std::string models_preset_hf = ""; // show a warning about remote presets on router loaded (if not empty) diff --git a/src/llama-ext.h b/src/llama-ext.h index 348bbae957..4d4f47f2c5 100644 --- a/src/llama-ext.h +++ b/src/llama-ext.h @@ -90,7 +90,17 @@ LLAMA_API ggml_backend_dev_t llama_model_get_device(const struct llama_model * m LLAMA_API llama_memory_breakdown llama_get_memory_breakdown(const struct llama_context * ctx); -// Set whether the context outputs nextn embeddings or not +// Returns the projected memory use (model + context + compute) in bytes +// for the given device within this context. Returns 0 if the device is not used. +LLAMA_API uint64_t llama_context_device_memory( + const struct llama_context * ctx, + ggml_backend_dev_t device); + +// +// pre-norm embeddings (hidden state before the final output norm) +// + +// Set whether the context outputs pre-norm embeddings or not // If masked == true, output the embeddings only for the tokens with batch.logits != 0 // If masked == false, output the embeddings for all tokens in the batch regardless of batch.logits LLAMA_API void llama_set_embeddings_nextn(struct llama_context * ctx, bool value, bool masked); diff --git a/tools/server/server-models.cpp b/tools/server/server-models.cpp index 81da00c0e0..b719fe8399 100644 --- a/tools/server/server-models.cpp +++ b/tools/server/server-models.cpp @@ -145,6 +145,7 @@ static void unset_reserved_args(common_preset & preset, bool unset_model_args) { preset.unset_option("LLAMA_API_KEY"); preset.unset_option("LLAMA_ARG_MODELS_DIR"); preset.unset_option("LLAMA_ARG_MODELS_MAX"); + preset.unset_option("LLAMA_ARG_MODELS_MEMORY_MARGIN"); preset.unset_option("LLAMA_ARG_MODELS_PRESET"); preset.unset_option("LLAMA_ARG_MODELS_AUTOLOAD"); if (unset_model_args) { @@ -260,9 +261,39 @@ server_models::server_models( bin_path = get_server_exec_path().string(); } catch (const std::exception & e) { bin_path = argv[0]; - LOG_WRN("failed to get server executable path: %s\n", e.what()); - LOG_WRN("using original argv[0] as fallback: %s\n", argv[0]); + SRV_WRN("failed to get server executable path: %s\n", e.what()); + SRV_WRN("using original argv[0] as fallback: %s\n", argv[0]); } + + const size_t memory_margin = (size_t) base_params.models_memory_margin * 1024 * 1024; + + if (memory_margin > 0) { + ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + ggml_backend_buffer_type_t cpu_buft = cpu_dev ? ggml_backend_dev_buffer_type(cpu_dev) : nullptr; + + const size_t n_devs = ggml_backend_dev_count(); + for (size_t i = 0; i < n_devs; i++) { + ggml_backend_dev_t dev = ggml_backend_dev_get(i); + ggml_backend_buffer_type_t dev_buft = ggml_backend_dev_buffer_type(dev); + if (dev_buft) { + buft_by_name[ggml_backend_buft_name(dev_buft)] = dev_buft; + } + ggml_backend_buffer_type_t host_buft = ggml_backend_dev_host_buffer_type(dev); + if (host_buft && cpu_buft) { + buft_by_name[ggml_backend_buft_name(host_buft)] = cpu_buft; + } + + size_t free, total; + ggml_backend_dev_memory(dev, &free, &total); + if (total > 0 && dev_buft) { + const size_t available = (free > memory_margin) ? free - memory_margin : 0; + bmm_available[dev_buft] = available; + SRV_DBG("buft %s: available memory after margin=%zu MiB\n", + ggml_backend_buft_name(dev_buft), available / (1024 * 1024)); + } + } + } + load_models(); } @@ -456,6 +487,7 @@ void server_models::load_models() { /* port */ 0, /* status */ SERVER_MODEL_STATUS_UNLOADED, /* last_used */ 0, + /* bmm_req */ {}, /* args */ std::vector(), /* loaded_info */ {}, /* progress */ {}, @@ -623,6 +655,7 @@ void server_models::load_models() { /* port */ 0, /* status */ SERVER_MODEL_STATUS_UNLOADED, /* last_used */ 0, + /* bmm_req */ {}, /* args */ std::vector(), /* loaded_info */ {}, /* progress */ {}, @@ -797,30 +830,87 @@ std::vector server_models::get_all_meta() { return result; } -void server_models::unload_lru() { - if (base_params.models_max <= 0) { - return; // no limit - } - // remove one of the servers if we passed the models_max (least recently used - LRU) - std::string lru_model_name = ""; - int64_t lru_last_used = ggml_time_ms(); - size_t count_active = 0; - { - std::unique_lock lk(mutex); - for (const auto & m : mapping) { - if (m.second.meta.is_running()) { - count_active++; - if (m.second.meta.last_used < lru_last_used) { - lru_model_name = m.first; - lru_last_used = m.second.meta.last_used; - } +int server_models::can_fit(const buft_memory_map & bmm_req) const { + buft_memory_map bmm_total; + for (const auto & m : mapping) { + if (m.second.meta.is_running()) { + for (const auto & [buft, mem] : m.second.meta.bmm_req) { + bmm_total[buft] += mem; } } } - if (!lru_model_name.empty() && count_active >= (size_t)base_params.models_max) { - SRV_INF("models_max limit reached, removing LRU name=%s\n", lru_model_name.c_str()); + + auto get = [](const buft_memory_map & dmm, ggml_backend_buffer_type_t buft) -> size_t { + auto it = dmm.find(buft); + return it != dmm.end() ? it->second : 0; + }; + + int res = 0; + + for (const auto & [buft, limit] : bmm_available) { + const size_t mem_total = get(bmm_total, buft); + const size_t mem_new = get(bmm_req, buft); + + SRV_DBG("buft %s: total=%zu MiB, new=%zu MiB, limit=%zu MiB\n", + ggml_backend_buft_name(buft), + mem_total / (1024 * 1024), mem_new / (1024 * 1024), limit / (1024 * 1024)); + + if (mem_total + mem_new > limit) { + res++; + } + } + + return res; +} + +bool server_models::limits_exceeded(const buft_memory_map & bmm_req) const { + const bool check_active = base_params.models_max > 0; + const bool check_memory = base_params.models_memory_margin > 0; + + if (!check_active && !check_memory) { + return false; + } + + int count_active = 0; + for (const auto & m : mapping) { + if (m.second.meta.is_running()) { + count_active++; + } + } + + const bool active_exceeded = check_active && count_active >= base_params.models_max; + const bool memory_exceeded = check_memory && can_fit(bmm_req) > 0; + + return active_exceeded || memory_exceeded; +} + +void server_models::unload_lru(const buft_memory_map & bmm_req) { + if (base_params.models_memory_margin > 0) { + GGML_ASSERT(!bmm_available.empty()); + } + + while (true) { + std::string lru_model_name; + { + std::unique_lock lk(mutex); + if (!limits_exceeded(bmm_req)) { + break; + } + int64_t lru_last_used = ggml_time_ms(); + for (const auto & m : mapping) { + if (m.second.meta.is_running() && m.second.meta.last_used < lru_last_used) { + lru_model_name = m.first; + lru_last_used = m.second.meta.last_used; + } + } + } + + if (lru_model_name.empty()) { + break; + } + + SRV_INF("limits exceeded, removing LRU name=%s\n", lru_model_name.c_str()); unload(lru_model_name); - // wait for unload to complete { std::unique_lock lk(mutex); cv.wait(lk, [this, &lru_model_name]() { @@ -830,6 +920,88 @@ void server_models::unload_lru() { } } +buft_memory_map server_models::estimate_model_memory(const std::string & name) { + std::vector child_args; + std::vector child_env; + { + std::lock_guard lk(mutex); + auto & meta = mapping[name].meta; + child_args = meta.preset.to_args(bin_path); + child_env = base_env; + } + child_args.push_back("--measure-only"); + child_args.push_back("--offline"); + + SRV_INF("estimating memory for model name=%s\n", name.c_str()); + + std::vector argv = to_char_ptr_array(child_args); + std::vector envp = to_char_ptr_array(child_env); + + subprocess_s proc; + int options = subprocess_option_no_window | subprocess_option_combined_stdout_stderr; + if (subprocess_create_ex(argv.data(), options, envp.data(), &proc) != 0) { + SRV_ERR("failed to spawn measure process for model name=%s\n", name.c_str()); + return {}; + } + + buft_memory_map result; + FILE * out = subprocess_stdout(&proc); + if (out) { + char buffer[4096]; + while (fgets(buffer, sizeof(buffer), out) != nullptr) { + LOG("[measure:%s] %s", name.c_str(), buffer); + std::string line(buffer); + if (string_starts_with(line, "measure:")) { + std::istringstream iss(line.substr(strlen("measure:"))); + std::string buft_name; + size_t size = 0; + if (iss >> buft_name >> size) { + auto it = buft_by_name.find(buft_name); + if (it != buft_by_name.end()) { + result[it->second] += size; + } else { + SRV_WRN("unknown buft name '%s' from measure child for model name=%s\n", + buft_name.c_str(), name.c_str()); + } + } + } + } + } + + int exit_code = 0; + subprocess_join(&proc, &exit_code); + subprocess_destroy(&proc); + + if (exit_code != 0) { + SRV_ERR("measure process for model name=%s exited with code %d\n", name.c_str(), exit_code); + return {}; + } + + SRV_INF("memory estimation complete for model name=%s\n", name.c_str()); + return result; +} + +void server_models::join_completed_bg_tasks() { + std::vector> to_join; + { + std::lock_guard lk(mutex); + for (auto it = bg_tasks.begin(); it != bg_tasks.end(); ) { + if (it->second->done.load()) { + to_join.push_back(std::move(it->second)); + it = bg_tasks.erase(it); + } else { + ++it; + } + } + } + for (auto & task : to_join) { + if (task->th.joinable()) { + task->th.join(); + } + } +} + + void server_models::load(const std::string & name) { load(name, load_options{}); } @@ -839,9 +1011,30 @@ void server_models::load(const std::string & name, const load_options & opts) { if (!has_model(name)) { throw std::runtime_error("model name=" + name + " is not found"); } - unload_lru(); } + join_completed_bg_tasks(); + + buft_memory_map bmm_req; + if (base_params.models_memory_margin > 0) { + { + std::lock_guard lk(mutex); + bmm_req = mapping[name].meta.bmm_req; + } + if (bmm_req.empty()) { + bmm_req = estimate_model_memory(name); + if (bmm_req.empty()) { + SRV_WRN("failed to estimate memory for model %s, memory limits will not apply\n", name.c_str()); + } + { + std::lock_guard lk(mutex); + mapping[name].meta.bmm_req = bmm_req; + } + } + } + + unload_lru(bmm_req); + std::unique_lock lk(mutex); // edge case: block until any in-progress reload has finished so we always load // against the freshest preset and a consistent mapping state @@ -857,16 +1050,8 @@ void server_models::load(const std::string & name, const load_options & opts) { // exceeding models_max. Without this, the window between unload_lru() // releasing its lock and this lock_guard acquiring allows multiple // threads to each observe capacity and all proceed to load. - if (base_params.models_max > 0) { - size_t count_active = 0; - for (const auto & m : mapping) { - if (m.second.meta.is_running()) { - count_active++; - } - } - if (count_active >= (size_t)base_params.models_max) { - throw std::runtime_error("model limit reached, try again later"); - } + if (limits_exceeded(bmm_req)) { + throw std::runtime_error("model limit reached, try again later"); } // prepare new instance info @@ -1066,6 +1251,7 @@ void server_models::unload(const std::string & name) { void server_models::unload_all() { std::vector to_join; + std::vector> bg_to_join; { std::lock_guard lk(mutex); for (auto & [name, inst] : mapping) { @@ -1081,15 +1267,26 @@ void server_models::unload_all() { // moving the thread to join list to avoid deadlock to_join.push_back(std::move(inst.th)); } + for (auto & [name, task] : bg_tasks) { + bg_to_join.push_back(std::move(task)); + } + bg_tasks.clear(); } for (auto & th : to_join) { if (th.joinable()) { th.join(); } } + for (auto & task : bg_to_join) { + if (task && task->th.joinable()) { + task->th.join(); + } + } } void server_models::update_status(const std::string & name, const update_status_args & args) { + join_completed_bg_tasks(); + std::unique_lock lk(mutex); auto it = mapping.find(name); if (it != mapping.end()) { diff --git a/tools/server/server-models.h b/tools/server/server-models.h index 62bed8725b..b1b3b28072 100644 --- a/tools/server/server-models.h +++ b/tools/server/server-models.h @@ -7,6 +7,7 @@ #include "server-http.h" #include "server-queue.h" +#include #include #include #include @@ -69,6 +70,8 @@ static std::string server_model_source_to_string(server_model_source source) { } } +using buft_memory_map = std::map; + struct server_model_meta { server_model_source source = SERVER_MODEL_SOURCE_CACHE; common_preset preset; @@ -78,6 +81,7 @@ struct server_model_meta { int port = 0; server_model_status status = SERVER_MODEL_STATUS_UNLOADED; int64_t last_used = 0; // for LRU unloading + buft_memory_map bmm_req; // bytes required per buffer type std::vector args; // args passed to the model instance, will be populated by render_args() json loaded_info; // info to be reflected via /v1/models endpoint ; if in DOWNLOADING state, it should contain download progress info json progress; // reflect load or download progress info, if any @@ -123,6 +127,13 @@ private: std::condition_variable cv_stop; std::set stopping_models; + // background tasks for download/estimate/load pipelines, keyed by model name + struct bg_task { + std::thread th; + std::atomic done{false}; + }; + std::map> bg_tasks; + // set to true while load_models() is executing a reload; load() will wait until clear bool is_reloading = false; @@ -174,10 +185,16 @@ private: std::vector base_env; common_preset base_preset; // base preset from llama-server CLI args + // available memory per buffer type + buft_memory_map bmm_available; + + // buft name -> buft lookup (host buffer types map to CPU buft) + std::unordered_map buft_by_name; + void update_meta(const std::string & name, const server_model_meta & meta); // unload least recently used models if the limit is reached - void unload_lru(); + void unload_lru(const buft_memory_map & bmm_req); // not thread-safe, caller must hold mutex void add_model(server_model_meta && meta); @@ -185,6 +202,21 @@ private: // notify SSE clients void notify_sse(const std::string & event, const std::string & model_id, const json & data = nullptr); + // return number of buffer types where the memory limit would be exceeded + // return 0 if the new model would fit + // not thread-safe, caller must hold mutex + int can_fit(const buft_memory_map & bmm_req) const; + + // check if active model count or memory limits would be exceeded + // not thread-safe, caller must hold mutex + bool limits_exceeded(const buft_memory_map & bmm_req) const; + + // estimate model memory by spawning a child process with --measure-only + // returns the buft memory map, or empty map on failure (caller must NOT hold mutex) + buft_memory_map estimate_model_memory(const std::string & name); + + // join and remove completed background tasks + void join_completed_bg_tasks(); public: // conv_id -> model tracker for the resumable stream routes, owns its lock conv_model_tracker conv_models; diff --git a/tools/server/server.cpp b/tools/server/server.cpp index eafef86bac..1a9340ff2f 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -12,6 +12,8 @@ #include "llama.h" #include "log.h" +#include "../../src/llama-ext.h" + #include #include #include @@ -140,6 +142,44 @@ int llama_server(int argc, char ** argv) { // struct that contains llama context and inference server_context ctx_server; + if (params.measure_only) { + llama_model_params mparams = common_model_params_to_llama(params); + mparams.no_alloc = true; + mparams.use_mmap = false; + mparams.use_mlock = false; + + llama_model_ptr model{llama_model_load_from_file(params.model.path.c_str(), mparams)}; + if (!model) { + LOG_ERR("%s: failed to load model for measurement\n", __func__); + llama_backend_free(); + return 1; + } + + llama_context_params cparams = common_context_params_to_llama(params); + llama_context_ptr ctx{llama_init_from_model(model.get(), cparams)}; + if (!ctx) { + LOG_ERR("%s: failed to create context for measurement\n", __func__); + llama_backend_free(); + return 1; + } + + common_log_pause(common_log_main()); + for (const auto & [buft, data] : llama_get_memory_breakdown(ctx.get())) { + size_t total = data.total(); + if (total > 0) { + fprintf(stdout, "measure:%s %zu\n", ggml_backend_buft_name(buft), total); + } + } + fflush(stdout); + common_log_resume(common_log_main()); + + llama_backend_free(); + return 0; + } + + LOG_INF("build_info: %s\n", llama_build_info()); + LOG_INF("%s\n", common_params_get_system_info(params).c_str()); + server_http_context ctx_http; if (!ctx_http.init(params)) { SRV_ERR("%s", "failed to initialize HTTP server\n");