[https://nvbugs/5394409][feat] Support Mistral Small 3.1 multimodal in Triton Backend (#6714)

Signed-off-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
Signed-off-by: Dimitrios Bariamis <dbari@users.noreply.github.com>
Co-authored-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
Co-authored-by: Iman Tabrizian <10105175+Tabrizian@users.noreply.github.com>
This commit is contained in:
Dimitrios Bariamis 2025-08-21 18:08:38 +02:00 committed by GitHub
parent 9a2b44d0f2
commit f49dafe0da
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
28 changed files with 821 additions and 80 deletions

2
.gitattributes vendored
View File

@ -7,3 +7,5 @@
triton_backend/tools/gpt/input_data.json filter=lfs diff=lfs merge=lfs -text
*cubin.cpp filter=lfs diff=lfs merge=lfs -text
docs/source/blogs/media/tech_blog3_mla_absorb.png filter=lfs diff=lfs merge=lfs -text
tests/integration/test_input_files/*.png filter=lfs diff=lfs merge=lfs -text
tests/integration/test_input_files/*.jpg filter=lfs diff=lfs merge=lfs -text

View File

@ -1763,3 +1763,75 @@ def prepare_rcca_nvbug_4714193_engine(tensorrt_llm_example_root,
assert os.path.exists(engine_dir), f"{engine_dir} does not exists."
return engine_dir
def prepare_mistral3_pixtral_engine(tensorrt_llm_multimodal_example_root,
tensorrt_llm_llama_example_root,
mistral_small_model_root):
# Convert Mistral3 from HF
model_base_name = os.path.basename(mistral_small_model_root.rstrip("/"))
ckpt_dir = os.path.join(tensorrt_llm_multimodal_example_root, "model_dir",
model_base_name)
convert_cmd = [
"python3",
f"{tensorrt_llm_llama_example_root}/convert_checkpoint.py",
"--dtype=bfloat16",
f"--model_dir={mistral_small_model_root}",
f"--output_dir={ckpt_dir}",
]
# Build Mistral3 LLM engine
engine_dir = os.path.join(tensorrt_llm_multimodal_example_root,
"engine_dir", model_base_name)
build_cmd = [
"trtllm-build",
f"--checkpoint_dir={ckpt_dir}",
"--max_batch_size=4",
"--max_input_len=8192",
"--max_seq_len=8192",
# Allow an arbitrary number of image tokens by setting:
# max_multimodal_len = max_batch_size * max_input_len
"--max_multimodal_len=32768",
"--use_paged_context_fmha=enable",
f"--output_dir={engine_dir}",
]
# Build Pixtral visual encoder engine
multimodal_engine_dir = os.path.join(tensorrt_llm_multimodal_example_root,
"tmp", "trt_engines", model_base_name,
"multimodal_encoder")
build_visual_engine_cmd = [
"python3",
"build_multimodal_engine.py",
"--model_type=pixtral",
f"--model_path={mistral_small_model_root}",
f"--output_dir={multimodal_engine_dir}",
"--max_batch_size=2",
]
append_timing_cache_args(build_cmd)
convert_cmd = " ".join(convert_cmd)
build_cmd = " ".join(build_cmd)
build_visual_engine_cmd = " ".join(build_visual_engine_cmd)
if not os.path.exists(engine_dir) or not os.path.exists(
multimodal_engine_dir):
check_call(install_requirement_cmd,
shell=True,
cwd=tensorrt_llm_llama_example_root)
check_call(convert_cmd, shell=True)
check_call(build_cmd, shell=True)
check_call(build_visual_engine_cmd,
shell=True,
cwd=tensorrt_llm_multimodal_example_root)
else:
print_info(f"Reusing engine: {engine_dir}")
print_info(f"Skipped: {convert_cmd}")
print_info(f"Skipped: {build_cmd}")
print_info(f"Skipped: {build_visual_engine_cmd}")
assert os.path.exists(engine_dir), f"{engine_dir} does not exists."
assert os.path.exists(
multimodal_engine_dir), f"{multimodal_engine_dir} does not exists."
return engine_dir, multimodal_engine_dir

View File

@ -247,7 +247,8 @@ def modify_ib_config_pbtxt(REPO_PATH,
CROSS_KV_CACHE_FRACTION="",
ENCODER_INPUT_FEATURES_DTYPE="TYPE_FP16",
GUIDED_DECODING_BACKEND="",
XGRAMMAR_TOKENIZER_INFO_PATH=""):
XGRAMMAR_TOKENIZER_INFO_PATH="",
PROMPT_EMBEDDING_TABLE_DTYPE="TYPE_FP16"):
fill_template_py = os.path.join(llm_backend_repo_root, "tools",
"fill_template.py")
tensorrt_llm_config = os.path.join(llm_backend_repo_root, REPO_PATH,
@ -274,6 +275,7 @@ def modify_ib_config_pbtxt(REPO_PATH,
check_call(
f"python3 {fill_template_py} -i {multimodal_enc_config} triton_max_batch_size:{TRITON_MAX_BATCH_SIZE}," \
f"multimodal_model_path:{MULTIMODAL_ENGINE_PATH},encoder_input_features_data_type:{ENCODER_INPUT_FEATURES_DTYPE}," \
f"prompt_embedding_table_data_type:{PROMPT_EMBEDDING_TABLE_DTYPE}," \
f"hf_model_path:{TOKENIZER_PATH}",
shell=True)
check_call(
@ -305,6 +307,7 @@ def modify_ib_config_pbtxt(REPO_PATH,
f"lookahead_ngram_size:{EXECUTOR_LOOKAHEAD_NGRAM}," \
f"lookahead_verification_set_size:{EXECUTOR_LOOKAHEAD_VERIFICATION_SET}," \
f"encoder_input_features_data_type:{ENCODER_INPUT_FEATURES_DTYPE}," \
f"prompt_embedding_table_data_type:{PROMPT_EMBEDDING_TABLE_DTYPE}," \
f"participant_ids:{PARTICIPANT_IDS_DRAFT}," \
f"logits_datatype:TYPE_FP32'",
shell=True)
@ -329,6 +332,7 @@ def modify_ib_config_pbtxt(REPO_PATH,
f"lookahead_ngram_size:{EXECUTOR_LOOKAHEAD_NGRAM}," \
f"lookahead_verification_set_size:{EXECUTOR_LOOKAHEAD_VERIFICATION_SET}," \
f"encoder_input_features_data_type:{ENCODER_INPUT_FEATURES_DTYPE}," \
f"prompt_embedding_table_data_type:{PROMPT_EMBEDDING_TABLE_DTYPE}," \
f"participant_ids:{PARTICIPANT_IDS_TARGET}," \
f"logits_datatype:TYPE_FP32'",
shell=True)
@ -348,7 +352,8 @@ def modify_ib_config_pbtxt(REPO_PATH,
check_call(
f"python3 {fill_template_py} -i {tensorrt_llm_bls_config} triton_max_batch_size:{TRITON_MAX_BATCH_SIZE}," \
f"decoupled_mode:{DECOUPLED_MODE},accumulate_tokens:{ACCUMULATE_TOKEN},bls_instance_count:{BLS_INSTANCE_COUNT}," \
f"tensorrt_llm_model_name:{TENSORRT_LLM_TARGET_MODEL_NAME},tensorrt_llm_draft_model_name:{TENSORRT_LLM_DRAFT_MODEL_NAME},logits_datatype:TYPE_FP32",
f"tensorrt_llm_model_name:{TENSORRT_LLM_TARGET_MODEL_NAME},tensorrt_llm_draft_model_name:{TENSORRT_LLM_DRAFT_MODEL_NAME},logits_datatype:TYPE_FP32," \
f"prompt_embedding_table_data_type:{PROMPT_EMBEDDING_TABLE_DTYPE}",
shell=True)
check_call(
@ -363,6 +368,7 @@ def modify_ib_config_pbtxt(REPO_PATH,
f"gpu_weights_percent:{GPU_WEIGHTS_PERCENT},encoder_engine_dir:{ENCODER_ENGINE_PATH},max_queue_size:{MAX_QUEUE_SIZE}," \
f"enable_context_fmha_fp32_acc:{ENABLE_CONTEXT_FMHA_FP32_ACC}," \
f"encoder_input_features_data_type:{ENCODER_INPUT_FEATURES_DTYPE}," \
f"prompt_embedding_table_data_type:{PROMPT_EMBEDDING_TABLE_DTYPE}," \
f"participant_ids:{PARTICIPANT_IDS}," \
f"logits_datatype:TYPE_FP32,guided_decoding_backend:{GUIDED_DECODING_BACKEND},tokenizer_dir:{TOKENIZER_PATH},xgrammar_tokenizer_info_path:{XGRAMMAR_TOKENIZER_INFO_PATH}'",
shell=True)

View File

@ -564,6 +564,19 @@ def tiny_llama_model_root():
return tiny_llama_model_root
@pytest.fixture(scope="session")
def mistral_small_3_1_24b_model_root():
models_root = llm_models_root()
assert models_root, "Did you set LLM_MODELS_ROOT?"
model_root = os.path.join(models_root,
"Mistral-Small-3.1-24B-Instruct-2503")
assert os.path.exists(
model_root
), f"{model_root} does not exist under NFS LLM_MODELS_ROOT dir"
return model_root
# Returns an array of total memory for each available device
@pytest.fixture(scope="session")
def total_gpu_memory_mib():

View File

@ -163,6 +163,7 @@ print_test_params () {
echo "DECODING_MODE: ${DECODING_MODE}"
echo "MAX_QUEUE_SIZE: ${MAX_QUEUE_SIZE}"
echo "ENABLE_CONTEXT_FMHA_FP32_ACC: ${ENABLE_CONTEXT_FMHA_FP32_ACC}"
echo "PROMPT_EMBEDDING_TABLE_DTYPE: ${PROMPT_EMBEDDING_TABLE_DTYPE}"
echo "run_all_tests: ${run_all_tests}"
echo "----------------------------------"
}
@ -180,26 +181,26 @@ fill_triton_repo () {
fi
echo "Filling triton repository at ${TRITON_REPO}/tensorrt_llm with engine ${DECODER_ENGINE_PATH}"
python3 tools/fill_template.py -i ${TRITON_REPO}/tensorrt_llm/config.pbtxt triton_backend:${BACKEND},engine_dir:${DECODER_ENGINE_PATH},decoupled_mode:${DECOUPLED_MODE},max_tokens_in_paged_kv_cache:${MAX_TOKENS_IN_KV_CACHE},max_attention_window_size:${MAX_ATTENTION_WINDOW_SIZE},batch_scheduler_policy:${BATCH_SCHEDULER_POLICY},batching_strategy:${BATCHING_STRATEGY},kv_cache_free_gpu_mem_fraction:${KV_CACHE_FREE_GPU_MEM_FRACTION},exclude_input_in_output:${EXCLUDE_INPUT_IN_OUTPUT},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},max_beam_width:${MAX_BEAM_WIDTH},enable_kv_cache_reuse:${ENABLE_KV_CACHE_REUSE},normalize_log_probs:${NORMALIZE_LOG_PROBS},enable_chunked_context:${ENABLE_CHUNKED_CONTEXT},gpu_device_ids:${GPU_DEVICE_IDS},decoding_mode:${DECODING_MODE},max_queue_size:${MAX_QUEUE_SIZE},enable_context_fmha_fp32_acc:${ENABLE_CONTEXT_FMHA_FP32_ACC},encoder_input_features_data_type:${ENCODER_INPUT_FEATURES_DTYPE},logits_datatype:TYPE_FP32,lookahead_window_size:${LOOKAHEAD_WINDOW_SIZE},lookahead_ngram_size:${LOOKAHEAD_NGRAM_SIZE},lookahead_verification_set_size:${LOOKAHEAD_VERIFICATION_SET_SIZE}
python3 tools/fill_template.py -i ${TRITON_REPO}/tensorrt_llm/config.pbtxt triton_backend:${BACKEND},engine_dir:${DECODER_ENGINE_PATH},decoupled_mode:${DECOUPLED_MODE},max_tokens_in_paged_kv_cache:${MAX_TOKENS_IN_KV_CACHE},max_attention_window_size:${MAX_ATTENTION_WINDOW_SIZE},batch_scheduler_policy:${BATCH_SCHEDULER_POLICY},batching_strategy:${BATCHING_STRATEGY},kv_cache_free_gpu_mem_fraction:${KV_CACHE_FREE_GPU_MEM_FRACTION},exclude_input_in_output:${EXCLUDE_INPUT_IN_OUTPUT},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},max_beam_width:${MAX_BEAM_WIDTH},enable_kv_cache_reuse:${ENABLE_KV_CACHE_REUSE},normalize_log_probs:${NORMALIZE_LOG_PROBS},enable_chunked_context:${ENABLE_CHUNKED_CONTEXT},gpu_device_ids:${GPU_DEVICE_IDS},decoding_mode:${DECODING_MODE},max_queue_size:${MAX_QUEUE_SIZE},enable_context_fmha_fp32_acc:${ENABLE_CONTEXT_FMHA_FP32_ACC},encoder_input_features_data_type:${ENCODER_INPUT_FEATURES_DTYPE},prompt_embedding_table_data_type:${PROMPT_EMBEDDING_TABLE_DTYPE},logits_datatype:TYPE_FP32,lookahead_window_size:${LOOKAHEAD_WINDOW_SIZE},lookahead_ngram_size:${LOOKAHEAD_NGRAM_SIZE},lookahead_verification_set_size:${LOOKAHEAD_VERIFICATION_SET_SIZE}
python3 tools/fill_template.py -i ${TRITON_REPO}/preprocessing/config.pbtxt tokenizer_dir:${TOKENIZER_PATH},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},preprocessing_instance_count:${PREPROCESSING_INSTANCE_COUNT}
python3 tools/fill_template.py -i ${TRITON_REPO}/postprocessing/config.pbtxt tokenizer_dir:${TOKENIZER_PATH},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},postprocessing_instance_count:${POSTPROCESSING_INSTANCE_COUNT}
python3 tools/fill_template.py -i ${TRITON_REPO}/ensemble/config.pbtxt triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},logits_datatype:TYPE_FP32
if [ "${DRAFT_ENGINE_PATH}" != "" ] && [ "${DRAFT_ENGINE_PATH}" != "skip" ] && [ "${TARGET_ENGINE_PATH}" != "" ] && [ "${TARGET_ENGINE_PATH}" != "skip" ]; then
python3 tools/fill_template.py -i ${TRITON_REPO}/tensorrt_llm_bls/config.pbtxt triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},accumulate_tokens:${ACCUMULATE_TOKEN},bls_instance_count:${BLS_INSTANCE_COUNT},tensorrt_llm_model_name:${TENSORRT_LLM_TARGET_MODEL_NAME},logits_datatype:TYPE_FP32,tensorrt_llm_draft_model_name:${TENSORRT_LLM_DRAFT_MODEL_NAME}
python3 tools/fill_template.py -i ${TRITON_REPO}/tensorrt_llm_bls/config.pbtxt triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},accumulate_tokens:${ACCUMULATE_TOKEN},bls_instance_count:${BLS_INSTANCE_COUNT},tensorrt_llm_model_name:${TENSORRT_LLM_TARGET_MODEL_NAME},logits_datatype:TYPE_FP32,tensorrt_llm_draft_model_name:${TENSORRT_LLM_DRAFT_MODEL_NAME},prompt_embedding_table_data_type:${PROMPT_EMBEDDING_TABLE_DTYPE}
else
python3 tools/fill_template.py -i ${TRITON_REPO}/tensorrt_llm_bls/config.pbtxt triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},accumulate_tokens:${ACCUMULATE_TOKEN},bls_instance_count:${BLS_INSTANCE_COUNT},tensorrt_llm_model_name:${TENSORRT_LLM_MODEL_NAME},logits_datatype:TYPE_FP32,tensorrt_llm_draft_model_name:""
python3 tools/fill_template.py -i ${TRITON_REPO}/tensorrt_llm_bls/config.pbtxt triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},accumulate_tokens:${ACCUMULATE_TOKEN},bls_instance_count:${BLS_INSTANCE_COUNT},tensorrt_llm_model_name:${TENSORRT_LLM_MODEL_NAME},logits_datatype:TYPE_FP32,tensorrt_llm_draft_model_name:"",prompt_embedding_table_data_type:${PROMPT_EMBEDDING_TABLE_DTYPE}
fi
if [ "${DRAFT_ENGINE_PATH}" != "" ] && [ "${DRAFT_ENGINE_PATH}" != "skip" ]; then
echo "Filling triton repository at ${TRITON_REPO}/tensorrt_llm_draft with engine ${DRAFT_ENGINE_PATH}"
python3 tools/fill_template.py -i ${TRITON_REPO}/tensorrt_llm_draft/config.pbtxt triton_backend:${BACKEND},engine_dir:${DRAFT_ENGINE_PATH},decoupled_mode:${DECOUPLED_MODE},max_tokens_in_paged_kv_cache:${MAX_TOKENS_IN_KV_CACHE},max_attention_window_size:${MAX_ATTENTION_WINDOW_SIZE},batch_scheduler_policy:${BATCH_SCHEDULER_POLICY},batching_strategy:${BATCHING_STRATEGY},kv_cache_free_gpu_mem_fraction:${KV_CACHE_FREE_GPU_MEM_FRACTION},exclude_input_in_output:${EXCLUDE_INPUT_IN_OUTPUT},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},max_beam_width:${MAX_BEAM_WIDTH},enable_kv_cache_reuse:${ENABLE_KV_CACHE_REUSE},normalize_log_probs:${NORMALIZE_LOG_PROBS},enable_chunked_context:${ENABLE_CHUNKED_CONTEXT},gpu_device_ids:${GPU_DEVICE_IDS},decoding_mode:${DECODING_MODE},max_queue_size:${MAX_QUEUE_SIZE},encoder_input_features_data_type:${ENCODER_INPUT_FEATURES_DTYPE},logits_datatype:TYPE_FP32
python3 tools/fill_template.py -i ${TRITON_REPO}/tensorrt_llm_draft/config.pbtxt triton_backend:${BACKEND},engine_dir:${DRAFT_ENGINE_PATH},decoupled_mode:${DECOUPLED_MODE},max_tokens_in_paged_kv_cache:${MAX_TOKENS_IN_KV_CACHE},max_attention_window_size:${MAX_ATTENTION_WINDOW_SIZE},batch_scheduler_policy:${BATCH_SCHEDULER_POLICY},batching_strategy:${BATCHING_STRATEGY},kv_cache_free_gpu_mem_fraction:${KV_CACHE_FREE_GPU_MEM_FRACTION},exclude_input_in_output:${EXCLUDE_INPUT_IN_OUTPUT},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},max_beam_width:${MAX_BEAM_WIDTH},enable_kv_cache_reuse:${ENABLE_KV_CACHE_REUSE},normalize_log_probs:${NORMALIZE_LOG_PROBS},enable_chunked_context:${ENABLE_CHUNKED_CONTEXT},gpu_device_ids:${GPU_DEVICE_IDS},decoding_mode:${DECODING_MODE},max_queue_size:${MAX_QUEUE_SIZE},encoder_input_features_data_type:${ENCODER_INPUT_FEATURES_DTYPE},prompt_embedding_table_data_type:${PROMPT_EMBEDDING_TABLE_DTYPE},logits_datatype:TYPE_FP32
fi
if [ "${TARGET_ENGINE_PATH}" != "" ] && [ "${TARGET_ENGINE_PATH}" != "skip" ]; then
echo "Filling triton repository at ${TRITON_REPO}/tensorrt_llm_target with engine ${TARGET_ENGINE_PATH}"
python3 tools/fill_template.py -i ${TRITON_REPO}/tensorrt_llm_target/config.pbtxt triton_backend:${BACKEND},engine_dir:${TARGET_ENGINE_PATH},decoupled_mode:${DECOUPLED_MODE},max_tokens_in_paged_kv_cache:${MAX_TOKENS_IN_KV_CACHE},max_attention_window_size:${MAX_ATTENTION_WINDOW_SIZE},batch_scheduler_policy:${BATCH_SCHEDULER_POLICY},batching_strategy:${BATCHING_STRATEGY},kv_cache_free_gpu_mem_fraction:${KV_CACHE_FREE_GPU_MEM_FRACTION},exclude_input_in_output:${EXCLUDE_INPUT_IN_OUTPUT},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},max_beam_width:${MAX_BEAM_WIDTH},enable_kv_cache_reuse:true,normalize_log_probs:${NORMALIZE_LOG_PROBS},enable_chunked_context:${ENABLE_CHUNKED_CONTEXT},gpu_device_ids:${GPU_DEVICE_IDS},decoding_mode:${DECODING_MODE},max_queue_size:${MAX_QUEUE_SIZE},encoder_input_features_data_type:${ENCODER_INPUT_FEATURES_DTYPE},logits_datatype:TYPE_FP32
python3 tools/fill_template.py -i ${TRITON_REPO}/tensorrt_llm_target/config.pbtxt triton_backend:${BACKEND},engine_dir:${TARGET_ENGINE_PATH},decoupled_mode:${DECOUPLED_MODE},max_tokens_in_paged_kv_cache:${MAX_TOKENS_IN_KV_CACHE},max_attention_window_size:${MAX_ATTENTION_WINDOW_SIZE},batch_scheduler_policy:${BATCH_SCHEDULER_POLICY},batching_strategy:${BATCHING_STRATEGY},kv_cache_free_gpu_mem_fraction:${KV_CACHE_FREE_GPU_MEM_FRACTION},exclude_input_in_output:${EXCLUDE_INPUT_IN_OUTPUT},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},max_beam_width:${MAX_BEAM_WIDTH},enable_kv_cache_reuse:true,normalize_log_probs:${NORMALIZE_LOG_PROBS},enable_chunked_context:${ENABLE_CHUNKED_CONTEXT},gpu_device_ids:${GPU_DEVICE_IDS},decoding_mode:${DECODING_MODE},max_queue_size:${MAX_QUEUE_SIZE},encoder_input_features_data_type:${ENCODER_INPUT_FEATURES_DTYPE},prompt_embedding_table_data_type:${PROMPT_EMBEDDING_TABLE_DTYPE},logits_datatype:TYPE_FP32
fi
@ -217,7 +218,7 @@ fill_triton_repo () {
cp all_models/multimodal/multimodal_encoders ${TRITON_REPO} -r
python3 tools/fill_template.py -i ${TRITON_REPO}/ensemble/config.pbtxt triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},logits_datatype:TYPE_FP32
python3 tools/fill_template.py -i ${TRITON_REPO}/preprocessing/config.pbtxt multimodal_model_path:${MULTIMODAL_ENGINE_PATH},engine_dir:${DECODER_ENGINE_PATH}
python3 tools/fill_template.py -i ${TRITON_REPO}/multimodal_encoders/config.pbtxt triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},multimodal_model_path:${MULTIMODAL_ENGINE_PATH},encoder_input_features_data_type:${ENCODER_INPUT_FEATURES_DTYPE},hf_model_path:${TOKENIZER_PATH}
python3 tools/fill_template.py -i ${TRITON_REPO}/multimodal_encoders/config.pbtxt triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},multimodal_model_path:${MULTIMODAL_ENGINE_PATH},encoder_input_features_data_type:${ENCODER_INPUT_FEATURES_DTYPE},prompt_embedding_table_data_type:${PROMPT_EMBEDDING_TABLE_DTYPE},hf_model_path:${TOKENIZER_PATH}
python3 tools/fill_template.py -i ${TRITON_REPO}/tensorrt_llm_bls/config.pbtxt multimodal_encoders_name:multimodal_encoders
fi
@ -649,6 +650,7 @@ TRITON_METRICS_PORT="8002"
GPU_DEVICE_IDS=""
DECODING_MODE="top_k_top_p"
MAX_QUEUE_SIZE="0"
PROMPT_EMBEDDING_TABLE_DTYPE="TYPE_FP16"
if [ "$MODEL" = "gpt-ib" ] || [ "$MODEL" = "mistral-ib" ] || [ "$MODEL" = "mistral-ib-mm" ]; then

View File

@ -1,4 +1,5 @@
import os
import re
import sys
import pytest
@ -3893,3 +3894,198 @@ def test_tiny_llama_ifb_token_counts(
print_info(
f"Successfully tested token count functionality for {TOKEN_COUNT_TEST} mode"
)
@pytest.mark.skip_less_device_memory(80000)
@pytest.mark.parametrize("E2E_MODEL_NAME", ["ensemble", "tensorrt_llm_bls"])
@pytest.mark.parametrize("ACCUMULATE_TOKEN", ["False"])
@pytest.mark.parametrize("BLS_INSTANCE_COUNT", ["1"])
@pytest.mark.parametrize("PREPROCESSING_INSTANCE_COUNT", ["1"])
@pytest.mark.parametrize("POSTPROCESSING_INSTANCE_COUNT", ["1"])
@pytest.mark.parametrize("MAX_TOKENS_IN_KV_CACHE", [""])
@pytest.mark.parametrize("MAX_ATTENTION_WINDOW_SIZE", [""])
@pytest.mark.parametrize("BATCH_SCHEDULER_POLICY",
["max_utilization", "guaranteed_no_evict"])
@pytest.mark.parametrize("KV_CACHE_FREE_GPU_MEM_FRACTION", ["0.7"])
@pytest.mark.parametrize("CROSS_KV_CACHE_FRACTION", [""])
@pytest.mark.parametrize("ENABLE_TRT_OVERLAP", ["False"],
ids=["disableTrtOverlap"])
@pytest.mark.parametrize("BATCHING_STRATEGY", ["inflight_fused_batching"])
@pytest.mark.parametrize("DECOUPLED_MODE", ["True", "False"],
ids=["enableDecoupleMode", "disableDecoupleMode"])
@pytest.mark.parametrize("TRITON_MAX_BATCH_SIZE", ["1"])
@pytest.mark.parametrize("MAX_QUEUE_DELAY_MICROSECONDS", ["0"])
@pytest.mark.parametrize("ENABLE_KV_CACHE_REUSE", ["False"])
@pytest.mark.parametrize("NORMALIZE_LOG_PROBS", ["True"])
@pytest.mark.parametrize("ENABLE_CHUNKED_CONTEXT", ["False"])
@pytest.mark.parametrize("GPU_DEVICE_IDS", [""])
@pytest.mark.parametrize("DECODING_MODE", [""])
@pytest.mark.parametrize("MAX_BEAM_WIDTH", ["1"])
@pytest.mark.parametrize("EXCLUDE_INPUT_IN_OUTPUT", ["False"])
@pytest.mark.parametrize("PROMPT_EMBEDDING_TABLE_DTYPE",
["TYPE_BF16"]) # allow override later
@pytest.mark.parametrize("ENCODER_INPUT_FEATURES_DTYPE",
["TYPE_FP16"]) # pixtral uses fp16 vision by default
def test_mistral_small_3_1_24b_pixtral(
E2E_MODEL_NAME,
MAX_TOKENS_IN_KV_CACHE,
MAX_ATTENTION_WINDOW_SIZE,
BATCH_SCHEDULER_POLICY,
KV_CACHE_FREE_GPU_MEM_FRACTION,
CROSS_KV_CACHE_FRACTION,
ENABLE_TRT_OVERLAP,
BATCHING_STRATEGY,
DECOUPLED_MODE,
TRITON_MAX_BATCH_SIZE,
MAX_QUEUE_DELAY_MICROSECONDS,
MAX_BEAM_WIDTH,
ENABLE_KV_CACHE_REUSE,
NORMALIZE_LOG_PROBS,
ENABLE_CHUNKED_CONTEXT,
GPU_DEVICE_IDS,
DECODING_MODE,
PREPROCESSING_INSTANCE_COUNT,
POSTPROCESSING_INSTANCE_COUNT,
ACCUMULATE_TOKEN,
BLS_INSTANCE_COUNT,
EXCLUDE_INPUT_IN_OUTPUT,
PROMPT_EMBEDDING_TABLE_DTYPE,
ENCODER_INPUT_FEATURES_DTYPE,
tensorrt_llm_multimodal_example_root,
tensorrt_llm_llama_example_root,
mistral_small_3_1_24b_model_root,
llm_backend_multimodal_example_root,
llm_backend_venv,
llm_root,
):
if BATCHING_STRATEGY == "V1" and BATCH_SCHEDULER_POLICY == "max_utilization":
pytest.skip("Skipping. V1 doesn't support max_utilization.")
llm_backend_repo_root = os.environ["LLM_BACKEND_ROOT"]
# Build Engines (LLM + vision)
ENGINE_PATH, MULTIMODAL_ENGINE_DIR = prepare_mistral3_pixtral_engine(
tensorrt_llm_multimodal_example_root, tensorrt_llm_llama_example_root,
mistral_small_3_1_24b_model_root)
# Prepare model repo
new_model_repo = os.path.join(llm_backend_repo_root, "triton_repo")
prepare_ib_model_repo(llm_backend_repo_root, new_model_repo)
# Prepare multimodal specific repo
prepare_multimodal_model_repo(llm_backend_repo_root, new_model_repo,
"ensemble")
prepare_multimodal_model_repo(llm_backend_repo_root, new_model_repo,
"multimodal_encoders")
# Modify config.pbtxt
TOKENIZER_PATH = mistral_small_3_1_24b_model_root
modify_ib_config_pbtxt(
new_model_repo,
ENGINE_PATH,
TOKENIZER_PATH,
llm_backend_repo_root,
DECOUPLED_MODE,
MAX_TOKENS_IN_KV_CACHE,
MAX_ATTENTION_WINDOW_SIZE,
BATCH_SCHEDULER_POLICY,
BATCHING_STRATEGY,
KV_CACHE_FREE_GPU_MEM_FRACTION,
EXCLUDE_INPUT_IN_OUTPUT,
ENABLE_TRT_OVERLAP,
TRITON_MAX_BATCH_SIZE,
MAX_QUEUE_DELAY_MICROSECONDS,
MAX_BEAM_WIDTH,
ENABLE_KV_CACHE_REUSE,
NORMALIZE_LOG_PROBS,
ENABLE_CHUNKED_CONTEXT,
GPU_DEVICE_IDS,
DECODING_MODE,
PREPROCESSING_INSTANCE_COUNT,
POSTPROCESSING_INSTANCE_COUNT,
ACCUMULATE_TOKEN,
BLS_INSTANCE_COUNT,
MULTIMODAL_ENGINE_PATH=MULTIMODAL_ENGINE_DIR,
ENCODER_INPUT_FEATURES_DTYPE=ENCODER_INPUT_FEATURES_DTYPE,
PROMPT_EMBEDDING_TABLE_DTYPE=PROMPT_EMBEDDING_TABLE_DTYPE,
)
# Launch Triton Server
launch_server_py = os.path.join(llm_backend_repo_root, "scripts",
"launch_triton_server.py")
check_call(
f"PMIX_MCA_gds=hash python3 {launch_server_py} --world_size=1 --model_repo={new_model_repo}",
shell=True)
check_server_ready()
image_merlion = os.path.join(
llm_root,
"tests/integration/test_input_files/merlion.png",
)
image_football = os.path.join(
llm_root,
"tests/integration/test_input_files/pexels-franco-monsalvo-252430633-32285228.jpg",
)
image_hockey = os.path.join(
llm_root,
"tests/integration/test_input_files/pexels-ron-lach-8975010.jpg",
)
image_basketball = os.path.join(
llm_root,
"tests/integration/test_input_files/pexels-maxim-shklyaev-1511525-2914194.jpg",
)
test_cases = [
{
"text": "What is the capital of England?",
"image": "",
"match": re.compile("london", re.IGNORECASE)
},
{
"text": "In as few words as possible, what city is this?",
"image": image_merlion,
"match": re.compile("singapore", re.IGNORECASE)
},
{
"text":
"In as few words as possible, what sports are depicted in the images?",
"image":
",".join([image_football, image_hockey]),
"match":
re.compile("(football|soccer).*hockey", re.IGNORECASE | re.DOTALL)
},
{
"text":
"In as few words as possible, what sports are depicted in the images?",
"image":
",".join([image_football, image_hockey, image_basketball]),
"match":
re.compile("(football|soccer).*hockey.*basket",
re.IGNORECASE | re.DOTALL)
},
]
for test_case in test_cases:
TEXT = test_case["text"]
IMAGE = test_case["image"]
MATCH = test_case["match"]
# Run Test: use multimodal client; set model_type to pixtral
run_cmd = [
f"{llm_backend_multimodal_example_root}/client.py",
"--model_type=pixtral",
f"--text={TEXT}",
f"--image={IMAGE}",
"--request-output-len=128",
"--end-id=2",
]
if DECOUPLED_MODE == "True":
run_cmd += ["--streaming"]
if E2E_MODEL_NAME == "tensorrt_llm_bls":
run_cmd += ["--use_bls"]
output = venv_check_output(llm_backend_venv, run_cmd)
assert MATCH.search(
output), f"Test failed for input: {TEXT=}, {IMAGE=}, {output=}"

Binary file not shown.

Before

Width:  |  Height:  |  Size: 20 KiB

After

Width:  |  Height:  |  Size: 130 B

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:f1f3b6a507ec92e8f47ac6d7c64e11b03fcba8c550bcb6851f80e261e8951431
size 1604159

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:4bd1efd0c8fe48b421210cd132dc3b3b2902ccf1523bb9bec3a3883bb5c7a650
size 116299

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:dd922b837bc92353d49a60df1dd933eddfe7546e2b16b365acaadb9b2a0a683b
size 72231

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:31c6fedadcb79990687d00d24350f774f4ad319439c89ed67d47c1df35a556fb
size 83652

View File

@ -99,3 +99,11 @@ l0_a100:
- triton_server/test_triton.py::test_eagle[eagle]
- triton_server/test_triton.py::test_llava_onevision[llava_onevision]
- triton_server/test_triton.py::test_qwen2_vl[qwen2_vl]
- triton_server/test_triton_llm.py::test_mistral_small_3_1_24b_pixtral[TYPE_FP16-TYPE_BF16-False-1---False-True-False-0-1-enableDecoupleMode-inflight_fused_batching-disableTrtOverlap--0.7-max_utilization---1-1-1-False-ensemble]
- triton_server/test_triton_llm.py::test_mistral_small_3_1_24b_pixtral[TYPE_FP16-TYPE_BF16-False-1---False-True-False-0-1-enableDecoupleMode-inflight_fused_batching-disableTrtOverlap--0.7-max_utilization---1-1-1-False-tensorrt_llm_bls]
- triton_server/test_triton_llm.py::test_mistral_small_3_1_24b_pixtral[TYPE_FP16-TYPE_BF16-False-1---False-True-False-0-1-enableDecoupleMode-inflight_fused_batching-disableTrtOverlap--0.7-guaranteed_no_evict---1-1-1-False-ensemble]
- triton_server/test_triton_llm.py::test_mistral_small_3_1_24b_pixtral[TYPE_FP16-TYPE_BF16-False-1---False-True-False-0-1-enableDecoupleMode-inflight_fused_batching-disableTrtOverlap--0.7-guaranteed_no_evict---1-1-1-False-tensorrt_llm_bls]
- triton_server/test_triton_llm.py::test_mistral_small_3_1_24b_pixtral[TYPE_FP16-TYPE_BF16-False-1---False-True-False-0-1-disableDecoupleMode-inflight_fused_batching-disableTrtOverlap--0.7-max_utilization---1-1-1-False-ensemble]
- triton_server/test_triton_llm.py::test_mistral_small_3_1_24b_pixtral[TYPE_FP16-TYPE_BF16-False-1---False-True-False-0-1-disableDecoupleMode-inflight_fused_batching-disableTrtOverlap--0.7-max_utilization---1-1-1-False-tensorrt_llm_bls]
- triton_server/test_triton_llm.py::test_mistral_small_3_1_24b_pixtral[TYPE_FP16-TYPE_BF16-False-1---False-True-False-0-1-disableDecoupleMode-inflight_fused_batching-disableTrtOverlap--0.7-guaranteed_no_evict---1-1-1-False-ensemble]
- triton_server/test_triton_llm.py::test_mistral_small_3_1_24b_pixtral[TYPE_FP16-TYPE_BF16-False-1---False-True-False-0-1-disableDecoupleMode-inflight_fused_batching-disableTrtOverlap--0.7-guaranteed_no_evict---1-1-1-False-tensorrt_llm_bls]

View File

@ -29,13 +29,13 @@ import io
import json
import os
from collections import defaultdict
from typing import List
from typing import Dict, List, Tuple
import numpy as np
import requests
import triton_python_backend_utils as pb_utils
from PIL import Image
from transformers import AutoProcessor, AutoTokenizer, T5Tokenizer
from transformers import AutoConfig, AutoProcessor, AutoTokenizer, T5Tokenizer
class TritonPythonModel:
@ -136,9 +136,9 @@ class TritonPythonModel:
'model_type']
assert self.model_type in [
'llava', 'blip2-opt', 'vila', 'mllama', 'llava_onevision',
'qwen2_vl'
], f"[TensorRT-LLM][ERROR] Currently supported multi-modal models are llava, blip2-opt, vila, mllama, llava_onevision and qwen2_vl. Got {self.model_type}."
'llava', 'blip2-opt', 'pixtral', 'vila', 'mllama',
'llava_onevision', 'qwen2_vl'
], f"[TensorRT-LLM][ERROR] Currently supported multi-modal models are llava, blip2-opt, pixtral, vila, mllama, llava_onevision and qwen2_vl. Got {self.model_type}."
assert self.model_type != 'llava_onevison' or self.max_num_images is None or self.max_num_images <= 1, f"LLaVA-OneVsion is not support multi image inference currently."
@ -151,10 +151,18 @@ class TritonPythonModel:
llm_model_config["pretrained_config"]["vocab_size"])
self._setup_ptable_shape(llm_model_config)
if self.model_type in ['mllama', 'llava_onevision', 'qwen2_vl']:
if self.model_type in [
'mllama', 'llava_onevision', 'qwen2_vl', 'pixtral'
]:
full_processor = AutoProcessor.from_pretrained(
tokenizer_dir, trust_remote_code=True)
self.hf_config = AutoConfig.from_pretrained(tokenizer_dir)
self.vision_preprocessor = VisionPreProcessor(
self.model_type,
AutoProcessor.from_pretrained(tokenizer_dir), model_config)
full_processor,
model_config,
self.hf_config,
)
# Parse model output configs and convert Triton types to numpy types
output_names = [
@ -285,7 +293,9 @@ class TritonPythonModel:
request, 'VIDEO_BYTES')
vision_processed_tensors = []
visual_tokens = []
if self.is_multimodal and (img_urls or image_bytes or video_bytes):
# Pixtral supports text-only input
if self.is_multimodal and (img_urls or image_bytes or video_bytes
or self.model_type == 'pixtral'):
assert self.vision_preprocessor != None, "Vision preprocessor for preparing images before encoding is None"
processed_tensors = {}
if self.model_type == 'mllama':
@ -317,6 +327,19 @@ class TritonPythonModel:
qwen2vl_input_length_tensor = processed_tensors.get(
"REQUEST_INPUT_LEN")
processed_tensors.pop("REQUEST_INPUT_LEN")
elif self.model_type == 'pixtral':
image_sizes = pb_utils.get_input_tensor_by_name(
request, 'IMAGE_SIZES')
processed_tensors, visual_tokens = self.vision_preprocessor.pixtral_process(
queries=query.astype(str).tolist(),
img_urls=img_urls,
image_bytes=image_bytes,
image_sizes=image_sizes,
)
pixtral_input_id_tensor = processed_tensors.pop("INPUT_IDS")
request_input_len = np.array(
[[len(input_ids_for_batch)]
for input_ids_for_batch in pixtral_input_id_tensor])
else:
raise ValueError(
"Unsupported model type for IMAGE_BYTES or IMAGE_URL inputs"
@ -330,8 +353,9 @@ class TritonPythonModel:
# Preprocessing input data.
# For the LLaVA_OneVision model, num_multimodal_features is not a fixed value
input_id, request_input_len = self._create_request(
query, visual_tokens)
if self.model_type != 'pixtral':
input_id, request_input_len = self._create_request(
query, visual_tokens)
if decoder_query is not None:
decoder_input_id, request_decoder_input_len = self._create_request(
decoder_query)
@ -362,6 +386,13 @@ class TritonPythonModel:
'INPUT_ID', qwen2vl_input_id_tensor)
request_input_len_tensor = pb_utils.Tensor.from_dlpack(
'REQUEST_INPUT_LEN', qwen2vl_input_length_tensor)
elif self.model_type == 'pixtral':
input_id_tensor = pb_utils.Tensor(
'INPUT_ID',
pixtral_input_id_tensor.numpy().astype(self.input_id_dtype))
request_input_len_tensor = pb_utils.Tensor(
'REQUEST_INPUT_LEN',
request_input_len.astype(self.request_input_len_dtype))
else:
input_id_tensor = pb_utils.Tensor(
'INPUT_ID', input_id.astype(self.input_id_dtype))
@ -719,7 +750,10 @@ class VisionPreProcessor:
def __init__(self,
vision_model_type,
vision_model_processor,
preprocessor_model_config={}):
preprocessor_model_config=None,
hf_config=None):
preprocessor_model_config = preprocessor_model_config or {}
# import libraries that are only relevant for multimodal models
import torch
from torch.utils.dlpack import from_dlpack
@ -767,6 +801,12 @@ class VisionPreProcessor:
self.vision_model_processor = vision_model_processor
self.vision_model_type = vision_model_type
if vision_model_type == 'pixtral':
assert hf_config is not None, "Pixtral model requires hf_config to be set"
self.vocab_size = hf_config.text_config.vocab_size
self.image_size = hf_config.vision_config.image_size
self.image_token_index = hf_config.image_token_index
def load_images_from_urls(self, img_urls):
images = []
for img_url in img_urls:
@ -777,10 +817,11 @@ class VisionPreProcessor:
image_data = base64.b64decode(image_base64)
# Create a BytesIO object from the decoded data
image_buffer = io.BytesIO(image_data)
images.append(Image.open(image_buffer))
images.append(Image.open(image_buffer).convert("RGB"))
else:
images.append(Image.open(
requests.get(img_url, stream=True).raw))
images.append(
Image.open(requests.get(img_url,
stream=True).raw).convert("RGB"))
return images
def mllama_process(self, queries, img_urls=None, image_bytes=None):
@ -879,6 +920,9 @@ class VisionPreProcessor:
mode='constant')
for image in preprocessor_outputs['PIXEL_VALUES']
]
# Add a dimension image_sizes to match the dimensions defined in config.pbtxt
for elem in preprocessor_outputs['IMAGE_SIZES']:
elem.unsqueeze_(1)
for key, tensor_list in preprocessor_outputs.items():
val = self.convert_tensor_list_to_tensor(tensor_list)
if key in self.output_str_dtypes:
@ -1001,3 +1045,130 @@ class VisionPreProcessor:
val, self.output_str_dtypes[key])
vision_processed_tensors[key] = val
return vision_processed_tensors
def pixtral_process(self,
queries,
img_urls=None,
image_bytes=None,
image_sizes=None
) -> Tuple[Dict[str, "torch.Tensor"], List[int]]:
import torch
vision_processed_tensors = {}
if img_urls is not None:
assert image_sizes is None, "IMAGE_SIZES should not be supplied together with IMAGE_URL"
# download and read images
images = [
self.load_images_from_urls(urls)
for urls in img_urls.as_numpy()
]
images = [[np.array(img) for img in batch] for batch in images]
# pad to the max_h, max_w dimensions to create one tensor for all images
shapes = [img.shape for batch in images for img in batch]
assert all(
len(s) == 3
for s in shapes), "All input images must have three dimensions"
assert all(
s[-1] == shapes[0][-1] for s in shapes
), "All input images must have the same number of channels"
max_h, max_w = max(s[0] for s in shapes), max(s[1] for s in shapes)
for batch_idx in range(len(images)):
for image_idx in range(len(images[batch_idx])):
images[batch_idx][image_idx] = np.pad(
images[batch_idx][image_idx],
((0, max_h - images[batch_idx][image_idx].shape[0]),
(0, max_w - images[batch_idx][image_idx].shape[1]),
(0, 0)),
mode='constant',
)
images = np.array(images)
elif image_bytes is not None:
images = self.load_images_tensor(image_bytes)
else:
images = np.empty((len(queries), 0, 0, 0, 0), dtype=np.uint8)
batch_size = len(images)
assert len(
queries
) == batch_size, f"Image must have the same batch size as Query."
if image_sizes is not None:
image_sizes = self.load_images_tensor(image_sizes)
else:
s = images.shape
image_sizes = np.array([[[s[2], s[3]]] * s[1]] * s[0])
preprocessor_outputs = {}
possible_output_names = ['PIXEL_VALUES', 'IMAGE_SIZES', 'INPUT_IDS']
visual_tokens = []
for batch_id in range(batch_size):
# Preprocess images and query
query = queries[batch_id]
if not isinstance(query, (str, bytes)):
query = query[0]
if isinstance(query, bytes):
query = query.decode("utf-8")
if "[IMG]" not in query:
query = "[IMG]" * len(images[batch_id]) + query
assert query.count("[IMG]") == len(
images[batch_id]
), "Number of [IMG] tags must match number of images"
if not query.startswith("[INST]"):
query = "[INST]" + query
if not query.endswith("[/INST]"):
query = query + "[/INST]"
sizes = image_sizes[batch_id]
curr_images = [
img[:sizes[idx][0], :sizes[idx][1], :]
for idx, img in enumerate(images[batch_id])
]
if not curr_images:
curr_images = None
processed_vision_data = self.vision_model_processor(
images=curr_images, text=query, return_tensors="pt")
visual_tokens.append(processed_vision_data['input_ids'].shape[1])
if "pixel_values" in processed_vision_data:
# Pad to self.image_size x self.image_size
processed_vision_data['pixel_values'] = torch.nn.functional.pad(
processed_vision_data['pixel_values'], (
0,
self.image_size -
processed_vision_data['pixel_values'].shape[-1],
0,
self.image_size -
processed_vision_data['pixel_values'].shape[-2],
),
mode='constant')
# Create vision output tensors
for key in possible_output_names:
val = processed_vision_data.get(key.lower())
if val is not None:
if key not in preprocessor_outputs:
preprocessor_outputs[key] = []
if key != 'INPUT_IDS':
val.unsqueeze_(0) # unsqueeze to add batch dimension
preprocessor_outputs[key].append(val)
for key, tensor_list in preprocessor_outputs.items():
val = self.convert_tensor_list_to_tensor(tensor_list)
if key in self.output_str_dtypes:
val = self.convert_tensor_to_str_dtype(
val, self.output_str_dtypes[key])
vision_processed_tensors[key] = val
# Replace all image tokens with a unique token_id > vocab_size.
# This shall be used to lookup the prompt table.
for batch_id in range(batch_size):
# Note: We reset replacer to vocab_size for each sample. This is as opposed to doing `replacer = vocab_size + img_idx * tokens_per_task`.
# That part of the look-up manipulation is done by the `task_ids` input to PromptEmbedding forward.
replacer = self.vocab_size
input_ids = vision_processed_tensors['INPUT_IDS'][batch_id]
for token_idx in range(len(input_ids)):
if input_ids[token_idx] == self.image_token_index:
input_ids[token_idx] = replacer
replacer += 1
return vision_processed_tensors, visual_tokens

View File

@ -55,7 +55,14 @@ input [
{
name: "IMAGE_URL"
data_type: TYPE_STRING
dims: [ 1 ]
dims: [ -1 ]
optional: true
},
# Required for pixtral
{
name: "IMAGE_SIZES"
data_type: TYPE_INT64
dims: [ -1, 2 ]
optional: true
},
{
@ -188,11 +195,11 @@ output [
data_type: TYPE_INT64
dims: [ -1, -1, -1 ]
},
# Required for image postprocessing in the llava_onevision model
# Required for image postprocessing in the llava_onevision and pixtral models
{
name: "IMAGE_SIZES"
data_type: TYPE_INT64
dims: [ 2 ]
dims: [ -1, 2 ]
},
# Indicates if the input is video in the llava_onevision model
{

View File

@ -280,7 +280,9 @@ def get_prompt_tuning_config_from_request(request,
kwargs = {}
prompt_embedding_table = get_input_tensor_by_name(request,
'prompt_embedding_table',
batch_size, batch_index)
batch_size,
batch_index,
force_on_torch=True)
prompt_table_extra_ids = get_input_tensor_by_name(request,
'prompt_table_extra_ids',
batch_size, batch_index)

View File

@ -319,7 +319,7 @@ input [
},
{
name: "prompt_embedding_table"
data_type: TYPE_FP16
data_type: ${prompt_embedding_table_data_type}
dims: [ -1, -1 ]
optional: true
allow_ragged_batch: true

View File

@ -103,6 +103,7 @@ class Request:
request_id: Optional[str] = None
mrope_rotary_cos_sin: Optional[np.ndarray] = None
mrope_position_deltas: Optional[np.ndarray] = None
image_sizes_input: Optional[np.ndarray] = None
def validate(self):
_validate_non_empty(self.text_input, "text_input is required")

View File

@ -165,7 +165,12 @@ class TritonDecoder(Decoder):
continue
triton_name = tensor.name()
if tensor.is_cpu():
value = tensor.as_numpy()
try:
value = tensor.as_numpy()
except pb_utils.TritonModelException as e:
# Use to_dlpack()/from_dlpack() if as_numpy() fails,
# e.g. in case of BF16 tensors
value = from_dlpack(tensor.to_dlpack())
else:
# If the tensor is in GPU memory make it torch.Tensor type
value = from_dlpack(tensor.to_dlpack())
@ -247,6 +252,7 @@ class TritonDecoder(Decoder):
"text_input": "QUERY",
"image_bytes_input": "IMAGE_BYTES",
"image_url_input": "IMAGE_URL",
"image_sizes_input": "IMAGE_SIZES",
"video_bytes_input": "VIDEO_BYTES",
"decoder_text_input": "DECODER_QUERY",
"max_tokens": "REQUEST_OUTPUT_LEN",

View File

@ -62,6 +62,13 @@ input [
dims: [ 1 ]
optional: true
},
# An arbitrary number of images for pixtral
{
name: "image_sizes_input"
data_type: TYPE_INT64
dims: [ -1, 2 ]
optional: true
},
{
name: "video_bytes_input"
data_type: TYPE_UINT8
@ -199,7 +206,7 @@ input [
},
{
name: "prompt_embedding_table"
data_type: TYPE_FP16
data_type: ${prompt_embedding_table_data_type}
dims: [ -1, -1 ]
optional: true
},

View File

@ -54,9 +54,16 @@ input [
{
name: "image_url_input"
data_type: TYPE_STRING
dims: [ 1 ]
dims: [ -1 ]
optional: true
},
# An arbitrary number of images for pixtral
{
name: "image_sizes_input"
data_type: TYPE_INT64
dims: [ -1, 2 ]
optional: true
},
{
name: "video_bytes_input"
data_type: TYPE_UINT8
@ -253,6 +260,10 @@ ensemble_scheduling {
key: "IMAGE_URL"
value: "image_url_input"
}
input_map {
key: "IMAGE_SIZES"
value: "image_sizes_input"
}
input_map {
key: "VIDEO_BYTES"
value: "video_bytes_input"

View File

@ -112,6 +112,8 @@ class TritonPythonModel:
self.image_session = Session.from_serialized_engine(engine_buffer)
self.vision_dtype_str = visual_config['builder_config']['precision']
self.vision_max_batch_size = visual_config['builder_config'][
'max_batch_size']
features_output_name = "OUT_PROMPT_EMBEDDING_TABLE"
if self.model_type == "mllama":
features_output_name = "ENCODER_INPUT_FEATURES"
@ -162,7 +164,21 @@ class TritonPythonModel:
self.vocab_size = hf_config.vocab_size
self.qwen2vl_utils = Qwen2VLUtils(hf_config)
def get_requests(self, request: List) -> Dict[str, torch.Tensor]:
if self.model_type == 'pixtral':
from transformers import AutoConfig
hf_model_path = model_config['parameters'].get(
'hf_model_path', None)
assert hf_model_path is not None and hf_model_path[
'string_value'] != "${hf_model_path}", "Need to provide hf_model_path for the Pixtral model"
hf_config = AutoConfig.from_pretrained(
hf_model_path['string_value'])
self.image_size = hf_config.vision_config.image_size
self.patch_size = hf_config.vision_config.patch_size
self.vocab_size = hf_config.text_config.vocab_size
self.spatial_merge_size = hf_config.spatial_merge_size
self.relevant_patch_size = self.patch_size * self.spatial_merge_size
def get_requests(self, request) -> Dict[str, torch.Tensor]:
"""
Processes the incoming request to extract and organize input tensors
for different model types.
@ -193,8 +209,10 @@ class TritonPythonModel:
img_tensor = (pb_utils.get_input_tensor_by_name(request, 'pixel_values')
or pb_utils.get_input_tensor_by_name(request, 'IMAGE'))
# mllama supports img_tensor is None case
assert img_tensor != None or self.model_type == 'mllama', "There is no preprocessed image tensor to encode"
# mllama and pixtral support img_tensor is None case
assert img_tensor != None or self.model_type in [
'mllama', 'pixtral'
], "There is no preprocessed image tensor to encode"
if img_tensor is not None:
img_tensor = from_dlpack(img_tensor.to_dlpack())
@ -242,6 +260,9 @@ class TritonPythonModel:
image_sizes = from_dlpack(
pb_utils.get_input_tensor_by_name(
request, 'image_sizes').to_dlpack())
# Remove dimension 1, which was added to match the dimensions defined in config.pbtxt
assert image_sizes.shape[1] == 1
image_sizes.squeeze_(1)
from transformers.models.llava_onevision.modeling_llava_onevision import \
image_size_to_num_patches
image_num_patches = [
@ -276,6 +297,33 @@ class TritonPythonModel:
input_tensors['attention_mask_llm'].append(attention_mask)
input_tensors['image_grid_thw'].append(image_grid_thw)
elif self.model_type == 'pixtral':
if img_tensor is None:
input_tensors['pixel_values'].append(None)
else:
assert batch_size == 1, "Only support batch size 1 for Pixtral, because each batch can contain a different number of images"
d_min = torch.finfo(self.vision_output_dtype).min
total_images = img_tensor.shape[0] * img_tensor.shape[1]
num_patches = self.image_size // self.patch_size
input_tensors['input'].append(
img_tensor.view(-1, img_tensor.shape[2],
img_tensor.shape[3], img_tensor.shape[4]))
attention_mask_shape = (total_images, num_patches, num_patches)
attention_mask = torch.full(attention_mask_shape,
fill_value=d_min,
dtype=self.vision_output_dtype,
device="cuda")
image_sizes = from_dlpack(
pb_utils.get_input_tensor_by_name(
request,
'image_sizes').to_dlpack()).reshape(total_images, 2)
for image_idx in range(total_images):
image_h, image_w = image_sizes[image_idx][0], image_sizes[
image_idx][1]
attention_mask[image_idx, :image_h //
self.patch_size, :image_w //
self.patch_size] = 0
input_tensors['attention_mask'].append(attention_mask)
else:
input_tensors['input'].append(
img_tensor.view(-1, img_tensor.shape[2], img_tensor.shape[3],
@ -408,7 +456,7 @@ class TritonPythonModel:
f"encoder_output_lengths: {encoder_output_lengths}")
# True when the request does not have image input
output_tensors = [
response_tensors = [
pb_utils.Tensor.from_dlpack(
'ENCODER_INPUT_FEATURES',
to_dlpack(encoder_input_features)),
@ -417,16 +465,16 @@ class TritonPythonModel:
to_dlpack(encoder_output_lengths))
]
if cross_attention_mask is not None:
output_tensors.append(
response_tensors.append(
pb_utils.Tensor.from_dlpack(
'CROSS_ATTENTION_MASK',
to_dlpack(cross_attention_mask)))
output_tensors.append(
response_tensors.append(
pb_utils.Tensor.from_dlpack(
'SKIP_CROSS_ATTN_BLOCKS',
to_dlpack(skip_cross_attn_blocks)))
inference_response = pb_utils.InferenceResponse(
output_tensors=output_tensors)
output_tensors=response_tensors)
responses.append(inference_response)
elif self.model_type == 'llava_onevision':
for req_idx, embeddings in enumerate(
@ -443,6 +491,9 @@ class TritonPythonModel:
image_sizes = from_dlpack(
pb_utils.get_input_tensor_by_name(
request, 'image_sizes').to_dlpack())
# Remove dimension 1, which was added to match the dimensions defined in config.pbtxt
assert image_sizes.shape[1] == 1
image_sizes.squeeze_(1)
from transformers.models.llava_onevision.modeling_llava_onevision import \
image_size_to_num_patches
image_num_patches = [
@ -458,10 +509,10 @@ class TritonPythonModel:
embeddings, image_sizes, image_num_patches)
prompt_embedding_table_tensor = pb_utils.Tensor.from_dlpack(
'OUT_PROMPT_EMBEDDING_TABLE', to_dlpack(prompt_table))
output_tensors = [prompt_embedding_table_tensor]
response_tensors = [prompt_embedding_table_tensor]
inference_response = pb_utils.InferenceResponse(
output_tensors=output_tensors)
output_tensors=response_tensors)
responses.append(inference_response)
elif self.model_type == 'qwen2_vl':
image_grid_thw = other_vision_input_tensors.get('image_grid_thw')
@ -493,12 +544,92 @@ class TritonPythonModel:
'MROPE_ROTARY_COS_SIN', to_dlpack(mrope_rotary_cos_sin))
mrope_position_deltas_tensor = pb_utils.Tensor.from_dlpack(
'MROPE_POSITION_DELTAS', to_dlpack(mrope_position_deltas))
output_tensors = [
response_tensors = [
prompt_embedding_table_tensor, mrope_rotary_cos_sin_tensor,
mrope_position_deltas_tensor
]
inference_response = pb_utils.InferenceResponse(
output_tensors=output_tensors)
output_tensors=response_tensors)
responses.append(inference_response)
elif self.model_type == 'pixtral':
assert len(num_images) == len(batch_sizes) == len(
is_skip_encoders) == len(requests)
images_per_batch = [i * b for i, b in zip(num_images, batch_sizes)]
split_along = np.cumsum(images_per_batch).tolist()
if output_tensor is not None:
splitted_output_tensor = torch.tensor_split(output_tensor,
split_along,
dim=0)
visual_embed_dim = output_tensor.shape[-1]
output_img_size = self.image_size // self.relevant_patch_size
for req_idx, request in enumerate(requests):
if is_skip_encoders[req_idx]:
responses.append(
pb_utils.InferenceResponse(output_tensors=[]))
continue
response_tensors = []
assert splitted_output_tensor[req_idx].ndim == 3
current_output_tensor = splitted_output_tensor[req_idx].reshape(
batch_sizes[req_idx], num_images[req_idx],
splitted_output_tensor[req_idx].shape[-2],
splitted_output_tensor[req_idx].shape[-1])
image_sizes = from_dlpack(
pb_utils.get_input_tensor_by_name(
request, 'image_sizes').to_dlpack())
complete_visual_features = []
vocab_size = []
for batch_idx in range(batch_sizes[req_idx]):
batch_visual_features = []
for image_idx in range(num_images[req_idx]):
image_h = image_sizes[batch_idx][image_idx][0]
image_w = image_sizes[batch_idx][image_idx][1]
h_patches = image_h // self.relevant_patch_size
w_patches = image_w // self.relevant_patch_size
relevant_visual_features = torch.zeros(
1, h_patches * w_patches, visual_embed_dim)
visual_features = current_output_tensor[batch_idx][
image_idx].reshape(output_img_size, output_img_size,
visual_embed_dim)
flattened_features = visual_features[:h_patches, :
w_patches, :].flatten(
0, 1)
relevant_visual_features[
0, :h_patches * w_patches, :] = flattened_features
batch_visual_features.append(relevant_visual_features)
batch_visual_features = torch.cat(batch_visual_features,
dim=1)
vocab_size.append(batch_visual_features.shape[1])
complete_visual_features.append(batch_visual_features)
# Pad elements of complete_visual_features to have the same shape[1],
# to allow concatenation over batch dimension
max_vocab_size = max(vocab_size)
for batch_idx in range(batch_sizes[req_idx]):
complete_visual_features[
batch_idx] = torch.nn.functional.pad(
complete_visual_features[batch_idx],
(0, 0, 0, max_vocab_size -
complete_visual_features[batch_idx].shape[1]),
mode='constant')
complete_visual_features = torch.cat(complete_visual_features,
dim=0)
prompt_embedding_table_tensor = pb_utils.Tensor.from_dlpack(
'OUT_PROMPT_EMBEDDING_TABLE',
to_dlpack(
complete_visual_features.type(
self.vision_output_dtype)))
prompt_vocab_size_tensor = pb_utils.Tensor(
'OUT_PROMPT_VOCAB_SIZE',
np.array(vocab_size,
dtype=np.int32).reshape(batch_sizes[req_idx], 1))
response_tensors.extend(
[prompt_embedding_table_tensor, prompt_vocab_size_tensor])
inference_response = pb_utils.InferenceResponse(
output_tensors=response_tensors)
responses.append(inference_response)
else:
for req_idx, embeddings in enumerate(
@ -530,17 +661,67 @@ class TritonPythonModel:
prompt_vocab_size_tensor = pb_utils.Tensor(
'OUT_PROMPT_VOCAB_SIZE', prompt_vocab_size.astype(np.int32))
output_tensors = [
response_tensors = [
prompt_embedding_table_tensor, prompt_vocab_size_tensor
]
inference_response = pb_utils.InferenceResponse(
output_tensors=output_tensors)
output_tensors=response_tensors)
responses.append(inference_response)
# You should return a list of pb_utils.InferenceResponse. Length
# of this list must match the length of `requests` list.
return responses
def run_vision_encoder(self, vit_input: Dict[str,
torch.Tensor]) -> torch.Tensor:
batch_size = [v.shape[0] for v in vit_input.values()]
assert all(
b == batch_size[0]
for b in batch_size), "Batch sizes of encoder inputs must match"
batch_size = batch_size[0]
embeddings = []
for start_idx in range(0, batch_size, self.vision_max_batch_size):
end_idx = min(start_idx + self.vision_max_batch_size, batch_size)
logger.debug(
f"Running encoder (max_batch_size={self.vision_max_batch_size}) "
+ f"with batch indices {start_idx}:{end_idx} of {batch_size}.")
# Slice the input tensors along the batch dimension
vit_input_batch = {
k: v[start_idx:end_idx]
for k, v in vit_input.items()
}
# Set up output tensors
vit_input_info = [
TensorInfo(key, torch_dtype_to_trt(val.dtype), val.shape)
for key, val in vit_input_batch.items()
]
vit_output_info = self.image_session.infer_shapes(vit_input_info)
vit_output_batch = {
t.name:
torch.empty(tuple(t.shape),
dtype=trt_dtype_to_torch(t.dtype),
device='cuda')
for t in vit_output_info
}
# Run the vision encoder
with torch.cuda.stream(self.vision_stream):
ok = self.image_session.run(vit_input_batch, vit_output_batch,
self.vision_stream.cuda_stream)
assert ok, "Runtime execution failed for vision encoder session"
embeddings.append(vit_output_batch['encoder_output'].to(
self.vision_output_dtype))
with torch.cuda.stream(self.vision_stream):
embeddings = torch.cat(embeddings, dim=0)
self.vision_stream.synchronize()
return embeddings
def execute(self, requests: List):
"""`execute` must be implemented in every Python model. `execute`
function receives a list of pb_utils.InferenceRequest as the only
@ -664,28 +845,8 @@ class TritonPythonModel:
vit_input['attention_mask'] = attention_mask_vit.to(
str_dtype_to_torch(self.vision_dtype_str)).to('cuda')
# Set up output tensors
vit_input_info = [
TensorInfo(key, torch_dtype_to_trt(val.dtype), val.shape)
for key, val in vit_input.items()
]
vit_output_info = self.image_session.infer_shapes(
vit_input_info)
vit_output = {
t.name:
torch.empty(tuple(t.shape),
dtype=trt_dtype_to_torch(t.dtype),
device='cuda')
for t in vit_output_info
}
# Run the vision encoder
with torch.cuda.stream(self.vision_stream):
ok = self.image_session.run(vit_input, vit_output,
self.vision_stream.cuda_stream)
assert ok, "Runtime execution failed for vision encoder session"
embeddings = vit_output['encoder_output'].to(
self.vision_output_dtype)
self.vision_stream.synchronize()
embeddings = self.run_vision_encoder(vit_input)
# Post process output and save in responses
responses.extend(
self.postprocess_output_tensors(embeddings,

View File

@ -72,13 +72,14 @@ input [
dims: [ 1 ]
optional: true
},
# input tensors for llava_onevision
# Required for llava_onevision and pixtral
{
name: "image_sizes"
data_type: TYPE_INT64
dims: [ 2 ]
dims: [ -1, 2 ]
optional: true
},
# Required for llava_onevision
{
name: "is_video_input"
data_type: TYPE_BOOL
@ -114,7 +115,7 @@ input [
output [
{
name: "OUT_PROMPT_EMBEDDING_TABLE"
data_type: TYPE_FP16
data_type: ${prompt_embedding_table_data_type}
dims: [ -1, -1 ]
},
{

View File

@ -0,0 +1 @@
transformers>=4.50.0

View File

@ -64,6 +64,12 @@ class MockTritonTensor:
else:
return False
def to_dlpack(self):
if self.is_cpu():
return self._tensor.__dlpack__()
else:
return self._tensor.to_dlpack()
@dataclass
class MockTritonError:

View File

@ -63,6 +63,12 @@ class MockTritonTensor:
else:
return False
def to_dlpack(self):
if self.is_cpu():
return self._tensor.__dlpack__()
else:
return self._tensor.to_dlpack()
@dataclass
class MockTritonError:

View File

@ -64,6 +64,12 @@ class MockTritonTensor:
else:
return False
def to_dlpack(self):
if self.is_cpu():
return self._tensor.__dlpack__()
else:
return self._tensor.to_dlpack()
@dataclass
class MockTritonResponse:

View File

@ -197,6 +197,7 @@ for NUM_GPU in "${NUM_GPUS_TO_TEST[@]}"; do
replace_config_tags '${max_queue_delay_microseconds}' "50000" "${MODEL_DIR}/tensorrt_llm/config.pbtxt"
replace_config_tags '${triton_backend}' "tensorrtllm" "${MODEL_DIR}/tensorrt_llm/config.pbtxt"
replace_config_tags '${encoder_input_features_data_type}' "TYPE_FP16" "${MODEL_DIR}/tensorrt_llm/config.pbtxt"
replace_config_tags '${prompt_embedding_table_data_type}' 'TYPE_FP16' "${MODEL_DIR}/tensorrt_llm/config.pbtxt"
replace_config_tags '${triton_max_batch_size}' "128" "${MODEL_DIR}/postprocessing/config.pbtxt"
replace_config_tags '${tokenizer_dir}' "${TOKENIZER_DIR}/" "${MODEL_DIR}/postprocessing/config.pbtxt"
replace_config_tags '${postprocessing_instance_count}' '1' "${MODEL_DIR}/postprocessing/config.pbtxt"

View File

@ -6,6 +6,8 @@ import io
import os
import sys
from datetime import datetime
from pathlib import Path
from typing import List, Tuple
sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
@ -19,8 +21,32 @@ from transformers import AutoProcessor, Blip2Processor
from utils import utils
def pixtral_pad_images(
image_list: List[Image.Image]) -> Tuple[np.ndarray, np.ndarray]:
if not image_list:
return np.empty((0, 0, 0, 0), dtype=np.uint8), np.empty((0, 2),
dtype=np.int64)
image_list_np = [np.array(img) for img in image_list]
shapes = [img.shape for img in image_list_np]
assert all(len(s) == 3
for s in shapes), "All input images must have three dimensions"
assert all(s[-1] == shapes[0][-1] for s in
shapes), "All input images must have the same number of channels"
max_h, max_w = max(s[0] for s in shapes), max(s[1] for s in shapes)
for i in range(len(image_list_np)):
image_list_np[i] = np.pad(image_list_np[i],
((0, max_h - image_list_np[i].shape[0]),
(0, max_w - image_list_np[i].shape[1]),
(0, 0)),
mode='constant')
raw_image = np.stack(image_list_np, axis=0)
image_sizes = np.array([s[:2] for s in shapes], dtype=np.int64)
return raw_image, image_sizes
def prepare_inputs(text_data,
image_data,
image_sizes,
request_output_len_data,
beam_width_data,
temperature_data,
@ -35,7 +61,6 @@ def prepare_inputs(text_data,
image_input_name="image_input"):
inputs = [
utils.prepare_tensor("text_input", text_data, grpcclient),
utils.prepare_tensor(image_input_name, image_data, grpcclient),
utils.prepare_tensor("max_tokens", request_output_len_data, grpcclient),
utils.prepare_tensor("beam_width", beam_width_data, grpcclient),
utils.prepare_tensor("temperature", temperature_data, grpcclient),
@ -45,6 +70,14 @@ def prepare_inputs(text_data,
utils.prepare_tensor("top_p", top_p_data, grpcclient),
utils.prepare_tensor("stream", streaming_data, grpcclient),
]
if image_data is not None:
inputs += [
utils.prepare_tensor(image_input_name, image_data, grpcclient),
]
if image_sizes is not None:
inputs += [
utils.prepare_tensor("image_sizes_input", image_sizes, grpcclient),
]
if repetition_penalty_data is not None:
inputs += [
utils.prepare_tensor("repetition_penalty", repetition_penalty_data,
@ -63,20 +96,16 @@ def prepare_inputs(text_data,
return inputs
def load_image(image_path):
def load_image(image_path) -> Image.Image:
if image_path.startswith("http") or image_path.startswith("https"):
image = Image.open(requests.get(image_path,
stream=True).raw).convert("RGB")
image_bytes = requests.get(image_path, stream=True).content
elif image_path.startswith("data:image/jpeg;base64,"):
image_base64 = image_path.split(",")[1]
# Decode the base64 string
image_data = base64.b64decode(image_base64)
# Create a BytesIO object from the decoded data
image_buffer = io.BytesIO(image_data)
image = Image.open(image_buffer).convert("RGB")
image_bytes = base64.b64decode(image_base64)
else:
image = Image.open(image_path).convert("RGB")
return image
image_bytes = Path(image_path).read_bytes()
return Image.open(io.BytesIO(image_bytes)).convert("RGB")
def load_video(video_path, num_of_frames):
@ -239,7 +268,7 @@ if __name__ == "__main__":
required=True,
choices=[
'blip2', 'llava', 'vila', 'mllama',
'llava_onevision', 'qwen2_vl'
'llava_onevision', 'qwen2_vl', 'pixtral'
],
help="Model type")
parser.add_argument("--hf_model_dir",
@ -249,11 +278,18 @@ if __name__ == "__main__":
help="path to the model directory")
FLAGS = parser.parse_args()
# load and process images or video
image_sizes = np.empty((0, 2), dtype=np.int64)
if 'vila' in FLAGS.model_type:
image_paths = FLAGS.image.split(",")
raw_image = []
for image_path in image_paths:
raw_image.append(load_image(image_path))
elif 'pixtral' in FLAGS.model_type:
image_paths = FLAGS.image.split(",") if FLAGS.image else []
raw_image = []
for image_path in image_paths:
raw_image.append(load_image(image_path))
raw_image, image_sizes = pixtral_pad_images(raw_image)
elif FLAGS.video is not None:
assert FLAGS.video_num_frames is not None, "Number of frames should be provided for video input."
raw_video = load_video(FLAGS.video, FLAGS.video_num_frames)
@ -303,6 +339,9 @@ if __name__ == "__main__":
FLAGS.text = image_tag + FLAGS.text
image_data = np.array([[raw_image]])
image_input_name = "image_bytes_input"
elif 'pixtral' in FLAGS.model_type:
image_data = np.array([raw_image])
image_input_name = "image_bytes_input"
elif 'llava_onevision' in FLAGS.model_type:
if FLAGS.video is not None:
image_data = np.array([raw_video])
@ -334,6 +373,9 @@ if __name__ == "__main__":
temperature_data = np.array(temperature, dtype=np.float32)
streaming = [[FLAGS.streaming]]
streaming_data = np.array(streaming, dtype=bool)
image_data = None if image_data.size == 0 else image_data
image_sizes_data = None if image_sizes.size == 0 else np.array(
[image_sizes], dtype=np.int64)
model_name = "ensemble"
if FLAGS.use_bls:
@ -356,6 +398,7 @@ if __name__ == "__main__":
inputs = prepare_inputs(text_data,
image_data,
image_sizes_data,
request_output_len_data,
beam_width_data,
temperature_data,