diff --git a/common/chat-auto-parser-generator.cpp b/common/chat-auto-parser-generator.cpp index 6f825a131b..37ca55c8df 100644 --- a/common/chat-auto-parser-generator.cpp +++ b/common/chat-auto-parser-generator.cpp @@ -103,6 +103,10 @@ common_chat_params peg_generator::generate_parser(const common_chat_template & data.grammar_triggers = { { COMMON_GRAMMAR_TRIGGER_TYPE_WORD, trigger_marker } }; + if (autoparser.tools.format.openai_wrapper_trigger) { + // model emits the OpenAI function wrapper, trigger on it + data.grammar_triggers.push_back({ COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "{\"type\": \"function\"," }); + } } } @@ -224,13 +228,13 @@ common_peg_parser analyze_tools::build_tool_parser_json_native(parser_build_cont auto single_tool_parser = p.standard_json_tools( format.per_call_start, format.per_call_end, inputs.tools, inputs.parallel_tool_calls, inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED, name_field, args_field, format.tools_array_wrapped, - format.fun_name_is_key, format.id_field, format.gen_id_field, format.parameter_order); + format.fun_name_is_key, format.id_field, format.gen_id_field, format.parameter_order, format.openai_wrapper_trigger); tools_parser = p.trigger_rule("tool-calls", p.one_or_more(single_tool_parser + p.space())); } else { tools_parser = p.standard_json_tools( format.section_start, format.section_end, inputs.tools, inputs.parallel_tool_calls, inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED, name_field, args_field, format.tools_array_wrapped, - format.fun_name_is_key, format.id_field, format.gen_id_field, format.parameter_order); + format.fun_name_is_key, format.id_field, format.gen_id_field, format.parameter_order, format.openai_wrapper_trigger); } // Handle content wrappers if present diff --git a/common/chat-auto-parser.h b/common/chat-auto-parser.h index 7858f6572f..9e8113f244 100644 --- a/common/chat-auto-parser.h +++ b/common/chat-auto-parser.h @@ -181,6 +181,7 @@ struct tool_format_analysis { bool fun_name_is_key = false; // In JSON format function name is JSON key, i.e. { "": { ... arguments ... } } bool tools_array_wrapped = false; // Tool calls wrapped in JSON array [...] + bool openai_wrapper_trigger = false; // model emits the OpenAI function wrapper, trigger on it std::string function_field = "function"; std::string name_field = "name"; diff --git a/common/chat-diff-analyzer.cpp b/common/chat-diff-analyzer.cpp index ecd9c807c1..b166ee5a18 100644 --- a/common/chat-diff-analyzer.cpp +++ b/common/chat-diff-analyzer.cpp @@ -165,6 +165,14 @@ static std::vector void { + if (tmpl.src.find("Respond in the format {\"name\": function name") != std::string::npos && + tmpl.src.find("Do not use variables.") != std::string::npos) { + analysis.tools.format.openai_wrapper_trigger = true; + LOG_DBG(ANSI_ORANGE "[Patch: JSON name/parameters tool instruction]\n" ANSI_RESET); + } + }, }); diff --git a/common/chat-peg-parser.cpp b/common/chat-peg-parser.cpp index 23b5b38412..a3aa765d1c 100644 --- a/common/chat-peg-parser.cpp +++ b/common/chat-peg-parser.cpp @@ -745,7 +745,8 @@ common_peg_parser common_chat_peg_builder::build_json_tools_flat_keys( const std::string & effective_args_key, const std::string & call_id_key, const std::string & gen_call_id_key, - const std::vector & parameters_order) { + const std::vector & parameters_order, + bool accept_openai_wrapper) { auto tool_choices = choice(); auto name_key_parser = literal("\"" + effective_name_key + "\""); @@ -807,7 +808,13 @@ common_peg_parser common_chat_peg_builder::build_json_tools_flat_keys( return idx_a < idx_b; }); - auto ordered_body = tool_open(literal("{")) + space(); + // accept an optional leading "type": "function" field when the model emits the OpenAI wrapper + common_peg_parser type_field = eps(); + if (accept_openai_wrapper) { + type_field = optional(literal("\"type\"") + space() + literal(":") + space() + + literal("\"function\"") + space() + literal(",") + space()); + } + auto ordered_body = tool_open(literal("{")) + space() + type_field; for (size_t i = 0; i < parser_pairs.size(); i++) { ordered_body = ordered_body + parser_pairs[i].first; if (i < parser_pairs.size() - 1) { @@ -870,7 +877,8 @@ common_peg_parser common_chat_peg_builder::standard_json_tools( bool function_is_key, const std::string & call_id_key, const std::string & gen_call_id_key, - const std::vector & parameters_order) { + const std::vector & parameters_order, + bool accept_openai_wrapper) { if (!tools.is_array() || tools.empty()) { return eps(); } @@ -888,7 +896,7 @@ common_peg_parser common_chat_peg_builder::standard_json_tools( if (!name_spec.first.empty() || !args_spec.first.empty()) { tool_choices = build_json_tools_nested_keys(tools, effective_name_key, effective_args_key, call_id_key, gen_call_id_key); } else { - tool_choices = build_json_tools_flat_keys(tools, effective_name_key, effective_args_key, call_id_key, gen_call_id_key, parameters_order); + tool_choices = build_json_tools_flat_keys(tools, effective_name_key, effective_args_key, call_id_key, gen_call_id_key, parameters_order, accept_openai_wrapper); } } diff --git a/common/chat-peg-parser.h b/common/chat-peg-parser.h index a4643fbea8..b3ffd7de2d 100644 --- a/common/chat-peg-parser.h +++ b/common/chat-peg-parser.h @@ -120,7 +120,8 @@ class common_chat_peg_builder : public common_peg_parser_builder { bool function_is_key = false, const std::string & call_id_key = "", const std::string & gen_call_id_key = "", - const std::vector & parameters_order = {}); + const std::vector & parameters_order = {}, + bool accept_openai_wrapper = false); // Legacy-compatible helper for building XML/tagged style tool calls // Used by tests and manual parsers @@ -157,7 +158,8 @@ class common_chat_peg_builder : public common_peg_parser_builder { const std::string & effective_args_key, const std::string & call_id_key, const std::string & gen_call_id_key, - const std::vector & parameters_order); + const std::vector & parameters_order, + bool accept_openai_wrapper); }; inline common_peg_arena build_chat_peg_parser( diff --git a/common/chat.cpp b/common/chat.cpp index b2d7ce23dd..ded8440e66 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -2678,8 +2678,9 @@ common_chat_msg common_chat_peg_parse(const common_peg_arena & src_pars } return msg; } - LOG_DBG("Failure during parsing following input at position %lu: \n\n=== START ===\n%s\n=== END ===\n", result.end, effective_input.c_str()); - throw std::runtime_error(std::string("Failed to parse input at pos ") + std::to_string(result.end) + ": " + effective_input.substr(result.end)); + LOG_WRN("%s: unparsed %s output: %s\n", __func__, common_chat_format_name(params.format), effective_input.substr(result.end).c_str()); + LOG_DBG("%s: full %s output triggering error:\n=== BEGIN ===\n%s\n=== END ===\n", __func__, common_chat_format_name(params.format), effective_input.c_str()); + throw std::runtime_error(std::string("The model produced output that does not match the expected ") + common_chat_format_name(params.format) + " format"); } common_chat_msg msg; diff --git a/common/peg-parser.cpp b/common/peg-parser.cpp index 310bebf735..d4b491a80e 100644 --- a/common/peg-parser.cpp +++ b/common/peg-parser.cpp @@ -1507,6 +1507,7 @@ static std::string gbnf_excluding_pattern(const std::vector & strin auto pieces = matcher.collect_prefix_and_next(); std::string pattern; + std::string trailing; // optional proper-prefix of a delimiter, allowed only at the very end for (size_t i = 0; i < pieces.size(); ++i) { if (i > 0) { pattern += " | "; @@ -1522,13 +1523,32 @@ static std::string gbnf_excluding_pattern(const std::vector & strin } if (!pre.empty()) { - pattern += gbnf_format_literal(common_unicode_cpts_to_utf8(pre)) + " [^" + cls + "]"; + std::string pre_literal = gbnf_format_literal(common_unicode_cpts_to_utf8(pre)); + pattern += pre_literal + " [^" + cls + "]"; + // Each interior alternative consumes a delimiter-prefix plus a disambiguating + // char, so the repetition alone cannot match a value that *ends* on a proper + // prefix of a delimiter (e.g. a trailing "\n" when the delimiter is + // "\n\n"). The runtime until() (greedy first-match) accepts such + // values, so without this the grammar would reject input the parser accepts. + // Allow the value to terminate on any proper prefix as an optional tail. + // This makes the grammar a slight superset of the runtime language (a value + // may end on the longest prefix, which greedy first-match would not itself + // produce); harmless for constrained generation, which only needs to admit + // every runtime-valid string. + if (!trailing.empty()) { + trailing += " | "; + } + trailing += pre_literal; } else { pattern += "[^" + cls + "]"; } } - return "(" + pattern + ")*"; + std::string result = "(" + pattern + ")*"; + if (!trailing.empty()) { + result += " (" + trailing + ")?"; + } + return result; } static std::unordered_set collect_reachable_rules( diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 5ab19a7d2c..6c149bf097 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -798,7 +798,7 @@ struct vk_device_struct { vk_pipeline pipeline_add_id_f32; - vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32; + vk_pipeline pipeline_concat_i8, pipeline_concat_i16, pipeline_concat_i32, pipeline_concat_i64; vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bicubic_f32, pipeline_upscale_bilinear_antialias_f32; vk_pipeline pipeline_scale_f32; vk_pipeline pipeline_sqr_f32; @@ -4996,9 +4996,10 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0, 1}, 1); ggml_vk_create_pipeline(device, device->pipeline_set_f32, "set_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0, 0}, 1); - ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_concat_i8, "concat_i8", concat_i8_len, concat_i8_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_concat_i16, "concat_i16", concat_i16_len, concat_i16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_concat_i32, "concat_i32", concat_i32_len, concat_i32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_concat_i64, "concat_i64", concat_i64_len, concat_i64_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_upscale_nearest_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_NEAREST}, 1); ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR}, 1); @@ -10318,17 +10319,27 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_add_id_f32; } return nullptr; - case GGML_OP_CONCAT: - if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ctx->device->pipeline_concat_f32; + case GGML_OP_CONCAT: { + if (src0->type != src1->type || src0->type != dst->type) { + return nullptr; } - if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { - return ctx->device->pipeline_concat_f16; + if (ggml_blck_size(src0->type) != 1) { + return nullptr; } - if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) { + const size_t type_size = ggml_type_size(src0->type); + switch (type_size) { + case 1: + return ctx->device->pipeline_concat_i8; + case 2: + return ctx->device->pipeline_concat_i16; + case 4: return ctx->device->pipeline_concat_i32; + case 8: + return ctx->device->pipeline_concat_i64; + default: + return nullptr; } - return nullptr; + } case GGML_OP_UPSCALE: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { uint32_t mode = (ggml_get_op_params_i32(dst, 0) & (0xFF | GGML_SCALE_FLAG_ANTIALIAS)); @@ -17042,8 +17053,14 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_SET: return op->src[0]->type == op->src[1]->type && op->src[0]->type == op->type && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_I32); - case GGML_OP_CONCAT: - return ggml_type_size(op->src[0]->type) == ggml_type_size(GGML_TYPE_F32); + case GGML_OP_CONCAT: { + if (op->src[0]->type != op->src[1]->type || op->src[0]->type != op->type) { + return false; + } + const size_t type_size = ggml_type_size(op->type); + return ggml_blck_size(op->type) == 1 && + (type_size == 1 || type_size == 2 || type_size == 4 || type_size == 8); + } case GGML_OP_ADD1: return (op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32) || (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F32) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index dbbd0b1932..c1f9bd1d47 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -862,9 +862,10 @@ void process_shaders() { string_to_spv("pad_f32", "pad.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("concat_f32", "concat.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("concat_f16", "concat.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); - string_to_spv("concat_i32", "concat.comp", {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}}); + string_to_spv("concat_i8", "concat.comp", {{"A_TYPE", "uint8_t"}, {"B_TYPE", "uint8_t"}, {"D_TYPE", "uint8_t"}}); + string_to_spv("concat_i16", "concat.comp", {{"A_TYPE", "uint16_t"}, {"B_TYPE", "uint16_t"}, {"D_TYPE", "uint16_t"}}); + string_to_spv("concat_i32", "concat.comp", {{"A_TYPE", "uint"}, {"B_TYPE", "uint"}, {"D_TYPE", "uint"}}); + string_to_spv("concat_i64", "concat.comp", {{"A_TYPE", "uvec2"}, {"B_TYPE", "uvec2"}, {"D_TYPE", "uvec2"}}); string_to_spv("upscale_f32", "upscale.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); diff --git a/tests/peg-parser/test-gbnf-generation.cpp b/tests/peg-parser/test-gbnf-generation.cpp index fe4bbbdd16..00111e6a19 100644 --- a/tests/peg-parser/test-gbnf-generation.cpp +++ b/tests/peg-parser/test-gbnf-generation.cpp @@ -129,7 +129,7 @@ void test_gbnf_generation(testing &t) { }); assert_gbnf_equal(t, R"""( - root ::= ([^<] | "<" [^/] | "])* + root ::= ([^<] | "<" [^/] | "])* ("<" | "type == GGML_TYPE_I8 || tensor->type == GGML_TYPE_I16 || tensor->type == GGML_TYPE_I32) { + } else if (tensor->type == GGML_TYPE_I8 || tensor->type == GGML_TYPE_I16) { // This is going to create some weird integers though. - ggml_backend_tensor_set(tensor, data.data(), 0, ggml_nbytes(tensor)); + ggml_backend_tensor_set(tensor, data.data(), 0, nels * ggml_type_size(tensor->type)); } else if (tensor->type == GGML_TYPE_I64) { // Integers with a size of 8 bytes can be set by mirroring the float data, the specific values are again not really meaningful. - const size_t nbytes_half = ggml_nbytes(tensor)/2; + const size_t nbytes_half = nels * sizeof(float); ggml_backend_tensor_set(tensor, data.data(), 0*nbytes_half, nbytes_half); ggml_backend_tensor_set(tensor, data.data(), 1*nbytes_half, nbytes_half); } else { diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index e65af46c63..548071c906 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -2024,6 +2024,61 @@ static void test_template_output_peg_parsers(bool detailed_debug) { }) .run(); + tst.test( + "\n" + "\n" + "\n" + "foo.c\n" + "\n" + "\n" + "#iclunde\n" + "\n" + "\n" + "#include\n" + "\n" + "\n" + "") + .enable_thinking(false) + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .tools({ + edit_tool + }) + .expect_tool_calls({ + { "edit", "{\"filename\": \"foo.c\", \"oldString\": \"#iclunde\", \"newString\": \"#include\"}", {} }, + }) + .run(); + + // a parameter value that itself ends in a newline (e.g. a source file with a + // trailing newline). The structural delimiter is "\n\n", so the value + // "#include\n" renders as "...#include\n\n\n". The trailing newline must + // be preserved faithfully (no stripping), and the generated grammar must admit a + // value ending on a delimiter prefix. Regression test for gbnf_excluding_pattern. + tst.test( + "\n" + "\n" + "\n" + "foo.c\n" + "\n" + "\n" + "#iclunde\n" + "\n" + "\n" + "#include\n" + "\n" + "\n" + "\n" + "") + .enable_thinking(false) + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .tools({ + edit_tool + }) + .expect_tool_calls({ + { "edit", "{\"filename\": \"foo.c\", \"oldString\": \"#iclunde\", \"newString\": \"#include\\n\"}", {} }, + }) + .run(); + + // test code that starts with indent tst.test( "\n" diff --git a/tools/mtmd/mtmd-helper.cpp b/tools/mtmd/mtmd-helper.cpp index 2d11a33804..b5c4089232 100644 --- a/tools/mtmd/mtmd-helper.cpp +++ b/tools/mtmd/mtmd-helper.cpp @@ -247,7 +247,9 @@ int32_t mtmd_helper_decode_image_chunk( llama_pos n_past, llama_seq_id seq_id, int32_t n_batch, - llama_pos * new_n_past) { + llama_pos * new_n_past, + mtmd_helper_post_decode_callback callback, + void * user_data) { GGML_ASSERT(n_batch > 0); auto chunk_type = mtmd_input_chunk_get_type(chunk); const char * name = chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE ? "image" : "audio"; @@ -302,10 +304,23 @@ int32_t mtmd_helper_decode_image_chunk( int32_t ret = llama_decode(lctx, batch_embd_view); if (ret != 0) { LOG_ERR("failed to decode %s\n", name); - llama_set_causal_attn(lctx, true); // restore causal attn + if (use_non_causal) { + llama_set_causal_attn(lctx, true); + } return ret; } + if (callback != nullptr) { + ret = callback(batch_embd_view, user_data); + if (ret != 0) { + LOG_ERR("post-decode callback failed\n"); + if (use_non_causal) { + llama_set_causal_attn(lctx, true); + } + return ret; + } + } + LOG_INF("%s decoded (batch %d/%d) in %" PRId64 " ms\n", name, i_batch+1, n_img_batches, ggml_time_ms() - t1); i_batch++; @@ -379,7 +394,7 @@ int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx, LOG_INF("%s slice encoded in %" PRId64 " ms\n", name, ggml_time_ms() - t0); float * embd = mtmd_get_output_embd(ctx); - ret = mtmd_helper_decode_image_chunk(ctx, lctx, chunk, embd, n_past, seq_id, n_batch, new_n_past); + ret = mtmd_helper_decode_image_chunk(ctx, lctx, chunk, embd, n_past, seq_id, n_batch, new_n_past, nullptr, nullptr); if (ret != 0) { LOG_ERR("failed to decode %s\n", name); llama_batch_free(text_batch); diff --git a/tools/mtmd/mtmd-helper.h b/tools/mtmd/mtmd-helper.h index 719aae9885..680a2317df 100644 --- a/tools/mtmd/mtmd-helper.h +++ b/tools/mtmd/mtmd-helper.h @@ -91,6 +91,8 @@ MTMD_API int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx, bool logits_last, llama_pos * new_n_past); +typedef int32_t (*mtmd_helper_post_decode_callback)(struct llama_batch batch, void * user_data); + // helper function to decode an image whose embeddings have already been calculated // this helper will handle batching and pre/post decoding setup (for ex. gemma 3 requires non-causal attention) // ret 0 on success, -1 on chunk not being a valid image chunk, 1 on decode failure @@ -101,7 +103,9 @@ MTMD_API int32_t mtmd_helper_decode_image_chunk(mtmd_context * ctx, llama_pos n_past, llama_seq_id seq_id, int32_t n_batch, - llama_pos * new_n_past); + llama_pos * new_n_past, + mtmd_helper_post_decode_callback callback, + void * user_data); // // video input helpers (requires ffmpeg/ffprobe installed on the system) diff --git a/tools/server/server-common.cpp b/tools/server/server-common.cpp index 2b89a8bc5a..75729e62dd 100644 --- a/tools/server/server-common.cpp +++ b/tools/server/server-common.cpp @@ -539,37 +539,6 @@ bool server_tokens::validate(const struct llama_context * ctx) const { return true; } -int32_t server_tokens::process_chunk( - llama_context * ctx, - mtmd_context * mctx, - size_t idx, - llama_pos pos, - int32_t seq_id, - size_t & n_tokens_out) const { - const auto & chunk = find_chunk(idx); - const char * name = mtmd_input_chunk_get_type(chunk.get()) == MTMD_INPUT_CHUNK_TYPE_IMAGE - ? "image" : "audio"; - SRV_INF("processing %s...\n", name); - int32_t n_batch = llama_n_batch(ctx); - int64_t t0 = ggml_time_ms(); - llama_pos new_n_past; // unused for now - int32_t result = mtmd_helper_eval_chunk_single(mctx, ctx, - chunk.get(), - pos, - seq_id, - n_batch, - true, // logits last - &new_n_past); - SRV_INF("%s processed in %" PRId64 " ms\n", name, ggml_time_ms() - t0); - if (result != 0) { - LOG_ERR("mtmd_helper_eval failed with status %d", result); - n_tokens_out = 0; - return result; - } - n_tokens_out = mtmd_input_chunk_get_n_tokens(chunk.get()); - return 0; -} - server_tokens server_tokens::clone() const { server_tokens res; res.has_mtmd = has_mtmd; diff --git a/tools/server/server-common.h b/tools/server/server-common.h index 857ffe1479..f286b3d156 100644 --- a/tools/server/server-common.h +++ b/tools/server/server-common.h @@ -221,15 +221,6 @@ public: // make sure all text tokens are within the vocab range bool validate(const struct llama_context * ctx) const; - // encode and decode the image chunk - int32_t process_chunk( - llama_context * ctx, - mtmd_context * mctx, - size_t idx, - llama_pos pos, - int32_t seq_id, - size_t & n_tokens_out) const; - server_tokens clone() const; }; diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 986b2f15d5..bcae39a109 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -15,11 +15,6 @@ #include "mtmd.h" #include "mtmd-helper.h" -#include "ggml-cpp.h" - -// TODO: tmp until the mtmd draft processing is refactored [TAG_MTMD_DRAFT_PROCESSING] -#include "../../src/llama-ext.h" - #include #include #include @@ -81,7 +76,6 @@ struct server_slot { // multimodal mtmd_context * mctx = nullptr; mtmd::batch_ptr mbatch = nullptr; - std::array mtgt = {nullptr, nullptr}; // [0] for main context, [1] for optional draft context // speculative decoding common_speculative * spec; @@ -244,15 +238,6 @@ struct server_slot { // clear multimodal state mbatch.reset(); - mtgt[0] = ctx_tgt; - mtgt[1] = nullptr; - if (ctx_dft && llama_get_ctx_other(ctx_dft) != ctx_tgt) { - // TODO: in the future, figure out how to infuse target embeddings to the images - // for now, we re-decode the same chunk in both ctx_tgt and ctx_dft - // maybe we simply need to call `common_speculative_process()` ? - // [TAG_MTMD_DRAFT_PROCESSING] - mtgt[1] = ctx_dft; - } } void init_sampler() const { @@ -598,32 +583,38 @@ struct server_slot { int process_mtmd_chunk(size_t idx, size_t & n_tokens_out) { GGML_ASSERT(mctx); const auto & input_tokens = task->tokens; - auto & chunk = input_tokens.find_chunk(idx); + const auto & chunk = input_tokens.find_chunk(idx); int32_t res = 0; auto try_decode = [&]() -> int32_t { if (mbatch) { float * embd = mtmd_batch_get_output_embd(mbatch.get(), chunk.get()); if (embd) { - for (auto * lctx : mtgt) { - if (lctx == nullptr) { - continue; - } - llama_pos new_n_past; // unused for now - res = mtmd_helper_decode_image_chunk( - mctx, - lctx, - chunk.get(), - embd, - prompt.tokens.pos_next(), - id, - llama_n_batch(lctx), - &new_n_past - ); - if (res != 0) { - SLT_ERR(*this, "failed to decode mtmd chunk, idx = %zu, res = %d\n", idx, res); - return -1; + void * cb_data = spec; + static auto cb = [](llama_batch batch, void * user_data) { + common_speculative * spec = static_cast(user_data); + if (!common_speculative_process(spec, batch)) { + return 1; } + return 0; + }; + + llama_pos new_n_past; // unused for now + res = mtmd_helper_decode_image_chunk( + mctx, + ctx_tgt, + chunk.get(), + embd, + prompt.tokens.pos_next(), + id, + llama_n_batch(ctx_tgt), + &new_n_past, + cb, + cb_data + ); + if (res != 0) { + SLT_ERR(*this, "failed to decode mtmd chunk, idx = %zu, res = %d\n", idx, res); + return -1; } n_tokens_out = mtmd_input_chunk_get_n_tokens(chunk.get()); return 0; // success @@ -636,7 +627,8 @@ struct server_slot { res = try_decode(); if (res == 0) { return 0; - } else if (res < 0) { + } + if (res < 0) { // fatal error return res; } @@ -3350,48 +3342,6 @@ private: // TODO: avoid restoring the draft context and re-evaluating the drafted tokens when not needed [TAG_SPEC_AVOID_DRAFT_REEVAL] // for now, always re-evaluate for simplicity // ref: https://github.com/ggml-org/llama.cpp/pull/22728#issuecomment-4400925384 - // - // | spec type | need re-eval | - // | --- | --- | - // | draft model | no | because the draft model does not use embeddings from the target - // | MTP (std) | yes | - // | MTP Gemma4 | no | because the KV cache is shared - // | Eagle3 | yes | - // | DFlash | yes | https://github.com/ggml-org/llama.cpp/pull/22728#issuecomment-4405406982 - // - // note: this logic is now moved in `common_speculative_process()` - // keeping the sketch here until for a bit, until the logic is finalized - // - //if (ctx_dft) { - // // TODO: update as needed for MTP, Eagle3, etc. - // const bool need_tgt_embd = false; - - // if (need_tgt_embd) { - // llama_synchronize(ctx_tgt); - // } - - // // the logic here varies depending on the speculative decoding method - // // - some draft contexts require embeddings from the target context, others don't - // // - some draft contexts involve an encoder step to transform the target embeddings to draft embeddings - // // TODO: extract this in a function ? - // { - // // TODO: hook the embeddings from the last target batch here - // if (llama_model_has_encoder(model_dft.get())) { - // //llama_encode(ctx_dft, ...); - - // GGML_ABORT("not implemented yet\n"); - // } - - // const int ret = llama_decode(ctx_dft.get(), batch_view); - - // if (ret != 0) { - // SRV_ERR("failed to decode draft batch, ret = %d\n", ret); - - // // TODO: handle error - // break; - // } - // } - //} if (!common_speculative_process(spec.get(), batch_view)) { SRV_ERR("%s", "failed to process speculative batch\n");