diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index c713703e01..fccc1e3487 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -1045,8 +1045,17 @@ struct clip_model_loader { bool has_vision = false; bool has_audio = false; + mtmd_progress_callback progress_callback = nullptr; + void * progress_callback_user_data = nullptr; + // TODO @ngxson : we should not pass clip_ctx here, it should be clip_model - clip_model_loader(const char * fname, bool skip_tensors = false) : fname(fname) { + clip_model_loader(const char * fname, + bool skip_tensors = false, + mtmd_progress_callback progress_cb = nullptr, + void * progress_user_data = nullptr) + : fname(fname), + progress_callback(progress_cb), + progress_callback_user_data(progress_user_data) { struct ggml_context * meta = nullptr; struct gguf_init_params params = { @@ -2790,10 +2799,22 @@ struct clip_model_loader { if (!ctx_clip.no_alloc) { std::vector read_buf; + // start loading event + if (progress_callback){ + progress_callback(0.0, progress_callback_user_data); + } + + // compute total tensor data size for progress reporting + size_t total_data_size = 0; + for (auto & t : tensors_to_load) { + total_data_size += ggml_nbytes(t); + } + // alloc memory and offload data ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(ctx_clip.backend); ctx_clip.buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(ctx_clip.ctx_data.get(), buft)); ggml_backend_buffer_set_usage(ctx_clip.buf.get(), GGML_BACKEND_BUFFER_USAGE_WEIGHTS); + size_t data_loaded = 0; for (auto & t : tensors_to_load) { ggml_tensor * cur = ggml_get_tensor(ctx_clip.ctx_data.get(), t->name); GGML_ASSERT(cur && "tensor not found in ctx_data"); @@ -2814,6 +2835,13 @@ struct clip_model_loader { fin.read(reinterpret_cast(read_buf.data()), num_bytes); ggml_backend_tensor_set(cur, read_buf.data(), 0, num_bytes); } + data_loaded += num_bytes; + if (progress_callback && total_data_size > 0) { + const float progress = (float)data_loaded / (float)total_data_size; + if (!progress_callback(progress, progress_callback_user_data)) { + throw std::runtime_error(string_format("%s: model loading cancelled by progress_callback\n", __func__)); + } + } } fin.close(); @@ -3105,7 +3133,10 @@ struct clip_init_result clip_init(const char * fname, struct clip_context_params clip_ctx * ctx_audio = nullptr; try { - clip_model_loader loader(fname); + clip_model_loader loader(fname, + /* skip_tensors */ false, + ctx_params.progress_callback, + ctx_params.progress_callback_user_data); bool skip_audio = false; if (loader.has_vision) { diff --git a/tools/mtmd/clip.h b/tools/mtmd/clip.h index e0f1d298c8..967093a812 100644 --- a/tools/mtmd/clip.h +++ b/tools/mtmd/clip.h @@ -54,6 +54,8 @@ struct clip_context_params { ggml_backend_sched_eval_callback cb_eval; void * cb_eval_user_data; bool no_alloc; + mtmd_progress_callback progress_callback; + void * progress_callback_user_data; }; struct clip_init_result { diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index cbaac1d377..564bafc621 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -251,6 +251,8 @@ mtmd_context_params mtmd_context_params_default() { /* cb_eval */ nullptr, /* cb_eval_user_data */ nullptr, /* batch_max_tokens */ 1024, + /* progress_callback */ nullptr, + /* progress_callback_user_data */ nullptr, }; return params; } @@ -345,6 +347,8 @@ struct mtmd_context { /* cb_eval */ ctx_params.cb_eval, /* cb_eval_user_data */ ctx_params.cb_eval_user_data, /* no_alloc */ no_alloc, + /* progress_callback */ ctx_params.progress_callback, + /* progress_callback_user_data */ ctx_params.progress_callback_user_data, }; auto res = clip_init(mmproj_fname, ctx_clip_params); @@ -2133,8 +2137,12 @@ std::map mtmd_get_memory_usage(const char * mmproj_f mtmd::context_ptr ctx; auto saved_log_callback = g_logger_state.log_callback; auto saved_log_user_data = g_logger_state.log_callback_user_data; + + ctx_params.progress_callback = nullptr; + try { mtmd_log_set(stub_log_callback, nullptr); // suppress logging + // TODO @ngxson : fix no_alloc here ctx.reset(new mtmd_context(mmproj_fname, nullptr, ctx_params)); mtmd_log_set(saved_log_callback, saved_log_user_data); // restore log callback std::map total_mem; diff --git a/tools/mtmd/mtmd.h b/tools/mtmd/mtmd.h index 2fd149e480..25d51ef58d 100644 --- a/tools/mtmd/mtmd.h +++ b/tools/mtmd/mtmd.h @@ -83,6 +83,8 @@ typedef struct mtmd_input_chunks mtmd_input_chunks; typedef struct mtmd_input_text mtmd_input_text; typedef struct mtmd_batch mtmd_batch; +typedef bool (*mtmd_progress_callback)(float progress, void * user_data); + struct mtmd_context_params { bool use_gpu; bool print_timings; @@ -104,6 +106,12 @@ struct mtmd_context_params { int32_t batch_max_tokens; // maximum number of output tokens in a batch // (note: this is not a hard-limit, the first image will always be added even if it exceeds this limit) // (default: 1024) + + // Called with a progress value between 0.0 and 1.0. Pass NULL to disable. + // If the provided progress_callback returns true, model loading continues. + // If it returns false, model loading is immediately aborted. + mtmd_progress_callback progress_callback; + void * progress_callback_user_data; }; MTMD_API const char * mtmd_default_marker(void); diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 531b106e55..7db4cb1986 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -833,8 +833,6 @@ private: bool sleeping = false; - int64_t t_last_load_progress_ms = 0; - void destroy() { spec.reset(); ctx_dft.reset(); @@ -865,12 +863,18 @@ private: sleeping = new_state; } + struct load_progress_data { + server_context_impl * ctx; + std::string stage; + int64_t t_last_load_progress_ms = 0; + load_progress_data(server_context_impl * ctx, const std::string & stage) : ctx(ctx), stage(stage) {} + }; static bool load_progress_callback(float progress, void * user_data) { - auto * ctx = static_cast(user_data); - GGML_ASSERT(ctx); + auto * d = static_cast(user_data); + GGML_ASSERT(d); // always emit the first and final sample; throttle the rest to one per 200ms { - auto & t_last = ctx->t_last_load_progress_ms; + auto & t_last = d->t_last_load_progress_ms; const int64_t t_now = ggml_time_ms(); const bool first = t_last == 0; const bool done = progress >= 1.0f; @@ -880,9 +884,9 @@ private: } t_last = t_now; } - if (ctx->callback_state) { - ctx->callback_state(SERVER_STATE_LOADING, { - {"stage", "text_model"}, + if (d->ctx->callback_state) { + d->ctx->callback_state(SERVER_STATE_LOADING, { + {"stage", d->stage}, {"value", progress}, }); } @@ -892,6 +896,9 @@ private: // load the model and initialize llama_context // this may also be called to resume from sleeping state bool load_model(common_params & params) { + load_progress_data load_progress_text(this, "text_model"); + load_progress_data load_progress_mmproj(this, "mmproj_model"); + bool is_resume = sleeping; SRV_INF("loading model '%s'\n", params.model.path.c_str()); @@ -912,6 +919,9 @@ private: mparams.image_max_tokens = params_base.image_max_tokens; mparams.batch_max_tokens = params_base.mtmd_batch_max_tokens; mparams.media_marker = get_media_marker(); + // progress callback + mparams.progress_callback = load_progress_callback; + mparams.progress_callback_user_data = &load_progress_mmproj; } // optionally get the memory usage of mmproj @@ -1023,9 +1033,8 @@ private: // attach a progress callback { - t_last_load_progress_ms = 0; params_base.load_progress_callback = load_progress_callback; - params_base.load_progress_callback_user_data = this; + params_base.load_progress_callback_user_data = &load_progress_text; } llama_init = common_init_from_params(params_base); @@ -1120,10 +1129,6 @@ private: } if (has_mmproj) { - if (callback_state) { - callback_state(SERVER_STATE_LOADING, {{"stage", "mmproj_model"}}); - } - if (!is_resume) { mtmd_helper_log_set(common_log_default_callback, nullptr); }