diff --git a/common/arg.cpp b/common/arg.cpp index 5297d90753..276dbec8ba 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -301,6 +301,8 @@ static handle_model_result common_params_handle_model(struct common_params_model const common_download_opts & opts) { handle_model_result result; + // TODO @ngxson : refactor this into a new common_model_download_context + if (!model.docker_repo.empty()) { model.path = common_docker_resolve_model(model.docker_repo); } else if (!model.hf_repo.empty()) { @@ -396,7 +398,7 @@ static bool parse_bool_value(const std::string & value) { // CLI argument parsing functions // -bool common_params_handle_models(common_params & params, llama_example curr_ex, common_download_callback * callback) { +bool common_params_handle_models(common_params & params, llama_example curr_ex, const common_params_handle_models_params & handle_params) { const bool spec_type_draft_mtp = std::find(params.speculative.types.begin(), params.speculative.types.end(), COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params.speculative.types.end(); @@ -407,9 +409,10 @@ bool common_params_handle_models(common_params & params, llama_example curr_ex, opts.skip_download = params.skip_download; opts.download_mtp = spec_type_draft_mtp; opts.download_mmproj = !params.no_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty(); + opts.preset_only = handle_params.preset_only; - if (callback) { - opts.callback = callback; + if (handle_params.callback) { + opts.callback = handle_params.callback; } // sub-models (draft, mmproj, vocoder) are explicitly specified by the user, @@ -596,7 +599,7 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context if (!skip_model_download) { // handle model and download - common_params_handle_models(params, ctx_arg.ex); + common_params_handle_models(params, ctx_arg.ex, {}); // model is required (except for server) // TODO @ngxson : maybe show a list of available models in CLI in this case diff --git a/common/arg.h b/common/arg.h index c061fc60f7..fdfc04bc7a 100644 --- a/common/arg.h +++ b/common/arg.h @@ -130,6 +130,11 @@ bool common_params_to_map(int argc, char ** argv, llama_example ex, std::map & args); +struct common_params_handle_models_params { + common_download_callback * callback = nullptr; + bool preset_only = false; // if true, only check & download remote preset (for router mode) +}; + // populate model paths (main model, mmproj, etc) from -hf if necessary // return true if the model is ready to use // throw an exception if there is an error that prevents the model from being used (e.g. network error, model not found, etc) @@ -137,7 +142,7 @@ void common_params_add_preset_options(std::vector & args); bool common_params_handle_models( common_params & params, llama_example curr_ex, - common_download_callback * callback = nullptr); + const common_params_handle_models_params & handle_params); // initialize argument parser context - used by test-arg-parser and preset common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr); diff --git a/common/download.cpp b/common/download.cpp index f320462753..5b55c76a11 100644 --- a/common/download.cpp +++ b/common/download.cpp @@ -799,6 +799,7 @@ common_download_model_result common_download_model(const common_params_model & bool download_mmproj = opts.download_mmproj; bool download_mtp = opts.download_mtp; + bool preset_only = opts.preset_only; bool is_hf = !model.hf_repo.empty(); if (is_hf) { @@ -806,7 +807,8 @@ common_download_model_result common_download_model(const common_params_model & if (!hf.preset.path.empty()) { // if preset.ini exists, only download that file alone tasks.push_back({hf.preset.url, hf.preset.local_path}); - } else { + } else if (!preset_only) { + // only add other files if we're NOT in preset-only mode (normal run, non-router) for (const auto & f : hf.model_files) { tasks.push_back({f.url, f.local_path}); } diff --git a/common/download.h b/common/download.h index 8dbf07836f..755e34ea8c 100644 --- a/common/download.h +++ b/common/download.h @@ -55,6 +55,7 @@ struct common_download_opts { bool skip_download = false; // if true, only validation is performed, common_skip_download_exception may be thrown if the file is missing or invalid bool download_mmproj = false; bool download_mtp = false; + bool preset_only = false; // if true, only check & download remote preset (for router mode) common_download_callback * callback = nullptr; }; diff --git a/tools/server/server-models.cpp b/tools/server/server-models.cpp index a87e4e423e..a4df3ef108 100644 --- a/tools/server/server-models.cpp +++ b/tools/server/server-models.cpp @@ -224,7 +224,7 @@ void server_model_meta::update_caps() { }); params.offline = true; // params.skip_download = true; // TODO: ideally, we should validate the model here, but it takes too much time - common_params_handle_models(params, LLAMA_EXAMPLE_SERVER); + common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, {}); if (params.mmproj.path.empty()) { multimodal = { false, false }; } else { @@ -1393,7 +1393,9 @@ struct server_download_state : public common_download_callback { bool run(common_params & params) { try { - common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, this); + common_params_handle_models_params p; + p.callback = this; + common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, p); is_ok = true; } catch (const std::exception & e) { auto model_name = params.model.get_name(); diff --git a/tools/server/server.cpp b/tools/server/server.cpp index dd4b1c507c..4165c1015e 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -89,6 +89,17 @@ int llama_server(int argc, char ** argv) { llama_backend_init(); llama_numa_init(params.numa); + // note: router mode also accepts -hf remote-preset, so we need to check that first + if (!params.model.hf_repo.empty()) { + try { + common_params_handle_models_params handle_params; + handle_params.preset_only = true; + common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, handle_params); + } catch (const std::exception & e) { + // ignored for now + } + } + // router server never loads a model and must not touch the GPU const bool is_router_server = params.model.path.empty() && params.model.hf_repo.empty(); @@ -263,7 +274,7 @@ int llama_server(int argc, char ** argv) { return child.run_download(params); } else if (!is_router_server) { // single-model mode (NOT spawned by router) - common_params_handle_models(params, LLAMA_EXAMPLE_SERVER); + common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, {}); } // diff --git a/tools/server/tests/unit/test_router.py b/tools/server/tests/unit/test_router.py index 41e95f4c5f..94165e520e 100644 --- a/tools/server/tests/unit/test_router.py +++ b/tools/server/tests/unit/test_router.py @@ -256,6 +256,25 @@ def test_router_reload_models(): os.remove(preset_path) +def test_router_remote_preset(): + global server + server.model_hf_repo = "ggml-org/test-preset-ci" + server.model_hf_file = None + server.offline = False + server.start() + + # Should see preset models in GET /models + res = server.make_request("GET", "/models") + assert res.status_code == 200 + ids = {item["id"] for item in res.body.get("data", [])} + assert "tinygemma3-preset" in ids + assert "stories260K-test" in ids + + # Should be able to load a preset model + model_id = "tinygemma3-preset" + _load_model_and_wait(model_id) + + MODEL_DOWNLOAD_ID = "ggml-org/test-model-router-download:F16" MODEL_DOWNLOAD_TIMEOUT = 30