mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-26 06:10:19 +00:00
mtmd: add load progress callback (#24865)
This commit is contained in:
+33
-2
@@ -1045,8 +1045,17 @@ struct clip_model_loader {
|
||||
bool has_vision = false;
|
||||
bool has_audio = false;
|
||||
|
||||
mtmd_progress_callback progress_callback = nullptr;
|
||||
void * progress_callback_user_data = nullptr;
|
||||
|
||||
// TODO @ngxson : we should not pass clip_ctx here, it should be clip_model
|
||||
clip_model_loader(const char * fname, bool skip_tensors = false) : fname(fname) {
|
||||
clip_model_loader(const char * fname,
|
||||
bool skip_tensors = false,
|
||||
mtmd_progress_callback progress_cb = nullptr,
|
||||
void * progress_user_data = nullptr)
|
||||
: fname(fname),
|
||||
progress_callback(progress_cb),
|
||||
progress_callback_user_data(progress_user_data) {
|
||||
struct ggml_context * meta = nullptr;
|
||||
|
||||
struct gguf_init_params params = {
|
||||
@@ -2790,10 +2799,22 @@ struct clip_model_loader {
|
||||
if (!ctx_clip.no_alloc) {
|
||||
std::vector<uint8_t> read_buf;
|
||||
|
||||
// start loading event
|
||||
if (progress_callback){
|
||||
progress_callback(0.0, progress_callback_user_data);
|
||||
}
|
||||
|
||||
// compute total tensor data size for progress reporting
|
||||
size_t total_data_size = 0;
|
||||
for (auto & t : tensors_to_load) {
|
||||
total_data_size += ggml_nbytes(t);
|
||||
}
|
||||
|
||||
// alloc memory and offload data
|
||||
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(ctx_clip.backend);
|
||||
ctx_clip.buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(ctx_clip.ctx_data.get(), buft));
|
||||
ggml_backend_buffer_set_usage(ctx_clip.buf.get(), GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
|
||||
size_t data_loaded = 0;
|
||||
for (auto & t : tensors_to_load) {
|
||||
ggml_tensor * cur = ggml_get_tensor(ctx_clip.ctx_data.get(), t->name);
|
||||
GGML_ASSERT(cur && "tensor not found in ctx_data");
|
||||
@@ -2814,6 +2835,13 @@ struct clip_model_loader {
|
||||
fin.read(reinterpret_cast<char *>(read_buf.data()), num_bytes);
|
||||
ggml_backend_tensor_set(cur, read_buf.data(), 0, num_bytes);
|
||||
}
|
||||
data_loaded += num_bytes;
|
||||
if (progress_callback && total_data_size > 0) {
|
||||
const float progress = (float)data_loaded / (float)total_data_size;
|
||||
if (!progress_callback(progress, progress_callback_user_data)) {
|
||||
throw std::runtime_error(string_format("%s: model loading cancelled by progress_callback\n", __func__));
|
||||
}
|
||||
}
|
||||
}
|
||||
fin.close();
|
||||
|
||||
@@ -3105,7 +3133,10 @@ struct clip_init_result clip_init(const char * fname, struct clip_context_params
|
||||
clip_ctx * ctx_audio = nullptr;
|
||||
|
||||
try {
|
||||
clip_model_loader loader(fname);
|
||||
clip_model_loader loader(fname,
|
||||
/* skip_tensors */ false,
|
||||
ctx_params.progress_callback,
|
||||
ctx_params.progress_callback_user_data);
|
||||
bool skip_audio = false;
|
||||
|
||||
if (loader.has_vision) {
|
||||
|
||||
@@ -54,6 +54,8 @@ struct clip_context_params {
|
||||
ggml_backend_sched_eval_callback cb_eval;
|
||||
void * cb_eval_user_data;
|
||||
bool no_alloc;
|
||||
mtmd_progress_callback progress_callback;
|
||||
void * progress_callback_user_data;
|
||||
};
|
||||
|
||||
struct clip_init_result {
|
||||
|
||||
@@ -251,6 +251,8 @@ mtmd_context_params mtmd_context_params_default() {
|
||||
/* cb_eval */ nullptr,
|
||||
/* cb_eval_user_data */ nullptr,
|
||||
/* batch_max_tokens */ 1024,
|
||||
/* progress_callback */ nullptr,
|
||||
/* progress_callback_user_data */ nullptr,
|
||||
};
|
||||
return params;
|
||||
}
|
||||
@@ -345,6 +347,8 @@ struct mtmd_context {
|
||||
/* cb_eval */ ctx_params.cb_eval,
|
||||
/* cb_eval_user_data */ ctx_params.cb_eval_user_data,
|
||||
/* no_alloc */ no_alloc,
|
||||
/* progress_callback */ ctx_params.progress_callback,
|
||||
/* progress_callback_user_data */ ctx_params.progress_callback_user_data,
|
||||
};
|
||||
|
||||
auto res = clip_init(mmproj_fname, ctx_clip_params);
|
||||
@@ -2133,8 +2137,12 @@ std::map<ggml_backend_dev_t, size_t> mtmd_get_memory_usage(const char * mmproj_f
|
||||
mtmd::context_ptr ctx;
|
||||
auto saved_log_callback = g_logger_state.log_callback;
|
||||
auto saved_log_user_data = g_logger_state.log_callback_user_data;
|
||||
|
||||
ctx_params.progress_callback = nullptr;
|
||||
|
||||
try {
|
||||
mtmd_log_set(stub_log_callback, nullptr); // suppress logging
|
||||
// TODO @ngxson : fix no_alloc here
|
||||
ctx.reset(new mtmd_context(mmproj_fname, nullptr, ctx_params));
|
||||
mtmd_log_set(saved_log_callback, saved_log_user_data); // restore log callback
|
||||
std::map<ggml_backend_dev_t, size_t> total_mem;
|
||||
|
||||
@@ -83,6 +83,8 @@ typedef struct mtmd_input_chunks mtmd_input_chunks;
|
||||
typedef struct mtmd_input_text mtmd_input_text;
|
||||
typedef struct mtmd_batch mtmd_batch;
|
||||
|
||||
typedef bool (*mtmd_progress_callback)(float progress, void * user_data);
|
||||
|
||||
struct mtmd_context_params {
|
||||
bool use_gpu;
|
||||
bool print_timings;
|
||||
@@ -104,6 +106,12 @@ struct mtmd_context_params {
|
||||
int32_t batch_max_tokens; // maximum number of output tokens in a batch
|
||||
// (note: this is not a hard-limit, the first image will always be added even if it exceeds this limit)
|
||||
// (default: 1024)
|
||||
|
||||
// Called with a progress value between 0.0 and 1.0. Pass NULL to disable.
|
||||
// If the provided progress_callback returns true, model loading continues.
|
||||
// If it returns false, model loading is immediately aborted.
|
||||
mtmd_progress_callback progress_callback;
|
||||
void * progress_callback_user_data;
|
||||
};
|
||||
|
||||
MTMD_API const char * mtmd_default_marker(void);
|
||||
|
||||
@@ -833,8 +833,6 @@ private:
|
||||
|
||||
bool sleeping = false;
|
||||
|
||||
int64_t t_last_load_progress_ms = 0;
|
||||
|
||||
void destroy() {
|
||||
spec.reset();
|
||||
ctx_dft.reset();
|
||||
@@ -865,12 +863,18 @@ private:
|
||||
sleeping = new_state;
|
||||
}
|
||||
|
||||
struct load_progress_data {
|
||||
server_context_impl * ctx;
|
||||
std::string stage;
|
||||
int64_t t_last_load_progress_ms = 0;
|
||||
load_progress_data(server_context_impl * ctx, const std::string & stage) : ctx(ctx), stage(stage) {}
|
||||
};
|
||||
static bool load_progress_callback(float progress, void * user_data) {
|
||||
auto * ctx = static_cast<server_context_impl *>(user_data);
|
||||
GGML_ASSERT(ctx);
|
||||
auto * d = static_cast<load_progress_data *>(user_data);
|
||||
GGML_ASSERT(d);
|
||||
// always emit the first and final sample; throttle the rest to one per 200ms
|
||||
{
|
||||
auto & t_last = ctx->t_last_load_progress_ms;
|
||||
auto & t_last = d->t_last_load_progress_ms;
|
||||
const int64_t t_now = ggml_time_ms();
|
||||
const bool first = t_last == 0;
|
||||
const bool done = progress >= 1.0f;
|
||||
@@ -880,9 +884,9 @@ private:
|
||||
}
|
||||
t_last = t_now;
|
||||
}
|
||||
if (ctx->callback_state) {
|
||||
ctx->callback_state(SERVER_STATE_LOADING, {
|
||||
{"stage", "text_model"},
|
||||
if (d->ctx->callback_state) {
|
||||
d->ctx->callback_state(SERVER_STATE_LOADING, {
|
||||
{"stage", d->stage},
|
||||
{"value", progress},
|
||||
});
|
||||
}
|
||||
@@ -892,6 +896,9 @@ private:
|
||||
// load the model and initialize llama_context
|
||||
// this may also be called to resume from sleeping state
|
||||
bool load_model(common_params & params) {
|
||||
load_progress_data load_progress_text(this, "text_model");
|
||||
load_progress_data load_progress_mmproj(this, "mmproj_model");
|
||||
|
||||
bool is_resume = sleeping;
|
||||
|
||||
SRV_INF("loading model '%s'\n", params.model.path.c_str());
|
||||
@@ -912,6 +919,9 @@ private:
|
||||
mparams.image_max_tokens = params_base.image_max_tokens;
|
||||
mparams.batch_max_tokens = params_base.mtmd_batch_max_tokens;
|
||||
mparams.media_marker = get_media_marker();
|
||||
// progress callback
|
||||
mparams.progress_callback = load_progress_callback;
|
||||
mparams.progress_callback_user_data = &load_progress_mmproj;
|
||||
}
|
||||
|
||||
// optionally get the memory usage of mmproj
|
||||
@@ -1023,9 +1033,8 @@ private:
|
||||
|
||||
// attach a progress callback
|
||||
{
|
||||
t_last_load_progress_ms = 0;
|
||||
params_base.load_progress_callback = load_progress_callback;
|
||||
params_base.load_progress_callback_user_data = this;
|
||||
params_base.load_progress_callback_user_data = &load_progress_text;
|
||||
}
|
||||
|
||||
llama_init = common_init_from_params(params_base);
|
||||
@@ -1120,10 +1129,6 @@ private:
|
||||
}
|
||||
|
||||
if (has_mmproj) {
|
||||
if (callback_state) {
|
||||
callback_state(SERVER_STATE_LOADING, {{"stage", "mmproj_model"}});
|
||||
}
|
||||
|
||||
if (!is_resume) {
|
||||
mtmd_helper_log_set(common_log_default_callback, nullptr);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user