diff --git a/cpp/include/tensorrt_llm/batch_manager/logitsPostProcessor.h b/cpp/include/tensorrt_llm/batch_manager/logitsPostProcessor.h index 048a84ecca..1916a915e3 100644 --- a/cpp/include/tensorrt_llm/batch_manager/logitsPostProcessor.h +++ b/cpp/include/tensorrt_llm/batch_manager/logitsPostProcessor.h @@ -47,7 +47,7 @@ public: bool operator()(DecoderInputBuffers& inputBuffers, bool replicateLogitsPostProcessor, runtime::WorldConfig const& worldConfig, CudaStreamPtr const& stream, - std::optional logitsPostProcessorBatched = std::nullopt) const; + std::optional const& logitsPostProcessorBatched = std::nullopt) const; }; } // namespace tensorrt_llm::batch_manager diff --git a/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp b/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp index 23f20bde67..8b4a8773d1 100644 --- a/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp +++ b/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp @@ -204,6 +204,34 @@ private: } } + void sendResponse(std::vector const& blockHashes, std::map::iterator it) + { + auto reqId = mCurrentRequest.value(); + auto count = --mRemainSendCount[reqId]; + TLLM_CHECK(count >= 0); + if (count == 0) + { + mRemainSendCount.erase(reqId); + + // TODO(zhengd): pass the hashes directly instead of update llmRequest + auto llmRequest = it->second.mRequest; + llmRequest->setRequestedBlockHashes(std::move(blockHashes)); + + if (common::getEnvParallelCacheSend()) + { + // TODO: Use a thread pool and check for thread safety. + std::thread(&DataResponder::Impl::sendAndRemoveResponse, this, it->first, std::move(it->second)) + .detach(); + } + else + { + DataResponder::Impl::sendAndRemoveResponse(it->first, std::move(it->second)); + } + removeResponse(it); + } + mCurrentRequest = std::nullopt; + } + void response() noexcept { try @@ -237,40 +265,22 @@ private: auto it = getCurrentResponse(); if (it != mReadyResponses.end()) { - auto reqId = mCurrentRequest.value(); - auto count = --mRemainSendCount[reqId]; - TLLM_CHECK(count >= 0); - if (count == 0) - { - mRemainSendCount.erase(reqId); - - // TODO(zhengd): pass the hashes directly instead of update llmRequest - auto llmRequest = it->second.mRequest; - llmRequest->setRequestedBlockHashes(std::move(blockHashes)); - - if (common::getEnvParallelCacheSend()) - { - // TODO: Use a thread pool and check for thread safety. - std::thread( - &DataResponder::Impl::sendAndRemoveResponse, this, it->first, std::move(it->second)) - .detach(); - } - else - { - DataResponder::Impl::sendAndRemoveResponse(it->first, std::move(it->second)); - } - removeResponse(it); - } - mCurrentRequest = std::nullopt; + sendResponse(blockHashes, it); } else { - TLLM_CHECK_WITH_INFO(!mCurrentRequest.has_value(), - "This executor does not have a prepared KV cache for request ID: %zu, and the " - "mReadyResponses size is: %zu. mpi rank :%d ", - mCurrentRequest.value(), mReadyResponses.size(), mpi::MpiComm::world().getRank()); - std::unique_lock lk(mCondMutex); - mResponderCv.wait(lk, [this]() { return (mAnyReady || mTerminate); }); + auto it = getCurrentResponse(); + while (it == mReadyResponses.end()) + { + std::unique_lock lk(mCondMutex); + mResponderCv.wait(lk, [this]() { return (mAnyReady || mTerminate); }); + if (mTerminate) + { + break; + } + it = getCurrentResponse(); + } + sendResponse(blockHashes, it); } } } diff --git a/cpp/tensorrt_llm/batch_manager/logitsPostProcessor.cpp b/cpp/tensorrt_llm/batch_manager/logitsPostProcessor.cpp index dd34de0ef9..dbb90da326 100644 --- a/cpp/tensorrt_llm/batch_manager/logitsPostProcessor.cpp +++ b/cpp/tensorrt_llm/batch_manager/logitsPostProcessor.cpp @@ -34,7 +34,7 @@ using SizeType32 = tensorrt_llm::runtime::SizeType32; bool LogitsPostProcessor::operator()(DecoderInputBuffers& inputBuffers, bool replicateLogitsPostProcessor, tr::WorldConfig const& worldConfig, CudaStreamPtr const& stream, - std::optional logitsPostProcessorBatched) const + std::optional const& logitsPostProcessorBatched) const { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); NVTX3_SCOPED_RANGE(LogitsPostProcessor); diff --git a/docs/source/commands/trtllm-serve/trtllm-serve.rst b/docs/source/commands/trtllm-serve/trtllm-serve.rst index b59a588cac..8b7d25e735 100644 --- a/docs/source/commands/trtllm-serve/trtllm-serve.rst +++ b/docs/source/commands/trtllm-serve/trtllm-serve.rst @@ -201,56 +201,60 @@ Metrics Endpoint .. note:: - This endpoint is beta maturity. + The metrics endpoint for the default PyTorch backend are in beta and are not as comprehensive as those for the TensorRT backend. - The statistics for the PyTorch backend are beta and not as comprehensive as those for the TensorRT backend. + Some fields, such as CPU memory usage, are not yet available for the PyTorch backend. - Some fields, such as CPU memory usage, are not available for the PyTorch backend. + Enabling ``enable_iter_perf_stats`` in the PyTorch backend can slightly impact performance, depending on the serving configuration. - Enabling ``enable_iter_perf_stats`` in the PyTorch backend can impact performance slightly, depending on the serving configuration. +The ``/metrics`` endpoint provides runtime iteration statistics such as GPU memory usage and KV cache details. -The ``/metrics`` endpoint provides runtime-iteration statistics such as GPU memory use and inflight-batching details. -For the TensorRT backend, these statistics are enabled by default. -However, for the PyTorch backend, you must explicitly enable iteration statistics logging by setting the `enable_iter_perf_stats` field in a YAML configuration file as shown in the following example: +For the default PyTorch backend, iteration statistics logging is enabled by setting the ``enable_iter_perf_stats`` field in a YAML file: .. code-block:: yaml - # extra-llm-api-config.yml - pytorch_backend_config: - enable_iter_perf_stats: true + # extra_llm_config.yaml + enable_iter_perf_stats: true -Then start the server and specify the ``--extra_llm_api_options`` argument with the path to the YAML file as shown in the following example: +Start the server and specify the ``--extra_llm_api_options`` argument with the path to the YAML file: .. code-block:: bash - trtllm-serve \ - --extra_llm_api_options \ - [--tp_size --pp_size --ep_size --host --port ] + trtllm-serve "TinyLlama/TinyLlama-1.1B-Chat-v1.0" --extra_llm_api_options extra_llm_config.yaml -After at least one inference request is sent to the server, you can fetch the runtime-iteration statistics by polling the `/metrics` endpoint: +After sending at least one inference request to the server, you can fetch runtime iteration statistics by polling the ``/metrics`` endpoint. +Since the statistics are stored in an internal queue and removed once retrieved, it's recommended to poll the endpoint shortly after each request and store the results if needed. .. code-block:: bash - curl -X GET http://:/metrics + curl -X GET http://localhost:8000/metrics -*Example Output* +Example output: .. code-block:: json - [ - { - "gpuMemUsage": 56401920000, - "inflightBatchingStats": { + [ + { + "gpuMemUsage": 76665782272, + "iter": 154, + "iterLatencyMS": 7.00688362121582, + "kvCacheStats": { + "allocNewBlocks": 3126, + "allocTotalBlocks": 3126, + "cacheHitRate": 0.00128, + "freeNumBlocks": 101253, + "maxNumBlocks": 101256, + "missedBlocks": 3121, + "reusedBlocks": 4, + "tokensPerBlock": 32, + "usedNumBlocks": 3 + }, + "numActiveRequests": 1 ... - }, - "iter": 1, - "iterLatencyMS": 16.505143404006958, - "kvCacheStats": { - ... - }, - "newActiveRequestsQueueLatencyMS": 0.0007503032684326172 - } -] + } + ] + + Syntax ------ diff --git a/docs/source/deployment-guide/quick-start-recipe-for-gpt-oss-on-trtllm.md b/docs/source/deployment-guide/quick-start-recipe-for-gpt-oss-on-trtllm.md index af109862b1..71180774f2 100644 --- a/docs/source/deployment-guide/quick-start-recipe-for-gpt-oss-on-trtllm.md +++ b/docs/source/deployment-guide/quick-start-recipe-for-gpt-oss-on-trtllm.md @@ -234,6 +234,28 @@ TODO: Use Chat Compeletions API / Responses API as the example after the PR is m We use OpenAI's official evaluation tool to test the model's accuracy. For more information see [https://github.com/openai/gpt-oss/tree/main/gpt_oss/evals](gpt-oss-eval). With the added support of Chat Completions and Responses API in `trtllm-serve,` `gpt_oss.evals` works directly without any modifications. +You need to set `enable_attention_dp`, `tp_size`, `ep_size`, `max_batch_size` and `max_num_tokens` when launching the trtllm server and set `reasoning-effort` when launching evaluation in gpt-oss. Below are some reference configurations for accuracy evaluation on B200. + +| **reasoning-effort** | **parallel configuration** | **max_batch_size** | **max_num_tokens** | +|:--------------------:|:--------------------------:|:------------------:|:------------------:| +| low/medium | DEP8 / DEP4 | 128 | 32768 | +| high | DEP8 / DEP4 | 2 | 133120 | +| low/medium | TP8 / TP4 | 1024 | 32768 | +| high | TP8 / TP4 | 720 | 133120 | + +Below is an example command for evaluating the accuracy of gpt-oss-120b with low and medium reasoning-effort on GPQA and AIME2025. + +```shell +# execute this command in gpt-oss +python -m gpt_oss.evals \ + --sampler chat_completions \ + --eval gpqa,aime25 \ + --model gpt-oss-120b \ + --reasoning-effort low,medium +``` + + + ## Benchmarking Performance To benchmark the performance of your TensorRT-LLM server you can leverage the built-in `benchmark_serving.py` script. To do this first creating a wrapper `bench.sh` script. diff --git a/docs/source/legacy/tensorrt_quickstart.md b/docs/source/legacy/tensorrt_quickstart.md index df62aa38d7..e74a0f5e9e 100644 --- a/docs/source/legacy/tensorrt_quickstart.md +++ b/docs/source/legacy/tensorrt_quickstart.md @@ -1,7 +1,7 @@ # LLM API with TensorRT Engine A simple inference example with TinyLlama using the LLM API: -```{literalinclude} ../../examples/llm-api/_tensorrt_engine/quickstart_example.py +```{literalinclude} ../../../examples/llm-api/_tensorrt_engine/quickstart_example.py :language: python :linenos: ``` diff --git a/examples/llm-api/_tensorrt_engine/quickstart_example.py b/examples/llm-api/_tensorrt_engine/quickstart_example.py index 400a241c0e..a6ba9ec559 100644 --- a/examples/llm-api/_tensorrt_engine/quickstart_example.py +++ b/examples/llm-api/_tensorrt_engine/quickstart_example.py @@ -1,11 +1,17 @@ -from tensorrt_llm import LLM, SamplingParams +from tensorrt_llm import BuildConfig, SamplingParams +from tensorrt_llm._tensorrt_engine import LLM # NOTE the change def main(): + build_config = BuildConfig() + build_config.max_batch_size = 256 + build_config.max_num_tokens = 1024 + # Model could accept HF model name, a path to local HF model, # or TensorRT Model Optimizer's quantized checkpoints like nvidia/Llama-3.1-8B-Instruct-FP8 on HF. - llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0") + llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", + build_config=build_config) # Sample prompts. prompts = [ diff --git a/examples/llm-api/llm_mgmn_trtllm_bench.sh b/examples/llm-api/llm_mgmn_trtllm_bench.sh index de7ee73d53..5169c00ad3 100644 --- a/examples/llm-api/llm_mgmn_trtllm_bench.sh +++ b/examples/llm-api/llm_mgmn_trtllm_bench.sh @@ -76,6 +76,7 @@ srun -l \ # This is optional cat > /tmp/pytorch_extra_args.txt << EOF +cuda_graph_config: null print_iter_log: true enable_attention_dp: false EOF diff --git a/jenkins/BuildDockerImage.groovy b/jenkins/BuildDockerImage.groovy index c0e718882b..180fcfb5f5 100644 --- a/jenkins/BuildDockerImage.groovy +++ b/jenkins/BuildDockerImage.groovy @@ -94,6 +94,26 @@ def createKubernetesPodConfig(type, arch = "amd64", build_wheel = false) """ } + if (arch == "amd64") { + // For x86_64, we block some nodes to avoid unstable network access. + selectors += """ + affinity: + nodeAffinity: + requiredDuringSchedulingIgnoredDuringExecution: + nodeSelectorTerms: + - matchExpressions: + - key: "kubernetes.io/hostname" + operator: NotIn + values: + - "sc-ipp-blossom-prod-k8w-105" + - "sc-ipp-blossom-prod-k8w-114" + - "sc-ipp-blossom-prod-k8w-115" + - "sc-ipp-blossom-prod-k8w-121" + - "sc-ipp-blossom-prod-k8w-123" + - "sc-ipp-blossom-prod-k8w-124" + """ + } + def archSuffix = arch == "arm64" ? "arm" : "amd" def jnlpImage = "urm.nvidia.com/sw-ipp-blossom-sre-docker-local/lambda/custom_jnlp_images_${archSuffix}_linux:jdk17" diff --git a/jenkins/L0_Test.groovy b/jenkins/L0_Test.groovy index 410248a18b..112e30b997 100644 --- a/jenkins/L0_Test.groovy +++ b/jenkins/L0_Test.groovy @@ -365,13 +365,24 @@ def runLLMTestlistOnSlurm(pipeline, platform, testList, config=VANILLA_CONFIG, p } stage('Checking if the Node is Online') { - def counter = 0 - // We submit the Slurm job with 5 hours timeout, and the K8S pod will be evicted after 22 hours. - // Let's use 15 hours to check if the node is online, and with 2 hours buffer. - while (!CloudManager.isNodeOnline(nodeName) && counter < 90) { - // Wait 10 minutes to check status of the node again - sleep(time: 10, unit: 'MINUTES') - counter++ + withCredentials([usernamePassword(credentialsId: 'svc_tensorrt', usernameVariable: 'USERNAME', passwordVariable: 'PASSWORD')]) { + def remote = [ + ip : cluster.ip, + host : cluster.host, + user : "${pipeline.USERNAME}", + passwd : "${pipeline.PASSWORD}", + allowAnyHosts: true, + ] + def counter = 0 + // We submit the Slurm job with 5 hours timeout, and the K8S pod will be evicted after 22 hours. + // Let's use 15 hours to check if the node is online, and with 2 hours buffer. + while (!CloudManager.isNodeOnline(nodeName) && counter < 90) { + // Wait 10 minutes to check status of the node again + sleep(time: 10, unit: 'MINUTES') + // Avoid the node being stuck in the held state. + Utils.exec(pipeline, Utils.sshUserCmd(remote, "\"scontrol release ${slurmJobID} || true\"")) + counter++ + } } if (CloudManager.isNodeOnline(nodeName)) { @@ -1157,7 +1168,7 @@ def transformMakoArgsToJson(optList) { def getMakoOpts(getMakoScript, makoArgs=[]) { // We want to save a map for the Mako opts - def turtleOutput = "" + def makoOutput = "" // Echo the command // NOTE: We redirect stderr to stdout so that we can capture @@ -1181,17 +1192,17 @@ def getMakoOpts(getMakoScript, makoArgs=[]) { // Capture the mako output, add timeout in case any hang timeout(time: 30, unit: 'MINUTES'){ - turtleOutput = sh(label: "Capture Mako Parameters", script: listMakoCmd, returnStdout: true) + makoOutput = sh(label: "Capture Mako Parameters", script: listMakoCmd, returnStdout: true) } } // Validate output - assert turtleOutput: "Mako opts not found - could not construct test db test list." + assert makoOutput: "Mako opts not found - could not construct test db test list." - // Split each line of turtle output into a list - def turtleOutList = turtleOutput.split("\n") + // Split each line of mako output into a list + def outputList = makoOutput.split("\n") - def makoOptsJson = transformMakoArgsToJson(turtleOutList) + def makoOptsJson = transformMakoArgsToJson(outputList) return makoOptsJson } @@ -1827,7 +1838,7 @@ def runLLMBuild(pipeline, cpu_arch, reinstall_dependencies=false, wheel_path="", if (env.alternativeTRT) { trtllm_utils.replaceWithAlternativeTRT(env.alternativeTRT, cpver) } - buildArgs = "--clean" + buildArgs = "--clean --nixl_root /opt/nvidia/nvda_nixl" if (cpu_arch == AARCH64_TRIPLE) { buildArgs += " -a '90-real;100-real;103-real;120-real'" } @@ -2062,9 +2073,9 @@ def launchTestJobs(pipeline, testFilter) "DGX_H200-4_GPUs-TensorRT-Post-Merge-1": ["dgx-h200-x4", "l0_dgx_h200", 1, 3, 4], "DGX_H200-4_GPUs-TensorRT-Post-Merge-2": ["dgx-h200-x4", "l0_dgx_h200", 2, 3, 4], "DGX_H200-4_GPUs-TensorRT-Post-Merge-3": ["dgx-h200-x4", "l0_dgx_h200", 3, 3, 4], - "RTXPro6000-Pytorch-Post-Merge-1": ["rtx-pro-6000", "l0_rtx_pro_6000", 1, 1], - "RTXPro6000-4_GPUs-Pytorch-Post-Merge-1": ["rtx-pro-6000-x4", "l0_rtx_pro_6000", 1, 2, 4], - "RTXPro6000-4_GPUs-Pytorch-Post-Merge-2": ["rtx-pro-6000-x4", "l0_rtx_pro_6000", 2, 2, 4], + //"RTXPro6000-Pytorch-Post-Merge-1": ["rtx-pro-6000", "l0_rtx_pro_6000", 1, 1], + //"RTXPro6000-4_GPUs-Pytorch-Post-Merge-1": ["rtx-pro-6000-x4", "l0_rtx_pro_6000", 1, 2, 4], + //"RTXPro6000-4_GPUs-Pytorch-Post-Merge-2": ["rtx-pro-6000-x4", "l0_rtx_pro_6000", 2, 2, 4], ] parallelJobs = x86TestConfigs.collectEntries{key, values -> [key, [createKubernetesPodConfig(key.contains("-CU12-") ? LLM_DOCKER_IMAGE_12_9 : LLM_DOCKER_IMAGE, values[0], "amd64", values[4] ?: 1, key.contains("Perf")), { diff --git a/tensorrt_llm/_torch/attention_backend/flashinfer.py b/tensorrt_llm/_torch/attention_backend/flashinfer.py index b8bf330488..74adc69c02 100644 --- a/tensorrt_llm/_torch/attention_backend/flashinfer.py +++ b/tensorrt_llm/_torch/attention_backend/flashinfer.py @@ -170,7 +170,8 @@ class FlashInferAttentionMetadata(AttentionMetadata): def create_cuda_graph_metadata(self, max_batch_size: int, sub_cross_metadata: bool = False, - max_draft_tokens: int = 0) -> Self: + max_draft_tokens: int = 0, + buffers=None) -> Self: metadata = super().create_cuda_graph_metadata(max_batch_size, sub_cross_metadata, max_draft_tokens) diff --git a/tensorrt_llm/_torch/attention_backend/interface.py b/tensorrt_llm/_torch/attention_backend/interface.py index aa081e82dd..6a035ad477 100644 --- a/tensorrt_llm/_torch/attention_backend/interface.py +++ b/tensorrt_llm/_torch/attention_backend/interface.py @@ -140,6 +140,7 @@ class AttentionMetadata: # This buffer is currently only used for TrtllmAttentionMetadata. cache_indirection: Optional[torch.Tensor] = None + cuda_graph_buffers: dict[str, list[torch.Tensor]] = None _saved_tensors: Dict[str, torch.Tensor] = field(init=False, default_factory=dict) @@ -288,7 +289,8 @@ class AttentionMetadata: def create_cuda_graph_metadata(self, max_batch_size: int, sub_cross_metadata: bool = False, - max_draft_tokens: int = 0) -> Self: + max_draft_tokens: int = 0, + buffers=None) -> Self: """ Creates metadata for CUDA graph execution. CUDA graphs require to use pre-allocated buffers for all tensors in fields. @@ -300,6 +302,7 @@ class AttentionMetadata: cuda_graph_metadata = copy.copy(self) cuda_graph_metadata.is_cuda_graph = True + cuda_graph_metadata.cuda_graph_buffers = buffers if self.has_cross_sub_metadata: cuda_graph_metadata.cross = cuda_graph_metadata.cross.create_cuda_graph_metadata( max_batch_size, True) diff --git a/tensorrt_llm/_torch/attention_backend/trtllm.py b/tensorrt_llm/_torch/attention_backend/trtllm.py index a95519f22c..cdca67a7b9 100644 --- a/tensorrt_llm/_torch/attention_backend/trtllm.py +++ b/tensorrt_llm/_torch/attention_backend/trtllm.py @@ -600,13 +600,65 @@ class TrtllmAttentionMetadata(AttentionMetadata): def __post_init__(self) -> None: super().__post_init__() + self._post_init_with_buffers(self.cuda_graph_buffers) + + def _post_init_with_buffers(self, buffers) -> None: + # Set a default value, as max_num_sequences is not always set. if self.max_num_sequences is None: self.max_num_sequences = self.max_num_requests - self.prompt_lens_cuda = torch.empty( + def get_empty(tensor_shape: list[int], dtype: torch.dtype, + cache_name: str) -> torch.Tensor: + """ + Finds a compatible, reusable buffer from a cache or creates a new one. + + This function searches for a pre-allocated tensor (buffer) that can be + reused for an operation involving a tensor with the shape of `tensor_shape`. + + The compatibility rules are: The buffer's total elements must be >= tensor_shape's. + + If a compatible buffer is found, it's returned immediately. Otherwise, a new + buffer is allocated on the 'cuda' device with the give properties of 'tensor_shape' and 'dtype'. + + Args: + tensor_shape: The required shape. + dtype: The required dtype. + cache_name: The key for the specific list of buffers to search in. + + Returns: + An existing compatible buffer or a newly created one. + """ + if buffers is not None: + # Safely get the list of candidates. Defaults to an empty list if key is missing. + candidate_buffers = buffers.get(cache_name, []) + numel_like = math.prod(tensor_shape) + + for buffer in candidate_buffers: + numel_buffer = buffer.numel() + + # buffer just needs to be large enough. + if numel_buffer >= numel_like: + return buffer[0:numel_like].view( + tensor_shape) # Found a fit, return immediately. + + # If we get here, no suitable buffer was found in the cache. Create a new one. + new_buffer = torch.zeros(tensor_shape, device='cuda', dtype=dtype) + if buffers is not None: + buffers.setdefault(cache_name, []).append(new_buffer) + return new_buffer + + def get_empty_like(like_tensor: torch.Tensor, + cache_name: str) -> torch.Tensor: + return get_empty( + like_tensor.shape, + cache_name=cache_name, + dtype=like_tensor.dtype, + ) + + self.prompt_lens_cuda = get_empty( (self.max_num_sequences, ), - device='cuda', + cache_name="prompt_lens_cuda", dtype=torch.int, ) self.prompt_lens_cpu = torch.empty_like( @@ -614,7 +666,10 @@ class TrtllmAttentionMetadata(AttentionMetadata): device='cpu', pin_memory=True, ) - self.kv_lens_cuda = torch.empty_like(self.prompt_lens_cuda) + self.kv_lens_cuda = get_empty_like( + self.prompt_lens_cuda, + cache_name="kv_lens_cuda", + ) self.kv_lens = torch.empty_like(self.kv_lens_cuda, device='cpu', pin_memory=True) @@ -629,13 +684,13 @@ class TrtllmAttentionMetadata(AttentionMetadata): dtype=torch.int8, ) if self.kv_cache_manager is not None: - self.kv_cache_block_offsets = torch.empty( + self.kv_cache_block_offsets = get_empty( [ self.kv_cache_manager.num_pools, self.max_num_sequences, 2, self.kv_cache_manager.max_blocks_per_seq ], + cache_name="kv_cache_block_offsets", dtype=torch.int32, - device='cuda', ) self.host_kv_cache_block_offsets = torch.empty_like( self.kv_cache_block_offsets, @@ -645,27 +700,27 @@ class TrtllmAttentionMetadata(AttentionMetadata): self.block_ids_per_seq = None self.kv_block_ids_per_seq = None if self.enable_flash_mla: - self.block_ids_per_seq = torch.zeros( + self.block_ids_per_seq = get_empty( [ self.kv_cache_manager.max_batch_size, self.kv_cache_manager.max_blocks_per_seq ], + cache_name="block_ids_per_seq", dtype=torch.int32, - device='cuda', ) - self.kv_block_ids_per_seq = torch.zeros( + self.kv_block_ids_per_seq = get_empty( [ self.kv_cache_manager.max_batch_size, self.kv_cache_manager.max_blocks_per_seq ], + cache_name="kv_block_ids_per_seq", dtype=torch.int32, - device='cuda', ) if self.enable_context_mla_with_cached_kv: # for kv cache reuse/chunked context in MLA - self.ctx_cached_token_indptr = torch.zeros( + self.ctx_cached_token_indptr = get_empty( (self.max_num_requests + 1, ), - device='cuda', + cache_name="ctx_cached_token_indptr", dtype=torch.int64, ) self.host_ctx_cached_token_indptr = torch.zeros_like( @@ -673,9 +728,9 @@ class TrtllmAttentionMetadata(AttentionMetadata): device='cpu', pin_memory=True, ) - self.ctx_uncached_token_indptr = torch.zeros( + self.ctx_uncached_token_indptr = get_empty( (self.max_num_requests + 1, ), - device='cuda', + cache_name="ctx_uncached_token_indptr", dtype=torch.int64, ) self.host_ctx_uncached_token_indptr = torch.zeros_like( @@ -684,9 +739,9 @@ class TrtllmAttentionMetadata(AttentionMetadata): pin_memory=True, ) # context full seqlens include cached tokens and uncached tokens - self.ctx_kv_indptr = torch.zeros( + self.ctx_kv_indptr = get_empty( (self.max_num_requests + 1, ), - device='cuda', + cache_name="ctx_kv_indptr", dtype=torch.int64, ) self.host_ctx_kv_indptr = torch.zeros_like( @@ -1165,7 +1220,7 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]): host_kv_cache_pool_pointers=metadata.host_kv_cache_pool_pointers, host_kv_cache_pool_mapping=metadata.host_kv_cache_pool_mapping, block_ids_per_seq=metadata.block_ids_per_seq, - workspace=metadata.workspace, + workspace=None, cache_indirection=metadata.cache_indirection, kv_scale_orig_quant=self.kv_scale_orig_quant, kv_scale_quant_orig=self.kv_scale_quant_orig, diff --git a/tensorrt_llm/_torch/autotuner.py b/tensorrt_llm/_torch/autotuner.py index aa1b250b3a..846386bb1f 100644 --- a/tensorrt_llm/_torch/autotuner.py +++ b/tensorrt_llm/_torch/autotuner.py @@ -371,7 +371,7 @@ class AutoTuner: if not is_cache_hit: logger.warning_once( f"[AutoTunner] Using the fallback tactic, due to cache miss on input shapes={input_shapes}", - key=(custom_op)) + key=custom_op) return (best_runner, best_tactic) diff --git a/tensorrt_llm/_torch/compilation/piecewise_optimizer.py b/tensorrt_llm/_torch/compilation/piecewise_optimizer.py index 32c37d5339..6139131e47 100644 --- a/tensorrt_llm/_torch/compilation/piecewise_optimizer.py +++ b/tensorrt_llm/_torch/compilation/piecewise_optimizer.py @@ -210,15 +210,9 @@ class PiecewiseRunner(object): runtime_input_addresses = [ i.data_ptr() for i in args if isinstance(i, torch.Tensor) ] - runtime_output_addresses = [ - i.data_ptr() for i in output if isinstance(i, torch.Tensor) - ] assert (entry.input_addresses == runtime_input_addresses ), f"{entry.input_addresses} vs\n {runtime_input_addresses}" - assert ( - entry.output_addresses == runtime_output_addresses - ), f"{entry.output_addresses} vs\n {runtime_output_addresses}" entry.cuda_graph.replay() diff --git a/tensorrt_llm/_torch/model_config.py b/tensorrt_llm/_torch/model_config.py index a4ce0092a0..f77d309805 100644 --- a/tensorrt_llm/_torch/model_config.py +++ b/tensorrt_llm/_torch/model_config.py @@ -494,7 +494,8 @@ class ModelConfig(Generic[TConfig]): architectures = self.pretrained_config.architectures if len(architectures ) == 1 and architectures[0] == "DeciLMForCausalLM": - mlp_hidden_size = self._infer_nemotron_ffn_mult() + mlp_hidden_size = self._infer_nemotron_ffn_mult( + ) // self.mapping.tp_size else: raise ValueError( f"Inferring mlp hidden size for model architecture: {architectures} isn't supported yet" diff --git a/tensorrt_llm/_torch/models/modeling_gemma3vl.py b/tensorrt_llm/_torch/models/modeling_gemma3vl.py index 37770d2f0d..15e93ad097 100644 --- a/tensorrt_llm/_torch/models/modeling_gemma3vl.py +++ b/tensorrt_llm/_torch/models/modeling_gemma3vl.py @@ -263,7 +263,9 @@ class Gemma3VLM(PreTrainedModel): embedding_layer=self.llm.model.embed_tokens, input_ids=input_ids, mm_embeds=mm_embeds, - mm_token_ids=self.image_token_ids) + mm_token_ids=self.image_token_ids, + **kwargs, + ) logits = self.llm.forward( attn_metadata=attn_metadata, input_ids=input_ids, @@ -284,3 +286,7 @@ class Gemma3VLM(PreTrainedModel): attn_metadata=attn_metadata)[-1] image_features = self.mm_projector(image_features) return image_features + + @property + def mm_token_ids(self): + return self.image_token_ids diff --git a/tensorrt_llm/_torch/models/modeling_hyperclovax.py b/tensorrt_llm/_torch/models/modeling_hyperclovax.py index a05784b9d8..975ccdb26e 100644 --- a/tensorrt_llm/_torch/models/modeling_hyperclovax.py +++ b/tensorrt_llm/_torch/models/modeling_hyperclovax.py @@ -1052,7 +1052,8 @@ class HCXVisionForCausalLM(PreTrainedModel): ] input_ids, input_embeds = fuse_input_embeds(self.llm.model.embed_tokens, - input_ids, mm_embeds) + input_ids, mm_embeds, + **kwargs) output_prob = self.llm.forward( attn_metadata=attn_metadata, input_ids=input_ids, diff --git a/tensorrt_llm/_torch/models/modeling_llama.py b/tensorrt_llm/_torch/models/modeling_llama.py index 1169feb0a6..5cb8f1b300 100644 --- a/tensorrt_llm/_torch/models/modeling_llama.py +++ b/tensorrt_llm/_torch/models/modeling_llama.py @@ -1280,7 +1280,8 @@ class Llama4ForConditionalGeneration(SpecDecOneEngineForCausalLM[Llama4Model, ] input_ids, inputs_embeds = fuse_input_embeds(self.model.embed_tokens, - input_ids, mm_embeds) + input_ids, mm_embeds, + **kwargs) return super().forward(attn_metadata, input_ids, position_ids, diff --git a/tensorrt_llm/_torch/models/modeling_llava_next.py b/tensorrt_llm/_torch/models/modeling_llava_next.py index 9356076dc5..7e84fbde5c 100644 --- a/tensorrt_llm/_torch/models/modeling_llava_next.py +++ b/tensorrt_llm/_torch/models/modeling_llava_next.py @@ -302,7 +302,8 @@ class LlavaNextVisionModel(nn.Module): logger.warning_once( "Image feature shape does not line up with the provided patch size. " "You may be using the `default` vision_feature_select_strategy with a" - " visual encoder that does not have CLS.") + " visual encoder that does not have CLS.", + key="llava_next_vision_model_pack_image_features") image_feature = image_feature.view(num_patch_height, num_patch_width, height, @@ -474,7 +475,7 @@ class LlavaNextModel(PreTrainedModel): for multimodal_param in multimodal_params ] input_ids, inputs_embeds = fuse_input_embeds( - self.llm.model.embed_tokens, input_ids, mm_embeds) + self.llm.model.embed_tokens, input_ids, mm_embeds, **kwargs) logits = self.llm.forward(attn_metadata, input_ids, position_ids, inputs_embeds, return_context_logits) diff --git a/tensorrt_llm/_torch/models/modeling_mistral.py b/tensorrt_llm/_torch/models/modeling_mistral.py index 591edade2b..a079d9ee9e 100644 --- a/tensorrt_llm/_torch/models/modeling_mistral.py +++ b/tensorrt_llm/_torch/models/modeling_mistral.py @@ -408,6 +408,7 @@ class Mistral3VLM(PreTrainedModel): input_ids=input_ids, mm_embeds=mm_embeds, mm_token_ids=self._image_token_ids, + **kwargs, ) return self.llm.forward( @@ -501,6 +502,10 @@ class Mistral3VLM(PreTrainedModel): ] return torch.cat(pixel_values), batched_image_sizes + @property + def mm_token_ids(self): + return self._image_token_ids + # Original implementation: # https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/mistral3/modeling_mistral3.py#L66 diff --git a/tensorrt_llm/_torch/models/modeling_multimodal_utils.py b/tensorrt_llm/_torch/models/modeling_multimodal_utils.py index d6387f8190..3417cf988d 100644 --- a/tensorrt_llm/_torch/models/modeling_multimodal_utils.py +++ b/tensorrt_llm/_torch/models/modeling_multimodal_utils.py @@ -105,48 +105,84 @@ def find_uncached_mm_embeds( return sliced_mm_embeds +def filter_mm_token_from_input_ids( + input_ids: torch.IntTensor, + vocab_size: int, + mm_token_ids: Optional[torch.IntTensor] = None, +) -> Tuple[torch.IntTensor, torch.IntTensor]: + """ + Filter multimodal tokens from input_ids. + Args: + input_ids: shape [text_total_length + mm_total_length]. + vocab_size: size of the model's vocabulary + mm_token_ids: possible token ids for multimodal tokens, if known. If not known and set to None, it is assumed that the multimodal tokens are out-of-vocabulary tokens i.e. the `input_ids` contains tokens >= vocab_size that represent the multimodal tokens. + Note: + This function involves host-device synchronization due to torch.where() (= torch.nonzero) requiring + host allocation. The output indices reside on the same device as input_ids. + Returns: + text_token_indices: indices of text tokens in the input_ids + mm_token_indices: indices of multimodal tokens in the input_ids + """ + if mm_token_ids is None: + # NOTE: + # If mm_token_ids is None, it is assumed that the multimodal + # tokens are out-of-vocab tokens i.e. the `input_ids` contains + # tokens >= vocab_size that represent the multimodal tokens. + # Since mm_token_ids can be unbounded in this case, + # using torch.isin() may not be performant. + # This provides a more performant alternative while keeping + # the flexibility of still specifying all possible mm_token_ids, + # if the user wants to. + mm_token_mask = input_ids >= vocab_size + text_token_mask = input_ids < vocab_size + else: + mm_token_ids = mm_token_ids.to(input_ids.device, dtype=input_ids.dtype) + mm_token_mask = torch.isin(input_ids, mm_token_ids) + text_token_mask = ~mm_token_mask + # NOTE: torch.where() enforces a host sync + text_token_indices = torch.where(text_token_mask)[0] + mm_token_indices = torch.where(mm_token_mask)[0] + return text_token_indices, mm_token_indices + + def fuse_input_embeds( embedding_layer: Embedding, input_ids: torch.IntTensor, mm_embeds: List[torch.Tensor], mm_token_ids: Optional[torch.IntTensor] = None, + text_token_indices: Optional[torch.IntTensor] = None, + mm_token_indices: Optional[torch.IntTensor] = None, + **kwargs, ) -> Tuple[Optional[torch.FloatTensor], Optional[torch.FloatTensor]]: """ Fuse text and multimodal embeddings. input_ids is [text_total_length + mm_total_length] and mm_embed is [mm_total_length, hidden_dim]. We just need to fuse them into [text_total_length + mm_total_length, hidden_dim] by slice-and-assign to the corresponding entries. Args: + embedding_layer: embedding layer of the model. input_ids: shape [text_total_length + mm_total_length], flattened from List[(text_length1 + mm_total_length1), ..., (text_lengthi + mm_total_lengthi)]. For LLM model, the requests are inflight batched together, but the input_ids are flattened with padding removed. By the slice condition < vocab_size, we can easily separate text / multimodal tokens and naturally batched the LLM embedding lookup - mm_embed: List[(mm_total_length1, hidden_dim), ..., (mm_total_lengthi, hidden_dim)]. - mm_token_ids: possible token ids for multimodal tokens, if known. If not known and set to None, it is assumed that the multimodal tokens are out-of-vocabulary tokens i.e. the `input_ids` contains tokens >= vocab_size that represent the multimodal tokens. + mm_embeds: List[(mm_total_length1, hidden_dim), ..., (mm_total_lengthi, hidden_dim)]. + mm_token_ids: possible token ids for multimodal tokens, if known. If not known and set to None, it is assumed that the multimodal tokens are out-of-vocabulary tokens. Returns: - If (1) JIT test run, (2) non-multimodal run, i.e. all text-only requests, either context or generation phase (3) multimodal run, all requests in generation phase --> there is no multimodal data, return only the input_ids - If (4) multimodal run, mixed batch of context and generation requests, each context request has a multimodal feature --> return only the fused input_embeds of shape [total length, hidden_dim]. For text tokens, LLM embedding layer has already run. + Note: + - Precedence: If kwargs provide indices (text_token_indices and mm_token_indices), those are used. If any one of them is not provided, fallback to filtering method. Sentinel-/OOV-based filtering (e.g., tokens >= vocab_size) is used only when neither index tensor and mm_token_ids is provided. + - This function may involve host-device synchronization if indices are not provided and filtering is performed. See filter_mm_token_from_input_ids for details. """ if len(mm_embeds) == 0: return input_ids, None mm_embed = torch.cat(mm_embeds, dim=0) - if mm_token_ids is None: - # NOTE: - # If mm_token_ids is None, it is assumed that the multimodal - # tokens are out-of-vocab tokens i.e. the `input_ids` contains - # tokens >= vocab_size that represent the multimodal tokens. - # Since mm_token_ids is be unbounded in this case, - # using torch.isin() may not be performant. - # This provides a more performant alternative while keeping - # the flexibility of still specifying all possible mm_token_ids, - # if the user wants to. - vocab_size = embedding_layer.num_embeddings - mm_token_mask = input_ids >= vocab_size - text_token_mask = input_ids < vocab_size - else: - mm_token_ids = mm_token_ids.to(input_ids.device) - mm_token_mask = torch.isin(input_ids, mm_token_ids) - text_token_mask = ~mm_token_mask - text_token_indices = torch.where(text_token_mask)[0] - mm_token_indices = torch.where(mm_token_mask)[0] - if len(mm_token_indices) != mm_embed.shape[0]: + # TODO: support the case where only one index tensor is provided, the other is derived as the complement (try to avoid implicit host-device synchronization) + if text_token_indices is None or mm_token_indices is None: + # NOTE: This function involves host-device synchronization due to torch.where() used in filter_mm_token_from_input_ids. + text_token_indices, mm_token_indices = filter_mm_token_from_input_ids( + input_ids, + vocab_size=embedding_layer.num_embeddings, + mm_token_ids=mm_token_ids) + + if mm_token_indices.shape[0] != mm_embed.shape[0]: raise ValueError( f"Multimodal token count mismatch: found {len(mm_token_indices)} image tokens in input_ids " f"but received {mm_embed.shape[0]} image embeddings. " @@ -159,8 +195,7 @@ def fuse_input_embeds( device=text_embed.device, dtype=text_embed.dtype) - input_embeds[text_token_indices, :] = text_embed.to( - dtype=input_embeds.dtype, device=input_embeds.device) + input_embeds[text_token_indices, :] = text_embed input_embeds[mm_token_indices, :] = mm_embed.to(dtype=input_embeds.dtype, device=input_embeds.device) diff --git a/tensorrt_llm/_torch/models/modeling_phi4mm.py b/tensorrt_llm/_torch/models/modeling_phi4mm.py index bc449e1da5..38ee1eb110 100644 --- a/tensorrt_llm/_torch/models/modeling_phi4mm.py +++ b/tensorrt_llm/_torch/models/modeling_phi4mm.py @@ -594,6 +594,7 @@ class Phi4MMForCausalLM(transformers.PreTrainedModel): input_ids, mm_embedding, mm_token_ids=self.mm_token_ids, + **kwargs, ) output_prob = self.llm.forward( diff --git a/tensorrt_llm/_torch/models/modeling_qwen2vl.py b/tensorrt_llm/_torch/models/modeling_qwen2vl.py index 6d9493fafe..3df3eb75ec 100644 --- a/tensorrt_llm/_torch/models/modeling_qwen2vl.py +++ b/tensorrt_llm/_torch/models/modeling_qwen2vl.py @@ -612,7 +612,8 @@ class Qwen2VLModelBase(PreTrainedModel): 'mrope_position_deltas'] input_ids, input_embeds = fuse_input_embeds(self.llm.model.embed_tokens, - input_ids, mm_embeds) + input_ids, mm_embeds, + **kwargs) output_prob = self.llm.forward( attn_metadata=attn_metadata, input_ids=input_ids, diff --git a/tensorrt_llm/_torch/models/modeling_qwen3.py b/tensorrt_llm/_torch/models/modeling_qwen3.py index cbd2ebb983..8087ef30df 100644 --- a/tensorrt_llm/_torch/models/modeling_qwen3.py +++ b/tensorrt_llm/_torch/models/modeling_qwen3.py @@ -48,6 +48,9 @@ class Qwen3Attention(QKNormRoPEAttention): rope=RopeParams.from_config(config), ) + # Qwen3 has accuracy issues with deep_gemm (see: https://nvbugspro.nvidia.com/bug/5461712) + # TODO: Consider adding disable_deep_gemm support to QKNormRoPEAttention if accuracy still remains + super().__init__( hidden_size=config.hidden_size, num_attention_heads=config.num_attention_heads, @@ -85,6 +88,7 @@ class Qwen3DecoderLayer(DecoderLayer): dtype=config.torch_dtype, config=model_config, ) + self.input_layernorm = RMSNorm(hidden_size=config.hidden_size, eps=config.rms_norm_eps, dtype=config.torch_dtype) diff --git a/tensorrt_llm/_torch/models/modeling_vila.py b/tensorrt_llm/_torch/models/modeling_vila.py index e69851cc2f..53f5969806 100644 --- a/tensorrt_llm/_torch/models/modeling_vila.py +++ b/tensorrt_llm/_torch/models/modeling_vila.py @@ -1186,7 +1186,7 @@ class VilaModel(PreTrainedModel): ] input_ids, inputs_embeds = fuse_input_embeds( - self.llm.model.embed_tokens, input_ids, mm_embeds) + self.llm.model.embed_tokens, input_ids, mm_embeds, **kwargs) logits = self.llm.forward(attn_metadata=attn_metadata, input_ids=input_ids, position_ids=position_ids, diff --git a/tensorrt_llm/_torch/modules/attention.py b/tensorrt_llm/_torch/modules/attention.py index 20d87188de..d390dde474 100644 --- a/tensorrt_llm/_torch/modules/attention.py +++ b/tensorrt_llm/_torch/modules/attention.py @@ -216,6 +216,7 @@ class Attention(nn.Module): skip_create_weights_in_init=config.skip_create_weights_in_init, allreduce_strategy=config.allreduce_strategy, force_dynamic_quantization=config.force_dynamic_quantization) + self.o_lora = LoraLayer([LoraModuleType.ATTENTION_DENSE], [self.hidden_size]) diff --git a/tensorrt_llm/_torch/modules/embedding.py b/tensorrt_llm/_torch/modules/embedding.py index 7fee9b515d..3be7a652f6 100644 --- a/tensorrt_llm/_torch/modules/embedding.py +++ b/tensorrt_llm/_torch/modules/embedding.py @@ -138,25 +138,23 @@ def pre_comm_embedding_ops( padding_size: int, ): # Generate the mask for the input if needed. - if tp_size > 1: - if tp_mode == TensorParallelMode.COLUMN: - input_, input_mask = get_masked_input_and_mask( - input_, - vocab_start_index, - vocab_end_index, - ) + if tp_mode == TensorParallelMode.COLUMN: + input_, input_mask = get_masked_input_and_mask( + input_, + vocab_start_index, + vocab_end_index, + ) # Get the embeddings. output = F.embedding(input_, weight) # Mask or pad the output if needed. - if tp_size > 1: - if tp_mode == TensorParallelMode.COLUMN: - output.masked_fill_(input_mask, 0) - elif tp_mode == TensorParallelMode.ROW: - if gather_output: - if tp_rank == tp_size - 1 and padding_size > 0: - output = F.pad(output, (0, padding_size)) + if tp_mode == TensorParallelMode.COLUMN: + output.masked_fill_(input_mask, 0) + elif tp_mode == TensorParallelMode.ROW: + if gather_output: + if tp_rank == tp_size - 1 and padding_size > 0: + output = F.pad(output, (0, padding_size)) return output @@ -205,12 +203,16 @@ class Embedding(LMHead): self.vocab_end_index = num_embeddings def forward(self, input): - # Run the ops before all_reduce/all_gather. - # We use torch.compile() to fuse the tiny pointwise ops before all_reduce/all_gather for Embedding module. - embedding_ops_func = torch.compile( - pre_comm_embedding_ops, - options={"max-autotune": True}, - disable=not self.enable_torch_compile_for_embedding) + if self.tp_size > 1: + # Run the ops before all_reduce/all_gather. + # We use torch.compile() to fuse the tiny pointwise ops before all_reduce/all_gather for Embedding module. + embedding_ops_func = torch.compile( + pre_comm_embedding_ops, + options={"max-autotune": True}, + disable=not self.enable_torch_compile_for_embedding) + else: + # Skip torch.compile when TP size is 1 to avoid unnecessary host overhead + embedding_ops_func = pre_comm_embedding_ops output = embedding_ops_func(input, self.weight, self.tp_size, self.tp_rank, self.tp_mode, self.vocab_start_index, diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py index a74d8f2e73..4e18ae8c24 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py @@ -3,6 +3,8 @@ from typing import Dict, List, Optional, Union import torch from torch import nn +from tensorrt_llm._utils import get_sm_version + from ...model_config import ModelConfig from ...utils import Fp4QuantizedTensor, next_positive_power_of_2 from .interface import MoE, MoEWeightLoadingMode @@ -78,6 +80,11 @@ class TRTLLMGenFusedMoE(MoE): swiglu_limit=swiglu_limit, ) + sm_version = get_sm_version() + if sm_version >= 120: + raise NotImplementedError( + "TRTLLMGenFusedMoE does not support SM120 and above.") + assert not self.smart_router, "Smart router is not supported in TRTLLMGenFusedMoE." self.num_slots = self.num_experts diff --git a/tensorrt_llm/_torch/modules/gated_mlp.py b/tensorrt_llm/_torch/modules/gated_mlp.py index f177c41885..cf381ea2c2 100644 --- a/tensorrt_llm/_torch/modules/gated_mlp.py +++ b/tensorrt_llm/_torch/modules/gated_mlp.py @@ -5,6 +5,7 @@ import torch import torch.nn.functional as F from torch import nn +from tensorrt_llm.logger import logger from tensorrt_llm.mapping import Mapping from ..distributed import AllReduceParams @@ -29,6 +30,7 @@ class GatedMLP(nn.Module): reduce_output: bool = True, layer_idx: Optional[int] = None, use_cute_dsl_blockscaling_mm: bool = False): + super().__init__() self.layer_idx = layer_idx self.hidden_size = hidden_size @@ -98,12 +100,21 @@ class GatedMLP(nn.Module): [LoraModuleType.MLP_GATE_UP], [2 * self.intermediate_size // mapping.tp_size]) - def _apply_activation(self, x): + def _apply_activation(self, x, *, has_lora: bool = False): if self.activation == F.silu: if self.down_proj.has_fp8_qdq or self.down_proj.has_w4a8_nvfp4_fp8: - return swiglu(x, - quant_scale=self.down_proj.input_scale, - quant_type=torch.float8_e4m3fn) + if has_lora: + # NOTE: This is a WAR, since LoRA grouped_gemm does not support FP8 yet. + # TODO: Remove this path when LoRA grouped_gemm supports FP8 + # see: cpp/tensorrt_llm/thop/loraOp.cpp::lora_grouped_gemm + logger.warning( + f"GatedMLP._apply_activation: LoRA path active; forcing non-FP8 activation dtype bf16/fp16, layer_idx={self.layer_idx}" + ) + return swiglu(x) + else: + return swiglu(x, + quant_scale=self.down_proj.input_scale, + quant_type=torch.float8_e4m3fn) else: return swiglu(x) elif callable(self.activation): @@ -155,7 +166,7 @@ class GatedMLP(nn.Module): if h1_lora is not None: h1 = h1 + h1_lora - h2 = self._apply_activation(h1) + h2 = self._apply_activation(h1, has_lora=True) output = self.down_proj(h2, all_reduce_params=final_all_reduce_params, lora_params=lora_params, diff --git a/tensorrt_llm/_torch/pyexecutor/config.py b/tensorrt_llm/_torch/pyexecutor/config.py index 3be1e5558f..7f46c521b6 100644 --- a/tensorrt_llm/_torch/pyexecutor/config.py +++ b/tensorrt_llm/_torch/pyexecutor/config.py @@ -5,7 +5,6 @@ from tensorrt_llm._torch.models.checkpoints.base_checkpoint_loader import \ BaseCheckpointLoader from tensorrt_llm.bindings.executor import ExecutorConfig -from ...builder import BuildConfig from ...llmapi.llm_args import LoadFormat, SamplerType from ...logger import logger from ...mapping import Mapping @@ -119,7 +118,6 @@ EXETENDED_EXECUTOR_CONFIG_FIELDS = [ 'backend', 'pytorch_backend_config', 'max_seq_len', - 'tokens_per_block', 'mapping', 'hf_model_dir', 'mm_encoder_only', @@ -131,7 +129,6 @@ def update_executor_config( backend: Optional[str] = None, pytorch_backend_config: Optional[PyTorchConfig] = None, mapping: Optional[Mapping] = None, - build_config: Optional[BuildConfig] = None, speculative_config: Optional["DecodingBaseConfig"] = None, hf_model_dir: Optional[str] = None, max_input_len: Optional[int] = None, @@ -156,10 +153,6 @@ def update_executor_config( logger.info(f"{executor_config.pytorch_backend_config}") - build_config = build_config or BuildConfig() - # TODO: move to pure-Python KvCacheConfig, and remove dependency on build_config. - executor_config.tokens_per_block = executor_config.tokens_per_block or build_config.plugin_config.tokens_per_block - executor_config.hf_model_dir = hf_model_dir if max_input_len is not None: diff --git a/tensorrt_llm/_torch/pyexecutor/llm_request.py b/tensorrt_llm/_torch/pyexecutor/llm_request.py index 2bddb6fd58..c6915fbd66 100644 --- a/tensorrt_llm/_torch/pyexecutor/llm_request.py +++ b/tensorrt_llm/_torch/pyexecutor/llm_request.py @@ -1,6 +1,8 @@ +from collections.abc import Generator from copy import deepcopy from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Union +from itertools import pairwise +from typing import Any, Dict, List, Optional, TypeAlias, Union import torch @@ -424,7 +426,10 @@ class LlmRequest(tensorrt_llm.bindings.internal.batch_manager.LlmRequest): self.child_requests.append(py_request) -def convert_wordlist(word_list) -> List[List[int]]: +StopWordList: TypeAlias = list[list[int]] + + +def convert_wordlist(word_list) -> StopWordList: """Converts a wordlist from format: [[word_0 token_0, word_0 token_1, ...], [word_1 token_0, ...], ...]] @@ -461,6 +466,16 @@ def convert_wordlist(word_list) -> List[List[int]]: return [tokens, offsets] +def produce_stop_words( + py_stop_words_list: StopWordList) -> Generator[list[int], None, None]: + """yield stop sequences from the output of `convert_wordlist` above.""" + stop_words_list, prefix_sum = py_stop_words_list + for start, end in pairwise((0, *prefix_sum)): # first element: prepend 0 + if end == -1: # -1 is a sentinel value in convert_wordlist + break + yield stop_words_list[start:end] + + def executor_request_to_llm_request( req_id: int, executor_request: ExecutorRequest, diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index e665af55d5..2829fcb18f 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -43,6 +43,7 @@ from ..metadata import KVCacheParams from ..model_config import ModelConfig, MoeLoadBalancerConfig from ..models import AutoModelForCausalLM from ..models.checkpoints.base_checkpoint_loader import BaseCheckpointLoader +from ..models.modeling_multimodal_utils import filter_mm_token_from_input_ids from ..models.modeling_utils import (DecoderModelForCausalLM, MetaInitMode, timing) from ..modules.fused_moe.moe_load_balancer import ( @@ -1259,6 +1260,16 @@ class PyTorchModelEngine(ModelEngine): return total_num_tokens, False, attn_all_rank_num_tokens + def _prepare_multimodal_indices(self, input_ids: list[int]): + input_ids = torch.tensor(input_ids, dtype=torch.int, device="cpu") + vocab_size = self.model.config.vocab_size + # TODO: unify naming of mm_token_ids across models + mm_token_ids = getattr(self.model, "mm_token_ids", None) + + text_token_indices, mm_token_indices = filter_mm_token_from_input_ids( + input_ids, vocab_size=vocab_size, mm_token_ids=mm_token_ids) + return text_token_indices, mm_token_indices + def _prepare_tp_inputs( self, scheduled_requests: ScheduledRequests, @@ -1335,6 +1346,14 @@ class PyTorchModelEngine(ModelEngine): request.py_batch_idx = request.py_seq_slot + if len(multimodal_params_list) > 0: + # discard the text token indices as it only includes context tokens at this moment + print( + f"len multimodal_params_list: {len(multimodal_params_list)} from model_engine" + ) + _, mm_token_indices = self._prepare_multimodal_indices(input_ids) + else: + mm_token_indices = None num_ctx_requests = len(scheduled_requests.context_requests) num_ctx_tokens = len(input_ids) @@ -1698,6 +1717,14 @@ class PyTorchModelEngine(ModelEngine): spec_metadata.all_rank_num_tokens = spec_all_rank_num_tokens spec_metadata.all_rank_num_seqs = all_rank_num_seqs + if mm_token_indices is not None: + mask = torch.ones(total_num_tokens, dtype=torch.bool) + mask[mm_token_indices] = False + inputs['mm_token_indices'] = mm_token_indices.pin_memory().to( + "cuda", non_blocking=True) + inputs['text_token_indices'] = torch.where(mask)[0].pin_memory().to( + "cuda", non_blocking=True) + num_generation_tokens = len(generation_requests) + len( extend_requests) + sum(draft_lens) self.iter_states['num_ctx_requests'] = num_ctx_requests diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 4070f210e1..8d7f2f3234 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -61,6 +61,9 @@ PROFILE_RECORD_GC_ENV_VAR_NAME = "TLLM_PROFILE_RECORD_GC" # Set to a path to save detailed tracing of PyTorch operations. PROFILE_TRACE_ENV_VAR_NAME = "TLLM_TORCH_PROFILE_TRACE" +# Unique tag base to avoid collisions with token/logits comms +TERMINATION_COMM_TAG_BASE = 20000 + @functools.cache def _load_iteration_indexes(env_var: str): @@ -208,6 +211,7 @@ class PyExecutor: self.kv_cache_manager = self.resource_manager.resource_managers.get( ResourceManagerType.KV_CACHE_MANAGER) self.enable_kv_cache_events = self.kv_cache_manager is not None and self.kv_cache_manager.event_buffer_max_size > 0 + self.enable_kv_cache_reuse = self.kv_cache_manager is not None and self.kv_cache_manager.enable_block_reuse self.max_input_len = max_input_len # _executor_loop private data @@ -259,6 +263,13 @@ class PyExecutor: self.gather_all_responses = False self.kv_cache_transceiver = kv_cache_transceiver + + # Initialize disagg PP termination handler if needed + self._disagg_pp_termination_handler = None + if self.dist.pp_size > 1 and self.enable_kv_cache_reuse and self.kv_cache_transceiver: + self._disagg_pp_termination_handler = DisaggPPTerminationHandler( + self.num_micro_batches, self.dist) + if self.dist.pp_size > 1: self.event_loop = self._executor_loop_pp else: @@ -718,6 +729,14 @@ class PyExecutor: batch_state.sample_state.scheduled_requests), req_stats) def _executor_loop_cleanup(self): + + for h in self.send_handles: + if h is not None: + h.wait() + + if self._disagg_pp_termination_handler is not None: + self._disagg_pp_termination_handler.cleanup() + with self.response_cv: self.is_shutdown = True self.response_cv.notify_all() @@ -826,6 +845,7 @@ class PyExecutor: sample_state = self._sample_async( scheduled_batch, batch_outputs) + assert sample_state is not None, "Sampling failed" self._update_request_states(scheduled_batch) if self.enable_iter_perf_stats: @@ -905,6 +925,12 @@ class PyExecutor: if self.kv_cache_transceiver and self.ctx_in_transmission_requests: self._terminate_ctx_finished_requests() + if self._disagg_pp_termination_handler is not None: + requests_to_terminate = self._disagg_pp_termination_handler.sync( + prev_microbatch_id) + for req in requests_to_terminate: + self._do_terminate_request(req) + # march forward in microbatch slots microbatch_id = (microbatch_id + 1) % self.num_micro_batches @@ -1696,9 +1722,13 @@ class PyExecutor: self._enqueue_responses(error_responses.items()) def _terminate_request(self, request: LlmRequest): - if self.kv_connector_manager is None: - self.resource_manager.free_resources(request) + if self._disagg_pp_termination_handler is not None: + self._disagg_pp_termination_handler.terminate(request) else: + self._do_terminate_request(request) + + def _do_terminate_request(self, request: LlmRequest): + if self.kv_connector_manager is not None: # Only call request_finished on the connector if the request has already been added to the kv cache manager. try: cache_block_ids = self.kv_cache_manager.get_cache_indices( @@ -1711,6 +1741,8 @@ class PyExecutor: if not self.kv_connector_manager.request_finished( request, cache_block_ids): self.resource_manager.free_resources(request) + else: + self.resource_manager.free_resources(request) @nvtx_range("_handle_canceled_requests") def _handle_canceled_requests(self): @@ -1919,3 +1951,104 @@ class PyExecutor: """Remove reqids of current requests from self.inflight_req_ids.""" for req in scheduled_requests.all_requests(): self.inflight_req_ids.erase(req.request_id) + + +class DisaggPPTerminationHandler: + """Handles termination synchronization across pipeline parallel ranks under disaggregated serving. + + We require synchronization when terminating requests in disaggregated PP when + KV cache reuse is enabled. All PP ranks need to reach consensus before freeing + resources to avoid a NCCL hang. + """ + + def __init__(self, num_micro_batches: int, dist): + self.dist = dist + # Request termination synchronization across PP ranks + # {request_id: {'ready_to_terminate': set{ranks}, 'terminated': {ranks}}} + self.pending_termination = {} + self.termination_handles = [None] * num_micro_batches + # Local map from request_id -> local LlmRequest awaiting consensus termination + self.local_termination = {} + + def terminate(self, request: LlmRequest) -> bool: + req_key = request.py_request_id + self.local_termination[req_key] = request + state = self.pending_termination.get(req_key, None) + if state is None: + state = {'ready_to_terminate': set(), 'terminated': set()} + self.pending_termination[req_key] = state + if self.dist.rank not in state['ready_to_terminate']: + state['ready_to_terminate'].add(self.dist.rank) + return False + + def sync(self, microbatch_id: int) -> List[LlmRequest]: + """Ring-communicate pending termination state and apply local terminations upon consensus. + + Each rank sends its current pending_termination snapshot to the next PP rank + and receives the previous rank's snapshot. After merging, apply any terminations + that have reached consensus (i.e., all PP ranks are ready). + """ + snapshot = { + req_id: { + 'ready_to_terminate': state.get('ready_to_terminate', set()), + 'terminated': state.get('terminated', set()), + } + for req_id, state in self.pending_termination.items() + } + + if self.termination_handles[microbatch_id] is not None: + self.termination_handles[microbatch_id].wait() + + term_tag = TERMINATION_COMM_TAG_BASE + microbatch_id + self.termination_handles[microbatch_id] = self.dist.isend_object( + snapshot, + dest=self.dist.next_pp_rank, + tag=term_tag, + ) + remote_state = self.dist.recv_object( + src=self.dist.prev_pp_rank, + tag=term_tag, + ) + logger.debug( + f"received remote state for microbatch {microbatch_id}, prev pp rank: {self.dist.prev_pp_rank} state {remote_state}" + ) + + if remote_state: + for req_id, state in remote_state.items(): + local = self.pending_termination.get(req_id) + if local is None: + self.pending_termination[req_id] = { + 'ready_to_terminate': state.get('ready_to_terminate', + set()), + 'terminated': state.get('terminated', set()), + } + else: + for key in ('ready_to_terminate', 'terminated'): + for r in state.get(key, []): + if r not in local[key]: + local[key].add(r) + + requests_to_terminate = [] + to_delete = [] + for req_id, state in self.pending_termination.items(): + ready = state.get('ready_to_terminate', set()) + done = state.get('terminated', set()) + # If all PP ranks are ready to terminate the request, we can free the resources + if len(ready) >= self.dist.pp_size and self.dist.rank not in done: + local_req = self.local_termination.get(req_id) + if local_req is not None: + requests_to_terminate.append(local_req) + done.add(self.dist.rank) + if len(done) >= self.dist.pp_size: + to_delete.append(req_id) + if req_id in self.local_termination: + self.local_termination.pop(req_id, None) + for req_id in to_delete: + self.pending_termination.pop(req_id, None) + + return requests_to_terminate + + def cleanup(self): + for h in self.termination_handles: + if h is not None: + h.wait() diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index e6d19a9df4..9a4d933e08 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -1,12 +1,16 @@ from abc import ABC, abstractmethod from collections.abc import Iterable from dataclasses import dataclass -from typing import List, Literal, Optional +from functools import cached_property +from typing import List, Literal, Optional, TypeAlias +import numpy as np import torch from tensorrt_llm._torch.pyexecutor.make_decoding_batch_input_output import \ MakeDecodingBatchInputOutput +from tensorrt_llm._torch.pyexecutor.sampler_utils import ( + BEAM_0, SINGLE_BEAM_WIDTH, handle_stop_single_beam) from tensorrt_llm._utils import nvtx_range, torch_dtype_to_binding from tensorrt_llm.bindings import (CudaStream, DataType, ModelConfig, WorldConfig, make_sampling_config) @@ -355,21 +359,55 @@ def int_tensor(shape: tuple[int, ...], device: str = 'cuda') -> torch.Tensor: return torch.empty(shape, dtype=torch.int, device=device) +class TorchStore: + + def __init__(self, *, max_draft_len: int, max_num_sequences: int, + max_beam_width: int): + self.max_draft_len = max_draft_len + self.max_num_sequences = max_num_sequences + self.max_beam_width = max_beam_width + self.max_tokens = max_draft_len + 1 + assert max_beam_width == SINGLE_BEAM_WIDTH, "TorchSampler only supports beam_width = 1" + self.new_tokens = int_tensor( + (self.max_tokens, max_num_sequences, max_beam_width)) + """Shape: See cpp DecoderState.getAllNewTokens()""" + self.finish_reasons = int_tensor(self.new_tokens.shape) + + # Helper tensors for finish_reasons: + self._finish_reasons_nonzero_static_buffer = torch.empty( + (self.max_tokens * max_num_sequences, 2), + device='cuda', + dtype=torch.int64) + """Preallocate buffer needed for torch.nonzero_static(..., out=finish_reasons_nonzero_static_buffer), see `def _write_reason`""" + self._reason_tensors = { + reason: + torch.tensor(reason.value, + dtype=self.finish_reasons.dtype, + device="cuda") + for reason in [ + FinishReason.NOT_FINISHED, FinishReason.END_ID, + FinishReason.STOP_WORDS, FinishReason.LENGTH, + FinishReason.TIMED_OUT, FinishReason.CANCELLED + ] # `in FinishReason` clashes with PyBind11: `TypeError: 'pybind11_type' object is not iterable` + } + + +@dataclass(kw_only=True) +class SampleStateTensorsHostTorch(SampleStateTensors): + finish_reasons: torch.Tensor + + +@dataclass(kw_only=True) +class SampleStateTorch(SampleState): + host: SampleStateTensorsHostTorch + + class TorchSampler(Sampler): - BEAM = 0 - MAX_BEAM_WIDTH = BEAM + 1 + SampleState = SampleStateTorch def is_generation_model(self) -> bool: return True - @dataclass(frozen=True, kw_only=True) - class Store: - new_tokens: torch.Tensor - """Shape: See cpp DecoderState.getAllNewTokens()""" - - def create_store(self) -> Store: - return self.Store(new_tokens=int_tensor(self.NEW_TOKENS_SHAPE)) - @dataclass(frozen=True, kw_only=True) class Args: max_seq_len: int @@ -381,17 +419,16 @@ class TorchSampler(Sampler): def __init__(self, args: Args): self.max_seq_len = args.max_seq_len self.enable_mixed_sampler = args.enable_mixed_sampler - self.max_tokens = args.max_draft_len + 1 - assert args.max_beam_width == self.MAX_BEAM_WIDTH, "TorchSampler only supports beam_width = 1" - self.max_num_sequences = args.max_num_sequences - self.NEW_TOKENS_SHAPE = (self.max_tokens, self.max_num_sequences, - self.MAX_BEAM_WIDTH) # AutoDeploy build creates the sampler in inference mode, # which would disallow in-place mutating of new_tokens. # So, we temporarily exit inference mode. with torch.inference_mode(False): - self.store = self.create_store() + self.store = TorchStore(max_draft_len=args.max_draft_len, + max_num_sequences=args.max_num_sequences, + max_beam_width=args.max_beam_width) + self.max_num_sequences = args.max_num_sequences + self.max_tokens = self.store.max_tokens # Initialize seed for multi-GPU consistency self._global_seed = 42 @@ -412,50 +449,7 @@ class TorchSampler(Sampler): self._generator.manual_seed(self._global_seed) return self._generator - def _meet_max_token_stop_criteria(self, request: LlmRequest): - num_tokens = request.get_num_tokens(self.BEAM) - return (num_tokens - request.py_orig_prompt_len - >= request.py_max_new_tokens) or (num_tokens - >= self.max_seq_len) - - @staticmethod - def _meet_stop_token_criteria(request: LlmRequest): - if request.py_stop_words_list: - assert isinstance( - request.py_stop_words_list, - list), "request.py_stop_words_list should be a list" - stop_words_list, prefix_sum = request.py_stop_words_list - tokens = request.get_tokens(0) - offset = 0 - for i, offset_end in enumerate(prefix_sum): - if i > 0: - offset = prefix_sum[i - 1] - stop_word = stop_words_list[offset:offset_end] - if len(stop_word) > len(tokens): - continue - if tokens[-len(stop_word):] == stop_word: - return True - return False - - def _handle_stop_criteria(self, request: LlmRequest, - new_token: int) -> bool: - """Handle stop criteria and set appropriate finish reasons and state. - Returns True if generation should stop.""" - if new_token == request.py_end_id: - request.finish_by(FinishReason.END_ID, self.BEAM) - return True - - if self._meet_max_token_stop_criteria(request): - request.finish_by(FinishReason.LENGTH, self.BEAM) - return True - - if self._meet_stop_token_criteria(request): - request.finish_by(FinishReason.STOP_WORDS, self.BEAM) - return True - - return False - - def handle_logprobs(self, request: LlmRequest, state: SampleState, *, + def handle_logprobs(self, request: LlmRequest, state: SampleStateTorch, *, beam: int, count: int): current_slice = slice(0, count), request.py_seq_slot, beam if request.py_return_log_probs: @@ -469,10 +463,26 @@ class TorchSampler(Sampler): assert beam == 0, "The following call relies on beam_width to be 1 - hence the list with a single element" request.py_result.append_log_probs([token_log_probs]) - def _process_draft_tokens_greedy(self, request: LlmRequest, - new_tokens: torch.Tensor) -> int: - new_token = add_token(request, new_tokens, beam=self.BEAM) - stop = self._handle_stop_criteria(request, new_token) + FinishReasons: TypeAlias = list[list[int]] + """`(num_seq_slots, num_steps)`""" + + @classmethod + def finish_if_reason(cls, request: LlmRequest, + finish_reasons: FinishReasons, *, step: int) -> bool: + reason = FinishReason(finish_reasons[request.py_seq_slot][step]) + valid_reasons = { + FinishReason.END_ID, FinishReason.LENGTH, FinishReason.STOP_WORDS + } + if reason in valid_reasons: + request.finish_by(reason, BEAM_0) + return True + return False + + def _process_draft_tokens_greedy(self, request: LlmRequest, *, + new_tokens: torch.Tensor, + finish_reasons: FinishReasons) -> int: + new_token = add_token(request, new_tokens, beam=BEAM_0) + stop = self.finish_if_reason(request, finish_reasons, step=0) if stop or get_draft_token_length(request) == 0: return 0 num_accepted = 0 @@ -485,14 +495,17 @@ class TorchSampler(Sampler): num_accepted += 1 new_token = add_token(request, new_tokens, - beam=self.BEAM, + beam=BEAM_0, step=num_accepted) - if self._handle_stop_criteria(request, new_token): + if self.finish_if_reason(request, finish_reasons, + step=num_accepted): break return num_accepted def _process_draft_tokens_rejection_sampling( self, request: LlmRequest, new_tokens: torch.Tensor) -> int: + """We cannot use finish_if_reason in _process_draft_tokens_rejection_sampling because it *writes to new_tokens*, + rendering the finish reason calculation in sample_async stale (incorrect) for this batch""" sampling_strategy = request_strategy(request) generator = self.get_generator(request.py_draft_logits.device) _, draft_probs = sample(sampling_strategy, @@ -503,7 +516,6 @@ class TorchSampler(Sampler): generator, request.py_draft_tokens) sample_last = True - stop = False if rejected_indices.numel() == 0: num_initially_accepted = get_draft_token_length(request) sample_last = False @@ -512,59 +524,62 @@ class TorchSampler(Sampler): num_accepted = num_initially_accepted for i in range(num_accepted): new_token = request.py_draft_tokens[i] - new_tokens[i, request.seq_slot, self.BEAM] = new_token - request.add_new_token(new_token, self.BEAM) - stop = self._handle_stop_criteria(request, new_token) - if stop: + new_tokens[i, request.seq_slot, BEAM_0] = new_token + request.add_new_token(new_token, BEAM_0) + if handle_stop_single_beam(request, + new_token, + max_seq_len=self.max_seq_len): num_accepted = i + 1 return num_accepted if sample_last: new_token = sample_rejected(draft_probs, target_probs, generator, num_accepted) - new_tokens[num_accepted, request.seq_slot, self.BEAM] = new_token - request.add_new_token(new_token, self.BEAM) - stop = self._handle_stop_criteria(request, new_token) + new_tokens[num_accepted, request.seq_slot, BEAM_0] = new_token + request.add_new_token(new_token, BEAM_0) + handle_stop_single_beam(request, + new_token, + max_seq_len=self.max_seq_len) else: new_token = add_token(request, new_tokens, - beam=self.BEAM, + beam=BEAM_0, step=num_accepted) - stop = self._handle_stop_criteria(request, new_token) + handle_stop_single_beam(request, + new_token, + max_seq_len=self.max_seq_len) return num_accepted - def process_draft_tokens(self, request: LlmRequest, - new_tokens: torch.Tensor) -> int: - if request.py_draft_logits is None: - return self._process_draft_tokens_greedy(request, new_tokens) - else: - return self._process_draft_tokens_rejection_sampling( - request, new_tokens) - - def update_requests(self, state: SampleState) -> None: - assert isinstance(state, SampleState) + def update_requests(self, state: SampleStateTorch) -> None: + assert isinstance(state, SampleStateTorch) if state.sampler_event: state.sampler_event.synchronize() new_tokens = state.host.new_tokens + finish_reasons = state.host.finish_reasons[:, :, BEAM_0].T.tolist() for req in state.scheduled_requests.context_requests: if req.state == LlmRequestState.GENERATION_COMPLETE or req.context_remaining_length != 0: continue - new_token = add_token(req, new_tokens, beam=self.BEAM) - self._handle_stop_criteria(req, new_token) - self.handle_logprobs(req, state, beam=self.BEAM, count=1) + add_token(req, new_tokens, beam=BEAM_0) + self.finish_if_reason(req, finish_reasons, step=0) + self.handle_logprobs(req, state, beam=BEAM_0, count=1) req.py_decoding_iter += 1 for req in state.scheduled_requests.generation_requests: if req.state == LlmRequestState.GENERATION_COMPLETE: continue processed = 1 - num_accepted = self.process_draft_tokens(req, new_tokens) + if req.py_draft_logits is None: + num_accepted = self._process_draft_tokens_greedy( + req, new_tokens=new_tokens, finish_reasons=finish_reasons) + else: + num_accepted = self._process_draft_tokens_rejection_sampling( + req, new_tokens) if get_draft_token_length(req) > 0: req.py_num_accepted_draft_tokens = num_accepted req.py_rewind_len = req.py_draft_pages_allocated - num_accepted processed += num_accepted - self.handle_logprobs(req, state, beam=self.BEAM, count=processed) + self.handle_logprobs(req, state, beam=BEAM_0, count=processed) req.py_decoding_iter += 1 def log_probs_host(self, scheduled_requests: ScheduledRequests): @@ -572,29 +587,50 @@ class TorchSampler(Sampler): if any(req.py_return_log_probs for req in scheduled_requests.all_requests()): return torch.empty( - (self.max_num_sequences, self.MAX_BEAM_WIDTH, self.max_tokens), + (self.max_num_sequences, SINGLE_BEAM_WIDTH, self.max_tokens), device="cpu", pin_memory=True) return None - def sample_async(self, scheduled_requests: ScheduledRequests, - model_outputs: dict[str, torch.Tensor], - num_context_logits_prefix_sum: list[int]) -> SampleState: + def sample_async( + self, scheduled_requests: ScheduledRequests, + model_outputs: dict[str, torch.Tensor], + num_context_logits_prefix_sum: list[int]) -> SampleStateTorch: + requests = scheduled_requests.all_requests() new_tokens = self.store.new_tokens + finish_reasons = self.store.finish_reasons log_probs_host = self.log_probs_host(scheduled_requests) + seq_slots_host = torch.tensor( + [r.py_seq_slot for r in requests], + dtype=torch.int64, # for index_fill_ + pin_memory=True) + seq_slots = seq_slots_host.to(device="cuda", non_blocking=True) self._process_requests(scheduled_requests, model_outputs, new_tokens, num_context_logits_prefix_sum, + seq_slots=seq_slots, + seq_slots_host=seq_slots_host, log_probs_host=log_probs_host) + self._write_finish_reasons(requests, + finish_reasons=finish_reasons, + seq_slots=seq_slots, + new_tokens=new_tokens) + new_tokens_host = new_tokens.to(device="cpu", non_blocking=True) + finish_reasons_host = finish_reasons.to(device="cpu", non_blocking=True) sampler_event = torch.cuda.Event() sampler_event.record() - return SampleState(scheduled_requests=scheduled_requests, - device=SampleStateTensors(new_tokens=new_tokens), - host=SampleStateTensors(new_tokens=new_tokens_host, - log_probs=log_probs_host), - sampler_event=sampler_event) + return SampleStateTorch( + scheduled_requests=scheduled_requests, + device=SampleStateTensors(new_tokens=new_tokens), + host=SampleStateTensorsHostTorch( + new_tokens=new_tokens_host, + log_probs=log_probs_host, + finish_reasons=finish_reasons_host, + ), + sampler_event=sampler_event, + ) @staticmethod def append_eagle3(tokens: torch.Tensor, model_outputs): @@ -654,15 +690,202 @@ class TorchSampler(Sampler): return logits + @staticmethod + def _longest_stop_word_len(requests: Iterable[LlmRequest]) -> int: + max_stop_word_len = 0 + for req in requests: + _, cumsum = req.py_stop_words_list + if -1 in cumsum: + cumsum = cumsum[:cumsum.index(-1)] + request_max_stop_word_len = np.max(np.diff(cumsum, prepend=0), + initial=0) + max_stop_word_len = max(max_stop_word_len, + request_max_stop_word_len) + return max_stop_word_len + + @staticmethod + def _requests_with_stop_words( + requests: list[LlmRequest]) -> list[LlmRequest]: + return [ + r for r in requests if (r.py_stop_words_list is not None + and len(r.py_stop_words_list[0]) > 0) + ] + + def _write_reason(self, finish_reasons: torch.Tensor, reason: FinishReason, + *, where: torch.Tensor, seq_slots: torch.Tensor) -> None: + """Avoid GPU<->CPU syncs via: + ### `nonzero_static` [REF-A], see: https://ianbarber.blog/2024/12/18/nonzero_static-in-pytorch/. + - `nonzero` syncs (frontend needs result size). + - `nonzero_static` pads with dummy entries (`fill_value`), written into a prealloc buffer (max_num_sequences, 2). + - Need to drop padding, but `buffer[buffer!=fill_value]`, `buffer[:count_nonzero]`, `buffer[:sum]` all sync. + + ### Hack: + 1. Use `fill_value=0`, so padding is `[..., [0,0], [0,0]]`. + 2. Write blindly to `finish_reasons` [REF-B]. Only `[seq_slot[0],0]` might have wrong values written to it, because of the padding entries. + 3. Save `[seq_slot[0],0]` in `before_write` [REF-C], restore if `where[0][0]` is `False` [REF-D]. + """ + assert seq_slots.is_cuda and where.is_cuda + assert seq_slots.shape[0] == where.shape[1] + first_slot = seq_slots[0].unsqueeze(0) + before_write = finish_reasons[0][:].index_select( + 0, first_slot).squeeze() # REF-C + reason_tensor = self.store._reason_tensors[reason] + buffer = self.store._finish_reasons_nonzero_static_buffer + size = buffer.shape[0] + torch.nonzero_static(where, size=size, fill_value=0, + out=buffer) # REF-A + r, c = buffer[:, 0], buffer[:, 1] + finish_reasons[r, seq_slots[c], BEAM_0] = reason_tensor # REF-B + + correct = torch.where(~where[0, 0], before_write, reason_tensor).view(1) + assert correct.is_cuda + finish_reasons[0, first_slot, BEAM_0] = correct # REF-D + + def _write_finish_reasons(self, requests: list[LlmRequest], *, + finish_reasons: torch.Tensor, + seq_slots: torch.Tensor, + new_tokens: torch.Tensor) -> None: + """later _write_reason overwrites earlier, in reverse precedence order""" + tokens = new_tokens[:, seq_slots, BEAM_0] + # we need to fill with NOT_FINISHED so we can differentiate between previous requests that had the same seq slot + finish_reasons.index_fill_(1, seq_slots, + FinishReason.NOT_FINISHED.value) + + if with_stop_words := self._requests_with_stop_words(requests): + stop_seq_slots = torch.tensor( + [r.py_seq_slot for r in with_stop_words], + pin_memory=True).to("cuda", non_blocking=True) + stop_tokens = new_tokens[:, stop_seq_slots, BEAM_0] + self._write_reason( + finish_reasons, + FinishReason.STOP_WORDS, + where=self._are_stop_words(with_stop_words, stop_tokens), + seq_slots=stop_seq_slots, + ) + + self._write_reason( + finish_reasons, + FinishReason.LENGTH, + where=self._are_max_length(requests), + seq_slots=seq_slots, + ) + + self._write_reason( + finish_reasons, + FinishReason.END_ID, + where=self._are_end_id(requests, tokens), + seq_slots=seq_slots, + ) + + def _are_end_id(self, requests: list[LlmRequest], + tokens: torch.Tensor) -> torch.Tensor: + end_ids_tensor = torch.tensor( + [([req.py_end_id if req.py_end_id is not None else -1] * + self.max_tokens) for req in requests], + pin_memory=True, + dtype=tokens.dtype).T.to(device="cuda", non_blocking=True) + return tokens == end_ids_tensor + + def _are_max_length(self, requests: list[LlmRequest]) -> torch.Tensor: + lengths_tensor = torch.tensor([[ + ((req.get_num_tokens(BEAM_0) + num_tokens) - req.py_orig_prompt_len) + for num_tokens in range(1, self.max_tokens + 1) + ] for req in requests]) + max_lengths_tensor = torch.tensor([ + ([min(req.py_max_new_tokens, self.max_seq_len)] * self.max_tokens) + for req in requests + ]) + return (lengths_tensor + >= max_lengths_tensor).T.pin_memory().to(device="cuda", + non_blocking=True) + + _PAD_ID = -1 + """Pad with negative, doesn't matter what""" + + @cached_property + def _pad_steps_mask(self): + square = torch.ones(self.max_tokens, self.max_tokens, dtype=torch.bool) + pad_id = torch.tensor(self._PAD_ID) + mask = torch.where(square.tril(), torch.tensor(1), pad_id) + mask.pin_memory() + return mask.to("cuda", non_blocking=True) + + def _padded_old_tokens(self, + requests: list[LlmRequest], + new_tokens: torch.Tensor, + pad_id: int = _PAD_ID) -> torch.Tensor: + # TODO: make sure only the lookback tokens are pulled into the list + longest = self._longest_stop_word_len(requests) + assert longest > 0, f"{longest=}, longest stop word length should be greater than 0, as this code path is only reached with requests with stop words" + lookback = longest - 1 + old_tokens = [] + for request in requests: + old = request.get_tokens(BEAM_0)[-lookback:] if lookback > 0 else [] + padded = [pad_id] * max(0, lookback - len(old)) + old + old_tokens.append([padded] * self.max_tokens) + old_tokens_tensor = torch.tensor(old_tokens, + pin_memory=True).to("cuda", + non_blocking=True) + assert old_tokens_tensor.shape == ( + len(requests), self.max_tokens, lookback + ), f"{old_tokens_tensor.shape} != ({len(requests)=}, {self.max_tokens=}, {lookback=})" + new_tokens = new_tokens.T.unsqueeze(1) * self._pad_steps_mask + ret = torch.cat((old_tokens_tensor, new_tokens), dim=-1) + assert ret.shape == ( + len(requests), self.max_tokens, lookback + self.max_tokens + ), f"{ret.shape} != ({len(requests)=}, {self.max_tokens=}, {lookback + self.max_tokens=})" + return ret + + def _are_stop_words(self, requests: list[LlmRequest], + tokens: torch.Tensor) -> torch.Tensor: + per_step = torch.zeros((self.max_tokens, len(requests)), + dtype=torch.bool, + pin_memory=True).to("cuda", non_blocking=True) + + padded_tokens = self._padded_old_tokens(requests, tokens) + + def request_stop_words(request: LlmRequest, new_tokens: torch.Tensor): + swl, ends = request.py_stop_words_list + if -1 in ends: + ends = ends[:ends.index(-1)] + lens = np.diff(ends, prepend=0) + lens_device = torch.tensor(list(lens), + pin_memory=True).to("cuda", + non_blocking=True) + max_len = np.max(lens) + + words = torch.zeros(len(lens), + max_len, + dtype=torch.int32, + pin_memory=True) + for step, (start, l) in enumerate(zip([0] + ends, lens)): + words[step, :l] = torch.tensor(swl[start:start + l], + dtype=torch.int32) + words_device = words.to("cuda", non_blocking=True) + + for step, step_seq in enumerate(new_tokens): + for word, L in zip(words_device, lens_device): + truncated_seq = step_seq[step_seq >= 0][-L:] + if torch.equal(truncated_seq, word[-L:]): + # We don't care about subsequent steps because we already found a stop word match + return step + return None + + for request_idx, request in enumerate(requests): + step = request_stop_words(request, padded_tokens[request_idx]) + if step is not None: + per_step[step][request_idx] = True + return per_step + def _process_requests(self, scheduled_requests: ScheduledRequests, model_outputs: dict[str, torch.Tensor], new_tokens: torch.Tensor, num_context_logits_prefix_sum: list[int], *, + seq_slots: torch.Tensor, + seq_slots_host: torch.Tensor, log_probs_host: torch.Tensor | None = None): - beam_width = self.MAX_BEAM_WIDTH - beam = self.BEAM # raw_logits should contain only the logits from the gen requests. # If return context logits is requested, fetch only the logits from gen requests. @@ -687,16 +910,13 @@ class TorchSampler(Sampler): no_draft_tokens = len(requests) == sum_steps fast_path = not self.enable_mixed_sampler and no_draft_tokens and log_probs_host is None - seq_slots_host = torch.as_tensor([r.py_seq_slot for r in requests]) - seq_slots = seq_slots_host.to(device="cuda", non_blocking=True) - if fast_path: logits = raw_logits[:len(requests)] logits = self._apply_embedding_bias(logits, requests) next_tokens = torch.argmax(logits, dim=-1) self.append_eagle3(next_tokens, model_outputs) int_next_tokens = next_tokens.to(torch.int, non_blocking=True) - next_tokens = int_next_tokens.view(1, -1, beam_width) + next_tokens = int_next_tokens.view(1, -1, SINGLE_BEAM_WIDTH) new_tokens[:1].index_copy_(1, seq_slots, next_tokens) return @@ -737,17 +957,17 @@ class TorchSampler(Sampler): # Batched processing already applied bias, just use the results next_tokens = batched_next_tokens[input_slice] softmax = batched_softmax[input_slice] - current_slice = slice(0, steps), slot, beam + current_slice = slice(0, steps), slot, BEAM_0 new_tokens[current_slice] = next_tokens if request.py_draft_logits is not None: request.py_target_probs = softmax.clone() if log_probs_host is not None: - assert beam == 0, "The following call relies on beam_width to be 1 - hence the unsqueeze" + assert BEAM_0 == 0, "The following call relies on beam_width to be 1 - hence the unsqueeze" token_probs = torch.gather( softmax, dim=1, index=next_tokens.unsqueeze(1)).squeeze(-1) log_probs = torch.log(token_probs) - log_probs_host[slot, beam, :steps].copy_(log_probs, - non_blocking=True) + log_probs_host[slot, BEAM_0, :steps].copy_(log_probs, + non_blocking=True) offset += steps diff --git a/tensorrt_llm/_torch/pyexecutor/sampler_utils.py b/tensorrt_llm/_torch/pyexecutor/sampler_utils.py new file mode 100644 index 0000000000..df0bc388e2 --- /dev/null +++ b/tensorrt_llm/_torch/pyexecutor/sampler_utils.py @@ -0,0 +1,61 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tensorrt_llm.bindings.executor import FinishReason + +from .llm_request import LlmRequest, produce_stop_words + +BEAM_0 = 0 +SINGLE_BEAM_WIDTH = 1 + + +def max_token_criteria_single_beam(request: LlmRequest, + max_seq_len: int) -> bool: + num_tokens = request.get_num_tokens(BEAM_0) + return (num_tokens - request.py_orig_prompt_len + >= request.py_max_new_tokens) or (num_tokens >= max_seq_len) + + +def stop_token_criteria(py_stop_words_list: list[list[int]] | None, + tokens: list[int]) -> bool: + if py_stop_words_list: + assert isinstance(py_stop_words_list, + list), "request.py_stop_words_list should be a list" + for stop_word in produce_stop_words(py_stop_words_list): + if len(stop_word) > len(tokens): + continue + if tokens[-len(stop_word):] == stop_word: + return True + return False + + +def handle_stop_single_beam(request: LlmRequest, new_token: int, *, + max_seq_len: int) -> bool: + """Handle stop criteria and set appropriate finish reasons and state. + Returns True if generation should stop.""" + if new_token == request.py_end_id: + request.finish_by(FinishReason.END_ID, BEAM_0) + return True + + if max_token_criteria_single_beam(request, max_seq_len): + request.finish_by(FinishReason.LENGTH, BEAM_0) + return True + + if stop_token_criteria(request.py_stop_words_list, + request.get_tokens(BEAM_0)): + request.finish_by(FinishReason.STOP_WORDS, BEAM_0) + return True + + return False diff --git a/tensorrt_llm/_torch/speculative/drafting_loops.py b/tensorrt_llm/_torch/speculative/drafting_loops.py index a54fb0cbfc..bf0c8e0f6d 100644 --- a/tensorrt_llm/_torch/speculative/drafting_loops.py +++ b/tensorrt_llm/_torch/speculative/drafting_loops.py @@ -20,9 +20,12 @@ from tensorrt_llm._torch.speculative.interface import SpecMetadata @contextmanager def save_metadata_state(attn_metadata: AttentionMetadata, spec_metadata: SpecMetadata) -> None: - attn_metadata.prepare_for_spec_dec("_seq_lens", "_seq_lens_cuda", - "kv_lens_cuda") + attn_metadata.prepare_for_spec_dec("_seq_lens", "_seq_lens_cuda") batch_size = attn_metadata.num_seqs + # Do not use prepare_for_spec_dec for this special field. + # TRTLLM attention uses views of this tensor internally and prepare_for_spec_dec + # creates a copy. If you write to the copy, TRTLLM attention won't see the updates. + kv_lens = attn_metadata.kv_lens_cuda[:batch_size].clone() if attn_metadata.is_cuda_graph: assert spec_metadata.is_cuda_graph @@ -39,6 +42,8 @@ def save_metadata_state(attn_metadata: AttentionMetadata, yield finally: attn_metadata.restore_from_spec_dec() + attn_metadata.kv_lens_cuda[:batch_size].copy_(kv_lens) + if attn_metadata.is_cuda_graph: spec_metadata.num_tokens = num_tokens if isinstance(spec_metadata, Eagle3SpecMetadata): diff --git a/tensorrt_llm/_torch/speculative/mtp.py b/tensorrt_llm/_torch/speculative/mtp.py index 5b3c3c6452..7579809c78 100644 --- a/tensorrt_llm/_torch/speculative/mtp.py +++ b/tensorrt_llm/_torch/speculative/mtp.py @@ -10,8 +10,11 @@ from ..model_config import ModelConfig from ..pyexecutor.guided_decoder import CapturableGuidedDecoder from ..pyexecutor.llm_request import LlmRequest, LlmRequestState from ..pyexecutor.resource_manager import BaseResourceManager, SlotManager -from ..pyexecutor.sampler import (SampleState, SampleStateTensors, TorchSampler, - add_token, int_tensor) +from ..pyexecutor.sampler import (Sampler, SampleState, SampleStateTensors, + TorchSampler, TorchStore, add_token, + int_tensor) +from ..pyexecutor.sampler_utils import (BEAM_0, SINGLE_BEAM_WIDTH, + handle_stop_single_beam) from ..pyexecutor.scheduler import ScheduledRequests from .interface import SpecMetadata @@ -206,40 +209,44 @@ class MTPSpecMetadata(SpecMetadata): self.slot_ids[:num_seqs].copy_(mtp_slot_ids, non_blocking=True) -class MTPSampler(TorchSampler): +class MTPStore(TorchStore): + + def __init__(self, *, max_draft_len: int, max_num_sequences: int, + max_beam_width: int): + super().__init__(max_draft_len=max_draft_len, + max_num_sequences=max_num_sequences, + max_beam_width=max_beam_width) + self.next_new_tokens = int_tensor( + (self.max_tokens, self.max_num_sequences, SINGLE_BEAM_WIDTH)) + self.next_draft_tokens = int_tensor( + (self.max_num_sequences, self.max_draft_len)) + self.new_tokens_lens = int_tensor((self.max_num_sequences, )) + + +class MTPSampler(Sampler): """ MTP sampler. """ SampleState = SampleStateMTP + def is_generation_model(self) -> bool: + return True + def __init__(self, args: TorchSampler.Args, *, nextn: int): self.mapping = None self.draft_len = nextn - super().__init__(args) - - @dataclass(frozen=True, kw_only=True) - class Store(TorchSampler.Store): - next_new_tokens: torch.Tensor - next_draft_tokens: torch.Tensor - new_tokens_lens: torch.Tensor - - def create_store(self) -> Store: - num_tokens, seq_slots, _ = self.NEW_TOKENS_SHAPE - draft_len = num_tokens - 1 - assert draft_len == self.draft_len - return self.Store( - new_tokens=int_tensor(self.NEW_TOKENS_SHAPE), - next_new_tokens=int_tensor(self.NEW_TOKENS_SHAPE), - next_draft_tokens=int_tensor((seq_slots, draft_len)), - new_tokens_lens=int_tensor((seq_slots, )), - ) + self.store = MTPStore(max_draft_len=nextn, + max_num_sequences=args.max_num_sequences, + max_beam_width=args.max_beam_width) + self.max_seq_len = args.max_seq_len def _request_common_handling(self, request: LlmRequest, next_draft_tokens: list[list[int]]): assert not request.py_return_context_logits, "return_context_logits not implemented for MTPSampler" assert not request.py_return_generation_logits, "return_generation_logits not implemented for MTPSampler" assert not request.py_return_log_probs, "return_log_probs not implemented for MTPSampler" + assert request.py_seq_slot is not None request.py_draft_tokens = next_draft_tokens[request.py_seq_slot] request.py_decoding_iter += 1 @@ -250,12 +257,12 @@ class MTPSampler(TorchSampler): new_tokens = state.host.new_tokens new_tokens_lens_list = state.host.new_tokens_lens.tolist() next_draft_tokens_list = state.host.next_draft_tokens.tolist() - beam_idx = self.BEAM + max_seq_len = self.max_seq_len for req in state.scheduled_requests.context_requests: if req.state == LlmRequestState.GENERATION_COMPLETE or req.context_remaining_length != 0: continue - new_token = add_token(req, new_tokens, beam=beam_idx) - self._handle_stop_criteria(req, new_token) + new_token = add_token(req, new_tokens, beam=BEAM_0) + handle_stop_single_beam(req, new_token, max_seq_len=max_seq_len) self._request_common_handling(req, next_draft_tokens_list) for req in state.scheduled_requests.generation_requests: @@ -263,8 +270,10 @@ class MTPSampler(TorchSampler): continue num_new_tokens = new_tokens_lens_list[req.py_seq_slot] for i in range(num_new_tokens): - new_token = add_token(req, new_tokens, beam=beam_idx, step=i) - if self._handle_stop_criteria(req, new_token): + new_token = add_token(req, new_tokens, beam=BEAM_0, step=i) + if handle_stop_single_beam(req, + new_token, + max_seq_len=max_seq_len): break req.py_num_accepted_draft_tokens = num_new_tokens - 1 req.py_rewind_len = self.draft_len - req.py_num_accepted_draft_tokens diff --git a/tensorrt_llm/bench/dataclasses/configuration.py b/tensorrt_llm/bench/dataclasses/configuration.py index a693333230..6d8e703ee4 100755 --- a/tensorrt_llm/bench/dataclasses/configuration.py +++ b/tensorrt_llm/bench/dataclasses/configuration.py @@ -84,8 +84,24 @@ class RuntimeConfig(BaseModel): backend_cache_config = llm_args.pop("kv_cache_config", {}) llm_args["kv_cache_config"] = backend_cache_config | kv_cache_config - return update_llm_args_with_extra_options(llm_args, - self.extra_llm_api_options) + updated_llm_args = update_llm_args_with_extra_options( + llm_args, self.extra_llm_api_options) + + if self.backend == "pytorch": + cuda_graph_config = updated_llm_args.pop( + "cuda_graph_config", llm_args["cuda_graph_config"]) + # Use runtime max_batch_size as cuda_graph_config.max_batch_size + # if both max_batch_size and batch_sizes are not set. + batch_sizes_set = cuda_graph_config.get("batch_sizes", + None) is not None + max_batch_size_set = cuda_graph_config.get("max_batch_size", + None) is not None + if not batch_sizes_set and not max_batch_size_set: + cuda_graph_config[ + "max_batch_size"] = self.settings_config.max_batch_size + updated_llm_args["cuda_graph_config"] = cuda_graph_config + + return updated_llm_args @model_validator(mode="after") def validate_full_config(self) -> RuntimeConfig: diff --git a/tensorrt_llm/executor/ipc.py b/tensorrt_llm/executor/ipc.py index 2a86f50d65..5d45ebe4c1 100644 --- a/tensorrt_llm/executor/ipc.py +++ b/tensorrt_llm/executor/ipc.py @@ -125,13 +125,36 @@ class ZeroMqQueue: # Send data without HMAC self.socket.send_pyobj(obj) - def put_noblock(self, obj: Any): + def put_noblock(self, + obj: Any, + *, + retry: int = 1, + wait_time: float = 0.001): + ''' + Put an object into the queue without blocking, and retry if the send fails. + NOTE: It won't raise any error if the send fails. + + Parameters: + obj (Any): The object to send. + retry (int): The number of times to retry sending the object. + wait_time (float): The time to wait before retrying. + ''' + + assert retry >= 0 and retry <= 10, "Retry must be between 0 and 10, adjust the wait_time if needed" + self.setup_lazily() with nvtx_range_debug("send", color="blue", category="IPC"): data = pickle.dumps(obj) # nosec B301 if self.use_hmac_encryption: data = self._sign_data(data) - self.socket.send(data, flags=zmq.NOBLOCK) + try: + self.socket.send(data, flags=zmq.NOBLOCK) + except zmq.Again: + if retry > 0: + time.sleep(wait_time) + self.put_noblock(obj, retry=retry - 1, wait_time=wait_time) + else: + logger.error(f"Failed to send object: {obj}") async def put_async(self, obj: Any): self.setup_lazily() diff --git a/tensorrt_llm/executor/proxy.py b/tensorrt_llm/executor/proxy.py index ec561bb291..bf60f7edb6 100644 --- a/tensorrt_llm/executor/proxy.py +++ b/tensorrt_llm/executor/proxy.py @@ -351,7 +351,7 @@ class GenerationExecutorProxy(GenerationExecutor): # notify the workers to quit if all(not f.done() for f in self.mpi_futures): - self.request_queue.put_noblock(None) + self.request_queue.put_noblock(None, retry=4) def shutdown(self): if not self.workers_started: diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 62db888a44..58e18ccc94 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -1041,6 +1041,9 @@ class KvCacheConfig(StrictBaseModel, PybindMirror): "The data type to use for the Mamba SSM cache. If set to 'auto', the data type will be inferred from the model config." ) + tokens_per_block: int = Field(default=32, + description="The number of tokens per block.") + def _to_pybind(self): return _KvCacheConfig( enable_block_reuse=self.enable_block_reuse, @@ -1946,6 +1949,9 @@ class BaseLlmArgs(StrictBaseModel): from tensorrt_llm._torch.speculative import suggest_spec_config spec_config = suggest_spec_config(max_batch_size) + if self.kv_cache_config is not None: + executor_config.tokens_per_block = self.kv_cache_config.tokens_per_block + update_executor_config( executor_config, backend=self.backend, diff --git a/tensorrt_llm/llmapi/mpi_session.py b/tensorrt_llm/llmapi/mpi_session.py index 7a25fa57f3..f361b977b7 100644 --- a/tensorrt_llm/llmapi/mpi_session.py +++ b/tensorrt_llm/llmapi/mpi_session.py @@ -435,7 +435,7 @@ class RemoteMpiCommSessionServer(): f"RemoteMpiCommSessionServer received all results, sending to client\n", "green") try: - self.queue.put_noblock(self.results) + self.queue.put_noblock(self.results, retry=2) except zmq.ZMQError as e: # The client could be shutdown first. if e.errno == zmq.EAGAIN: diff --git a/tensorrt_llm/llmapi/utils.py b/tensorrt_llm/llmapi/utils.py index 6500084190..d08e539750 100644 --- a/tensorrt_llm/llmapi/utils.py +++ b/tensorrt_llm/llmapi/utils.py @@ -518,6 +518,8 @@ def generate_api_docs_as_docstring(model: Type[BaseModel], # Format the argument documentation with 12 spaces indent for args arg_line = f"{indent} {field_name} ({type_str}): " + if status := field_info.get("status", None): + arg_line += f":tag:`{status}` " if field_description: arg_line += field_description.split('\n')[0] # First line with type @@ -557,20 +559,21 @@ class ApiParamTagger: ''' def __call__(self, cls: Type[BaseModel]) -> None: - self.process_pydantic_model(cls) + """ The main entry point to tag the api doc. """ + self._process_pydantic_model(cls) - def process_pydantic_model(self, cls: Type[BaseModel]) -> None: + def _process_pydantic_model(self, cls: Type[BaseModel]) -> None: """Process the Pydantic model to add tags to the fields. """ for field_name, field_info in cls.model_fields.items(): if field_info.json_schema_extra and 'status' in field_info.json_schema_extra: status = field_info.json_schema_extra['status'] - self.amend_pydantic_field_description_with_tags( + self._amend_pydantic_field_description_with_tags( cls, [field_name], status) - def amend_pydantic_field_description_with_tags(self, cls: Type[BaseModel], - field_names: list[str], - tag: str) -> None: + def _amend_pydantic_field_description_with_tags(self, cls: Type[BaseModel], + field_names: list[str], + tag: str) -> None: """Amend the description of the fields with tags. e.g. :tag:`beta` or :tag:`prototype` Args: diff --git a/tensorrt_llm/logger.py b/tensorrt_llm/logger.py index 27b10165d1..99d9ddaa58 100644 --- a/tensorrt_llm/logger.py +++ b/tensorrt_llm/logger.py @@ -109,6 +109,7 @@ class Logger(metaclass=Singleton): self._func_wrapper(severity)(" ".join(parts)) def log_once(self, severity, *msg, key): + assert key is not None, "key is required for log_once" if key not in self._appeared_keys: self._appeared_keys.add(key) self.log(severity, *msg) diff --git a/tensorrt_llm/serve/harmony_adapter.py b/tensorrt_llm/serve/harmony_adapter.py index a46e7c5ed4..2949965d72 100644 --- a/tensorrt_llm/serve/harmony_adapter.py +++ b/tensorrt_llm/serve/harmony_adapter.py @@ -6,7 +6,7 @@ import re import time import traceback import uuid -from typing import Any, AsyncGenerator, Literal +from typing import Any, List, Literal from openai_harmony import (Author, Conversation, DeveloperContent, HarmonyEncodingName, HarmonyError, Message, @@ -14,15 +14,15 @@ from openai_harmony import (Author, Conversation, DeveloperContent, SystemContent, TextContent, ToolDescription, load_harmony_encoding) -from tensorrt_llm.llmapi import RequestOutput from tensorrt_llm.logger import logger # yapf: disable -from .openai_protocol import (ChatCompletionMessageParam, ChatCompletionRequest, +from .openai_protocol import (ChatCompletionMessageParam, ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, - ChatCompletionStreamResponse, ChatMessage, + ChatCompletionStreamResponse, + ChatCompletionToolsParam, ChatMessage, DeltaFunctionCall, DeltaMessage, DeltaToolCall, UsageInfo) @@ -1485,36 +1485,72 @@ class HarmonyAdapter: return True -async def handle_streaming_response( - harmony_adapter: HarmonyAdapter, - generator: RequestOutput, - request_id: str, - request: ChatCompletionRequest, -) -> AsyncGenerator[str, None]: - """Handle streaming response with harmony format.""" +_SERVE_HARMONY_ADAPTER: HarmonyAdapter = None + + +def get_harmony_adapter(): + global _SERVE_HARMONY_ADAPTER + if _SERVE_HARMONY_ADAPTER is None: + _SERVE_HARMONY_ADAPTER = HarmonyAdapter() + + return _SERVE_HARMONY_ADAPTER + + +def handle_streaming_response(tools: List[ChatCompletionToolsParam], + tool_choice: str, outputs: List, model: str, + request_id: str, done: bool, + num_prompt_tokens: int): first_iteration = True - async for res in generator: - output = res.outputs[0] + output = outputs[0] - # Convert tools to dictionary format for harmony adapter (standard pattern) - tools_dict = None - if request.tools: - tools_dict = [tool.model_dump() for tool in request.tools] + # Convert tools to dictionary format for harmony adapter (standard pattern) + tools_dict = None + harmony_adapter = get_harmony_adapter() + if tools: + tools_dict = [tool.model_dump() for tool in tools] - # Get tool_choice from request - if "none", don't pass tools to parser - tool_choice = getattr(request, 'tool_choice', None) - if tool_choice == "none": - tools_for_parser = None + # Get tool_choice from request - if "none", don't pass tools to parser + if tool_choice == "none": + tools_for_parser = None + else: + tools_for_parser = tools_dict + + # Create OpenAI streaming responses + try: + res = [] + if done: + # Clean up state + harmony_adapter.cleanup_stream_state(request_id) + + usage_info = _create_usage_info(num_prompt_tokens, outputs) + + # Send final message with finish_reason + final_response = ChatCompletionStreamResponse( + model=model, + choices=[ + ChatCompletionResponseStreamChoice( + index=0, + delta=DeltaMessage(), + finish_reason=output.finish_reason, + stop_reason=output.stop_reason) + ], + ) + + final_response_json = final_response.model_dump_json( + exclude_none=True) + final_usage_chunk = ChatCompletionStreamResponse(choices=[], + model=model, + usage=usage_info) + final_usage_json = final_usage_chunk.model_dump_json( + exclude_none=True) + res.append(f"data: {final_response_json}\n\n") + res.append(f"data: {final_usage_json}\n\n") else: - tools_for_parser = tools_dict - - # Create OpenAI streaming responses - try: responses = harmony_adapter.create_openai_streaming_response( request_id=request_id, tokens=output.token_ids_diff, available_tools=tools_for_parser, - model_name=request.model, + model_name=model, tool_choice=tool_choice) # Send first response after receiving the first output if first_iteration: @@ -1525,64 +1561,44 @@ async def handle_streaming_response( delta=first_delta) first_response = ChatCompletionStreamResponse( - model=request.model, + model=model, choices=[choice], ) response_json = first_response.model_dump_json( exclude_none=True) - yield f"data: {response_json}\n\n" + res.append(f"data: {response_json}\n\n") - for response in responses: - yield response + res.extend(responses) - except Exception as e: - logger.error(f"Failed to create OpenAI streaming response: {e}") - logger.debug(f"Streaming error details: {traceback.format_exc()}") - # Clean up state - harmony_adapter.cleanup_stream_state(request_id) - raise e + return res - # Clean up state - harmony_adapter.cleanup_stream_state(request_id) - - # Send final message with finish_reason - output = generator.outputs[0] - final_response = ChatCompletionStreamResponse( - model=request.model, - choices=[ - ChatCompletionResponseStreamChoice( - index=0, - delta=DeltaMessage(), - finish_reason=output.finish_reason, - stop_reason=output.stop_reason) - ]) - - yield f"data: {final_response.model_dump_json(exclude_unset=True)}\n\n" - yield "data: [DONE]\n\n" + except Exception as e: + logger.error(f"Failed to create OpenAI streaming response: {e}") + logger.debug(f"Streaming error details: {traceback.format_exc()}") + # Clean up state + harmony_adapter.cleanup_stream_state(request_id) + raise e -async def handle_non_streaming_response( - harmony_adapter: HarmonyAdapter, promise: RequestOutput, - request: ChatCompletionRequest) -> ChatCompletionResponse: +def handle_non_streaming_response(tools: List[ChatCompletionToolsParam], + tool_choice: str, outputs: List, model: str, + num_prompt_tokens: int): """Handle non-streaming response with harmony format.""" - # Get final result - await promise - # Parse harmony output to OpenAI format # Convert tools to dictionary format for harmony adapter (standard pattern) tools_dict = None - if request.tools: - tools_dict = [tool.model_dump() for tool in request.tools] + harmony_adapter = get_harmony_adapter() + if tools: + tools_dict = [tool.model_dump() for tool in tools] # Get tool_choice from request - if "none", don't pass tools to parser - tool_choice = getattr(request, 'tool_choice', None) if tool_choice == "none": tools_for_parser = None else: tools_for_parser = tools_dict - output = promise.outputs[0] + output = outputs[0] parsed_output = harmony_adapter.harmony_output_to_openai( output.token_ids, tools_for_parser, tool_choice) @@ -1597,11 +1613,11 @@ async def handle_non_streaming_response( output.finish_reason) # Create usage info from metrics (RequestOutput doesn't have usage in v1) - usage_info = _create_usage_info(promise) + usage_info = _create_usage_info(num_prompt_tokens, outputs) # Create response response = ChatCompletionResponse( - model=request.model, + model=model, choices=[ ChatCompletionResponseChoice( index=0, @@ -1613,7 +1629,6 @@ async def handle_non_streaming_response( # Optional: Log if harmony parsing failed (for debugging) if parsed_output.get('_harmony_parsing_failed'): logger.warning("⚠️ Harmony parsing fell back to raw text decoding") - logger.debug(f"request\n\n{request}") logger.debug(f"response\n\n{response}\n") return response @@ -1646,15 +1661,10 @@ def _determine_finish_reason(parsed_output: dict[str, Any], return reason -def _create_usage_info(final_res: RequestOutput) -> UsageInfo: +def _create_usage_info(num_prompt_tokens, outputs) -> UsageInfo: """Create usage info from RequestOutput following serving_chat.py pattern.""" - # Calculate prompt tokens from prompt_token_ids and encoder_prompt_token_ids - assert final_res.prompt_token_ids is not None - num_prompt_tokens = len(final_res.prompt_token_ids) - # Calculate completion tokens from all outputs - num_generated_tokens = sum( - len(output.token_ids) for output in final_res.outputs) + num_generated_tokens = sum(len(output.token_ids) for output in outputs) # Create usage info usage = UsageInfo(prompt_tokens=num_prompt_tokens, diff --git a/tensorrt_llm/serve/openai_disagg_server.py b/tensorrt_llm/serve/openai_disagg_server.py index 9726efd881..495724b292 100644 --- a/tensorrt_llm/serve/openai_disagg_server.py +++ b/tensorrt_llm/serve/openai_disagg_server.py @@ -322,6 +322,8 @@ class OpenAIDisaggServer: raise ValueError("Disagg server returned more than one choice. This is currently not supported in disaggregated server.") if choices[0].disaggregated_params is None: raise ValueError("Context server did not return disaggregated params") + if choices[0].disaggregated_params.ctx_request_id is None: + raise ValueError("Invalid disaggregated params in context phase response.") return ctx_response diff --git a/tensorrt_llm/serve/openai_server.py b/tensorrt_llm/serve/openai_server.py index c622aed63f..aaac4ba8cf 100644 --- a/tensorrt_llm/serve/openai_server.py +++ b/tensorrt_llm/serve/openai_server.py @@ -44,9 +44,10 @@ from tensorrt_llm.serve.openai_protocol import (ChatCompletionRequest, UsageInfo, to_llm_disaggregated_params) from tensorrt_llm.serve.postprocess_handlers import ( - ChatPostprocArgs, CompletionPostprocArgs, chat_response_post_processor, - chat_stream_post_processor, completion_response_post_processor, - completion_stream_post_processor) + ChatCompletionPostprocArgs, ChatPostprocArgs, CompletionPostprocArgs, + chat_harmony_post_processor, chat_harmony_streaming_post_processor, + chat_response_post_processor, chat_stream_post_processor, + completion_response_post_processor, completion_stream_post_processor) from tensorrt_llm.serve.responses_utils import ConversationHistoryStore from tensorrt_llm.serve.responses_utils import \ create_response as responses_api_create_response @@ -57,8 +58,7 @@ from tensorrt_llm.serve.responses_utils import \ from tensorrt_llm.version import __version__ as VERSION from .._utils import nvtx_mark, set_prometheus_multiproc_dir -from .harmony_adapter import (HarmonyAdapter, handle_non_streaming_response, - handle_streaming_response, +from .harmony_adapter import (HarmonyAdapter, get_harmony_adapter, maybe_transform_reasoning_effort) # yapf: enale @@ -117,7 +117,11 @@ class OpenAIServer: # gpt-oss self.harmony_adapter: HarmonyAdapter | None = None - self.use_harmony = self.model_config.model_type == "gpt_oss" + disable_harmony = os.getenv("DISABLE_HARMONY_ADAPTER", "0") == "1" + if disable_harmony: + self.use_harmony = False + else: + self.use_harmony = (self.model_config.model_type == "gpt_oss") @asynccontextmanager async def lifespan(app: FastAPI): @@ -704,11 +708,35 @@ class OpenAIServer: Chat Completion API with harmony format support. Supports both streaming and non-streaming modes. """ + + async def create_harmony_response( + promise: RequestOutput, postproc_params: PostprocParams) -> ChatCompletionResponse: + await promise.aresult() + if self.postproc_worker_enabled: + chat_response =promise.outputs[0]._postprocess_result + else: + post_processor, args = postproc_params.post_processor, postproc_params.postproc_args + chat_response = post_processor(promise, args) + + return chat_response + + async def create_streaming_generator(promise: RequestOutput, postproc_params: PostprocParams): + if not self.postproc_worker_enabled: + post_processor, args = postproc_params.post_processor, postproc_params.postproc_args + + async for res in promise: + pp_results = res.outputs[0]._postprocess_result if self.postproc_worker_enabled else post_processor(res, args) + # await self._extract_metrics(res) + for pp_res in pp_results: + yield pp_res + + yield "data: [DONE]\n\n" + try: # Initialize HarmonyAdapter # NOTE: WAR for Disagg failure, may affect perf if no warmup if not self.harmony_adapter: - self.harmony_adapter = HarmonyAdapter() + self.harmony_adapter = get_harmony_adapter() # Convert Pydantic models to dictionaries for JSON serialization (standard pattern) tools_dict = None if request.tools: @@ -743,27 +771,37 @@ class OpenAIServer: vocab_size=self.tokenizer.tokenizer.vocab_size) sampling_params.detokenize = False # Harmony adapter handles detokenization + postproc_args = ChatCompletionPostprocArgs.from_request(request) + postproc_params = PostprocParams( + post_processor=chat_harmony_streaming_post_processor + if request.stream else chat_harmony_post_processor, + postproc_args=postproc_args, + ) + # Generate promise = self.llm.generate_async( inputs=harmony_tokens, sampling_params=sampling_params, + _postproc_params=postproc_params if self.postproc_worker_enabled else None, streaming=bool(request.stream), lora_request=request.lora_request, ) + postproc_args.request_id = promise.request_id + + if not self.postproc_worker_enabled: + postproc_args.num_prompt_tokens = len(promise.prompt_token_ids) + # Disconnect cancellation asyncio.create_task(self.await_disconnected(raw_request, promise)) # Handle streaming if request.stream: return StreamingResponse( - handle_streaming_response( - self.harmony_adapter, promise, - str(promise.request_id), request, - ), + content=create_streaming_generator(promise, postproc_params), media_type="text/event-stream" ) else: - response = await handle_non_streaming_response(self.harmony_adapter, promise, request) + response = await create_harmony_response(promise, postproc_params) return JSONResponse(response.model_dump()) except Exception as e: diff --git a/tensorrt_llm/serve/postprocess_handlers.py b/tensorrt_llm/serve/postprocess_handlers.py index 07db6e27a7..0fbcedb9da 100644 --- a/tensorrt_llm/serve/postprocess_handlers.py +++ b/tensorrt_llm/serve/postprocess_handlers.py @@ -9,6 +9,8 @@ from ..llmapi.reasoning_parser import (BaseReasoningParser, ReasoningParserFactory) from ..llmapi.tokenizer import TransformersTokenizer # yapf: disable +from .harmony_adapter import (handle_non_streaming_response, + handle_streaming_response) from .openai_protocol import (ChatCompletionLogProbs, ChatCompletionLogProbsContent, ChatCompletionNamedToolChoiceParam, @@ -24,7 +26,8 @@ from .openai_protocol import (ChatCompletionLogProbs, FunctionCall, StreamOptions, ToolCall, UsageInfo, to_disaggregated_params) -# yapf: enale +# yapf: enable + @dataclass(kw_only=True) class ChatPostprocArgs(PostprocArgs): @@ -57,8 +60,7 @@ class ChatPostprocArgs(PostprocArgs): ) -def create_logprobs(token_ids: List[int], - tokenizer: TransformersTokenizer, +def create_logprobs(token_ids: List[int], tokenizer: TransformersTokenizer, logprobs: List[float]) -> ChatCompletionLogProbs: assert len(token_ids) == len(logprobs), \ "token_ids and logprobs have different lengths" @@ -75,12 +77,14 @@ def create_logprobs(token_ids: List[int], return chat_logprobs -def apply_reasoning_parser(args: ChatPostprocArgs, output_index: int, text: str, streaming: bool) -> Tuple[bool, str, str]: +def apply_reasoning_parser(args: ChatPostprocArgs, output_index: int, text: str, + streaming: bool) -> Tuple[bool, str, str]: reasoning_parser = None if args.reasoning_parser is not None: if output_index not in args.reasoning_parser_dict: - args.reasoning_parser_dict[output_index] = ReasoningParserFactory.create_reasoning_parser( - args.reasoning_parser) + args.reasoning_parser_dict[ + output_index] = ReasoningParserFactory.create_reasoning_parser( + args.reasoning_parser) reasoning_parser = args.reasoning_parser_dict[output_index] in_reasoning = False @@ -97,7 +101,8 @@ def apply_reasoning_parser(args: ChatPostprocArgs, output_index: int, text: str, @nvtx_range_debug("chat_stream_post_processor") -def chat_stream_post_processor(rsp: GenerationResultBase, args: ChatPostprocArgs) -> List[str]: +def chat_stream_post_processor(rsp: GenerationResultBase, + args: ChatPostprocArgs) -> List[str]: def yield_first_chat(num_tokens: int, idx: int, @@ -128,9 +133,13 @@ def chat_stream_post_processor(rsp: GenerationResultBase, args: ChatPostprocArgs include_continuous_usage = False if args.first_iteration: for i in range(args.num_choices): - res.append(f"data: {yield_first_chat(prompt_tokens, i, role=args.role)} \n\n") + res.append( + f"data: {yield_first_chat(prompt_tokens, i, role=args.role)} \n\n" + ) if args.echo and args.last_message_content: - res.append(f"data: {yield_first_chat(prompt_tokens, i, content=args.last_message_content)} \n\n") + res.append( + f"data: {yield_first_chat(prompt_tokens, i, content=args.last_message_content)} \n\n" + ) args.first_iteration = False for output in rsp.outputs: @@ -158,14 +167,18 @@ def chat_stream_post_processor(rsp: GenerationResultBase, args: ChatPostprocArgs delta_message = DeltaMessage( content=delta_text, reasoning_content=reasoning_delta_text) - choice = ChatCompletionResponseStreamChoice(index=i, - delta=delta_message, - finish_reason=None, - avg_decoded_tokens_per_iter=getattr(rsp, 'avg_decoded_tokens_per_iter', None)) + choice = ChatCompletionResponseStreamChoice( + index=i, + delta=delta_message, + finish_reason=None, + avg_decoded_tokens_per_iter=getattr(rsp, + 'avg_decoded_tokens_per_iter', + None)) if args.return_logprobs: logprobs = output.logprobs_diff token_ids = output.token_ids_diff - choice.logprobs = create_logprobs(token_ids, args.tokenizer, logprobs) + choice.logprobs = create_logprobs(token_ids, args.tokenizer, + logprobs) if output.finish_reason is not None: choice.finish_reason = output.finish_reason choice.stop_reason = output.stop_reason @@ -179,57 +192,62 @@ def chat_stream_post_processor(rsp: GenerationResultBase, args: ChatPostprocArgs res.append(f"data: {data}\n\n") if include_usage and rsp._done: - completion_tokens = sum(output.length - for output in rsp.outputs) + completion_tokens = sum(output.length for output in rsp.outputs) final_usage = UsageInfo( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens, ) - final_usage_chunk = ChatCompletionStreamResponse( - choices=[], model=args.model, usage=final_usage) + final_usage_chunk = ChatCompletionStreamResponse(choices=[], + model=args.model, + usage=final_usage) final_usage_data = final_usage_chunk.model_dump_json() res.append(f"data: {final_usage_data}\n\n") return res @nvtx_range_debug("chat_response_post_processor") -def chat_response_post_processor(rsp: GenerationResultBase, args: ChatPostprocArgs) -> ChatCompletionResponse: +def chat_response_post_processor( + rsp: GenerationResultBase, + args: ChatPostprocArgs) -> ChatCompletionResponse: choices: List[ChatCompletionResponseChoice] = [] role = args.role for output in rsp.outputs: _, text, reasoning_text = apply_reasoning_parser( args, output.index, output.text, False) - if args.tool_choice and isinstance( - args.tool_choice, - ChatCompletionNamedToolChoiceParam): + if args.tool_choice and isinstance(args.tool_choice, + ChatCompletionNamedToolChoiceParam): message = ChatMessage( role=role, content="", tool_calls=[ ToolCall(function=FunctionCall( - name=args.tool_choice.function.name, - arguments=text)) + name=args.tool_choice.function.name, arguments=text)) ]) else: if text is None: text = "" - message = ChatMessage( - role=role, content=text, reasoning_content=reasoning_text) - disaggregated_params = to_disaggregated_params(output.disaggregated_params) + message = ChatMessage(role=role, + content=text, + reasoning_content=reasoning_text) + disaggregated_params = to_disaggregated_params( + output.disaggregated_params) choice = ChatCompletionResponseChoice( index=output.index, message=message, finish_reason=output.finish_reason, stop_reason=output.stop_reason, disaggregated_params=disaggregated_params, - avg_decoded_tokens_per_iter=getattr(rsp, 'avg_decoded_tokens_per_iter', None), + avg_decoded_tokens_per_iter=getattr(rsp, + 'avg_decoded_tokens_per_iter', + None), ) if args.return_logprobs: - choice.logprobs = create_logprobs(output.token_ids, args.tokenizer, output.logprobs) + choice.logprobs = create_logprobs(output.token_ids, args.tokenizer, + output.logprobs) choices.append(choice) if args.echo and args.last_message_content: @@ -238,8 +256,7 @@ def chat_response_post_processor(rsp: GenerationResultBase, args: ChatPostprocAr choice.message.content = full_message num_prompt_tokens = args.num_prompt_tokens - num_generated_tokens = sum( - len(output.token_ids) for output in rsp.outputs) + num_generated_tokens = sum(len(output.token_ids) for output in rsp.outputs) usage = UsageInfo( prompt_tokens=num_prompt_tokens, completion_tokens=num_generated_tokens, @@ -275,7 +292,8 @@ class CompletionPostprocArgs(PostprocArgs): @nvtx_range_debug("completion_stream_post_processor") -def completion_stream_post_processor(rsp: DetokenizedGenerationResultBase, args: CompletionPostprocArgs) -> List[str]: +def completion_stream_post_processor(rsp: DetokenizedGenerationResultBase, + args: CompletionPostprocArgs) -> List[str]: res: List[str] = [] prompt_tokens = args.num_prompt_tokens if stream_option := args.stream_options: @@ -293,9 +311,11 @@ def completion_stream_post_processor(rsp: DetokenizedGenerationResultBase, args: index=args.prompt_idx * args.num_choices + output.index, text=delta_text if args.detokenize else "", token_ids=None if args.detokenize else output.token_ids_diff, - finish_reason = output.finish_reason, - stop_reason = output.stop_reason, - avg_decoded_tokens_per_iter=getattr(rsp, 'avg_decoded_tokens_per_iter', None), + finish_reason=output.finish_reason, + stop_reason=output.stop_reason, + avg_decoded_tokens_per_iter=getattr(rsp, + 'avg_decoded_tokens_per_iter', + None), ) chunk = CompletionStreamResponse(model=args.model, choices=[choice]) if include_continuous_usage: @@ -306,16 +326,16 @@ def completion_stream_post_processor(rsp: DetokenizedGenerationResultBase, args: res.append(f"data: {data}\n\n") if include_usage and rsp._done: - completion_tokens = sum(output.length - for output in rsp.outputs) + completion_tokens = sum(output.length for output in rsp.outputs) final_usage = UsageInfo( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens, ) - final_usage_chunk = ChatCompletionStreamResponse( - choices=[], model=args.model, usage=final_usage) + final_usage_chunk = ChatCompletionStreamResponse(choices=[], + model=args.model, + usage=final_usage) final_usage_data = final_usage_chunk.model_dump_json() res.append(f"data: {final_usage_data}\n\n") args.first_iteration = False @@ -323,7 +343,9 @@ def completion_stream_post_processor(rsp: DetokenizedGenerationResultBase, args: @nvtx_range_debug("completion_response_post_processor") -def completion_response_post_processor(rsp: GenerationResult, args: CompletionPostprocArgs) -> CompletionResponse: +def completion_response_post_processor( + rsp: GenerationResult, + args: CompletionPostprocArgs) -> CompletionResponse: prompt_tokens = args.num_prompt_tokens completion_tokens = 0 choices = [] @@ -331,23 +353,75 @@ def completion_response_post_processor(rsp: GenerationResult, args: CompletionPo text = output.text if args.echo: text = args.prompt + text - disaggregated_params = to_disaggregated_params(output.disaggregated_params) + disaggregated_params = to_disaggregated_params( + output.disaggregated_params) choice = CompletionResponseChoice( text=text if args.detokenize else "", token_ids=None if args.detokenize else output.token_ids, index=args.prompt_idx * args.num_choices + output.index, disaggregated_params=disaggregated_params, - context_logits=None if rsp.context_logits is None else rsp.context_logits.tolist(), + context_logits=None + if rsp.context_logits is None else rsp.context_logits.tolist(), stop_reason=output.stop_reason, finish_reason=output.finish_reason, - avg_decoded_tokens_per_iter=getattr(rsp, 'avg_decoded_tokens_per_iter', None), + avg_decoded_tokens_per_iter=getattr(rsp, + 'avg_decoded_tokens_per_iter', + None), ) completion_tokens += output.length choices.append(choice) usage = UsageInfo(prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=completion_tokens + prompt_tokens) - response = CompletionResponse(choices=choices, model=args.model, usage=usage) + completion_tokens=completion_tokens, + total_tokens=completion_tokens + prompt_tokens) + response = CompletionResponse(choices=choices, + model=args.model, + usage=usage) + return response + + +@dataclass(kw_only=True) +class ChatCompletionPostprocArgs(PostprocArgs): + model: str + tools: Optional[List[ChatCompletionToolsParam]] + tool_choice: Optional[Union[Literal["none", "auto"], + ChatCompletionNamedToolChoiceParam]] + request_id: Optional[int] = None + + @classmethod + def from_request(cls, request: ChatCompletionRequest): + return cls( + model=request.model, + tools=request.tools, + tool_choice=request.tool_choice, + ) + + +@nvtx_range_debug("chat_harmony_post_processor") +def chat_harmony_post_processor( + rsp: GenerationResult, + args: ChatCompletionPostprocArgs) -> ChatCompletionResponse: + response = handle_non_streaming_response( + tools=args.tools, + tool_choice=args.tool_choice, + outputs=rsp.outputs, + model=args.model, + num_prompt_tokens=args.num_prompt_tokens, + ) + return response + + +@nvtx_range_debug("chat_harmony_streaming_post_processor") +def chat_harmony_streaming_post_processor( + rsp: GenerationResult, args: ChatCompletionPostprocArgs) -> List[str]: + response = handle_streaming_response( + tools=args.tools, + tool_choice=args.tool_choice, + outputs=rsp.outputs, + model=args.model, + request_id=args.request_id, + done=rsp._done, + num_prompt_tokens=args.num_prompt_tokens, + ) return response diff --git a/tests/integration/defs/accuracy/references/gpqa_diamond.yaml b/tests/integration/defs/accuracy/references/gpqa_diamond.yaml index f729cef1bd..dde3b53876 100644 --- a/tests/integration/defs/accuracy/references/gpqa_diamond.yaml +++ b/tests/integration/defs/accuracy/references/gpqa_diamond.yaml @@ -8,6 +8,9 @@ meta-llama/Llama-3.3-70B-Instruct: accuracy: 48.03 - quant_algo: FP8 accuracy: 48.03 + - quant_algo: FP8 + kv_cache_quant_algo: FP8 + accuracy: 48.03 deepseek-ai/DeepSeek-R1: - quant_algo: NVFP4 accuracy: 70.45 diff --git a/tests/integration/defs/accuracy/test_disaggregated_serving.py b/tests/integration/defs/accuracy/test_disaggregated_serving.py index 973765b2f0..6a70628f90 100644 --- a/tests/integration/defs/accuracy/test_disaggregated_serving.py +++ b/tests/integration/defs/accuracy/test_disaggregated_serving.py @@ -602,9 +602,9 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness): @parametrize_with_ids("gen_tp", [1, 2]) @pytest.mark.parametrize("testset", ["GSM8K", "MMLU"]) def test_ctx_pp_gen_tp_asymmetric(self, ctx_pp, gen_tp, testset): - if ctx_pp * gen_tp * 2 > get_device_count(): + if ctx_pp + gen_tp > get_device_count(): pytest.skip( - f"Not enough devices for ctx_pp={ctx_pp}*gen_tp={gen_tp} test") + f"Not enough devices for ctx_pp={ctx_pp}+gen_tp={gen_tp} test") return run_parallel_test(self.MODEL_NAME, self.MODEL_PATH, ctx_pp, 1, 1, gen_tp, 1, 1, [get_accuracy_task(testset)]) diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index dfadc5f05d..8b7c723b22 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -1565,6 +1565,11 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness): @parametrize_with_ids("moe_backend", ["CUTLASS", "TRTLLM"]) def test_nvfp4(self, fp8kv, attention_dp, cuda_graph, overlap_scheduler, torch_compile, mtp_nextn, moe_backend): + if moe_backend == "TRTLLM" and (get_sm_version() == 120 + or get_sm_version() == 121): + pytest.skip( + "MOE TRTLLM backend does not support SM version 120 or 121") + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.75) torch_compile_config = TorchCompileConfig( enable_fullgraph=True, @@ -1613,8 +1618,10 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness): torch_compile, mtp_nextn, moe_backend): if torch_compile and pp_size > 1: pytest.skip("PP with torch.compile is not supported yet.") - if moe_backend == "TRTLLM" and get_sm_version() == 120: - pytest.skip("MOE TRTLLM backend does not support SM version 120") + if moe_backend == "TRTLLM" and (get_sm_version() == 120 + or get_sm_version() == 121): + pytest.skip( + "MOE TRTLLM backend does not support SM version 120 or 121") kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.75) # Picewise Cuda Graph cannot be enabled for nvfp4 attention dp. torch_compile_config = TorchCompileConfig( @@ -1885,6 +1892,11 @@ class TestDeepSeekR1(LlmapiAccuracyTestHarness): def test_nvfp4_multi_gpus(self, tp_size, pp_size, ep_size, mtp_nextn, fp8kv, attention_dp, cuda_graph, overlap_scheduler, max_batch_size, moe_backend): + if moe_backend == "TRTLLM" and (get_sm_version() == 120 + or get_sm_version() == 121): + pytest.skip( + "MOE TRTLLM backend does not support SM version 120 or 121") + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.70) pytorch_config = dict( disable_overlap_scheduler=not overlap_scheduler, @@ -2509,6 +2521,11 @@ class TestQwen3_30B_A3B(LlmapiAccuracyTestHarness): torch_compile, ): + if moe_backend == "TRTLLM" and (get_sm_version() == 120 + or get_sm_version() == 121): + pytest.skip( + "MOE TRTLLM backend does not support SM version 120 or 121") + torch_compile_config = TorchCompileConfig( enable_fullgraph=True, enable_piecewise_cuda_graph=cuda_graph, @@ -2700,6 +2717,11 @@ class TestQwen3_235B_A22B(LlmapiAccuracyTestHarness): def test_nvfp4(self, tp_size, pp_size, ep_size, attention_dp, cuda_graph, overlap_scheduler, moe_backend, eagle3): + if moe_backend == "TRTLLM" and (get_sm_version() == 120 + or get_sm_version() == 121): + pytest.skip( + "MOE TRTLLM backend does not support SM version 120 or 121") + pytorch_config = dict( disable_overlap_scheduler=not overlap_scheduler, cuda_graph_config=CudaGraphConfig() if cuda_graph else None, @@ -2779,7 +2801,7 @@ class TestPhi4MiniInstruct(LlmapiAccuracyTestHarness): MODEL_PATH = f"{llm_models_root()}/Phi-4-mini-instruct" def test_auto_dtype(self): - with LLM(self.MODEL_PATH) as llm: + with LLM(self.MODEL_PATH, max_seq_len=4096) as llm: task = CnnDailymail(self.MODEL_NAME) task.evaluate(llm) task = MMLU(self.MODEL_NAME) @@ -3046,10 +3068,8 @@ class TestGPTOSS(LlmapiAccuracyTestHarness): class TestEXAONE4(LlmapiAccuracyTestHarness): MODEL_NAME = "LGAI-EXAONE/EXAONE-4.0-32B" - kv_cache_config = KvCacheConfig( - enable_block_reuse=False, - enable_partial_reuse=False, - max_attention_window=[4096, 4096, 4096, 131072]) + kv_cache_config = KvCacheConfig(enable_block_reuse=False, + enable_partial_reuse=False) def test_auto_dtype(self): model_path = f"{llm_models_root()}/EXAONE-4.0-32B" diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_diff_max_tokens.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_diff_max_tokens.yaml index 5276de524a..26d1f6b6c1 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_diff_max_tokens.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_diff_max_tokens.yaml @@ -8,7 +8,7 @@ disable_overlap_scheduler: True context_servers: num_instances: 1 max_num_tokens: 512 - max_batch_size: 256 + max_batch_size: 64 cache_transceiver_config: backend: DEFAULT urls: @@ -16,7 +16,7 @@ context_servers: generation_servers: num_instances: 1 max_num_tokens: 256 - max_batch_size: 128 + max_batch_size: 32 cache_transceiver_config: backend: DEFAULT urls: diff --git a/tests/integration/defs/disaggregated/test_disaggregated.py b/tests/integration/defs/disaggregated/test_disaggregated.py index 46c393ab48..89719b395d 100644 --- a/tests/integration/defs/disaggregated/test_disaggregated.py +++ b/tests/integration/defs/disaggregated/test_disaggregated.py @@ -1302,7 +1302,7 @@ def get_config_for_benchmark(model_root, backend): "num_instances": 1, "max_batch_size": 2, "max_num_tokens": 384, - "max_seq_len": 320, + "max_seq_len": 384, "tensor_parallel_size": 1, "pipeline_parallel_size": 1, "disable_overlap_scheduler": True, @@ -1318,7 +1318,7 @@ def get_config_for_benchmark(model_root, backend): "pipeline_parallel_size": 1, "max_batch_size": 2, "max_num_tokens": 384, - "max_seq_len": 320, + "max_seq_len": 384, "cache_transceiver_config": { "backend": backend, "max_tokens_in_buffer": 512, diff --git a/tests/integration/defs/disaggregated/test_disaggregated_single_gpu.py b/tests/integration/defs/disaggregated/test_disaggregated_single_gpu.py index 93611de040..b49b4afb7c 100644 --- a/tests/integration/defs/disaggregated/test_disaggregated_single_gpu.py +++ b/tests/integration/defs/disaggregated/test_disaggregated_single_gpu.py @@ -36,6 +36,42 @@ MODEL_PATHS = { } +def mpi_publish_name(): + port_name = None + try: + port_name = MPI.Open_port() + MPI.Publish_name('my_port', port_name) + except MPI.Exception as e: + print(f"Error publishing port name: {e}") + raise e + except Exception as e: + print(f"Unexpected error publishing port name: {e}") + raise e + + return port_name + + +def mpi_initialize_intercomm(port_name): + intercomm = None + try: + intercomm = MPI.COMM_SELF.Accept(port_name) + except MPI.Exception as e: + print(f"Error accepting intercomm: {e}", flush=True) + raise + except Exception as e: + print(f"Unexpected error accepting intercomm: {e}", flush=True) + raise + return intercomm + + +def mpi_send_termination_request(intercomm): + if intercomm is not None: + # Send termination requests + intercomm.send(None, dest=0, tag=MPI_REQUEST) + intercomm.send(None, dest=1, tag=MPI_REQUEST) + print("Sent termination requests to the workers.") + + def model_path(model_name): llm_models_root = os.environ["LLM_MODELS_ROOT"] for name, path in MODEL_PATHS.items(): @@ -48,8 +84,15 @@ async def run_worker(kv_cache_config, cache_transceiver_config, pytorch_config, model_name, rank): assert isinstance(pytorch_config, dict) print(f"Running worker {rank}") - port_name = MPI.Lookup_name('my_port') - intercomm = MPI.COMM_WORLD.Connect(port_name) + try: + port_name = MPI.Lookup_name('my_port') + intercomm = MPI.COMM_WORLD.Connect(port_name) + except MPI.Exception as e: + print(f"Error publishing port name: {e}") + raise e + except Exception as e: + print(f"Unexpected error publishing port name: {e}") + raise e session = MPI.COMM_WORLD.Split(color=rank, key=0) set_mpi_comm(session) @@ -139,8 +182,7 @@ def verify_disaggregated(model, generation_overlap, enable_cuda_graph, prompt, zip(kv_cache_configs, cache_transceiver_configs, worker_pytorch_configs, model_names, ranks)) - port_name = MPI.Open_port() - MPI.Publish_name('my_port', port_name) + port_name = mpi_publish_name() with MPIPoolExecutor(max_workers=2, env={"UCX_TLS": "^ib"}) as executor: futures = [] @@ -152,9 +194,10 @@ def verify_disaggregated(model, generation_overlap, enable_cuda_graph, prompt, print(f"Error in worker {worker_arg}: {e}") raise e + intercomm = None try: - print("Launched all the workers.") - intercomm = MPI.COMM_SELF.Accept(port_name) + print("Launched all the workers.", flush=True) + intercomm = mpi_initialize_intercomm(port_name) for _ in range(2): intercomm.recv(tag=MPI_READY) @@ -187,14 +230,15 @@ def verify_disaggregated(model, generation_overlap, enable_cuda_graph, prompt, output = responses[0] assert output[0].text == expected_output assert output[0].token_ids == expected_output_ids - + except Exception as e: + print(f"Exception encountered: {e}", flush=True) + raise e finally: - # Send termination requests - intercomm.send(None, dest=0, tag=MPI_REQUEST) - intercomm.send(None, dest=1, tag=MPI_REQUEST) - print("Sent termination requests to the workers.") + print("Sending termination request", flush=True) + mpi_send_termination_request(intercomm) # Wait for all futures to complete + print("Waiting for all workers to terminate. ", flush=True) for future in futures: future.result() print("All workers terminated.") @@ -282,8 +326,7 @@ def test_disaggregated_llama_context_capacity(model, enable_cuda_graph, zip(kv_cache_configs, cache_transceiver_configs, worker_pytorch_configs, model_names, ranks)) - port_name = MPI.Open_port() - MPI.Publish_name('my_port', port_name) + port_name = mpi_publish_name() prompt = "European Union is a political and economic union of 27 countries. The European Union is headquartered in Brussels, Belgium. The first president of the European Union was Jean-Claude Juncker. The current president is Ursula von der Leyen. The European Union is a major economic and political entity." @@ -297,9 +340,10 @@ def test_disaggregated_llama_context_capacity(model, enable_cuda_graph, print(f"Error in worker {worker_arg}: {e}") raise e + intercomm = None try: print("Launched all the workers.") - intercomm = MPI.COMM_SELF.Accept(port_name) + intercomm = mpi_initialize_intercomm(port_name) for _ in range(2): intercomm.recv(tag=MPI_READY) @@ -334,11 +378,11 @@ def test_disaggregated_llama_context_capacity(model, enable_cuda_graph, intercomm.send(requests, dest=1, tag=MPI_REQUEST) output = intercomm.recv(source=1, tag=MPI_RESULT) + except MPI.Exception as e: + print(f"MPI Error") + raise e finally: - # Send termination requests - intercomm.send(None, dest=0, tag=MPI_REQUEST) - intercomm.send(None, dest=1, tag=MPI_REQUEST) - print("Sent termination requests to the workers.") + mpi_send_termination_request(intercomm) # Wait for all futures to complete for future in futures: @@ -387,8 +431,7 @@ def test_disaggregated_spec_dec_batch_slot_limit(model, spec_dec_model_path, zip(kv_cache_configs, cache_transceiver_configs, worker_pytorch_configs, model_names, ranks)) - port_name = MPI.Open_port() - MPI.Publish_name('my_port', port_name) + port_name = mpi_publish_name() prompt = "What is the capital of Germany?" @@ -402,9 +445,10 @@ def test_disaggregated_spec_dec_batch_slot_limit(model, spec_dec_model_path, print(f"Error in worker {worker_arg}: {e}") raise e + intercomm = None try: print("Launched all the workers.") - intercomm = MPI.COMM_SELF.Accept(port_name) + intercomm = mpi_initialize_intercomm(port_name) for _ in range(2): intercomm.recv(tag=MPI_READY) @@ -438,11 +482,11 @@ def test_disaggregated_spec_dec_batch_slot_limit(model, spec_dec_model_path, intercomm.send(requests, dest=1, tag=MPI_REQUEST) output = intercomm.recv(source=1, tag=MPI_RESULT) + except MPI.Exception as e: + print(f"MPI Error") + raise e finally: - # Send termination requests - intercomm.send(None, dest=0, tag=MPI_REQUEST) - intercomm.send(None, dest=1, tag=MPI_REQUEST) - print("Sent termination requests to the workers.") + mpi_send_termination_request(intercomm) # Wait for all futures to complete for future in futures: diff --git a/tests/integration/defs/local_venv.py b/tests/integration/defs/local_venv.py index 4e72ad8ecb..ad5cb66d2a 100644 --- a/tests/integration/defs/local_venv.py +++ b/tests/integration/defs/local_venv.py @@ -23,7 +23,7 @@ class PythonVenvRunnerImpl(PythonRunnerInterface): venv_dir (str): Path to the virtualenv root directory, or None if this is an externally-built virtualenv venv_bin (str): Path to the Python executable to use when running tests - workspace (str): Path to the TURTLE workspace + workspace (str): Path to the test workspace """ def __init__(self, pip_opts, venv_dir, venv_bin, workspace): diff --git a/tests/integration/defs/perf/gpu_clock_lock.py b/tests/integration/defs/perf/gpu_clock_lock.py index 61c86b89b9..5687356584 100644 --- a/tests/integration/defs/perf/gpu_clock_lock.py +++ b/tests/integration/defs/perf/gpu_clock_lock.py @@ -189,7 +189,7 @@ class GPUClockLock: # Initialize thread self._thread = threading.Thread( target=self._monitoring_thread, - name="TURTLE - GPUMonitor", + name="LLM Test - GPUMonitor", kwargs={"interval_ms": self._interval_ms}) self._thread.daemon = True self._thread.start() diff --git a/tests/integration/defs/perf/pytorch_model_config.py b/tests/integration/defs/perf/pytorch_model_config.py index ab0e0bf08d..49915d3b47 100644 --- a/tests/integration/defs/perf/pytorch_model_config.py +++ b/tests/integration/defs/perf/pytorch_model_config.py @@ -181,10 +181,19 @@ def get_model_yaml_config(model_label: str, # lora-specific change for pytorch if 'pytorch' in model_label and 'loras' in model_label: + # Derive the requested number of adapters from model_label (segment like "loras:X") + lora_count = 1 + for part in model_label.split('-'): + if part.startswith('loras:'): + lora_count = max(1, int(part.split(':', 1)[1])) + break + lora_config = { 'lora_config': { 'lora_dir': lora_dirs if lora_dirs is not None else [], - 'max_lora_rank': 64 + 'max_lora_rank': 64, + 'max_loras': lora_count, + 'max_cpu_loras': lora_count, } } if 'phi_4_multimodal_instruct' in model_label: diff --git a/tests/integration/defs/test_e2e.py b/tests/integration/defs/test_e2e.py index 0def4787de..21bf49b363 100644 --- a/tests/integration/defs/test_e2e.py +++ b/tests/integration/defs/test_e2e.py @@ -1755,7 +1755,7 @@ def parse_output(text): for item in text_lists: item = item.replace(os.linesep, "") while True: - match = re.search(r"(Generated text: \'(.*?)\')", item, + match = re.search(r'Generated text: ([\'"])(.*?)\1', item, re.MULTILINE) if match is None: break @@ -2299,7 +2299,8 @@ def test_ptp_quickstart_advanced_mixed_precision(llm_root, llm_venv): marks=pytest.mark.skip_less_device_memory(80000)), pytest.param("gemma-3-27b-it", "gemma/gemma-3-27b-it", - marks=pytest.mark.skip_less_device_memory(80000)), + marks=(skip_post_blackwell, + pytest.mark.skip_less_device_memory(80000))), ]) def test_ptp_quickstart_multimodal(llm_root, llm_venv, model_name, model_path, modality, use_cuda_graph): @@ -2407,9 +2408,9 @@ def test_ptp_quickstart_multimodal(llm_root, llm_venv, model_name, model_path, }, "gemma-3-27b-it": { "image": [ - ["dramatic", "turbulent", "waves", "ocean", "overcast"], - ["half", "dome", "yosemite", "landmark", "rounded"], - ["flowing", "traffic", "vehicles", "road", "Changi"], + ["natural", "turbulent", "dramatic", "scene", "wave"], + ["image", "famous", "rock", "granite", "landmark"], + ["traffic", "moderate", "heavy", "flowing", "cars"], ], }, } @@ -2600,9 +2601,10 @@ def test_ptp_quickstart_multimodal_phi4mm(llm_root, llm_venv, modality): @pytest.mark.skip_less_device(2) @pytest.mark.skip_less_device_memory(80000) @pytest.mark.parametrize("model_name,model_path", [ - ("gemma-3-27b-it", "gemma/gemma-3-27b-it"), ("mistral-small-3.1-24b-instruct", "Mistral-Small-3.1-24B-Instruct-2503"), ("Phi-4-multimodal-instruct", "multimodals/Phi-4-multimodal-instruct"), + pytest.param( + "gemma-3-27b-it", "gemma/gemma-3-27b-it", marks=skip_post_blackwell), ]) def test_ptp_quickstart_multimodal_2gpu(llm_root, llm_venv, model_name, model_path): @@ -2645,8 +2647,8 @@ def test_ptp_quickstart_multimodal_2gpu(llm_root, llm_venv, model_name, }, "Phi-4-multimodal-instruct": { "image": [ - ["image", "depicts", "mountain", "half", "rock"], - ["road", "car", "lane", "traffic", "bus"], + ["object", "mountain", "weather", "clear", "clouds"], + ["traffic", "road", "vehicles", "cars", "bus"], ], }, } @@ -2674,6 +2676,8 @@ def test_ptp_quickstart_multimodal_2gpu(llm_root, llm_venv, model_name, cmd.append("--image_format=pil") cmd.append("--attention_backend=FLASHINFER") cmd.append("--disable_kv_cache_reuse") + cmd.append("--kv_cache_fraction=0.5") + cmd.append("--max_seq_len=1024") elif model_name == "Phi-4-multimodal-instruct": # Set max_seq_len to 4096 to use short rope factor. cmd.append("--max_seq_len=4096") @@ -2702,9 +2706,10 @@ def test_ptp_quickstart_multimodal_2gpu(llm_root, llm_venv, model_name, @pytest.mark.skip_less_device_memory(80000) @pytest.mark.parametrize("model_name,model_path", [ - ("gemma-3-27b-it", "gemma/gemma-3-27b-it"), ("mistral-small-3.1-24b-instruct", "Mistral-Small-3.1-24B-Instruct-2503"), ("Phi-4-multimodal-instruct", "multimodals/Phi-4-multimodal-instruct"), + pytest.param( + "gemma-3-27b-it", "gemma/gemma-3-27b-it", marks=skip_post_blackwell), ]) def test_ptp_quickstart_multimodal_multiturn(llm_root, llm_venv, model_name, model_path): @@ -2770,6 +2775,9 @@ def test_ptp_quickstart_multimodal_multiturn(llm_root, llm_venv, model_name, cmd.append("--image_format=pil") cmd.append("--attention_backend=FLASHINFER") cmd.append("--disable_kv_cache_reuse") + cmd.append("--kv_cache_fraction=0.5") + cmd.append("--max_seq_len=1024") + elif model_name == "Phi-4-multimodal-instruct": # Set max_seq_len to 4096 to use short rope factor. cmd.append("--max_seq_len=4096") diff --git a/tests/integration/defs/triton_server/conftest.py b/tests/integration/defs/triton_server/conftest.py index d66bc0f09d..a277a30d90 100644 --- a/tests/integration/defs/triton_server/conftest.py +++ b/tests/integration/defs/triton_server/conftest.py @@ -69,11 +69,6 @@ def test_case_name(request): return request.node.nodeid -@pytest.fixture(scope="session") -def output_dir(request): - return request.config._trt_config["output_dir"] - - @pytest.fixture(scope="session") def llm_backend_root(): llm_root = os.environ.get("LLM_ROOT", find_repo_root()) @@ -655,10 +650,10 @@ def install_root_requirements(llm_backend_root): @pytest.fixture(scope="session") def output_dir(request): - if USE_TURTLE: - return request.config._trt_config["output_dir"] - else: - return request.config.getoption("--output-dir") + output = request.config.getoption("--output-dir") + if output: + os.makedirs(str(output), exist_ok=True) + return output def deselect_by_regex(regexp, items, test_prefix, config): diff --git a/tests/integration/test_lists/test-db/l0_a10.yml b/tests/integration/test_lists/test-db/l0_a10.yml index 138e7e2376..ba1420e64e 100644 --- a/tests/integration/test_lists/test-db/l0_a10.yml +++ b/tests/integration/test_lists/test-db/l0_a10.yml @@ -39,6 +39,13 @@ l0_a10: - test_e2e.py::test_openai_chat_example[pytorch] TIMEOUT (90) - test_e2e.py::test_trtllm_bench_request_rate_and_concurrency[enable_concurrency-] - test_e2e.py::test_trtllm_bench_invalid_token_pytorch[TinyLlama-1.1B-Chat-v1.0-TinyLlama-1.1B-Chat-v1.0] + # llmapi + - unittest/llmapi/test_llm_utils.py + - unittest/llmapi/test_gc_utils.py + - unittest/llmapi/test_reasoning_parser.py + - unittest/llmapi/test_serialization.py + - unittest/llmapi/test_utils.py + - unittest/llmapi/test_llm_args.py - condition: ranges: system_gpu_count: @@ -114,12 +121,6 @@ l0_a10: - unittest/bindings - unittest/test_model_runner_cpp.py - unittest/llmapi/test_build_cache.py - - unittest/llmapi/test_llm_utils.py - - unittest/llmapi/test_gc_utils.py - - unittest/llmapi/test_reasoning_parser.py - - unittest/llmapi/test_serialization.py - - unittest/llmapi/test_utils.py - - unittest/llmapi/test_llm_args.py - accuracy/test_cli_flow.py::TestGpt2::test_auto_dtype # 1 min - accuracy/test_cli_flow.py::TestGpt2::test_beam_search # 1 min - accuracy/test_cli_flow.py::TestGpt2::test_beam_search_large # 6 mins diff --git a/tests/integration/test_lists/test-db/l0_h100.yml b/tests/integration/test_lists/test-db/l0_h100.yml index 339b12fba2..eb597d25a0 100644 --- a/tests/integration/test_lists/test-db/l0_h100.yml +++ b/tests/integration/test_lists/test-db/l0_h100.yml @@ -106,7 +106,7 @@ l0_h100: - test_e2e.py::test_trtllm_bench_help_sanity[meta-llama/Llama-3.1-8B] - test_e2e.py::test_openai_chat_harmony - test_e2e.py::test_openai_responses - - test_e2e.py::test_ptp_quickstart_multimodal[gemma-3-27b-it-gemma/gemma-3-27b-it-image-True] + - test_e2e.py::test_ptp_quickstart_multimodal[gemma-3-27b-it-gemma/gemma-3-27b-it-image-True] TIMEOUT (90) # ------------- AutoDeploy tests --------------- - accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype - condition: @@ -227,7 +227,7 @@ l0_h100: - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_guided_decoding[xgrammar-mtp_nextn=0] - accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_fp8_prequantized - accuracy/test_llm_api_pytorch.py::TestGemma3_27BInstruct::test_fp8_prequantized - - accuracy/test_llm_api_pytorch.py::TestGemma3_27BInstruct::test_auto_dtype + - accuracy/test_llm_api_pytorch.py::TestGemma3_27BInstruct::test_auto_dtype TIMEOUT (90) - accuracy/test_llm_api_pytorch.py::TestMistralSmall24B::test_auto_dtype - accuracy/test_llm_api_pytorch.py::TestMistralSmall24B::test_fp8 - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8_block_scales[latency-torch_compile=False] diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index 359b1fbd22..18968eb4a0 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -252,11 +252,9 @@ examples/test_recurrentgemma.py::test_llm_recurrentgemma_1gpu[use_cpp_session-re examples/test_bert.py::test_llm_bert_general[compare_hf-enable_remove_input_padding-use_attention_plugin-enable_context_fmha-tp:1-pp:1-float16-BertForSequenceClassification-bert/bert-base-uncased-yelp-polarity] SKIP (https://nvbugs/5421989) examples/test_bert.py::test_llm_bert_general[compare_hf-enable_remove_input_padding-use_attention_plugin-enable_context_fmha-tp:1-pp:1-float16-RobertaForSequenceClassification-bert/twitter-roberta-base-emotion] SKIP (https://nvbugs/5421989) examples/test_granite.py::test_granite_bf16_lora[granite-3.0-1b-a400m-instruct] SKIP (https://nvbugs/5431132) -accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=1-ctx_pp=4] SKIP (https://nvbugs/5434320) accuracy/test_llm_api.py::TestLlama3_2_1B::test_int4_awq_int8_kv_cache SKIP (https://nvbugs/5433541) accuracy/test_llm_api.py::TestLlama3_2_1B::test_fp8_pp2 SKIP (https://nvbugs/5433541) accuracy/test_llm_api_pytorch.py::TestPhi4MiniInstruct::test_auto_dtype SKIP (https://nvbugs/5433545) -accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[GSM8K-gen_tp=2-ctx_pp=4] SKIP (https://nvbugs/5434320) examples/test_gemma.py::test_hf_gemma_fp8_base_bf16_multi_lora[gemma-2-9b-it] SKIP (https://nvbugs/5434451) examples/test_gemma.py::test_hf_gemma_fp8_base_bf16_multi_lora[gemma-2-27b-it] SKIP (https://nvbugs/5434451) examples/test_gemma.py::test_hf_gemma_fp8_base_bf16_multi_lora[gemma-3-1b-it] SKIP (https://nvbugs/5434451) @@ -278,7 +276,6 @@ triton_server/test_triton.py::test_gpt_ib_ptuning[gpt-ib-ptuning] SKIP (https:// triton_server/test_triton.py::test_mistral_ib_mm[mistral-ib-mm] SKIP (https://nvbugs/5371343) triton_server/test_triton.py::test_t5_ib[t5-ib] SKIP (https://nvbugs/5456482) triton_server/test_triton_llm.py::test_gpt_speculative_decoding_bls[False-False-1---False-True-True-0-128-disableDecoupleMode-inflight_fused_batching-disableTrtOverlap-0.2-guaranteed_no_evict---1-1-1-False-ensemble] SKIP (https://nvbugs/5456485) -accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[GSM8K-gen_tp=1-ctx_pp=4] SKIP (https://nvbugs/5434320) accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4[latency_moe_trtllm_eagle3] SKIP (https://nvbugs/5437384) llmapi/test_llm_examples.py::test_llmapi_speculative_decoding_mtp SKIP (https://nvbugs/5461796) accuracy/test_llm_api.py::TestLlama3_1_8BInstruct::test_gather_generation_logits_cuda_graph SKIP (https://nvbugs/5365525) @@ -287,6 +284,12 @@ examples/test_phi.py::test_phi_fp8_with_bf16_lora[Phi-3-small-128k-instruct] SKI examples/test_phi.py::test_phi_fp8_with_bf16_lora[Phi-3.5-mini-instruct] SKIP (https://nvbugs/5465143) examples/test_phi.py::test_phi_fp8_with_bf16_lora[Phi-4-mini-instruct] SKIP (https://nvbugs/5465143) examples/test_llama.py::test_llm_llama_v1_2gpu_summary[llama-7b-nb:4-enable_auto_parallel] SKIP (https://nvbugs/5465173) +test_e2e.py::test_ptp_quickstart_multimodal[llava-v1.6-mistral-7b-llava-v1.6-mistral-7b-hf-image-False] SKIP (https://nvbugs/5444095) +full:GB200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen1.5_7b_chat-enable_fmha_fp32_acc-enable_plugin-tp2pp2-nb:4] SKIP (https://nvbugs/5247837) +full:GB200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen2_7b_instruct-enable_fmha_fp32_acc-enable_plugin-tp2pp2-nb:4] SKIP (https://nvbugs/5247837) +full:GB200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen2_vl_7b_instruct-enable_fmha_fp32_acc-enable_plugin-tp2pp2-nb:4] SKIP (https://nvbugs/5359696) +full:GB200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen2.5_7b_chat-enable_fmha_fp32_acc-enable_plugin-tp2pp2-nb:4] SKIP (https://nvbugs/5247837) +accuracy/test_cli_flow.py::TestLlama3_8BInstructGradient1048k::test_long_context_ppl SKIP (https://nvbugs/5413362) disaggregated/test_disaggregated.py::test_disaggregated_diff_max_tokens[TinyLlama-1.1B-Chat-v1.0] SKIP (https://nvbugs/5451272) disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_tp1_single_gpu_mtp[DeepSeek-V3-Lite-fp8] SKIP (https://nvbugs/5465642) examples/test_multimodal.py::test_llm_multimodal_general[Mistral-Small-3.1-24B-Instruct-2503-pp:1-tp:1-bfloat16-bs:1-cpp_e2e:False-nb:1] SKIP (https://nvbugs/5431146) @@ -340,17 +343,6 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus_online_ep accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False] SKIP (https://nvbugs/5488118) accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True] SKIP (https://nvbugs/5488118) accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ngram SKIP (https://nvbugs/5488118) -full:GB200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen1.5_7b_chat-enable_fmha_fp32_acc-enable_plugin-tp2pp2-nb:4] SKIP (https://nvbugs/5247837) -full:GB200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen2_7b_instruct-enable_fmha_fp32_acc-enable_plugin-tp2pp2-nb:4] SKIP (https://nvbugs/5247837) -full:GB200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen2_vl_7b_instruct-enable_fmha_fp32_acc-enable_plugin-tp2pp2-nb:4] SKIP (https://nvbugs/5359696) -full:GB200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen2.5_7b_chat-enable_fmha_fp32_acc-enable_plugin-tp2pp2-nb:4] SKIP (https://nvbugs/5247837) -accuracy/test_cli_flow.py::TestLlama3_8BInstructGradient1048k::test_long_context_ppl SKIP (https://nvbugs/5413362) -accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_tp8] SKIP (https://nvbugs/5455140) -full:L40S/accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype[mtp_nextn=0-overlap_scheduler=False] SKIP (https://nvbugs/5347051) -full:L40S/accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=False-overlap_scheduler=False] SKIP (https://nvbugs/5471106) -full:L40S/accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[MMLU-tp2pp2] SKIP (https://nvbugs/5471108) -test_e2e.py::test_multi_nodes_eval[llama4-models/nvidia/Llama-4-Maverick-17B-128E-Instruct-FP8-tp8pp2-mmlu] SKIP (https://nvbugs/5473781) -accuracy/test_cli_flow.py::TestLlama3_8BInstructGradient1048k::test_long_context_ppl SKIP (https://nvbugs/5413362) stress_test/stress_test.py::test_run_stress_test[llama-v3-8b-instruct-hf_tp1-stress_time_300s_timeout_450s-GUARANTEED_NO_EVICT-pytorch-stress-test] SKIP (https://nvbugs/5474169) test_e2e.py::test_trtllm_bench_iteration_log[TRT-streaming-meta-llama/Llama-3.1-8B-llama-3.1-model/Meta-Llama-3.1-8B] SKIP (https://nvbugs/5448523) cpp/test_unit_tests.py::test_unit_tests[kernels-80] SKIP (https://nvbugs/5504078) diff --git a/tests/unittest/_torch/multi_gpu_modeling/test_deepseek.py b/tests/unittest/_torch/multi_gpu_modeling/test_deepseek.py index 5a38f0d078..8b6ac42cd2 100644 --- a/tests/unittest/_torch/multi_gpu_modeling/test_deepseek.py +++ b/tests/unittest/_torch/multi_gpu_modeling/test_deepseek.py @@ -17,6 +17,7 @@ def similar(a, b, threshold=0.9): return SequenceMatcher(None, a, b).ratio() >= threshold +@pytest.mark.skip(reason="https://nvbugs/5470782") @pytest.mark.parametrize("model_name", ["DeepSeek-V3-Lite"], ids=["deepseekv3_lite"]) @pytest.mark.parametrize("backend", ["TRTLLM"], ids=["trtllm"]) diff --git a/tests/unittest/_torch/multimodal/test_fuse_input_embeds.py b/tests/unittest/_torch/multimodal/test_fuse_input_embeds.py new file mode 100644 index 0000000000..2fc85ef075 --- /dev/null +++ b/tests/unittest/_torch/multimodal/test_fuse_input_embeds.py @@ -0,0 +1,205 @@ +import pytest +import torch + +from tensorrt_llm._torch.models.modeling_multimodal_utils import ( + filter_mm_token_from_input_ids, fuse_input_embeds) +from tensorrt_llm._torch.modules.embedding import Embedding + + +def make_embedding(num_embeddings: int = 100, + hidden_size: int = 16, + device: str = "cpu") -> Embedding: + torch.manual_seed(0) + emb = Embedding(num_embeddings=num_embeddings, embedding_dim=hidden_size) + emb.weight.data.normal_(mean=0.0, std=0.02) + return emb.to(device) + + +@pytest.mark.parametrize("device", ["cpu"] + + (["cuda"] if torch.cuda.is_available() else [])) +def test_filter_mm_token_from_input_ids_oov(device): + vocab_size = 10 + # input_ids contains text (< vocab) and OOV mm tokens (>= vocab) + input_ids = torch.tensor([1, 2, 11, 3, 12, 9, 15], + dtype=torch.long, + device=device) + text_idx, mm_idx = filter_mm_token_from_input_ids(input_ids, + vocab_size=vocab_size, + mm_token_ids=None) + + torch.testing.assert_close(text_idx.cpu(), + torch.tensor([0, 1, 3, 5], dtype=torch.long)) + torch.testing.assert_close(mm_idx.cpu(), + torch.tensor([2, 4, 6], dtype=torch.long)) + + +@pytest.mark.parametrize("device", ["cpu"] + + (["cuda"] if torch.cuda.is_available() else [])) +def test_filter_mm_token_from_input_ids_explicit_ids(device): + vocab_size = 100 + # All ids are < vocab; mm tokens are explicitly specified + input_ids = torch.tensor([1, 2, 55, 3, 77, 9, 88], + dtype=torch.long, + device=device) + mm_token_ids = torch.tensor([55, 77, 88], dtype=torch.long) + text_idx, mm_idx = filter_mm_token_from_input_ids(input_ids, + vocab_size=vocab_size, + mm_token_ids=mm_token_ids) + + torch.testing.assert_close(text_idx.cpu(), + torch.tensor([0, 1, 3, 5], dtype=torch.long)) + torch.testing.assert_close(mm_idx.cpu(), + torch.tensor([2, 4, 6], dtype=torch.long)) + + # Even with some ids > vocab, mm indices should still only match the given mm_token_ids + input_ids = torch.tensor([1, 2, 55, 3, 77, 9, 88, 101, 102, 103], + dtype=torch.long, + device=device) + _, mm_idx = filter_mm_token_from_input_ids(input_ids, + vocab_size=vocab_size, + mm_token_ids=mm_token_ids) + torch.testing.assert_close(mm_idx.cpu(), + torch.tensor([2, 4, 6], dtype=torch.long)) + + +@pytest.mark.parametrize("device", ["cpu"] + + (["cuda"] if torch.cuda.is_available() else [])) +def test_fuse_input_embeds_empty_mm_returns_ids(device): + emb = make_embedding(num_embeddings=20, hidden_size=8, device=device) + input_ids = torch.tensor([1, 2, 3, 4], dtype=torch.long, device=device) + + out_ids, out_embeds = fuse_input_embeds(emb, + input_ids, + mm_embeds=[], + mm_token_ids=None) + + # No mm embeddings => passthrough ids, no embeds fused + assert out_embeds is None + torch.testing.assert_close(out_ids, input_ids) + + +@pytest.mark.parametrize("device", ["cpu"] + + (["cuda"] if torch.cuda.is_available() else [])) +def test_fuse_input_embeds_mismatch_raises(device): + emb = make_embedding(num_embeddings=50, hidden_size=8, device=device) + + # Mix text (< vocab) and mm (>= vocab) tokens. Here vocab_size == 50 + input_ids = torch.tensor([1, 51, 2, 52, 3, 53], + dtype=torch.long, + device=device) + + # Identify indices first to drive the lower-level CUDA fuse directly for mismatch + text_idx, mm_idx = filter_mm_token_from_input_ids( + input_ids, vocab_size=emb.num_embeddings) + + # Provide the wrong number of mm embeddings (e.g., one short) + hidden = 8 + true_mm_count = mm_idx.shape[0] + wrong_mm = torch.randn(true_mm_count - 1, hidden, device=device) + + with pytest.raises(ValueError, match="Multimodal token count mismatch"): + fuse_input_embeds(emb, + input_ids, [wrong_mm], + mm_token_ids=None, + text_token_indices=text_idx, + mm_token_indices=mm_idx) + + +@pytest.mark.parametrize("device", ["cpu"] + + (["cuda"] if torch.cuda.is_available() else [])) +def test_fuse_input_embeds_success_oov_path(device): + hidden = 8 + emb = make_embedding(num_embeddings=40, hidden_size=hidden, device=device) + + # input ids: mix of text (<40) and mm (>=40) + input_ids = torch.tensor([0, 1, 41, 2, 42, 3, 43, 4], + dtype=torch.long, + device=device) + + # Build mm embeddings to match number of OOV positions + text_idx, mm_idx = filter_mm_token_from_input_ids( + input_ids, vocab_size=emb.num_embeddings) + mm_emb = torch.randn(mm_idx.shape[0], hidden, device=device) + + # kwargs path to produce fused embeddings + out_ids, out_embeds = fuse_input_embeds(emb, + input_ids, + mm_embeds=[mm_emb], + mm_token_ids=None, + text_token_indices=text_idx, + mm_token_indices=mm_idx) + # integrated filtering path to produce fused embeddings (not kwargs path) + out_ids_v2, out_embeds_v2 = fuse_input_embeds(emb, + input_ids, + mm_embeds=[mm_emb], + mm_token_ids=None) + + assert out_ids is None + assert out_embeds is not None + assert out_embeds.shape == (input_ids.numel(), hidden) + + # Validate that text positions equal embedding lookup, and mm positions equal provided mm_emb + text_idx, mm_idx2 = filter_mm_token_from_input_ids( + input_ids, vocab_size=emb.num_embeddings) + torch.testing.assert_close(mm_idx2, mm_idx) + + ref_text = emb(input_ids[text_idx]) + torch.testing.assert_close(out_embeds[text_idx], ref_text) + torch.testing.assert_close( + out_embeds[mm_idx], + mm_emb.to(dtype=out_embeds.dtype, device=out_embeds.device)) + torch.testing.assert_close(out_embeds_v2, out_embeds) + torch.testing.assert_close(out_ids_v2, out_ids) + + +@pytest.mark.parametrize("device", ["cpu"] + + (["cuda"] if torch.cuda.is_available() else [])) +def test_fuse_input_embeds_kwargs_precedence_over_sentinel_and_ids(device): + """ + Ensure that when kwargs provide precomputed indices, they take precedence + over both OOV-sentinel filtering and explicit mm_token_ids. + """ + hidden = 8 + vocab_size = 40 + emb = make_embedding(num_embeddings=vocab_size, + hidden_size=hidden, + device=device) + + # Use vocab_size+1 as OOV sentinel + oov_sentinel = vocab_size + 1 + input_ids = torch.tensor([0, oov_sentinel, 1, oov_sentinel, 2], + dtype=torch.long, + device=device) + + # Precompute correct indices (kwargs path) + text_idx, mm_idx = filter_mm_token_from_input_ids(input_ids, + vocab_size=vocab_size, + mm_token_ids=None) + mm_emb = torch.randn(mm_idx.shape[0], hidden, device=device) + + # Provide a deliberately incorrect mm_token_ids to ensure it is ignored + bad_mm_token_ids = torch.tensor( + [0], dtype=torch.long, + device=device) # would misclassify index 0 as mm if used + + out_ids, out_embeds = fuse_input_embeds( + emb, + input_ids, + mm_embeds=[mm_emb], + mm_token_ids= + bad_mm_token_ids, # should be ignored because indices are provided + text_token_indices=text_idx, + mm_token_indices=mm_idx, + ) + + # Validate outputs + assert out_ids is None + assert out_embeds is not None + assert out_embeds.shape == (input_ids.numel(), hidden) + + ref_text = emb(input_ids[text_idx]) + torch.testing.assert_close(out_embeds[text_idx], ref_text) + torch.testing.assert_close( + out_embeds[mm_idx], + mm_emb.to(dtype=out_embeds.dtype, device=out_embeds.device), + ) diff --git a/tests/unittest/_torch/speculative/test_eagle3.py b/tests/unittest/_torch/speculative/test_eagle3.py index 8de0ac8642..bf69917ef2 100644 --- a/tests/unittest/_torch/speculative/test_eagle3.py +++ b/tests/unittest/_torch/speculative/test_eagle3.py @@ -16,7 +16,6 @@ from tensorrt_llm.llmapi import (CudaGraphConfig, EagleDecodingConfig, sys.path.append(os.path.join(os.path.dirname(__file__), '..')) -@pytest.mark.skip(reason="https://nvbugs/5461761") @pytest.mark.parametrize( "use_cuda_graph,attn_backend,disable_overlap_scheduler,enable_block_reuse,use_one_model,enable_chunked_prefill", [ @@ -27,7 +26,8 @@ sys.path.append(os.path.join(os.path.dirname(__file__), '..')) [False, "TRTLLM", False, True, True, False], [True, "TRTLLM", False, True, True, False], [True, "TRTLLM", True, False, True, True], - [True, "TRTLLM", True, False, False, True], + # TODO: nvbugs/5461761 + # [True, "TRTLLM", True, False, False, True], ]) @pytest.mark.high_cuda_memory def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str, diff --git a/tests/unittest/_torch/test_torch_sampler.py b/tests/unittest/_torch/test_torch_sampler.py new file mode 100644 index 0000000000..aede4f130d --- /dev/null +++ b/tests/unittest/_torch/test_torch_sampler.py @@ -0,0 +1,189 @@ +import random + +import pytest +import torch + +from tensorrt_llm._torch.pyexecutor.llm_request import convert_wordlist +from tensorrt_llm._torch.pyexecutor.sampler import (BEAM_0, LlmRequest, + TorchSampler) +from tensorrt_llm._torch.pyexecutor.sampler_utils import produce_stop_words +from tensorrt_llm.bindings import SamplingConfig +from tensorrt_llm.bindings.executor import FinishReason + +MAX_NUM_SEQUENCES = 128 +NOT_FINISHED = FinishReason.NOT_FINISHED +STOP_WORDS = FinishReason.STOP_WORDS +END_ID = FinishReason.END_ID +LENGTH = FinishReason.LENGTH + + +class RequestCase: + MAX_NEW_TOKENS = 10 + seq_slots = random.sample(range(MAX_NUM_SEQUENCES), MAX_NUM_SEQUENCES) + + def __init__(self, + *, + prompt: list[int], + new_tokens: list[int], + finish_reasons: list[FinishReason], + max_new_tokens: int = MAX_NEW_TOKENS, + end_id: int = None, + stop_words_list: list[list[int]] = None): + seq_slot = self.seq_slots.pop() # random seq slot in MAX_NUM_SEQUENCES + self.prompt = prompt + self.request = LlmRequest( + request_id=seq_slot, + seq_slot=seq_slot, + input_tokens=prompt, + max_new_tokens=max_new_tokens, + stop_words_list=convert_wordlist(stop_words_list) + if stop_words_list is not None else None, + end_id=end_id, + sampling_config=SamplingConfig(), + is_streaming=False, + ) + assert len(new_tokens) == len(finish_reasons) + self.new_tokens = new_tokens + self.finish_reasons = finish_reasons + + def __repr__(self): + return f"RequestCase({self.prompt=}, {self.new_tokens=}, {self.finish_reasons=}, {self.request.max_new_tokens=}, {self.request.end_id=}, {self.request.stop_words_list=})" + + @staticmethod + def setup(requests: list["RequestCase"]): + max_tokens = set(len(req.new_tokens) for req in requests) + assert len(max_tokens) == 1 + max_draft_len = max_tokens.pop() - 1 + sampler_args = TorchSampler.Args( + max_seq_len=20, + max_draft_len=max_draft_len, + # Fill with many more max requests than below, so we can test that write_finish_reasons uses seq_slots correctly + max_num_sequences=MAX_NUM_SEQUENCES, + max_beam_width=1, + enable_mixed_sampler=False) + sampler = TorchSampler(args=sampler_args) + + # fill with garbage value so we can observe that finish reasons are filled with NOT_FINISHED before we write to them. + sampler.store.finish_reasons.fill_(205) + seq_slots = torch.tensor([req.request.py_seq_slot for req in requests], + device="cuda", + dtype=torch.int64) + new_tokens = torch.tensor([req.new_tokens for req in requests], + dtype=torch.int32, + device="cuda").T + sampler.store.new_tokens[:, seq_slots, BEAM_0] = new_tokens + + def run(): + sampler._write_finish_reasons( + [req.request for req in requests], + finish_reasons=sampler.store.finish_reasons, + new_tokens=sampler.store.new_tokens, + seq_slots=seq_slots) + + reasons = sampler.store.finish_reasons[:, seq_slots, + BEAM_0].T.tolist() + + for actual, request in zip(reasons, requests, strict=True): + expected = request.finish_reasons + msg = f"actual={[FinishReason(reason) for reason in actual]} != expected={expected}\nFor {request}" + assert actual == [reason.value for reason in expected], msg + + return run, sampler + + +def test_write_finish_reasons(): + """We don't really care about the finish reason past the first infraction, because we're not going to use it, although in some instance it is written anyway.""" + run, _ = RequestCase.setup([ + RequestCase( + prompt=[13, 14], + new_tokens=[60, 61, 62], + # We pre-fill the finish reasons with NOT_FINISHED. + finish_reasons=[NOT_FINISHED, NOT_FINISHED, NOT_FINISHED], + ), + RequestCase( + prompt=[7, 8, 6], + stop_words_list=[[12, 13]], + new_tokens=[12, 13, 60], + finish_reasons=[NOT_FINISHED, STOP_WORDS, NOT_FINISHED], + ), + RequestCase( + prompt=[1, 2, 3, 4], + end_id=99, + new_tokens=[55, 99, 58], + finish_reasons=[NOT_FINISHED, END_ID, NOT_FINISHED], + ), + RequestCase( + prompt=[4, 5, 6], + max_new_tokens=2, + new_tokens=[56, 57, 59], + # The LENGTH check happens to not have an early exit + finish_reasons=[NOT_FINISHED, LENGTH, LENGTH]), + RequestCase( + prompt=[1, 12], + stop_words_list=[[12, 13], [14, 15]], + new_tokens=[13, 14, 15], + # We have an early exit specifically for stop words + finish_reasons=[STOP_WORDS, NOT_FINISHED, NOT_FINISHED], + ), + RequestCase( + prompt=[1], + max_new_tokens=2, + end_id=99, + stop_words_list=[[1, 12]], + new_tokens=[12, 99, 63], + # Different infractions are written to different places as we don't have an early exit between infractions + finish_reasons=[STOP_WORDS, END_ID, LENGTH], + ), + RequestCase( + prompt=[1, 12, 56, 67, 68, 234, 678], + stop_words_list=[[12, 56, 67, 68, 234, 678, 129, 182]], + new_tokens=[129, 182, 600], + # Notice the offending stop sequence is concatenated, as we lookback + finish_reasons=[NOT_FINISHED, STOP_WORDS, NOT_FINISHED], + ), + RequestCase( + prompt=[1, 12], + end_id=99, + max_new_tokens=1, + stop_words_list=[[1, 12, 99]], + new_tokens=[99, 100, 101], + # The latest infraction check overrides the earlier infraction checks, hence the first finish_reason is END_ID + finish_reasons=[END_ID, LENGTH, LENGTH], + ), + ]) + run() + + +def test_are_stop_words_isnt_called_when_no_stop_words(): + """We don't want to call are_stop_words when there are no stop words because it's expensive""" + + def stop_words_that_raises(*args, **kwargs): + raise AssertionError + + run_with_stop_words, sampler = RequestCase.setup([ + RequestCase(prompt=[1], + stop_words_list=[[1]], + new_tokens=[4], + finish_reasons=[NOT_FINISHED]) + ]) + sampler._are_stop_words = stop_words_that_raises + with pytest.raises(AssertionError): + run_with_stop_words() + + run_without_stop_words, sampler = RequestCase.setup([ + RequestCase(prompt=[1], new_tokens=[4], finish_reasons=[NOT_FINISHED]) + ]) + sampler._are_stop_words = stop_words_that_raises + _ = run_without_stop_words() + + +def test_produce_stop_words(): + for original in [ + [[]], + [[1]], + [[1, 2, 3]], + [[1], [2, 3]], + [[1, 2, 3], [4, 5]], + [[10], [20], [30, 40], [50]], + ]: + assert original == list(produce_stop_words(convert_wordlist(original))) diff --git a/tests/unittest/llmapi/apps/_test_openai_chat_harmony.py b/tests/unittest/llmapi/apps/_test_openai_chat_harmony.py index 0204a04acf..ba6c7d5337 100644 --- a/tests/unittest/llmapi/apps/_test_openai_chat_harmony.py +++ b/tests/unittest/llmapi/apps/_test_openai_chat_harmony.py @@ -147,6 +147,10 @@ async def test_streaming(client: openai.AsyncOpenAI, model: str): collected_chunks = [] collected_messages = [] async for chunk in response: + # Last streaming response will only contains usage info + if len(chunk.choices) <= 0: + continue + collected_chunks.append(chunk) collected_messages.append(chunk.choices[0].delta) @@ -198,6 +202,10 @@ async def test_streaming_tool_call(client: openai.AsyncOpenAI, model: str): reasoning_chunks: list[str] = [] tool_arg_chunks: list[str] = [] async for chunk in response: + # Last streaming response will only contains usage info + if len(chunk.choices) <= 0: + continue + delta = chunk.choices[0].delta if hasattr(delta, "tool_calls") and delta.tool_calls: function = delta.tool_calls[0].function diff --git a/tests/unittest/llmapi/apps/_test_openai_multi_chat.py b/tests/unittest/llmapi/apps/_test_openai_multi_chat.py index 9ed9a654c5..1265b58bd9 100644 --- a/tests/unittest/llmapi/apps/_test_openai_multi_chat.py +++ b/tests/unittest/llmapi/apps/_test_openai_multi_chat.py @@ -65,7 +65,10 @@ def engine_from_fp8_quantization(model_name): @pytest.fixture(scope="module") def server(model_name: str, engine_from_fp8_quantization: str): model_path = get_model_path(model_name) - args = ["--tp_size", "2", "--tokenizer", model_path] + args = [ + "--tp_size", "2", "--tokenizer", model_path, "--backend", "trt", + "--max_num_tokens", "20480", "--max_batch_size", "128" + ] with RemoteOpenAIServer(engine_from_fp8_quantization, args) as remote_server: yield remote_server diff --git a/tests/unittest/llmapi/test_utils.py b/tests/unittest/llmapi/test_utils.py index fc5876cdb1..5488f7c7ba 100644 --- a/tests/unittest/llmapi/test_utils.py +++ b/tests/unittest/llmapi/test_utils.py @@ -1,4 +1,6 @@ -from tensorrt_llm.llmapi.utils import ApiStatusRegistry +from tensorrt_llm.llmapi import LlmArgs +from tensorrt_llm.llmapi.utils import (ApiStatusRegistry, + generate_api_docs_as_docstring) def test_api_status_registry(): @@ -24,3 +26,9 @@ def test_api_status_registry(): pass assert ApiStatusRegistry.get_api_status(App._my_method) == "beta" + + +def test_generate_api_docs_as_docstring(): + doc = generate_api_docs_as_docstring(LlmArgs) + assert ":tag:`beta`" in doc, "the label is not generated" + print(doc)