mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Merge remote-tracking branch 'origin/main' into user/xiweny/merge_0901
Signed-off-by: Xiwen Yu <13230610+VALLIS-NERIA@users.noreply.github.com>
This commit is contained in:
commit
62a78973a8
@ -9,5 +9,6 @@ examples/**/.git
|
||||
examples/**/*.bin
|
||||
examples/**/*.engine
|
||||
examples/**/*.onnx
|
||||
examples/**/*.safetensors
|
||||
examples/**/c-model
|
||||
examples/models/core/gpt/gpt*
|
||||
|
||||
11
.github/CODEOWNERS
vendored
11
.github/CODEOWNERS
vendored
@ -1,10 +1,5 @@
|
||||
# This file defines code ownership rules for the repository.
|
||||
|
||||
# The following rule should only be uncommented on release branches (e.g., release/0.19).
|
||||
# The rule below requires that any PR to release/**/* branches must be approved by at least one member
|
||||
# of the NVIDIA/trt-llm-release-branch-approval team, regardless of who else approves the PR.
|
||||
# Without approval from a member of this team, PRs cannot be merged to release branches.
|
||||
# * @NVIDIA/trt-llm-release-branch-approval
|
||||
|
||||
## TensorRT-LLM Infra
|
||||
### CI
|
||||
@ -160,3 +155,9 @@ docs/source/performance/perf-benchmarking.md @NVIDIA/trtllm-bench-reviewers
|
||||
# from a member of this team, PRs affecting public APIs cannot be merged to main or release branches.
|
||||
/tests/unittest/api_stability/ @NVIDIA/trt-llm-noncommitted-api-review-committee
|
||||
/tests/unittest/api_stability/references_committed/ @NVIDIA/trt-llm-committed-api-review-committee
|
||||
|
||||
# The following rule should only be uncommented on release branches (e.g., release/0.19).
|
||||
# The rule below requires that any PR to release/**/* branches must be approved by at least one member
|
||||
# of the NVIDIA/trt-llm-release-branch-approval team, regardless of who else approves the PR.
|
||||
# Without approval from a member of this team, PRs cannot be merged to release branches.
|
||||
# * @NVIDIA/trt-llm-release-branch-approval
|
||||
|
||||
@ -363,6 +363,18 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m)
|
||||
}
|
||||
return block_pool_pointers;
|
||||
})
|
||||
.def("get_block_scale_pool_pointers",
|
||||
[](tbk::BaseKVCacheManager& self)
|
||||
{
|
||||
std::optional<at::Tensor> block_scale_pool_pointers{std::nullopt};
|
||||
auto tensor = self.getBlockScalePoolPointers();
|
||||
if (tensor)
|
||||
{
|
||||
std::shared_ptr<tensorrt_llm::runtime::ITensor> _tensor = std::move(tensor);
|
||||
block_scale_pool_pointers = tr::Torch::tensor(_tensor);
|
||||
}
|
||||
return block_scale_pool_pointers;
|
||||
})
|
||||
.def("get_layer_to_pool_mapping",
|
||||
[](tbk::BaseKVCacheManager& self)
|
||||
{
|
||||
|
||||
@ -392,8 +392,8 @@ public:
|
||||
std::vector<int64_t> output_shape = {num_rows, unpadded_hidden_size_val};
|
||||
auto output = torch::empty(output_shape, input.options().dtype(mOutputDtype));
|
||||
|
||||
WorkspaceInfo workspace_info = getWorkspaceInfo(num_rows, hidden_size, inter_size, num_experts_total,
|
||||
static_cast<int>(experts_per_token), base_activation_type, parallelism_config, min_latency_mode);
|
||||
WorkspaceInfo const& workspace_info = getWorkspaceInfo(num_rows, hidden_size, inter_size, num_experts_total,
|
||||
static_cast<int>(experts_per_token), base_activation_type, parallelism_config, min_latency_mode, stream);
|
||||
|
||||
auto const quant_params = getQuantParams(num_experts_on_rank, hidden_size, inter_size, quant_scales);
|
||||
kernels::MoeMinLatencyParams min_latency_params{};
|
||||
@ -553,8 +553,8 @@ public:
|
||||
min_latency_params.experts_to_token_score = static_cast<float*>(experts_to_token_score.data_ptr());
|
||||
min_latency_params.active_expert_global_ids = static_cast<int*>(active_expert_global_ids.data_ptr());
|
||||
|
||||
WorkspaceInfo workspace_info = getWorkspaceInfo(num_rows, hidden_size, inter_size, num_experts_total,
|
||||
static_cast<int>(experts_per_token), base_activation_type, parallelism_config, min_latency_mode);
|
||||
WorkspaceInfo const& workspace_info = getWorkspaceInfo(num_rows, hidden_size, inter_size, num_experts_total,
|
||||
static_cast<int>(experts_per_token), base_activation_type, parallelism_config, min_latency_mode, stream);
|
||||
|
||||
auto const quant_params = getQuantParams(num_experts_on_rank, hidden_size, inter_size, quant_scales);
|
||||
|
||||
@ -709,6 +709,7 @@ private:
|
||||
// e.g. 16 nvfp4 elements are packed into a single int64 element
|
||||
int64_t mInnerDimMultiplier;
|
||||
char* mProfileWorkspace = nullptr;
|
||||
WorkspaceInfo workspace_info;
|
||||
|
||||
bool mUseDeepSeekFP8BlockScaling = false;
|
||||
bool mUseW4GroupScaling = false;
|
||||
@ -757,9 +758,9 @@ private:
|
||||
mKernelRunner->setTactic(best_gemm1_profile, best_gemm2_profile);
|
||||
}
|
||||
|
||||
WorkspaceInfo getWorkspaceInfo(int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size,
|
||||
WorkspaceInfo const& getWorkspaceInfo(int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size,
|
||||
int num_experts, int experts_per_token, ActivationType activation_type,
|
||||
kernels::MOEParallelismConfig const& parallelismConfig, bool min_latency_mode)
|
||||
kernels::MOEParallelismConfig const& parallelismConfig, bool min_latency_mode, cudaStream_t stream)
|
||||
{
|
||||
size_t moe_workspace_size = mKernelRunner->getWorkspaceSize(num_rows, hidden_size, inter_size, num_experts,
|
||||
experts_per_token, activation_type, parallelismConfig, /* use_lora */ false, mUseDeepSeekFP8BlockScaling,
|
||||
@ -768,15 +769,29 @@ private:
|
||||
|
||||
std::vector<size_t> workspaces{moe_workspace_size, src_to_dest_map_size};
|
||||
|
||||
size_t total_workspace_size = common::calculateTotalWorkspaceSize(workspaces.data(), workspaces.size());
|
||||
int64_t const total_workspace_size = common::calculateTotalWorkspaceSize(workspaces.data(), workspaces.size());
|
||||
|
||||
WorkspaceInfo info{};
|
||||
info.workspace = torch::empty({static_cast<long>(total_workspace_size)},
|
||||
torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false));
|
||||
info.src_to_dest_map
|
||||
= common::nextWorkspacePtr(static_cast<int8_t*>(info.workspace.data_ptr()), moe_workspace_size);
|
||||
bool is_capturing = tensorrt_llm::common::isCapturing(stream);
|
||||
// Always allocate workspace when capturing cuda graph to avoid illegal memory access during replay
|
||||
if (is_capturing || workspace_info.workspace.numel() < total_workspace_size)
|
||||
{
|
||||
if (is_capturing)
|
||||
{
|
||||
TLLM_LOG_DEBUG(
|
||||
"Allocating MoE workspace with %ld bytes size during cuda graph capture", total_workspace_size);
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_LOG_DEBUG("MoE workspace size is not enough, increase the size from %ld bytes to %ld bytes",
|
||||
workspace_info.workspace.numel(), total_workspace_size);
|
||||
}
|
||||
workspace_info.workspace = torch::empty({static_cast<long>(total_workspace_size)},
|
||||
torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false));
|
||||
}
|
||||
workspace_info.src_to_dest_map
|
||||
= common::nextWorkspacePtr(static_cast<int8_t*>(workspace_info.workspace.data_ptr()), moe_workspace_size);
|
||||
|
||||
return info;
|
||||
return workspace_info;
|
||||
}
|
||||
|
||||
kernels::QuantParams getQuantParams(int64_t const num_experts_on_rank, int64_t const hidden_size,
|
||||
|
||||
@ -197,7 +197,8 @@ FROM wheel AS tritonbuild
|
||||
WORKDIR /src/tensorrt_llm
|
||||
RUN pip install /src/tensorrt_llm/build/tensorrt_llm*.whl
|
||||
COPY ./triton_backend/ ./triton_backend/
|
||||
RUN bash ./triton_backend/inflight_batcher_llm/scripts/build.sh
|
||||
ARG TRITON_BASE_TAG
|
||||
RUN bash ./triton_backend/inflight_batcher_llm/scripts/build.sh -s "r${TRITON_BASE_TAG%-py3}"
|
||||
|
||||
|
||||
FROM release AS tritonrelease
|
||||
|
||||
@ -138,7 +138,7 @@ CODE_DIR ?= /code/tensorrt_llm
|
||||
EXTRA_VOLUMES ?=
|
||||
CCACHE_DIR ?= $(CODE_DIR)/cpp/.ccache
|
||||
CONAN_DIR ?= $(CODE_DIR)/cpp/.conan
|
||||
USER_CACHE_DIR ?= $(HOME_DIR)/.cache
|
||||
USER_CACHE_DIR ?= $(shell readlink -f "${HOME_DIR}/.cache")
|
||||
RUN_CMD ?=
|
||||
CONTAINER_NAME ?= tensorrt_llm
|
||||
WORK_DIR ?= $(CODE_DIR)
|
||||
|
||||
@ -160,6 +160,12 @@ Welcome to TensorRT-LLM's Documentation!
|
||||
blogs/XQA-kernel.md
|
||||
blogs/tech_blog/*
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
:caption: Use TensorRT Engine
|
||||
:hidden:
|
||||
|
||||
legacy/tensorrt_quickstart.md
|
||||
|
||||
Indices and tables
|
||||
==================
|
||||
|
||||
9
docs/source/legacy/tensorrt_quickstart.md
Normal file
9
docs/source/legacy/tensorrt_quickstart.md
Normal file
@ -0,0 +1,9 @@
|
||||
# LLM API with TensorRT Engine
|
||||
A simple inference example with TinyLlama using the LLM API:
|
||||
|
||||
```{literalinclude} ../../examples/llm-api/_tensorrt_engine/quickstart_example.py
|
||||
:language: python
|
||||
:linenos:
|
||||
```
|
||||
|
||||
For more advanced usage including distributed inference, multimodal, and speculative decoding, please refer to this [README](../../../examples/llm-api/README.md).
|
||||
33
examples/llm-api/_tensorrt_engine/quickstart_example.py
Normal file
33
examples/llm-api/_tensorrt_engine/quickstart_example.py
Normal file
@ -0,0 +1,33 @@
|
||||
from tensorrt_llm import LLM, SamplingParams
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
# Model could accept HF model name, a path to local HF model,
|
||||
# or TensorRT Model Optimizer's quantized checkpoints like nvidia/Llama-3.1-8B-Instruct-FP8 on HF.
|
||||
llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
# Create a sampling params.
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
for output in llm.generate(prompts, sampling_params):
|
||||
print(
|
||||
f"Prompt: {output.prompt!r}, Generated text: {output.outputs[0].text!r}"
|
||||
)
|
||||
|
||||
# Got output like
|
||||
# Prompt: 'Hello, my name is', Generated text: '\n\nJane Smith. I am a student pursuing my degree in Computer Science at [university]. I enjoy learning new things, especially technology and programming'
|
||||
# Prompt: 'The president of the United States is', Generated text: 'likely to nominate a new Supreme Court justice to fill the seat vacated by the death of Antonin Scalia. The Senate should vote to confirm the'
|
||||
# Prompt: 'The capital of France is', Generated text: 'Paris.'
|
||||
# Prompt: 'The future of AI is', Generated text: 'an exciting time for us. We are constantly researching, developing, and improving our platform to create the most advanced and efficient model available. We are'
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@ -29,8 +29,7 @@ def example_cuda_graph_config():
|
||||
cuda_graph_config=cuda_graph_config, # Enable CUDA graphs
|
||||
max_batch_size=4,
|
||||
max_seq_len=512,
|
||||
kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.8,
|
||||
enable_block_reuse=True))
|
||||
kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.5))
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
@ -56,7 +55,7 @@ def example_kv_cache_config():
|
||||
max_batch_size=8,
|
||||
max_seq_len=1024,
|
||||
kv_cache_config=KvCacheConfig(
|
||||
free_gpu_memory_fraction=0.85,
|
||||
free_gpu_memory_fraction=0.5,
|
||||
enable_block_reuse=True))
|
||||
|
||||
prompts = [
|
||||
|
||||
@ -1,11 +1,17 @@
|
||||
from tensorrt_llm import LLM, SamplingParams
|
||||
from tensorrt_llm import BuildConfig, SamplingParams
|
||||
from tensorrt_llm._tensorrt_engine import LLM # NOTE the change
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
build_config = BuildConfig()
|
||||
build_config.max_batch_size = 256
|
||||
build_config.max_num_tokens = 1024
|
||||
|
||||
# Model could accept HF model name, a path to local HF model,
|
||||
# or TensorRT Model Optimizer's quantized checkpoints like nvidia/Llama-3.1-8B-Instruct-FP8 on HF.
|
||||
llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
|
||||
llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
build_config=build_config)
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
|
||||
@ -122,6 +122,15 @@ def add_multimodal_args(parser):
|
||||
" ├── __init__.py"
|
||||
" ├── <model_name>.py"
|
||||
" └── <sub_dirs>"))
|
||||
# Add multiturn conversation related parameters
|
||||
parser.add_argument("--multiturn",
|
||||
action="store_true",
|
||||
help="Enable multi-turn conversation mode.")
|
||||
parser.add_argument(
|
||||
"--conversation_turns",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Number of conversation turns for automated testing.")
|
||||
return parser
|
||||
|
||||
|
||||
@ -188,6 +197,80 @@ def main():
|
||||
f"Unsupported model_type: {model_type} found!\n" \
|
||||
f"Supported types: {MULTIMODAL_PLACEHOLDER_REGISTRY.get_registered_model_types()}"
|
||||
|
||||
# If multiturn mode is enabled
|
||||
if args.multiturn:
|
||||
# Run predefined multiturn conversation examples
|
||||
assert args.prompt is not None, "Please provide a prompt for multiturn conversation."
|
||||
assert args.media is not None, "Please provide media for multiturn conversation."
|
||||
# Determine how many turns to run
|
||||
max_turns = min(args.conversation_turns, len(args.prompt))
|
||||
generated_outputs = [] # Store generated outputs for return
|
||||
|
||||
# Initialize conversation history with the first prompt
|
||||
conversation_history = args.prompt[0] if args.prompt else ""
|
||||
|
||||
for i in range(max_turns):
|
||||
print(f"\n--- Turn {i+1} ---")
|
||||
|
||||
try:
|
||||
# Use multimodal input loader to process input with conversation context
|
||||
# Use accumulated conversation history instead of just the current prompt
|
||||
cur_prompt = conversation_history
|
||||
inputs = default_multimodal_input_loader(
|
||||
tokenizer=llm.tokenizer,
|
||||
model_dir=llm._hf_model_dir,
|
||||
model_type=model_type,
|
||||
modality=args.modality,
|
||||
prompts=[cur_prompt],
|
||||
media=args.media,
|
||||
image_data_format="pt",
|
||||
num_frames=8,
|
||||
device="cpu")
|
||||
|
||||
lora_request = None
|
||||
if args.load_lora:
|
||||
if model_class is None:
|
||||
raise ValueError(
|
||||
"model_class must be provided when load_lora is True"
|
||||
)
|
||||
lora_request = model_class.lora_request(
|
||||
len(inputs), args.modality, llm._hf_model_dir)
|
||||
|
||||
# Generate response
|
||||
outputs = llm.generate(inputs,
|
||||
sampling_params,
|
||||
lora_request=lora_request)
|
||||
assert outputs and len(
|
||||
outputs) > 0 and outputs[0].outputs and len(
|
||||
outputs[0].outputs) > 0
|
||||
response = outputs[0].outputs[0].text.strip()
|
||||
|
||||
# Store generated output
|
||||
generated_outputs.append({
|
||||
"turn": i + 1,
|
||||
"user_input": cur_prompt,
|
||||
"assistant_response": response,
|
||||
"media": args.media
|
||||
})
|
||||
|
||||
conversation_history = conversation_history + "\n" + response
|
||||
if i + 1 < len(args.prompt):
|
||||
conversation_history = conversation_history + "\n" + args.prompt[
|
||||
i + 1]
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in turn {i+1}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
continue
|
||||
|
||||
for i, output in enumerate(generated_outputs):
|
||||
print(
|
||||
f"[{i}] Prompt: {output['user_input']!r}, Generated text: {output['assistant_response']!r}"
|
||||
)
|
||||
return
|
||||
|
||||
# Original single-turn processing logic
|
||||
# set prompts and media to example prompts and images if they are not provided
|
||||
if args.prompt is None:
|
||||
args.prompt = example_medias_and_prompts[args.modality]["prompt"]
|
||||
|
||||
@ -28,6 +28,11 @@ Please refer to [this guide](https://nvidia.github.io/TensorRT-LLM/installation/
|
||||
- [Evaluation](#evaluation)
|
||||
- [Serving](#serving)
|
||||
- [trtllm-serve](#trtllm-serve)
|
||||
- [B200 FP4 min-latency config](#b200-fp4-min-latency-config)
|
||||
- [B200 FP4 max-throughput config](#b200-fp4-max-throughput-config)
|
||||
- [B200 FP8 min-latency config](#b200-fp8-min-latency-config)
|
||||
- [B200 FP8 max-throughput config](#b200-fp8-max-throughput-config)
|
||||
- [Launch trtllm-serve OpenAI-compatible API server](#launch-trtllm-serve-openai-compatible-api-server)
|
||||
- [Disaggregated Serving](#disaggregated-serving)
|
||||
- [Dynamo](#dynamo)
|
||||
- [tensorrtllm\_backend for triton inference server (Prototype)](#tensorrtllm_backend-for-triton-inference-server-prototype)
|
||||
@ -228,56 +233,111 @@ trtllm-eval --model <YOUR_MODEL_DIR> \
|
||||
## Serving
|
||||
### trtllm-serve
|
||||
|
||||
Take max-throughput scenario on B200 as an example, the settings are extracted from the [blog](https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/blogs/Best_perf_practice_on_DeepSeek-R1_in_TensorRT-LLM.md#b200-max-throughput). **For users' own models and cases, the specific settings could be different to get best performance.**
|
||||
Below are example B200 serving configurations for both min-latency and max-throughput in FP4 and FP8. If you want to explore configurations, see the [blog](https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/blogs/Best_perf_practice_on_DeepSeek-R1_in_TensorRT-LLM.md). **Treat these as starting points—tune for your model and workload to achieve the best performance.**
|
||||
|
||||
To serve the model using `trtllm-serve`:
|
||||
|
||||
#### B200 FP4 min-latency config
|
||||
```bash
|
||||
cat >./extra-llm-api-config.yml <<EOF
|
||||
cuda_graph_config:
|
||||
enable_padding: true
|
||||
max_batch_size: 1024
|
||||
enable_attention_dp: false
|
||||
kv_cache_config:
|
||||
dtype: fp8
|
||||
stream_interval: 10
|
||||
EOF
|
||||
```
|
||||
|
||||
#### B200 FP4 max-throughput config
|
||||
```bash
|
||||
cat >./extra-llm-api-config.yml <<EOF
|
||||
cuda_graph_config:
|
||||
enable_padding: true
|
||||
batch_sizes:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 8
|
||||
- 16
|
||||
- 32
|
||||
- 64
|
||||
- 128
|
||||
- 256
|
||||
- 384
|
||||
print_iter_log: true
|
||||
- 1024
|
||||
- 896
|
||||
- 512
|
||||
- 256
|
||||
- 128
|
||||
- 64
|
||||
- 32
|
||||
- 16
|
||||
- 8
|
||||
- 4
|
||||
- 2
|
||||
- 1
|
||||
kv_cache_config:
|
||||
dtype: fp8
|
||||
stream_interval: 10
|
||||
enable_attention_dp: true
|
||||
EOF
|
||||
```
|
||||
|
||||
#### B200 FP8 min-latency config
|
||||
```bash
|
||||
cat >./extra-llm-api-config.yml <<EOF
|
||||
cuda_graph_config:
|
||||
enable_padding: true
|
||||
max_batch_size: 1024
|
||||
enable_attention_dp: false
|
||||
kv_cache_config:
|
||||
dtype: fp8
|
||||
free_gpu_memory_fraction: 0.8
|
||||
stream_interval: 10
|
||||
moe_config:
|
||||
backend: DEEPGEMM
|
||||
max_num_tokens: 37376
|
||||
EOF
|
||||
```
|
||||
|
||||
#### B200 FP8 max-throughput config
|
||||
```bash
|
||||
cat >./extra-llm-api-config.yml <<EOF
|
||||
cuda_graph_config:
|
||||
enable_padding: true
|
||||
max_batch_size: 512
|
||||
enable_attention_dp: true
|
||||
kv_cache_config:
|
||||
dtype: fp8
|
||||
free_gpu_memory_fraction: 0.8
|
||||
stream_interval: 10
|
||||
moe_config:
|
||||
backend: DEEPGEMM
|
||||
EOF
|
||||
```
|
||||
#### Launch trtllm-serve OpenAI-compatible API server
|
||||
```bash
|
||||
trtllm-serve \
|
||||
deepseek-ai/DeepSeek-V3 \
|
||||
deepseek-ai/DeepSeek-R1 \
|
||||
--host localhost \
|
||||
--port 8000 \
|
||||
--backend pytorch \
|
||||
--max_batch_size 384 \
|
||||
--max_num_tokens 1536 \
|
||||
--max_batch_size 1024 \
|
||||
--max_num_tokens 8192 \
|
||||
--tp_size 8 \
|
||||
--ep_size 8 \
|
||||
--pp_size 1 \
|
||||
--kv_cache_free_gpu_memory_fraction 0.85 \
|
||||
--kv_cache_free_gpu_memory_fraction 0.9 \
|
||||
--extra_llm_api_options ./extra-llm-api-config.yml
|
||||
```
|
||||
It's possible seeing OOM issues on some configs. Considering reducing `kv_cache_free_gpu_mem_fraction` to a smaller value as a workaround. We're working on the investigation and addressing the problem. If you are using max-throughput config, reduce `max_num_tokens` to `3072` to avoid OOM issues.
|
||||
|
||||
To query the server, you can start with a `curl` command:
|
||||
```bash
|
||||
curl http://localhost:8000/v1/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "deepseek-ai/DeepSeek-V3",
|
||||
"model": "deepseek-ai/DeepSeek-R1",
|
||||
"prompt": "Where is New York?",
|
||||
"max_tokens": 16,
|
||||
"temperature": 0
|
||||
}'
|
||||
```
|
||||
|
||||
For DeepSeek-R1, use the model name `deepseek-ai/DeepSeek-R1`.
|
||||
For DeepSeek-R1 FP4, use the model name `nvidia/DeepSeek-R1-FP4-v2`.
|
||||
For DeepSeek-V3, use the model name `deepseek-ai/DeepSeek-V3`.
|
||||
|
||||
### Disaggregated Serving
|
||||
|
||||
|
||||
@ -12,6 +12,7 @@ We first describe three runtime modes for running multimodal models and how to r
|
||||
- [CogVLM](#cogvlm)
|
||||
- [Deplot](#deplot)
|
||||
- [Fuyu](#fuyu)
|
||||
- [Gemma3](#gemma3)
|
||||
- [InternLM-XComposer2](#internlm-xcomposer2)
|
||||
- [InternVL2](#internvl2)
|
||||
- [Kosmos-2](#kosmos-2)
|
||||
@ -352,6 +353,75 @@ Currently, CogVLM only support bfloat16 precision.
|
||||
--engine_dir tmp/trt_engines/${MODEL_NAME}/fp16/1-gpu
|
||||
```
|
||||
|
||||
## Gemma3
|
||||
|
||||
**NOTE: We only support Gemma3 VLMs in Pytorch workflow.**
|
||||
|
||||
Gemma3VL decoder requires a custom attention mask while processing images. During the context phase:
|
||||
- Text tokens attend to other tokens in a causal fashion (standard autoregressive behavior)
|
||||
- Image tokens attend to other tokens in a causal fashion AND attend to other tokens from the same image in a bidirectional manner
|
||||
|
||||
**Reference:** [Gemma3 Model Documentation](https://huggingface.co/docs/transformers/en/model_doc/gemma3)
|
||||
|
||||
We support this custom mask with FlashInfer attention backend.
|
||||
|
||||
### Requirements
|
||||
|
||||
To ensure expected behavior with Gemma3VL, the following configurations are **required**:
|
||||
- **Attention Backend**: Use the FlashInfer attention backend
|
||||
- **Chunked Prefill**: Must be disabled
|
||||
- **KV Cache Reuse**: Must be disabled
|
||||
|
||||
### Quick Start
|
||||
|
||||
#### 1. Download Model Weights
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="gemma-3-27b-it"
|
||||
git clone https://huggingface.co/google/${MODEL_NAME}
|
||||
```
|
||||
|
||||
#### 2. Interactive Testing
|
||||
|
||||
Use the `quickstart_multimodal.py` script for quick testing:
|
||||
|
||||
```bash
|
||||
python3 examples/llm-api/quickstart_multimodal.py \
|
||||
--model_dir ${MODEL_NAME}/ \
|
||||
--modality image \
|
||||
--image_format pil \
|
||||
--attention_backend FLASHINFER \
|
||||
--disable_kv_cache_reuse
|
||||
```
|
||||
|
||||
#### 3. Model Serving
|
||||
|
||||
Serve the model using `trtllm-serve` with the required llmapi arguments mentioned in a yaml file:
|
||||
|
||||
```bash
|
||||
# Create the configuration file
|
||||
cat > extra-llm-api-options.yaml << 'EOF'
|
||||
cuda_graph_config: null
|
||||
attn_backend: "FLASHINFER"
|
||||
enable_chunked_prefill: false
|
||||
kv_cache_config:
|
||||
enable_block_reuse: false
|
||||
EOF
|
||||
|
||||
# Serve the model
|
||||
trtllm-serve ${MODEL_NAME}/ \
|
||||
--backend pytorch \
|
||||
--tp_size 1 \
|
||||
--port 8000 \
|
||||
--max_batch_size 4 \
|
||||
--extra_llm_api_options extra-llm-api-options.yaml
|
||||
```
|
||||
|
||||
### Supported Model Variants
|
||||
|
||||
Currently supported Gemma3 variants: 4B, 12B, 27B
|
||||
|
||||
|
||||
## InternLM-XComposer2
|
||||
|
||||
**NOTE: We only support InternLM-XComposer-VL-7b for now**
|
||||
|
||||
@ -170,16 +170,11 @@ def cleanUpNodeResourcesMultiNodes(def pipeline, SlurmCluster cluster, String jo
|
||||
"-e 's/.*Submitted batch job \\([0-9]\\+\\).*/\\1/p' " +
|
||||
"-e 's/.*srun: job \\([0-9]\\+\\) queued.*/\\1/p' " +
|
||||
"-e 's/.*srun: job \\([0-9]\\+\\) has been allocated.*/\\1/p' " +
|
||||
"${slurmOutputFile} | tail -n1\""
|
||||
"${slurmOutputFile} | tail -n1 || true\""
|
||||
),
|
||||
returnStdout: true
|
||||
).trim()
|
||||
|
||||
if (!slurmJobID || !slurmJobID.isNumber()) {
|
||||
Utils.exec(pipeline, script: Utils.sshUserCmd(remote, "\"cat ${slurmOutputFile}\""))
|
||||
error("Slurm job did not submit successfully. No job ID found.")
|
||||
}
|
||||
|
||||
Utils.exec(pipeline, script: "echo Slurm job ID: ${slurmJobID}")
|
||||
|
||||
Utils.exec(pipeline, script: "echo Sleeping to allow slurm job termination; sleep 30")
|
||||
@ -196,10 +191,18 @@ def cleanUpNodeResourcesMultiNodes(def pipeline, SlurmCluster cluster, String jo
|
||||
pipeline,
|
||||
script: Utils.sshUserCmd(
|
||||
remote,
|
||||
"rm -rf /home/svc_tensorrt/bloom/scripts/${jobUID}"
|
||||
"\"rm -rf /home/svc_tensorrt/bloom/scripts/${jobUID} || true\""
|
||||
)
|
||||
)
|
||||
|
||||
if (!slurmJobID || !slurmJobID.isNumber()) {
|
||||
Utils.exec(pipeline, script: Utils.sshUserCmd(remote, "\"cat ${slurmOutputFile} || true\""))
|
||||
echo "Slurm job did not submit successfully. No job ID found."
|
||||
} else {
|
||||
def newSlurmOutputFile = slurmOutputFile.replace("%j", slurmJobID)
|
||||
Utils.exec(pipeline, script: Utils.sshUserCmd(remote, "\"mv ${slurmOutputFile} ${newSlurmOutputFile} || true\""))
|
||||
}
|
||||
|
||||
Utils.exec(pipeline, script: "echo Slurm job ID: ${slurmJobID} cleaned up")
|
||||
}
|
||||
}
|
||||
@ -214,6 +217,12 @@ def cleanUpNodeResources(def pipeline, SlurmCluster cluster, String nodeName, St
|
||||
allowAnyHosts: true,
|
||||
]
|
||||
|
||||
Utils.exec(pipeline, script: "echo Sleeping to allow docker stop; sleep 30")
|
||||
|
||||
CloudManager.destroyNode(nodeName)
|
||||
|
||||
Utils.exec(pipeline, script: "echo Sleeping to allow node destruction; sleep 30")
|
||||
|
||||
Utils.exec(pipeline, script: "apt-get update && apt-get install -y sshpass openssh-client")
|
||||
|
||||
Utils.exec(pipeline, script: "echo Slurm job ID: ${slurmJobID}")
|
||||
@ -230,7 +239,7 @@ def cleanUpNodeResources(def pipeline, SlurmCluster cluster, String nodeName, St
|
||||
pipeline,
|
||||
script: Utils.sshUserCmd(
|
||||
remote,
|
||||
"rm -rf /home/svc_tensorrt/bloom/scripts/agent-${nodeName}.jar /home/svc_tensorrt/bloom/scripts/${nodeName}-slurm_jenkins_agent_setup.sh"
|
||||
"\"rm -rf /home/svc_tensorrt/bloom/scripts/agent-${nodeName}.jar /home/svc_tensorrt/bloom/scripts/${nodeName}-slurm_jenkins_agent_setup.sh || true\""
|
||||
)
|
||||
)
|
||||
|
||||
@ -330,7 +339,7 @@ def runLLMTestlistOnSlurm(pipeline, platform, testList, config=VANILLA_CONFIG, p
|
||||
slurmJobID = jobIDs ? jobIDs[-1] : null
|
||||
|
||||
if (!slurmJobID || !slurmJobID.isNumber()) {
|
||||
error("Slurm job did not submit successfully. No job ID found.\nSubmission output:\n${slurmSubmitOutput}")
|
||||
echo "Slurm job did not submit successfully. No job ID found.\nSubmission output:\n${slurmSubmitOutput}"
|
||||
}
|
||||
Utils.exec(pipeline, script: "echo Slurm job ID: ${slurmJobID}")
|
||||
Utils.exec(pipeline, script: "echo Sleeping to allow agent initialization; sleep 30")
|
||||
@ -377,12 +386,22 @@ def runLLMTestlistOnSlurm(pipeline, platform, testList, config=VANILLA_CONFIG, p
|
||||
error "The Slurm node does not come online in the waiting period. Terminating the job."
|
||||
}
|
||||
}
|
||||
} catch (Exception e) {
|
||||
if (e.getMessage()?.contains("Failed to kill container")) {
|
||||
echo "Known benign error ignored: ${e.getMessage()}"
|
||||
} else {
|
||||
throw e // Re-throw if it's a different IOException
|
||||
}
|
||||
} finally {
|
||||
stage('Clean up SLURM Resources') {
|
||||
Utils.exec(pipeline, script: "echo Sleeping to allow docker stop; sleep 30")
|
||||
CloudManager.destroyNode(nodeName)
|
||||
Utils.exec(pipeline, script: "echo Sleeping to allow node destruction; sleep 30")
|
||||
cleanUpNodeResources(pipeline, cluster, nodeName, slurmJobID)
|
||||
stage("Clean up SLURM Resources") {
|
||||
// Workaround to handle the interruption during clean up SLURM resources
|
||||
retry(3) {
|
||||
try {
|
||||
cleanUpNodeResources(pipeline, cluster, nodeName, slurmJobID)
|
||||
} catch (Exception e) {
|
||||
error "Error during clean up SLURM resources: ${e.getMessage()} and retrying."
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -436,7 +455,7 @@ def runLLMTestlistOnSlurm_MultiNodes(pipeline, platform, testList, config=VANILL
|
||||
def llmSrcLocal = "${llmPath}/TensorRT-LLM/src"
|
||||
def scriptRunNode = "${jobWorkspace}/${jobUID}-slurm_run.sh"
|
||||
def scriptLaunch = "${jobWorkspace}/${jobUID}-slurm_launch.sh"
|
||||
slurmOutputFile = "${jobWorkspace}/${jobUID}-slurm_output.log"
|
||||
slurmOutputFile = SlurmConfig.getOutputFilePath("/home/svc_tensorrt/slurm-logs", jobUID)
|
||||
def testListPathNode = "${jobWorkspace}/${testList}.txt"
|
||||
def waivesListPathNode = "${jobWorkspace}/waives.txt"
|
||||
def isAarch64 = config.contains("aarch64")
|
||||
@ -490,6 +509,7 @@ def runLLMTestlistOnSlurm_MultiNodes(pipeline, platform, testList, config=VANILL
|
||||
|
||||
def srunCmd = SlurmConfig.generateMultiNodeCommand(partition, taskArgs, scriptRunNode)
|
||||
scriptLaunchDestPath = Utils.createTempLocation(pipeline, "./slurm_launch.sh")
|
||||
// TODO: check if the tee always returns 0
|
||||
def scriptContent = """#!/bin/bash
|
||||
export jobWorkspace=$jobWorkspace
|
||||
export tarName=$tarName
|
||||
@ -531,8 +551,15 @@ def runLLMTestlistOnSlurm_MultiNodes(pipeline, platform, testList, config=VANILL
|
||||
} finally {
|
||||
uploadResults(pipeline, cluster, jobUID, stageName)
|
||||
|
||||
stage('Clean up SLURM Resources') {
|
||||
cleanUpNodeResourcesMultiNodes(pipeline, cluster, jobUID, slurmOutputFile)
|
||||
stage("Clean up SLURM Resources") {
|
||||
// Workaround to handle the interruption during clean up SLURM resources
|
||||
retry(3) {
|
||||
try {
|
||||
cleanUpNodeResourcesMultiNodes(pipeline, cluster, jobUID, slurmOutputFile)
|
||||
} catch (Exception e) {
|
||||
error "Error during clean up SLURM resources: ${e.getMessage()} and retrying."
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -660,7 +687,7 @@ def cacheErrorAndUploadResult(stageName, taskRunner, finallyRunner, noResultIfSu
|
||||
if (stageIsInterrupted) {
|
||||
echo "Stage is interrupted, skip to upload test result."
|
||||
} else {
|
||||
sh 'if [ "$(id -u)" -eq 0 ]; then dmesg; fi'
|
||||
sh 'if [ "$(id -u)" -eq 0 ]; then dmesg || true; fi'
|
||||
if (noResultIfSuccess && !stageIsFailed) {
|
||||
// Clean up the workspace
|
||||
sh """
|
||||
@ -1555,7 +1582,7 @@ def runLLMTestlistOnPlatformImpl(pipeline, platform, testList, config=VANILLA_CO
|
||||
stage ("[${stageName}] Run Pytest")
|
||||
{
|
||||
echoNodeAndGpuInfo(pipeline, stageName)
|
||||
sh 'if [ "$(id -u)" -eq 0 ]; then dmesg -C; fi'
|
||||
sh 'if [ "$(id -u)" -eq 0 ]; then dmesg -C || true; fi'
|
||||
|
||||
def extraInternalEnv = ""
|
||||
def pytestTestTimeout = "3600"
|
||||
@ -2056,7 +2083,8 @@ def launchTestJobs(pipeline, testFilter, dockerNode=null)
|
||||
fullSet += SBSATestConfigs.keySet()
|
||||
|
||||
SBSASlurmTestConfigs = [
|
||||
"GB200-PyTorch-1": ["gb200-single", "l0_gb200", 1, 1],
|
||||
// Disable GB200-PyTorch-1 due to OOM (https://nvbugspro.nvidia.com/bug/5490507)
|
||||
//"GB200-PyTorch-1": ["gb200-single", "l0_gb200", 1, 1],
|
||||
"GB200-4_GPUs-PyTorch-1": ["gb200-x4", "l0_gb200_multi_gpus", 1, 1, 4],
|
||||
"GB200-4_GPUs-PyTorch-Post-Merge-1": ["gb200-x4", "l0_gb200_multi_gpus", 1, 1, 4],
|
||||
"GB300-4_GPUs-PyTorch-Post-Merge-1": ["gb300-x4", "l0_gb300_multi_gpus", 1, 1, 4],
|
||||
|
||||
@ -768,8 +768,14 @@ class TrtllmAttentionMetadata(AttentionMetadata):
|
||||
self.kv_cache_block_offsets[:, :self.num_seqs].copy_(
|
||||
self.host_kv_cache_block_offsets[:, :self.num_seqs],
|
||||
non_blocking=True)
|
||||
|
||||
error_message = (
|
||||
f"The max KV cache length of input sequences ({self.kv_lens[:self.num_seqs].max()}) "
|
||||
f"exceeds the KV cache manager's maximum supported length "
|
||||
f"({self.kv_cache_manager.max_seq_len}).")
|
||||
|
||||
assert self.kv_lens[:self.num_seqs].max(
|
||||
) <= self.kv_cache_manager.max_seq_len, f"Please set max_seq_len to at least {self.kv_lens[:self.num_seqs].max()} for kv cache manager."
|
||||
) <= self.kv_cache_manager.max_seq_len, error_message
|
||||
|
||||
self.kv_lens_cuda_runtime = self.kv_lens_cuda[:self.num_seqs]
|
||||
self.kv_lens_runtime = self.kv_lens[:self.num_seqs]
|
||||
|
||||
@ -9,7 +9,6 @@ from tensorrt_llm._torch.pyexecutor.seq_slot_manager import SeqSlotManager
|
||||
from tensorrt_llm._utils import nvtx_range
|
||||
|
||||
from ...._utils import mpi_rank, mpi_world_size
|
||||
from ....bindings.executor import ExecutorConfig
|
||||
from ....bindings.internal.batch_manager import CacheType
|
||||
from ....mapping import Mapping
|
||||
from ...distributed import MPIDist
|
||||
@ -259,7 +258,7 @@ class ADEngine(ModelEngine):
|
||||
return {"logits": logits_flat}
|
||||
|
||||
|
||||
def create_autodeploy_executor(executor_config: ExecutorConfig, checkpoint_dir: str = None):
|
||||
def create_autodeploy_executor(ad_config: LlmArgs):
|
||||
"""Create an AutoDeploy executor from the given configuration and checkpoint directory.
|
||||
|
||||
This is the entrypoint API to the _autodeploy backend.
|
||||
@ -276,8 +275,7 @@ def create_autodeploy_executor(executor_config: ExecutorConfig, checkpoint_dir:
|
||||
|
||||
# some config
|
||||
msg = "pytorch_backend_config must be an AD LlmArgs object"
|
||||
assert isinstance(executor_config.pytorch_backend_config, LlmArgs), msg
|
||||
ad_config: LlmArgs = executor_config.pytorch_backend_config
|
||||
assert isinstance(ad_config, LlmArgs), msg
|
||||
assert ad_config.max_beam_width <= 1, "_autodeploy + beam_search is not supported"
|
||||
|
||||
max_num_sequences = ad_config.max_batch_size * dist_mapping.pp_size
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import glob
|
||||
import multiprocessing
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, List
|
||||
|
||||
import psutil
|
||||
@ -128,7 +129,7 @@ class HfWeightLoader(BaseWeightLoader):
|
||||
if len(local_file_names) == 0:
|
||||
return
|
||||
|
||||
max_processes = min(multiprocessing.cpu_count() * 2, 16,
|
||||
len(local_file_names))
|
||||
with multiprocessing.Pool(processes=max_processes) as pool:
|
||||
pool.map(self._prefetch_one_file, local_file_names)
|
||||
max_workers = min(multiprocessing.cpu_count() * 2, 16,
|
||||
len(local_file_names))
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
list(executor.map(self._prefetch_one_file, local_file_names))
|
||||
|
||||
@ -284,17 +284,3 @@ class Gemma3VLM(PreTrainedModel):
|
||||
attn_metadata=attn_metadata)[-1]
|
||||
image_features = self.mm_projector(image_features)
|
||||
return image_features
|
||||
|
||||
|
||||
def _load_weights_into_hf_module(
|
||||
model: torch.nn.Module,
|
||||
weights: dict,
|
||||
prefix: str,
|
||||
model_name: str,
|
||||
) -> None:
|
||||
filtered_weights = filter_weights(prefix, weights)
|
||||
missing_keys, _ = model.load_state_dict(filtered_weights)
|
||||
if len(missing_keys) > 0:
|
||||
raise KeyError(
|
||||
f"Missing the following keys for the {model_name} in the checkpoint: "
|
||||
f"[{', '.join(missing_keys)}].")
|
||||
|
||||
@ -186,6 +186,9 @@ class Llama4Attention(Attention):
|
||||
mrope_config,
|
||||
attention_sinks=None)
|
||||
|
||||
if isinstance(attn_output, tuple):
|
||||
attn_output = Fp4QuantizedTensor(attn_output[0], attn_output[1])
|
||||
|
||||
attn_output = self.o_proj(attn_output,
|
||||
all_reduce_params=all_reduce_params)
|
||||
|
||||
@ -554,50 +557,60 @@ class Llama4DecoderLayer(DecoderLayer):
|
||||
hidden_states, residual)
|
||||
|
||||
if (self.fusion_config.POST_MOE_FUSION
|
||||
or self.fusion_config.POST_MLP_FUSION
|
||||
) and self.next_layer_layernorm is not None:
|
||||
# Get the scale for the next allreduce fusion op
|
||||
if self.next_attn is not None and (self.is_nvfp4
|
||||
or self.is_fp8_quant):
|
||||
scale = self.next_attn.qkv_proj.input_scale
|
||||
else:
|
||||
# Add just the fusion op to RESIDUAL_RMS_NORM due to this is the last decoder layer
|
||||
self.post_feed_forward_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM
|
||||
scale = None
|
||||
|
||||
# TODO: MIN_LATENCY_MODE is hardcoded to False
|
||||
if cutlass_min_latency_mode:
|
||||
shared_output = hidden_states[0]
|
||||
hidden_states_activated_experts = hidden_states[1]
|
||||
num_activated_experts_per_node = hidden_states[2]
|
||||
experts_to_token_score = hidden_states[3]
|
||||
|
||||
allreduce_output = self.moe_allreduce(
|
||||
residual,
|
||||
self.next_layer_layernorm.weight,
|
||||
device_num_experts=num_activated_experts_per_node,
|
||||
scale_input=experts_to_token_score,
|
||||
active_experts_token_input=hidden_states_activated_experts,
|
||||
token_input=shared_output,
|
||||
eps=self.next_layer_layernorm.variance_epsilon,
|
||||
)
|
||||
else:
|
||||
allreduce_output = self.all_reduce(
|
||||
or self.fusion_config.POST_MLP_FUSION):
|
||||
# If there is no extra layernorm, do another pure allreduce because
|
||||
# the allreduce in feed-forward module has been disabled.
|
||||
if self.next_layer_layernorm is None:
|
||||
hidden_states, residual = self.all_reduce(
|
||||
hidden_states,
|
||||
all_reduce_params=AllReduceParams(
|
||||
fusion_op=self.post_feed_forward_fusion_op,
|
||||
fusion_op=None,
|
||||
residual=residual,
|
||||
norm_weight=self.next_layer_layernorm.weight,
|
||||
scale=scale,
|
||||
eps=self.next_layer_layernorm.variance_epsilon,
|
||||
))
|
||||
|
||||
# Unpack the allreduce output
|
||||
if self.next_attn is not None and self.is_nvfp4:
|
||||
act_fp4, act_sf, residual = allreduce_output
|
||||
hidden_states = Fp4QuantizedTensor(act_fp4, act_sf)
|
||||
else:
|
||||
hidden_states, residual = allreduce_output
|
||||
# The next layernorm exists but it could be the last decoder layer.
|
||||
# Adjust the scale and fusion pattern.
|
||||
if self.next_attn is not None and (self.is_nvfp4
|
||||
or self.is_fp8_quant):
|
||||
scale = self.next_attn.qkv_proj.input_scale
|
||||
else:
|
||||
self.post_feed_forward_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM
|
||||
scale = None
|
||||
|
||||
# TODO: MIN_LATENCY_MODE is hardcoded to False
|
||||
if cutlass_min_latency_mode:
|
||||
shared_output = hidden_states[0]
|
||||
hidden_states_activated_experts = hidden_states[1]
|
||||
num_activated_experts_per_node = hidden_states[2]
|
||||
experts_to_token_score = hidden_states[3]
|
||||
|
||||
allreduce_output = self.moe_allreduce(
|
||||
residual,
|
||||
self.next_layer_layernorm.weight,
|
||||
device_num_experts=num_activated_experts_per_node,
|
||||
scale_input=experts_to_token_score,
|
||||
active_experts_token_input=
|
||||
hidden_states_activated_experts,
|
||||
token_input=shared_output,
|
||||
eps=self.next_layer_layernorm.variance_epsilon,
|
||||
)
|
||||
else:
|
||||
allreduce_output = self.all_reduce(
|
||||
hidden_states,
|
||||
all_reduce_params=AllReduceParams(
|
||||
fusion_op=self.post_feed_forward_fusion_op,
|
||||
residual=residual,
|
||||
norm_weight=self.next_layer_layernorm.weight,
|
||||
scale=scale,
|
||||
eps=self.next_layer_layernorm.variance_epsilon,
|
||||
))
|
||||
|
||||
# Unpack the allreduce output
|
||||
if self.next_attn is not None and self.is_nvfp4:
|
||||
act_fp4, act_sf, residual = allreduce_output
|
||||
hidden_states = Fp4QuantizedTensor(act_fp4, act_sf)
|
||||
else:
|
||||
hidden_states, residual = allreduce_output
|
||||
elif self.next_layer_layernorm:
|
||||
hidden_states, residual = self.next_layer_layernorm(
|
||||
hidden_states, residual)
|
||||
@ -710,6 +723,7 @@ class LlamaDecoderLayer(DecoderLayer):
|
||||
scale = self.mlp.gate_up_proj.input_scale
|
||||
else:
|
||||
scale = None
|
||||
|
||||
all_reduce_output = self.all_reduce(
|
||||
hidden_states,
|
||||
all_reduce_params=AllReduceParams(
|
||||
@ -752,25 +766,40 @@ class LlamaDecoderLayer(DecoderLayer):
|
||||
|
||||
spec_metadata.maybe_capture_hidden_states(self.layer_idx,
|
||||
hidden_states, residual)
|
||||
if self.POST_MLP_FUSION and self.next_attn is not None:
|
||||
if self.is_nvfp4 or self.is_fp8_quant:
|
||||
scale = self.next_attn.qkv_proj.input_scale
|
||||
|
||||
if self.POST_MLP_FUSION:
|
||||
# If there is no extra layernorm, do another pure allreduce.
|
||||
if self.next_layer_layernorm is None:
|
||||
hidden_states, residual = self.all_reduce(
|
||||
hidden_states,
|
||||
all_reduce_params=AllReduceParams(
|
||||
fusion_op=None,
|
||||
residual=residual,
|
||||
))
|
||||
else:
|
||||
scale = None
|
||||
all_reduce_output = self.all_reduce(
|
||||
hidden_states,
|
||||
all_reduce_params=AllReduceParams(
|
||||
fusion_op=self.post_mlp_fusion_op,
|
||||
residual=residual,
|
||||
norm_weight=self.next_layer_layernorm.weight,
|
||||
scale=scale,
|
||||
eps=self.next_layer_layernorm.variance_epsilon,
|
||||
))
|
||||
if self.is_nvfp4:
|
||||
act_fp4, act_sf, residual = all_reduce_output
|
||||
hidden_states = Fp4QuantizedTensor(act_fp4, act_sf)
|
||||
else:
|
||||
hidden_states, residual = all_reduce_output
|
||||
# The next layernorm exists but it could be the last decoder layer.
|
||||
# Adjust the scale and fusion pattern.
|
||||
if self.next_attn is not None and (self.is_nvfp4
|
||||
or self.is_fp8_quant):
|
||||
scale = self.next_attn.qkv_proj.input_scale
|
||||
else:
|
||||
self.post_mlp_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM
|
||||
scale = None
|
||||
|
||||
all_reduce_output = self.all_reduce(
|
||||
hidden_states,
|
||||
all_reduce_params=AllReduceParams(
|
||||
fusion_op=self.post_mlp_fusion_op,
|
||||
residual=residual,
|
||||
norm_weight=self.next_layer_layernorm.weight,
|
||||
scale=scale,
|
||||
eps=self.next_layer_layernorm.variance_epsilon,
|
||||
))
|
||||
if self.next_attn is not None and self.is_nvfp4:
|
||||
act_fp4, act_sf, residual = all_reduce_output
|
||||
hidden_states = Fp4QuantizedTensor(act_fp4, act_sf)
|
||||
else:
|
||||
hidden_states, residual = all_reduce_output
|
||||
elif self.next_layer_layernorm:
|
||||
hidden_states, residual = self.next_layer_layernorm(
|
||||
hidden_states, residual)
|
||||
|
||||
@ -314,8 +314,10 @@ class Mistral3VLM(PreTrainedModel):
|
||||
self.llm = MistralForCausalLM(llm_model_config)
|
||||
|
||||
self._device = "cuda"
|
||||
vision_model_config = self._get_sub_model_config(
|
||||
model_config, "vision_config")
|
||||
# NOTE: current `modelopt` does not support quantizing the vision portion.
|
||||
vision_model_config = self._get_sub_model_config(model_config,
|
||||
"vision_config",
|
||||
quant_config=None)
|
||||
self._vision_tower = modeling_pixtral.PixtralVisionModel(
|
||||
vision_model_config)
|
||||
self._multi_modal_projector = Mistral3MultiModalProjector(model_config)
|
||||
@ -385,7 +387,7 @@ class Mistral3VLM(PreTrainedModel):
|
||||
f"Expected as many `pixel_values` ({len(pixel_values)}) and "
|
||||
f"`image_sizes` ({len(image_sizes)}) as number of multimodal parameters "
|
||||
f"({multimodal_params_len}).")
|
||||
batched_pixel_values, batched_image_sizes = self._batch_pixel_values(
|
||||
batched_pixel_values, batched_image_sizes = self.batch_pixel_values(
|
||||
pixel_values=pixel_values, image_sizes=image_sizes)
|
||||
mm_embeds = [
|
||||
self._get_image_features(pixel_values=batched_pixel_values,
|
||||
@ -411,12 +413,14 @@ class Mistral3VLM(PreTrainedModel):
|
||||
def _get_sub_model_config(
|
||||
model_config: ModelConfig[MistralConfig],
|
||||
name: str,
|
||||
**changes,
|
||||
) -> ModelConfig:
|
||||
# Extract the subconfig from the `transformers` config and shove it into our own
|
||||
# `ModelConfig` class.
|
||||
sub_model_config: ModelConfig[MistralConfig] = dataclasses.replace(
|
||||
model_config,
|
||||
pretrained_config=getattr(model_config.pretrained_config, name),
|
||||
**changes,
|
||||
)
|
||||
# Make sure some fields that are not explicitly included in the sub config, but present
|
||||
# in the top-level config, are replicated.
|
||||
@ -450,21 +454,38 @@ class Mistral3VLM(PreTrainedModel):
|
||||
# (the transformers one expected numpy arrays).
|
||||
@staticmethod
|
||||
@torch.inference_mode()
|
||||
def _batch_pixel_values(
|
||||
def batch_pixel_values(
|
||||
pixel_values: List[torch.Tensor],
|
||||
image_sizes: List[torch.Tensor],
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# NOTES:
|
||||
# * `pixel_values` is a list of `[B_idx, C, H_idx, W_idx]` tensors, i.e. a batch of images as
|
||||
# padded + batched by the input processor.
|
||||
# The height (H_idx) and width (W_idx) of each element need not coincide.
|
||||
# * Similarly, each element in `image_sizes` describes the original image sizes prior to
|
||||
# padding for the corresponding element in `pixel_values`.
|
||||
|
||||
# The below creates a single `[sum(B_idx), 2]` tensor describing all image sizes, and then
|
||||
# calculates the maximum height / width across all of them.
|
||||
batched_image_sizes = torch.cat(image_sizes)
|
||||
max_shape = batched_image_sizes.max(dim=0).values
|
||||
|
||||
# This next step then pads the pixel values potentially a second time by using the `max_shape`
|
||||
# computed above. Note that as far as this function is concerned, the original sizes for
|
||||
# batching purposes can be deduced from looking at the tensors in `pixel_values`, NOT in
|
||||
# `image_sizes`.
|
||||
pixel_values = [
|
||||
torchvision.transforms.v2.functional.pad(
|
||||
image,
|
||||
# Per torchvision docs, this should be in LTRB order if it's a sequence of 4 numbers.
|
||||
padding=[0, 0, max_shape[1] - size[1], max_shape[0] - size[0]],
|
||||
padding=[
|
||||
0, 0, max_shape[1] - image.shape[-1],
|
||||
max_shape[0] - image.shape[-2]
|
||||
],
|
||||
# Values extracted from HF implementation.
|
||||
fill=0.0,
|
||||
padding_mode="constant",
|
||||
) for image, size in zip(pixel_values, batched_image_sizes)
|
||||
) for image in pixel_values
|
||||
]
|
||||
return torch.cat(pixel_values), batched_image_sizes
|
||||
|
||||
|
||||
@ -1,23 +1,14 @@
|
||||
# Plan for phi4-mm model support.
|
||||
# (done) step 1: support legacy inference pipeline for phi4-mm model.
|
||||
# (done) step 2: refactor the inference pipeline to use AGGREGATE mode (https://github.com/NVIDIA/TensorRT-LLM/pull/5522).
|
||||
# (todo) step 3: optimization
|
||||
# * use TRTLLM-attention to replace original pytorch attention in vision/audio encoders.
|
||||
# * use data parallel to accelerate inference.
|
||||
# (todo) step 2: refactor the inference pipeline to use AGGREGATE mode (https://github.com/NVIDIA/TensorRT-LLM/pull/5522).
|
||||
|
||||
import copy
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from PIL import Image
|
||||
|
||||
from tensorrt_llm.inputs.multimodal import MultimodalParams
|
||||
|
||||
from ...executor.request import LoRARequest
|
||||
from ...inputs import (ExtraProcessedInputs, InputProcessor,
|
||||
MultimodalPlaceholderMetadata,
|
||||
@ -32,361 +23,16 @@ from .modeling_auto import AutoModelForCausalLM
|
||||
from .modeling_multimodal_utils import fuse_input_embeds
|
||||
from .modeling_utils import register_auto_model
|
||||
|
||||
# Special token ids from the original Phi-4-multimodal-instruct implementation
|
||||
_IMAGE_SPECIAL_TOKEN_ID = 200010 # '<|endoftext10|>' from HF `modeling_phi4mm.py`
|
||||
_AUDIO_SPECIAL_TOKEN_ID = 200011 # '<|endoftext11|>' from HF `modeling_phi4mm.py`
|
||||
_PAD_TOKEN_ID = 199999 # '<|endoftext|>' from HF `special_tokens_map.json`
|
||||
_COMPATIBLE_IMAGE_SPECIAL_TOKEN_ID_RANGE = [-9999,
|
||||
-1] # from HF `modeling_phi4mm.py`
|
||||
_COMPATIBLE_AUDIO_SPECIAL_TOKEN_ID_RANGE = [float('-inf'), -10000
|
||||
] # from HF `modeling_phi4mm.py`
|
||||
|
||||
# Below classes will be loaded from HuggingFace codes, rather than using transformers version,
|
||||
# since transformers version is not compatible with checkpoints and configs from `microsoft/Phi-4-multimodal-instruct`.
|
||||
Phi4MMAudioEmbedding = None
|
||||
Phi4MMImageEmbedding = None
|
||||
Phi4MMConfig = None
|
||||
# Special tokens
|
||||
_IMAGE_SPECIAL_TOKEN_ID = 200010 # '<|endoftext10|>'
|
||||
_AUDIO_SPECIAL_TOKEN_ID = 200011 # '<|endoftext11|>'
|
||||
|
||||
|
||||
# Make this a runtime lookup rather than a module-wide constant for easier unit testing.
|
||||
def _is_disagg() -> bool:
|
||||
return os.getenv("TLLM_MULTIMODAL_DISAGGREGATED", "0") == "1"
|
||||
|
||||
|
||||
# Load the Phi4MM classes from HuggingFace Phi-4-multimodal-instruct repo.
|
||||
# Remove this function by using the transformers version of Phi4Multimodal when weights/configs are converted to transformers format.
|
||||
def _load_phi4mm_classes(local_path):
|
||||
"""Load Phi4MM classes from the specified local path."""
|
||||
global Phi4MMAudioEmbedding, Phi4MMImageEmbedding, Phi4MMConfig
|
||||
if Phi4MMAudioEmbedding is not None and Phi4MMImageEmbedding is not None and Phi4MMConfig is not None:
|
||||
return
|
||||
|
||||
# Add parent folder to sys.path to enable relative import.
|
||||
original_sys_path = sys.path.copy()
|
||||
package_folder = Path(local_path)
|
||||
parent_folder = str(package_folder.parent)
|
||||
if parent_folder not in sys.path:
|
||||
sys.path.insert(0, parent_folder)
|
||||
|
||||
try:
|
||||
# Import Phi4MMConfig from configuration_phi4mm.py.
|
||||
config_path = os.path.join(local_path, 'configuration_phi4mm.py')
|
||||
if not os.path.exists(config_path):
|
||||
raise FileNotFoundError(
|
||||
f"configuration_phi4mm.py not found at {local_path}.")
|
||||
spec = importlib.util.spec_from_file_location("hf_config", config_path)
|
||||
hf_config = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(hf_config)
|
||||
Phi4MMConfig = hf_config.Phi4MMConfig
|
||||
|
||||
# Import Phi4MMAudioEmbedding and Phi4MMImageEmbedding from modeling_phi4mm.py.
|
||||
modeling_phi4mm_path = os.path.join(local_path, 'modeling_phi4mm.py')
|
||||
if not os.path.exists(modeling_phi4mm_path):
|
||||
raise FileNotFoundError(
|
||||
f"modeling_phi4mm.py not found at {local_path}.")
|
||||
# `Phi-4-multimodal-instruct` as the package name to avoid relative import errors.
|
||||
# `hf_modeling_phi4mm` as the module name to avoid name conflicts.
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"Phi-4-multimodal-instruct.hf_modeling_phi4mm",
|
||||
modeling_phi4mm_path)
|
||||
hf_modeling_phi4mm = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(hf_modeling_phi4mm)
|
||||
Phi4MMAudioEmbedding = hf_modeling_phi4mm.Phi4MMAudioEmbedding
|
||||
Phi4MMImageEmbedding = hf_modeling_phi4mm.Phi4MMImageEmbedding
|
||||
finally:
|
||||
sys.path = original_sys_path
|
||||
|
||||
|
||||
class HFPhi4MultimodalEncoder(transformers.PreTrainedModel,
|
||||
transformers.generation.GenerationMixin):
|
||||
|
||||
# Copy and modify from https://huggingface.co/microsoft/Phi-4-multimodal-instruct/blob/main/modeling_phi4mm.py::Phi4MMImageAudioEmbedding
|
||||
# Note: the HF implementation here will cause duplicated encoders on all GPUs for TP>1 scenario.
|
||||
# TODO: use TRTLLM-attention to replace original pytorch Flash_attn_2 in HFPhi4MultimodalEncoder.
|
||||
config_class = Phi4MMConfig
|
||||
base_model_prefix = "model"
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_cache_class = True
|
||||
|
||||
def __init__(self, config: transformers.PretrainedConfig, **kwargs):
|
||||
super().__init__(config, **kwargs)
|
||||
self.padding_idx = config.pad_token_id
|
||||
|
||||
self.embed_tokens = torch.nn.Embedding(config.vocab_size,
|
||||
config.hidden_size,
|
||||
self.padding_idx)
|
||||
|
||||
self._attn_implementation = config._attn_implementation
|
||||
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
embedding_config = {
|
||||
'embedding_cls': config.embd_layer['embedding_cls'],
|
||||
**config.embd_layer
|
||||
}
|
||||
# The default values are from HuggingFace Phi-4-multimodal-instruct codes.
|
||||
self.image_input_id = embedding_config.get('image_input_id', -1)
|
||||
self.audio_input_id = embedding_config.get('audio_input_id', -10000)
|
||||
if self.image_input_id == self.audio_input_id:
|
||||
raise ValueError(
|
||||
'image_input_id and audio_input_id should be different')
|
||||
|
||||
self.image_embd_layer_kwargs = embedding_config['image_embd_layer']
|
||||
self.image_embed = Phi4MMImageEmbedding(config,
|
||||
**self.image_embd_layer_kwargs)
|
||||
|
||||
self.audio_embd_layer_kwargs = embedding_config['audio_embd_layer']
|
||||
self.audio_embed = Phi4MMAudioEmbedding(config,
|
||||
**self.audio_embd_layer_kwargs)
|
||||
|
||||
def _replace_special_token_ids(self,
|
||||
input_ids: torch.Tensor) -> torch.Tensor:
|
||||
# Inplace-replacement for special token ids.
|
||||
torch.where(
|
||||
(input_ids >= _COMPATIBLE_IMAGE_SPECIAL_TOKEN_ID_RANGE[0])
|
||||
& (input_ids <= _COMPATIBLE_IMAGE_SPECIAL_TOKEN_ID_RANGE[1]),
|
||||
torch.tensor(_IMAGE_SPECIAL_TOKEN_ID),
|
||||
input_ids,
|
||||
out=input_ids,
|
||||
)
|
||||
torch.where(
|
||||
(input_ids >= _COMPATIBLE_AUDIO_SPECIAL_TOKEN_ID_RANGE[0])
|
||||
& (input_ids <= _COMPATIBLE_AUDIO_SPECIAL_TOKEN_ID_RANGE[1]),
|
||||
torch.tensor(_AUDIO_SPECIAL_TOKEN_ID),
|
||||
input_ids,
|
||||
out=input_ids,
|
||||
)
|
||||
return input_ids
|
||||
|
||||
def _batch_infer_image_embeds(
|
||||
self, batched_input_ids: torch.Tensor,
|
||||
multimodal_params: List[MultimodalParams]) -> torch.Tensor:
|
||||
# Batch image inputs and attention mask with padding along dim=1 (patch num).
|
||||
input_image_embeds_list, input_image_attn_mask_list, input_image_sizes_list = [], [], []
|
||||
for mm_param in multimodal_params:
|
||||
mm_data = mm_param.multimodal_data
|
||||
input_image_embeds = mm_data["input_image_embeds"]
|
||||
if input_image_embeds is not None and input_image_embeds.numel(
|
||||
) > 0:
|
||||
input_image_embeds_list.append(input_image_embeds)
|
||||
input_image_attn_mask_list.append(
|
||||
mm_data["image_attention_mask"])
|
||||
input_image_sizes_list.append(mm_data["image_sizes"])
|
||||
batched_image_hidden_states = None
|
||||
if len(input_image_embeds_list) > 0:
|
||||
# Padding image embeds/attn_masks along dim=1 (patch dimension).
|
||||
b_list = [x.shape[0] for x in input_image_embeds_list]
|
||||
p_list = [x.shape[1] for x in input_image_embeds_list]
|
||||
c_i, h_i, w_i = input_image_embeds_list[0].shape[2:5]
|
||||
h_i_attn, w_i_attn = input_image_attn_mask_list[0].shape[2:4]
|
||||
total_b = sum(b_list)
|
||||
max_p = max(p_list)
|
||||
batched_image_embeds = torch.zeros(
|
||||
(total_b, max_p, c_i, h_i, w_i),
|
||||
dtype=input_image_embeds_list[0].dtype,
|
||||
device=input_image_embeds_list[0].device)
|
||||
batched_image_attn_mask = torch.zeros(
|
||||
(total_b, max_p, h_i_attn, w_i_attn),
|
||||
dtype=input_image_embeds_list[0].dtype,
|
||||
device=input_image_embeds_list[0].device)
|
||||
b_offset = 0
|
||||
for i, tensor in enumerate(input_image_embeds_list):
|
||||
b, p = tensor.shape[:2]
|
||||
batched_image_embeds[b_offset:b_offset + b, :p] = tensor
|
||||
if input_image_attn_mask_list[i] is not None:
|
||||
batched_image_attn_mask[
|
||||
b_offset:b_offset +
|
||||
b, :p] = input_image_attn_mask_list[i]
|
||||
else:
|
||||
batched_image_attn_mask[b_offset:b_offset + b, :p] = 1
|
||||
b_offset += b
|
||||
batched_image_sizes = torch.cat(input_image_sizes_list, dim=0)
|
||||
# Forward image encoder with batched image embeds.
|
||||
batched_image_hidden_states = self.image_embed(
|
||||
input_ids=batched_input_ids,
|
||||
input_embeds=batched_image_embeds,
|
||||
image_sizes=batched_image_sizes,
|
||||
image_attention_mask=batched_image_attn_mask,
|
||||
wte=self.embed_tokens,
|
||||
)
|
||||
return batched_image_hidden_states
|
||||
|
||||
def _batch_infer_audio_embeds(
|
||||
self, batched_input_ids: torch.Tensor,
|
||||
multimodal_params: List[MultimodalParams]) -> torch.Tensor:
|
||||
# Batch audio inputs and attention mask with padding along dim=1 (patch num).
|
||||
input_audio_embeds_list, input_audio_attn_mask_list, input_audio_sizes_list = [], [], []
|
||||
for mm_param in multimodal_params:
|
||||
mm_data = mm_param.multimodal_data
|
||||
input_audio_embeds = mm_data["input_audio_embeds"]
|
||||
if input_audio_embeds is not None and input_audio_embeds.numel(
|
||||
) > 0:
|
||||
input_audio_embeds_list.append(input_audio_embeds)
|
||||
input_audio_attn_mask_list.append(
|
||||
mm_data["audio_attention_mask"])
|
||||
input_audio_sizes_list.append(mm_data["audio_embed_sizes"])
|
||||
batched_audio_hidden_states = None
|
||||
if len(input_audio_embeds_list) > 0:
|
||||
b_list = [x.shape[0] for x in input_audio_embeds_list]
|
||||
p_list = [x.shape[1] for x in input_audio_embeds_list]
|
||||
d_a = input_audio_embeds_list[0].shape[2]
|
||||
total_b = sum(b_list)
|
||||
max_p = max(p_list)
|
||||
batched_audio_embeds = torch.zeros(
|
||||
(total_b, max_p, d_a),
|
||||
dtype=input_audio_embeds_list[0].dtype,
|
||||
device=input_audio_embeds_list[0].device)
|
||||
batched_audio_attn_mask = torch.zeros(
|
||||
(total_b, max_p),
|
||||
dtype=input_audio_embeds_list[0].dtype,
|
||||
device=input_audio_embeds_list[0].device)
|
||||
b_offset = 0
|
||||
for i, tensor in enumerate(input_audio_embeds_list):
|
||||
b, p = tensor.shape[:2]
|
||||
batched_audio_embeds[b_offset:b_offset + b, :p] = tensor
|
||||
if input_audio_attn_mask_list[i] is not None:
|
||||
batched_audio_attn_mask[
|
||||
b_offset:b_offset +
|
||||
b, :p] = input_audio_attn_mask_list[i]
|
||||
else:
|
||||
batched_audio_attn_mask[b_offset:b_offset + b, :p] = 1
|
||||
b_offset += b
|
||||
batched_audio_sizes = torch.cat(input_audio_sizes_list, dim=0)
|
||||
# Forward audio encoder with batched audio embeds.
|
||||
batched_audio_hidden_states = self.audio_embed(
|
||||
input_ids=batched_input_ids,
|
||||
input_embeds=batched_audio_embeds,
|
||||
audio_embed_sizes=batched_audio_sizes,
|
||||
audio_attention_mask=batched_audio_attn_mask,
|
||||
wte=self.embed_tokens,
|
||||
)
|
||||
return batched_audio_hidden_states
|
||||
|
||||
def _encoding_per_request(
|
||||
self, multimodal_params: List[MultimodalParams],
|
||||
mm_token_ids: torch.Tensor) -> List[torch.FloatTensor]:
|
||||
# Loop implementation.
|
||||
mm_embeddings = []
|
||||
for i in range(len(multimodal_params)):
|
||||
input_ids = multimodal_params[i].multimodal_data["input_ids"]
|
||||
input_image_embeds = multimodal_params[i].multimodal_data[
|
||||
"input_image_embeds"]
|
||||
input_audio_embeds = multimodal_params[i].multimodal_data[
|
||||
"input_audio_embeds"]
|
||||
image_sizes = multimodal_params[i].multimodal_data["image_sizes"]
|
||||
image_attention_mask = multimodal_params[i].multimodal_data[
|
||||
"image_attention_mask"]
|
||||
audio_embed_sizes = multimodal_params[i].multimodal_data[
|
||||
"audio_embed_sizes"]
|
||||
audio_attention_mask = multimodal_params[i].multimodal_data[
|
||||
"audio_attention_mask"]
|
||||
audio_projection_mode = multimodal_params[i].multimodal_data[
|
||||
"audio_projection_mode"]
|
||||
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
input_ids = self._replace_special_token_ids(input_ids)
|
||||
image_position_mask = input_ids == _IMAGE_SPECIAL_TOKEN_ID
|
||||
non_image_position_mask = ~image_position_mask
|
||||
|
||||
image_hidden_states = None
|
||||
if input_image_embeds is not None:
|
||||
image_hidden_states = self.image_embed(
|
||||
input_ids=input_ids,
|
||||
input_embeds=input_image_embeds,
|
||||
image_sizes=image_sizes,
|
||||
wte=self.embed_tokens,
|
||||
image_attention_mask=image_attention_mask,
|
||||
)
|
||||
audio_hidden_states = None
|
||||
if input_audio_embeds is not None:
|
||||
audio_hidden_states = self.audio_embed(
|
||||
input_ids=input_ids,
|
||||
input_embeds=input_audio_embeds,
|
||||
audio_embed_sizes=audio_embed_sizes,
|
||||
audio_attention_mask=audio_attention_mask,
|
||||
wte=self.embed_tokens,
|
||||
audio_projection_mode=audio_projection_mode,
|
||||
)
|
||||
|
||||
if input_image_embeds is not None and input_audio_embeds is not None:
|
||||
dtype = image_hidden_states.dtype
|
||||
hidden_states = image_hidden_states * image_position_mask.to(
|
||||
dtype).unsqueeze(
|
||||
-1) + audio_hidden_states * non_image_position_mask.to(
|
||||
dtype).unsqueeze(-1)
|
||||
elif input_image_embeds is not None:
|
||||
hidden_states = image_hidden_states
|
||||
elif input_audio_embeds is not None:
|
||||
hidden_states = audio_hidden_states
|
||||
else:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
|
||||
# Postprocessing to get multimodal-only embeddings.
|
||||
mm_token_mask = torch.isin(input_ids, mm_token_ids)
|
||||
hidden_states = hidden_states[mm_token_mask]
|
||||
|
||||
mm_embeddings.append(hidden_states)
|
||||
return mm_embeddings
|
||||
|
||||
def _encoding_batch_request(
|
||||
self, multimodal_params: List[MultimodalParams],
|
||||
mm_token_ids: torch.Tensor) -> List[torch.FloatTensor]:
|
||||
# Batch input_ids.
|
||||
input_ids_list = [
|
||||
multimodal_params[i].multimodal_data["input_ids"]
|
||||
for i in range(len(multimodal_params))
|
||||
]
|
||||
max_input_ids_len = max(
|
||||
[input_ids.shape[1] for input_ids in input_ids_list])
|
||||
batched_input_ids = torch.full(
|
||||
(len(multimodal_params), max_input_ids_len),
|
||||
_PAD_TOKEN_ID,
|
||||
device=input_ids_list[0].device)
|
||||
for i, input_ids in enumerate(input_ids_list):
|
||||
batched_input_ids[i, :input_ids.shape[1]] = input_ids
|
||||
batched_input_ids = batched_input_ids.view(-1, max_input_ids_len)
|
||||
batched_input_ids = self._replace_special_token_ids(batched_input_ids)
|
||||
image_position_mask = batched_input_ids == _IMAGE_SPECIAL_TOKEN_ID
|
||||
non_image_position_mask = ~image_position_mask
|
||||
|
||||
# Batch inference for image and audio embeds.
|
||||
batched_image_hidden_states = self._batch_infer_image_embeds(
|
||||
batched_input_ids, multimodal_params)
|
||||
batched_audio_hidden_states = self._batch_infer_audio_embeds(
|
||||
batched_input_ids, multimodal_params)
|
||||
|
||||
# Combine different modalities into one.
|
||||
if batched_image_hidden_states is not None and batched_audio_hidden_states is not None:
|
||||
batched_hidden_states = batched_image_hidden_states * image_position_mask.unsqueeze(
|
||||
-1
|
||||
) + batched_audio_hidden_states * non_image_position_mask.unsqueeze(
|
||||
-1)
|
||||
elif batched_image_hidden_states is not None:
|
||||
batched_hidden_states = batched_image_hidden_states
|
||||
elif batched_audio_hidden_states is not None:
|
||||
batched_hidden_states = batched_audio_hidden_states
|
||||
else:
|
||||
batched_hidden_states = self.embed_tokens(batched_input_ids)
|
||||
|
||||
# Postprocessing to get multimodal-only embeddings.
|
||||
mm_token_mask = torch.isin(batched_input_ids, mm_token_ids)
|
||||
batched_hidden_states = batched_hidden_states[mm_token_mask]
|
||||
batched_hidden_states = [batched_hidden_states]
|
||||
return batched_hidden_states
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(self, multimodal_params: List[MultimodalParams],
|
||||
mm_token_ids: torch.Tensor) -> List[torch.FloatTensor]:
|
||||
if os.getenv("PHI4_MM_PER_REQUEST_INFER", "0") == "1":
|
||||
# Reference code path to check correctness of batch inference and further dev.
|
||||
# (TODO) Remove this path after accuracy bench and data parallelism are supported.
|
||||
return self._encoding_per_request(multimodal_params, mm_token_ids)
|
||||
else:
|
||||
# Batch inference as default path.
|
||||
return self._encoding_batch_request(multimodal_params, mm_token_ids)
|
||||
# Create a PreTrainedModel class for transformers=4.53.1 upgrade.
|
||||
# Core idea is to provide `prepare_inputs_for_generation` method from `GenerationMixin`.
|
||||
class NewPreTrainedModel(transformers.modeling_utils.PreTrainedModel,
|
||||
transformers.generation.GenerationMixin):
|
||||
pass
|
||||
|
||||
|
||||
class Phi4MMInputProcessor(InputProcessor):
|
||||
@ -396,11 +42,10 @@ class Phi4MMInputProcessor(InputProcessor):
|
||||
model_config: transformers.PretrainedConfig,
|
||||
tokenizer: transformers.AutoTokenizer,
|
||||
trust_remote_code: bool = True):
|
||||
if not trust_remote_code:
|
||||
raise ValueError("trust_remote_code must be True for Phi4MM")
|
||||
assert trust_remote_code, "trust_remote_code must be True for Phi4MM"
|
||||
|
||||
self.model_config = model_config
|
||||
self.device = 'cpu'
|
||||
self.device = 'cuda'
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
self.use_fast = True
|
||||
@ -415,18 +60,37 @@ class Phi4MMInputProcessor(InputProcessor):
|
||||
trust_remote_code=trust_remote_code,
|
||||
use_fast=self.use_fast)
|
||||
|
||||
# Build pure-pytorch model architecture for multimodal encoder.
|
||||
# Model weights are also loaded here.
|
||||
OldPreTrainedModel = transformers.modeling_utils.PreTrainedModel
|
||||
transformers.modeling_utils.PreTrainedModel = NewPreTrainedModel
|
||||
# TODO: Make separate Phi4VisionEncoder and Phi4AudioEncoder, and move them to LLM-side.
|
||||
ref_phi4mm_model = transformers.AutoModelForCausalLM.from_pretrained(
|
||||
model_path,
|
||||
trust_remote_code=True,
|
||||
# Flash_attn_2 only supports bf16 or fp16 and set in HF config.
|
||||
torch_dtype='auto',
|
||||
_attn_implementation='flash_attention_2',
|
||||
).eval()
|
||||
transformers.modeling_utils.PreTrainedModel = OldPreTrainedModel
|
||||
self.phi4mm_modal_encoder = ref_phi4mm_model.model.embed_tokens_extend.to(
|
||||
self.device)
|
||||
# Required by Phi4MMImageAudioEmbedding.
|
||||
# See link: https://huggingface.co/microsoft/Phi-4-multimodal-instruct/blob/main/modeling_phi4mm.py#L701
|
||||
self.phi4mm_wte = ref_phi4mm_model.model.embed_tokens.to(self.device)
|
||||
|
||||
@torch.inference_mode()
|
||||
def __call__(
|
||||
self, inputs: TextPrompt, sampling_params: SamplingParams
|
||||
) -> Tuple[List[int], Optional[ExtraProcessedInputs]]:
|
||||
text_prompt, mm_data = inputs.get("prompt"), inputs.get(
|
||||
"multi_modal_data", {})
|
||||
text_prompt, mm_data, mm_processor_kwargs = inputs.get("prompt"), \
|
||||
inputs.get("multi_modal_data", {}), inputs.get("mm_processor_kwargs", {})
|
||||
images = mm_data.get("image", None)
|
||||
audios = mm_data.get("audio", None)
|
||||
|
||||
if images is not None:
|
||||
if isinstance(images[0], torch.Tensor):
|
||||
# HF Phi4MM can only support PIL images. Convert normalized tensors (0-1) to PIL images (0-255).
|
||||
# Convert normalized tensors (0-1) to PIL images (0-255).
|
||||
images = [
|
||||
Image.fromarray((image.permute(1, 2, 0) * 255).to(
|
||||
torch.uint8).cpu().numpy()) for image in images
|
||||
@ -447,16 +111,29 @@ class Phi4MMInputProcessor(InputProcessor):
|
||||
else:
|
||||
audio_projection_mode = 'speech'
|
||||
|
||||
# Will package inputs for language model forward in AGGREGATE mode.
|
||||
# Processing with Phi4MMImageAudioEmbedding.
|
||||
mm_features = self.phi4mm_modal_encoder(
|
||||
input_ids=inputs['input_ids'],
|
||||
input_embeds=None,
|
||||
input_image_embeds=inputs['input_image_embeds'],
|
||||
input_audio_embeds=inputs['input_audio_embeds'],
|
||||
image_sizes=inputs['image_sizes'],
|
||||
image_attention_mask=inputs['image_attention_mask'],
|
||||
audio_embed_sizes=inputs['audio_embed_sizes'],
|
||||
audio_attention_mask=inputs['audio_attention_mask'],
|
||||
audio_projection_mode=audio_projection_mode,
|
||||
wte=self.phi4mm_wte,
|
||||
)
|
||||
|
||||
# Postprocessing to get multimodal-only embeddings.
|
||||
image_token_mask = inputs['input_ids'] == _IMAGE_SPECIAL_TOKEN_ID
|
||||
audio_token_mask = inputs['input_ids'] == _AUDIO_SPECIAL_TOKEN_ID
|
||||
mm_token_mask = image_token_mask | audio_token_mask
|
||||
mm_features = mm_features[mm_token_mask]
|
||||
|
||||
multimodal_data = {}
|
||||
multimodal_data['input_ids'] = inputs['input_ids']
|
||||
multimodal_data['input_image_embeds'] = inputs['input_image_embeds']
|
||||
multimodal_data['image_sizes'] = inputs['image_sizes']
|
||||
multimodal_data['image_attention_mask'] = inputs['image_attention_mask']
|
||||
multimodal_data['input_audio_embeds'] = inputs['input_audio_embeds']
|
||||
multimodal_data['audio_embed_sizes'] = inputs['audio_embed_sizes']
|
||||
multimodal_data['audio_attention_mask'] = inputs['audio_attention_mask']
|
||||
multimodal_data['audio_projection_mode'] = audio_projection_mode
|
||||
multimodal_data["multimodal_embedding"] = mm_features
|
||||
|
||||
return inputs['input_ids'][0].to(torch.int32).tolist(), {
|
||||
"multimodal_data": multimodal_data,
|
||||
}
|
||||
@ -477,11 +154,10 @@ class Phi4MMInputProcessor(InputProcessor):
|
||||
class Phi4MMForCausalLM(transformers.PreTrainedModel):
|
||||
|
||||
_supports_flash_attn_2 = True
|
||||
MM_TOKEN_IDS = torch.tensor(
|
||||
[_IMAGE_SPECIAL_TOKEN_ID, _AUDIO_SPECIAL_TOKEN_ID])
|
||||
|
||||
def __init__(self, model_config: ModelConfig):
|
||||
if _is_disagg():
|
||||
raise ValueError(
|
||||
"Phi4MM does not support disaggregated inference yet.")
|
||||
|
||||
config = model_config.pretrained_config
|
||||
super().__init__(config)
|
||||
@ -490,15 +166,6 @@ class Phi4MMForCausalLM(transformers.PreTrainedModel):
|
||||
if hasattr(self, "llm"):
|
||||
return
|
||||
|
||||
if not _is_disagg():
|
||||
_load_phi4mm_classes(config._name_or_path)
|
||||
|
||||
# Setup HFPhi4MultimodalEncoder in AGGREGATE mode.
|
||||
self.hf_phi4mm_model = HFPhi4MultimodalEncoder(config).eval()
|
||||
self.hf_phi4mm_model.to(config.torch_dtype)
|
||||
# Required by HFPhi4MultimodalEncoder.
|
||||
self.phi4mm_wte = self.hf_phi4mm_model.embed_tokens
|
||||
|
||||
# We use Phi3ForCausalLM as the language model.
|
||||
llm_model_config = copy.deepcopy(model_config)
|
||||
llm_model_config.pretrained_config.architectures = ["Phi3ForCausalLM"]
|
||||
@ -512,18 +179,6 @@ class Phi4MMForCausalLM(transformers.PreTrainedModel):
|
||||
self.is_loaded = True
|
||||
|
||||
def load_weights(self, weights):
|
||||
# Load weights into HFPhi4MultimodalEncoder.
|
||||
if not _is_disagg():
|
||||
filtered_weights = {}
|
||||
for k, v in weights.items():
|
||||
if k.startswith("model.embed_tokens."):
|
||||
new_k = k.replace("model.embed_tokens.", "embed_tokens.")
|
||||
filtered_weights[new_k] = v
|
||||
elif k.startswith("model.embed_tokens_extend."):
|
||||
new_k = k.replace("model.embed_tokens_extend.", "")
|
||||
filtered_weights[new_k] = v
|
||||
self.hf_phi4mm_model.load_state_dict(filtered_weights, strict=True)
|
||||
|
||||
# Filter out non-language model weights.
|
||||
weights = {
|
||||
k: v
|
||||
@ -542,12 +197,8 @@ class Phi4MMForCausalLM(transformers.PreTrainedModel):
|
||||
else:
|
||||
updated_weights[k] = weights[k]
|
||||
weights = updated_weights
|
||||
self.llm.load_weights(weights)
|
||||
|
||||
# Move mm_token_ids to the correct device.
|
||||
self.mm_token_ids = torch.tensor(
|
||||
[_IMAGE_SPECIAL_TOKEN_ID, _AUDIO_SPECIAL_TOKEN_ID],
|
||||
device=self.device)
|
||||
self.llm.load_weights(weights)
|
||||
|
||||
def infer_max_seq_len(self) -> int:
|
||||
return self.llm.infer_max_seq_len()
|
||||
@ -576,24 +227,17 @@ class Phi4MMForCausalLM(transformers.PreTrainedModel):
|
||||
)
|
||||
|
||||
multimodal_params = kwargs.get("multimodal_params", [])
|
||||
mm_embedding = []
|
||||
mm_embeds = []
|
||||
if len(multimodal_params) > 0:
|
||||
if not _is_disagg():
|
||||
# Forward the multimodal data to HFPhi4MultimodalEncoder in AGGREGATE mode.
|
||||
mm_embedding = self.hf_phi4mm_model(multimodal_params,
|
||||
self.mm_token_ids)
|
||||
else:
|
||||
# Directly fetch the multimodal embedding for DISAGG mode.
|
||||
# This path is not functional now. `multimodal_params` will be prepared in PyExecutor.
|
||||
mm_embedding = [
|
||||
multimodal_param.multimodal_data["multimodal_embedding"]
|
||||
for multimodal_param in multimodal_params
|
||||
]
|
||||
mm_embeds = [
|
||||
multimodal_param.multimodal_data["multimodal_embedding"]
|
||||
for multimodal_param in multimodal_params
|
||||
]
|
||||
input_ids, input_embeds = fuse_input_embeds(
|
||||
self.llm.model.embed_tokens,
|
||||
input_ids,
|
||||
mm_embedding,
|
||||
mm_token_ids=self.mm_token_ids,
|
||||
mm_embeds,
|
||||
mm_token_ids=self.MM_TOKEN_IDS,
|
||||
)
|
||||
|
||||
output_prob = self.llm.forward(
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, Generic, Optional, Tuple
|
||||
from typing import Dict, Generic, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -293,18 +293,6 @@ class Eagle3ForCausalLM(DecoderModelForCausalLM[Eagle3DraftModel, LlamaConfig]):
|
||||
if self.load_lm_head_from_target:
|
||||
self.lm_head = target_model.lm_head
|
||||
|
||||
# TODO: should input/position IDs be included in this? Keeping it implicit
|
||||
# for now since the shapes/dtypes are the same across all models we have.
|
||||
def get_warmup_extra_inputs(self, batch_size: int,
|
||||
num_tokens: int) -> Dict[str, Any]:
|
||||
|
||||
hidden_states = torch.empty(batch_size * num_tokens,
|
||||
self.model.hidden_size,
|
||||
dtype=self.model.dtype,
|
||||
device='cuda')
|
||||
|
||||
return {'hidden_states': hidden_states}
|
||||
|
||||
def apply_eagle3_fc(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Hack for eagle3. We might need to run a matmul to reduce
|
||||
|
||||
@ -650,9 +650,12 @@ class FP8BlockScalesLinearMethod(LinearMethodBase):
|
||||
load_weights_vanilla_helper(module, weights)
|
||||
|
||||
scale_name = self._get_scale_name(weights)
|
||||
weight_scale = load_weight_shard(weights[0][scale_name], module.tp_size,
|
||||
module.tp_rank,
|
||||
module.tp_mode).squeeze()
|
||||
full_weight_scale = weights[0][scale_name]
|
||||
# modelopt fp8_pb_wo can have 2 extra singleton dimensions
|
||||
if full_weight_scale.dim() == 4:
|
||||
full_weight_scale = full_weight_scale.squeeze(1).squeeze(-1)
|
||||
weight_scale = load_weight_shard(full_weight_scale, module.tp_size,
|
||||
module.tp_rank, module.tp_mode)
|
||||
copy_weight(module.weight_scale, weight_scale)
|
||||
if "input_scale" in weights[0]:
|
||||
copy_weight(module.input_scale, weights[0]["input_scale"])
|
||||
@ -665,13 +668,23 @@ class FP8BlockScalesLinearMethod(LinearMethodBase):
|
||||
fused_weight = torch.cat((q_weight, k_weight, v_weight))
|
||||
|
||||
scale_name = self._get_scale_name(weights)
|
||||
q_scale = load_weight_shard(weights[0][scale_name], module.tp_size,
|
||||
full_q_scale = weights[0][scale_name]
|
||||
full_k_scale = weights[1][scale_name]
|
||||
full_v_scale = weights[2][scale_name]
|
||||
# modelopt fp8_pb_wo can have 2 extra singleton dimensions
|
||||
if full_q_scale.dim() == 4:
|
||||
full_q_scale = full_q_scale.squeeze(1).squeeze(-1)
|
||||
if full_k_scale.dim() == 4:
|
||||
full_k_scale = full_k_scale.squeeze(1).squeeze(-1)
|
||||
if full_v_scale.dim() == 4:
|
||||
full_v_scale = full_v_scale.squeeze(1).squeeze(-1)
|
||||
q_scale = load_weight_shard(full_q_scale, module.tp_size,
|
||||
module.tp_rank, module.tp_mode)
|
||||
k_scale = load_weight_shard(weights[1][scale_name], module.tp_size,
|
||||
k_scale = load_weight_shard(full_k_scale, module.tp_size,
|
||||
module.tp_rank, module.tp_mode)
|
||||
v_scale = load_weight_shard(weights[2][scale_name], module.tp_size,
|
||||
v_scale = load_weight_shard(full_v_scale, module.tp_size,
|
||||
module.tp_rank, module.tp_mode)
|
||||
fused_fp8_block_scale = torch.cat((q_scale, k_scale, v_scale)).squeeze()
|
||||
fused_fp8_block_scale = torch.cat((q_scale, k_scale, v_scale))
|
||||
|
||||
copy_weight(module.weight, fused_weight)
|
||||
copy_weight(module.weight_scale, fused_fp8_block_scale)
|
||||
@ -683,11 +696,18 @@ class FP8BlockScalesLinearMethod(LinearMethodBase):
|
||||
fused_weight = torch.cat((gate_weight, up_weight))
|
||||
|
||||
scale_name = self._get_scale_name(weights)
|
||||
left_scale = load_weight_shard(weights[0][scale_name], module.tp_size,
|
||||
full_left_scale = weights[0][scale_name]
|
||||
full_right_scale = weights[1][scale_name]
|
||||
# modelopt fp8_pb_wo can have 2 extra singleton dimensions
|
||||
if full_left_scale.dim() == 4:
|
||||
full_left_scale = full_left_scale.squeeze(1).squeeze(-1)
|
||||
if full_right_scale.dim() == 4:
|
||||
full_right_scale = full_right_scale.squeeze(1).squeeze(-1)
|
||||
left_scale = load_weight_shard(full_left_scale, module.tp_size,
|
||||
module.tp_rank, module.tp_mode)
|
||||
right_scale = load_weight_shard(weights[1][scale_name], module.tp_size,
|
||||
right_scale = load_weight_shard(full_right_scale, module.tp_size,
|
||||
module.tp_rank, module.tp_mode)
|
||||
fused_scale = torch.cat([left_scale, right_scale], dim=0).squeeze()
|
||||
fused_scale = torch.cat([left_scale, right_scale], dim=0)
|
||||
copy_weight(module.weight, fused_weight)
|
||||
copy_weight(module.weight_scale, fused_scale)
|
||||
|
||||
|
||||
@ -513,7 +513,8 @@ def create_py_executor_instance(
|
||||
guided_decoder: Optional[GuidedDecoder] = None,
|
||||
lora_config: Optional[LoraConfig] = None,
|
||||
garbage_collection_gen0_threshold: Optional[int] = None,
|
||||
kv_connector_manager: Optional[KvCacheConnectorManager] = None
|
||||
kv_connector_manager: Optional[KvCacheConnectorManager] = None,
|
||||
max_seq_len: Optional[int] = None,
|
||||
) -> PyExecutor:
|
||||
kv_cache_manager = resources.get(ResourceManagerType.KV_CACHE_MANAGER, None)
|
||||
|
||||
@ -662,7 +663,8 @@ def create_py_executor_instance(
|
||||
guided_decoder=guided_decoder,
|
||||
start_worker=start_worker,
|
||||
garbage_collection_gen0_threshold=garbage_collection_gen0_threshold,
|
||||
kv_connector_manager=kv_connector_manager)
|
||||
kv_connector_manager=kv_connector_manager,
|
||||
max_seq_len=max_seq_len)
|
||||
|
||||
|
||||
def create_torch_sampler_args(executor_config: ExecutorConfig, mapping: Mapping,
|
||||
|
||||
@ -10,7 +10,7 @@ import traceback
|
||||
import weakref
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch._dynamo.config
|
||||
@ -21,6 +21,7 @@ from tensorrt_llm._torch.models.checkpoints.base_checkpoint_loader import \
|
||||
from tensorrt_llm._torch.pyexecutor.sampler import SampleStateTensors
|
||||
from tensorrt_llm._torch.speculative import (
|
||||
get_num_extra_kv_tokens, update_spec_config_from_model_config)
|
||||
from tensorrt_llm._torch.speculative.drafting_loops import ChainDrafter
|
||||
from tensorrt_llm._torch.speculative.mtp import SampleStateTensorsMTP
|
||||
from tensorrt_llm._utils import (is_trace_enabled, nvtx_range, release_gc,
|
||||
str_dtype_to_torch, torch_dtype_to_str,
|
||||
@ -276,6 +277,8 @@ class PyTorchModelEngine(ModelEngine):
|
||||
spec_config: Optional["DecodingBaseConfig"] = None,
|
||||
lora_config: Optional[LoraConfig] = None,
|
||||
is_draft_model: bool = False,
|
||||
drafting_loop_wrapper: Optional[Callable[[torch.nn.Module],
|
||||
torch.nn.Module]] = None,
|
||||
):
|
||||
self.ub_buffers = None
|
||||
self.batch_size = batch_size
|
||||
@ -311,7 +314,8 @@ class PyTorchModelEngine(ModelEngine):
|
||||
max_num_tokens=max_num_tokens,
|
||||
moe_max_num_tokens=pytorch_backend_config.moe_max_num_tokens,
|
||||
moe_load_balancer=pytorch_backend_config.moe_load_balancer,
|
||||
lora_config=lora_config)
|
||||
lora_config=lora_config,
|
||||
drafting_loop_wrapper=drafting_loop_wrapper)
|
||||
# In case that some tests use stub models and override `_load_model`.
|
||||
if not hasattr(self.model, 'extra_attrs'):
|
||||
self.model.extra_attrs = {}
|
||||
@ -403,7 +407,7 @@ class PyTorchModelEngine(ModelEngine):
|
||||
dtype=torch.int,
|
||||
device='cuda')
|
||||
self.without_logits = self.spec_config.spec_dec_mode.without_logits(
|
||||
)
|
||||
) or self.model_is_wrapped
|
||||
self.max_draft_len = spec_config.max_draft_len
|
||||
else:
|
||||
self.without_logits = False
|
||||
@ -562,6 +566,15 @@ class PyTorchModelEngine(ModelEngine):
|
||||
# Reset the global cuda graph dummy request to None in warmup.
|
||||
self.cuda_graph_runner.padding_dummy_request = None
|
||||
|
||||
def get_num_extra_decoding_steps():
|
||||
if isinstance(self.model, ChainDrafter):
|
||||
return self.model.max_draft_len
|
||||
else:
|
||||
assert not self.model_is_wrapped, (
|
||||
f"Please add logic to determine num_extra_decoding_steps for drafting loop {type(self.model)}"
|
||||
)
|
||||
return 0
|
||||
|
||||
def get_cuda_graph_warmup_request(batch_size, draft_len):
|
||||
# Divide by max_beam_width to get an approximation of the number of requests that can be run in parallel.
|
||||
available_blocks = kv_cache_manager.get_num_free_blocks(
|
||||
@ -569,6 +582,8 @@ class PyTorchModelEngine(ModelEngine):
|
||||
if available_blocks >= batch_size:
|
||||
result = ScheduledRequests()
|
||||
result.context_requests = []
|
||||
num_extra_decoding_steps = get_num_extra_decoding_steps()
|
||||
|
||||
# Add (batch_size - 1) dummy requests with seq_len=1.
|
||||
# Should only need one more page per request.
|
||||
requests = kv_cache_manager.add_dummy_requests(
|
||||
@ -576,7 +591,8 @@ class PyTorchModelEngine(ModelEngine):
|
||||
is_gen=True,
|
||||
max_num_draft_tokens=draft_len,
|
||||
use_mrope=use_mrope,
|
||||
max_beam_width=self.max_beam_width)
|
||||
max_beam_width=self.max_beam_width,
|
||||
num_extra_decoding_steps=num_extra_decoding_steps)
|
||||
# Divide by max_beam_width to get an approximation of the number of tokens that can be added to the final request.
|
||||
available_tokens = kv_cache_manager.get_num_available_tokens(
|
||||
draft_len)
|
||||
@ -592,13 +608,20 @@ class PyTorchModelEngine(ModelEngine):
|
||||
if max_position_embeddings is not None:
|
||||
token_num = min(token_num,
|
||||
max_position_embeddings - draft_len)
|
||||
|
||||
assert token_num > num_extra_decoding_steps, (
|
||||
"Cannot fuse drafting loop. We do not have enough KV cache space "
|
||||
"for all of the draft tokens.")
|
||||
token_num -= num_extra_decoding_steps
|
||||
|
||||
max_seq_len_request = kv_cache_manager.add_dummy_requests(
|
||||
request_ids=[batch_size - 1],
|
||||
token_nums=[token_num],
|
||||
is_gen=True,
|
||||
max_num_draft_tokens=draft_len,
|
||||
use_mrope=use_mrope,
|
||||
max_beam_width=self.max_beam_width)[0]
|
||||
max_beam_width=self.max_beam_width,
|
||||
num_extra_decoding_steps=num_extra_decoding_steps)[0]
|
||||
# Add the longest request before all other seq_len=1 request to simulate the padding CUDA graph case.
|
||||
# This batch contains both the longest request and the shortest requests,
|
||||
# it also contains the maximum number of requests and the maximum token number,
|
||||
@ -620,6 +643,13 @@ class PyTorchModelEngine(ModelEngine):
|
||||
if num_tokens > self.max_num_tokens or num_tokens > available_tokens:
|
||||
return None
|
||||
|
||||
num_extra_decoding_steps = get_num_extra_decoding_steps()
|
||||
if num_extra_decoding_steps > 0:
|
||||
# Disable autotuning for fused drafting loops for now.
|
||||
# There are a few bugs that can cause illegal memory accesses
|
||||
# during warmup.
|
||||
return None
|
||||
|
||||
num_ctx_tokens = num_tokens - num_gen_tokens
|
||||
num_ctx_requests = 0
|
||||
ctx_requests = []
|
||||
@ -905,6 +935,8 @@ class PyTorchModelEngine(ModelEngine):
|
||||
moe_max_num_tokens: Optional[int] = None,
|
||||
moe_load_balancer: Optional[MoeLoadBalancerConfig] = None,
|
||||
lora_config: Optional[LoraConfig] = None,
|
||||
drafting_loop_wrapper: Optional[Callable[
|
||||
[torch.nn.Module], torch.nn.Module]] = None,
|
||||
**kwargs) -> DecoderModelForCausalLM:
|
||||
config = checkpoint_loader.load_config(
|
||||
checkpoint_dir,
|
||||
@ -1008,6 +1040,13 @@ class PyTorchModelEngine(ModelEngine):
|
||||
logger.info("moe_load_balancer finalize model done")
|
||||
|
||||
torch.cuda.current_stream().synchronize()
|
||||
|
||||
if drafting_loop_wrapper is not None:
|
||||
model = drafting_loop_wrapper(model)
|
||||
self.model_is_wrapped = True
|
||||
else:
|
||||
self.model_is_wrapped = False
|
||||
|
||||
return model
|
||||
|
||||
def _call_load_weights(self, load_method, weights, weight_mapper):
|
||||
|
||||
@ -139,25 +139,25 @@ class BatchStatePP(BatchState):
|
||||
|
||||
class PyExecutor:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
resource_manager,
|
||||
scheduler: RequestScheduler,
|
||||
model_engine: ModelEngine,
|
||||
sampler: Sampler,
|
||||
dist: Distributed,
|
||||
max_num_sequences: int,
|
||||
drafter: Optional[Drafter] = None,
|
||||
disable_overlap_scheduler: bool = False,
|
||||
max_input_len: int = 2048,
|
||||
max_batch_size: int = 8,
|
||||
max_beam_width: int = 1,
|
||||
max_draft_len: int = 0,
|
||||
kv_cache_transceiver: Optional[KvCacheTransceiver] = None,
|
||||
guided_decoder: Optional[GuidedDecoder] = None,
|
||||
garbage_collection_gen0_threshold: Optional[int] = None,
|
||||
start_worker: bool = True,
|
||||
kv_connector_manager: Optional[KvCacheConnectorManager] = None):
|
||||
def __init__(self,
|
||||
resource_manager,
|
||||
scheduler: RequestScheduler,
|
||||
model_engine: ModelEngine,
|
||||
sampler: Sampler,
|
||||
dist: Distributed,
|
||||
max_num_sequences: int,
|
||||
drafter: Optional[Drafter] = None,
|
||||
disable_overlap_scheduler: bool = False,
|
||||
max_input_len: int = 2048,
|
||||
max_batch_size: int = 8,
|
||||
max_beam_width: int = 1,
|
||||
max_draft_len: int = 0,
|
||||
kv_cache_transceiver: Optional[KvCacheTransceiver] = None,
|
||||
guided_decoder: Optional[GuidedDecoder] = None,
|
||||
garbage_collection_gen0_threshold: Optional[int] = None,
|
||||
start_worker: bool = True,
|
||||
kv_connector_manager: Optional[KvCacheConnectorManager] = None,
|
||||
max_seq_len: Optional[int] = None):
|
||||
super(PyExecutor, self).__init__()
|
||||
self.device_id = torch.cuda.current_device()
|
||||
self.global_rank = global_mpi_rank()
|
||||
@ -271,6 +271,7 @@ class PyExecutor:
|
||||
)
|
||||
self.draft_seq_slot_manager = SeqSlotManager(max_num_sequences)
|
||||
self.garbage_collection_gen0_threshold = garbage_collection_gen0_threshold
|
||||
self.max_seq_len = max_seq_len
|
||||
|
||||
self.worker_started = False
|
||||
self.worker_lock = threading.Lock()
|
||||
|
||||
@ -14,9 +14,12 @@ from tensorrt_llm._torch.pyexecutor.resource_manager import ResourceManagerType
|
||||
from tensorrt_llm._utils import get_sm_version
|
||||
from tensorrt_llm.bindings.executor import (CapacitySchedulerPolicy,
|
||||
ContextChunkingPolicy,
|
||||
ExecutorConfig)
|
||||
ExecutorConfig,
|
||||
LogitsPostProcessorConfig,
|
||||
ParallelConfig)
|
||||
from tensorrt_llm.bindings.internal.batch_manager import ContextChunkingConfig
|
||||
from tensorrt_llm.llmapi.llm_args import KvCacheConnectorConfig
|
||||
from tensorrt_llm.llmapi.llm_args import KvCacheConnectorConfig, TorchLlmArgs
|
||||
from tensorrt_llm.llmapi.tokenizer import TokenizerBase
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.lora_helper import LoraConfig
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
@ -209,12 +212,21 @@ def _get_mapping(executor_config: ExecutorConfig) -> Mapping:
|
||||
|
||||
|
||||
def create_py_executor(
|
||||
executor_config: ExecutorConfig,
|
||||
checkpoint_dir: str = None,
|
||||
lora_config: Optional[LoraConfig] = None,
|
||||
garbage_collection_gen0_threshold: Optional[int] = None,
|
||||
kv_connector_config: Optional[KvCacheConnectorConfig] = None
|
||||
llm_args: TorchLlmArgs,
|
||||
checkpoint_dir: str = None,
|
||||
tokenizer: Optional[TokenizerBase] = None,
|
||||
lora_config: Optional[LoraConfig] = None,
|
||||
kv_connector_config: Optional[KvCacheConnectorConfig] = None,
|
||||
logits_post_processor_config: Optional[LogitsPostProcessorConfig] = None,
|
||||
parallel_config: Optional[ParallelConfig] = None,
|
||||
) -> PyExecutor:
|
||||
|
||||
executor_config = llm_args.get_executor_config(checkpoint_dir, tokenizer)
|
||||
executor_config.logits_post_processor_config = logits_post_processor_config
|
||||
executor_config.parallel_config = parallel_config
|
||||
|
||||
garbage_collection_gen0_threshold = llm_args.garbage_collection_gen0_threshold
|
||||
|
||||
_mangle_executor_config(executor_config)
|
||||
pytorch_backend_config = executor_config.pytorch_backend_config
|
||||
|
||||
@ -260,13 +272,29 @@ def create_py_executor(
|
||||
with mem_monitor.observe_creation_stage(
|
||||
_ExecutorCreationStage.MODEL_ENGINE_DRAFT):
|
||||
draft_spec_config = copy.copy(spec_config)
|
||||
draft_pytorch_backend_config = copy.copy(pytorch_backend_config)
|
||||
if spec_config.load_format == "dummy":
|
||||
draft_pytorch_backend_config.load_format = LoadFormat.DUMMY
|
||||
# The draft model won't have any draft tokens attached to
|
||||
# generation requests when we invoke it autoregressively
|
||||
draft_spec_config.max_draft_len = 0
|
||||
|
||||
use_chain_drafter = (
|
||||
executor_config.guided_decoding_config is None
|
||||
and not pytorch_backend_config.enable_mixed_sampler
|
||||
and pytorch_backend_config.attn_backend == "TRTLLM")
|
||||
|
||||
if use_chain_drafter:
|
||||
|
||||
def drafting_loop_wrapper(model):
|
||||
from tensorrt_llm._torch.speculative.drafting_loops import \
|
||||
ChainDrafter
|
||||
|
||||
return ChainDrafter(spec_config.max_draft_len, model)
|
||||
else:
|
||||
drafting_loop_wrapper = None
|
||||
|
||||
draft_pytorch_backend_config = copy.copy(pytorch_backend_config)
|
||||
if spec_config.load_format == "dummy":
|
||||
draft_pytorch_backend_config.load_format = LoadFormat.DUMMY
|
||||
|
||||
draft_model_engine = PyTorchModelEngine(
|
||||
model_path=spec_config.speculative_model_dir,
|
||||
pytorch_backend_config=draft_pytorch_backend_config,
|
||||
@ -282,6 +310,7 @@ def create_py_executor(
|
||||
spec_config=draft_spec_config,
|
||||
checkpoint_loader=executor_config.checkpoint_loader,
|
||||
is_draft_model=True,
|
||||
drafting_loop_wrapper=drafting_loop_wrapper,
|
||||
)
|
||||
draft_model_engine.kv_cache_manager_key = ResourceManagerType.DRAFT_KV_CACHE_MANAGER
|
||||
draft_model_engine.load_weights_from_target_model(
|
||||
@ -484,6 +513,7 @@ def create_py_executor(
|
||||
garbage_collection_gen0_threshold=garbage_collection_gen0_threshold,
|
||||
kv_connector_manager=kv_connector_manager
|
||||
if not estimating_kv_cache else None,
|
||||
max_seq_len=executor_config.max_seq_len,
|
||||
)
|
||||
|
||||
if estimating_kv_cache:
|
||||
@ -528,6 +558,7 @@ def create_py_executor(
|
||||
garbage_collection_gen0_threshold=
|
||||
garbage_collection_gen0_threshold,
|
||||
kv_connector_manager=kv_connector_manager,
|
||||
max_seq_len=executor_config.max_seq_len,
|
||||
)
|
||||
|
||||
_adjust_torch_mem_fraction(executor_config.pytorch_backend_config)
|
||||
|
||||
@ -469,6 +469,11 @@ class KVCacheManager(BaseResourceManager):
|
||||
max_num_draft_tokens: int = 0,
|
||||
use_mrope: bool = False,
|
||||
max_beam_width: int = 1,
|
||||
# For capturable drafting loops. During normal inference, the draft model always
|
||||
# has enough KV cache space to fit all of our draft tokens. During warmup, however,
|
||||
# we need to make the KV cache manager aware that multiple autoregressive steps will
|
||||
# occur.
|
||||
num_extra_decoding_steps: int = 0,
|
||||
):
|
||||
beam_width = max_beam_width
|
||||
requests = []
|
||||
@ -502,6 +507,10 @@ class KVCacheManager(BaseResourceManager):
|
||||
self.impl.add_sequence(req_id, token_num, beam_width, req)
|
||||
for _ in range(self.num_extra_kv_tokens):
|
||||
self.impl.add_token(req_id)
|
||||
|
||||
for _ in range(num_extra_decoding_steps):
|
||||
self.impl.add_token(req_id)
|
||||
|
||||
if is_gen:
|
||||
req.state = LlmRequestState.GENERATION_IN_PROGRESS
|
||||
req.prompt_len = token_num - 1
|
||||
@ -510,6 +519,7 @@ class KVCacheManager(BaseResourceManager):
|
||||
if prepare_resource:
|
||||
for _ in range(max_num_draft_tokens):
|
||||
self.impl.add_token(req_id)
|
||||
|
||||
requests.append(req)
|
||||
return requests
|
||||
|
||||
|
||||
150
tensorrt_llm/_torch/speculative/drafting_loops.py
Normal file
150
tensorrt_llm/_torch/speculative/drafting_loops.py
Normal file
@ -0,0 +1,150 @@
|
||||
"""
|
||||
This module contains capturable drafting loops for speculative decoding.
|
||||
|
||||
These are torch modules wrap another draft model. The wrapped module
|
||||
is supposed to invoke the draft model autoregressively and invoke
|
||||
a sampling algorithm to obtain draft tokens. By structuring the code
|
||||
like this, we are able to avoid host overhead: the entire drafting process
|
||||
for speculation can be launched as a single CUDA graph.
|
||||
"""
|
||||
|
||||
from contextlib import contextmanager
|
||||
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._torch.attention_backend.interface import AttentionMetadata
|
||||
from tensorrt_llm._torch.speculative.eagle3 import Eagle3SpecMetadata
|
||||
from tensorrt_llm._torch.speculative.interface import SpecMetadata
|
||||
|
||||
|
||||
@contextmanager
|
||||
def save_metadata_state(attn_metadata: AttentionMetadata,
|
||||
spec_metadata: SpecMetadata) -> None:
|
||||
batch_size = attn_metadata.num_seqs
|
||||
|
||||
if attn_metadata.is_cuda_graph:
|
||||
seq_len = attn_metadata._seq_lens[:batch_size].clone()
|
||||
seq_len_cuda = attn_metadata._seq_lens_cuda[:batch_size].clone()
|
||||
kv_lens = attn_metadata.kv_lens_cuda.clone()
|
||||
|
||||
assert spec_metadata.is_cuda_graph
|
||||
num_tokens = spec_metadata.num_tokens
|
||||
if isinstance(spec_metadata, Eagle3SpecMetadata):
|
||||
read_indices = spec_metadata.hidden_states_read_indices[:
|
||||
batch_size].clone(
|
||||
)
|
||||
write_indices = spec_metadata.hidden_states_write_indices[:
|
||||
batch_size].clone(
|
||||
)
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if attn_metadata.is_cuda_graph:
|
||||
attn_metadata._seq_lens[:batch_size].copy_(seq_len[:batch_size])
|
||||
attn_metadata._seq_lens_cuda[:batch_size].copy_(
|
||||
seq_len_cuda[:batch_size])
|
||||
attn_metadata.kv_lens_cuda[:batch_size].copy_(kv_lens[:batch_size])
|
||||
|
||||
spec_metadata.num_tokens = num_tokens
|
||||
if isinstance(spec_metadata, Eagle3SpecMetadata):
|
||||
spec_metadata.hidden_states_read_indices[:batch_size].copy_(
|
||||
read_indices)
|
||||
spec_metadata.hidden_states_write_indices[:batch_size].copy_(
|
||||
write_indices)
|
||||
|
||||
# This restore has to happen even if the spec_metadata is not being used
|
||||
# for CUDA graphs. It won't be reset by spec_metadata.prepare().
|
||||
if isinstance(spec_metadata, Eagle3SpecMetadata):
|
||||
spec_metadata.is_first_draft = True
|
||||
spec_metadata.eagle3_resource_manager.is_first_draft = True
|
||||
|
||||
|
||||
def prepare_for_generation(attn_metadata: AttentionMetadata,
|
||||
spec_metadata: SpecMetadata,
|
||||
last_tokens_idx: torch.Tensor) -> None:
|
||||
batch_size = attn_metadata.num_seqs
|
||||
attn_metadata._seq_lens[:batch_size].fill_(1)
|
||||
attn_metadata._seq_lens_cuda[:batch_size].fill_(1)
|
||||
attn_metadata.on_update()
|
||||
attn_metadata.kv_lens_cuda[:batch_size] += 1
|
||||
|
||||
attn_metadata.host_request_types[:attn_metadata.num_contexts].fill_(1)
|
||||
attn_metadata.num_contexts = 0
|
||||
|
||||
spec_metadata.num_tokens = batch_size
|
||||
|
||||
if isinstance(spec_metadata, Eagle3SpecMetadata):
|
||||
spec_metadata.eagle3_resource_manager.is_first_draft = False
|
||||
spec_metadata.is_first_draft = False
|
||||
|
||||
old_write_indices = spec_metadata.hidden_states_write_indices
|
||||
|
||||
spec_metadata.hidden_states_read_indices[:batch_size].copy_(
|
||||
old_write_indices[last_tokens_idx])
|
||||
spec_metadata.hidden_states_write_indices[:batch_size].copy_(
|
||||
torch.arange(
|
||||
batch_size,
|
||||
dtype=spec_metadata.hidden_states_write_indices.dtype,
|
||||
device=spec_metadata.hidden_states_write_indices.device))
|
||||
|
||||
|
||||
class ChainDrafter(torch.nn.Module):
|
||||
|
||||
def __init__(self, max_draft_len: int, draft_model: torch.nn.Module):
|
||||
super().__init__()
|
||||
self.draft_model = draft_model
|
||||
self.config = self.draft_model.config
|
||||
self.model_config = self.draft_model.model_config
|
||||
self.max_draft_len = max_draft_len
|
||||
|
||||
def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
spec_metadata: AttentionMetadata, **kwargs) -> None:
|
||||
|
||||
logits = self.draft_model.forward(input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
attn_metadata=attn_metadata,
|
||||
spec_metadata=spec_metadata)
|
||||
|
||||
new_draft_tokens = [self.sample(logits)]
|
||||
|
||||
with save_metadata_state(attn_metadata, spec_metadata):
|
||||
batch_size = attn_metadata.num_seqs
|
||||
last_tokens_idx = torch.cumsum(
|
||||
attn_metadata.seq_lens_cuda, dim=0, dtype=torch.long) - 1
|
||||
new_position_ids = position_ids[0, last_tokens_idx] + 1
|
||||
|
||||
prepare_for_generation(attn_metadata, spec_metadata,
|
||||
last_tokens_idx)
|
||||
|
||||
for i in range(self.max_draft_len - 1):
|
||||
logits = self.draft_model.forward(
|
||||
input_ids=new_draft_tokens[-1],
|
||||
position_ids=new_position_ids,
|
||||
attn_metadata=attn_metadata,
|
||||
spec_metadata=spec_metadata)
|
||||
new_draft_tokens.append(self.sample(logits))
|
||||
new_position_ids += 1
|
||||
attn_metadata.kv_lens_cuda[:batch_size] += 1
|
||||
if i == 0 and isinstance(spec_metadata, Eagle3SpecMetadata):
|
||||
spec_metadata.hidden_states_read_indices[:batch_size].copy_(
|
||||
spec_metadata.hidden_states_write_indices[:batch_size])
|
||||
|
||||
return torch.stack(new_draft_tokens)
|
||||
|
||||
def sample(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
# TODO: inject the sampler here so we can support non-greedy
|
||||
tokens = torch.argmax(logits, dim=-1)
|
||||
if hasattr(self.draft_model.model, "d2t"):
|
||||
d2t = self.draft_model.model.d2t.data
|
||||
return tokens + d2t[tokens]
|
||||
|
||||
return tokens
|
||||
|
||||
def load_weights_from_target_model(self,
|
||||
target_model: torch.nn.Module) -> None:
|
||||
loader = getattr(self.draft_model, "load_weights_from_target_model",
|
||||
None)
|
||||
if callable(loader):
|
||||
self.draft_model.load_weights_from_target_model(target_model)
|
||||
@ -71,6 +71,12 @@ class ModelDrafter(Drafter):
|
||||
self._request_draft_logits = sampler.enable_mixed_sampler
|
||||
self.guided_decoder = guided_decoder
|
||||
|
||||
self.use_static_draft_loop = draft_model_engine.model_is_wrapped
|
||||
if self.use_static_draft_loop:
|
||||
# TODO: enable sampling/guided decoding on static draft loop
|
||||
assert guided_decoder is None
|
||||
assert not sampler.enable_mixed_sampler
|
||||
|
||||
def _create_draft_request(self, request: LlmRequest,
|
||||
input_tokens: Optional[List]) -> LlmRequest:
|
||||
"""Create a draft request with common parameters."""
|
||||
@ -236,6 +242,8 @@ class ModelDrafter(Drafter):
|
||||
"""Check if CUDA graph should be disabled for the current forward pass."""
|
||||
if previous_batch is not None:
|
||||
return False
|
||||
if self.use_static_draft_loop:
|
||||
return False
|
||||
return self.spec_config.spec_dec_mode.needs_kv_cache_recompute()
|
||||
|
||||
def _forward_draft_model(
|
||||
@ -255,8 +263,10 @@ class ModelDrafter(Drafter):
|
||||
resource_manager,
|
||||
new_tensors_device=new_tensors_device)
|
||||
|
||||
# Handle d2t data if available
|
||||
if hasattr(self.draft_model_engine.model.model, 'd2t'):
|
||||
# Handle d2t data if available. Static drafting loops should incorporate d2t
|
||||
# in their implementations.
|
||||
if not self.use_static_draft_loop and hasattr(
|
||||
self.draft_model_engine.model.model, 'd2t'):
|
||||
outputs['d2t'] = self.draft_model_engine.model.model.d2t.data
|
||||
|
||||
return outputs
|
||||
@ -365,8 +375,29 @@ class ModelDrafter(Drafter):
|
||||
for req in scheduled_requests.all_requests()
|
||||
}
|
||||
|
||||
# Initial forward pass
|
||||
# Initial forward pass. May do the complete drafting loop
|
||||
# if use_static_draft_loop is set.
|
||||
outputs = self._forward_draft_model(draft_batch, resource_manager)
|
||||
|
||||
if self.use_static_draft_loop:
|
||||
outputs_host = outputs.cpu()
|
||||
for token_idx in range(self.max_draft_tokens):
|
||||
for req_idx, req in enumerate(draft_batch.all_requests()):
|
||||
target_model_req = req_id_to_old_request[
|
||||
req.py_request_id]
|
||||
if target_model_req.state != LlmRequestState.GENERATION_IN_PROGRESS:
|
||||
# Chunked prefill request in progress; no need to append draft tokens
|
||||
continue
|
||||
|
||||
target_req = req_id_to_old_request[req.py_request_id]
|
||||
target_req.py_draft_tokens.append(
|
||||
outputs_host[token_idx][req_idx])
|
||||
|
||||
for req in draft_batch.all_requests():
|
||||
self.draft_seq_slot_manager.free_resources(req)
|
||||
|
||||
return
|
||||
|
||||
self._execute_guided_decoder(draft_batch,
|
||||
outputs['logits'],
|
||||
d2t=outputs.get('d2t'))
|
||||
|
||||
@ -21,7 +21,7 @@ from .._utils import mpi_world_size
|
||||
from ..bindings import executor as tllm
|
||||
from ..builder import Engine
|
||||
from ..disaggregated_params import DisaggregatedParams
|
||||
from ..llmapi.llm_args import KvCacheConnectorConfig, TorchLlmArgs
|
||||
from ..llmapi.llm_args import BaseLlmArgs, KvCacheConnectorConfig
|
||||
from ..llmapi.llm_utils import KvCacheRetentionConfig
|
||||
from ..llmapi.mpi_session import (MpiSession, external_mpi_comm_available,
|
||||
need_spawn_mpi_workers)
|
||||
@ -359,7 +359,7 @@ class GenerationExecutor(ABC):
|
||||
kv_connector_config: Optional[KvCacheConnectorConfig] = None,
|
||||
hf_model_dir: Optional[Path] = None,
|
||||
tokenizer: Optional[TokenizerBase] = None,
|
||||
llm_args: Optional[TorchLlmArgs] = None,
|
||||
llm_args: Optional[BaseLlmArgs] = None,
|
||||
) -> Union["GenerationExecutorProxy", "GenerationExecutorWorker"]:
|
||||
# local imports to avoid cyclic importing
|
||||
from .proxy import GenerationExecutorProxy
|
||||
|
||||
@ -18,7 +18,8 @@ from .._utils import (KVCacheEventSerializer, global_mpi_rank, global_mpi_size,
|
||||
mpi_comm, mpi_rank, nvtx_range_debug)
|
||||
from ..bindings import executor as tllm
|
||||
from ..builder import ConfigEncoder, Engine, EngineConfig
|
||||
from ..llmapi.llm_args import KvCacheConnectorConfig, PybindMirror, TorchLlmArgs
|
||||
from ..llmapi.llm_args import (BaseLlmArgs, KvCacheConnectorConfig,
|
||||
PybindMirror, TorchLlmArgs)
|
||||
from ..llmapi.mpi_session import set_mpi_session_cpp
|
||||
from ..llmapi.tokenizer import TokenizerBase
|
||||
from ..llmapi.tracer import VizTracer, global_tracer, set_global_tracer
|
||||
@ -64,7 +65,7 @@ class GenerationExecutorWorker(GenerationExecutor):
|
||||
kv_connector_config: Optional[KvCacheConnectorConfig] = None,
|
||||
hf_model_dir: Optional[Path] = None,
|
||||
tokenizer: Optional[TokenizerBase] = None,
|
||||
llm_args: Optional[TorchLlmArgs] = None,
|
||||
llm_args: Optional[BaseLlmArgs] = None,
|
||||
) -> None:
|
||||
postproc_config = postproc_worker_config or PostprocWorkerConfig()
|
||||
super().__init__(
|
||||
@ -107,40 +108,55 @@ class GenerationExecutorWorker(GenerationExecutor):
|
||||
device_ids = mpi_comm().allgather(device_id)
|
||||
return comm_ranks, device_ids
|
||||
|
||||
def _create_py_executor(executor_config):
|
||||
assert executor_config is None, "expect an empty executor_config is _create_py_executor"
|
||||
executor_config = llm_args.get_executor_config(
|
||||
hf_model_dir, tokenizer)
|
||||
# Persist so downstream code (e.g., default max_tokens deduction) has access
|
||||
self._executor_config = executor_config
|
||||
executor_config.logits_post_processor_config = tllm.LogitsPostProcessorConfig(
|
||||
processor_batched=batched_logits_processor, replicate=False)
|
||||
comm_ranks, device_ids = _get_comm_ranks_device_id()
|
||||
executor_config.parallel_config = tllm.ParallelConfig(
|
||||
participant_ids=comm_ranks, device_ids=device_ids)
|
||||
args = {
|
||||
"executor_config": executor_config,
|
||||
"checkpoint_dir": executor_config.hf_model_dir,
|
||||
}
|
||||
def _create_py_executor():
|
||||
args = {}
|
||||
assert hasattr(
|
||||
executor_config, "backend"
|
||||
), "executor_config should be with backend in _create_py_executor"
|
||||
if executor_config.backend == "pytorch":
|
||||
self.llm_args, "backend"
|
||||
), "llm_args should be with backend in _create_py_executor"
|
||||
if self.llm_args.backend == "pytorch":
|
||||
from tensorrt_llm._torch.pyexecutor.py_executor_creator import \
|
||||
create_py_executor
|
||||
create_executor = create_py_executor
|
||||
args["llm_args"] = self.llm_args
|
||||
args["checkpoint_dir"] = hf_model_dir
|
||||
args["tokenizer"] = tokenizer
|
||||
args["lora_config"] = lora_config
|
||||
args[
|
||||
"garbage_collection_gen0_threshold"] = llm_args.garbage_collection_gen0_threshold
|
||||
args["kv_connector_config"] = kv_connector_config
|
||||
elif executor_config.backend == "_autodeploy":
|
||||
args[
|
||||
"logits_post_processor_config"] = tllm.LogitsPostProcessorConfig(
|
||||
processor_batched=batched_logits_processor,
|
||||
replicate=False)
|
||||
comm_ranks, device_ids = _get_comm_ranks_device_id()
|
||||
args["parallel_config"] = tllm.ParallelConfig(
|
||||
participant_ids=comm_ranks, device_ids=device_ids)
|
||||
elif self.llm_args.backend == "_autodeploy":
|
||||
from tensorrt_llm._torch.auto_deploy.llm_args import \
|
||||
LlmArgs as ADLlmArgs
|
||||
from tensorrt_llm._torch.auto_deploy.shim.ad_executor import \
|
||||
create_autodeploy_executor
|
||||
create_executor = create_autodeploy_executor
|
||||
assert isinstance(self.llm_args, ADLlmArgs)
|
||||
args["ad_config"] = self.llm_args.get_pytorch_backend_config()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported backend config: {executor_config.backend}")
|
||||
return create_executor(**args)
|
||||
f"Unsupported backend config: {self.llm_args.backend}")
|
||||
|
||||
# Define additional attributes that can be used later, such as in _deduce_max_tokens
|
||||
self.mapping = self.llm_args.parallel_config.to_mapping()
|
||||
self.checkpoint_loader = None
|
||||
if self.llm_args.backend == "pytorch":
|
||||
from tensorrt_llm._torch.pyexecutor.config import \
|
||||
_construct_checkpoint_loader
|
||||
self.checkpoint_loader = _construct_checkpoint_loader(
|
||||
self.llm_args.backend, self.llm_args.checkpoint_loader,
|
||||
self.llm_args.checkpoint_format)
|
||||
|
||||
_executor = create_executor(**args)
|
||||
self.max_seq_len = self.llm_args.max_seq_len
|
||||
if _executor.max_seq_len is not None:
|
||||
# max_seq_len might be updated by model engine as in create_py_executor
|
||||
self.max_seq_len = _executor.max_seq_len
|
||||
return _executor
|
||||
|
||||
def _create_engine(executor_config):
|
||||
if executor_config is None:
|
||||
@ -164,8 +180,7 @@ class GenerationExecutorWorker(GenerationExecutor):
|
||||
executor_config)
|
||||
|
||||
self.engine = _create_py_executor(
|
||||
executor_config) if llm_args is not None else _create_engine(
|
||||
executor_config)
|
||||
) if self.llm_args is not None else _create_engine(executor_config)
|
||||
|
||||
self._lora_manager: Optional[LoraManager] = None
|
||||
self._prompt_adapter_manager: Optional[PromptAdapterManager] = None
|
||||
@ -188,8 +203,9 @@ class GenerationExecutorWorker(GenerationExecutor):
|
||||
if engine_config.build_config.max_prompt_embedding_table_size > 0:
|
||||
self._prompt_adapter_manager = PromptAdapterManager()
|
||||
|
||||
if getattr(self._executor_config, "backend",
|
||||
"") == "pytorch" and lora_config is not None:
|
||||
if self.llm_args and getattr(
|
||||
self.llm_args, "backend",
|
||||
"") == "pytorch" and lora_config is not None:
|
||||
from tensorrt_llm._torch.pyexecutor.resource_manager import \
|
||||
ResourceManagerType
|
||||
peft_cache_manager = self.engine.resource_manager.resource_managers.get(
|
||||
@ -471,26 +487,43 @@ class GenerationExecutorWorker(GenerationExecutor):
|
||||
assert request.id is not None
|
||||
|
||||
def _deduce_max_tokens(request: GenerationRequest,
|
||||
executor_config: tllm.ExecutorConfig) -> int:
|
||||
executor_config: tllm.ExecutorConfig,
|
||||
llm_args: Optional[BaseLlmArgs] = None) -> int:
|
||||
# deduce max_tokens when it's not set by user
|
||||
max_tokens = request.sampling_params.max_tokens
|
||||
query_token_len = len(
|
||||
request.query_token_ids) if request.query_token_ids else 0
|
||||
cp_size = 1 if (not hasattr(executor_config, "mapping")
|
||||
or executor_config.mapping.cp_size
|
||||
is None) else executor_config.mapping.cp_size
|
||||
if not hasattr(executor_config, "max_seq_len"):
|
||||
|
||||
cp_size = 1
|
||||
max_seq_len = None
|
||||
if llm_args is not None:
|
||||
# deduce max_tokens by llm args
|
||||
assert executor_config is None, "An empty executor_config in _deduce_max_tokens is expected when LLM arguments are defined."
|
||||
if hasattr(self,
|
||||
"mapping") and self.mapping.cp_size is not None:
|
||||
cp_size = self.mapping.cp_size
|
||||
max_seq_len = getattr(self, "max_seq_len", None)
|
||||
else:
|
||||
# deduce max_tokens by executor config
|
||||
if hasattr(executor_config, "mapping"
|
||||
) and executor_config.mapping.cp_size is not None:
|
||||
cp_size = executor_config.mapping.cp_size
|
||||
max_seq_len = getattr(executor_config, "max_seq_len", None)
|
||||
if max_seq_len is None:
|
||||
logger.warning("`default_max_tokens` cannot be deduced")
|
||||
if max_tokens is None:
|
||||
raise ValueError(
|
||||
"`max_tokens` must be set when `default_max_tokens` cannot be deduced"
|
||||
)
|
||||
else:
|
||||
# use max_tokens if can't deduce default_max_tokens
|
||||
return max_tokens
|
||||
splited_prompt_len = int(len(prompt_token_ids) / cp_size)
|
||||
default_max_tokens = executor_config.max_seq_len - splited_prompt_len - query_token_len
|
||||
default_max_tokens = max_seq_len - splited_prompt_len - query_token_len
|
||||
if default_max_tokens <= 0:
|
||||
logger.warning(
|
||||
f"`default_max_tokens` ({default_max_tokens}) should be greater than 0, "
|
||||
f"`default_max_tokens` ({default_max_tokens}) = max_seq_len ({executor_config.max_seq_len})"
|
||||
f"`default_max_tokens` ({default_max_tokens}) = max_seq_len ({max_seq_len})"
|
||||
f" - `splited_prompt_len` ({splited_prompt_len}) - `query_token_len` ({query_token_len})"
|
||||
)
|
||||
if max_tokens is None:
|
||||
@ -512,7 +545,8 @@ class GenerationExecutorWorker(GenerationExecutor):
|
||||
executor_request = tllm.Request(
|
||||
client_id=request.id,
|
||||
input_token_ids=prompt_token_ids,
|
||||
max_tokens=_deduce_max_tokens(request, self._executor_config),
|
||||
max_tokens=_deduce_max_tokens(request, self._executor_config,
|
||||
self.llm_args),
|
||||
streaming=request.streaming,
|
||||
sampling_config=request.sampling_params._get_sampling_config(),
|
||||
end_id=-1 if request.sampling_params.ignore_eos else
|
||||
@ -638,11 +672,19 @@ class GenerationExecutorWorker(GenerationExecutor):
|
||||
self.engine.shutdown()
|
||||
self.engine = None
|
||||
|
||||
if hasattr(
|
||||
self._executor_config, "checkpoint_loader"
|
||||
) and self._executor_config.checkpoint_loader is not None:
|
||||
self._executor_config.checkpoint_loader.cleanup()
|
||||
self._executor_config.checkpoint_loader = None
|
||||
if self.llm_args is not None:
|
||||
assert self._executor_config is None, "An empty executor_config is expected in shutdown when LLM arguments are defined."
|
||||
if (self.llm_args.backend == "pytorch"
|
||||
and hasattr(self, "checkpoint_loader")
|
||||
and self.checkpoint_loader is not None):
|
||||
self.checkpoint_loader.cleanup()
|
||||
self.checkpoint_loader = None
|
||||
else:
|
||||
if hasattr(
|
||||
self._executor_config, "checkpoint_loader"
|
||||
) and self._executor_config.checkpoint_loader is not None:
|
||||
self._executor_config.checkpoint_loader.cleanup()
|
||||
self._executor_config.checkpoint_loader = None
|
||||
|
||||
# Check if there are any errors from the threads before shutdown.
|
||||
self._handle_background_error()
|
||||
@ -689,7 +731,7 @@ def worker_main(
|
||||
kv_connector_config: Optional[KvCacheConnectorConfig] = None,
|
||||
hf_model_dir: Optional[Path] = None,
|
||||
tokenizer: Optional[TokenizerBase] = None,
|
||||
llm_args: Optional[TorchLlmArgs] = None,
|
||||
llm_args: Optional[BaseLlmArgs] = None,
|
||||
) -> None:
|
||||
mpi_comm().barrier()
|
||||
print_colored_debug(f"Worker {mpi_rank()} entering worker_main...\n",
|
||||
|
||||
@ -186,7 +186,8 @@ meta-llama/Llama-3.2-3B:
|
||||
kv_cache_quant_algo: FP8
|
||||
accuracy: 33.629
|
||||
meta-llama/Llama-3.3-70B-Instruct:
|
||||
- spec_dec_algo: Eagle
|
||||
- quant_algo: FP8
|
||||
spec_dec_algo: Eagle
|
||||
accuracy: 33.244
|
||||
- quant_algo: NVFP4
|
||||
kv_cache_quant_algo: FP8
|
||||
@ -216,6 +217,9 @@ mistralai/Mistral-7B-Instruct-v0.3:
|
||||
accuracy: 31.201
|
||||
mistralai/Mistral-Small-3.1-24B-Instruct-2503:
|
||||
- accuracy: 29.20
|
||||
- quant_algo: FP8
|
||||
kv_cache_quant_algo: FP8
|
||||
accuracy: 27.0
|
||||
mistralai/Mistral-Nemo-12b-Base:
|
||||
- accuracy: 28.906
|
||||
- quant_algo: FP8
|
||||
|
||||
@ -26,7 +26,11 @@ meta-llama/Llama-4-Maverick-17B-128E-Instruct:
|
||||
- accuracy: 92.20
|
||||
- quant_algo: FP8
|
||||
kv_cache_quant_algo: FP8
|
||||
accuracy: 90.20
|
||||
accuracy: 92.20
|
||||
- quant_algo: FP8
|
||||
kv_cache_quant_algo: FP8
|
||||
spec_dec_algo: Eagle
|
||||
accuracy: 92.20
|
||||
meta-llama/Llama-4-Scout-17B-16E-Instruct:
|
||||
- accuracy: 89.70
|
||||
- quant_algo: NVFP4
|
||||
@ -176,6 +180,9 @@ mistralai/Ministral-8B-Instruct-2410:
|
||||
accuracy: 78.35
|
||||
mistralai/Mistral-Small-3.1-24B-Instruct-2503:
|
||||
- accuracy: 89.23
|
||||
- quant_algo: FP8
|
||||
kv_cache_quant_algo: FP8
|
||||
accuracy: 89.23
|
||||
microsoft/Phi-4-multimodal-instruct:
|
||||
- accuracy: 81.19
|
||||
microsoft/Phi-4-multimodal-instruct-long-rope:
|
||||
|
||||
@ -62,7 +62,8 @@ meta-llama/Llama-3.2-3B:
|
||||
accuracy: 60.60
|
||||
meta-llama/Llama-3.3-70B-Instruct:
|
||||
- accuracy: 81.31
|
||||
- spec_dec_algo: Eagle
|
||||
- quant_algo: FP8
|
||||
spec_dec_algo: Eagle
|
||||
accuracy: 81.31
|
||||
- quant_algo: NVFP4
|
||||
kv_cache_quant_algo: FP8
|
||||
@ -81,6 +82,9 @@ meta-llama/Llama-4-Maverick-17B-128E-Instruct:
|
||||
kv_cache_quant_algo: FP8
|
||||
spec_dec_algo: Eagle
|
||||
accuracy: 86.40
|
||||
- quant_algo: FP8
|
||||
kv_cache_quant_algo: FP8
|
||||
accuracy: 86.40
|
||||
meta-llama/Llama-4-Scout-17B-16E-Instruct:
|
||||
- accuracy: 80.00
|
||||
- quant_algo: NVFP4
|
||||
@ -113,6 +117,9 @@ mistralai/Mixtral-8x22B-v0.1:
|
||||
accuracy: 77.63
|
||||
mistralai/Mistral-Small-3.1-24B-Instruct-2503:
|
||||
- accuracy: 81.7
|
||||
- quant_algo: FP8
|
||||
kv_cache_quant_algo: FP8
|
||||
accuracy: 81.1
|
||||
google/gemma-2-9b-it:
|
||||
- accuracy: 73.05
|
||||
google/gemma-3-1b-it:
|
||||
|
||||
@ -782,6 +782,7 @@ class TestLlama3_1_8B(CliFlowAccuracyTestHarness):
|
||||
extra_build_args=extra_build_args)
|
||||
|
||||
@skip_pre_hopper
|
||||
@skip_post_blackwell
|
||||
@pytest.mark.skip_less_device(4)
|
||||
@pytest.mark.parametrize(
|
||||
"gemm_allreduce", [False, pytest.param(True, marks=skip_no_nvls)],
|
||||
|
||||
@ -286,7 +286,7 @@ def run_parallel_test(model_name: str,
|
||||
total_ctx_gpus = ctx_tp * ctx_pp * ctx_instances
|
||||
total_gen_gpus = gen_tp * gen_pp * gen_instances
|
||||
if total_ctx_gpus + total_gen_gpus > get_device_count():
|
||||
pytest.fail(
|
||||
pytest.skip(
|
||||
f"Not enough devices for {ctx_instances} ctx instances (ctx_pp={ctx_pp}*ctx_tp={ctx_tp}) + {gen_instances} gen instances (gen_pp={gen_pp}*gen_tp={gen_tp}), total: {total_ctx_gpus + total_gen_gpus}"
|
||||
)
|
||||
|
||||
@ -376,6 +376,7 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
|
||||
task.evaluate(llm)
|
||||
|
||||
@pytest.mark.skip_less_device(2)
|
||||
@skip_pre_hopper
|
||||
def test_ngram(self):
|
||||
speculative_decoding_config = {
|
||||
"decoding_type": "NGram",
|
||||
@ -424,6 +425,7 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
|
||||
task.evaluate(llm)
|
||||
|
||||
@pytest.mark.skip_less_device(2)
|
||||
@skip_pre_hopper
|
||||
@parametrize_with_ids("overlap_scheduler", [True, False])
|
||||
@parametrize_with_ids("eagle3_one_model", [True, False])
|
||||
def test_eagle3(self, overlap_scheduler, eagle3_one_model):
|
||||
@ -581,7 +583,6 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
|
||||
return run_parallel_test(self.MODEL_NAME, self.MODEL_PATH, pp, tp, pp,
|
||||
tp, 1, 1, [get_accuracy_task(testset)])
|
||||
|
||||
@pytest.mark.skip_less_device(4)
|
||||
@parametrize_with_ids("ctx_pp", [2, 4])
|
||||
@parametrize_with_ids("gen_tp", [1, 2])
|
||||
@pytest.mark.parametrize("testset", ["GSM8K", "MMLU"])
|
||||
@ -592,19 +593,18 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
|
||||
return run_parallel_test(self.MODEL_NAME, self.MODEL_PATH, ctx_pp, 1, 1,
|
||||
gen_tp, 1, 1, [get_accuracy_task(testset)])
|
||||
|
||||
@pytest.mark.skip_less_device(4)
|
||||
@pytest.mark.parametrize("testset", ["GSM8K", "MMLU"])
|
||||
def test_multi_instance(self, testset):
|
||||
return run_parallel_test(self.MODEL_NAME, self.MODEL_PATH, 1, 1, 1, 1,
|
||||
2, 2, [get_accuracy_task(testset)])
|
||||
|
||||
|
||||
@pytest.mark.skip_less_device_memory(140000)
|
||||
@pytest.mark.timeout(3600)
|
||||
class TestLlama4ScoutInstruct(LlmapiAccuracyTestHarness):
|
||||
MODEL_NAME = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
|
||||
MODEL_PATH = f"{llm_models_root()}/llama4-models/Llama-4-Scout-17B-16E-Instruct"
|
||||
|
||||
@pytest.mark.skip_less_device_memory(140000)
|
||||
@pytest.mark.timeout(3600)
|
||||
@pytest.mark.skip_less_device(8)
|
||||
@pytest.mark.parametrize("overlap_scheduler", [False, True])
|
||||
def test_auto_dtype(self, overlap_scheduler):
|
||||
@ -685,6 +685,7 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
@parametrize_with_ids("overlap_scheduler", [True, False])
|
||||
@parametrize_with_ids("mtp_nextn",
|
||||
[0, pytest.param(2, marks=skip_pre_hopper)])
|
||||
@pytest.mark.skip_less_device(8)
|
||||
def test_auto_dtype(self, overlap_scheduler, mtp_nextn):
|
||||
ctx_server_config = {"disable_overlap_scheduler": True}
|
||||
gen_server_config = {"disable_overlap_scheduler": not overlap_scheduler}
|
||||
@ -728,6 +729,7 @@ class TestGemma3_1BInstruct(LlmapiAccuracyTestHarness):
|
||||
MODEL_NAME = "google/gemma-3-1b-it"
|
||||
MODEL_PATH = f"{llm_models_root()}/gemma/gemma-3-1b-it/"
|
||||
|
||||
@pytest.mark.skip_less_device(2)
|
||||
@pytest.mark.parametrize("overlap_scheduler", [False, True])
|
||||
def test_auto_dtype(self, overlap_scheduler):
|
||||
pytest.skip(
|
||||
@ -817,6 +819,8 @@ class TestQwen3_8B(LlmapiAccuracyTestHarness):
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
@skip_pre_hopper
|
||||
@pytest.mark.skip_less_device(2)
|
||||
@pytest.mark.parametrize("overlap_scheduler", [False, True])
|
||||
def test_auto_dtype(self, overlap_scheduler):
|
||||
ctx_server_config = {
|
||||
|
||||
@ -568,25 +568,27 @@ class TestLlama3_3_70BInstruct(LlmapiAccuracyTestHarness):
|
||||
task.evaluate(llm,
|
||||
extra_evaluator_kwargs=dict(apply_chat_template=True))
|
||||
|
||||
@skip_pre_hopper
|
||||
@pytest.mark.skip_less_mpi_world_size(8)
|
||||
@parametrize_with_ids("eagle3_one_model", [True, False])
|
||||
def test_eagle3_tp8(self, eagle3_one_model):
|
||||
model_path = f"{llm_models_root()}/llama-3.3-models/Llama-3.3-70B-Instruct"
|
||||
def test_fp8_eagle3_tp8(self, eagle3_one_model):
|
||||
model_path = f"{llm_models_root()}/modelopt-hf-model-hub/Llama-3.3-70B-Instruct-fp8"
|
||||
eagle_model_dir = f"{llm_models_root()}/EAGLE3-LLaMA3.3-Instruct-70B"
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.6)
|
||||
spec_config = EagleDecodingConfig(max_draft_len=4,
|
||||
speculative_model_dir=eagle_model_dir,
|
||||
eagle3_one_model=eagle3_one_model)
|
||||
pytorch_config = dict(disable_overlap_scheduler=True, )
|
||||
pytorch_config = dict(
|
||||
disable_overlap_scheduler=True,
|
||||
cuda_graph_config=CudaGraphConfig(max_batch_size=1))
|
||||
with LLM(model_path,
|
||||
max_batch_size=16,
|
||||
tensor_parallel_size=8,
|
||||
speculative_config=spec_config,
|
||||
kv_cache_config=kv_cache_config,
|
||||
**pytorch_config) as llm:
|
||||
task = CnnDailymail(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
@pytest.mark.skip_less_device(4)
|
||||
@skip_pre_hopper
|
||||
@ -911,8 +913,24 @@ class TestMistralSmall24B(LlmapiAccuracyTestHarness):
|
||||
MODEL_NAME = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
|
||||
MODEL_PATH = f"{llm_models_root()}/Mistral-Small-3.1-24B-Instruct-2503"
|
||||
|
||||
@pytest.mark.skip_less_device_memory(80000)
|
||||
def test_auto_dtype(self):
|
||||
with LLM(self.MODEL_PATH) as llm:
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.75)
|
||||
with LLM(self.MODEL_PATH, kv_cache_config=kv_cache_config) as llm:
|
||||
task = CnnDailymail(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
@skip_pre_ada
|
||||
@pytest.mark.skip_less_device_memory(80000)
|
||||
def test_fp8(self):
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.75)
|
||||
model_path = f"{llm_models_root()}/Mistral-Small-3.1-24B-Instruct-2503-fp8"
|
||||
with LLM(model_path, kv_cache_config=kv_cache_config) as llm:
|
||||
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8
|
||||
task = CnnDailymail(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
@ -2804,8 +2822,6 @@ class TestGPTOSS(LlmapiAccuracyTestHarness):
|
||||
extra_evaluator_kwargs = {
|
||||
"fewshot_as_multiturn": True,
|
||||
"apply_chat_template": True,
|
||||
"scores_filter": "exact_match,flexible-extract",
|
||||
"MAX_OUTPUT_LEN": 8192
|
||||
}
|
||||
|
||||
MODEL_PATH = f"{llm_models_root()}/gpt_oss/gpt-oss-120b"
|
||||
@ -2819,7 +2835,9 @@ class TestGPTOSS(LlmapiAccuracyTestHarness):
|
||||
(True, True),
|
||||
])
|
||||
def test_w4_1gpu(self, moe_backend, cuda_graph, overlap_scheduler, mocker):
|
||||
pytest.skip("https://nvbugs/5481087")
|
||||
mocker.patch.object(GSM8K, "MAX_OUTPUT_LEN", 8192)
|
||||
mocker.patch.dict(GSM8K.EVALUATE_KWARGS,
|
||||
{"scores_filter": "exact_match,flexible-extract"})
|
||||
if moe_backend == "TRITON" and not IS_TRITON_KERNELS_AVAILABLE:
|
||||
pytest.skip("Triton kernels are not available")
|
||||
|
||||
@ -2837,7 +2855,6 @@ class TestGPTOSS(LlmapiAccuracyTestHarness):
|
||||
|
||||
with llm:
|
||||
model_name = "GPT-OSS/MXFP4"
|
||||
mocker.patch.object(GSM8K, "MAX_OUTPUT_LEN", 8192)
|
||||
task = GSM8K(model_name)
|
||||
task.evaluate(llm,
|
||||
extra_evaluator_kwargs=self.extra_evaluator_kwargs)
|
||||
@ -2857,7 +2874,9 @@ class TestGPTOSS(LlmapiAccuracyTestHarness):
|
||||
ids=["tp4", "ep4", "dp4"])
|
||||
def test_w4_4gpus(self, moe_backend, tp_size, pp_size, ep_size,
|
||||
attention_dp, cuda_graph, overlap_scheduler, mocker):
|
||||
pytest.skip("https://nvbugs/5481087")
|
||||
mocker.patch.object(GSM8K, "MAX_OUTPUT_LEN", 8192)
|
||||
mocker.patch.dict(GSM8K.EVALUATE_KWARGS,
|
||||
{"scores_filter": "exact_match,flexible-extract"})
|
||||
if moe_backend == "TRITON":
|
||||
if not IS_TRITON_KERNELS_AVAILABLE:
|
||||
pytest.skip("Triton kernels are not available")
|
||||
@ -2878,7 +2897,6 @@ class TestGPTOSS(LlmapiAccuracyTestHarness):
|
||||
with llm:
|
||||
model_name = "GPT-OSS/MXFP4"
|
||||
task = GSM8K(model_name)
|
||||
mocker.patch.object(GSM8K, "MAX_OUTPUT_LEN", 8192)
|
||||
task.evaluate(llm,
|
||||
extra_evaluator_kwargs=self.extra_evaluator_kwargs)
|
||||
|
||||
@ -2890,6 +2908,9 @@ class TestGPTOSS(LlmapiAccuracyTestHarness):
|
||||
ids=["dp4"])
|
||||
def test_w4a16(self, tp_size, pp_size, ep_size, attention_dp, cuda_graph,
|
||||
overlap_scheduler, monkeypatch, mocker):
|
||||
mocker.patch.object(GSM8K, "MAX_OUTPUT_LEN", 8192)
|
||||
mocker.patch.dict(GSM8K.EVALUATE_KWARGS,
|
||||
{"scores_filter": "exact_match,flexible-extract"})
|
||||
if not IS_TRITON_KERNELS_AVAILABLE:
|
||||
pytest.skip("Triton kernels are not available")
|
||||
monkeypatch.setenv("OVERRIDE_QUANT_ALGO", "W4A16_MXFP4")
|
||||
@ -2909,7 +2930,6 @@ class TestGPTOSS(LlmapiAccuracyTestHarness):
|
||||
with llm:
|
||||
model_name = "GPT-OSS/BF16"
|
||||
task = GSM8K(model_name)
|
||||
mocker.patch.object(GSM8K, {"MAX_OUTPUT_LEN": 8192})
|
||||
task.evaluate(llm,
|
||||
extra_evaluator_kwargs=self.extra_evaluator_kwargs)
|
||||
|
||||
|
||||
@ -155,12 +155,14 @@ def test_llmapi_speculative_decoding_ngram(llm_root, engine_dir, llm_venv):
|
||||
"llm_speculative_decoding.py", "NGRAM")
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="https://nvbugs/5365825")
|
||||
@pytest.mark.skip(reason="https://nvbugs/5365825"
|
||||
) # maybe unrelated, but this test will always timeout
|
||||
def test_llmapi_sampling(llm_root, engine_dir, llm_venv):
|
||||
_run_llmapi_example(llm_root, engine_dir, llm_venv, "llm_sampling.py")
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="https://nvbugs/5365825")
|
||||
@pytest.mark.skip(reason="https://nvbugs/5365825"
|
||||
) # maybe unrelated, but this test will always timeout
|
||||
def test_llmapi_runtime(llm_root, engine_dir, llm_venv):
|
||||
_run_llmapi_example(llm_root, engine_dir, llm_venv, "llm_runtime.py")
|
||||
|
||||
@ -172,3 +174,8 @@ def test_llmapi_kv_cache_connector(llm_root, llm_venv, model):
|
||||
model_path = f"{llm_models_root()}/{model}"
|
||||
|
||||
venv_check_call(llm_venv, [str(script_path), model_path])
|
||||
|
||||
|
||||
def test_llmapi_tensorrt_engine(llm_root, engine_dir, llm_venv):
|
||||
_run_llmapi_example(llm_root, engine_dir, llm_venv,
|
||||
"_tensorrt_engine/quickstart_example.py")
|
||||
|
||||
@ -17,8 +17,6 @@
|
||||
Model pytorch yaml config for trtllm-bench perf tests
|
||||
"""
|
||||
|
||||
from tensorrt_llm.llmapi import KvCacheConfig
|
||||
|
||||
|
||||
def recursive_update(d, u):
|
||||
for k, v in u.items():
|
||||
@ -204,9 +202,10 @@ def get_model_yaml_config(model_label: str,
|
||||
'swap_gate_up_proj_lora_b_weight'] = False
|
||||
base_config.update(lora_config)
|
||||
|
||||
kv_cache_config = base_config.get('kv_cache_config', KvCacheConfig())
|
||||
kv_cache_config = base_config.get('kv_cache_config', {})
|
||||
if 'kv_cache_dtype' in base_config:
|
||||
kv_cache_config.dtype = base_config.pop('kv_cache_dtype', 'auto')
|
||||
kv_cache_dtype = base_config.pop('kv_cache_dtype', 'auto')
|
||||
kv_cache_config['dtype'] = kv_cache_dtype
|
||||
base_config.update({'kv_cache_config': kv_cache_config})
|
||||
|
||||
return base_config
|
||||
|
||||
@ -451,7 +451,9 @@ class BenchRunner:
|
||||
skip_engine_build: bool = False,
|
||||
quant: Optional[str] = None,
|
||||
extra_llm_api_options: Optional[str] = None,
|
||||
use_mpirun: bool = False):
|
||||
use_mpirun: bool = False,
|
||||
concurrency: Optional[int] = None,
|
||||
num_requests: int = 10):
|
||||
|
||||
llm_models = llm_models_root()
|
||||
assert llm_models is not None
|
||||
@ -476,12 +478,14 @@ class BenchRunner:
|
||||
else:
|
||||
self.mpirun_cmd = ""
|
||||
self.engine_path = None
|
||||
self.concurrency = concurrency
|
||||
self.num_requests = num_requests
|
||||
|
||||
def __call__(self):
|
||||
self.prepare_dataset()
|
||||
if not (self.skip_engine_build or self.use_pytorch_backend):
|
||||
self.build_engine()
|
||||
self.run_bench()
|
||||
return self.run_bench()
|
||||
|
||||
def prepare_dataset(self):
|
||||
dataset_tool = Path(self.llm_root, "benchmarks", "cpp",
|
||||
@ -504,7 +508,7 @@ class BenchRunner:
|
||||
"--output-stdev",
|
||||
"0",
|
||||
"--num-requests",
|
||||
"10",
|
||||
str(self.num_requests),
|
||||
]
|
||||
print(f"Running command: {' '.join(command)}")
|
||||
dataset_output = self.llm_venv.run_cmd(
|
||||
@ -558,7 +562,47 @@ class BenchRunner:
|
||||
|
||||
if self.extra_llm_api_options:
|
||||
benchmark_cmd += f" --extra_llm_api_options {self.extra_llm_api_options}"
|
||||
check_call(benchmark_cmd, shell=True, env=self.llm_venv._new_env)
|
||||
if self.concurrency:
|
||||
benchmark_cmd += f" --concurrency {self.concurrency}"
|
||||
if self.num_requests:
|
||||
benchmark_cmd += f" --num_requests {self.num_requests}"
|
||||
|
||||
benchmark_output = check_output(benchmark_cmd,
|
||||
shell=True,
|
||||
env=self.llm_venv._new_env)
|
||||
return self.parse_benchmark_output(benchmark_output)
|
||||
|
||||
def parse_benchmark_output(self, output):
|
||||
"""Parse the benchmark output to extract key metrics."""
|
||||
result = {
|
||||
'concurrency': self.concurrency,
|
||||
'num_requests': self.num_requests,
|
||||
'throughput': 0,
|
||||
'latency': 0
|
||||
}
|
||||
|
||||
lines = output.split('\n')
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if 'total token throughput' in line.lower(
|
||||
) and 'tokens/sec' in line.lower():
|
||||
try:
|
||||
throughput = line.split(":")[1].strip()
|
||||
result['throughput'] = throughput
|
||||
except (IndexError, ValueError) as e:
|
||||
print(
|
||||
f"Failed to parse throughput from line: {line}. Error: {e}"
|
||||
)
|
||||
elif 'total latency' in line.lower() and 'ms' in line.lower():
|
||||
try:
|
||||
latency = line.split(":")[1].strip()
|
||||
result['latency'] = latency
|
||||
except (IndexError, ValueError) as e:
|
||||
print(
|
||||
f"Failed to parse latency from line: {line}. Error: {e}"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", ["meta-llama/Meta-Llama-3-8B-Instruct"],
|
||||
@ -581,6 +625,67 @@ def test_trtllm_bench_llmapi_launch(llm_root, llm_venv, model_name,
|
||||
runner()
|
||||
|
||||
|
||||
@skip_pre_hopper
|
||||
@pytest.mark.skip_less_device_memory(80000)
|
||||
@pytest.mark.parametrize("model_name", ["meta/Meta-Llama-3.1-8B"],
|
||||
ids=["llama3_1-8b"])
|
||||
@pytest.mark.parametrize("model_subdir", ["llama-3.1-model/Meta-Llama-3.1-8B"],
|
||||
ids=["llama_v3_1"])
|
||||
@pytest.mark.parametrize("use_pytorch_backend", [False], ids=["trt_backend"])
|
||||
def test_trtllm_bench_mig_launch(llm_root, llm_venv, model_name, model_subdir,
|
||||
use_pytorch_backend):
|
||||
"run bench mark in MIG mode, check if the throughput is increasing by concurrency"
|
||||
skip_engine_build = False
|
||||
results = {}
|
||||
concurrency_list = [1, 32, 64, 128]
|
||||
|
||||
for concurrency in concurrency_list:
|
||||
num_requests = concurrency * 10
|
||||
runner = BenchRunner(llm_root=llm_root,
|
||||
llm_venv=llm_venv,
|
||||
model_name=model_name,
|
||||
model_subdir=model_subdir,
|
||||
streaming=False,
|
||||
use_pytorch_backend=use_pytorch_backend,
|
||||
use_mpirun=False,
|
||||
tp_size=1,
|
||||
concurrency=concurrency,
|
||||
num_requests=num_requests,
|
||||
skip_engine_build=skip_engine_build)
|
||||
|
||||
output = runner()
|
||||
results[concurrency] = output
|
||||
|
||||
print(f"\n=== Benchmark Results Comparison ===")
|
||||
print(f"Model: {model_name}")
|
||||
print(f"Backend: {'PyTorch' if use_pytorch_backend else 'TensorRT'}")
|
||||
print(
|
||||
f"{'Concurrency':<15} {'Throughput':<15} {'Latency':<15} {'Num Requests':<15}"
|
||||
)
|
||||
print("-" * 60)
|
||||
|
||||
for idx, val in enumerate(concurrency_list):
|
||||
metrics = results.get(val)
|
||||
if not isinstance(metrics, dict):
|
||||
pytest.fail(
|
||||
f"Unexpected benchmark result type for concurrency {val}: {type(metrics)}"
|
||||
)
|
||||
try:
|
||||
throughput = float(metrics.get('throughput', 0))
|
||||
latency = float(metrics.get('latency', 0))
|
||||
num_requests = int(metrics.get('num_requests', 0))
|
||||
except (ValueError, TypeError) as e:
|
||||
pytest.fail(
|
||||
f"Failed to parse benchmark results for concurrency {val}: {e}")
|
||||
assert throughput > 0, f"Throughput is 0 for concurrency {val}"
|
||||
assert latency > 0, f"Latency is 0 for concurrency {val}"
|
||||
print(f"{val:<15} {throughput:<15} {latency:<15} {num_requests:<15}")
|
||||
if idx > 0:
|
||||
prev_throughput = float(results[concurrency_list[idx - 1]].get(
|
||||
'throughput', 0))
|
||||
assert throughput > prev_throughput * 1.3, f"Throughput is not increasing for concurrency {concurrency_list[idx]}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_name, llama_model_root",
|
||||
[pytest.param("TinyLlama-1.1B-Chat-v1.0", "TinyLlama-1.1B-Chat-v1.0")],
|
||||
@ -2090,7 +2195,7 @@ def test_ptp_quickstart_advanced_8gpus(llm_root, llm_venv, model_name,
|
||||
def test_ptp_quickstart_advanced_8gpus_chunked_prefill_sq_22k(
|
||||
llm_root, llm_venv, model_name, model_path, cuda_graph):
|
||||
print(f"Testing {model_name} on 8 GPUs.")
|
||||
example_root = Path(os.path.join(llm_root, "examples", "pytorch"))
|
||||
example_root = Path(os.path.join(llm_root, "examples", "llm-api"))
|
||||
cmd = [
|
||||
str(example_root / "quickstart_advanced.py"),
|
||||
"--enable_chunked_prefill",
|
||||
@ -2115,10 +2220,12 @@ def test_ptp_quickstart_advanced_8gpus_chunked_prefill_sq_22k(
|
||||
@pytest.mark.skip_less_device_memory(80000)
|
||||
@pytest.mark.skip_less_device(2)
|
||||
@pytest.mark.parametrize("model_name,model_path", [
|
||||
("Llama3.1-70B-BF16", "llama-3.1-model/Meta-Llama-3.1-70B"),
|
||||
('Nemotron-Super-49B-v1-BF16',
|
||||
'nemotron-nas/Llama-3_3-Nemotron-Super-49B-v1'),
|
||||
("Mixtral-8x7B-BF16", "Mixtral-8x7B-Instruct-v0.1"),
|
||||
pytest.param('Llama3.1-70B-BF16',
|
||||
'llama-3.1-model/Meta-Llama-3.1-70B',
|
||||
marks=pytest.mark.skip_less_device_memory(95000)),
|
||||
])
|
||||
def test_ptp_quickstart_advanced_2gpus_sm120(llm_root, llm_venv, model_name,
|
||||
model_path):
|
||||
@ -2565,6 +2672,106 @@ def test_ptp_quickstart_multimodal_2gpu(llm_root, llm_venv, model_name,
|
||||
print("All answers are correct!")
|
||||
|
||||
|
||||
@pytest.mark.skip_less_device_memory(80000)
|
||||
@pytest.mark.parametrize("model_name,model_path", [
|
||||
("gemma-3-27b-it", "gemma/gemma-3-27b-it"),
|
||||
("mistral-small-3.1-24b-instruct", "Mistral-Small-3.1-24B-Instruct-2503"),
|
||||
("Phi-4-multimodal-instruct", "multimodals/Phi-4-multimodal-instruct"),
|
||||
])
|
||||
def test_ptp_quickstart_multimodal_multiturn(llm_root, llm_venv, model_name,
|
||||
model_path):
|
||||
example_root = Path(os.path.join(llm_root, "examples", "llm-api"))
|
||||
test_data_root = Path(
|
||||
os.path.join(llm_models_root(), "multimodals", "test_data"))
|
||||
|
||||
print(f"Accuracy test {model_name} image mode with example inputs.")
|
||||
|
||||
# Define accuracy inputs for image modality
|
||||
accuracy_inputs = {
|
||||
"image": {
|
||||
"prompt": [
|
||||
"Describe what you see in this image.",
|
||||
"How would you describe the atmosphere of this scene?",
|
||||
],
|
||||
"media": [
|
||||
str(test_data_root / "inpaint.png"),
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
# Define expected keywords for each model
|
||||
expected_keywords = {
|
||||
"gemma-3-27b-it": {
|
||||
"image": [
|
||||
["half", "dome", "yosemite", "landmark", "rounded"],
|
||||
["atmosphere", "peaceful", "majestic", "calm", "quiet"],
|
||||
],
|
||||
},
|
||||
"mistral-small-3.1-24b-instruct": {
|
||||
"image": [
|
||||
["depicts", "landscape", "rock", "sky", "high", "altitude"],
|
||||
["atmosphere", "serene", "majestic", "sense", "tranquility"],
|
||||
],
|
||||
},
|
||||
"Phi-4-multimodal-instruct": {
|
||||
"image": [
|
||||
["depicts", "landscape", "mountain", "half", "dome"],
|
||||
["atmosphere", "serene", "sense", "tranquility", "peace."],
|
||||
],
|
||||
},
|
||||
}
|
||||
# Build command for image modality
|
||||
cmd = [
|
||||
str(example_root / "quickstart_multimodal.py"),
|
||||
"--model_dir",
|
||||
f"{llm_models_root()}/{model_path}",
|
||||
"--modality",
|
||||
"image",
|
||||
"--multiturn",
|
||||
"--prompt",
|
||||
*accuracy_inputs["image"]["prompt"],
|
||||
"--media",
|
||||
*accuracy_inputs["image"]["media"],
|
||||
]
|
||||
|
||||
# Add model-specific configurations
|
||||
if model_name == "gemma-3-27b-it":
|
||||
# Gemma3 VLM needs a custom mask which is only supported by flashinfer backend currently.
|
||||
# Custom mask involves bidirectional masking of image tokens in context phase. To get this
|
||||
# correct, chunked prefill and kv cache reuse need to be turned off.
|
||||
cmd.append("--image_format=pil")
|
||||
cmd.append("--attention_backend=FLASHINFER")
|
||||
cmd.append("--disable_kv_cache_reuse")
|
||||
elif model_name == "Phi-4-multimodal-instruct":
|
||||
# Set max_seq_len to 4096 to use short rope factor.
|
||||
cmd.append("--max_seq_len=4096")
|
||||
cmd.append("--load_lora")
|
||||
cmd.append("--auto_model_name")
|
||||
cmd.append("Phi4MMForCausalLM")
|
||||
|
||||
output = llm_venv.run_cmd(cmd, caller=check_output)
|
||||
print("output:", output)
|
||||
# Set match ratio based on model
|
||||
match_ratio = 4.0 / 5
|
||||
if model_name == "Phi-4-multimodal-instruct":
|
||||
match_ratio = 0.6
|
||||
|
||||
# Check output accuracy
|
||||
for prompt_output, prompt_keywords in zip(
|
||||
parse_output(output), expected_keywords[model_name]["image"]):
|
||||
matches = [
|
||||
keyword in prompt_output.lower() for keyword in prompt_keywords
|
||||
]
|
||||
obs_match_ratio = 1. * sum(matches) / len(matches)
|
||||
print("prompt_output:", prompt_output)
|
||||
print("prompt_keywords:", prompt_keywords)
|
||||
print("matches:", matches)
|
||||
print("obs_match_ratio:", obs_match_ratio)
|
||||
assert obs_match_ratio >= match_ratio, f"Incorrect output!\nGenerated \"{prompt_output}\"\nExpected keywords \"{prompt_keywords}\"\n Matched keywords: {matches}\n Observed match ratio {obs_match_ratio} below threshold {match_ratio}"
|
||||
|
||||
print("All answers are correct!")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name,model_path", [
|
||||
("BertForSequenceClassification", "bert/bert-base-uncased-yelp-polarity"),
|
||||
])
|
||||
|
||||
@ -471,11 +471,12 @@ accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_beam_search[
|
||||
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_beam_search[enable_cuda_graph=True-enable_padding=True-disable_overlap_scheduler=True]
|
||||
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_tp4
|
||||
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_nvfp4_tp4
|
||||
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_eagle3_tp8[eagle3_one_model=True]
|
||||
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_eagle3_tp8[eagle3_one_model=False]
|
||||
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_eagle3_tp8[eagle3_one_model=True]
|
||||
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_eagle3_tp8[eagle3_one_model=False]
|
||||
accuracy/test_llm_api_pytorch.py::TestMistral7B::test_auto_dtype
|
||||
accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype
|
||||
accuracy/test_llm_api_pytorch.py::TestMistralSmall24B::test_auto_dtype
|
||||
accuracy/test_llm_api_pytorch.py::TestMistralSmall24B::test_fp8
|
||||
accuracy/test_llm_api_pytorch.py::TestLlama4MaverickInstruct::test_auto_dtype[tp8-cuda_graph=False]
|
||||
accuracy/test_llm_api_pytorch.py::TestLlama4MaverickInstruct::test_auto_dtype[tp8ep4-cuda_graph=True]
|
||||
accuracy/test_llm_api_pytorch.py::TestLlama4MaverickInstruct::test_auto_dtype[tp8ep8-cuda_graph=True]
|
||||
@ -662,6 +663,9 @@ test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[image_audio]
|
||||
test_e2e.py::test_ptp_quickstart_multimodal_2gpu[gemma-3-27b-it-gemma/gemma-3-27b-it]
|
||||
test_e2e.py::test_ptp_quickstart_multimodal_2gpu[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503]
|
||||
test_e2e.py::test_ptp_quickstart_multimodal_2gpu[Phi-4-multimodal-instruct-multimodals/Phi-4-multimodal-instruct]
|
||||
test_e2e.py::test_ptp_quickstart_multimodal_multiturn[gemma-3-27b-it-gemma/gemma-3-27b-it]
|
||||
test_e2e.py::test_ptp_quickstart_multimodal_multiturn[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503]
|
||||
test_e2e.py::test_ptp_quickstart_multimodal_multiturn[Phi-4-multimodal-instruct-multimodals/Phi-4-multimodal-instruct]
|
||||
test_e2e.py::test_ptp_quickstart_bert[VANILLA-BertForSequenceClassification-bert/bert-base-uncased-yelp-polarity]
|
||||
test_e2e.py::test_ptp_quickstart_bert[TRTLLM-BertForSequenceClassification-bert/bert-base-uncased-yelp-polarity]
|
||||
test_e2e.py::test_ptp_star_attention_example[Llama3.1-8B-BF16-llama-3.1-model/Meta-Llama-3.1-8B]
|
||||
|
||||
@ -58,6 +58,7 @@ accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp4-cutlass]
|
||||
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp4-triton]
|
||||
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp4-trtllm]
|
||||
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4a16[dp4]
|
||||
accuracy/test_llm_api_pytorch.py::TestMistralSmall24B::test_auto_dtype
|
||||
accuracy/test_llm_api_pytorch.py::TestKanana_Instruct::test_auto_dtype
|
||||
accuracy/test_llm_api_pytorch.py::TestKimiK2::test_fp8_blockscale[latency]
|
||||
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8B::test_nvfp4
|
||||
@ -86,10 +87,10 @@ accuracy/test_llm_api_pytorch.py::TestLlama3_2_1B::test_auto_dtype
|
||||
accuracy/test_llm_api_pytorch.py::TestLlama3_2_1B::test_fp8_prequantized
|
||||
accuracy/test_llm_api_pytorch.py::TestLlama3_2_3B::test_auto_dtype
|
||||
accuracy/test_llm_api_pytorch.py::TestLlama3_2_3B::test_fp8_prequantized
|
||||
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_eagle3_tp8[eagle3_one_model=False]
|
||||
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_eagle3_tp8[eagle3_one_model=True]
|
||||
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_tp4
|
||||
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_nvfp4_tp4
|
||||
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_eagle3_tp8[eagle3_one_model=True]
|
||||
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_eagle3_tp8[eagle3_one_model=False]
|
||||
accuracy/test_llm_api_pytorch.py::TestLlama4MaverickInstruct::test_auto_dtype[tp4-cuda_graph=False]
|
||||
accuracy/test_llm_api_pytorch.py::TestLlama4MaverickInstruct::test_auto_dtype[tp4ep2-cuda_graph=True]
|
||||
accuracy/test_llm_api_pytorch.py::TestLlama4MaverickInstruct::test_auto_dtype[tp4ep4-cuda_graph=True]
|
||||
|
||||
@ -191,8 +191,8 @@ llm_perf_sanity:
|
||||
|
||||
tests:
|
||||
#llama_v3.1_70b
|
||||
#pytorch backend
|
||||
- perf/test_perf.py::test_perf[llama_v3.1_70b-bench-pytorch-bfloat16-maxbs:1-input_output_len:512,32-quant:fp8-gpus:8]
|
||||
#trt backend
|
||||
- perf/test_perf.py::test_perf[llama_v3.1_70b-bench-bfloat16-maxbs:1-maxnt:544-input_output_len:512,32-quant:fp8-gpus:8]
|
||||
#llama_v3.3_70b_instruct_fp8
|
||||
#pytorch backend
|
||||
- perf/test_perf.py::test_perf[llama_v3.3_70b_instruct_fp8-bench-pytorch-float8-input_output_len:128,128-gpus:8]
|
||||
|
||||
@ -226,6 +226,7 @@ l0_h100:
|
||||
- accuracy/test_llm_api_pytorch.py::TestGemma3_27BInstruct::test_fp8_prequantized
|
||||
- accuracy/test_llm_api_pytorch.py::TestGemma3_27BInstruct::test_auto_dtype
|
||||
- accuracy/test_llm_api_pytorch.py::TestMistralSmall24B::test_auto_dtype
|
||||
- accuracy/test_llm_api_pytorch.py::TestMistralSmall24B::test_fp8
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8_block_scales[latency-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8_block_scales[latency-torch_compile=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding[llguidance]
|
||||
|
||||
@ -30,4 +30,5 @@ l0_sanity_check:
|
||||
- llmapi/test_llm_examples.py::test_llmapi_speculative_decoding_ngram
|
||||
- llmapi/test_llm_examples.py::test_llmapi_sampling
|
||||
- llmapi/test_llm_examples.py::test_llmapi_runtime
|
||||
- llmapi/test_llm_examples.py::test_llmapi_tensorrt_engine
|
||||
- examples/test_llm_api_with_mpi.py::test_llm_api_single_gpu_with_mpirun[TinyLlama-1.1B-Chat-v1.0]
|
||||
|
||||
@ -43,7 +43,6 @@ examples/test_multimodal.py::test_llm_multimodal_general[video-neva-pp:1-tp:1-bf
|
||||
examples/test_whisper.py::test_llm_whisper_general[large-v3-enable_gemm_plugin-enable_attention_plugin-disable_weight_only-float16-nb:1-use_python_runtime] SKIP (https://nvbugs/4866931)
|
||||
examples/test_nemotron.py::test_llm_nemotron_3_8b_1gpu[bfloat16-fp8] SKIP (https://nvbugs/4961624)
|
||||
examples/test_mistral.py::test_llm_mistral_v1_1gpu[mistral-7b-v0.1-float16-max_attention_window_size_4096-chunked_summarization_long] SKIP (https://nvbugs/5321371)
|
||||
test_e2e.py::test_openai_chat_structural_tag_example SKIP (https://nvbugspro.nvidia.com/bug/5375594)
|
||||
cpp/test_e2e.py::test_model[fp8-chatglm-90] SKIP (https://nvbugs/5034830)
|
||||
full:B200_PCIe/unittest/trt/functional SKIP (Disable for Blackwell)
|
||||
full:B200_PCIe/unittest/trt/quantization SKIP (Disable for Blackwell)
|
||||
@ -229,7 +228,6 @@ examples/test_qwen.py::test_llm_hf_qwen_quantization_1gpu[qwen2_vl_7b_instruct-f
|
||||
accuracy/test_cli_flow.py::TestSantacoder::test_auto_dtype SKIP (https://nvbugs/5234043)
|
||||
examples/test_gpt.py::test_starcoder_fp8_quantization_2gpu[starcoder] SKIP (https://nvbugs/5355128)
|
||||
examples/test_gpt.py::test_starcoder_fp8_quantization_2gpu[starcoderplus] SKIP (https://nvbugs/5355128)
|
||||
stress_test/stress_test.py::test_run_stress_test[llama-v3-8b-instruct-hf_tp1-stress_time_300s_timeout_450s-GUARANTEED_NO_EVICT-pytorch-stress-test] SKIP (https://nvbugs/5375646)
|
||||
full:L40S/accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype SKIP (https://nvbugs/5375620)
|
||||
full:L20/accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype SKIP (https://nvbugs/5375620)
|
||||
test_e2e.py::test_ptp_quickstart_advanced_8gpus[Llama3.1-405B-FP8-llama-3.1-model/Llama-3.1-405B-Instruct-FP8] SKIP (https://nvbugs/5380570)
|
||||
@ -243,6 +241,11 @@ examples/test_recurrentgemma.py::test_llm_recurrentgemma_1gpu[use_cpp_session-re
|
||||
examples/test_recurrentgemma.py::test_llm_recurrentgemma_2gpu[recurrentgemma-2b] SKIP (https://nvbugs/5401233)
|
||||
test_e2e.py::test_openai_multi_chat_example SKIP (https://nvbugs/5409416)
|
||||
test_e2e.py::test_ptp_star_attention_example[Llama3.1-8B-BF16-llama-3.1-model/Meta-Llama-3.1-8B] SKIP (https://nvbugs/5409420)
|
||||
llmapi/test_llm_examples.py::test_llmapi_speculative_decoding_mtp SKIP (https://nvbugs/5410399)
|
||||
unittest/trt/attention/test_gpt_attention.py -k "partition0" SKIP (https://nvbugs/5412456)
|
||||
unittest/trt/attention/test_gpt_attention.py -k "partition1" SKIP (https://nvbugs/5412456)
|
||||
unittest/trt/attention/test_gpt_attention.py -k "partition2" SKIP (https://nvbugs/5412456)
|
||||
unittest/trt/attention/test_gpt_attention.py -k "partition3" SKIP (https://nvbugs/5412456)
|
||||
examples/test_multimodal.py::test_llm_multimodal_general[kosmos-2-pp:1-tp:1-float16-bs:8-cpp_e2e:False-nb:1] SKIP (https://nvbugs/5141288)
|
||||
examples/test_qwen.py::test_llm_qwen_7b_int8_kv_1node_1gpus[qwen2_vl_7b_instruct-enable_gemm_plugin-enable_weight_only] SKIP (https://nvbugs/5419067)
|
||||
examples/test_qwen.py::test_llm_qwen_awq_single_gpu_summary[qwen2_vl_7b_instruct-nb:4] SKIP (https://nvbugs/5419068)
|
||||
@ -250,12 +253,11 @@ examples/test_recurrentgemma.py::test_llm_recurrentgemma_1gpu[use_cpp_session-re
|
||||
examples/test_bert.py::test_llm_bert_general[compare_hf-enable_remove_input_padding-use_attention_plugin-enable_context_fmha-tp:1-pp:1-float16-BertForSequenceClassification-bert/bert-base-uncased-yelp-polarity] SKIP (https://nvbugs/5421989)
|
||||
examples/test_bert.py::test_llm_bert_general[compare_hf-enable_remove_input_padding-use_attention_plugin-enable_context_fmha-tp:1-pp:1-float16-RobertaForSequenceClassification-bert/twitter-roberta-base-emotion] SKIP (https://nvbugs/5421989)
|
||||
examples/test_granite.py::test_granite_bf16_lora[granite-3.0-1b-a400m-instruct] SKIP (https://nvbugs/5431132)
|
||||
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=1-ctx_pp=4] SKIP (https://nvbugs/5431139)
|
||||
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=2-ctx_pp=4] SKIP (https://nvbugs/5434320)
|
||||
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=1-ctx_pp=4] SKIP (https://nvbugs/5434320)
|
||||
accuracy/test_llm_api.py::TestLlama3_2_1B::test_int4_awq_int8_kv_cache SKIP (https://nvbugs/5433541)
|
||||
accuracy/test_llm_api.py::TestLlama3_2_1B::test_fp8_pp2 SKIP (https://nvbugs/5433541)
|
||||
accuracy/test_llm_api_pytorch.py::TestPhi4MiniInstruct::test_auto_dtype SKIP (https://nvbugs/5433545)
|
||||
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[GSM8K-gen_tp=2-ctx_pp=4] SKIP (https://nvbugs/5431139)
|
||||
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[GSM8K-gen_tp=2-ctx_pp=4] SKIP (https://nvbugs/5434320)
|
||||
examples/test_gemma.py::test_hf_gemma_fp8_base_bf16_multi_lora[gemma-2-9b-it] SKIP (https://nvbugs/5434451)
|
||||
examples/test_gemma.py::test_hf_gemma_fp8_base_bf16_multi_lora[gemma-2-27b-it] SKIP (https://nvbugs/5434451)
|
||||
examples/test_gemma.py::test_hf_gemma_fp8_base_bf16_multi_lora[gemma-3-1b-it] SKIP (https://nvbugs/5434451)
|
||||
@ -294,8 +296,6 @@ disaggregated/test_disaggregated.py::test_disaggregated_diff_max_tokens[TinyLlam
|
||||
disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_tp1_single_gpu_mtp[DeepSeek-V3-Lite-fp8] SKIP (https://nvbugs/5465642)
|
||||
examples/test_multimodal.py::test_llm_multimodal_general[Mistral-Small-3.1-24B-Instruct-2503-pp:1-tp:1-bfloat16-bs:1-cpp_e2e:False-nb:1] SKIP (https://nvbugs/5431146)
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale[latency] SKIP (https://nvbugs/5464461)
|
||||
full:H100/accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_eagle3_tp8[eagle3_one_model=True] SKIP (https://nvbugs/5467815)
|
||||
full:H100/accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_eagle3_tp8[eagle3_one_model=False] SKIP (https://nvbugs/5467815)
|
||||
full:H100/accuracy/test_llm_api_pytorch.py::TestLlama4ScoutInstruct::test_fp8[tp4-cuda_graph=True] SKIP (https://nvbugs/5467815)
|
||||
full:H100/accuracy/test_llm_api_pytorch.py::TestLlama4ScoutInstruct::test_fp8_chunked_prefill[tp4ep4-cuda_graph=True] SKIP (https://nvbugs/5467815)
|
||||
accuracy/test_disaggregated_serving.py::TestQwen3_30B_A3B::test_mixed_ctx_gen_model[ctxpp2gentp2] SKIP (https://nvbugs/5470769)
|
||||
@ -331,11 +331,8 @@ accuracy/test_cli_flow.py::TestPhi4MiniInstruct::test_tp2 SKIP (https://nvbugs/5
|
||||
accuracy/test_cli_flow.py::TestLongAlpaca7B::test_auto_dtype SKIP (https://nvbugs/5481075)
|
||||
accuracy/test_llm_api.py::TestPhi4MiniInstruct::test_fp8 SKIP (https://nvbugs/5465143)
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale[throughput] SKIP (https://nvbugs/5471106)
|
||||
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[True-True-cutlass] SKIP (https://nvbugs/5481080)
|
||||
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp4-cutlass] SKIP (https://nvbugs/5481080)
|
||||
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[ep4-cutlass] SKIP (https://nvbugs/5481080)
|
||||
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-cutlass] SKIP (https://nvbugs/5481080)
|
||||
accuracy/test_llm_api_pytorch.py::TestEXAONE4::test_auto_dtype SKIP (https://nvbugs/5481090)
|
||||
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[True-True-cutlass] SKIP (https://nvbugs/5481080)
|
||||
test_e2e.py::test_ptp_quickstart_advanced_8gpus_chunked_prefill_sq_22k[Llama-4-Maverick-17B-128E-Instruct-FP8-llama4-models/nvidia/Llama-4-Maverick-17B-128E-Instruct-FP8-False] SKIP (https://nvbugs/5481094)
|
||||
test_e2e.py::test_ptp_quickstart_advanced_8gpus_chunked_prefill_sq_22k[Llama-4-Maverick-17B-128E-Instruct-FP8-llama4-models/nvidia/Llama-4-Maverick-17B-128E-Instruct-FP8-True] SKIP (https://nvbugs/5481094)
|
||||
test_e2e.py::test_ptp_quickstart_advanced_8gpus_chunked_prefill_sq_22k[Llama-4-Scout-17B-16E-Instruct-FP8-llama4-models/Llama-4-Scout-17B-16E-Instruct-FP8-True] SKIP (https://nvbugs/5481094)
|
||||
@ -357,3 +354,14 @@ accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype
|
||||
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True] SKIP (https://nvbugs/5488118)
|
||||
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=True-overlap_scheduler=True] SKIP (https://nvbugs/5488141)
|
||||
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ngram SKIP (https://nvbugs/5488118)
|
||||
full:GB200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen1.5_7b_chat-enable_fmha_fp32_acc-enable_plugin-tp2pp2-nb:4] SKIP (https://nvbugs/5247837)
|
||||
full:GB200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen2_7b_instruct-enable_fmha_fp32_acc-enable_plugin-tp2pp2-nb:4] SKIP (https://nvbugs/5247837)
|
||||
full:GB200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen2_vl_7b_instruct-enable_fmha_fp32_acc-enable_plugin-tp2pp2-nb:4] SKIP (https://nvbugs/5359696)
|
||||
full:GB200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen2.5_7b_chat-enable_fmha_fp32_acc-enable_plugin-tp2pp2-nb:4] SKIP (https://nvbugs/5247837)
|
||||
accuracy/test_cli_flow.py::TestLlama3_8BInstructGradient1048k::test_long_context_ppl SKIP (https://nvbugs/5413362)
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_tp8] SKIP (https://nvbugs/5455140)
|
||||
full:L40S/accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype[mtp_nextn=0-overlap_scheduler=False] SKIP (https://nvbugs/5347051)
|
||||
full:L40S/accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=False-overlap_scheduler=False] SKIP (https://nvbugs/5471106)
|
||||
full:L40S/accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[MMLU-tp2pp2] SKIP (https://nvbugs/5471108)
|
||||
test_e2e.py::test_multi_nodes_eval[llama4-models/nvidia/Llama-4-Maverick-17B-128E-Instruct-FP8-tp8pp2-mmlu] SKIP (https://nvbugs/5473781)
|
||||
accuracy/test_cli_flow.py::TestLlama3_8BInstructGradient1048k::test_long_context_ppl SKIP (https://nvbugs/5413362)
|
||||
|
||||
@ -67,6 +67,7 @@ class DummyModelEngine(PyTorchModelEngine):
|
||||
mapping = Mapping(world_size=tensorrt_llm.mpi_world_size(),
|
||||
tp_size=tensorrt_llm.mpi_world_size(),
|
||||
rank=tensorrt_llm.mpi_rank())
|
||||
self.model_is_wrapped = False
|
||||
super().__init__(model_path="",
|
||||
pytorch_backend_config=pytorch_backend_config,
|
||||
checkpoint_loader=None,
|
||||
|
||||
@ -443,3 +443,58 @@ def test_mistral_3_vlm_allclose_to_hf(mistral_small_3_1_24b_config, backend, use
|
||||
torch.testing.assert_close(logits, ref.logits[:, -1].float(), atol=0.4, rtol=0.4)
|
||||
if graph_runner is not None:
|
||||
graph_runner.clear()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"in_shapes, image_sizes, expected_out_shape",
|
||||
[
|
||||
(
|
||||
[(2, 3, 100, 150), (1, 3, 200, 100), (3, 3, 120, 180)],
|
||||
[
|
||||
[[92, 150], [100, 73]],
|
||||
[[200, 100]],
|
||||
[[37, 130], [120, 83], [73, 180]],
|
||||
],
|
||||
[6, 3, 200, 180],
|
||||
),
|
||||
# Single batch, single image.
|
||||
(
|
||||
[(1, 3, 64, 128)],
|
||||
[[[64, 128]]],
|
||||
[1, 3, 64, 128],
|
||||
),
|
||||
# Same max size across batches.
|
||||
(
|
||||
[(2, 3, 59, 59), (1, 3, 59, 59), (5, 3, 59, 59)],
|
||||
[
|
||||
[[13, 59], [59, 17]],
|
||||
[[59, 59]],
|
||||
[[19, 29], [59, 31], [17, 54], [13, 59], [11, 37]],
|
||||
],
|
||||
[8, 3, 59, 59],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_batch_pixel_values(in_shapes, image_sizes, expected_out_shape):
|
||||
# Test case 1: Basic functionality with different sized images
|
||||
pixel_values = [torch.randn(*shape) for shape in in_shapes]
|
||||
image_sizes = [torch.tensor(size) for size in image_sizes]
|
||||
|
||||
batched_pixels, batched_sizes = modeling_mistral.Mistral3VLM.batch_pixel_values(
|
||||
pixel_values, image_sizes
|
||||
)
|
||||
|
||||
# Check output shapes
|
||||
assert list(batched_pixels.shape) == expected_out_shape
|
||||
assert list(batched_sizes.shape) == [expected_out_shape[0], 2]
|
||||
|
||||
# Check that the original image data is preserved (with padding).
|
||||
start_idx = 0
|
||||
for original_values in pixel_values:
|
||||
batch_size = original_values.shape[0]
|
||||
end_idx = start_idx + batch_size
|
||||
orig_h, orig_w = original_values.shape[-2:]
|
||||
padded_values = batched_pixels[start_idx:end_idx, :, :orig_h, :orig_w]
|
||||
torch.testing.assert_close(padded_values, original_values)
|
||||
|
||||
start_idx += batch_size
|
||||
|
||||
@ -16,6 +16,7 @@ from tensorrt_llm.llmapi import (CudaGraphConfig, EagleDecodingConfig,
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="https://nvbugs/5461761")
|
||||
@pytest.mark.parametrize(
|
||||
"use_cuda_graph,attn_backend,disable_overlap_scheduler,enable_block_reuse,use_one_model,enable_chunked_prefill",
|
||||
[
|
||||
|
||||
@ -26,11 +26,7 @@ def temp_extra_llm_api_options_file(request):
|
||||
temp_dir = tempfile.gettempdir()
|
||||
temp_file_path = os.path.join(temp_dir, "extra_llm_api_options.yaml")
|
||||
try:
|
||||
extra_llm_api_options_dict = {
|
||||
"guided_decoding_backend": "xgrammar",
|
||||
"disable_overlap_scheduler":
|
||||
True, # Guided decoding is not supported with overlap scheduler
|
||||
}
|
||||
extra_llm_api_options_dict = {"guided_decoding_backend": "xgrammar"}
|
||||
|
||||
with open(temp_file_path, "w") as f:
|
||||
yaml.dump(extra_llm_api_options_dict, f)
|
||||
|
||||
@ -1,25 +1,28 @@
|
||||
# Adapted from
|
||||
# https://github.com/vllm-project/vllm/blob/aae6927be06dedbda39c6b0c30f6aa3242b84388/tests/entrypoints/openai/test_chat.py
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
|
||||
import jsonschema
|
||||
import openai
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from ..test_llm import get_model_path, similar
|
||||
from ..test_llm import get_model_path
|
||||
from .openai_server import RemoteOpenAIServer
|
||||
|
||||
pytestmark = pytest.mark.threadleak(enabled=False)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", ids=["TinyLlama-1.1B-Chat"])
|
||||
@pytest.fixture(scope="module")
|
||||
def model_name():
|
||||
return "llama-3.1-model/Llama-3.1-8B-Instruct"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def temp_extra_llm_api_options_file(request):
|
||||
def temp_extra_llm_api_options_file():
|
||||
temp_dir = tempfile.gettempdir()
|
||||
temp_file_path = os.path.join(temp_dir, "extra_llm_api_options.yaml")
|
||||
try:
|
||||
@ -37,7 +40,12 @@ def temp_extra_llm_api_options_file(request):
|
||||
@pytest.fixture(scope="module")
|
||||
def server(model_name: str, temp_extra_llm_api_options_file: str):
|
||||
model_path = get_model_path(model_name)
|
||||
args = ["--extra_llm_api_options", temp_extra_llm_api_options_file]
|
||||
|
||||
# Use small max_batch_size/max_seq_len/max_num_tokens to avoid OOM on A10/A30 GPUs.
|
||||
args = [
|
||||
"--max_batch_size=8", "--max_seq_len=1024", "--max_num_tokens=1024",
|
||||
f"--extra_llm_api_options={temp_extra_llm_api_options_file}"
|
||||
]
|
||||
with RemoteOpenAIServer(model_path, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
@ -112,12 +120,7 @@ def tool_get_current_date():
|
||||
|
||||
def test_chat_structural_tag(client: openai.OpenAI, model_name: str,
|
||||
tool_get_current_weather, tool_get_current_date):
|
||||
messages = [
|
||||
{
|
||||
"role":
|
||||
"system",
|
||||
"content":
|
||||
f"""
|
||||
system_prompt = f"""
|
||||
# Tool Instructions
|
||||
- Always execute python code in messages that you share.
|
||||
- When looking for real time information use relevant functions if available else fallback to brave_search
|
||||
@ -140,20 +143,24 @@ Reminder:
|
||||
- Only call one function at a time
|
||||
- Put the entire function call reply on one line
|
||||
- Always add your sources when using search results to answer the user query
|
||||
You are a helpful assistant.""",
|
||||
You are a helpful assistant."""
|
||||
user_prompt = "You are in New York. Please get the current date and time, and the weather."
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
{
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
"You are in New York. Please get the current date and time, and the weather.",
|
||||
"role": "user",
|
||||
"content": user_prompt,
|
||||
},
|
||||
]
|
||||
|
||||
chat_completion = client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
max_completion_tokens=100,
|
||||
max_completion_tokens=256,
|
||||
response_format={
|
||||
"type":
|
||||
"structural_tag",
|
||||
@ -173,11 +180,18 @@ You are a helpful assistant.""",
|
||||
"triggers": ["<function="],
|
||||
},
|
||||
)
|
||||
assert chat_completion.id is not None
|
||||
assert len(chat_completion.choices) == 1
|
||||
|
||||
message = chat_completion.choices[0].message
|
||||
assert message.content is not None
|
||||
assert message.role == "assistant"
|
||||
|
||||
reference = '<function=get_current_date>{"timezone": "America/New_York"}</function>\n<function=get_current_weather>{"city": "New York", "state": "NY", "unit": "fahrenheit"}</function>\n\nSources:\n- get_current_date function\n- get_current_weather function'
|
||||
assert similar(chat_completion.choices[0].message.content, reference)
|
||||
match = re.search(r'<function=get_current_weather>([\S\s]+?)</function>',
|
||||
message.content)
|
||||
params = json.loads(match.group(1))
|
||||
jsonschema.validate(params,
|
||||
tool_get_current_weather["function"]["parameters"])
|
||||
|
||||
match = re.search(r'<function=get_current_date>([\S\s]+?)</function>',
|
||||
message.content)
|
||||
params = json.loads(match.group(1))
|
||||
jsonschema.validate(params, tool_get_current_date["function"]["parameters"])
|
||||
|
||||
@ -8,6 +8,7 @@ Help()
|
||||
echo "h Print this Help."
|
||||
echo "t Location of tensorrt library"
|
||||
echo "u Option to build unit tests"
|
||||
echo "s Triton short tag, e.g. 'r25.06'"
|
||||
echo
|
||||
}
|
||||
|
||||
@ -15,7 +16,7 @@ TRT_ROOT='/usr/local/tensorrt'
|
||||
BUILD_UNIT_TESTS='false'
|
||||
|
||||
# Get the options
|
||||
while getopts ":ht:u" option; do
|
||||
while getopts ":ht:us:" option; do
|
||||
case $option in
|
||||
h) # display Help
|
||||
Help
|
||||
@ -24,6 +25,8 @@ while getopts ":ht:u" option; do
|
||||
TRT_ROOT=$OPTARG;;
|
||||
u) # Option to build unit tests
|
||||
BUILD_UNIT_TESTS='true';;
|
||||
s) # Triton short tag
|
||||
TRITON_SHORT_TAG=$OPTARG;;
|
||||
\?) # Invalid option
|
||||
echo "Error: Invalid option"
|
||||
echo ""
|
||||
@ -35,14 +38,22 @@ done
|
||||
echo "Using TRT_ROOT=${TRT_ROOT}"
|
||||
echo "Using BUILD_UNIT_TESTS=${BUILD_UNIT_TESTS}"
|
||||
|
||||
DIRNAME="$(dirname "$(realpath "$0")")"
|
||||
if [ -z "$TRITON_SHORT_TAG" ]; then
|
||||
# Get TRITON_SHORT_TAG from docker/Dockerfile.multi
|
||||
LLM_ROOT="${DIRNAME}/../../.."
|
||||
TRITON_SHORT_TAG=$("$LLM_ROOT/jenkins/scripts/get_triton_tag.sh" "$LLM_ROOT")
|
||||
fi
|
||||
echo "Using TRITON_SHORT_TAG=${TRITON_SHORT_TAG}"
|
||||
|
||||
set -x
|
||||
apt-get update
|
||||
apt-get install -y --no-install-recommends rapidjson-dev
|
||||
|
||||
BUILD_DIR=$(dirname $0)/../build
|
||||
mkdir $BUILD_DIR
|
||||
BUILD_DIR=$(cd -- "$BUILD_DIR" && pwd)
|
||||
cd $BUILD_DIR
|
||||
|
||||
BUILD_DIR=$(realpath "$DIRNAME/../build")
|
||||
mkdir -p "$BUILD_DIR"
|
||||
cd "$BUILD_DIR" || exit 1
|
||||
|
||||
export LD_LIBRARY_PATH="/usr/local/cuda/compat/lib.real:${LD_LIBRARY_PATH}"
|
||||
|
||||
@ -51,12 +62,13 @@ if [[ "$BUILD_UNIT_TESTS" == "true" ]]; then
|
||||
BUILD_TESTS_ARG="-DBUILD_TESTS=ON -DUSE_CXX11_ABI=ON"
|
||||
fi
|
||||
|
||||
# TODO: Remove specifying Triton version after cmake version is upgraded to 3.31.8
|
||||
# Get TRITON_SHORT_TAG from docker/Dockerfile.multi
|
||||
LLM_ROOT=$BUILD_DIR/../../..
|
||||
LLM_ROOT=$(cd -- "$LLM_ROOT" && pwd)
|
||||
TRITON_SHORT_TAG=$("$LLM_ROOT/jenkins/scripts/get_triton_tag.sh" "$LLM_ROOT")
|
||||
cmake -DCMAKE_INSTALL_PREFIX:PATH=`pwd`/install ${BUILD_TESTS_ARG} -DTRITON_COMMON_REPO_TAG=${TRITON_SHORT_TAG} -DTRITON_CORE_REPO_TAG=${TRITON_SHORT_TAG} -DTRITON_THIRD_PARTY_REPO_TAG=${TRITON_SHORT_TAG} -DTRITON_BACKEND_REPO_TAG=${TRITON_SHORT_TAG} ..
|
||||
cmake -DCMAKE_INSTALL_PREFIX:PATH="$(pwd)/install" \
|
||||
${BUILD_TESTS_ARG} \
|
||||
-DTRITON_COMMON_REPO_TAG="${TRITON_SHORT_TAG}" \
|
||||
-DTRITON_CORE_REPO_TAG="${TRITON_SHORT_TAG}" \
|
||||
-DTRITON_THIRD_PARTY_REPO_TAG="${TRITON_SHORT_TAG}" \
|
||||
-DTRITON_BACKEND_REPO_TAG="${TRITON_SHORT_TAG}" \
|
||||
..
|
||||
make install
|
||||
|
||||
mkdir -p /opt/tritonserver/backends/tensorrtllm
|
||||
|
||||
Loading…
Reference in New Issue
Block a user