Merge remote-tracking branch 'origin/main' into feat/b300_cu13

Signed-off-by: Xiwen Yu <13230610+VALLIS-NERIA@users.noreply.github.com>
This commit is contained in:
Xiwen Yu 2025-09-08 15:14:54 +08:00
commit fdaf4e2985
70 changed files with 1869 additions and 543 deletions

View File

@ -47,7 +47,7 @@ public:
bool operator()(DecoderInputBuffers& inputBuffers, bool replicateLogitsPostProcessor,
runtime::WorldConfig const& worldConfig, CudaStreamPtr const& stream,
std::optional<LogitsPostProcessorBatched> logitsPostProcessorBatched = std::nullopt) const;
std::optional<LogitsPostProcessorBatched> const& logitsPostProcessorBatched = std::nullopt) const;
};
} // namespace tensorrt_llm::batch_manager

View File

@ -204,6 +204,34 @@ private:
}
}
void sendResponse(std::vector<size_t> const& blockHashes, std::map<RequestIdType, Response>::iterator it)
{
auto reqId = mCurrentRequest.value();
auto count = --mRemainSendCount[reqId];
TLLM_CHECK(count >= 0);
if (count == 0)
{
mRemainSendCount.erase(reqId);
// TODO(zhengd): pass the hashes directly instead of update llmRequest
auto llmRequest = it->second.mRequest;
llmRequest->setRequestedBlockHashes(std::move(blockHashes));
if (common::getEnvParallelCacheSend())
{
// TODO: Use a thread pool and check for thread safety.
std::thread(&DataResponder::Impl::sendAndRemoveResponse, this, it->first, std::move(it->second))
.detach();
}
else
{
DataResponder::Impl::sendAndRemoveResponse(it->first, std::move(it->second));
}
removeResponse(it);
}
mCurrentRequest = std::nullopt;
}
void response() noexcept
{
try
@ -237,40 +265,22 @@ private:
auto it = getCurrentResponse();
if (it != mReadyResponses.end())
{
auto reqId = mCurrentRequest.value();
auto count = --mRemainSendCount[reqId];
TLLM_CHECK(count >= 0);
if (count == 0)
{
mRemainSendCount.erase(reqId);
// TODO(zhengd): pass the hashes directly instead of update llmRequest
auto llmRequest = it->second.mRequest;
llmRequest->setRequestedBlockHashes(std::move(blockHashes));
if (common::getEnvParallelCacheSend())
{
// TODO: Use a thread pool and check for thread safety.
std::thread(
&DataResponder::Impl::sendAndRemoveResponse, this, it->first, std::move(it->second))
.detach();
}
else
{
DataResponder::Impl::sendAndRemoveResponse(it->first, std::move(it->second));
}
removeResponse(it);
}
mCurrentRequest = std::nullopt;
sendResponse(blockHashes, it);
}
else
{
TLLM_CHECK_WITH_INFO(!mCurrentRequest.has_value(),
"This executor does not have a prepared KV cache for request ID: %zu, and the "
"mReadyResponses size is: %zu. mpi rank :%d ",
mCurrentRequest.value(), mReadyResponses.size(), mpi::MpiComm::world().getRank());
std::unique_lock lk(mCondMutex);
mResponderCv.wait(lk, [this]() { return (mAnyReady || mTerminate); });
auto it = getCurrentResponse();
while (it == mReadyResponses.end())
{
std::unique_lock lk(mCondMutex);
mResponderCv.wait(lk, [this]() { return (mAnyReady || mTerminate); });
if (mTerminate)
{
break;
}
it = getCurrentResponse();
}
sendResponse(blockHashes, it);
}
}
}

View File

@ -34,7 +34,7 @@ using SizeType32 = tensorrt_llm::runtime::SizeType32;
bool LogitsPostProcessor::operator()(DecoderInputBuffers& inputBuffers, bool replicateLogitsPostProcessor,
tr::WorldConfig const& worldConfig, CudaStreamPtr const& stream,
std::optional<LogitsPostProcessorBatched> logitsPostProcessorBatched) const
std::optional<LogitsPostProcessorBatched> const& logitsPostProcessorBatched) const
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
NVTX3_SCOPED_RANGE(LogitsPostProcessor);

View File

@ -201,56 +201,60 @@ Metrics Endpoint
.. note::
This endpoint is beta maturity.
The metrics endpoint for the default PyTorch backend are in beta and are not as comprehensive as those for the TensorRT backend.
The statistics for the PyTorch backend are beta and not as comprehensive as those for the TensorRT backend.
Some fields, such as CPU memory usage, are not yet available for the PyTorch backend.
Some fields, such as CPU memory usage, are not available for the PyTorch backend.
Enabling ``enable_iter_perf_stats`` in the PyTorch backend can slightly impact performance, depending on the serving configuration.
Enabling ``enable_iter_perf_stats`` in the PyTorch backend can impact performance slightly, depending on the serving configuration.
The ``/metrics`` endpoint provides runtime iteration statistics such as GPU memory usage and KV cache details.
The ``/metrics`` endpoint provides runtime-iteration statistics such as GPU memory use and inflight-batching details.
For the TensorRT backend, these statistics are enabled by default.
However, for the PyTorch backend, you must explicitly enable iteration statistics logging by setting the `enable_iter_perf_stats` field in a YAML configuration file as shown in the following example:
For the default PyTorch backend, iteration statistics logging is enabled by setting the ``enable_iter_perf_stats`` field in a YAML file:
.. code-block:: yaml
# extra-llm-api-config.yml
pytorch_backend_config:
enable_iter_perf_stats: true
# extra_llm_config.yaml
enable_iter_perf_stats: true
Then start the server and specify the ``--extra_llm_api_options`` argument with the path to the YAML file as shown in the following example:
Start the server and specify the ``--extra_llm_api_options`` argument with the path to the YAML file:
.. code-block:: bash
trtllm-serve <model> \
--extra_llm_api_options <path-to-extra-llm-api-config.yml> \
[--tp_size <tp> --pp_size <pp> --ep_size <ep> --host <host> --port <port>]
trtllm-serve "TinyLlama/TinyLlama-1.1B-Chat-v1.0" --extra_llm_api_options extra_llm_config.yaml
After at least one inference request is sent to the server, you can fetch the runtime-iteration statistics by polling the `/metrics` endpoint:
After sending at least one inference request to the server, you can fetch runtime iteration statistics by polling the ``/metrics`` endpoint.
Since the statistics are stored in an internal queue and removed once retrieved, it's recommended to poll the endpoint shortly after each request and store the results if needed.
.. code-block:: bash
curl -X GET http://<host>:<port>/metrics
curl -X GET http://localhost:8000/metrics
*Example Output*
Example output:
.. code-block:: json
[
{
"gpuMemUsage": 56401920000,
"inflightBatchingStats": {
[
{
"gpuMemUsage": 76665782272,
"iter": 154,
"iterLatencyMS": 7.00688362121582,
"kvCacheStats": {
"allocNewBlocks": 3126,
"allocTotalBlocks": 3126,
"cacheHitRate": 0.00128,
"freeNumBlocks": 101253,
"maxNumBlocks": 101256,
"missedBlocks": 3121,
"reusedBlocks": 4,
"tokensPerBlock": 32,
"usedNumBlocks": 3
},
"numActiveRequests": 1
...
},
"iter": 1,
"iterLatencyMS": 16.505143404006958,
"kvCacheStats": {
...
},
"newActiveRequestsQueueLatencyMS": 0.0007503032684326172
}
]
}
]
Syntax
------

View File

@ -234,6 +234,28 @@ TODO: Use Chat Compeletions API / Responses API as the example after the PR is m
We use OpenAI's official evaluation tool to test the model's accuracy. For more information see [https://github.com/openai/gpt-oss/tree/main/gpt_oss/evals](gpt-oss-eval).
With the added support of Chat Completions and Responses API in `trtllm-serve,` `gpt_oss.evals` works directly without any modifications.
You need to set `enable_attention_dp`, `tp_size`, `ep_size`, `max_batch_size` and `max_num_tokens` when launching the trtllm server and set `reasoning-effort` when launching evaluation in gpt-oss. Below are some reference configurations for accuracy evaluation on B200.
| **reasoning-effort** | **parallel configuration** | **max_batch_size** | **max_num_tokens** |
|:--------------------:|:--------------------------:|:------------------:|:------------------:|
| low/medium | DEP8 / DEP4 | 128 | 32768 |
| high | DEP8 / DEP4 | 2 | 133120 |
| low/medium | TP8 / TP4 | 1024 | 32768 |
| high | TP8 / TP4 | 720 | 133120 |
Below is an example command for evaluating the accuracy of gpt-oss-120b with low and medium reasoning-effort on GPQA and AIME2025.
```shell
# execute this command in gpt-oss
python -m gpt_oss.evals \
--sampler chat_completions \
--eval gpqa,aime25 \
--model gpt-oss-120b \
--reasoning-effort low,medium
```
## Benchmarking Performance
To benchmark the performance of your TensorRT-LLM server you can leverage the built-in `benchmark_serving.py` script. To do this first creating a wrapper `bench.sh` script.

View File

@ -1,7 +1,7 @@
# LLM API with TensorRT Engine
A simple inference example with TinyLlama using the LLM API:
```{literalinclude} ../../examples/llm-api/_tensorrt_engine/quickstart_example.py
```{literalinclude} ../../../examples/llm-api/_tensorrt_engine/quickstart_example.py
:language: python
:linenos:
```

View File

@ -1,11 +1,17 @@
from tensorrt_llm import LLM, SamplingParams
from tensorrt_llm import BuildConfig, SamplingParams
from tensorrt_llm._tensorrt_engine import LLM # NOTE the change
def main():
build_config = BuildConfig()
build_config.max_batch_size = 256
build_config.max_num_tokens = 1024
# Model could accept HF model name, a path to local HF model,
# or TensorRT Model Optimizer's quantized checkpoints like nvidia/Llama-3.1-8B-Instruct-FP8 on HF.
llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
build_config=build_config)
# Sample prompts.
prompts = [

View File

@ -76,6 +76,7 @@ srun -l \
# This is optional
cat > /tmp/pytorch_extra_args.txt << EOF
cuda_graph_config: null
print_iter_log: true
enable_attention_dp: false
EOF

View File

@ -94,6 +94,26 @@ def createKubernetesPodConfig(type, arch = "amd64", build_wheel = false)
"""
}
if (arch == "amd64") {
// For x86_64, we block some nodes to avoid unstable network access.
selectors += """
affinity:
nodeAffinity:
requiredDuringSchedulingIgnoredDuringExecution:
nodeSelectorTerms:
- matchExpressions:
- key: "kubernetes.io/hostname"
operator: NotIn
values:
- "sc-ipp-blossom-prod-k8w-105"
- "sc-ipp-blossom-prod-k8w-114"
- "sc-ipp-blossom-prod-k8w-115"
- "sc-ipp-blossom-prod-k8w-121"
- "sc-ipp-blossom-prod-k8w-123"
- "sc-ipp-blossom-prod-k8w-124"
"""
}
def archSuffix = arch == "arm64" ? "arm" : "amd"
def jnlpImage = "urm.nvidia.com/sw-ipp-blossom-sre-docker-local/lambda/custom_jnlp_images_${archSuffix}_linux:jdk17"

View File

@ -365,13 +365,24 @@ def runLLMTestlistOnSlurm(pipeline, platform, testList, config=VANILLA_CONFIG, p
}
stage('Checking if the Node is Online') {
def counter = 0
// We submit the Slurm job with 5 hours timeout, and the K8S pod will be evicted after 22 hours.
// Let's use 15 hours to check if the node is online, and with 2 hours buffer.
while (!CloudManager.isNodeOnline(nodeName) && counter < 90) {
// Wait 10 minutes to check status of the node again
sleep(time: 10, unit: 'MINUTES')
counter++
withCredentials([usernamePassword(credentialsId: 'svc_tensorrt', usernameVariable: 'USERNAME', passwordVariable: 'PASSWORD')]) {
def remote = [
ip : cluster.ip,
host : cluster.host,
user : "${pipeline.USERNAME}",
passwd : "${pipeline.PASSWORD}",
allowAnyHosts: true,
]
def counter = 0
// We submit the Slurm job with 5 hours timeout, and the K8S pod will be evicted after 22 hours.
// Let's use 15 hours to check if the node is online, and with 2 hours buffer.
while (!CloudManager.isNodeOnline(nodeName) && counter < 90) {
// Wait 10 minutes to check status of the node again
sleep(time: 10, unit: 'MINUTES')
// Avoid the node being stuck in the held state.
Utils.exec(pipeline, Utils.sshUserCmd(remote, "\"scontrol release ${slurmJobID} || true\""))
counter++
}
}
if (CloudManager.isNodeOnline(nodeName)) {
@ -1157,7 +1168,7 @@ def transformMakoArgsToJson(optList) {
def getMakoOpts(getMakoScript, makoArgs=[]) {
// We want to save a map for the Mako opts
def turtleOutput = ""
def makoOutput = ""
// Echo the command
// NOTE: We redirect stderr to stdout so that we can capture
@ -1181,17 +1192,17 @@ def getMakoOpts(getMakoScript, makoArgs=[]) {
// Capture the mako output, add timeout in case any hang
timeout(time: 30, unit: 'MINUTES'){
turtleOutput = sh(label: "Capture Mako Parameters", script: listMakoCmd, returnStdout: true)
makoOutput = sh(label: "Capture Mako Parameters", script: listMakoCmd, returnStdout: true)
}
}
// Validate output
assert turtleOutput: "Mako opts not found - could not construct test db test list."
assert makoOutput: "Mako opts not found - could not construct test db test list."
// Split each line of turtle output into a list
def turtleOutList = turtleOutput.split("\n")
// Split each line of mako output into a list
def outputList = makoOutput.split("\n")
def makoOptsJson = transformMakoArgsToJson(turtleOutList)
def makoOptsJson = transformMakoArgsToJson(outputList)
return makoOptsJson
}
@ -1827,7 +1838,7 @@ def runLLMBuild(pipeline, cpu_arch, reinstall_dependencies=false, wheel_path="",
if (env.alternativeTRT) {
trtllm_utils.replaceWithAlternativeTRT(env.alternativeTRT, cpver)
}
buildArgs = "--clean"
buildArgs = "--clean --nixl_root /opt/nvidia/nvda_nixl"
if (cpu_arch == AARCH64_TRIPLE) {
buildArgs += " -a '90-real;100-real;103-real;120-real'"
}
@ -2062,9 +2073,9 @@ def launchTestJobs(pipeline, testFilter)
"DGX_H200-4_GPUs-TensorRT-Post-Merge-1": ["dgx-h200-x4", "l0_dgx_h200", 1, 3, 4],
"DGX_H200-4_GPUs-TensorRT-Post-Merge-2": ["dgx-h200-x4", "l0_dgx_h200", 2, 3, 4],
"DGX_H200-4_GPUs-TensorRT-Post-Merge-3": ["dgx-h200-x4", "l0_dgx_h200", 3, 3, 4],
"RTXPro6000-Pytorch-Post-Merge-1": ["rtx-pro-6000", "l0_rtx_pro_6000", 1, 1],
"RTXPro6000-4_GPUs-Pytorch-Post-Merge-1": ["rtx-pro-6000-x4", "l0_rtx_pro_6000", 1, 2, 4],
"RTXPro6000-4_GPUs-Pytorch-Post-Merge-2": ["rtx-pro-6000-x4", "l0_rtx_pro_6000", 2, 2, 4],
//"RTXPro6000-Pytorch-Post-Merge-1": ["rtx-pro-6000", "l0_rtx_pro_6000", 1, 1],
//"RTXPro6000-4_GPUs-Pytorch-Post-Merge-1": ["rtx-pro-6000-x4", "l0_rtx_pro_6000", 1, 2, 4],
//"RTXPro6000-4_GPUs-Pytorch-Post-Merge-2": ["rtx-pro-6000-x4", "l0_rtx_pro_6000", 2, 2, 4],
]
parallelJobs = x86TestConfigs.collectEntries{key, values -> [key, [createKubernetesPodConfig(key.contains("-CU12-") ? LLM_DOCKER_IMAGE_12_9 : LLM_DOCKER_IMAGE, values[0], "amd64", values[4] ?: 1, key.contains("Perf")), {

View File

@ -170,7 +170,8 @@ class FlashInferAttentionMetadata(AttentionMetadata):
def create_cuda_graph_metadata(self,
max_batch_size: int,
sub_cross_metadata: bool = False,
max_draft_tokens: int = 0) -> Self:
max_draft_tokens: int = 0,
buffers=None) -> Self:
metadata = super().create_cuda_graph_metadata(max_batch_size,
sub_cross_metadata,
max_draft_tokens)

View File

@ -140,6 +140,7 @@ class AttentionMetadata:
# This buffer is currently only used for TrtllmAttentionMetadata.
cache_indirection: Optional[torch.Tensor] = None
cuda_graph_buffers: dict[str, list[torch.Tensor]] = None
_saved_tensors: Dict[str, torch.Tensor] = field(init=False,
default_factory=dict)
@ -288,7 +289,8 @@ class AttentionMetadata:
def create_cuda_graph_metadata(self,
max_batch_size: int,
sub_cross_metadata: bool = False,
max_draft_tokens: int = 0) -> Self:
max_draft_tokens: int = 0,
buffers=None) -> Self:
"""
Creates metadata for CUDA graph execution.
CUDA graphs require to use pre-allocated buffers for all tensors in fields.
@ -300,6 +302,7 @@ class AttentionMetadata:
cuda_graph_metadata = copy.copy(self)
cuda_graph_metadata.is_cuda_graph = True
cuda_graph_metadata.cuda_graph_buffers = buffers
if self.has_cross_sub_metadata:
cuda_graph_metadata.cross = cuda_graph_metadata.cross.create_cuda_graph_metadata(
max_batch_size, True)

View File

@ -600,13 +600,65 @@ class TrtllmAttentionMetadata(AttentionMetadata):
def __post_init__(self) -> None:
super().__post_init__()
self._post_init_with_buffers(self.cuda_graph_buffers)
def _post_init_with_buffers(self, buffers) -> None:
# Set a default value, as max_num_sequences is not always set.
if self.max_num_sequences is None:
self.max_num_sequences = self.max_num_requests
self.prompt_lens_cuda = torch.empty(
def get_empty(tensor_shape: list[int], dtype: torch.dtype,
cache_name: str) -> torch.Tensor:
"""
Finds a compatible, reusable buffer from a cache or creates a new one.
This function searches for a pre-allocated tensor (buffer) that can be
reused for an operation involving a tensor with the shape of `tensor_shape`.
The compatibility rules are: The buffer's total elements must be >= tensor_shape's.
If a compatible buffer is found, it's returned immediately. Otherwise, a new
buffer is allocated on the 'cuda' device with the give properties of 'tensor_shape' and 'dtype'.
Args:
tensor_shape: The required shape.
dtype: The required dtype.
cache_name: The key for the specific list of buffers to search in.
Returns:
An existing compatible buffer or a newly created one.
"""
if buffers is not None:
# Safely get the list of candidates. Defaults to an empty list if key is missing.
candidate_buffers = buffers.get(cache_name, [])
numel_like = math.prod(tensor_shape)
for buffer in candidate_buffers:
numel_buffer = buffer.numel()
# buffer just needs to be large enough.
if numel_buffer >= numel_like:
return buffer[0:numel_like].view(
tensor_shape) # Found a fit, return immediately.
# If we get here, no suitable buffer was found in the cache. Create a new one.
new_buffer = torch.zeros(tensor_shape, device='cuda', dtype=dtype)
if buffers is not None:
buffers.setdefault(cache_name, []).append(new_buffer)
return new_buffer
def get_empty_like(like_tensor: torch.Tensor,
cache_name: str) -> torch.Tensor:
return get_empty(
like_tensor.shape,
cache_name=cache_name,
dtype=like_tensor.dtype,
)
self.prompt_lens_cuda = get_empty(
(self.max_num_sequences, ),
device='cuda',
cache_name="prompt_lens_cuda",
dtype=torch.int,
)
self.prompt_lens_cpu = torch.empty_like(
@ -614,7 +666,10 @@ class TrtllmAttentionMetadata(AttentionMetadata):
device='cpu',
pin_memory=True,
)
self.kv_lens_cuda = torch.empty_like(self.prompt_lens_cuda)
self.kv_lens_cuda = get_empty_like(
self.prompt_lens_cuda,
cache_name="kv_lens_cuda",
)
self.kv_lens = torch.empty_like(self.kv_lens_cuda,
device='cpu',
pin_memory=True)
@ -629,13 +684,13 @@ class TrtllmAttentionMetadata(AttentionMetadata):
dtype=torch.int8,
)
if self.kv_cache_manager is not None:
self.kv_cache_block_offsets = torch.empty(
self.kv_cache_block_offsets = get_empty(
[
self.kv_cache_manager.num_pools, self.max_num_sequences, 2,
self.kv_cache_manager.max_blocks_per_seq
],
cache_name="kv_cache_block_offsets",
dtype=torch.int32,
device='cuda',
)
self.host_kv_cache_block_offsets = torch.empty_like(
self.kv_cache_block_offsets,
@ -645,27 +700,27 @@ class TrtllmAttentionMetadata(AttentionMetadata):
self.block_ids_per_seq = None
self.kv_block_ids_per_seq = None
if self.enable_flash_mla:
self.block_ids_per_seq = torch.zeros(
self.block_ids_per_seq = get_empty(
[
self.kv_cache_manager.max_batch_size,
self.kv_cache_manager.max_blocks_per_seq
],
cache_name="block_ids_per_seq",
dtype=torch.int32,
device='cuda',
)
self.kv_block_ids_per_seq = torch.zeros(
self.kv_block_ids_per_seq = get_empty(
[
self.kv_cache_manager.max_batch_size,
self.kv_cache_manager.max_blocks_per_seq
],
cache_name="kv_block_ids_per_seq",
dtype=torch.int32,
device='cuda',
)
if self.enable_context_mla_with_cached_kv:
# for kv cache reuse/chunked context in MLA
self.ctx_cached_token_indptr = torch.zeros(
self.ctx_cached_token_indptr = get_empty(
(self.max_num_requests + 1, ),
device='cuda',
cache_name="ctx_cached_token_indptr",
dtype=torch.int64,
)
self.host_ctx_cached_token_indptr = torch.zeros_like(
@ -673,9 +728,9 @@ class TrtllmAttentionMetadata(AttentionMetadata):
device='cpu',
pin_memory=True,
)
self.ctx_uncached_token_indptr = torch.zeros(
self.ctx_uncached_token_indptr = get_empty(
(self.max_num_requests + 1, ),
device='cuda',
cache_name="ctx_uncached_token_indptr",
dtype=torch.int64,
)
self.host_ctx_uncached_token_indptr = torch.zeros_like(
@ -684,9 +739,9 @@ class TrtllmAttentionMetadata(AttentionMetadata):
pin_memory=True,
)
# context full seqlens include cached tokens and uncached tokens
self.ctx_kv_indptr = torch.zeros(
self.ctx_kv_indptr = get_empty(
(self.max_num_requests + 1, ),
device='cuda',
cache_name="ctx_kv_indptr",
dtype=torch.int64,
)
self.host_ctx_kv_indptr = torch.zeros_like(
@ -1165,7 +1220,7 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):
host_kv_cache_pool_pointers=metadata.host_kv_cache_pool_pointers,
host_kv_cache_pool_mapping=metadata.host_kv_cache_pool_mapping,
block_ids_per_seq=metadata.block_ids_per_seq,
workspace=metadata.workspace,
workspace=None,
cache_indirection=metadata.cache_indirection,
kv_scale_orig_quant=self.kv_scale_orig_quant,
kv_scale_quant_orig=self.kv_scale_quant_orig,

View File

@ -371,7 +371,7 @@ class AutoTuner:
if not is_cache_hit:
logger.warning_once(
f"[AutoTunner] Using the fallback tactic, due to cache miss on input shapes={input_shapes}",
key=(custom_op))
key=custom_op)
return (best_runner, best_tactic)

View File

@ -210,15 +210,9 @@ class PiecewiseRunner(object):
runtime_input_addresses = [
i.data_ptr() for i in args if isinstance(i, torch.Tensor)
]
runtime_output_addresses = [
i.data_ptr() for i in output if isinstance(i, torch.Tensor)
]
assert (entry.input_addresses == runtime_input_addresses
), f"{entry.input_addresses} vs\n {runtime_input_addresses}"
assert (
entry.output_addresses == runtime_output_addresses
), f"{entry.output_addresses} vs\n {runtime_output_addresses}"
entry.cuda_graph.replay()

View File

@ -494,7 +494,8 @@ class ModelConfig(Generic[TConfig]):
architectures = self.pretrained_config.architectures
if len(architectures
) == 1 and architectures[0] == "DeciLMForCausalLM":
mlp_hidden_size = self._infer_nemotron_ffn_mult()
mlp_hidden_size = self._infer_nemotron_ffn_mult(
) // self.mapping.tp_size
else:
raise ValueError(
f"Inferring mlp hidden size for model architecture: {architectures} isn't supported yet"

View File

@ -263,7 +263,9 @@ class Gemma3VLM(PreTrainedModel):
embedding_layer=self.llm.model.embed_tokens,
input_ids=input_ids,
mm_embeds=mm_embeds,
mm_token_ids=self.image_token_ids)
mm_token_ids=self.image_token_ids,
**kwargs,
)
logits = self.llm.forward(
attn_metadata=attn_metadata,
input_ids=input_ids,
@ -284,3 +286,7 @@ class Gemma3VLM(PreTrainedModel):
attn_metadata=attn_metadata)[-1]
image_features = self.mm_projector(image_features)
return image_features
@property
def mm_token_ids(self):
return self.image_token_ids

View File

@ -1052,7 +1052,8 @@ class HCXVisionForCausalLM(PreTrainedModel):
]
input_ids, input_embeds = fuse_input_embeds(self.llm.model.embed_tokens,
input_ids, mm_embeds)
input_ids, mm_embeds,
**kwargs)
output_prob = self.llm.forward(
attn_metadata=attn_metadata,
input_ids=input_ids,

View File

@ -1280,7 +1280,8 @@ class Llama4ForConditionalGeneration(SpecDecOneEngineForCausalLM[Llama4Model,
]
input_ids, inputs_embeds = fuse_input_embeds(self.model.embed_tokens,
input_ids, mm_embeds)
input_ids, mm_embeds,
**kwargs)
return super().forward(attn_metadata,
input_ids,
position_ids,

View File

@ -302,7 +302,8 @@ class LlavaNextVisionModel(nn.Module):
logger.warning_once(
"Image feature shape does not line up with the provided patch size. "
"You may be using the `default` vision_feature_select_strategy with a"
" visual encoder that does not have CLS.")
" visual encoder that does not have CLS.",
key="llava_next_vision_model_pack_image_features")
image_feature = image_feature.view(num_patch_height,
num_patch_width, height,
@ -474,7 +475,7 @@ class LlavaNextModel(PreTrainedModel):
for multimodal_param in multimodal_params
]
input_ids, inputs_embeds = fuse_input_embeds(
self.llm.model.embed_tokens, input_ids, mm_embeds)
self.llm.model.embed_tokens, input_ids, mm_embeds, **kwargs)
logits = self.llm.forward(attn_metadata, input_ids, position_ids,
inputs_embeds, return_context_logits)

View File

@ -408,6 +408,7 @@ class Mistral3VLM(PreTrainedModel):
input_ids=input_ids,
mm_embeds=mm_embeds,
mm_token_ids=self._image_token_ids,
**kwargs,
)
return self.llm.forward(
@ -501,6 +502,10 @@ class Mistral3VLM(PreTrainedModel):
]
return torch.cat(pixel_values), batched_image_sizes
@property
def mm_token_ids(self):
return self._image_token_ids
# Original implementation:
# https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/mistral3/modeling_mistral3.py#L66

View File

@ -105,48 +105,84 @@ def find_uncached_mm_embeds(
return sliced_mm_embeds
def filter_mm_token_from_input_ids(
input_ids: torch.IntTensor,
vocab_size: int,
mm_token_ids: Optional[torch.IntTensor] = None,
) -> Tuple[torch.IntTensor, torch.IntTensor]:
"""
Filter multimodal tokens from input_ids.
Args:
input_ids: shape [text_total_length + mm_total_length].
vocab_size: size of the model's vocabulary
mm_token_ids: possible token ids for multimodal tokens, if known. If not known and set to None, it is assumed that the multimodal tokens are out-of-vocabulary tokens i.e. the `input_ids` contains tokens >= vocab_size that represent the multimodal tokens.
Note:
This function involves host-device synchronization due to torch.where() (= torch.nonzero) requiring
host allocation. The output indices reside on the same device as input_ids.
Returns:
text_token_indices: indices of text tokens in the input_ids
mm_token_indices: indices of multimodal tokens in the input_ids
"""
if mm_token_ids is None:
# NOTE:
# If mm_token_ids is None, it is assumed that the multimodal
# tokens are out-of-vocab tokens i.e. the `input_ids` contains
# tokens >= vocab_size that represent the multimodal tokens.
# Since mm_token_ids can be unbounded in this case,
# using torch.isin() may not be performant.
# This provides a more performant alternative while keeping
# the flexibility of still specifying all possible mm_token_ids,
# if the user wants to.
mm_token_mask = input_ids >= vocab_size
text_token_mask = input_ids < vocab_size
else:
mm_token_ids = mm_token_ids.to(input_ids.device, dtype=input_ids.dtype)
mm_token_mask = torch.isin(input_ids, mm_token_ids)
text_token_mask = ~mm_token_mask
# NOTE: torch.where() enforces a host sync
text_token_indices = torch.where(text_token_mask)[0]
mm_token_indices = torch.where(mm_token_mask)[0]
return text_token_indices, mm_token_indices
def fuse_input_embeds(
embedding_layer: Embedding,
input_ids: torch.IntTensor,
mm_embeds: List[torch.Tensor],
mm_token_ids: Optional[torch.IntTensor] = None,
text_token_indices: Optional[torch.IntTensor] = None,
mm_token_indices: Optional[torch.IntTensor] = None,
**kwargs,
) -> Tuple[Optional[torch.FloatTensor], Optional[torch.FloatTensor]]:
"""
Fuse text and multimodal embeddings. input_ids is [text_total_length + mm_total_length] and mm_embed is [mm_total_length, hidden_dim]. We just need to fuse them into [text_total_length + mm_total_length, hidden_dim] by slice-and-assign to the corresponding entries.
Args:
embedding_layer: embedding layer of the model.
input_ids: shape [text_total_length + mm_total_length], flattened from List[(text_length1 + mm_total_length1), ..., (text_lengthi + mm_total_lengthi)]. For LLM model, the requests are inflight batched together, but the input_ids are flattened with padding removed. By the slice condition < vocab_size, we can easily separate text / multimodal tokens and naturally batched the LLM embedding lookup
mm_embed: List[(mm_total_length1, hidden_dim), ..., (mm_total_lengthi, hidden_dim)].
mm_token_ids: possible token ids for multimodal tokens, if known. If not known and set to None, it is assumed that the multimodal tokens are out-of-vocabulary tokens i.e. the `input_ids` contains tokens >= vocab_size that represent the multimodal tokens.
mm_embeds: List[(mm_total_length1, hidden_dim), ..., (mm_total_lengthi, hidden_dim)].
mm_token_ids: possible token ids for multimodal tokens, if known. If not known and set to None, it is assumed that the multimodal tokens are out-of-vocabulary tokens.
Returns:
- If (1) JIT test run, (2) non-multimodal run, i.e. all text-only requests, either context or generation phase (3) multimodal run, all requests in generation phase --> there is no multimodal data, return only the input_ids
- If (4) multimodal run, mixed batch of context and generation requests, each context request has a multimodal feature --> return only the fused input_embeds of shape [total length, hidden_dim]. For text tokens, LLM embedding layer has already run.
Note:
- Precedence: If kwargs provide indices (text_token_indices and mm_token_indices), those are used. If any one of them is not provided, fallback to filtering method. Sentinel-/OOV-based filtering (e.g., tokens >= vocab_size) is used only when neither index tensor and mm_token_ids is provided.
- This function may involve host-device synchronization if indices are not provided and filtering is performed. See filter_mm_token_from_input_ids for details.
"""
if len(mm_embeds) == 0:
return input_ids, None
mm_embed = torch.cat(mm_embeds, dim=0)
if mm_token_ids is None:
# NOTE:
# If mm_token_ids is None, it is assumed that the multimodal
# tokens are out-of-vocab tokens i.e. the `input_ids` contains
# tokens >= vocab_size that represent the multimodal tokens.
# Since mm_token_ids is be unbounded in this case,
# using torch.isin() may not be performant.
# This provides a more performant alternative while keeping
# the flexibility of still specifying all possible mm_token_ids,
# if the user wants to.
vocab_size = embedding_layer.num_embeddings
mm_token_mask = input_ids >= vocab_size
text_token_mask = input_ids < vocab_size
else:
mm_token_ids = mm_token_ids.to(input_ids.device)
mm_token_mask = torch.isin(input_ids, mm_token_ids)
text_token_mask = ~mm_token_mask
text_token_indices = torch.where(text_token_mask)[0]
mm_token_indices = torch.where(mm_token_mask)[0]
if len(mm_token_indices) != mm_embed.shape[0]:
# TODO: support the case where only one index tensor is provided, the other is derived as the complement (try to avoid implicit host-device synchronization)
if text_token_indices is None or mm_token_indices is None:
# NOTE: This function involves host-device synchronization due to torch.where() used in filter_mm_token_from_input_ids.
text_token_indices, mm_token_indices = filter_mm_token_from_input_ids(
input_ids,
vocab_size=embedding_layer.num_embeddings,
mm_token_ids=mm_token_ids)
if mm_token_indices.shape[0] != mm_embed.shape[0]:
raise ValueError(
f"Multimodal token count mismatch: found {len(mm_token_indices)} image tokens in input_ids "
f"but received {mm_embed.shape[0]} image embeddings. "
@ -159,8 +195,7 @@ def fuse_input_embeds(
device=text_embed.device,
dtype=text_embed.dtype)
input_embeds[text_token_indices, :] = text_embed.to(
dtype=input_embeds.dtype, device=input_embeds.device)
input_embeds[text_token_indices, :] = text_embed
input_embeds[mm_token_indices, :] = mm_embed.to(dtype=input_embeds.dtype,
device=input_embeds.device)

View File

@ -594,6 +594,7 @@ class Phi4MMForCausalLM(transformers.PreTrainedModel):
input_ids,
mm_embedding,
mm_token_ids=self.mm_token_ids,
**kwargs,
)
output_prob = self.llm.forward(

View File

@ -612,7 +612,8 @@ class Qwen2VLModelBase(PreTrainedModel):
'mrope_position_deltas']
input_ids, input_embeds = fuse_input_embeds(self.llm.model.embed_tokens,
input_ids, mm_embeds)
input_ids, mm_embeds,
**kwargs)
output_prob = self.llm.forward(
attn_metadata=attn_metadata,
input_ids=input_ids,

View File

@ -48,6 +48,9 @@ class Qwen3Attention(QKNormRoPEAttention):
rope=RopeParams.from_config(config),
)
# Qwen3 has accuracy issues with deep_gemm (see: https://nvbugspro.nvidia.com/bug/5461712)
# TODO: Consider adding disable_deep_gemm support to QKNormRoPEAttention if accuracy still remains
super().__init__(
hidden_size=config.hidden_size,
num_attention_heads=config.num_attention_heads,
@ -85,6 +88,7 @@ class Qwen3DecoderLayer(DecoderLayer):
dtype=config.torch_dtype,
config=model_config,
)
self.input_layernorm = RMSNorm(hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype)

View File

@ -1186,7 +1186,7 @@ class VilaModel(PreTrainedModel):
]
input_ids, inputs_embeds = fuse_input_embeds(
self.llm.model.embed_tokens, input_ids, mm_embeds)
self.llm.model.embed_tokens, input_ids, mm_embeds, **kwargs)
logits = self.llm.forward(attn_metadata=attn_metadata,
input_ids=input_ids,
position_ids=position_ids,

View File

@ -216,6 +216,7 @@ class Attention(nn.Module):
skip_create_weights_in_init=config.skip_create_weights_in_init,
allreduce_strategy=config.allreduce_strategy,
force_dynamic_quantization=config.force_dynamic_quantization)
self.o_lora = LoraLayer([LoraModuleType.ATTENTION_DENSE],
[self.hidden_size])

View File

@ -138,25 +138,23 @@ def pre_comm_embedding_ops(
padding_size: int,
):
# Generate the mask for the input if needed.
if tp_size > 1:
if tp_mode == TensorParallelMode.COLUMN:
input_, input_mask = get_masked_input_and_mask(
input_,
vocab_start_index,
vocab_end_index,
)
if tp_mode == TensorParallelMode.COLUMN:
input_, input_mask = get_masked_input_and_mask(
input_,
vocab_start_index,
vocab_end_index,
)
# Get the embeddings.
output = F.embedding(input_, weight)
# Mask or pad the output if needed.
if tp_size > 1:
if tp_mode == TensorParallelMode.COLUMN:
output.masked_fill_(input_mask, 0)
elif tp_mode == TensorParallelMode.ROW:
if gather_output:
if tp_rank == tp_size - 1 and padding_size > 0:
output = F.pad(output, (0, padding_size))
if tp_mode == TensorParallelMode.COLUMN:
output.masked_fill_(input_mask, 0)
elif tp_mode == TensorParallelMode.ROW:
if gather_output:
if tp_rank == tp_size - 1 and padding_size > 0:
output = F.pad(output, (0, padding_size))
return output
@ -205,12 +203,16 @@ class Embedding(LMHead):
self.vocab_end_index = num_embeddings
def forward(self, input):
# Run the ops before all_reduce/all_gather.
# We use torch.compile() to fuse the tiny pointwise ops before all_reduce/all_gather for Embedding module.
embedding_ops_func = torch.compile(
pre_comm_embedding_ops,
options={"max-autotune": True},
disable=not self.enable_torch_compile_for_embedding)
if self.tp_size > 1:
# Run the ops before all_reduce/all_gather.
# We use torch.compile() to fuse the tiny pointwise ops before all_reduce/all_gather for Embedding module.
embedding_ops_func = torch.compile(
pre_comm_embedding_ops,
options={"max-autotune": True},
disable=not self.enable_torch_compile_for_embedding)
else:
# Skip torch.compile when TP size is 1 to avoid unnecessary host overhead
embedding_ops_func = pre_comm_embedding_ops
output = embedding_ops_func(input, self.weight, self.tp_size,
self.tp_rank, self.tp_mode,
self.vocab_start_index,

View File

@ -3,6 +3,8 @@ from typing import Dict, List, Optional, Union
import torch
from torch import nn
from tensorrt_llm._utils import get_sm_version
from ...model_config import ModelConfig
from ...utils import Fp4QuantizedTensor, next_positive_power_of_2
from .interface import MoE, MoEWeightLoadingMode
@ -78,6 +80,11 @@ class TRTLLMGenFusedMoE(MoE):
swiglu_limit=swiglu_limit,
)
sm_version = get_sm_version()
if sm_version >= 120:
raise NotImplementedError(
"TRTLLMGenFusedMoE does not support SM120 and above.")
assert not self.smart_router, "Smart router is not supported in TRTLLMGenFusedMoE."
self.num_slots = self.num_experts

View File

@ -5,6 +5,7 @@ import torch
import torch.nn.functional as F
from torch import nn
from tensorrt_llm.logger import logger
from tensorrt_llm.mapping import Mapping
from ..distributed import AllReduceParams
@ -29,6 +30,7 @@ class GatedMLP(nn.Module):
reduce_output: bool = True,
layer_idx: Optional[int] = None,
use_cute_dsl_blockscaling_mm: bool = False):
super().__init__()
self.layer_idx = layer_idx
self.hidden_size = hidden_size
@ -98,12 +100,21 @@ class GatedMLP(nn.Module):
[LoraModuleType.MLP_GATE_UP],
[2 * self.intermediate_size // mapping.tp_size])
def _apply_activation(self, x):
def _apply_activation(self, x, *, has_lora: bool = False):
if self.activation == F.silu:
if self.down_proj.has_fp8_qdq or self.down_proj.has_w4a8_nvfp4_fp8:
return swiglu(x,
quant_scale=self.down_proj.input_scale,
quant_type=torch.float8_e4m3fn)
if has_lora:
# NOTE: This is a WAR, since LoRA grouped_gemm does not support FP8 yet.
# TODO: Remove this path when LoRA grouped_gemm supports FP8
# see: cpp/tensorrt_llm/thop/loraOp.cpp::lora_grouped_gemm
logger.warning(
f"GatedMLP._apply_activation: LoRA path active; forcing non-FP8 activation dtype bf16/fp16, layer_idx={self.layer_idx}"
)
return swiglu(x)
else:
return swiglu(x,
quant_scale=self.down_proj.input_scale,
quant_type=torch.float8_e4m3fn)
else:
return swiglu(x)
elif callable(self.activation):
@ -155,7 +166,7 @@ class GatedMLP(nn.Module):
if h1_lora is not None:
h1 = h1 + h1_lora
h2 = self._apply_activation(h1)
h2 = self._apply_activation(h1, has_lora=True)
output = self.down_proj(h2,
all_reduce_params=final_all_reduce_params,
lora_params=lora_params,

View File

@ -5,7 +5,6 @@ from tensorrt_llm._torch.models.checkpoints.base_checkpoint_loader import \
BaseCheckpointLoader
from tensorrt_llm.bindings.executor import ExecutorConfig
from ...builder import BuildConfig
from ...llmapi.llm_args import LoadFormat, SamplerType
from ...logger import logger
from ...mapping import Mapping
@ -119,7 +118,6 @@ EXETENDED_EXECUTOR_CONFIG_FIELDS = [
'backend',
'pytorch_backend_config',
'max_seq_len',
'tokens_per_block',
'mapping',
'hf_model_dir',
'mm_encoder_only',
@ -131,7 +129,6 @@ def update_executor_config(
backend: Optional[str] = None,
pytorch_backend_config: Optional[PyTorchConfig] = None,
mapping: Optional[Mapping] = None,
build_config: Optional[BuildConfig] = None,
speculative_config: Optional["DecodingBaseConfig"] = None,
hf_model_dir: Optional[str] = None,
max_input_len: Optional[int] = None,
@ -156,10 +153,6 @@ def update_executor_config(
logger.info(f"{executor_config.pytorch_backend_config}")
build_config = build_config or BuildConfig()
# TODO: move to pure-Python KvCacheConfig, and remove dependency on build_config.
executor_config.tokens_per_block = executor_config.tokens_per_block or build_config.plugin_config.tokens_per_block
executor_config.hf_model_dir = hf_model_dir
if max_input_len is not None:

View File

@ -1,6 +1,8 @@
from collections.abc import Generator
from copy import deepcopy
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
from itertools import pairwise
from typing import Any, Dict, List, Optional, TypeAlias, Union
import torch
@ -424,7 +426,10 @@ class LlmRequest(tensorrt_llm.bindings.internal.batch_manager.LlmRequest):
self.child_requests.append(py_request)
def convert_wordlist(word_list) -> List[List[int]]:
StopWordList: TypeAlias = list[list[int]]
def convert_wordlist(word_list) -> StopWordList:
"""Converts a wordlist from format:
[[word_0 token_0, word_0 token_1, ...], [word_1 token_0, ...], ...]]
@ -461,6 +466,16 @@ def convert_wordlist(word_list) -> List[List[int]]:
return [tokens, offsets]
def produce_stop_words(
py_stop_words_list: StopWordList) -> Generator[list[int], None, None]:
"""yield stop sequences from the output of `convert_wordlist` above."""
stop_words_list, prefix_sum = py_stop_words_list
for start, end in pairwise((0, *prefix_sum)): # first element: prepend 0
if end == -1: # -1 is a sentinel value in convert_wordlist
break
yield stop_words_list[start:end]
def executor_request_to_llm_request(
req_id: int,
executor_request: ExecutorRequest,

View File

@ -43,6 +43,7 @@ from ..metadata import KVCacheParams
from ..model_config import ModelConfig, MoeLoadBalancerConfig
from ..models import AutoModelForCausalLM
from ..models.checkpoints.base_checkpoint_loader import BaseCheckpointLoader
from ..models.modeling_multimodal_utils import filter_mm_token_from_input_ids
from ..models.modeling_utils import (DecoderModelForCausalLM, MetaInitMode,
timing)
from ..modules.fused_moe.moe_load_balancer import (
@ -1259,6 +1260,16 @@ class PyTorchModelEngine(ModelEngine):
return total_num_tokens, False, attn_all_rank_num_tokens
def _prepare_multimodal_indices(self, input_ids: list[int]):
input_ids = torch.tensor(input_ids, dtype=torch.int, device="cpu")
vocab_size = self.model.config.vocab_size
# TODO: unify naming of mm_token_ids across models
mm_token_ids = getattr(self.model, "mm_token_ids", None)
text_token_indices, mm_token_indices = filter_mm_token_from_input_ids(
input_ids, vocab_size=vocab_size, mm_token_ids=mm_token_ids)
return text_token_indices, mm_token_indices
def _prepare_tp_inputs(
self,
scheduled_requests: ScheduledRequests,
@ -1335,6 +1346,14 @@ class PyTorchModelEngine(ModelEngine):
request.py_batch_idx = request.py_seq_slot
if len(multimodal_params_list) > 0:
# discard the text token indices as it only includes context tokens at this moment
print(
f"len multimodal_params_list: {len(multimodal_params_list)} from model_engine"
)
_, mm_token_indices = self._prepare_multimodal_indices(input_ids)
else:
mm_token_indices = None
num_ctx_requests = len(scheduled_requests.context_requests)
num_ctx_tokens = len(input_ids)
@ -1698,6 +1717,14 @@ class PyTorchModelEngine(ModelEngine):
spec_metadata.all_rank_num_tokens = spec_all_rank_num_tokens
spec_metadata.all_rank_num_seqs = all_rank_num_seqs
if mm_token_indices is not None:
mask = torch.ones(total_num_tokens, dtype=torch.bool)
mask[mm_token_indices] = False
inputs['mm_token_indices'] = mm_token_indices.pin_memory().to(
"cuda", non_blocking=True)
inputs['text_token_indices'] = torch.where(mask)[0].pin_memory().to(
"cuda", non_blocking=True)
num_generation_tokens = len(generation_requests) + len(
extend_requests) + sum(draft_lens)
self.iter_states['num_ctx_requests'] = num_ctx_requests

View File

@ -61,6 +61,9 @@ PROFILE_RECORD_GC_ENV_VAR_NAME = "TLLM_PROFILE_RECORD_GC"
# Set to a path to save detailed tracing of PyTorch operations.
PROFILE_TRACE_ENV_VAR_NAME = "TLLM_TORCH_PROFILE_TRACE"
# Unique tag base to avoid collisions with token/logits comms
TERMINATION_COMM_TAG_BASE = 20000
@functools.cache
def _load_iteration_indexes(env_var: str):
@ -208,6 +211,7 @@ class PyExecutor:
self.kv_cache_manager = self.resource_manager.resource_managers.get(
ResourceManagerType.KV_CACHE_MANAGER)
self.enable_kv_cache_events = self.kv_cache_manager is not None and self.kv_cache_manager.event_buffer_max_size > 0
self.enable_kv_cache_reuse = self.kv_cache_manager is not None and self.kv_cache_manager.enable_block_reuse
self.max_input_len = max_input_len
# _executor_loop private data
@ -259,6 +263,13 @@ class PyExecutor:
self.gather_all_responses = False
self.kv_cache_transceiver = kv_cache_transceiver
# Initialize disagg PP termination handler if needed
self._disagg_pp_termination_handler = None
if self.dist.pp_size > 1 and self.enable_kv_cache_reuse and self.kv_cache_transceiver:
self._disagg_pp_termination_handler = DisaggPPTerminationHandler(
self.num_micro_batches, self.dist)
if self.dist.pp_size > 1:
self.event_loop = self._executor_loop_pp
else:
@ -718,6 +729,14 @@ class PyExecutor:
batch_state.sample_state.scheduled_requests), req_stats)
def _executor_loop_cleanup(self):
for h in self.send_handles:
if h is not None:
h.wait()
if self._disagg_pp_termination_handler is not None:
self._disagg_pp_termination_handler.cleanup()
with self.response_cv:
self.is_shutdown = True
self.response_cv.notify_all()
@ -826,6 +845,7 @@ class PyExecutor:
sample_state = self._sample_async(
scheduled_batch, batch_outputs)
assert sample_state is not None, "Sampling failed"
self._update_request_states(scheduled_batch)
if self.enable_iter_perf_stats:
@ -905,6 +925,12 @@ class PyExecutor:
if self.kv_cache_transceiver and self.ctx_in_transmission_requests:
self._terminate_ctx_finished_requests()
if self._disagg_pp_termination_handler is not None:
requests_to_terminate = self._disagg_pp_termination_handler.sync(
prev_microbatch_id)
for req in requests_to_terminate:
self._do_terminate_request(req)
# march forward in microbatch slots
microbatch_id = (microbatch_id + 1) % self.num_micro_batches
@ -1696,9 +1722,13 @@ class PyExecutor:
self._enqueue_responses(error_responses.items())
def _terminate_request(self, request: LlmRequest):
if self.kv_connector_manager is None:
self.resource_manager.free_resources(request)
if self._disagg_pp_termination_handler is not None:
self._disagg_pp_termination_handler.terminate(request)
else:
self._do_terminate_request(request)
def _do_terminate_request(self, request: LlmRequest):
if self.kv_connector_manager is not None:
# Only call request_finished on the connector if the request has already been added to the kv cache manager.
try:
cache_block_ids = self.kv_cache_manager.get_cache_indices(
@ -1711,6 +1741,8 @@ class PyExecutor:
if not self.kv_connector_manager.request_finished(
request, cache_block_ids):
self.resource_manager.free_resources(request)
else:
self.resource_manager.free_resources(request)
@nvtx_range("_handle_canceled_requests")
def _handle_canceled_requests(self):
@ -1919,3 +1951,104 @@ class PyExecutor:
"""Remove reqids of current requests from self.inflight_req_ids."""
for req in scheduled_requests.all_requests():
self.inflight_req_ids.erase(req.request_id)
class DisaggPPTerminationHandler:
"""Handles termination synchronization across pipeline parallel ranks under disaggregated serving.
We require synchronization when terminating requests in disaggregated PP when
KV cache reuse is enabled. All PP ranks need to reach consensus before freeing
resources to avoid a NCCL hang.
"""
def __init__(self, num_micro_batches: int, dist):
self.dist = dist
# Request termination synchronization across PP ranks
# {request_id: {'ready_to_terminate': set{ranks}, 'terminated': {ranks}}}
self.pending_termination = {}
self.termination_handles = [None] * num_micro_batches
# Local map from request_id -> local LlmRequest awaiting consensus termination
self.local_termination = {}
def terminate(self, request: LlmRequest) -> bool:
req_key = request.py_request_id
self.local_termination[req_key] = request
state = self.pending_termination.get(req_key, None)
if state is None:
state = {'ready_to_terminate': set(), 'terminated': set()}
self.pending_termination[req_key] = state
if self.dist.rank not in state['ready_to_terminate']:
state['ready_to_terminate'].add(self.dist.rank)
return False
def sync(self, microbatch_id: int) -> List[LlmRequest]:
"""Ring-communicate pending termination state and apply local terminations upon consensus.
Each rank sends its current pending_termination snapshot to the next PP rank
and receives the previous rank's snapshot. After merging, apply any terminations
that have reached consensus (i.e., all PP ranks are ready).
"""
snapshot = {
req_id: {
'ready_to_terminate': state.get('ready_to_terminate', set()),
'terminated': state.get('terminated', set()),
}
for req_id, state in self.pending_termination.items()
}
if self.termination_handles[microbatch_id] is not None:
self.termination_handles[microbatch_id].wait()
term_tag = TERMINATION_COMM_TAG_BASE + microbatch_id
self.termination_handles[microbatch_id] = self.dist.isend_object(
snapshot,
dest=self.dist.next_pp_rank,
tag=term_tag,
)
remote_state = self.dist.recv_object(
src=self.dist.prev_pp_rank,
tag=term_tag,
)
logger.debug(
f"received remote state for microbatch {microbatch_id}, prev pp rank: {self.dist.prev_pp_rank} state {remote_state}"
)
if remote_state:
for req_id, state in remote_state.items():
local = self.pending_termination.get(req_id)
if local is None:
self.pending_termination[req_id] = {
'ready_to_terminate': state.get('ready_to_terminate',
set()),
'terminated': state.get('terminated', set()),
}
else:
for key in ('ready_to_terminate', 'terminated'):
for r in state.get(key, []):
if r not in local[key]:
local[key].add(r)
requests_to_terminate = []
to_delete = []
for req_id, state in self.pending_termination.items():
ready = state.get('ready_to_terminate', set())
done = state.get('terminated', set())
# If all PP ranks are ready to terminate the request, we can free the resources
if len(ready) >= self.dist.pp_size and self.dist.rank not in done:
local_req = self.local_termination.get(req_id)
if local_req is not None:
requests_to_terminate.append(local_req)
done.add(self.dist.rank)
if len(done) >= self.dist.pp_size:
to_delete.append(req_id)
if req_id in self.local_termination:
self.local_termination.pop(req_id, None)
for req_id in to_delete:
self.pending_termination.pop(req_id, None)
return requests_to_terminate
def cleanup(self):
for h in self.termination_handles:
if h is not None:
h.wait()

View File

@ -1,12 +1,16 @@
from abc import ABC, abstractmethod
from collections.abc import Iterable
from dataclasses import dataclass
from typing import List, Literal, Optional
from functools import cached_property
from typing import List, Literal, Optional, TypeAlias
import numpy as np
import torch
from tensorrt_llm._torch.pyexecutor.make_decoding_batch_input_output import \
MakeDecodingBatchInputOutput
from tensorrt_llm._torch.pyexecutor.sampler_utils import (
BEAM_0, SINGLE_BEAM_WIDTH, handle_stop_single_beam)
from tensorrt_llm._utils import nvtx_range, torch_dtype_to_binding
from tensorrt_llm.bindings import (CudaStream, DataType, ModelConfig,
WorldConfig, make_sampling_config)
@ -355,21 +359,55 @@ def int_tensor(shape: tuple[int, ...], device: str = 'cuda') -> torch.Tensor:
return torch.empty(shape, dtype=torch.int, device=device)
class TorchStore:
def __init__(self, *, max_draft_len: int, max_num_sequences: int,
max_beam_width: int):
self.max_draft_len = max_draft_len
self.max_num_sequences = max_num_sequences
self.max_beam_width = max_beam_width
self.max_tokens = max_draft_len + 1
assert max_beam_width == SINGLE_BEAM_WIDTH, "TorchSampler only supports beam_width = 1"
self.new_tokens = int_tensor(
(self.max_tokens, max_num_sequences, max_beam_width))
"""Shape: See cpp DecoderState.getAllNewTokens()"""
self.finish_reasons = int_tensor(self.new_tokens.shape)
# Helper tensors for finish_reasons:
self._finish_reasons_nonzero_static_buffer = torch.empty(
(self.max_tokens * max_num_sequences, 2),
device='cuda',
dtype=torch.int64)
"""Preallocate buffer needed for torch.nonzero_static(..., out=finish_reasons_nonzero_static_buffer), see `def _write_reason`"""
self._reason_tensors = {
reason:
torch.tensor(reason.value,
dtype=self.finish_reasons.dtype,
device="cuda")
for reason in [
FinishReason.NOT_FINISHED, FinishReason.END_ID,
FinishReason.STOP_WORDS, FinishReason.LENGTH,
FinishReason.TIMED_OUT, FinishReason.CANCELLED
] # `in FinishReason` clashes with PyBind11: `TypeError: 'pybind11_type' object is not iterable`
}
@dataclass(kw_only=True)
class SampleStateTensorsHostTorch(SampleStateTensors):
finish_reasons: torch.Tensor
@dataclass(kw_only=True)
class SampleStateTorch(SampleState):
host: SampleStateTensorsHostTorch
class TorchSampler(Sampler):
BEAM = 0
MAX_BEAM_WIDTH = BEAM + 1
SampleState = SampleStateTorch
def is_generation_model(self) -> bool:
return True
@dataclass(frozen=True, kw_only=True)
class Store:
new_tokens: torch.Tensor
"""Shape: See cpp DecoderState.getAllNewTokens()"""
def create_store(self) -> Store:
return self.Store(new_tokens=int_tensor(self.NEW_TOKENS_SHAPE))
@dataclass(frozen=True, kw_only=True)
class Args:
max_seq_len: int
@ -381,17 +419,16 @@ class TorchSampler(Sampler):
def __init__(self, args: Args):
self.max_seq_len = args.max_seq_len
self.enable_mixed_sampler = args.enable_mixed_sampler
self.max_tokens = args.max_draft_len + 1
assert args.max_beam_width == self.MAX_BEAM_WIDTH, "TorchSampler only supports beam_width = 1"
self.max_num_sequences = args.max_num_sequences
self.NEW_TOKENS_SHAPE = (self.max_tokens, self.max_num_sequences,
self.MAX_BEAM_WIDTH)
# AutoDeploy build creates the sampler in inference mode,
# which would disallow in-place mutating of new_tokens.
# So, we temporarily exit inference mode.
with torch.inference_mode(False):
self.store = self.create_store()
self.store = TorchStore(max_draft_len=args.max_draft_len,
max_num_sequences=args.max_num_sequences,
max_beam_width=args.max_beam_width)
self.max_num_sequences = args.max_num_sequences
self.max_tokens = self.store.max_tokens
# Initialize seed for multi-GPU consistency
self._global_seed = 42
@ -412,50 +449,7 @@ class TorchSampler(Sampler):
self._generator.manual_seed(self._global_seed)
return self._generator
def _meet_max_token_stop_criteria(self, request: LlmRequest):
num_tokens = request.get_num_tokens(self.BEAM)
return (num_tokens - request.py_orig_prompt_len
>= request.py_max_new_tokens) or (num_tokens
>= self.max_seq_len)
@staticmethod
def _meet_stop_token_criteria(request: LlmRequest):
if request.py_stop_words_list:
assert isinstance(
request.py_stop_words_list,
list), "request.py_stop_words_list should be a list"
stop_words_list, prefix_sum = request.py_stop_words_list
tokens = request.get_tokens(0)
offset = 0
for i, offset_end in enumerate(prefix_sum):
if i > 0:
offset = prefix_sum[i - 1]
stop_word = stop_words_list[offset:offset_end]
if len(stop_word) > len(tokens):
continue
if tokens[-len(stop_word):] == stop_word:
return True
return False
def _handle_stop_criteria(self, request: LlmRequest,
new_token: int) -> bool:
"""Handle stop criteria and set appropriate finish reasons and state.
Returns True if generation should stop."""
if new_token == request.py_end_id:
request.finish_by(FinishReason.END_ID, self.BEAM)
return True
if self._meet_max_token_stop_criteria(request):
request.finish_by(FinishReason.LENGTH, self.BEAM)
return True
if self._meet_stop_token_criteria(request):
request.finish_by(FinishReason.STOP_WORDS, self.BEAM)
return True
return False
def handle_logprobs(self, request: LlmRequest, state: SampleState, *,
def handle_logprobs(self, request: LlmRequest, state: SampleStateTorch, *,
beam: int, count: int):
current_slice = slice(0, count), request.py_seq_slot, beam
if request.py_return_log_probs:
@ -469,10 +463,26 @@ class TorchSampler(Sampler):
assert beam == 0, "The following call relies on beam_width to be 1 - hence the list with a single element"
request.py_result.append_log_probs([token_log_probs])
def _process_draft_tokens_greedy(self, request: LlmRequest,
new_tokens: torch.Tensor) -> int:
new_token = add_token(request, new_tokens, beam=self.BEAM)
stop = self._handle_stop_criteria(request, new_token)
FinishReasons: TypeAlias = list[list[int]]
"""`(num_seq_slots, num_steps)`"""
@classmethod
def finish_if_reason(cls, request: LlmRequest,
finish_reasons: FinishReasons, *, step: int) -> bool:
reason = FinishReason(finish_reasons[request.py_seq_slot][step])
valid_reasons = {
FinishReason.END_ID, FinishReason.LENGTH, FinishReason.STOP_WORDS
}
if reason in valid_reasons:
request.finish_by(reason, BEAM_0)
return True
return False
def _process_draft_tokens_greedy(self, request: LlmRequest, *,
new_tokens: torch.Tensor,
finish_reasons: FinishReasons) -> int:
new_token = add_token(request, new_tokens, beam=BEAM_0)
stop = self.finish_if_reason(request, finish_reasons, step=0)
if stop or get_draft_token_length(request) == 0:
return 0
num_accepted = 0
@ -485,14 +495,17 @@ class TorchSampler(Sampler):
num_accepted += 1
new_token = add_token(request,
new_tokens,
beam=self.BEAM,
beam=BEAM_0,
step=num_accepted)
if self._handle_stop_criteria(request, new_token):
if self.finish_if_reason(request, finish_reasons,
step=num_accepted):
break
return num_accepted
def _process_draft_tokens_rejection_sampling(
self, request: LlmRequest, new_tokens: torch.Tensor) -> int:
"""We cannot use finish_if_reason in _process_draft_tokens_rejection_sampling because it *writes to new_tokens*,
rendering the finish reason calculation in sample_async stale (incorrect) for this batch"""
sampling_strategy = request_strategy(request)
generator = self.get_generator(request.py_draft_logits.device)
_, draft_probs = sample(sampling_strategy,
@ -503,7 +516,6 @@ class TorchSampler(Sampler):
generator,
request.py_draft_tokens)
sample_last = True
stop = False
if rejected_indices.numel() == 0:
num_initially_accepted = get_draft_token_length(request)
sample_last = False
@ -512,59 +524,62 @@ class TorchSampler(Sampler):
num_accepted = num_initially_accepted
for i in range(num_accepted):
new_token = request.py_draft_tokens[i]
new_tokens[i, request.seq_slot, self.BEAM] = new_token
request.add_new_token(new_token, self.BEAM)
stop = self._handle_stop_criteria(request, new_token)
if stop:
new_tokens[i, request.seq_slot, BEAM_0] = new_token
request.add_new_token(new_token, BEAM_0)
if handle_stop_single_beam(request,
new_token,
max_seq_len=self.max_seq_len):
num_accepted = i + 1
return num_accepted
if sample_last:
new_token = sample_rejected(draft_probs, target_probs, generator,
num_accepted)
new_tokens[num_accepted, request.seq_slot, self.BEAM] = new_token
request.add_new_token(new_token, self.BEAM)
stop = self._handle_stop_criteria(request, new_token)
new_tokens[num_accepted, request.seq_slot, BEAM_0] = new_token
request.add_new_token(new_token, BEAM_0)
handle_stop_single_beam(request,
new_token,
max_seq_len=self.max_seq_len)
else:
new_token = add_token(request,
new_tokens,
beam=self.BEAM,
beam=BEAM_0,
step=num_accepted)
stop = self._handle_stop_criteria(request, new_token)
handle_stop_single_beam(request,
new_token,
max_seq_len=self.max_seq_len)
return num_accepted
def process_draft_tokens(self, request: LlmRequest,
new_tokens: torch.Tensor) -> int:
if request.py_draft_logits is None:
return self._process_draft_tokens_greedy(request, new_tokens)
else:
return self._process_draft_tokens_rejection_sampling(
request, new_tokens)
def update_requests(self, state: SampleState) -> None:
assert isinstance(state, SampleState)
def update_requests(self, state: SampleStateTorch) -> None:
assert isinstance(state, SampleStateTorch)
if state.sampler_event:
state.sampler_event.synchronize()
new_tokens = state.host.new_tokens
finish_reasons = state.host.finish_reasons[:, :, BEAM_0].T.tolist()
for req in state.scheduled_requests.context_requests:
if req.state == LlmRequestState.GENERATION_COMPLETE or req.context_remaining_length != 0:
continue
new_token = add_token(req, new_tokens, beam=self.BEAM)
self._handle_stop_criteria(req, new_token)
self.handle_logprobs(req, state, beam=self.BEAM, count=1)
add_token(req, new_tokens, beam=BEAM_0)
self.finish_if_reason(req, finish_reasons, step=0)
self.handle_logprobs(req, state, beam=BEAM_0, count=1)
req.py_decoding_iter += 1
for req in state.scheduled_requests.generation_requests:
if req.state == LlmRequestState.GENERATION_COMPLETE:
continue
processed = 1
num_accepted = self.process_draft_tokens(req, new_tokens)
if req.py_draft_logits is None:
num_accepted = self._process_draft_tokens_greedy(
req, new_tokens=new_tokens, finish_reasons=finish_reasons)
else:
num_accepted = self._process_draft_tokens_rejection_sampling(
req, new_tokens)
if get_draft_token_length(req) > 0:
req.py_num_accepted_draft_tokens = num_accepted
req.py_rewind_len = req.py_draft_pages_allocated - num_accepted
processed += num_accepted
self.handle_logprobs(req, state, beam=self.BEAM, count=processed)
self.handle_logprobs(req, state, beam=BEAM_0, count=processed)
req.py_decoding_iter += 1
def log_probs_host(self, scheduled_requests: ScheduledRequests):
@ -572,29 +587,50 @@ class TorchSampler(Sampler):
if any(req.py_return_log_probs
for req in scheduled_requests.all_requests()):
return torch.empty(
(self.max_num_sequences, self.MAX_BEAM_WIDTH, self.max_tokens),
(self.max_num_sequences, SINGLE_BEAM_WIDTH, self.max_tokens),
device="cpu",
pin_memory=True)
return None
def sample_async(self, scheduled_requests: ScheduledRequests,
model_outputs: dict[str, torch.Tensor],
num_context_logits_prefix_sum: list[int]) -> SampleState:
def sample_async(
self, scheduled_requests: ScheduledRequests,
model_outputs: dict[str, torch.Tensor],
num_context_logits_prefix_sum: list[int]) -> SampleStateTorch:
requests = scheduled_requests.all_requests()
new_tokens = self.store.new_tokens
finish_reasons = self.store.finish_reasons
log_probs_host = self.log_probs_host(scheduled_requests)
seq_slots_host = torch.tensor(
[r.py_seq_slot for r in requests],
dtype=torch.int64, # for index_fill_
pin_memory=True)
seq_slots = seq_slots_host.to(device="cuda", non_blocking=True)
self._process_requests(scheduled_requests,
model_outputs,
new_tokens,
num_context_logits_prefix_sum,
seq_slots=seq_slots,
seq_slots_host=seq_slots_host,
log_probs_host=log_probs_host)
self._write_finish_reasons(requests,
finish_reasons=finish_reasons,
seq_slots=seq_slots,
new_tokens=new_tokens)
new_tokens_host = new_tokens.to(device="cpu", non_blocking=True)
finish_reasons_host = finish_reasons.to(device="cpu", non_blocking=True)
sampler_event = torch.cuda.Event()
sampler_event.record()
return SampleState(scheduled_requests=scheduled_requests,
device=SampleStateTensors(new_tokens=new_tokens),
host=SampleStateTensors(new_tokens=new_tokens_host,
log_probs=log_probs_host),
sampler_event=sampler_event)
return SampleStateTorch(
scheduled_requests=scheduled_requests,
device=SampleStateTensors(new_tokens=new_tokens),
host=SampleStateTensorsHostTorch(
new_tokens=new_tokens_host,
log_probs=log_probs_host,
finish_reasons=finish_reasons_host,
),
sampler_event=sampler_event,
)
@staticmethod
def append_eagle3(tokens: torch.Tensor, model_outputs):
@ -654,15 +690,202 @@ class TorchSampler(Sampler):
return logits
@staticmethod
def _longest_stop_word_len(requests: Iterable[LlmRequest]) -> int:
max_stop_word_len = 0
for req in requests:
_, cumsum = req.py_stop_words_list
if -1 in cumsum:
cumsum = cumsum[:cumsum.index(-1)]
request_max_stop_word_len = np.max(np.diff(cumsum, prepend=0),
initial=0)
max_stop_word_len = max(max_stop_word_len,
request_max_stop_word_len)
return max_stop_word_len
@staticmethod
def _requests_with_stop_words(
requests: list[LlmRequest]) -> list[LlmRequest]:
return [
r for r in requests if (r.py_stop_words_list is not None
and len(r.py_stop_words_list[0]) > 0)
]
def _write_reason(self, finish_reasons: torch.Tensor, reason: FinishReason,
*, where: torch.Tensor, seq_slots: torch.Tensor) -> None:
"""Avoid GPU<->CPU syncs via:
### `nonzero_static` [REF-A], see: https://ianbarber.blog/2024/12/18/nonzero_static-in-pytorch/.
- `nonzero` syncs (frontend needs result size).
- `nonzero_static` pads with dummy entries (`fill_value`), written into a prealloc buffer (max_num_sequences, 2).
- Need to drop padding, but `buffer[buffer!=fill_value]`, `buffer[:count_nonzero]`, `buffer[:sum]` all sync.
### Hack:
1. Use `fill_value=0`, so padding is `[..., [0,0], [0,0]]`.
2. Write blindly to `finish_reasons` [REF-B]. Only `[seq_slot[0],0]` might have wrong values written to it, because of the padding entries.
3. Save `[seq_slot[0],0]` in `before_write` [REF-C], restore if `where[0][0]` is `False` [REF-D].
"""
assert seq_slots.is_cuda and where.is_cuda
assert seq_slots.shape[0] == where.shape[1]
first_slot = seq_slots[0].unsqueeze(0)
before_write = finish_reasons[0][:].index_select(
0, first_slot).squeeze() # REF-C
reason_tensor = self.store._reason_tensors[reason]
buffer = self.store._finish_reasons_nonzero_static_buffer
size = buffer.shape[0]
torch.nonzero_static(where, size=size, fill_value=0,
out=buffer) # REF-A
r, c = buffer[:, 0], buffer[:, 1]
finish_reasons[r, seq_slots[c], BEAM_0] = reason_tensor # REF-B
correct = torch.where(~where[0, 0], before_write, reason_tensor).view(1)
assert correct.is_cuda
finish_reasons[0, first_slot, BEAM_0] = correct # REF-D
def _write_finish_reasons(self, requests: list[LlmRequest], *,
finish_reasons: torch.Tensor,
seq_slots: torch.Tensor,
new_tokens: torch.Tensor) -> None:
"""later _write_reason overwrites earlier, in reverse precedence order"""
tokens = new_tokens[:, seq_slots, BEAM_0]
# we need to fill with NOT_FINISHED so we can differentiate between previous requests that had the same seq slot
finish_reasons.index_fill_(1, seq_slots,
FinishReason.NOT_FINISHED.value)
if with_stop_words := self._requests_with_stop_words(requests):
stop_seq_slots = torch.tensor(
[r.py_seq_slot for r in with_stop_words],
pin_memory=True).to("cuda", non_blocking=True)
stop_tokens = new_tokens[:, stop_seq_slots, BEAM_0]
self._write_reason(
finish_reasons,
FinishReason.STOP_WORDS,
where=self._are_stop_words(with_stop_words, stop_tokens),
seq_slots=stop_seq_slots,
)
self._write_reason(
finish_reasons,
FinishReason.LENGTH,
where=self._are_max_length(requests),
seq_slots=seq_slots,
)
self._write_reason(
finish_reasons,
FinishReason.END_ID,
where=self._are_end_id(requests, tokens),
seq_slots=seq_slots,
)
def _are_end_id(self, requests: list[LlmRequest],
tokens: torch.Tensor) -> torch.Tensor:
end_ids_tensor = torch.tensor(
[([req.py_end_id if req.py_end_id is not None else -1] *
self.max_tokens) for req in requests],
pin_memory=True,
dtype=tokens.dtype).T.to(device="cuda", non_blocking=True)
return tokens == end_ids_tensor
def _are_max_length(self, requests: list[LlmRequest]) -> torch.Tensor:
lengths_tensor = torch.tensor([[
((req.get_num_tokens(BEAM_0) + num_tokens) - req.py_orig_prompt_len)
for num_tokens in range(1, self.max_tokens + 1)
] for req in requests])
max_lengths_tensor = torch.tensor([
([min(req.py_max_new_tokens, self.max_seq_len)] * self.max_tokens)
for req in requests
])
return (lengths_tensor
>= max_lengths_tensor).T.pin_memory().to(device="cuda",
non_blocking=True)
_PAD_ID = -1
"""Pad with negative, doesn't matter what"""
@cached_property
def _pad_steps_mask(self):
square = torch.ones(self.max_tokens, self.max_tokens, dtype=torch.bool)
pad_id = torch.tensor(self._PAD_ID)
mask = torch.where(square.tril(), torch.tensor(1), pad_id)
mask.pin_memory()
return mask.to("cuda", non_blocking=True)
def _padded_old_tokens(self,
requests: list[LlmRequest],
new_tokens: torch.Tensor,
pad_id: int = _PAD_ID) -> torch.Tensor:
# TODO: make sure only the lookback tokens are pulled into the list
longest = self._longest_stop_word_len(requests)
assert longest > 0, f"{longest=}, longest stop word length should be greater than 0, as this code path is only reached with requests with stop words"
lookback = longest - 1
old_tokens = []
for request in requests:
old = request.get_tokens(BEAM_0)[-lookback:] if lookback > 0 else []
padded = [pad_id] * max(0, lookback - len(old)) + old
old_tokens.append([padded] * self.max_tokens)
old_tokens_tensor = torch.tensor(old_tokens,
pin_memory=True).to("cuda",
non_blocking=True)
assert old_tokens_tensor.shape == (
len(requests), self.max_tokens, lookback
), f"{old_tokens_tensor.shape} != ({len(requests)=}, {self.max_tokens=}, {lookback=})"
new_tokens = new_tokens.T.unsqueeze(1) * self._pad_steps_mask
ret = torch.cat((old_tokens_tensor, new_tokens), dim=-1)
assert ret.shape == (
len(requests), self.max_tokens, lookback + self.max_tokens
), f"{ret.shape} != ({len(requests)=}, {self.max_tokens=}, {lookback + self.max_tokens=})"
return ret
def _are_stop_words(self, requests: list[LlmRequest],
tokens: torch.Tensor) -> torch.Tensor:
per_step = torch.zeros((self.max_tokens, len(requests)),
dtype=torch.bool,
pin_memory=True).to("cuda", non_blocking=True)
padded_tokens = self._padded_old_tokens(requests, tokens)
def request_stop_words(request: LlmRequest, new_tokens: torch.Tensor):
swl, ends = request.py_stop_words_list
if -1 in ends:
ends = ends[:ends.index(-1)]
lens = np.diff(ends, prepend=0)
lens_device = torch.tensor(list(lens),
pin_memory=True).to("cuda",
non_blocking=True)
max_len = np.max(lens)
words = torch.zeros(len(lens),
max_len,
dtype=torch.int32,
pin_memory=True)
for step, (start, l) in enumerate(zip([0] + ends, lens)):
words[step, :l] = torch.tensor(swl[start:start + l],
dtype=torch.int32)
words_device = words.to("cuda", non_blocking=True)
for step, step_seq in enumerate(new_tokens):
for word, L in zip(words_device, lens_device):
truncated_seq = step_seq[step_seq >= 0][-L:]
if torch.equal(truncated_seq, word[-L:]):
# We don't care about subsequent steps because we already found a stop word match
return step
return None
for request_idx, request in enumerate(requests):
step = request_stop_words(request, padded_tokens[request_idx])
if step is not None:
per_step[step][request_idx] = True
return per_step
def _process_requests(self,
scheduled_requests: ScheduledRequests,
model_outputs: dict[str, torch.Tensor],
new_tokens: torch.Tensor,
num_context_logits_prefix_sum: list[int],
*,
seq_slots: torch.Tensor,
seq_slots_host: torch.Tensor,
log_probs_host: torch.Tensor | None = None):
beam_width = self.MAX_BEAM_WIDTH
beam = self.BEAM
# raw_logits should contain only the logits from the gen requests.
# If return context logits is requested, fetch only the logits from gen requests.
@ -687,16 +910,13 @@ class TorchSampler(Sampler):
no_draft_tokens = len(requests) == sum_steps
fast_path = not self.enable_mixed_sampler and no_draft_tokens and log_probs_host is None
seq_slots_host = torch.as_tensor([r.py_seq_slot for r in requests])
seq_slots = seq_slots_host.to(device="cuda", non_blocking=True)
if fast_path:
logits = raw_logits[:len(requests)]
logits = self._apply_embedding_bias(logits, requests)
next_tokens = torch.argmax(logits, dim=-1)
self.append_eagle3(next_tokens, model_outputs)
int_next_tokens = next_tokens.to(torch.int, non_blocking=True)
next_tokens = int_next_tokens.view(1, -1, beam_width)
next_tokens = int_next_tokens.view(1, -1, SINGLE_BEAM_WIDTH)
new_tokens[:1].index_copy_(1, seq_slots, next_tokens)
return
@ -737,17 +957,17 @@ class TorchSampler(Sampler):
# Batched processing already applied bias, just use the results
next_tokens = batched_next_tokens[input_slice]
softmax = batched_softmax[input_slice]
current_slice = slice(0, steps), slot, beam
current_slice = slice(0, steps), slot, BEAM_0
new_tokens[current_slice] = next_tokens
if request.py_draft_logits is not None:
request.py_target_probs = softmax.clone()
if log_probs_host is not None:
assert beam == 0, "The following call relies on beam_width to be 1 - hence the unsqueeze"
assert BEAM_0 == 0, "The following call relies on beam_width to be 1 - hence the unsqueeze"
token_probs = torch.gather(
softmax, dim=1, index=next_tokens.unsqueeze(1)).squeeze(-1)
log_probs = torch.log(token_probs)
log_probs_host[slot, beam, :steps].copy_(log_probs,
non_blocking=True)
log_probs_host[slot, BEAM_0, :steps].copy_(log_probs,
non_blocking=True)
offset += steps

View File

@ -0,0 +1,61 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tensorrt_llm.bindings.executor import FinishReason
from .llm_request import LlmRequest, produce_stop_words
BEAM_0 = 0
SINGLE_BEAM_WIDTH = 1
def max_token_criteria_single_beam(request: LlmRequest,
max_seq_len: int) -> bool:
num_tokens = request.get_num_tokens(BEAM_0)
return (num_tokens - request.py_orig_prompt_len
>= request.py_max_new_tokens) or (num_tokens >= max_seq_len)
def stop_token_criteria(py_stop_words_list: list[list[int]] | None,
tokens: list[int]) -> bool:
if py_stop_words_list:
assert isinstance(py_stop_words_list,
list), "request.py_stop_words_list should be a list"
for stop_word in produce_stop_words(py_stop_words_list):
if len(stop_word) > len(tokens):
continue
if tokens[-len(stop_word):] == stop_word:
return True
return False
def handle_stop_single_beam(request: LlmRequest, new_token: int, *,
max_seq_len: int) -> bool:
"""Handle stop criteria and set appropriate finish reasons and state.
Returns True if generation should stop."""
if new_token == request.py_end_id:
request.finish_by(FinishReason.END_ID, BEAM_0)
return True
if max_token_criteria_single_beam(request, max_seq_len):
request.finish_by(FinishReason.LENGTH, BEAM_0)
return True
if stop_token_criteria(request.py_stop_words_list,
request.get_tokens(BEAM_0)):
request.finish_by(FinishReason.STOP_WORDS, BEAM_0)
return True
return False

View File

@ -20,9 +20,12 @@ from tensorrt_llm._torch.speculative.interface import SpecMetadata
@contextmanager
def save_metadata_state(attn_metadata: AttentionMetadata,
spec_metadata: SpecMetadata) -> None:
attn_metadata.prepare_for_spec_dec("_seq_lens", "_seq_lens_cuda",
"kv_lens_cuda")
attn_metadata.prepare_for_spec_dec("_seq_lens", "_seq_lens_cuda")
batch_size = attn_metadata.num_seqs
# Do not use prepare_for_spec_dec for this special field.
# TRTLLM attention uses views of this tensor internally and prepare_for_spec_dec
# creates a copy. If you write to the copy, TRTLLM attention won't see the updates.
kv_lens = attn_metadata.kv_lens_cuda[:batch_size].clone()
if attn_metadata.is_cuda_graph:
assert spec_metadata.is_cuda_graph
@ -39,6 +42,8 @@ def save_metadata_state(attn_metadata: AttentionMetadata,
yield
finally:
attn_metadata.restore_from_spec_dec()
attn_metadata.kv_lens_cuda[:batch_size].copy_(kv_lens)
if attn_metadata.is_cuda_graph:
spec_metadata.num_tokens = num_tokens
if isinstance(spec_metadata, Eagle3SpecMetadata):

View File

@ -10,8 +10,11 @@ from ..model_config import ModelConfig
from ..pyexecutor.guided_decoder import CapturableGuidedDecoder
from ..pyexecutor.llm_request import LlmRequest, LlmRequestState
from ..pyexecutor.resource_manager import BaseResourceManager, SlotManager
from ..pyexecutor.sampler import (SampleState, SampleStateTensors, TorchSampler,
add_token, int_tensor)
from ..pyexecutor.sampler import (Sampler, SampleState, SampleStateTensors,
TorchSampler, TorchStore, add_token,
int_tensor)
from ..pyexecutor.sampler_utils import (BEAM_0, SINGLE_BEAM_WIDTH,
handle_stop_single_beam)
from ..pyexecutor.scheduler import ScheduledRequests
from .interface import SpecMetadata
@ -206,40 +209,44 @@ class MTPSpecMetadata(SpecMetadata):
self.slot_ids[:num_seqs].copy_(mtp_slot_ids, non_blocking=True)
class MTPSampler(TorchSampler):
class MTPStore(TorchStore):
def __init__(self, *, max_draft_len: int, max_num_sequences: int,
max_beam_width: int):
super().__init__(max_draft_len=max_draft_len,
max_num_sequences=max_num_sequences,
max_beam_width=max_beam_width)
self.next_new_tokens = int_tensor(
(self.max_tokens, self.max_num_sequences, SINGLE_BEAM_WIDTH))
self.next_draft_tokens = int_tensor(
(self.max_num_sequences, self.max_draft_len))
self.new_tokens_lens = int_tensor((self.max_num_sequences, ))
class MTPSampler(Sampler):
"""
MTP sampler.
"""
SampleState = SampleStateMTP
def is_generation_model(self) -> bool:
return True
def __init__(self, args: TorchSampler.Args, *, nextn: int):
self.mapping = None
self.draft_len = nextn
super().__init__(args)
@dataclass(frozen=True, kw_only=True)
class Store(TorchSampler.Store):
next_new_tokens: torch.Tensor
next_draft_tokens: torch.Tensor
new_tokens_lens: torch.Tensor
def create_store(self) -> Store:
num_tokens, seq_slots, _ = self.NEW_TOKENS_SHAPE
draft_len = num_tokens - 1
assert draft_len == self.draft_len
return self.Store(
new_tokens=int_tensor(self.NEW_TOKENS_SHAPE),
next_new_tokens=int_tensor(self.NEW_TOKENS_SHAPE),
next_draft_tokens=int_tensor((seq_slots, draft_len)),
new_tokens_lens=int_tensor((seq_slots, )),
)
self.store = MTPStore(max_draft_len=nextn,
max_num_sequences=args.max_num_sequences,
max_beam_width=args.max_beam_width)
self.max_seq_len = args.max_seq_len
def _request_common_handling(self, request: LlmRequest,
next_draft_tokens: list[list[int]]):
assert not request.py_return_context_logits, "return_context_logits not implemented for MTPSampler"
assert not request.py_return_generation_logits, "return_generation_logits not implemented for MTPSampler"
assert not request.py_return_log_probs, "return_log_probs not implemented for MTPSampler"
assert request.py_seq_slot is not None
request.py_draft_tokens = next_draft_tokens[request.py_seq_slot]
request.py_decoding_iter += 1
@ -250,12 +257,12 @@ class MTPSampler(TorchSampler):
new_tokens = state.host.new_tokens
new_tokens_lens_list = state.host.new_tokens_lens.tolist()
next_draft_tokens_list = state.host.next_draft_tokens.tolist()
beam_idx = self.BEAM
max_seq_len = self.max_seq_len
for req in state.scheduled_requests.context_requests:
if req.state == LlmRequestState.GENERATION_COMPLETE or req.context_remaining_length != 0:
continue
new_token = add_token(req, new_tokens, beam=beam_idx)
self._handle_stop_criteria(req, new_token)
new_token = add_token(req, new_tokens, beam=BEAM_0)
handle_stop_single_beam(req, new_token, max_seq_len=max_seq_len)
self._request_common_handling(req, next_draft_tokens_list)
for req in state.scheduled_requests.generation_requests:
@ -263,8 +270,10 @@ class MTPSampler(TorchSampler):
continue
num_new_tokens = new_tokens_lens_list[req.py_seq_slot]
for i in range(num_new_tokens):
new_token = add_token(req, new_tokens, beam=beam_idx, step=i)
if self._handle_stop_criteria(req, new_token):
new_token = add_token(req, new_tokens, beam=BEAM_0, step=i)
if handle_stop_single_beam(req,
new_token,
max_seq_len=max_seq_len):
break
req.py_num_accepted_draft_tokens = num_new_tokens - 1
req.py_rewind_len = self.draft_len - req.py_num_accepted_draft_tokens

View File

@ -84,8 +84,24 @@ class RuntimeConfig(BaseModel):
backend_cache_config = llm_args.pop("kv_cache_config", {})
llm_args["kv_cache_config"] = backend_cache_config | kv_cache_config
return update_llm_args_with_extra_options(llm_args,
self.extra_llm_api_options)
updated_llm_args = update_llm_args_with_extra_options(
llm_args, self.extra_llm_api_options)
if self.backend == "pytorch":
cuda_graph_config = updated_llm_args.pop(
"cuda_graph_config", llm_args["cuda_graph_config"])
# Use runtime max_batch_size as cuda_graph_config.max_batch_size
# if both max_batch_size and batch_sizes are not set.
batch_sizes_set = cuda_graph_config.get("batch_sizes",
None) is not None
max_batch_size_set = cuda_graph_config.get("max_batch_size",
None) is not None
if not batch_sizes_set and not max_batch_size_set:
cuda_graph_config[
"max_batch_size"] = self.settings_config.max_batch_size
updated_llm_args["cuda_graph_config"] = cuda_graph_config
return updated_llm_args
@model_validator(mode="after")
def validate_full_config(self) -> RuntimeConfig:

View File

@ -125,13 +125,36 @@ class ZeroMqQueue:
# Send data without HMAC
self.socket.send_pyobj(obj)
def put_noblock(self, obj: Any):
def put_noblock(self,
obj: Any,
*,
retry: int = 1,
wait_time: float = 0.001):
'''
Put an object into the queue without blocking, and retry if the send fails.
NOTE: It won't raise any error if the send fails.
Parameters:
obj (Any): The object to send.
retry (int): The number of times to retry sending the object.
wait_time (float): The time to wait before retrying.
'''
assert retry >= 0 and retry <= 10, "Retry must be between 0 and 10, adjust the wait_time if needed"
self.setup_lazily()
with nvtx_range_debug("send", color="blue", category="IPC"):
data = pickle.dumps(obj) # nosec B301
if self.use_hmac_encryption:
data = self._sign_data(data)
self.socket.send(data, flags=zmq.NOBLOCK)
try:
self.socket.send(data, flags=zmq.NOBLOCK)
except zmq.Again:
if retry > 0:
time.sleep(wait_time)
self.put_noblock(obj, retry=retry - 1, wait_time=wait_time)
else:
logger.error(f"Failed to send object: {obj}")
async def put_async(self, obj: Any):
self.setup_lazily()

View File

@ -351,7 +351,7 @@ class GenerationExecutorProxy(GenerationExecutor):
# notify the workers to quit
if all(not f.done() for f in self.mpi_futures):
self.request_queue.put_noblock(None)
self.request_queue.put_noblock(None, retry=4)
def shutdown(self):
if not self.workers_started:

View File

@ -1041,6 +1041,9 @@ class KvCacheConfig(StrictBaseModel, PybindMirror):
"The data type to use for the Mamba SSM cache. If set to 'auto', the data type will be inferred from the model config."
)
tokens_per_block: int = Field(default=32,
description="The number of tokens per block.")
def _to_pybind(self):
return _KvCacheConfig(
enable_block_reuse=self.enable_block_reuse,
@ -1946,6 +1949,9 @@ class BaseLlmArgs(StrictBaseModel):
from tensorrt_llm._torch.speculative import suggest_spec_config
spec_config = suggest_spec_config(max_batch_size)
if self.kv_cache_config is not None:
executor_config.tokens_per_block = self.kv_cache_config.tokens_per_block
update_executor_config(
executor_config,
backend=self.backend,

View File

@ -435,7 +435,7 @@ class RemoteMpiCommSessionServer():
f"RemoteMpiCommSessionServer received all results, sending to client\n",
"green")
try:
self.queue.put_noblock(self.results)
self.queue.put_noblock(self.results, retry=2)
except zmq.ZMQError as e:
# The client could be shutdown first.
if e.errno == zmq.EAGAIN:

View File

@ -518,6 +518,8 @@ def generate_api_docs_as_docstring(model: Type[BaseModel],
# Format the argument documentation with 12 spaces indent for args
arg_line = f"{indent} {field_name} ({type_str}): "
if status := field_info.get("status", None):
arg_line += f":tag:`{status}` "
if field_description:
arg_line += field_description.split('\n')[0] # First line with type
@ -557,20 +559,21 @@ class ApiParamTagger:
'''
def __call__(self, cls: Type[BaseModel]) -> None:
self.process_pydantic_model(cls)
""" The main entry point to tag the api doc. """
self._process_pydantic_model(cls)
def process_pydantic_model(self, cls: Type[BaseModel]) -> None:
def _process_pydantic_model(self, cls: Type[BaseModel]) -> None:
"""Process the Pydantic model to add tags to the fields.
"""
for field_name, field_info in cls.model_fields.items():
if field_info.json_schema_extra and 'status' in field_info.json_schema_extra:
status = field_info.json_schema_extra['status']
self.amend_pydantic_field_description_with_tags(
self._amend_pydantic_field_description_with_tags(
cls, [field_name], status)
def amend_pydantic_field_description_with_tags(self, cls: Type[BaseModel],
field_names: list[str],
tag: str) -> None:
def _amend_pydantic_field_description_with_tags(self, cls: Type[BaseModel],
field_names: list[str],
tag: str) -> None:
"""Amend the description of the fields with tags.
e.g. :tag:`beta` or :tag:`prototype`
Args:

View File

@ -109,6 +109,7 @@ class Logger(metaclass=Singleton):
self._func_wrapper(severity)(" ".join(parts))
def log_once(self, severity, *msg, key):
assert key is not None, "key is required for log_once"
if key not in self._appeared_keys:
self._appeared_keys.add(key)
self.log(severity, *msg)

View File

@ -6,7 +6,7 @@ import re
import time
import traceback
import uuid
from typing import Any, AsyncGenerator, Literal
from typing import Any, List, Literal
from openai_harmony import (Author, Conversation, DeveloperContent,
HarmonyEncodingName, HarmonyError, Message,
@ -14,15 +14,15 @@ from openai_harmony import (Author, Conversation, DeveloperContent,
SystemContent, TextContent, ToolDescription,
load_harmony_encoding)
from tensorrt_llm.llmapi import RequestOutput
from tensorrt_llm.logger import logger
# yapf: disable
from .openai_protocol import (ChatCompletionMessageParam, ChatCompletionRequest,
from .openai_protocol import (ChatCompletionMessageParam,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage,
ChatCompletionStreamResponse,
ChatCompletionToolsParam, ChatMessage,
DeltaFunctionCall, DeltaMessage, DeltaToolCall,
UsageInfo)
@ -1485,36 +1485,72 @@ class HarmonyAdapter:
return True
async def handle_streaming_response(
harmony_adapter: HarmonyAdapter,
generator: RequestOutput,
request_id: str,
request: ChatCompletionRequest,
) -> AsyncGenerator[str, None]:
"""Handle streaming response with harmony format."""
_SERVE_HARMONY_ADAPTER: HarmonyAdapter = None
def get_harmony_adapter():
global _SERVE_HARMONY_ADAPTER
if _SERVE_HARMONY_ADAPTER is None:
_SERVE_HARMONY_ADAPTER = HarmonyAdapter()
return _SERVE_HARMONY_ADAPTER
def handle_streaming_response(tools: List[ChatCompletionToolsParam],
tool_choice: str, outputs: List, model: str,
request_id: str, done: bool,
num_prompt_tokens: int):
first_iteration = True
async for res in generator:
output = res.outputs[0]
output = outputs[0]
# Convert tools to dictionary format for harmony adapter (standard pattern)
tools_dict = None
if request.tools:
tools_dict = [tool.model_dump() for tool in request.tools]
# Convert tools to dictionary format for harmony adapter (standard pattern)
tools_dict = None
harmony_adapter = get_harmony_adapter()
if tools:
tools_dict = [tool.model_dump() for tool in tools]
# Get tool_choice from request - if "none", don't pass tools to parser
tool_choice = getattr(request, 'tool_choice', None)
if tool_choice == "none":
tools_for_parser = None
# Get tool_choice from request - if "none", don't pass tools to parser
if tool_choice == "none":
tools_for_parser = None
else:
tools_for_parser = tools_dict
# Create OpenAI streaming responses
try:
res = []
if done:
# Clean up state
harmony_adapter.cleanup_stream_state(request_id)
usage_info = _create_usage_info(num_prompt_tokens, outputs)
# Send final message with finish_reason
final_response = ChatCompletionStreamResponse(
model=model,
choices=[
ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(),
finish_reason=output.finish_reason,
stop_reason=output.stop_reason)
],
)
final_response_json = final_response.model_dump_json(
exclude_none=True)
final_usage_chunk = ChatCompletionStreamResponse(choices=[],
model=model,
usage=usage_info)
final_usage_json = final_usage_chunk.model_dump_json(
exclude_none=True)
res.append(f"data: {final_response_json}\n\n")
res.append(f"data: {final_usage_json}\n\n")
else:
tools_for_parser = tools_dict
# Create OpenAI streaming responses
try:
responses = harmony_adapter.create_openai_streaming_response(
request_id=request_id,
tokens=output.token_ids_diff,
available_tools=tools_for_parser,
model_name=request.model,
model_name=model,
tool_choice=tool_choice)
# Send first response after receiving the first output
if first_iteration:
@ -1525,64 +1561,44 @@ async def handle_streaming_response(
delta=first_delta)
first_response = ChatCompletionStreamResponse(
model=request.model,
model=model,
choices=[choice],
)
response_json = first_response.model_dump_json(
exclude_none=True)
yield f"data: {response_json}\n\n"
res.append(f"data: {response_json}\n\n")
for response in responses:
yield response
res.extend(responses)
except Exception as e:
logger.error(f"Failed to create OpenAI streaming response: {e}")
logger.debug(f"Streaming error details: {traceback.format_exc()}")
# Clean up state
harmony_adapter.cleanup_stream_state(request_id)
raise e
return res
# Clean up state
harmony_adapter.cleanup_stream_state(request_id)
# Send final message with finish_reason
output = generator.outputs[0]
final_response = ChatCompletionStreamResponse(
model=request.model,
choices=[
ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(),
finish_reason=output.finish_reason,
stop_reason=output.stop_reason)
])
yield f"data: {final_response.model_dump_json(exclude_unset=True)}\n\n"
yield "data: [DONE]\n\n"
except Exception as e:
logger.error(f"Failed to create OpenAI streaming response: {e}")
logger.debug(f"Streaming error details: {traceback.format_exc()}")
# Clean up state
harmony_adapter.cleanup_stream_state(request_id)
raise e
async def handle_non_streaming_response(
harmony_adapter: HarmonyAdapter, promise: RequestOutput,
request: ChatCompletionRequest) -> ChatCompletionResponse:
def handle_non_streaming_response(tools: List[ChatCompletionToolsParam],
tool_choice: str, outputs: List, model: str,
num_prompt_tokens: int):
"""Handle non-streaming response with harmony format."""
# Get final result
await promise
# Parse harmony output to OpenAI format
# Convert tools to dictionary format for harmony adapter (standard pattern)
tools_dict = None
if request.tools:
tools_dict = [tool.model_dump() for tool in request.tools]
harmony_adapter = get_harmony_adapter()
if tools:
tools_dict = [tool.model_dump() for tool in tools]
# Get tool_choice from request - if "none", don't pass tools to parser
tool_choice = getattr(request, 'tool_choice', None)
if tool_choice == "none":
tools_for_parser = None
else:
tools_for_parser = tools_dict
output = promise.outputs[0]
output = outputs[0]
parsed_output = harmony_adapter.harmony_output_to_openai(
output.token_ids, tools_for_parser, tool_choice)
@ -1597,11 +1613,11 @@ async def handle_non_streaming_response(
output.finish_reason)
# Create usage info from metrics (RequestOutput doesn't have usage in v1)
usage_info = _create_usage_info(promise)
usage_info = _create_usage_info(num_prompt_tokens, outputs)
# Create response
response = ChatCompletionResponse(
model=request.model,
model=model,
choices=[
ChatCompletionResponseChoice(
index=0,
@ -1613,7 +1629,6 @@ async def handle_non_streaming_response(
# Optional: Log if harmony parsing failed (for debugging)
if parsed_output.get('_harmony_parsing_failed'):
logger.warning("⚠️ Harmony parsing fell back to raw text decoding")
logger.debug(f"request\n\n{request}")
logger.debug(f"response\n\n{response}\n")
return response
@ -1646,15 +1661,10 @@ def _determine_finish_reason(parsed_output: dict[str, Any],
return reason
def _create_usage_info(final_res: RequestOutput) -> UsageInfo:
def _create_usage_info(num_prompt_tokens, outputs) -> UsageInfo:
"""Create usage info from RequestOutput following serving_chat.py pattern."""
# Calculate prompt tokens from prompt_token_ids and encoder_prompt_token_ids
assert final_res.prompt_token_ids is not None
num_prompt_tokens = len(final_res.prompt_token_ids)
# Calculate completion tokens from all outputs
num_generated_tokens = sum(
len(output.token_ids) for output in final_res.outputs)
num_generated_tokens = sum(len(output.token_ids) for output in outputs)
# Create usage info
usage = UsageInfo(prompt_tokens=num_prompt_tokens,

View File

@ -322,6 +322,8 @@ class OpenAIDisaggServer:
raise ValueError("Disagg server returned more than one choice. This is currently not supported in disaggregated server.")
if choices[0].disaggregated_params is None:
raise ValueError("Context server did not return disaggregated params")
if choices[0].disaggregated_params.ctx_request_id is None:
raise ValueError("Invalid disaggregated params in context phase response.")
return ctx_response

View File

@ -44,9 +44,10 @@ from tensorrt_llm.serve.openai_protocol import (ChatCompletionRequest,
UsageInfo,
to_llm_disaggregated_params)
from tensorrt_llm.serve.postprocess_handlers import (
ChatPostprocArgs, CompletionPostprocArgs, chat_response_post_processor,
chat_stream_post_processor, completion_response_post_processor,
completion_stream_post_processor)
ChatCompletionPostprocArgs, ChatPostprocArgs, CompletionPostprocArgs,
chat_harmony_post_processor, chat_harmony_streaming_post_processor,
chat_response_post_processor, chat_stream_post_processor,
completion_response_post_processor, completion_stream_post_processor)
from tensorrt_llm.serve.responses_utils import ConversationHistoryStore
from tensorrt_llm.serve.responses_utils import \
create_response as responses_api_create_response
@ -57,8 +58,7 @@ from tensorrt_llm.serve.responses_utils import \
from tensorrt_llm.version import __version__ as VERSION
from .._utils import nvtx_mark, set_prometheus_multiproc_dir
from .harmony_adapter import (HarmonyAdapter, handle_non_streaming_response,
handle_streaming_response,
from .harmony_adapter import (HarmonyAdapter, get_harmony_adapter,
maybe_transform_reasoning_effort)
# yapf: enale
@ -117,7 +117,11 @@ class OpenAIServer:
# gpt-oss
self.harmony_adapter: HarmonyAdapter | None = None
self.use_harmony = self.model_config.model_type == "gpt_oss"
disable_harmony = os.getenv("DISABLE_HARMONY_ADAPTER", "0") == "1"
if disable_harmony:
self.use_harmony = False
else:
self.use_harmony = (self.model_config.model_type == "gpt_oss")
@asynccontextmanager
async def lifespan(app: FastAPI):
@ -704,11 +708,35 @@ class OpenAIServer:
Chat Completion API with harmony format support.
Supports both streaming and non-streaming modes.
"""
async def create_harmony_response(
promise: RequestOutput, postproc_params: PostprocParams) -> ChatCompletionResponse:
await promise.aresult()
if self.postproc_worker_enabled:
chat_response =promise.outputs[0]._postprocess_result
else:
post_processor, args = postproc_params.post_processor, postproc_params.postproc_args
chat_response = post_processor(promise, args)
return chat_response
async def create_streaming_generator(promise: RequestOutput, postproc_params: PostprocParams):
if not self.postproc_worker_enabled:
post_processor, args = postproc_params.post_processor, postproc_params.postproc_args
async for res in promise:
pp_results = res.outputs[0]._postprocess_result if self.postproc_worker_enabled else post_processor(res, args)
# await self._extract_metrics(res)
for pp_res in pp_results:
yield pp_res
yield "data: [DONE]\n\n"
try:
# Initialize HarmonyAdapter
# NOTE: WAR for Disagg failure, may affect perf if no warmup
if not self.harmony_adapter:
self.harmony_adapter = HarmonyAdapter()
self.harmony_adapter = get_harmony_adapter()
# Convert Pydantic models to dictionaries for JSON serialization (standard pattern)
tools_dict = None
if request.tools:
@ -743,27 +771,37 @@ class OpenAIServer:
vocab_size=self.tokenizer.tokenizer.vocab_size)
sampling_params.detokenize = False # Harmony adapter handles detokenization
postproc_args = ChatCompletionPostprocArgs.from_request(request)
postproc_params = PostprocParams(
post_processor=chat_harmony_streaming_post_processor
if request.stream else chat_harmony_post_processor,
postproc_args=postproc_args,
)
# Generate
promise = self.llm.generate_async(
inputs=harmony_tokens,
sampling_params=sampling_params,
_postproc_params=postproc_params if self.postproc_worker_enabled else None,
streaming=bool(request.stream),
lora_request=request.lora_request,
)
postproc_args.request_id = promise.request_id
if not self.postproc_worker_enabled:
postproc_args.num_prompt_tokens = len(promise.prompt_token_ids)
# Disconnect cancellation
asyncio.create_task(self.await_disconnected(raw_request, promise))
# Handle streaming
if request.stream:
return StreamingResponse(
handle_streaming_response(
self.harmony_adapter, promise,
str(promise.request_id), request,
),
content=create_streaming_generator(promise, postproc_params),
media_type="text/event-stream"
)
else:
response = await handle_non_streaming_response(self.harmony_adapter, promise, request)
response = await create_harmony_response(promise, postproc_params)
return JSONResponse(response.model_dump())
except Exception as e:

View File

@ -9,6 +9,8 @@ from ..llmapi.reasoning_parser import (BaseReasoningParser,
ReasoningParserFactory)
from ..llmapi.tokenizer import TransformersTokenizer
# yapf: disable
from .harmony_adapter import (handle_non_streaming_response,
handle_streaming_response)
from .openai_protocol import (ChatCompletionLogProbs,
ChatCompletionLogProbsContent,
ChatCompletionNamedToolChoiceParam,
@ -24,7 +26,8 @@ from .openai_protocol import (ChatCompletionLogProbs,
FunctionCall, StreamOptions, ToolCall, UsageInfo,
to_disaggregated_params)
# yapf: enale
# yapf: enable
@dataclass(kw_only=True)
class ChatPostprocArgs(PostprocArgs):
@ -57,8 +60,7 @@ class ChatPostprocArgs(PostprocArgs):
)
def create_logprobs(token_ids: List[int],
tokenizer: TransformersTokenizer,
def create_logprobs(token_ids: List[int], tokenizer: TransformersTokenizer,
logprobs: List[float]) -> ChatCompletionLogProbs:
assert len(token_ids) == len(logprobs), \
"token_ids and logprobs have different lengths"
@ -75,12 +77,14 @@ def create_logprobs(token_ids: List[int],
return chat_logprobs
def apply_reasoning_parser(args: ChatPostprocArgs, output_index: int, text: str, streaming: bool) -> Tuple[bool, str, str]:
def apply_reasoning_parser(args: ChatPostprocArgs, output_index: int, text: str,
streaming: bool) -> Tuple[bool, str, str]:
reasoning_parser = None
if args.reasoning_parser is not None:
if output_index not in args.reasoning_parser_dict:
args.reasoning_parser_dict[output_index] = ReasoningParserFactory.create_reasoning_parser(
args.reasoning_parser)
args.reasoning_parser_dict[
output_index] = ReasoningParserFactory.create_reasoning_parser(
args.reasoning_parser)
reasoning_parser = args.reasoning_parser_dict[output_index]
in_reasoning = False
@ -97,7 +101,8 @@ def apply_reasoning_parser(args: ChatPostprocArgs, output_index: int, text: str,
@nvtx_range_debug("chat_stream_post_processor")
def chat_stream_post_processor(rsp: GenerationResultBase, args: ChatPostprocArgs) -> List[str]:
def chat_stream_post_processor(rsp: GenerationResultBase,
args: ChatPostprocArgs) -> List[str]:
def yield_first_chat(num_tokens: int,
idx: int,
@ -128,9 +133,13 @@ def chat_stream_post_processor(rsp: GenerationResultBase, args: ChatPostprocArgs
include_continuous_usage = False
if args.first_iteration:
for i in range(args.num_choices):
res.append(f"data: {yield_first_chat(prompt_tokens, i, role=args.role)} \n\n")
res.append(
f"data: {yield_first_chat(prompt_tokens, i, role=args.role)} \n\n"
)
if args.echo and args.last_message_content:
res.append(f"data: {yield_first_chat(prompt_tokens, i, content=args.last_message_content)} \n\n")
res.append(
f"data: {yield_first_chat(prompt_tokens, i, content=args.last_message_content)} \n\n"
)
args.first_iteration = False
for output in rsp.outputs:
@ -158,14 +167,18 @@ def chat_stream_post_processor(rsp: GenerationResultBase, args: ChatPostprocArgs
delta_message = DeltaMessage(
content=delta_text, reasoning_content=reasoning_delta_text)
choice = ChatCompletionResponseStreamChoice(index=i,
delta=delta_message,
finish_reason=None,
avg_decoded_tokens_per_iter=getattr(rsp, 'avg_decoded_tokens_per_iter', None))
choice = ChatCompletionResponseStreamChoice(
index=i,
delta=delta_message,
finish_reason=None,
avg_decoded_tokens_per_iter=getattr(rsp,
'avg_decoded_tokens_per_iter',
None))
if args.return_logprobs:
logprobs = output.logprobs_diff
token_ids = output.token_ids_diff
choice.logprobs = create_logprobs(token_ids, args.tokenizer, logprobs)
choice.logprobs = create_logprobs(token_ids, args.tokenizer,
logprobs)
if output.finish_reason is not None:
choice.finish_reason = output.finish_reason
choice.stop_reason = output.stop_reason
@ -179,57 +192,62 @@ def chat_stream_post_processor(rsp: GenerationResultBase, args: ChatPostprocArgs
res.append(f"data: {data}\n\n")
if include_usage and rsp._done:
completion_tokens = sum(output.length
for output in rsp.outputs)
completion_tokens = sum(output.length for output in rsp.outputs)
final_usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
final_usage_chunk = ChatCompletionStreamResponse(
choices=[], model=args.model, usage=final_usage)
final_usage_chunk = ChatCompletionStreamResponse(choices=[],
model=args.model,
usage=final_usage)
final_usage_data = final_usage_chunk.model_dump_json()
res.append(f"data: {final_usage_data}\n\n")
return res
@nvtx_range_debug("chat_response_post_processor")
def chat_response_post_processor(rsp: GenerationResultBase, args: ChatPostprocArgs) -> ChatCompletionResponse:
def chat_response_post_processor(
rsp: GenerationResultBase,
args: ChatPostprocArgs) -> ChatCompletionResponse:
choices: List[ChatCompletionResponseChoice] = []
role = args.role
for output in rsp.outputs:
_, text, reasoning_text = apply_reasoning_parser(
args, output.index, output.text, False)
if args.tool_choice and isinstance(
args.tool_choice,
ChatCompletionNamedToolChoiceParam):
if args.tool_choice and isinstance(args.tool_choice,
ChatCompletionNamedToolChoiceParam):
message = ChatMessage(
role=role,
content="",
tool_calls=[
ToolCall(function=FunctionCall(
name=args.tool_choice.function.name,
arguments=text))
name=args.tool_choice.function.name, arguments=text))
])
else:
if text is None:
text = ""
message = ChatMessage(
role=role, content=text, reasoning_content=reasoning_text)
disaggregated_params = to_disaggregated_params(output.disaggregated_params)
message = ChatMessage(role=role,
content=text,
reasoning_content=reasoning_text)
disaggregated_params = to_disaggregated_params(
output.disaggregated_params)
choice = ChatCompletionResponseChoice(
index=output.index,
message=message,
finish_reason=output.finish_reason,
stop_reason=output.stop_reason,
disaggregated_params=disaggregated_params,
avg_decoded_tokens_per_iter=getattr(rsp, 'avg_decoded_tokens_per_iter', None),
avg_decoded_tokens_per_iter=getattr(rsp,
'avg_decoded_tokens_per_iter',
None),
)
if args.return_logprobs:
choice.logprobs = create_logprobs(output.token_ids, args.tokenizer, output.logprobs)
choice.logprobs = create_logprobs(output.token_ids, args.tokenizer,
output.logprobs)
choices.append(choice)
if args.echo and args.last_message_content:
@ -238,8 +256,7 @@ def chat_response_post_processor(rsp: GenerationResultBase, args: ChatPostprocAr
choice.message.content = full_message
num_prompt_tokens = args.num_prompt_tokens
num_generated_tokens = sum(
len(output.token_ids) for output in rsp.outputs)
num_generated_tokens = sum(len(output.token_ids) for output in rsp.outputs)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
@ -275,7 +292,8 @@ class CompletionPostprocArgs(PostprocArgs):
@nvtx_range_debug("completion_stream_post_processor")
def completion_stream_post_processor(rsp: DetokenizedGenerationResultBase, args: CompletionPostprocArgs) -> List[str]:
def completion_stream_post_processor(rsp: DetokenizedGenerationResultBase,
args: CompletionPostprocArgs) -> List[str]:
res: List[str] = []
prompt_tokens = args.num_prompt_tokens
if stream_option := args.stream_options:
@ -293,9 +311,11 @@ def completion_stream_post_processor(rsp: DetokenizedGenerationResultBase, args:
index=args.prompt_idx * args.num_choices + output.index,
text=delta_text if args.detokenize else "",
token_ids=None if args.detokenize else output.token_ids_diff,
finish_reason = output.finish_reason,
stop_reason = output.stop_reason,
avg_decoded_tokens_per_iter=getattr(rsp, 'avg_decoded_tokens_per_iter', None),
finish_reason=output.finish_reason,
stop_reason=output.stop_reason,
avg_decoded_tokens_per_iter=getattr(rsp,
'avg_decoded_tokens_per_iter',
None),
)
chunk = CompletionStreamResponse(model=args.model, choices=[choice])
if include_continuous_usage:
@ -306,16 +326,16 @@ def completion_stream_post_processor(rsp: DetokenizedGenerationResultBase, args:
res.append(f"data: {data}\n\n")
if include_usage and rsp._done:
completion_tokens = sum(output.length
for output in rsp.outputs)
completion_tokens = sum(output.length for output in rsp.outputs)
final_usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
final_usage_chunk = ChatCompletionStreamResponse(
choices=[], model=args.model, usage=final_usage)
final_usage_chunk = ChatCompletionStreamResponse(choices=[],
model=args.model,
usage=final_usage)
final_usage_data = final_usage_chunk.model_dump_json()
res.append(f"data: {final_usage_data}\n\n")
args.first_iteration = False
@ -323,7 +343,9 @@ def completion_stream_post_processor(rsp: DetokenizedGenerationResultBase, args:
@nvtx_range_debug("completion_response_post_processor")
def completion_response_post_processor(rsp: GenerationResult, args: CompletionPostprocArgs) -> CompletionResponse:
def completion_response_post_processor(
rsp: GenerationResult,
args: CompletionPostprocArgs) -> CompletionResponse:
prompt_tokens = args.num_prompt_tokens
completion_tokens = 0
choices = []
@ -331,23 +353,75 @@ def completion_response_post_processor(rsp: GenerationResult, args: CompletionPo
text = output.text
if args.echo:
text = args.prompt + text
disaggregated_params = to_disaggregated_params(output.disaggregated_params)
disaggregated_params = to_disaggregated_params(
output.disaggregated_params)
choice = CompletionResponseChoice(
text=text if args.detokenize else "",
token_ids=None if args.detokenize else output.token_ids,
index=args.prompt_idx * args.num_choices + output.index,
disaggregated_params=disaggregated_params,
context_logits=None if rsp.context_logits is None else rsp.context_logits.tolist(),
context_logits=None
if rsp.context_logits is None else rsp.context_logits.tolist(),
stop_reason=output.stop_reason,
finish_reason=output.finish_reason,
avg_decoded_tokens_per_iter=getattr(rsp, 'avg_decoded_tokens_per_iter', None),
avg_decoded_tokens_per_iter=getattr(rsp,
'avg_decoded_tokens_per_iter',
None),
)
completion_tokens += output.length
choices.append(choice)
usage = UsageInfo(prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=completion_tokens + prompt_tokens)
response = CompletionResponse(choices=choices, model=args.model, usage=usage)
completion_tokens=completion_tokens,
total_tokens=completion_tokens + prompt_tokens)
response = CompletionResponse(choices=choices,
model=args.model,
usage=usage)
return response
@dataclass(kw_only=True)
class ChatCompletionPostprocArgs(PostprocArgs):
model: str
tools: Optional[List[ChatCompletionToolsParam]]
tool_choice: Optional[Union[Literal["none", "auto"],
ChatCompletionNamedToolChoiceParam]]
request_id: Optional[int] = None
@classmethod
def from_request(cls, request: ChatCompletionRequest):
return cls(
model=request.model,
tools=request.tools,
tool_choice=request.tool_choice,
)
@nvtx_range_debug("chat_harmony_post_processor")
def chat_harmony_post_processor(
rsp: GenerationResult,
args: ChatCompletionPostprocArgs) -> ChatCompletionResponse:
response = handle_non_streaming_response(
tools=args.tools,
tool_choice=args.tool_choice,
outputs=rsp.outputs,
model=args.model,
num_prompt_tokens=args.num_prompt_tokens,
)
return response
@nvtx_range_debug("chat_harmony_streaming_post_processor")
def chat_harmony_streaming_post_processor(
rsp: GenerationResult, args: ChatCompletionPostprocArgs) -> List[str]:
response = handle_streaming_response(
tools=args.tools,
tool_choice=args.tool_choice,
outputs=rsp.outputs,
model=args.model,
request_id=args.request_id,
done=rsp._done,
num_prompt_tokens=args.num_prompt_tokens,
)
return response

View File

@ -8,6 +8,9 @@ meta-llama/Llama-3.3-70B-Instruct:
accuracy: 48.03
- quant_algo: FP8
accuracy: 48.03
- quant_algo: FP8
kv_cache_quant_algo: FP8
accuracy: 48.03
deepseek-ai/DeepSeek-R1:
- quant_algo: NVFP4
accuracy: 70.45

View File

@ -602,9 +602,9 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
@parametrize_with_ids("gen_tp", [1, 2])
@pytest.mark.parametrize("testset", ["GSM8K", "MMLU"])
def test_ctx_pp_gen_tp_asymmetric(self, ctx_pp, gen_tp, testset):
if ctx_pp * gen_tp * 2 > get_device_count():
if ctx_pp + gen_tp > get_device_count():
pytest.skip(
f"Not enough devices for ctx_pp={ctx_pp}*gen_tp={gen_tp} test")
f"Not enough devices for ctx_pp={ctx_pp}+gen_tp={gen_tp} test")
return run_parallel_test(self.MODEL_NAME, self.MODEL_PATH, ctx_pp, 1, 1,
gen_tp, 1, 1, [get_accuracy_task(testset)])

View File

@ -1565,6 +1565,11 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
@parametrize_with_ids("moe_backend", ["CUTLASS", "TRTLLM"])
def test_nvfp4(self, fp8kv, attention_dp, cuda_graph, overlap_scheduler,
torch_compile, mtp_nextn, moe_backend):
if moe_backend == "TRTLLM" and (get_sm_version() == 120
or get_sm_version() == 121):
pytest.skip(
"MOE TRTLLM backend does not support SM version 120 or 121")
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.75)
torch_compile_config = TorchCompileConfig(
enable_fullgraph=True,
@ -1613,8 +1618,10 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
torch_compile, mtp_nextn, moe_backend):
if torch_compile and pp_size > 1:
pytest.skip("PP with torch.compile is not supported yet.")
if moe_backend == "TRTLLM" and get_sm_version() == 120:
pytest.skip("MOE TRTLLM backend does not support SM version 120")
if moe_backend == "TRTLLM" and (get_sm_version() == 120
or get_sm_version() == 121):
pytest.skip(
"MOE TRTLLM backend does not support SM version 120 or 121")
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.75)
# Picewise Cuda Graph cannot be enabled for nvfp4 attention dp.
torch_compile_config = TorchCompileConfig(
@ -1885,6 +1892,11 @@ class TestDeepSeekR1(LlmapiAccuracyTestHarness):
def test_nvfp4_multi_gpus(self, tp_size, pp_size, ep_size, mtp_nextn, fp8kv,
attention_dp, cuda_graph, overlap_scheduler,
max_batch_size, moe_backend):
if moe_backend == "TRTLLM" and (get_sm_version() == 120
or get_sm_version() == 121):
pytest.skip(
"MOE TRTLLM backend does not support SM version 120 or 121")
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.70)
pytorch_config = dict(
disable_overlap_scheduler=not overlap_scheduler,
@ -2509,6 +2521,11 @@ class TestQwen3_30B_A3B(LlmapiAccuracyTestHarness):
torch_compile,
):
if moe_backend == "TRTLLM" and (get_sm_version() == 120
or get_sm_version() == 121):
pytest.skip(
"MOE TRTLLM backend does not support SM version 120 or 121")
torch_compile_config = TorchCompileConfig(
enable_fullgraph=True,
enable_piecewise_cuda_graph=cuda_graph,
@ -2700,6 +2717,11 @@ class TestQwen3_235B_A22B(LlmapiAccuracyTestHarness):
def test_nvfp4(self, tp_size, pp_size, ep_size, attention_dp, cuda_graph,
overlap_scheduler, moe_backend, eagle3):
if moe_backend == "TRTLLM" and (get_sm_version() == 120
or get_sm_version() == 121):
pytest.skip(
"MOE TRTLLM backend does not support SM version 120 or 121")
pytorch_config = dict(
disable_overlap_scheduler=not overlap_scheduler,
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
@ -2779,7 +2801,7 @@ class TestPhi4MiniInstruct(LlmapiAccuracyTestHarness):
MODEL_PATH = f"{llm_models_root()}/Phi-4-mini-instruct"
def test_auto_dtype(self):
with LLM(self.MODEL_PATH) as llm:
with LLM(self.MODEL_PATH, max_seq_len=4096) as llm:
task = CnnDailymail(self.MODEL_NAME)
task.evaluate(llm)
task = MMLU(self.MODEL_NAME)
@ -3046,10 +3068,8 @@ class TestGPTOSS(LlmapiAccuracyTestHarness):
class TestEXAONE4(LlmapiAccuracyTestHarness):
MODEL_NAME = "LGAI-EXAONE/EXAONE-4.0-32B"
kv_cache_config = KvCacheConfig(
enable_block_reuse=False,
enable_partial_reuse=False,
max_attention_window=[4096, 4096, 4096, 131072])
kv_cache_config = KvCacheConfig(enable_block_reuse=False,
enable_partial_reuse=False)
def test_auto_dtype(self):
model_path = f"{llm_models_root()}/EXAONE-4.0-32B"

View File

@ -8,7 +8,7 @@ disable_overlap_scheduler: True
context_servers:
num_instances: 1
max_num_tokens: 512
max_batch_size: 256
max_batch_size: 64
cache_transceiver_config:
backend: DEFAULT
urls:
@ -16,7 +16,7 @@ context_servers:
generation_servers:
num_instances: 1
max_num_tokens: 256
max_batch_size: 128
max_batch_size: 32
cache_transceiver_config:
backend: DEFAULT
urls:

View File

@ -1302,7 +1302,7 @@ def get_config_for_benchmark(model_root, backend):
"num_instances": 1,
"max_batch_size": 2,
"max_num_tokens": 384,
"max_seq_len": 320,
"max_seq_len": 384,
"tensor_parallel_size": 1,
"pipeline_parallel_size": 1,
"disable_overlap_scheduler": True,
@ -1318,7 +1318,7 @@ def get_config_for_benchmark(model_root, backend):
"pipeline_parallel_size": 1,
"max_batch_size": 2,
"max_num_tokens": 384,
"max_seq_len": 320,
"max_seq_len": 384,
"cache_transceiver_config": {
"backend": backend,
"max_tokens_in_buffer": 512,

View File

@ -36,6 +36,42 @@ MODEL_PATHS = {
}
def mpi_publish_name():
port_name = None
try:
port_name = MPI.Open_port()
MPI.Publish_name('my_port', port_name)
except MPI.Exception as e:
print(f"Error publishing port name: {e}")
raise e
except Exception as e:
print(f"Unexpected error publishing port name: {e}")
raise e
return port_name
def mpi_initialize_intercomm(port_name):
intercomm = None
try:
intercomm = MPI.COMM_SELF.Accept(port_name)
except MPI.Exception as e:
print(f"Error accepting intercomm: {e}", flush=True)
raise
except Exception as e:
print(f"Unexpected error accepting intercomm: {e}", flush=True)
raise
return intercomm
def mpi_send_termination_request(intercomm):
if intercomm is not None:
# Send termination requests
intercomm.send(None, dest=0, tag=MPI_REQUEST)
intercomm.send(None, dest=1, tag=MPI_REQUEST)
print("Sent termination requests to the workers.")
def model_path(model_name):
llm_models_root = os.environ["LLM_MODELS_ROOT"]
for name, path in MODEL_PATHS.items():
@ -48,8 +84,15 @@ async def run_worker(kv_cache_config, cache_transceiver_config, pytorch_config,
model_name, rank):
assert isinstance(pytorch_config, dict)
print(f"Running worker {rank}")
port_name = MPI.Lookup_name('my_port')
intercomm = MPI.COMM_WORLD.Connect(port_name)
try:
port_name = MPI.Lookup_name('my_port')
intercomm = MPI.COMM_WORLD.Connect(port_name)
except MPI.Exception as e:
print(f"Error publishing port name: {e}")
raise e
except Exception as e:
print(f"Unexpected error publishing port name: {e}")
raise e
session = MPI.COMM_WORLD.Split(color=rank, key=0)
set_mpi_comm(session)
@ -139,8 +182,7 @@ def verify_disaggregated(model, generation_overlap, enable_cuda_graph, prompt,
zip(kv_cache_configs, cache_transceiver_configs, worker_pytorch_configs,
model_names, ranks))
port_name = MPI.Open_port()
MPI.Publish_name('my_port', port_name)
port_name = mpi_publish_name()
with MPIPoolExecutor(max_workers=2, env={"UCX_TLS": "^ib"}) as executor:
futures = []
@ -152,9 +194,10 @@ def verify_disaggregated(model, generation_overlap, enable_cuda_graph, prompt,
print(f"Error in worker {worker_arg}: {e}")
raise e
intercomm = None
try:
print("Launched all the workers.")
intercomm = MPI.COMM_SELF.Accept(port_name)
print("Launched all the workers.", flush=True)
intercomm = mpi_initialize_intercomm(port_name)
for _ in range(2):
intercomm.recv(tag=MPI_READY)
@ -187,14 +230,15 @@ def verify_disaggregated(model, generation_overlap, enable_cuda_graph, prompt,
output = responses[0]
assert output[0].text == expected_output
assert output[0].token_ids == expected_output_ids
except Exception as e:
print(f"Exception encountered: {e}", flush=True)
raise e
finally:
# Send termination requests
intercomm.send(None, dest=0, tag=MPI_REQUEST)
intercomm.send(None, dest=1, tag=MPI_REQUEST)
print("Sent termination requests to the workers.")
print("Sending termination request", flush=True)
mpi_send_termination_request(intercomm)
# Wait for all futures to complete
print("Waiting for all workers to terminate. ", flush=True)
for future in futures:
future.result()
print("All workers terminated.")
@ -282,8 +326,7 @@ def test_disaggregated_llama_context_capacity(model, enable_cuda_graph,
zip(kv_cache_configs, cache_transceiver_configs, worker_pytorch_configs,
model_names, ranks))
port_name = MPI.Open_port()
MPI.Publish_name('my_port', port_name)
port_name = mpi_publish_name()
prompt = "European Union is a political and economic union of 27 countries. The European Union is headquartered in Brussels, Belgium. The first president of the European Union was Jean-Claude Juncker. The current president is Ursula von der Leyen. The European Union is a major economic and political entity."
@ -297,9 +340,10 @@ def test_disaggregated_llama_context_capacity(model, enable_cuda_graph,
print(f"Error in worker {worker_arg}: {e}")
raise e
intercomm = None
try:
print("Launched all the workers.")
intercomm = MPI.COMM_SELF.Accept(port_name)
intercomm = mpi_initialize_intercomm(port_name)
for _ in range(2):
intercomm.recv(tag=MPI_READY)
@ -334,11 +378,11 @@ def test_disaggregated_llama_context_capacity(model, enable_cuda_graph,
intercomm.send(requests, dest=1, tag=MPI_REQUEST)
output = intercomm.recv(source=1, tag=MPI_RESULT)
except MPI.Exception as e:
print(f"MPI Error")
raise e
finally:
# Send termination requests
intercomm.send(None, dest=0, tag=MPI_REQUEST)
intercomm.send(None, dest=1, tag=MPI_REQUEST)
print("Sent termination requests to the workers.")
mpi_send_termination_request(intercomm)
# Wait for all futures to complete
for future in futures:
@ -387,8 +431,7 @@ def test_disaggregated_spec_dec_batch_slot_limit(model, spec_dec_model_path,
zip(kv_cache_configs, cache_transceiver_configs, worker_pytorch_configs,
model_names, ranks))
port_name = MPI.Open_port()
MPI.Publish_name('my_port', port_name)
port_name = mpi_publish_name()
prompt = "What is the capital of Germany?"
@ -402,9 +445,10 @@ def test_disaggregated_spec_dec_batch_slot_limit(model, spec_dec_model_path,
print(f"Error in worker {worker_arg}: {e}")
raise e
intercomm = None
try:
print("Launched all the workers.")
intercomm = MPI.COMM_SELF.Accept(port_name)
intercomm = mpi_initialize_intercomm(port_name)
for _ in range(2):
intercomm.recv(tag=MPI_READY)
@ -438,11 +482,11 @@ def test_disaggregated_spec_dec_batch_slot_limit(model, spec_dec_model_path,
intercomm.send(requests, dest=1, tag=MPI_REQUEST)
output = intercomm.recv(source=1, tag=MPI_RESULT)
except MPI.Exception as e:
print(f"MPI Error")
raise e
finally:
# Send termination requests
intercomm.send(None, dest=0, tag=MPI_REQUEST)
intercomm.send(None, dest=1, tag=MPI_REQUEST)
print("Sent termination requests to the workers.")
mpi_send_termination_request(intercomm)
# Wait for all futures to complete
for future in futures:

View File

@ -23,7 +23,7 @@ class PythonVenvRunnerImpl(PythonRunnerInterface):
venv_dir (str): Path to the virtualenv root directory, or None if this is
an externally-built virtualenv
venv_bin (str): Path to the Python executable to use when running tests
workspace (str): Path to the TURTLE workspace
workspace (str): Path to the test workspace
"""
def __init__(self, pip_opts, venv_dir, venv_bin, workspace):

View File

@ -189,7 +189,7 @@ class GPUClockLock:
# Initialize thread
self._thread = threading.Thread(
target=self._monitoring_thread,
name="TURTLE - GPUMonitor",
name="LLM Test - GPUMonitor",
kwargs={"interval_ms": self._interval_ms})
self._thread.daemon = True
self._thread.start()

View File

@ -181,10 +181,19 @@ def get_model_yaml_config(model_label: str,
# lora-specific change for pytorch
if 'pytorch' in model_label and 'loras' in model_label:
# Derive the requested number of adapters from model_label (segment like "loras:X")
lora_count = 1
for part in model_label.split('-'):
if part.startswith('loras:'):
lora_count = max(1, int(part.split(':', 1)[1]))
break
lora_config = {
'lora_config': {
'lora_dir': lora_dirs if lora_dirs is not None else [],
'max_lora_rank': 64
'max_lora_rank': 64,
'max_loras': lora_count,
'max_cpu_loras': lora_count,
}
}
if 'phi_4_multimodal_instruct' in model_label:

View File

@ -1755,7 +1755,7 @@ def parse_output(text):
for item in text_lists:
item = item.replace(os.linesep, "")
while True:
match = re.search(r"(Generated text: \'(.*?)\')", item,
match = re.search(r'Generated text: ([\'"])(.*?)\1', item,
re.MULTILINE)
if match is None:
break
@ -2299,7 +2299,8 @@ def test_ptp_quickstart_advanced_mixed_precision(llm_root, llm_venv):
marks=pytest.mark.skip_less_device_memory(80000)),
pytest.param("gemma-3-27b-it",
"gemma/gemma-3-27b-it",
marks=pytest.mark.skip_less_device_memory(80000)),
marks=(skip_post_blackwell,
pytest.mark.skip_less_device_memory(80000))),
])
def test_ptp_quickstart_multimodal(llm_root, llm_venv, model_name, model_path,
modality, use_cuda_graph):
@ -2407,9 +2408,9 @@ def test_ptp_quickstart_multimodal(llm_root, llm_venv, model_name, model_path,
},
"gemma-3-27b-it": {
"image": [
["dramatic", "turbulent", "waves", "ocean", "overcast"],
["half", "dome", "yosemite", "landmark", "rounded"],
["flowing", "traffic", "vehicles", "road", "Changi"],
["natural", "turbulent", "dramatic", "scene", "wave"],
["image", "famous", "rock", "granite", "landmark"],
["traffic", "moderate", "heavy", "flowing", "cars"],
],
},
}
@ -2600,9 +2601,10 @@ def test_ptp_quickstart_multimodal_phi4mm(llm_root, llm_venv, modality):
@pytest.mark.skip_less_device(2)
@pytest.mark.skip_less_device_memory(80000)
@pytest.mark.parametrize("model_name,model_path", [
("gemma-3-27b-it", "gemma/gemma-3-27b-it"),
("mistral-small-3.1-24b-instruct", "Mistral-Small-3.1-24B-Instruct-2503"),
("Phi-4-multimodal-instruct", "multimodals/Phi-4-multimodal-instruct"),
pytest.param(
"gemma-3-27b-it", "gemma/gemma-3-27b-it", marks=skip_post_blackwell),
])
def test_ptp_quickstart_multimodal_2gpu(llm_root, llm_venv, model_name,
model_path):
@ -2645,8 +2647,8 @@ def test_ptp_quickstart_multimodal_2gpu(llm_root, llm_venv, model_name,
},
"Phi-4-multimodal-instruct": {
"image": [
["image", "depicts", "mountain", "half", "rock"],
["road", "car", "lane", "traffic", "bus"],
["object", "mountain", "weather", "clear", "clouds"],
["traffic", "road", "vehicles", "cars", "bus"],
],
},
}
@ -2674,6 +2676,8 @@ def test_ptp_quickstart_multimodal_2gpu(llm_root, llm_venv, model_name,
cmd.append("--image_format=pil")
cmd.append("--attention_backend=FLASHINFER")
cmd.append("--disable_kv_cache_reuse")
cmd.append("--kv_cache_fraction=0.5")
cmd.append("--max_seq_len=1024")
elif model_name == "Phi-4-multimodal-instruct":
# Set max_seq_len to 4096 to use short rope factor.
cmd.append("--max_seq_len=4096")
@ -2702,9 +2706,10 @@ def test_ptp_quickstart_multimodal_2gpu(llm_root, llm_venv, model_name,
@pytest.mark.skip_less_device_memory(80000)
@pytest.mark.parametrize("model_name,model_path", [
("gemma-3-27b-it", "gemma/gemma-3-27b-it"),
("mistral-small-3.1-24b-instruct", "Mistral-Small-3.1-24B-Instruct-2503"),
("Phi-4-multimodal-instruct", "multimodals/Phi-4-multimodal-instruct"),
pytest.param(
"gemma-3-27b-it", "gemma/gemma-3-27b-it", marks=skip_post_blackwell),
])
def test_ptp_quickstart_multimodal_multiturn(llm_root, llm_venv, model_name,
model_path):
@ -2770,6 +2775,9 @@ def test_ptp_quickstart_multimodal_multiturn(llm_root, llm_venv, model_name,
cmd.append("--image_format=pil")
cmd.append("--attention_backend=FLASHINFER")
cmd.append("--disable_kv_cache_reuse")
cmd.append("--kv_cache_fraction=0.5")
cmd.append("--max_seq_len=1024")
elif model_name == "Phi-4-multimodal-instruct":
# Set max_seq_len to 4096 to use short rope factor.
cmd.append("--max_seq_len=4096")

View File

@ -69,11 +69,6 @@ def test_case_name(request):
return request.node.nodeid
@pytest.fixture(scope="session")
def output_dir(request):
return request.config._trt_config["output_dir"]
@pytest.fixture(scope="session")
def llm_backend_root():
llm_root = os.environ.get("LLM_ROOT", find_repo_root())
@ -655,10 +650,10 @@ def install_root_requirements(llm_backend_root):
@pytest.fixture(scope="session")
def output_dir(request):
if USE_TURTLE:
return request.config._trt_config["output_dir"]
else:
return request.config.getoption("--output-dir")
output = request.config.getoption("--output-dir")
if output:
os.makedirs(str(output), exist_ok=True)
return output
def deselect_by_regex(regexp, items, test_prefix, config):

View File

@ -39,6 +39,13 @@ l0_a10:
- test_e2e.py::test_openai_chat_example[pytorch] TIMEOUT (90)
- test_e2e.py::test_trtllm_bench_request_rate_and_concurrency[enable_concurrency-]
- test_e2e.py::test_trtllm_bench_invalid_token_pytorch[TinyLlama-1.1B-Chat-v1.0-TinyLlama-1.1B-Chat-v1.0]
# llmapi
- unittest/llmapi/test_llm_utils.py
- unittest/llmapi/test_gc_utils.py
- unittest/llmapi/test_reasoning_parser.py
- unittest/llmapi/test_serialization.py
- unittest/llmapi/test_utils.py
- unittest/llmapi/test_llm_args.py
- condition:
ranges:
system_gpu_count:
@ -114,12 +121,6 @@ l0_a10:
- unittest/bindings
- unittest/test_model_runner_cpp.py
- unittest/llmapi/test_build_cache.py
- unittest/llmapi/test_llm_utils.py
- unittest/llmapi/test_gc_utils.py
- unittest/llmapi/test_reasoning_parser.py
- unittest/llmapi/test_serialization.py
- unittest/llmapi/test_utils.py
- unittest/llmapi/test_llm_args.py
- accuracy/test_cli_flow.py::TestGpt2::test_auto_dtype # 1 min
- accuracy/test_cli_flow.py::TestGpt2::test_beam_search # 1 min
- accuracy/test_cli_flow.py::TestGpt2::test_beam_search_large # 6 mins

View File

@ -106,7 +106,7 @@ l0_h100:
- test_e2e.py::test_trtllm_bench_help_sanity[meta-llama/Llama-3.1-8B]
- test_e2e.py::test_openai_chat_harmony
- test_e2e.py::test_openai_responses
- test_e2e.py::test_ptp_quickstart_multimodal[gemma-3-27b-it-gemma/gemma-3-27b-it-image-True]
- test_e2e.py::test_ptp_quickstart_multimodal[gemma-3-27b-it-gemma/gemma-3-27b-it-image-True] TIMEOUT (90)
# ------------- AutoDeploy tests ---------------
- accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype
- condition:
@ -227,7 +227,7 @@ l0_h100:
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_guided_decoding[xgrammar-mtp_nextn=0]
- accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_fp8_prequantized
- accuracy/test_llm_api_pytorch.py::TestGemma3_27BInstruct::test_fp8_prequantized
- accuracy/test_llm_api_pytorch.py::TestGemma3_27BInstruct::test_auto_dtype
- accuracy/test_llm_api_pytorch.py::TestGemma3_27BInstruct::test_auto_dtype TIMEOUT (90)
- accuracy/test_llm_api_pytorch.py::TestMistralSmall24B::test_auto_dtype
- accuracy/test_llm_api_pytorch.py::TestMistralSmall24B::test_fp8
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8_block_scales[latency-torch_compile=False]

View File

@ -252,11 +252,9 @@ examples/test_recurrentgemma.py::test_llm_recurrentgemma_1gpu[use_cpp_session-re
examples/test_bert.py::test_llm_bert_general[compare_hf-enable_remove_input_padding-use_attention_plugin-enable_context_fmha-tp:1-pp:1-float16-BertForSequenceClassification-bert/bert-base-uncased-yelp-polarity] SKIP (https://nvbugs/5421989)
examples/test_bert.py::test_llm_bert_general[compare_hf-enable_remove_input_padding-use_attention_plugin-enable_context_fmha-tp:1-pp:1-float16-RobertaForSequenceClassification-bert/twitter-roberta-base-emotion] SKIP (https://nvbugs/5421989)
examples/test_granite.py::test_granite_bf16_lora[granite-3.0-1b-a400m-instruct] SKIP (https://nvbugs/5431132)
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=1-ctx_pp=4] SKIP (https://nvbugs/5434320)
accuracy/test_llm_api.py::TestLlama3_2_1B::test_int4_awq_int8_kv_cache SKIP (https://nvbugs/5433541)
accuracy/test_llm_api.py::TestLlama3_2_1B::test_fp8_pp2 SKIP (https://nvbugs/5433541)
accuracy/test_llm_api_pytorch.py::TestPhi4MiniInstruct::test_auto_dtype SKIP (https://nvbugs/5433545)
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[GSM8K-gen_tp=2-ctx_pp=4] SKIP (https://nvbugs/5434320)
examples/test_gemma.py::test_hf_gemma_fp8_base_bf16_multi_lora[gemma-2-9b-it] SKIP (https://nvbugs/5434451)
examples/test_gemma.py::test_hf_gemma_fp8_base_bf16_multi_lora[gemma-2-27b-it] SKIP (https://nvbugs/5434451)
examples/test_gemma.py::test_hf_gemma_fp8_base_bf16_multi_lora[gemma-3-1b-it] SKIP (https://nvbugs/5434451)
@ -278,7 +276,6 @@ triton_server/test_triton.py::test_gpt_ib_ptuning[gpt-ib-ptuning] SKIP (https://
triton_server/test_triton.py::test_mistral_ib_mm[mistral-ib-mm] SKIP (https://nvbugs/5371343)
triton_server/test_triton.py::test_t5_ib[t5-ib] SKIP (https://nvbugs/5456482)
triton_server/test_triton_llm.py::test_gpt_speculative_decoding_bls[False-False-1---False-True-True-0-128-disableDecoupleMode-inflight_fused_batching-disableTrtOverlap-0.2-guaranteed_no_evict---1-1-1-False-ensemble] SKIP (https://nvbugs/5456485)
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[GSM8K-gen_tp=1-ctx_pp=4] SKIP (https://nvbugs/5434320)
accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4[latency_moe_trtllm_eagle3] SKIP (https://nvbugs/5437384)
llmapi/test_llm_examples.py::test_llmapi_speculative_decoding_mtp SKIP (https://nvbugs/5461796)
accuracy/test_llm_api.py::TestLlama3_1_8BInstruct::test_gather_generation_logits_cuda_graph SKIP (https://nvbugs/5365525)
@ -287,6 +284,12 @@ examples/test_phi.py::test_phi_fp8_with_bf16_lora[Phi-3-small-128k-instruct] SKI
examples/test_phi.py::test_phi_fp8_with_bf16_lora[Phi-3.5-mini-instruct] SKIP (https://nvbugs/5465143)
examples/test_phi.py::test_phi_fp8_with_bf16_lora[Phi-4-mini-instruct] SKIP (https://nvbugs/5465143)
examples/test_llama.py::test_llm_llama_v1_2gpu_summary[llama-7b-nb:4-enable_auto_parallel] SKIP (https://nvbugs/5465173)
test_e2e.py::test_ptp_quickstart_multimodal[llava-v1.6-mistral-7b-llava-v1.6-mistral-7b-hf-image-False] SKIP (https://nvbugs/5444095)
full:GB200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen1.5_7b_chat-enable_fmha_fp32_acc-enable_plugin-tp2pp2-nb:4] SKIP (https://nvbugs/5247837)
full:GB200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen2_7b_instruct-enable_fmha_fp32_acc-enable_plugin-tp2pp2-nb:4] SKIP (https://nvbugs/5247837)
full:GB200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen2_vl_7b_instruct-enable_fmha_fp32_acc-enable_plugin-tp2pp2-nb:4] SKIP (https://nvbugs/5359696)
full:GB200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen2.5_7b_chat-enable_fmha_fp32_acc-enable_plugin-tp2pp2-nb:4] SKIP (https://nvbugs/5247837)
accuracy/test_cli_flow.py::TestLlama3_8BInstructGradient1048k::test_long_context_ppl SKIP (https://nvbugs/5413362)
disaggregated/test_disaggregated.py::test_disaggregated_diff_max_tokens[TinyLlama-1.1B-Chat-v1.0] SKIP (https://nvbugs/5451272)
disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_tp1_single_gpu_mtp[DeepSeek-V3-Lite-fp8] SKIP (https://nvbugs/5465642)
examples/test_multimodal.py::test_llm_multimodal_general[Mistral-Small-3.1-24B-Instruct-2503-pp:1-tp:1-bfloat16-bs:1-cpp_e2e:False-nb:1] SKIP (https://nvbugs/5431146)
@ -340,17 +343,6 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus_online_ep
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False] SKIP (https://nvbugs/5488118)
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True] SKIP (https://nvbugs/5488118)
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ngram SKIP (https://nvbugs/5488118)
full:GB200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen1.5_7b_chat-enable_fmha_fp32_acc-enable_plugin-tp2pp2-nb:4] SKIP (https://nvbugs/5247837)
full:GB200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen2_7b_instruct-enable_fmha_fp32_acc-enable_plugin-tp2pp2-nb:4] SKIP (https://nvbugs/5247837)
full:GB200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen2_vl_7b_instruct-enable_fmha_fp32_acc-enable_plugin-tp2pp2-nb:4] SKIP (https://nvbugs/5359696)
full:GB200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen2.5_7b_chat-enable_fmha_fp32_acc-enable_plugin-tp2pp2-nb:4] SKIP (https://nvbugs/5247837)
accuracy/test_cli_flow.py::TestLlama3_8BInstructGradient1048k::test_long_context_ppl SKIP (https://nvbugs/5413362)
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_tp8] SKIP (https://nvbugs/5455140)
full:L40S/accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype[mtp_nextn=0-overlap_scheduler=False] SKIP (https://nvbugs/5347051)
full:L40S/accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=False-overlap_scheduler=False] SKIP (https://nvbugs/5471106)
full:L40S/accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[MMLU-tp2pp2] SKIP (https://nvbugs/5471108)
test_e2e.py::test_multi_nodes_eval[llama4-models/nvidia/Llama-4-Maverick-17B-128E-Instruct-FP8-tp8pp2-mmlu] SKIP (https://nvbugs/5473781)
accuracy/test_cli_flow.py::TestLlama3_8BInstructGradient1048k::test_long_context_ppl SKIP (https://nvbugs/5413362)
stress_test/stress_test.py::test_run_stress_test[llama-v3-8b-instruct-hf_tp1-stress_time_300s_timeout_450s-GUARANTEED_NO_EVICT-pytorch-stress-test] SKIP (https://nvbugs/5474169)
test_e2e.py::test_trtllm_bench_iteration_log[TRT-streaming-meta-llama/Llama-3.1-8B-llama-3.1-model/Meta-Llama-3.1-8B] SKIP (https://nvbugs/5448523)
cpp/test_unit_tests.py::test_unit_tests[kernels-80] SKIP (https://nvbugs/5504078)

View File

@ -17,6 +17,7 @@ def similar(a, b, threshold=0.9):
return SequenceMatcher(None, a, b).ratio() >= threshold
@pytest.mark.skip(reason="https://nvbugs/5470782")
@pytest.mark.parametrize("model_name", ["DeepSeek-V3-Lite"],
ids=["deepseekv3_lite"])
@pytest.mark.parametrize("backend", ["TRTLLM"], ids=["trtllm"])

View File

@ -0,0 +1,205 @@
import pytest
import torch
from tensorrt_llm._torch.models.modeling_multimodal_utils import (
filter_mm_token_from_input_ids, fuse_input_embeds)
from tensorrt_llm._torch.modules.embedding import Embedding
def make_embedding(num_embeddings: int = 100,
hidden_size: int = 16,
device: str = "cpu") -> Embedding:
torch.manual_seed(0)
emb = Embedding(num_embeddings=num_embeddings, embedding_dim=hidden_size)
emb.weight.data.normal_(mean=0.0, std=0.02)
return emb.to(device)
@pytest.mark.parametrize("device", ["cpu"] +
(["cuda"] if torch.cuda.is_available() else []))
def test_filter_mm_token_from_input_ids_oov(device):
vocab_size = 10
# input_ids contains text (< vocab) and OOV mm tokens (>= vocab)
input_ids = torch.tensor([1, 2, 11, 3, 12, 9, 15],
dtype=torch.long,
device=device)
text_idx, mm_idx = filter_mm_token_from_input_ids(input_ids,
vocab_size=vocab_size,
mm_token_ids=None)
torch.testing.assert_close(text_idx.cpu(),
torch.tensor([0, 1, 3, 5], dtype=torch.long))
torch.testing.assert_close(mm_idx.cpu(),
torch.tensor([2, 4, 6], dtype=torch.long))
@pytest.mark.parametrize("device", ["cpu"] +
(["cuda"] if torch.cuda.is_available() else []))
def test_filter_mm_token_from_input_ids_explicit_ids(device):
vocab_size = 100
# All ids are < vocab; mm tokens are explicitly specified
input_ids = torch.tensor([1, 2, 55, 3, 77, 9, 88],
dtype=torch.long,
device=device)
mm_token_ids = torch.tensor([55, 77, 88], dtype=torch.long)
text_idx, mm_idx = filter_mm_token_from_input_ids(input_ids,
vocab_size=vocab_size,
mm_token_ids=mm_token_ids)
torch.testing.assert_close(text_idx.cpu(),
torch.tensor([0, 1, 3, 5], dtype=torch.long))
torch.testing.assert_close(mm_idx.cpu(),
torch.tensor([2, 4, 6], dtype=torch.long))
# Even with some ids > vocab, mm indices should still only match the given mm_token_ids
input_ids = torch.tensor([1, 2, 55, 3, 77, 9, 88, 101, 102, 103],
dtype=torch.long,
device=device)
_, mm_idx = filter_mm_token_from_input_ids(input_ids,
vocab_size=vocab_size,
mm_token_ids=mm_token_ids)
torch.testing.assert_close(mm_idx.cpu(),
torch.tensor([2, 4, 6], dtype=torch.long))
@pytest.mark.parametrize("device", ["cpu"] +
(["cuda"] if torch.cuda.is_available() else []))
def test_fuse_input_embeds_empty_mm_returns_ids(device):
emb = make_embedding(num_embeddings=20, hidden_size=8, device=device)
input_ids = torch.tensor([1, 2, 3, 4], dtype=torch.long, device=device)
out_ids, out_embeds = fuse_input_embeds(emb,
input_ids,
mm_embeds=[],
mm_token_ids=None)
# No mm embeddings => passthrough ids, no embeds fused
assert out_embeds is None
torch.testing.assert_close(out_ids, input_ids)
@pytest.mark.parametrize("device", ["cpu"] +
(["cuda"] if torch.cuda.is_available() else []))
def test_fuse_input_embeds_mismatch_raises(device):
emb = make_embedding(num_embeddings=50, hidden_size=8, device=device)
# Mix text (< vocab) and mm (>= vocab) tokens. Here vocab_size == 50
input_ids = torch.tensor([1, 51, 2, 52, 3, 53],
dtype=torch.long,
device=device)
# Identify indices first to drive the lower-level CUDA fuse directly for mismatch
text_idx, mm_idx = filter_mm_token_from_input_ids(
input_ids, vocab_size=emb.num_embeddings)
# Provide the wrong number of mm embeddings (e.g., one short)
hidden = 8
true_mm_count = mm_idx.shape[0]
wrong_mm = torch.randn(true_mm_count - 1, hidden, device=device)
with pytest.raises(ValueError, match="Multimodal token count mismatch"):
fuse_input_embeds(emb,
input_ids, [wrong_mm],
mm_token_ids=None,
text_token_indices=text_idx,
mm_token_indices=mm_idx)
@pytest.mark.parametrize("device", ["cpu"] +
(["cuda"] if torch.cuda.is_available() else []))
def test_fuse_input_embeds_success_oov_path(device):
hidden = 8
emb = make_embedding(num_embeddings=40, hidden_size=hidden, device=device)
# input ids: mix of text (<40) and mm (>=40)
input_ids = torch.tensor([0, 1, 41, 2, 42, 3, 43, 4],
dtype=torch.long,
device=device)
# Build mm embeddings to match number of OOV positions
text_idx, mm_idx = filter_mm_token_from_input_ids(
input_ids, vocab_size=emb.num_embeddings)
mm_emb = torch.randn(mm_idx.shape[0], hidden, device=device)
# kwargs path to produce fused embeddings
out_ids, out_embeds = fuse_input_embeds(emb,
input_ids,
mm_embeds=[mm_emb],
mm_token_ids=None,
text_token_indices=text_idx,
mm_token_indices=mm_idx)
# integrated filtering path to produce fused embeddings (not kwargs path)
out_ids_v2, out_embeds_v2 = fuse_input_embeds(emb,
input_ids,
mm_embeds=[mm_emb],
mm_token_ids=None)
assert out_ids is None
assert out_embeds is not None
assert out_embeds.shape == (input_ids.numel(), hidden)
# Validate that text positions equal embedding lookup, and mm positions equal provided mm_emb
text_idx, mm_idx2 = filter_mm_token_from_input_ids(
input_ids, vocab_size=emb.num_embeddings)
torch.testing.assert_close(mm_idx2, mm_idx)
ref_text = emb(input_ids[text_idx])
torch.testing.assert_close(out_embeds[text_idx], ref_text)
torch.testing.assert_close(
out_embeds[mm_idx],
mm_emb.to(dtype=out_embeds.dtype, device=out_embeds.device))
torch.testing.assert_close(out_embeds_v2, out_embeds)
torch.testing.assert_close(out_ids_v2, out_ids)
@pytest.mark.parametrize("device", ["cpu"] +
(["cuda"] if torch.cuda.is_available() else []))
def test_fuse_input_embeds_kwargs_precedence_over_sentinel_and_ids(device):
"""
Ensure that when kwargs provide precomputed indices, they take precedence
over both OOV-sentinel filtering and explicit mm_token_ids.
"""
hidden = 8
vocab_size = 40
emb = make_embedding(num_embeddings=vocab_size,
hidden_size=hidden,
device=device)
# Use vocab_size+1 as OOV sentinel
oov_sentinel = vocab_size + 1
input_ids = torch.tensor([0, oov_sentinel, 1, oov_sentinel, 2],
dtype=torch.long,
device=device)
# Precompute correct indices (kwargs path)
text_idx, mm_idx = filter_mm_token_from_input_ids(input_ids,
vocab_size=vocab_size,
mm_token_ids=None)
mm_emb = torch.randn(mm_idx.shape[0], hidden, device=device)
# Provide a deliberately incorrect mm_token_ids to ensure it is ignored
bad_mm_token_ids = torch.tensor(
[0], dtype=torch.long,
device=device) # would misclassify index 0 as mm if used
out_ids, out_embeds = fuse_input_embeds(
emb,
input_ids,
mm_embeds=[mm_emb],
mm_token_ids=
bad_mm_token_ids, # should be ignored because indices are provided
text_token_indices=text_idx,
mm_token_indices=mm_idx,
)
# Validate outputs
assert out_ids is None
assert out_embeds is not None
assert out_embeds.shape == (input_ids.numel(), hidden)
ref_text = emb(input_ids[text_idx])
torch.testing.assert_close(out_embeds[text_idx], ref_text)
torch.testing.assert_close(
out_embeds[mm_idx],
mm_emb.to(dtype=out_embeds.dtype, device=out_embeds.device),
)

View File

@ -16,7 +16,6 @@ from tensorrt_llm.llmapi import (CudaGraphConfig, EagleDecodingConfig,
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
@pytest.mark.skip(reason="https://nvbugs/5461761")
@pytest.mark.parametrize(
"use_cuda_graph,attn_backend,disable_overlap_scheduler,enable_block_reuse,use_one_model,enable_chunked_prefill",
[
@ -27,7 +26,8 @@ sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
[False, "TRTLLM", False, True, True, False],
[True, "TRTLLM", False, True, True, False],
[True, "TRTLLM", True, False, True, True],
[True, "TRTLLM", True, False, False, True],
# TODO: nvbugs/5461761
# [True, "TRTLLM", True, False, False, True],
])
@pytest.mark.high_cuda_memory
def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,

View File

@ -0,0 +1,189 @@
import random
import pytest
import torch
from tensorrt_llm._torch.pyexecutor.llm_request import convert_wordlist
from tensorrt_llm._torch.pyexecutor.sampler import (BEAM_0, LlmRequest,
TorchSampler)
from tensorrt_llm._torch.pyexecutor.sampler_utils import produce_stop_words
from tensorrt_llm.bindings import SamplingConfig
from tensorrt_llm.bindings.executor import FinishReason
MAX_NUM_SEQUENCES = 128
NOT_FINISHED = FinishReason.NOT_FINISHED
STOP_WORDS = FinishReason.STOP_WORDS
END_ID = FinishReason.END_ID
LENGTH = FinishReason.LENGTH
class RequestCase:
MAX_NEW_TOKENS = 10
seq_slots = random.sample(range(MAX_NUM_SEQUENCES), MAX_NUM_SEQUENCES)
def __init__(self,
*,
prompt: list[int],
new_tokens: list[int],
finish_reasons: list[FinishReason],
max_new_tokens: int = MAX_NEW_TOKENS,
end_id: int = None,
stop_words_list: list[list[int]] = None):
seq_slot = self.seq_slots.pop() # random seq slot in MAX_NUM_SEQUENCES
self.prompt = prompt
self.request = LlmRequest(
request_id=seq_slot,
seq_slot=seq_slot,
input_tokens=prompt,
max_new_tokens=max_new_tokens,
stop_words_list=convert_wordlist(stop_words_list)
if stop_words_list is not None else None,
end_id=end_id,
sampling_config=SamplingConfig(),
is_streaming=False,
)
assert len(new_tokens) == len(finish_reasons)
self.new_tokens = new_tokens
self.finish_reasons = finish_reasons
def __repr__(self):
return f"RequestCase({self.prompt=}, {self.new_tokens=}, {self.finish_reasons=}, {self.request.max_new_tokens=}, {self.request.end_id=}, {self.request.stop_words_list=})"
@staticmethod
def setup(requests: list["RequestCase"]):
max_tokens = set(len(req.new_tokens) for req in requests)
assert len(max_tokens) == 1
max_draft_len = max_tokens.pop() - 1
sampler_args = TorchSampler.Args(
max_seq_len=20,
max_draft_len=max_draft_len,
# Fill with many more max requests than below, so we can test that write_finish_reasons uses seq_slots correctly
max_num_sequences=MAX_NUM_SEQUENCES,
max_beam_width=1,
enable_mixed_sampler=False)
sampler = TorchSampler(args=sampler_args)
# fill with garbage value so we can observe that finish reasons are filled with NOT_FINISHED before we write to them.
sampler.store.finish_reasons.fill_(205)
seq_slots = torch.tensor([req.request.py_seq_slot for req in requests],
device="cuda",
dtype=torch.int64)
new_tokens = torch.tensor([req.new_tokens for req in requests],
dtype=torch.int32,
device="cuda").T
sampler.store.new_tokens[:, seq_slots, BEAM_0] = new_tokens
def run():
sampler._write_finish_reasons(
[req.request for req in requests],
finish_reasons=sampler.store.finish_reasons,
new_tokens=sampler.store.new_tokens,
seq_slots=seq_slots)
reasons = sampler.store.finish_reasons[:, seq_slots,
BEAM_0].T.tolist()
for actual, request in zip(reasons, requests, strict=True):
expected = request.finish_reasons
msg = f"actual={[FinishReason(reason) for reason in actual]} != expected={expected}\nFor {request}"
assert actual == [reason.value for reason in expected], msg
return run, sampler
def test_write_finish_reasons():
"""We don't really care about the finish reason past the first infraction, because we're not going to use it, although in some instance it is written anyway."""
run, _ = RequestCase.setup([
RequestCase(
prompt=[13, 14],
new_tokens=[60, 61, 62],
# We pre-fill the finish reasons with NOT_FINISHED.
finish_reasons=[NOT_FINISHED, NOT_FINISHED, NOT_FINISHED],
),
RequestCase(
prompt=[7, 8, 6],
stop_words_list=[[12, 13]],
new_tokens=[12, 13, 60],
finish_reasons=[NOT_FINISHED, STOP_WORDS, NOT_FINISHED],
),
RequestCase(
prompt=[1, 2, 3, 4],
end_id=99,
new_tokens=[55, 99, 58],
finish_reasons=[NOT_FINISHED, END_ID, NOT_FINISHED],
),
RequestCase(
prompt=[4, 5, 6],
max_new_tokens=2,
new_tokens=[56, 57, 59],
# The LENGTH check happens to not have an early exit
finish_reasons=[NOT_FINISHED, LENGTH, LENGTH]),
RequestCase(
prompt=[1, 12],
stop_words_list=[[12, 13], [14, 15]],
new_tokens=[13, 14, 15],
# We have an early exit specifically for stop words
finish_reasons=[STOP_WORDS, NOT_FINISHED, NOT_FINISHED],
),
RequestCase(
prompt=[1],
max_new_tokens=2,
end_id=99,
stop_words_list=[[1, 12]],
new_tokens=[12, 99, 63],
# Different infractions are written to different places as we don't have an early exit between infractions
finish_reasons=[STOP_WORDS, END_ID, LENGTH],
),
RequestCase(
prompt=[1, 12, 56, 67, 68, 234, 678],
stop_words_list=[[12, 56, 67, 68, 234, 678, 129, 182]],
new_tokens=[129, 182, 600],
# Notice the offending stop sequence is concatenated, as we lookback
finish_reasons=[NOT_FINISHED, STOP_WORDS, NOT_FINISHED],
),
RequestCase(
prompt=[1, 12],
end_id=99,
max_new_tokens=1,
stop_words_list=[[1, 12, 99]],
new_tokens=[99, 100, 101],
# The latest infraction check overrides the earlier infraction checks, hence the first finish_reason is END_ID
finish_reasons=[END_ID, LENGTH, LENGTH],
),
])
run()
def test_are_stop_words_isnt_called_when_no_stop_words():
"""We don't want to call are_stop_words when there are no stop words because it's expensive"""
def stop_words_that_raises(*args, **kwargs):
raise AssertionError
run_with_stop_words, sampler = RequestCase.setup([
RequestCase(prompt=[1],
stop_words_list=[[1]],
new_tokens=[4],
finish_reasons=[NOT_FINISHED])
])
sampler._are_stop_words = stop_words_that_raises
with pytest.raises(AssertionError):
run_with_stop_words()
run_without_stop_words, sampler = RequestCase.setup([
RequestCase(prompt=[1], new_tokens=[4], finish_reasons=[NOT_FINISHED])
])
sampler._are_stop_words = stop_words_that_raises
_ = run_without_stop_words()
def test_produce_stop_words():
for original in [
[[]],
[[1]],
[[1, 2, 3]],
[[1], [2, 3]],
[[1, 2, 3], [4, 5]],
[[10], [20], [30, 40], [50]],
]:
assert original == list(produce_stop_words(convert_wordlist(original)))

View File

@ -147,6 +147,10 @@ async def test_streaming(client: openai.AsyncOpenAI, model: str):
collected_chunks = []
collected_messages = []
async for chunk in response:
# Last streaming response will only contains usage info
if len(chunk.choices) <= 0:
continue
collected_chunks.append(chunk)
collected_messages.append(chunk.choices[0].delta)
@ -198,6 +202,10 @@ async def test_streaming_tool_call(client: openai.AsyncOpenAI, model: str):
reasoning_chunks: list[str] = []
tool_arg_chunks: list[str] = []
async for chunk in response:
# Last streaming response will only contains usage info
if len(chunk.choices) <= 0:
continue
delta = chunk.choices[0].delta
if hasattr(delta, "tool_calls") and delta.tool_calls:
function = delta.tool_calls[0].function

View File

@ -65,7 +65,10 @@ def engine_from_fp8_quantization(model_name):
@pytest.fixture(scope="module")
def server(model_name: str, engine_from_fp8_quantization: str):
model_path = get_model_path(model_name)
args = ["--tp_size", "2", "--tokenizer", model_path]
args = [
"--tp_size", "2", "--tokenizer", model_path, "--backend", "trt",
"--max_num_tokens", "20480", "--max_batch_size", "128"
]
with RemoteOpenAIServer(engine_from_fp8_quantization,
args) as remote_server:
yield remote_server

View File

@ -1,4 +1,6 @@
from tensorrt_llm.llmapi.utils import ApiStatusRegistry
from tensorrt_llm.llmapi import LlmArgs
from tensorrt_llm.llmapi.utils import (ApiStatusRegistry,
generate_api_docs_as_docstring)
def test_api_status_registry():
@ -24,3 +26,9 @@ def test_api_status_registry():
pass
assert ApiStatusRegistry.get_api_status(App._my_method) == "beta"
def test_generate_api_docs_as_docstring():
doc = generate_api_docs_as_docstring(LlmArgs)
assert ":tag:`beta`" in doc, "the label is not generated"
print(doc)