mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Merge remote-tracking branch 'origin/main' into feat/b300_cu13
Signed-off-by: Xiwen Yu <13230610+VALLIS-NERIA@users.noreply.github.com>
This commit is contained in:
commit
fdaf4e2985
@ -47,7 +47,7 @@ public:
|
||||
|
||||
bool operator()(DecoderInputBuffers& inputBuffers, bool replicateLogitsPostProcessor,
|
||||
runtime::WorldConfig const& worldConfig, CudaStreamPtr const& stream,
|
||||
std::optional<LogitsPostProcessorBatched> logitsPostProcessorBatched = std::nullopt) const;
|
||||
std::optional<LogitsPostProcessorBatched> const& logitsPostProcessorBatched = std::nullopt) const;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::batch_manager
|
||||
|
||||
@ -204,6 +204,34 @@ private:
|
||||
}
|
||||
}
|
||||
|
||||
void sendResponse(std::vector<size_t> const& blockHashes, std::map<RequestIdType, Response>::iterator it)
|
||||
{
|
||||
auto reqId = mCurrentRequest.value();
|
||||
auto count = --mRemainSendCount[reqId];
|
||||
TLLM_CHECK(count >= 0);
|
||||
if (count == 0)
|
||||
{
|
||||
mRemainSendCount.erase(reqId);
|
||||
|
||||
// TODO(zhengd): pass the hashes directly instead of update llmRequest
|
||||
auto llmRequest = it->second.mRequest;
|
||||
llmRequest->setRequestedBlockHashes(std::move(blockHashes));
|
||||
|
||||
if (common::getEnvParallelCacheSend())
|
||||
{
|
||||
// TODO: Use a thread pool and check for thread safety.
|
||||
std::thread(&DataResponder::Impl::sendAndRemoveResponse, this, it->first, std::move(it->second))
|
||||
.detach();
|
||||
}
|
||||
else
|
||||
{
|
||||
DataResponder::Impl::sendAndRemoveResponse(it->first, std::move(it->second));
|
||||
}
|
||||
removeResponse(it);
|
||||
}
|
||||
mCurrentRequest = std::nullopt;
|
||||
}
|
||||
|
||||
void response() noexcept
|
||||
{
|
||||
try
|
||||
@ -237,40 +265,22 @@ private:
|
||||
auto it = getCurrentResponse();
|
||||
if (it != mReadyResponses.end())
|
||||
{
|
||||
auto reqId = mCurrentRequest.value();
|
||||
auto count = --mRemainSendCount[reqId];
|
||||
TLLM_CHECK(count >= 0);
|
||||
if (count == 0)
|
||||
{
|
||||
mRemainSendCount.erase(reqId);
|
||||
|
||||
// TODO(zhengd): pass the hashes directly instead of update llmRequest
|
||||
auto llmRequest = it->second.mRequest;
|
||||
llmRequest->setRequestedBlockHashes(std::move(blockHashes));
|
||||
|
||||
if (common::getEnvParallelCacheSend())
|
||||
{
|
||||
// TODO: Use a thread pool and check for thread safety.
|
||||
std::thread(
|
||||
&DataResponder::Impl::sendAndRemoveResponse, this, it->first, std::move(it->second))
|
||||
.detach();
|
||||
}
|
||||
else
|
||||
{
|
||||
DataResponder::Impl::sendAndRemoveResponse(it->first, std::move(it->second));
|
||||
}
|
||||
removeResponse(it);
|
||||
}
|
||||
mCurrentRequest = std::nullopt;
|
||||
sendResponse(blockHashes, it);
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(!mCurrentRequest.has_value(),
|
||||
"This executor does not have a prepared KV cache for request ID: %zu, and the "
|
||||
"mReadyResponses size is: %zu. mpi rank :%d ",
|
||||
mCurrentRequest.value(), mReadyResponses.size(), mpi::MpiComm::world().getRank());
|
||||
std::unique_lock lk(mCondMutex);
|
||||
mResponderCv.wait(lk, [this]() { return (mAnyReady || mTerminate); });
|
||||
auto it = getCurrentResponse();
|
||||
while (it == mReadyResponses.end())
|
||||
{
|
||||
std::unique_lock lk(mCondMutex);
|
||||
mResponderCv.wait(lk, [this]() { return (mAnyReady || mTerminate); });
|
||||
if (mTerminate)
|
||||
{
|
||||
break;
|
||||
}
|
||||
it = getCurrentResponse();
|
||||
}
|
||||
sendResponse(blockHashes, it);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -34,7 +34,7 @@ using SizeType32 = tensorrt_llm::runtime::SizeType32;
|
||||
|
||||
bool LogitsPostProcessor::operator()(DecoderInputBuffers& inputBuffers, bool replicateLogitsPostProcessor,
|
||||
tr::WorldConfig const& worldConfig, CudaStreamPtr const& stream,
|
||||
std::optional<LogitsPostProcessorBatched> logitsPostProcessorBatched) const
|
||||
std::optional<LogitsPostProcessorBatched> const& logitsPostProcessorBatched) const
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
NVTX3_SCOPED_RANGE(LogitsPostProcessor);
|
||||
|
||||
@ -201,56 +201,60 @@ Metrics Endpoint
|
||||
|
||||
.. note::
|
||||
|
||||
This endpoint is beta maturity.
|
||||
The metrics endpoint for the default PyTorch backend are in beta and are not as comprehensive as those for the TensorRT backend.
|
||||
|
||||
The statistics for the PyTorch backend are beta and not as comprehensive as those for the TensorRT backend.
|
||||
Some fields, such as CPU memory usage, are not yet available for the PyTorch backend.
|
||||
|
||||
Some fields, such as CPU memory usage, are not available for the PyTorch backend.
|
||||
Enabling ``enable_iter_perf_stats`` in the PyTorch backend can slightly impact performance, depending on the serving configuration.
|
||||
|
||||
Enabling ``enable_iter_perf_stats`` in the PyTorch backend can impact performance slightly, depending on the serving configuration.
|
||||
The ``/metrics`` endpoint provides runtime iteration statistics such as GPU memory usage and KV cache details.
|
||||
|
||||
The ``/metrics`` endpoint provides runtime-iteration statistics such as GPU memory use and inflight-batching details.
|
||||
For the TensorRT backend, these statistics are enabled by default.
|
||||
However, for the PyTorch backend, you must explicitly enable iteration statistics logging by setting the `enable_iter_perf_stats` field in a YAML configuration file as shown in the following example:
|
||||
For the default PyTorch backend, iteration statistics logging is enabled by setting the ``enable_iter_perf_stats`` field in a YAML file:
|
||||
|
||||
.. code-block:: yaml
|
||||
|
||||
# extra-llm-api-config.yml
|
||||
pytorch_backend_config:
|
||||
enable_iter_perf_stats: true
|
||||
# extra_llm_config.yaml
|
||||
enable_iter_perf_stats: true
|
||||
|
||||
Then start the server and specify the ``--extra_llm_api_options`` argument with the path to the YAML file as shown in the following example:
|
||||
Start the server and specify the ``--extra_llm_api_options`` argument with the path to the YAML file:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
trtllm-serve <model> \
|
||||
--extra_llm_api_options <path-to-extra-llm-api-config.yml> \
|
||||
[--tp_size <tp> --pp_size <pp> --ep_size <ep> --host <host> --port <port>]
|
||||
trtllm-serve "TinyLlama/TinyLlama-1.1B-Chat-v1.0" --extra_llm_api_options extra_llm_config.yaml
|
||||
|
||||
After at least one inference request is sent to the server, you can fetch the runtime-iteration statistics by polling the `/metrics` endpoint:
|
||||
After sending at least one inference request to the server, you can fetch runtime iteration statistics by polling the ``/metrics`` endpoint.
|
||||
Since the statistics are stored in an internal queue and removed once retrieved, it's recommended to poll the endpoint shortly after each request and store the results if needed.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
curl -X GET http://<host>:<port>/metrics
|
||||
curl -X GET http://localhost:8000/metrics
|
||||
|
||||
*Example Output*
|
||||
Example output:
|
||||
|
||||
.. code-block:: json
|
||||
|
||||
[
|
||||
{
|
||||
"gpuMemUsage": 56401920000,
|
||||
"inflightBatchingStats": {
|
||||
[
|
||||
{
|
||||
"gpuMemUsage": 76665782272,
|
||||
"iter": 154,
|
||||
"iterLatencyMS": 7.00688362121582,
|
||||
"kvCacheStats": {
|
||||
"allocNewBlocks": 3126,
|
||||
"allocTotalBlocks": 3126,
|
||||
"cacheHitRate": 0.00128,
|
||||
"freeNumBlocks": 101253,
|
||||
"maxNumBlocks": 101256,
|
||||
"missedBlocks": 3121,
|
||||
"reusedBlocks": 4,
|
||||
"tokensPerBlock": 32,
|
||||
"usedNumBlocks": 3
|
||||
},
|
||||
"numActiveRequests": 1
|
||||
...
|
||||
},
|
||||
"iter": 1,
|
||||
"iterLatencyMS": 16.505143404006958,
|
||||
"kvCacheStats": {
|
||||
...
|
||||
},
|
||||
"newActiveRequestsQueueLatencyMS": 0.0007503032684326172
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
|
||||
Syntax
|
||||
------
|
||||
|
||||
@ -234,6 +234,28 @@ TODO: Use Chat Compeletions API / Responses API as the example after the PR is m
|
||||
We use OpenAI's official evaluation tool to test the model's accuracy. For more information see [https://github.com/openai/gpt-oss/tree/main/gpt_oss/evals](gpt-oss-eval).
|
||||
With the added support of Chat Completions and Responses API in `trtllm-serve,` `gpt_oss.evals` works directly without any modifications.
|
||||
|
||||
You need to set `enable_attention_dp`, `tp_size`, `ep_size`, `max_batch_size` and `max_num_tokens` when launching the trtllm server and set `reasoning-effort` when launching evaluation in gpt-oss. Below are some reference configurations for accuracy evaluation on B200.
|
||||
|
||||
| **reasoning-effort** | **parallel configuration** | **max_batch_size** | **max_num_tokens** |
|
||||
|:--------------------:|:--------------------------:|:------------------:|:------------------:|
|
||||
| low/medium | DEP8 / DEP4 | 128 | 32768 |
|
||||
| high | DEP8 / DEP4 | 2 | 133120 |
|
||||
| low/medium | TP8 / TP4 | 1024 | 32768 |
|
||||
| high | TP8 / TP4 | 720 | 133120 |
|
||||
|
||||
Below is an example command for evaluating the accuracy of gpt-oss-120b with low and medium reasoning-effort on GPQA and AIME2025.
|
||||
|
||||
```shell
|
||||
# execute this command in gpt-oss
|
||||
python -m gpt_oss.evals \
|
||||
--sampler chat_completions \
|
||||
--eval gpqa,aime25 \
|
||||
--model gpt-oss-120b \
|
||||
--reasoning-effort low,medium
|
||||
```
|
||||
|
||||
|
||||
|
||||
## Benchmarking Performance
|
||||
|
||||
To benchmark the performance of your TensorRT-LLM server you can leverage the built-in `benchmark_serving.py` script. To do this first creating a wrapper `bench.sh` script.
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# LLM API with TensorRT Engine
|
||||
A simple inference example with TinyLlama using the LLM API:
|
||||
|
||||
```{literalinclude} ../../examples/llm-api/_tensorrt_engine/quickstart_example.py
|
||||
```{literalinclude} ../../../examples/llm-api/_tensorrt_engine/quickstart_example.py
|
||||
:language: python
|
||||
:linenos:
|
||||
```
|
||||
|
||||
@ -1,11 +1,17 @@
|
||||
from tensorrt_llm import LLM, SamplingParams
|
||||
from tensorrt_llm import BuildConfig, SamplingParams
|
||||
from tensorrt_llm._tensorrt_engine import LLM # NOTE the change
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
build_config = BuildConfig()
|
||||
build_config.max_batch_size = 256
|
||||
build_config.max_num_tokens = 1024
|
||||
|
||||
# Model could accept HF model name, a path to local HF model,
|
||||
# or TensorRT Model Optimizer's quantized checkpoints like nvidia/Llama-3.1-8B-Instruct-FP8 on HF.
|
||||
llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
|
||||
llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
build_config=build_config)
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
|
||||
@ -76,6 +76,7 @@ srun -l \
|
||||
|
||||
# This is optional
|
||||
cat > /tmp/pytorch_extra_args.txt << EOF
|
||||
cuda_graph_config: null
|
||||
print_iter_log: true
|
||||
enable_attention_dp: false
|
||||
EOF
|
||||
|
||||
@ -94,6 +94,26 @@ def createKubernetesPodConfig(type, arch = "amd64", build_wheel = false)
|
||||
"""
|
||||
}
|
||||
|
||||
if (arch == "amd64") {
|
||||
// For x86_64, we block some nodes to avoid unstable network access.
|
||||
selectors += """
|
||||
affinity:
|
||||
nodeAffinity:
|
||||
requiredDuringSchedulingIgnoredDuringExecution:
|
||||
nodeSelectorTerms:
|
||||
- matchExpressions:
|
||||
- key: "kubernetes.io/hostname"
|
||||
operator: NotIn
|
||||
values:
|
||||
- "sc-ipp-blossom-prod-k8w-105"
|
||||
- "sc-ipp-blossom-prod-k8w-114"
|
||||
- "sc-ipp-blossom-prod-k8w-115"
|
||||
- "sc-ipp-blossom-prod-k8w-121"
|
||||
- "sc-ipp-blossom-prod-k8w-123"
|
||||
- "sc-ipp-blossom-prod-k8w-124"
|
||||
"""
|
||||
}
|
||||
|
||||
def archSuffix = arch == "arm64" ? "arm" : "amd"
|
||||
def jnlpImage = "urm.nvidia.com/sw-ipp-blossom-sre-docker-local/lambda/custom_jnlp_images_${archSuffix}_linux:jdk17"
|
||||
|
||||
|
||||
@ -365,13 +365,24 @@ def runLLMTestlistOnSlurm(pipeline, platform, testList, config=VANILLA_CONFIG, p
|
||||
}
|
||||
|
||||
stage('Checking if the Node is Online') {
|
||||
def counter = 0
|
||||
// We submit the Slurm job with 5 hours timeout, and the K8S pod will be evicted after 22 hours.
|
||||
// Let's use 15 hours to check if the node is online, and with 2 hours buffer.
|
||||
while (!CloudManager.isNodeOnline(nodeName) && counter < 90) {
|
||||
// Wait 10 minutes to check status of the node again
|
||||
sleep(time: 10, unit: 'MINUTES')
|
||||
counter++
|
||||
withCredentials([usernamePassword(credentialsId: 'svc_tensorrt', usernameVariable: 'USERNAME', passwordVariable: 'PASSWORD')]) {
|
||||
def remote = [
|
||||
ip : cluster.ip,
|
||||
host : cluster.host,
|
||||
user : "${pipeline.USERNAME}",
|
||||
passwd : "${pipeline.PASSWORD}",
|
||||
allowAnyHosts: true,
|
||||
]
|
||||
def counter = 0
|
||||
// We submit the Slurm job with 5 hours timeout, and the K8S pod will be evicted after 22 hours.
|
||||
// Let's use 15 hours to check if the node is online, and with 2 hours buffer.
|
||||
while (!CloudManager.isNodeOnline(nodeName) && counter < 90) {
|
||||
// Wait 10 minutes to check status of the node again
|
||||
sleep(time: 10, unit: 'MINUTES')
|
||||
// Avoid the node being stuck in the held state.
|
||||
Utils.exec(pipeline, Utils.sshUserCmd(remote, "\"scontrol release ${slurmJobID} || true\""))
|
||||
counter++
|
||||
}
|
||||
}
|
||||
|
||||
if (CloudManager.isNodeOnline(nodeName)) {
|
||||
@ -1157,7 +1168,7 @@ def transformMakoArgsToJson(optList) {
|
||||
|
||||
def getMakoOpts(getMakoScript, makoArgs=[]) {
|
||||
// We want to save a map for the Mako opts
|
||||
def turtleOutput = ""
|
||||
def makoOutput = ""
|
||||
|
||||
// Echo the command
|
||||
// NOTE: We redirect stderr to stdout so that we can capture
|
||||
@ -1181,17 +1192,17 @@ def getMakoOpts(getMakoScript, makoArgs=[]) {
|
||||
|
||||
// Capture the mako output, add timeout in case any hang
|
||||
timeout(time: 30, unit: 'MINUTES'){
|
||||
turtleOutput = sh(label: "Capture Mako Parameters", script: listMakoCmd, returnStdout: true)
|
||||
makoOutput = sh(label: "Capture Mako Parameters", script: listMakoCmd, returnStdout: true)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate output
|
||||
assert turtleOutput: "Mako opts not found - could not construct test db test list."
|
||||
assert makoOutput: "Mako opts not found - could not construct test db test list."
|
||||
|
||||
// Split each line of turtle output into a list
|
||||
def turtleOutList = turtleOutput.split("\n")
|
||||
// Split each line of mako output into a list
|
||||
def outputList = makoOutput.split("\n")
|
||||
|
||||
def makoOptsJson = transformMakoArgsToJson(turtleOutList)
|
||||
def makoOptsJson = transformMakoArgsToJson(outputList)
|
||||
|
||||
return makoOptsJson
|
||||
}
|
||||
@ -1827,7 +1838,7 @@ def runLLMBuild(pipeline, cpu_arch, reinstall_dependencies=false, wheel_path="",
|
||||
if (env.alternativeTRT) {
|
||||
trtllm_utils.replaceWithAlternativeTRT(env.alternativeTRT, cpver)
|
||||
}
|
||||
buildArgs = "--clean"
|
||||
buildArgs = "--clean --nixl_root /opt/nvidia/nvda_nixl"
|
||||
if (cpu_arch == AARCH64_TRIPLE) {
|
||||
buildArgs += " -a '90-real;100-real;103-real;120-real'"
|
||||
}
|
||||
@ -2062,9 +2073,9 @@ def launchTestJobs(pipeline, testFilter)
|
||||
"DGX_H200-4_GPUs-TensorRT-Post-Merge-1": ["dgx-h200-x4", "l0_dgx_h200", 1, 3, 4],
|
||||
"DGX_H200-4_GPUs-TensorRT-Post-Merge-2": ["dgx-h200-x4", "l0_dgx_h200", 2, 3, 4],
|
||||
"DGX_H200-4_GPUs-TensorRT-Post-Merge-3": ["dgx-h200-x4", "l0_dgx_h200", 3, 3, 4],
|
||||
"RTXPro6000-Pytorch-Post-Merge-1": ["rtx-pro-6000", "l0_rtx_pro_6000", 1, 1],
|
||||
"RTXPro6000-4_GPUs-Pytorch-Post-Merge-1": ["rtx-pro-6000-x4", "l0_rtx_pro_6000", 1, 2, 4],
|
||||
"RTXPro6000-4_GPUs-Pytorch-Post-Merge-2": ["rtx-pro-6000-x4", "l0_rtx_pro_6000", 2, 2, 4],
|
||||
//"RTXPro6000-Pytorch-Post-Merge-1": ["rtx-pro-6000", "l0_rtx_pro_6000", 1, 1],
|
||||
//"RTXPro6000-4_GPUs-Pytorch-Post-Merge-1": ["rtx-pro-6000-x4", "l0_rtx_pro_6000", 1, 2, 4],
|
||||
//"RTXPro6000-4_GPUs-Pytorch-Post-Merge-2": ["rtx-pro-6000-x4", "l0_rtx_pro_6000", 2, 2, 4],
|
||||
]
|
||||
|
||||
parallelJobs = x86TestConfigs.collectEntries{key, values -> [key, [createKubernetesPodConfig(key.contains("-CU12-") ? LLM_DOCKER_IMAGE_12_9 : LLM_DOCKER_IMAGE, values[0], "amd64", values[4] ?: 1, key.contains("Perf")), {
|
||||
|
||||
@ -170,7 +170,8 @@ class FlashInferAttentionMetadata(AttentionMetadata):
|
||||
def create_cuda_graph_metadata(self,
|
||||
max_batch_size: int,
|
||||
sub_cross_metadata: bool = False,
|
||||
max_draft_tokens: int = 0) -> Self:
|
||||
max_draft_tokens: int = 0,
|
||||
buffers=None) -> Self:
|
||||
metadata = super().create_cuda_graph_metadata(max_batch_size,
|
||||
sub_cross_metadata,
|
||||
max_draft_tokens)
|
||||
|
||||
@ -140,6 +140,7 @@ class AttentionMetadata:
|
||||
|
||||
# This buffer is currently only used for TrtllmAttentionMetadata.
|
||||
cache_indirection: Optional[torch.Tensor] = None
|
||||
cuda_graph_buffers: dict[str, list[torch.Tensor]] = None
|
||||
|
||||
_saved_tensors: Dict[str, torch.Tensor] = field(init=False,
|
||||
default_factory=dict)
|
||||
@ -288,7 +289,8 @@ class AttentionMetadata:
|
||||
def create_cuda_graph_metadata(self,
|
||||
max_batch_size: int,
|
||||
sub_cross_metadata: bool = False,
|
||||
max_draft_tokens: int = 0) -> Self:
|
||||
max_draft_tokens: int = 0,
|
||||
buffers=None) -> Self:
|
||||
"""
|
||||
Creates metadata for CUDA graph execution.
|
||||
CUDA graphs require to use pre-allocated buffers for all tensors in fields.
|
||||
@ -300,6 +302,7 @@ class AttentionMetadata:
|
||||
|
||||
cuda_graph_metadata = copy.copy(self)
|
||||
cuda_graph_metadata.is_cuda_graph = True
|
||||
cuda_graph_metadata.cuda_graph_buffers = buffers
|
||||
if self.has_cross_sub_metadata:
|
||||
cuda_graph_metadata.cross = cuda_graph_metadata.cross.create_cuda_graph_metadata(
|
||||
max_batch_size, True)
|
||||
|
||||
@ -600,13 +600,65 @@ class TrtllmAttentionMetadata(AttentionMetadata):
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__post_init__()
|
||||
self._post_init_with_buffers(self.cuda_graph_buffers)
|
||||
|
||||
def _post_init_with_buffers(self, buffers) -> None:
|
||||
|
||||
# Set a default value, as max_num_sequences is not always set.
|
||||
if self.max_num_sequences is None:
|
||||
self.max_num_sequences = self.max_num_requests
|
||||
|
||||
self.prompt_lens_cuda = torch.empty(
|
||||
def get_empty(tensor_shape: list[int], dtype: torch.dtype,
|
||||
cache_name: str) -> torch.Tensor:
|
||||
"""
|
||||
Finds a compatible, reusable buffer from a cache or creates a new one.
|
||||
|
||||
This function searches for a pre-allocated tensor (buffer) that can be
|
||||
reused for an operation involving a tensor with the shape of `tensor_shape`.
|
||||
|
||||
The compatibility rules are: The buffer's total elements must be >= tensor_shape's.
|
||||
|
||||
If a compatible buffer is found, it's returned immediately. Otherwise, a new
|
||||
buffer is allocated on the 'cuda' device with the give properties of 'tensor_shape' and 'dtype'.
|
||||
|
||||
Args:
|
||||
tensor_shape: The required shape.
|
||||
dtype: The required dtype.
|
||||
cache_name: The key for the specific list of buffers to search in.
|
||||
|
||||
Returns:
|
||||
An existing compatible buffer or a newly created one.
|
||||
"""
|
||||
if buffers is not None:
|
||||
# Safely get the list of candidates. Defaults to an empty list if key is missing.
|
||||
candidate_buffers = buffers.get(cache_name, [])
|
||||
numel_like = math.prod(tensor_shape)
|
||||
|
||||
for buffer in candidate_buffers:
|
||||
numel_buffer = buffer.numel()
|
||||
|
||||
# buffer just needs to be large enough.
|
||||
if numel_buffer >= numel_like:
|
||||
return buffer[0:numel_like].view(
|
||||
tensor_shape) # Found a fit, return immediately.
|
||||
|
||||
# If we get here, no suitable buffer was found in the cache. Create a new one.
|
||||
new_buffer = torch.zeros(tensor_shape, device='cuda', dtype=dtype)
|
||||
if buffers is not None:
|
||||
buffers.setdefault(cache_name, []).append(new_buffer)
|
||||
return new_buffer
|
||||
|
||||
def get_empty_like(like_tensor: torch.Tensor,
|
||||
cache_name: str) -> torch.Tensor:
|
||||
return get_empty(
|
||||
like_tensor.shape,
|
||||
cache_name=cache_name,
|
||||
dtype=like_tensor.dtype,
|
||||
)
|
||||
|
||||
self.prompt_lens_cuda = get_empty(
|
||||
(self.max_num_sequences, ),
|
||||
device='cuda',
|
||||
cache_name="prompt_lens_cuda",
|
||||
dtype=torch.int,
|
||||
)
|
||||
self.prompt_lens_cpu = torch.empty_like(
|
||||
@ -614,7 +666,10 @@ class TrtllmAttentionMetadata(AttentionMetadata):
|
||||
device='cpu',
|
||||
pin_memory=True,
|
||||
)
|
||||
self.kv_lens_cuda = torch.empty_like(self.prompt_lens_cuda)
|
||||
self.kv_lens_cuda = get_empty_like(
|
||||
self.prompt_lens_cuda,
|
||||
cache_name="kv_lens_cuda",
|
||||
)
|
||||
self.kv_lens = torch.empty_like(self.kv_lens_cuda,
|
||||
device='cpu',
|
||||
pin_memory=True)
|
||||
@ -629,13 +684,13 @@ class TrtllmAttentionMetadata(AttentionMetadata):
|
||||
dtype=torch.int8,
|
||||
)
|
||||
if self.kv_cache_manager is not None:
|
||||
self.kv_cache_block_offsets = torch.empty(
|
||||
self.kv_cache_block_offsets = get_empty(
|
||||
[
|
||||
self.kv_cache_manager.num_pools, self.max_num_sequences, 2,
|
||||
self.kv_cache_manager.max_blocks_per_seq
|
||||
],
|
||||
cache_name="kv_cache_block_offsets",
|
||||
dtype=torch.int32,
|
||||
device='cuda',
|
||||
)
|
||||
self.host_kv_cache_block_offsets = torch.empty_like(
|
||||
self.kv_cache_block_offsets,
|
||||
@ -645,27 +700,27 @@ class TrtllmAttentionMetadata(AttentionMetadata):
|
||||
self.block_ids_per_seq = None
|
||||
self.kv_block_ids_per_seq = None
|
||||
if self.enable_flash_mla:
|
||||
self.block_ids_per_seq = torch.zeros(
|
||||
self.block_ids_per_seq = get_empty(
|
||||
[
|
||||
self.kv_cache_manager.max_batch_size,
|
||||
self.kv_cache_manager.max_blocks_per_seq
|
||||
],
|
||||
cache_name="block_ids_per_seq",
|
||||
dtype=torch.int32,
|
||||
device='cuda',
|
||||
)
|
||||
self.kv_block_ids_per_seq = torch.zeros(
|
||||
self.kv_block_ids_per_seq = get_empty(
|
||||
[
|
||||
self.kv_cache_manager.max_batch_size,
|
||||
self.kv_cache_manager.max_blocks_per_seq
|
||||
],
|
||||
cache_name="kv_block_ids_per_seq",
|
||||
dtype=torch.int32,
|
||||
device='cuda',
|
||||
)
|
||||
if self.enable_context_mla_with_cached_kv:
|
||||
# for kv cache reuse/chunked context in MLA
|
||||
self.ctx_cached_token_indptr = torch.zeros(
|
||||
self.ctx_cached_token_indptr = get_empty(
|
||||
(self.max_num_requests + 1, ),
|
||||
device='cuda',
|
||||
cache_name="ctx_cached_token_indptr",
|
||||
dtype=torch.int64,
|
||||
)
|
||||
self.host_ctx_cached_token_indptr = torch.zeros_like(
|
||||
@ -673,9 +728,9 @@ class TrtllmAttentionMetadata(AttentionMetadata):
|
||||
device='cpu',
|
||||
pin_memory=True,
|
||||
)
|
||||
self.ctx_uncached_token_indptr = torch.zeros(
|
||||
self.ctx_uncached_token_indptr = get_empty(
|
||||
(self.max_num_requests + 1, ),
|
||||
device='cuda',
|
||||
cache_name="ctx_uncached_token_indptr",
|
||||
dtype=torch.int64,
|
||||
)
|
||||
self.host_ctx_uncached_token_indptr = torch.zeros_like(
|
||||
@ -684,9 +739,9 @@ class TrtllmAttentionMetadata(AttentionMetadata):
|
||||
pin_memory=True,
|
||||
)
|
||||
# context full seqlens include cached tokens and uncached tokens
|
||||
self.ctx_kv_indptr = torch.zeros(
|
||||
self.ctx_kv_indptr = get_empty(
|
||||
(self.max_num_requests + 1, ),
|
||||
device='cuda',
|
||||
cache_name="ctx_kv_indptr",
|
||||
dtype=torch.int64,
|
||||
)
|
||||
self.host_ctx_kv_indptr = torch.zeros_like(
|
||||
@ -1165,7 +1220,7 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):
|
||||
host_kv_cache_pool_pointers=metadata.host_kv_cache_pool_pointers,
|
||||
host_kv_cache_pool_mapping=metadata.host_kv_cache_pool_mapping,
|
||||
block_ids_per_seq=metadata.block_ids_per_seq,
|
||||
workspace=metadata.workspace,
|
||||
workspace=None,
|
||||
cache_indirection=metadata.cache_indirection,
|
||||
kv_scale_orig_quant=self.kv_scale_orig_quant,
|
||||
kv_scale_quant_orig=self.kv_scale_quant_orig,
|
||||
|
||||
@ -371,7 +371,7 @@ class AutoTuner:
|
||||
if not is_cache_hit:
|
||||
logger.warning_once(
|
||||
f"[AutoTunner] Using the fallback tactic, due to cache miss on input shapes={input_shapes}",
|
||||
key=(custom_op))
|
||||
key=custom_op)
|
||||
|
||||
return (best_runner, best_tactic)
|
||||
|
||||
|
||||
@ -210,15 +210,9 @@ class PiecewiseRunner(object):
|
||||
runtime_input_addresses = [
|
||||
i.data_ptr() for i in args if isinstance(i, torch.Tensor)
|
||||
]
|
||||
runtime_output_addresses = [
|
||||
i.data_ptr() for i in output if isinstance(i, torch.Tensor)
|
||||
]
|
||||
|
||||
assert (entry.input_addresses == runtime_input_addresses
|
||||
), f"{entry.input_addresses} vs\n {runtime_input_addresses}"
|
||||
assert (
|
||||
entry.output_addresses == runtime_output_addresses
|
||||
), f"{entry.output_addresses} vs\n {runtime_output_addresses}"
|
||||
|
||||
entry.cuda_graph.replay()
|
||||
|
||||
|
||||
@ -494,7 +494,8 @@ class ModelConfig(Generic[TConfig]):
|
||||
architectures = self.pretrained_config.architectures
|
||||
if len(architectures
|
||||
) == 1 and architectures[0] == "DeciLMForCausalLM":
|
||||
mlp_hidden_size = self._infer_nemotron_ffn_mult()
|
||||
mlp_hidden_size = self._infer_nemotron_ffn_mult(
|
||||
) // self.mapping.tp_size
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Inferring mlp hidden size for model architecture: {architectures} isn't supported yet"
|
||||
|
||||
@ -263,7 +263,9 @@ class Gemma3VLM(PreTrainedModel):
|
||||
embedding_layer=self.llm.model.embed_tokens,
|
||||
input_ids=input_ids,
|
||||
mm_embeds=mm_embeds,
|
||||
mm_token_ids=self.image_token_ids)
|
||||
mm_token_ids=self.image_token_ids,
|
||||
**kwargs,
|
||||
)
|
||||
logits = self.llm.forward(
|
||||
attn_metadata=attn_metadata,
|
||||
input_ids=input_ids,
|
||||
@ -284,3 +286,7 @@ class Gemma3VLM(PreTrainedModel):
|
||||
attn_metadata=attn_metadata)[-1]
|
||||
image_features = self.mm_projector(image_features)
|
||||
return image_features
|
||||
|
||||
@property
|
||||
def mm_token_ids(self):
|
||||
return self.image_token_ids
|
||||
|
||||
@ -1052,7 +1052,8 @@ class HCXVisionForCausalLM(PreTrainedModel):
|
||||
]
|
||||
|
||||
input_ids, input_embeds = fuse_input_embeds(self.llm.model.embed_tokens,
|
||||
input_ids, mm_embeds)
|
||||
input_ids, mm_embeds,
|
||||
**kwargs)
|
||||
output_prob = self.llm.forward(
|
||||
attn_metadata=attn_metadata,
|
||||
input_ids=input_ids,
|
||||
|
||||
@ -1280,7 +1280,8 @@ class Llama4ForConditionalGeneration(SpecDecOneEngineForCausalLM[Llama4Model,
|
||||
]
|
||||
|
||||
input_ids, inputs_embeds = fuse_input_embeds(self.model.embed_tokens,
|
||||
input_ids, mm_embeds)
|
||||
input_ids, mm_embeds,
|
||||
**kwargs)
|
||||
return super().forward(attn_metadata,
|
||||
input_ids,
|
||||
position_ids,
|
||||
|
||||
@ -302,7 +302,8 @@ class LlavaNextVisionModel(nn.Module):
|
||||
logger.warning_once(
|
||||
"Image feature shape does not line up with the provided patch size. "
|
||||
"You may be using the `default` vision_feature_select_strategy with a"
|
||||
" visual encoder that does not have CLS.")
|
||||
" visual encoder that does not have CLS.",
|
||||
key="llava_next_vision_model_pack_image_features")
|
||||
|
||||
image_feature = image_feature.view(num_patch_height,
|
||||
num_patch_width, height,
|
||||
@ -474,7 +475,7 @@ class LlavaNextModel(PreTrainedModel):
|
||||
for multimodal_param in multimodal_params
|
||||
]
|
||||
input_ids, inputs_embeds = fuse_input_embeds(
|
||||
self.llm.model.embed_tokens, input_ids, mm_embeds)
|
||||
self.llm.model.embed_tokens, input_ids, mm_embeds, **kwargs)
|
||||
|
||||
logits = self.llm.forward(attn_metadata, input_ids, position_ids,
|
||||
inputs_embeds, return_context_logits)
|
||||
|
||||
@ -408,6 +408,7 @@ class Mistral3VLM(PreTrainedModel):
|
||||
input_ids=input_ids,
|
||||
mm_embeds=mm_embeds,
|
||||
mm_token_ids=self._image_token_ids,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return self.llm.forward(
|
||||
@ -501,6 +502,10 @@ class Mistral3VLM(PreTrainedModel):
|
||||
]
|
||||
return torch.cat(pixel_values), batched_image_sizes
|
||||
|
||||
@property
|
||||
def mm_token_ids(self):
|
||||
return self._image_token_ids
|
||||
|
||||
|
||||
# Original implementation:
|
||||
# https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/mistral3/modeling_mistral3.py#L66
|
||||
|
||||
@ -105,48 +105,84 @@ def find_uncached_mm_embeds(
|
||||
return sliced_mm_embeds
|
||||
|
||||
|
||||
def filter_mm_token_from_input_ids(
|
||||
input_ids: torch.IntTensor,
|
||||
vocab_size: int,
|
||||
mm_token_ids: Optional[torch.IntTensor] = None,
|
||||
) -> Tuple[torch.IntTensor, torch.IntTensor]:
|
||||
"""
|
||||
Filter multimodal tokens from input_ids.
|
||||
Args:
|
||||
input_ids: shape [text_total_length + mm_total_length].
|
||||
vocab_size: size of the model's vocabulary
|
||||
mm_token_ids: possible token ids for multimodal tokens, if known. If not known and set to None, it is assumed that the multimodal tokens are out-of-vocabulary tokens i.e. the `input_ids` contains tokens >= vocab_size that represent the multimodal tokens.
|
||||
Note:
|
||||
This function involves host-device synchronization due to torch.where() (= torch.nonzero) requiring
|
||||
host allocation. The output indices reside on the same device as input_ids.
|
||||
Returns:
|
||||
text_token_indices: indices of text tokens in the input_ids
|
||||
mm_token_indices: indices of multimodal tokens in the input_ids
|
||||
"""
|
||||
if mm_token_ids is None:
|
||||
# NOTE:
|
||||
# If mm_token_ids is None, it is assumed that the multimodal
|
||||
# tokens are out-of-vocab tokens i.e. the `input_ids` contains
|
||||
# tokens >= vocab_size that represent the multimodal tokens.
|
||||
# Since mm_token_ids can be unbounded in this case,
|
||||
# using torch.isin() may not be performant.
|
||||
# This provides a more performant alternative while keeping
|
||||
# the flexibility of still specifying all possible mm_token_ids,
|
||||
# if the user wants to.
|
||||
mm_token_mask = input_ids >= vocab_size
|
||||
text_token_mask = input_ids < vocab_size
|
||||
else:
|
||||
mm_token_ids = mm_token_ids.to(input_ids.device, dtype=input_ids.dtype)
|
||||
mm_token_mask = torch.isin(input_ids, mm_token_ids)
|
||||
text_token_mask = ~mm_token_mask
|
||||
# NOTE: torch.where() enforces a host sync
|
||||
text_token_indices = torch.where(text_token_mask)[0]
|
||||
mm_token_indices = torch.where(mm_token_mask)[0]
|
||||
return text_token_indices, mm_token_indices
|
||||
|
||||
|
||||
def fuse_input_embeds(
|
||||
embedding_layer: Embedding,
|
||||
input_ids: torch.IntTensor,
|
||||
mm_embeds: List[torch.Tensor],
|
||||
mm_token_ids: Optional[torch.IntTensor] = None,
|
||||
text_token_indices: Optional[torch.IntTensor] = None,
|
||||
mm_token_indices: Optional[torch.IntTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[Optional[torch.FloatTensor], Optional[torch.FloatTensor]]:
|
||||
"""
|
||||
Fuse text and multimodal embeddings. input_ids is [text_total_length + mm_total_length] and mm_embed is [mm_total_length, hidden_dim]. We just need to fuse them into [text_total_length + mm_total_length, hidden_dim] by slice-and-assign to the corresponding entries.
|
||||
|
||||
Args:
|
||||
embedding_layer: embedding layer of the model.
|
||||
input_ids: shape [text_total_length + mm_total_length], flattened from List[(text_length1 + mm_total_length1), ..., (text_lengthi + mm_total_lengthi)]. For LLM model, the requests are inflight batched together, but the input_ids are flattened with padding removed. By the slice condition < vocab_size, we can easily separate text / multimodal tokens and naturally batched the LLM embedding lookup
|
||||
mm_embed: List[(mm_total_length1, hidden_dim), ..., (mm_total_lengthi, hidden_dim)].
|
||||
mm_token_ids: possible token ids for multimodal tokens, if known. If not known and set to None, it is assumed that the multimodal tokens are out-of-vocabulary tokens i.e. the `input_ids` contains tokens >= vocab_size that represent the multimodal tokens.
|
||||
mm_embeds: List[(mm_total_length1, hidden_dim), ..., (mm_total_lengthi, hidden_dim)].
|
||||
mm_token_ids: possible token ids for multimodal tokens, if known. If not known and set to None, it is assumed that the multimodal tokens are out-of-vocabulary tokens.
|
||||
Returns:
|
||||
- If (1) JIT test run, (2) non-multimodal run, i.e. all text-only requests, either context or generation phase (3) multimodal run, all requests in generation phase --> there is no multimodal data, return only the input_ids
|
||||
- If (4) multimodal run, mixed batch of context and generation requests, each context request has a multimodal feature --> return only the fused input_embeds of shape [total length, hidden_dim]. For text tokens, LLM embedding layer has already run.
|
||||
Note:
|
||||
- Precedence: If kwargs provide indices (text_token_indices and mm_token_indices), those are used. If any one of them is not provided, fallback to filtering method. Sentinel-/OOV-based filtering (e.g., tokens >= vocab_size) is used only when neither index tensor and mm_token_ids is provided.
|
||||
- This function may involve host-device synchronization if indices are not provided and filtering is performed. See filter_mm_token_from_input_ids for details.
|
||||
"""
|
||||
if len(mm_embeds) == 0:
|
||||
return input_ids, None
|
||||
|
||||
mm_embed = torch.cat(mm_embeds, dim=0)
|
||||
|
||||
if mm_token_ids is None:
|
||||
# NOTE:
|
||||
# If mm_token_ids is None, it is assumed that the multimodal
|
||||
# tokens are out-of-vocab tokens i.e. the `input_ids` contains
|
||||
# tokens >= vocab_size that represent the multimodal tokens.
|
||||
# Since mm_token_ids is be unbounded in this case,
|
||||
# using torch.isin() may not be performant.
|
||||
# This provides a more performant alternative while keeping
|
||||
# the flexibility of still specifying all possible mm_token_ids,
|
||||
# if the user wants to.
|
||||
vocab_size = embedding_layer.num_embeddings
|
||||
mm_token_mask = input_ids >= vocab_size
|
||||
text_token_mask = input_ids < vocab_size
|
||||
else:
|
||||
mm_token_ids = mm_token_ids.to(input_ids.device)
|
||||
mm_token_mask = torch.isin(input_ids, mm_token_ids)
|
||||
text_token_mask = ~mm_token_mask
|
||||
text_token_indices = torch.where(text_token_mask)[0]
|
||||
mm_token_indices = torch.where(mm_token_mask)[0]
|
||||
if len(mm_token_indices) != mm_embed.shape[0]:
|
||||
# TODO: support the case where only one index tensor is provided, the other is derived as the complement (try to avoid implicit host-device synchronization)
|
||||
if text_token_indices is None or mm_token_indices is None:
|
||||
# NOTE: This function involves host-device synchronization due to torch.where() used in filter_mm_token_from_input_ids.
|
||||
text_token_indices, mm_token_indices = filter_mm_token_from_input_ids(
|
||||
input_ids,
|
||||
vocab_size=embedding_layer.num_embeddings,
|
||||
mm_token_ids=mm_token_ids)
|
||||
|
||||
if mm_token_indices.shape[0] != mm_embed.shape[0]:
|
||||
raise ValueError(
|
||||
f"Multimodal token count mismatch: found {len(mm_token_indices)} image tokens in input_ids "
|
||||
f"but received {mm_embed.shape[0]} image embeddings. "
|
||||
@ -159,8 +195,7 @@ def fuse_input_embeds(
|
||||
device=text_embed.device,
|
||||
dtype=text_embed.dtype)
|
||||
|
||||
input_embeds[text_token_indices, :] = text_embed.to(
|
||||
dtype=input_embeds.dtype, device=input_embeds.device)
|
||||
input_embeds[text_token_indices, :] = text_embed
|
||||
input_embeds[mm_token_indices, :] = mm_embed.to(dtype=input_embeds.dtype,
|
||||
device=input_embeds.device)
|
||||
|
||||
|
||||
@ -594,6 +594,7 @@ class Phi4MMForCausalLM(transformers.PreTrainedModel):
|
||||
input_ids,
|
||||
mm_embedding,
|
||||
mm_token_ids=self.mm_token_ids,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
output_prob = self.llm.forward(
|
||||
|
||||
@ -612,7 +612,8 @@ class Qwen2VLModelBase(PreTrainedModel):
|
||||
'mrope_position_deltas']
|
||||
|
||||
input_ids, input_embeds = fuse_input_embeds(self.llm.model.embed_tokens,
|
||||
input_ids, mm_embeds)
|
||||
input_ids, mm_embeds,
|
||||
**kwargs)
|
||||
output_prob = self.llm.forward(
|
||||
attn_metadata=attn_metadata,
|
||||
input_ids=input_ids,
|
||||
|
||||
@ -48,6 +48,9 @@ class Qwen3Attention(QKNormRoPEAttention):
|
||||
rope=RopeParams.from_config(config),
|
||||
)
|
||||
|
||||
# Qwen3 has accuracy issues with deep_gemm (see: https://nvbugspro.nvidia.com/bug/5461712)
|
||||
# TODO: Consider adding disable_deep_gemm support to QKNormRoPEAttention if accuracy still remains
|
||||
|
||||
super().__init__(
|
||||
hidden_size=config.hidden_size,
|
||||
num_attention_heads=config.num_attention_heads,
|
||||
@ -85,6 +88,7 @@ class Qwen3DecoderLayer(DecoderLayer):
|
||||
dtype=config.torch_dtype,
|
||||
config=model_config,
|
||||
)
|
||||
|
||||
self.input_layernorm = RMSNorm(hidden_size=config.hidden_size,
|
||||
eps=config.rms_norm_eps,
|
||||
dtype=config.torch_dtype)
|
||||
|
||||
@ -1186,7 +1186,7 @@ class VilaModel(PreTrainedModel):
|
||||
]
|
||||
|
||||
input_ids, inputs_embeds = fuse_input_embeds(
|
||||
self.llm.model.embed_tokens, input_ids, mm_embeds)
|
||||
self.llm.model.embed_tokens, input_ids, mm_embeds, **kwargs)
|
||||
logits = self.llm.forward(attn_metadata=attn_metadata,
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
|
||||
@ -216,6 +216,7 @@ class Attention(nn.Module):
|
||||
skip_create_weights_in_init=config.skip_create_weights_in_init,
|
||||
allreduce_strategy=config.allreduce_strategy,
|
||||
force_dynamic_quantization=config.force_dynamic_quantization)
|
||||
|
||||
self.o_lora = LoraLayer([LoraModuleType.ATTENTION_DENSE],
|
||||
[self.hidden_size])
|
||||
|
||||
|
||||
@ -138,25 +138,23 @@ def pre_comm_embedding_ops(
|
||||
padding_size: int,
|
||||
):
|
||||
# Generate the mask for the input if needed.
|
||||
if tp_size > 1:
|
||||
if tp_mode == TensorParallelMode.COLUMN:
|
||||
input_, input_mask = get_masked_input_and_mask(
|
||||
input_,
|
||||
vocab_start_index,
|
||||
vocab_end_index,
|
||||
)
|
||||
if tp_mode == TensorParallelMode.COLUMN:
|
||||
input_, input_mask = get_masked_input_and_mask(
|
||||
input_,
|
||||
vocab_start_index,
|
||||
vocab_end_index,
|
||||
)
|
||||
|
||||
# Get the embeddings.
|
||||
output = F.embedding(input_, weight)
|
||||
|
||||
# Mask or pad the output if needed.
|
||||
if tp_size > 1:
|
||||
if tp_mode == TensorParallelMode.COLUMN:
|
||||
output.masked_fill_(input_mask, 0)
|
||||
elif tp_mode == TensorParallelMode.ROW:
|
||||
if gather_output:
|
||||
if tp_rank == tp_size - 1 and padding_size > 0:
|
||||
output = F.pad(output, (0, padding_size))
|
||||
if tp_mode == TensorParallelMode.COLUMN:
|
||||
output.masked_fill_(input_mask, 0)
|
||||
elif tp_mode == TensorParallelMode.ROW:
|
||||
if gather_output:
|
||||
if tp_rank == tp_size - 1 and padding_size > 0:
|
||||
output = F.pad(output, (0, padding_size))
|
||||
|
||||
return output
|
||||
|
||||
@ -205,12 +203,16 @@ class Embedding(LMHead):
|
||||
self.vocab_end_index = num_embeddings
|
||||
|
||||
def forward(self, input):
|
||||
# Run the ops before all_reduce/all_gather.
|
||||
# We use torch.compile() to fuse the tiny pointwise ops before all_reduce/all_gather for Embedding module.
|
||||
embedding_ops_func = torch.compile(
|
||||
pre_comm_embedding_ops,
|
||||
options={"max-autotune": True},
|
||||
disable=not self.enable_torch_compile_for_embedding)
|
||||
if self.tp_size > 1:
|
||||
# Run the ops before all_reduce/all_gather.
|
||||
# We use torch.compile() to fuse the tiny pointwise ops before all_reduce/all_gather for Embedding module.
|
||||
embedding_ops_func = torch.compile(
|
||||
pre_comm_embedding_ops,
|
||||
options={"max-autotune": True},
|
||||
disable=not self.enable_torch_compile_for_embedding)
|
||||
else:
|
||||
# Skip torch.compile when TP size is 1 to avoid unnecessary host overhead
|
||||
embedding_ops_func = pre_comm_embedding_ops
|
||||
output = embedding_ops_func(input, self.weight, self.tp_size,
|
||||
self.tp_rank, self.tp_mode,
|
||||
self.vocab_start_index,
|
||||
|
||||
@ -3,6 +3,8 @@ from typing import Dict, List, Optional, Union
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from tensorrt_llm._utils import get_sm_version
|
||||
|
||||
from ...model_config import ModelConfig
|
||||
from ...utils import Fp4QuantizedTensor, next_positive_power_of_2
|
||||
from .interface import MoE, MoEWeightLoadingMode
|
||||
@ -78,6 +80,11 @@ class TRTLLMGenFusedMoE(MoE):
|
||||
swiglu_limit=swiglu_limit,
|
||||
)
|
||||
|
||||
sm_version = get_sm_version()
|
||||
if sm_version >= 120:
|
||||
raise NotImplementedError(
|
||||
"TRTLLMGenFusedMoE does not support SM120 and above.")
|
||||
|
||||
assert not self.smart_router, "Smart router is not supported in TRTLLMGenFusedMoE."
|
||||
|
||||
self.num_slots = self.num_experts
|
||||
|
||||
@ -5,6 +5,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
|
||||
from ..distributed import AllReduceParams
|
||||
@ -29,6 +30,7 @@ class GatedMLP(nn.Module):
|
||||
reduce_output: bool = True,
|
||||
layer_idx: Optional[int] = None,
|
||||
use_cute_dsl_blockscaling_mm: bool = False):
|
||||
|
||||
super().__init__()
|
||||
self.layer_idx = layer_idx
|
||||
self.hidden_size = hidden_size
|
||||
@ -98,12 +100,21 @@ class GatedMLP(nn.Module):
|
||||
[LoraModuleType.MLP_GATE_UP],
|
||||
[2 * self.intermediate_size // mapping.tp_size])
|
||||
|
||||
def _apply_activation(self, x):
|
||||
def _apply_activation(self, x, *, has_lora: bool = False):
|
||||
if self.activation == F.silu:
|
||||
if self.down_proj.has_fp8_qdq or self.down_proj.has_w4a8_nvfp4_fp8:
|
||||
return swiglu(x,
|
||||
quant_scale=self.down_proj.input_scale,
|
||||
quant_type=torch.float8_e4m3fn)
|
||||
if has_lora:
|
||||
# NOTE: This is a WAR, since LoRA grouped_gemm does not support FP8 yet.
|
||||
# TODO: Remove this path when LoRA grouped_gemm supports FP8
|
||||
# see: cpp/tensorrt_llm/thop/loraOp.cpp::lora_grouped_gemm
|
||||
logger.warning(
|
||||
f"GatedMLP._apply_activation: LoRA path active; forcing non-FP8 activation dtype bf16/fp16, layer_idx={self.layer_idx}"
|
||||
)
|
||||
return swiglu(x)
|
||||
else:
|
||||
return swiglu(x,
|
||||
quant_scale=self.down_proj.input_scale,
|
||||
quant_type=torch.float8_e4m3fn)
|
||||
else:
|
||||
return swiglu(x)
|
||||
elif callable(self.activation):
|
||||
@ -155,7 +166,7 @@ class GatedMLP(nn.Module):
|
||||
if h1_lora is not None:
|
||||
h1 = h1 + h1_lora
|
||||
|
||||
h2 = self._apply_activation(h1)
|
||||
h2 = self._apply_activation(h1, has_lora=True)
|
||||
output = self.down_proj(h2,
|
||||
all_reduce_params=final_all_reduce_params,
|
||||
lora_params=lora_params,
|
||||
|
||||
@ -5,7 +5,6 @@ from tensorrt_llm._torch.models.checkpoints.base_checkpoint_loader import \
|
||||
BaseCheckpointLoader
|
||||
from tensorrt_llm.bindings.executor import ExecutorConfig
|
||||
|
||||
from ...builder import BuildConfig
|
||||
from ...llmapi.llm_args import LoadFormat, SamplerType
|
||||
from ...logger import logger
|
||||
from ...mapping import Mapping
|
||||
@ -119,7 +118,6 @@ EXETENDED_EXECUTOR_CONFIG_FIELDS = [
|
||||
'backend',
|
||||
'pytorch_backend_config',
|
||||
'max_seq_len',
|
||||
'tokens_per_block',
|
||||
'mapping',
|
||||
'hf_model_dir',
|
||||
'mm_encoder_only',
|
||||
@ -131,7 +129,6 @@ def update_executor_config(
|
||||
backend: Optional[str] = None,
|
||||
pytorch_backend_config: Optional[PyTorchConfig] = None,
|
||||
mapping: Optional[Mapping] = None,
|
||||
build_config: Optional[BuildConfig] = None,
|
||||
speculative_config: Optional["DecodingBaseConfig"] = None,
|
||||
hf_model_dir: Optional[str] = None,
|
||||
max_input_len: Optional[int] = None,
|
||||
@ -156,10 +153,6 @@ def update_executor_config(
|
||||
|
||||
logger.info(f"{executor_config.pytorch_backend_config}")
|
||||
|
||||
build_config = build_config or BuildConfig()
|
||||
# TODO: move to pure-Python KvCacheConfig, and remove dependency on build_config.
|
||||
executor_config.tokens_per_block = executor_config.tokens_per_block or build_config.plugin_config.tokens_per_block
|
||||
|
||||
executor_config.hf_model_dir = hf_model_dir
|
||||
|
||||
if max_input_len is not None:
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
from collections.abc import Generator
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from itertools import pairwise
|
||||
from typing import Any, Dict, List, Optional, TypeAlias, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -424,7 +426,10 @@ class LlmRequest(tensorrt_llm.bindings.internal.batch_manager.LlmRequest):
|
||||
self.child_requests.append(py_request)
|
||||
|
||||
|
||||
def convert_wordlist(word_list) -> List[List[int]]:
|
||||
StopWordList: TypeAlias = list[list[int]]
|
||||
|
||||
|
||||
def convert_wordlist(word_list) -> StopWordList:
|
||||
"""Converts a wordlist from format:
|
||||
|
||||
[[word_0 token_0, word_0 token_1, ...], [word_1 token_0, ...], ...]]
|
||||
@ -461,6 +466,16 @@ def convert_wordlist(word_list) -> List[List[int]]:
|
||||
return [tokens, offsets]
|
||||
|
||||
|
||||
def produce_stop_words(
|
||||
py_stop_words_list: StopWordList) -> Generator[list[int], None, None]:
|
||||
"""yield stop sequences from the output of `convert_wordlist` above."""
|
||||
stop_words_list, prefix_sum = py_stop_words_list
|
||||
for start, end in pairwise((0, *prefix_sum)): # first element: prepend 0
|
||||
if end == -1: # -1 is a sentinel value in convert_wordlist
|
||||
break
|
||||
yield stop_words_list[start:end]
|
||||
|
||||
|
||||
def executor_request_to_llm_request(
|
||||
req_id: int,
|
||||
executor_request: ExecutorRequest,
|
||||
|
||||
@ -43,6 +43,7 @@ from ..metadata import KVCacheParams
|
||||
from ..model_config import ModelConfig, MoeLoadBalancerConfig
|
||||
from ..models import AutoModelForCausalLM
|
||||
from ..models.checkpoints.base_checkpoint_loader import BaseCheckpointLoader
|
||||
from ..models.modeling_multimodal_utils import filter_mm_token_from_input_ids
|
||||
from ..models.modeling_utils import (DecoderModelForCausalLM, MetaInitMode,
|
||||
timing)
|
||||
from ..modules.fused_moe.moe_load_balancer import (
|
||||
@ -1259,6 +1260,16 @@ class PyTorchModelEngine(ModelEngine):
|
||||
|
||||
return total_num_tokens, False, attn_all_rank_num_tokens
|
||||
|
||||
def _prepare_multimodal_indices(self, input_ids: list[int]):
|
||||
input_ids = torch.tensor(input_ids, dtype=torch.int, device="cpu")
|
||||
vocab_size = self.model.config.vocab_size
|
||||
# TODO: unify naming of mm_token_ids across models
|
||||
mm_token_ids = getattr(self.model, "mm_token_ids", None)
|
||||
|
||||
text_token_indices, mm_token_indices = filter_mm_token_from_input_ids(
|
||||
input_ids, vocab_size=vocab_size, mm_token_ids=mm_token_ids)
|
||||
return text_token_indices, mm_token_indices
|
||||
|
||||
def _prepare_tp_inputs(
|
||||
self,
|
||||
scheduled_requests: ScheduledRequests,
|
||||
@ -1335,6 +1346,14 @@ class PyTorchModelEngine(ModelEngine):
|
||||
|
||||
request.py_batch_idx = request.py_seq_slot
|
||||
|
||||
if len(multimodal_params_list) > 0:
|
||||
# discard the text token indices as it only includes context tokens at this moment
|
||||
print(
|
||||
f"len multimodal_params_list: {len(multimodal_params_list)} from model_engine"
|
||||
)
|
||||
_, mm_token_indices = self._prepare_multimodal_indices(input_ids)
|
||||
else:
|
||||
mm_token_indices = None
|
||||
num_ctx_requests = len(scheduled_requests.context_requests)
|
||||
num_ctx_tokens = len(input_ids)
|
||||
|
||||
@ -1698,6 +1717,14 @@ class PyTorchModelEngine(ModelEngine):
|
||||
spec_metadata.all_rank_num_tokens = spec_all_rank_num_tokens
|
||||
spec_metadata.all_rank_num_seqs = all_rank_num_seqs
|
||||
|
||||
if mm_token_indices is not None:
|
||||
mask = torch.ones(total_num_tokens, dtype=torch.bool)
|
||||
mask[mm_token_indices] = False
|
||||
inputs['mm_token_indices'] = mm_token_indices.pin_memory().to(
|
||||
"cuda", non_blocking=True)
|
||||
inputs['text_token_indices'] = torch.where(mask)[0].pin_memory().to(
|
||||
"cuda", non_blocking=True)
|
||||
|
||||
num_generation_tokens = len(generation_requests) + len(
|
||||
extend_requests) + sum(draft_lens)
|
||||
self.iter_states['num_ctx_requests'] = num_ctx_requests
|
||||
|
||||
@ -61,6 +61,9 @@ PROFILE_RECORD_GC_ENV_VAR_NAME = "TLLM_PROFILE_RECORD_GC"
|
||||
# Set to a path to save detailed tracing of PyTorch operations.
|
||||
PROFILE_TRACE_ENV_VAR_NAME = "TLLM_TORCH_PROFILE_TRACE"
|
||||
|
||||
# Unique tag base to avoid collisions with token/logits comms
|
||||
TERMINATION_COMM_TAG_BASE = 20000
|
||||
|
||||
|
||||
@functools.cache
|
||||
def _load_iteration_indexes(env_var: str):
|
||||
@ -208,6 +211,7 @@ class PyExecutor:
|
||||
self.kv_cache_manager = self.resource_manager.resource_managers.get(
|
||||
ResourceManagerType.KV_CACHE_MANAGER)
|
||||
self.enable_kv_cache_events = self.kv_cache_manager is not None and self.kv_cache_manager.event_buffer_max_size > 0
|
||||
self.enable_kv_cache_reuse = self.kv_cache_manager is not None and self.kv_cache_manager.enable_block_reuse
|
||||
|
||||
self.max_input_len = max_input_len
|
||||
# _executor_loop private data
|
||||
@ -259,6 +263,13 @@ class PyExecutor:
|
||||
self.gather_all_responses = False
|
||||
|
||||
self.kv_cache_transceiver = kv_cache_transceiver
|
||||
|
||||
# Initialize disagg PP termination handler if needed
|
||||
self._disagg_pp_termination_handler = None
|
||||
if self.dist.pp_size > 1 and self.enable_kv_cache_reuse and self.kv_cache_transceiver:
|
||||
self._disagg_pp_termination_handler = DisaggPPTerminationHandler(
|
||||
self.num_micro_batches, self.dist)
|
||||
|
||||
if self.dist.pp_size > 1:
|
||||
self.event_loop = self._executor_loop_pp
|
||||
else:
|
||||
@ -718,6 +729,14 @@ class PyExecutor:
|
||||
batch_state.sample_state.scheduled_requests), req_stats)
|
||||
|
||||
def _executor_loop_cleanup(self):
|
||||
|
||||
for h in self.send_handles:
|
||||
if h is not None:
|
||||
h.wait()
|
||||
|
||||
if self._disagg_pp_termination_handler is not None:
|
||||
self._disagg_pp_termination_handler.cleanup()
|
||||
|
||||
with self.response_cv:
|
||||
self.is_shutdown = True
|
||||
self.response_cv.notify_all()
|
||||
@ -826,6 +845,7 @@ class PyExecutor:
|
||||
|
||||
sample_state = self._sample_async(
|
||||
scheduled_batch, batch_outputs)
|
||||
assert sample_state is not None, "Sampling failed"
|
||||
self._update_request_states(scheduled_batch)
|
||||
|
||||
if self.enable_iter_perf_stats:
|
||||
@ -905,6 +925,12 @@ class PyExecutor:
|
||||
if self.kv_cache_transceiver and self.ctx_in_transmission_requests:
|
||||
self._terminate_ctx_finished_requests()
|
||||
|
||||
if self._disagg_pp_termination_handler is not None:
|
||||
requests_to_terminate = self._disagg_pp_termination_handler.sync(
|
||||
prev_microbatch_id)
|
||||
for req in requests_to_terminate:
|
||||
self._do_terminate_request(req)
|
||||
|
||||
# march forward in microbatch slots
|
||||
microbatch_id = (microbatch_id + 1) % self.num_micro_batches
|
||||
|
||||
@ -1696,9 +1722,13 @@ class PyExecutor:
|
||||
self._enqueue_responses(error_responses.items())
|
||||
|
||||
def _terminate_request(self, request: LlmRequest):
|
||||
if self.kv_connector_manager is None:
|
||||
self.resource_manager.free_resources(request)
|
||||
if self._disagg_pp_termination_handler is not None:
|
||||
self._disagg_pp_termination_handler.terminate(request)
|
||||
else:
|
||||
self._do_terminate_request(request)
|
||||
|
||||
def _do_terminate_request(self, request: LlmRequest):
|
||||
if self.kv_connector_manager is not None:
|
||||
# Only call request_finished on the connector if the request has already been added to the kv cache manager.
|
||||
try:
|
||||
cache_block_ids = self.kv_cache_manager.get_cache_indices(
|
||||
@ -1711,6 +1741,8 @@ class PyExecutor:
|
||||
if not self.kv_connector_manager.request_finished(
|
||||
request, cache_block_ids):
|
||||
self.resource_manager.free_resources(request)
|
||||
else:
|
||||
self.resource_manager.free_resources(request)
|
||||
|
||||
@nvtx_range("_handle_canceled_requests")
|
||||
def _handle_canceled_requests(self):
|
||||
@ -1919,3 +1951,104 @@ class PyExecutor:
|
||||
"""Remove reqids of current requests from self.inflight_req_ids."""
|
||||
for req in scheduled_requests.all_requests():
|
||||
self.inflight_req_ids.erase(req.request_id)
|
||||
|
||||
|
||||
class DisaggPPTerminationHandler:
|
||||
"""Handles termination synchronization across pipeline parallel ranks under disaggregated serving.
|
||||
|
||||
We require synchronization when terminating requests in disaggregated PP when
|
||||
KV cache reuse is enabled. All PP ranks need to reach consensus before freeing
|
||||
resources to avoid a NCCL hang.
|
||||
"""
|
||||
|
||||
def __init__(self, num_micro_batches: int, dist):
|
||||
self.dist = dist
|
||||
# Request termination synchronization across PP ranks
|
||||
# {request_id: {'ready_to_terminate': set{ranks}, 'terminated': {ranks}}}
|
||||
self.pending_termination = {}
|
||||
self.termination_handles = [None] * num_micro_batches
|
||||
# Local map from request_id -> local LlmRequest awaiting consensus termination
|
||||
self.local_termination = {}
|
||||
|
||||
def terminate(self, request: LlmRequest) -> bool:
|
||||
req_key = request.py_request_id
|
||||
self.local_termination[req_key] = request
|
||||
state = self.pending_termination.get(req_key, None)
|
||||
if state is None:
|
||||
state = {'ready_to_terminate': set(), 'terminated': set()}
|
||||
self.pending_termination[req_key] = state
|
||||
if self.dist.rank not in state['ready_to_terminate']:
|
||||
state['ready_to_terminate'].add(self.dist.rank)
|
||||
return False
|
||||
|
||||
def sync(self, microbatch_id: int) -> List[LlmRequest]:
|
||||
"""Ring-communicate pending termination state and apply local terminations upon consensus.
|
||||
|
||||
Each rank sends its current pending_termination snapshot to the next PP rank
|
||||
and receives the previous rank's snapshot. After merging, apply any terminations
|
||||
that have reached consensus (i.e., all PP ranks are ready).
|
||||
"""
|
||||
snapshot = {
|
||||
req_id: {
|
||||
'ready_to_terminate': state.get('ready_to_terminate', set()),
|
||||
'terminated': state.get('terminated', set()),
|
||||
}
|
||||
for req_id, state in self.pending_termination.items()
|
||||
}
|
||||
|
||||
if self.termination_handles[microbatch_id] is not None:
|
||||
self.termination_handles[microbatch_id].wait()
|
||||
|
||||
term_tag = TERMINATION_COMM_TAG_BASE + microbatch_id
|
||||
self.termination_handles[microbatch_id] = self.dist.isend_object(
|
||||
snapshot,
|
||||
dest=self.dist.next_pp_rank,
|
||||
tag=term_tag,
|
||||
)
|
||||
remote_state = self.dist.recv_object(
|
||||
src=self.dist.prev_pp_rank,
|
||||
tag=term_tag,
|
||||
)
|
||||
logger.debug(
|
||||
f"received remote state for microbatch {microbatch_id}, prev pp rank: {self.dist.prev_pp_rank} state {remote_state}"
|
||||
)
|
||||
|
||||
if remote_state:
|
||||
for req_id, state in remote_state.items():
|
||||
local = self.pending_termination.get(req_id)
|
||||
if local is None:
|
||||
self.pending_termination[req_id] = {
|
||||
'ready_to_terminate': state.get('ready_to_terminate',
|
||||
set()),
|
||||
'terminated': state.get('terminated', set()),
|
||||
}
|
||||
else:
|
||||
for key in ('ready_to_terminate', 'terminated'):
|
||||
for r in state.get(key, []):
|
||||
if r not in local[key]:
|
||||
local[key].add(r)
|
||||
|
||||
requests_to_terminate = []
|
||||
to_delete = []
|
||||
for req_id, state in self.pending_termination.items():
|
||||
ready = state.get('ready_to_terminate', set())
|
||||
done = state.get('terminated', set())
|
||||
# If all PP ranks are ready to terminate the request, we can free the resources
|
||||
if len(ready) >= self.dist.pp_size and self.dist.rank not in done:
|
||||
local_req = self.local_termination.get(req_id)
|
||||
if local_req is not None:
|
||||
requests_to_terminate.append(local_req)
|
||||
done.add(self.dist.rank)
|
||||
if len(done) >= self.dist.pp_size:
|
||||
to_delete.append(req_id)
|
||||
if req_id in self.local_termination:
|
||||
self.local_termination.pop(req_id, None)
|
||||
for req_id in to_delete:
|
||||
self.pending_termination.pop(req_id, None)
|
||||
|
||||
return requests_to_terminate
|
||||
|
||||
def cleanup(self):
|
||||
for h in self.termination_handles:
|
||||
if h is not None:
|
||||
h.wait()
|
||||
|
||||
@ -1,12 +1,16 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Literal, Optional
|
||||
from functools import cached_property
|
||||
from typing import List, Literal, Optional, TypeAlias
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._torch.pyexecutor.make_decoding_batch_input_output import \
|
||||
MakeDecodingBatchInputOutput
|
||||
from tensorrt_llm._torch.pyexecutor.sampler_utils import (
|
||||
BEAM_0, SINGLE_BEAM_WIDTH, handle_stop_single_beam)
|
||||
from tensorrt_llm._utils import nvtx_range, torch_dtype_to_binding
|
||||
from tensorrt_llm.bindings import (CudaStream, DataType, ModelConfig,
|
||||
WorldConfig, make_sampling_config)
|
||||
@ -355,21 +359,55 @@ def int_tensor(shape: tuple[int, ...], device: str = 'cuda') -> torch.Tensor:
|
||||
return torch.empty(shape, dtype=torch.int, device=device)
|
||||
|
||||
|
||||
class TorchStore:
|
||||
|
||||
def __init__(self, *, max_draft_len: int, max_num_sequences: int,
|
||||
max_beam_width: int):
|
||||
self.max_draft_len = max_draft_len
|
||||
self.max_num_sequences = max_num_sequences
|
||||
self.max_beam_width = max_beam_width
|
||||
self.max_tokens = max_draft_len + 1
|
||||
assert max_beam_width == SINGLE_BEAM_WIDTH, "TorchSampler only supports beam_width = 1"
|
||||
self.new_tokens = int_tensor(
|
||||
(self.max_tokens, max_num_sequences, max_beam_width))
|
||||
"""Shape: See cpp DecoderState.getAllNewTokens()"""
|
||||
self.finish_reasons = int_tensor(self.new_tokens.shape)
|
||||
|
||||
# Helper tensors for finish_reasons:
|
||||
self._finish_reasons_nonzero_static_buffer = torch.empty(
|
||||
(self.max_tokens * max_num_sequences, 2),
|
||||
device='cuda',
|
||||
dtype=torch.int64)
|
||||
"""Preallocate buffer needed for torch.nonzero_static(..., out=finish_reasons_nonzero_static_buffer), see `def _write_reason`"""
|
||||
self._reason_tensors = {
|
||||
reason:
|
||||
torch.tensor(reason.value,
|
||||
dtype=self.finish_reasons.dtype,
|
||||
device="cuda")
|
||||
for reason in [
|
||||
FinishReason.NOT_FINISHED, FinishReason.END_ID,
|
||||
FinishReason.STOP_WORDS, FinishReason.LENGTH,
|
||||
FinishReason.TIMED_OUT, FinishReason.CANCELLED
|
||||
] # `in FinishReason` clashes with PyBind11: `TypeError: 'pybind11_type' object is not iterable`
|
||||
}
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class SampleStateTensorsHostTorch(SampleStateTensors):
|
||||
finish_reasons: torch.Tensor
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class SampleStateTorch(SampleState):
|
||||
host: SampleStateTensorsHostTorch
|
||||
|
||||
|
||||
class TorchSampler(Sampler):
|
||||
BEAM = 0
|
||||
MAX_BEAM_WIDTH = BEAM + 1
|
||||
SampleState = SampleStateTorch
|
||||
|
||||
def is_generation_model(self) -> bool:
|
||||
return True
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Store:
|
||||
new_tokens: torch.Tensor
|
||||
"""Shape: See cpp DecoderState.getAllNewTokens()"""
|
||||
|
||||
def create_store(self) -> Store:
|
||||
return self.Store(new_tokens=int_tensor(self.NEW_TOKENS_SHAPE))
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Args:
|
||||
max_seq_len: int
|
||||
@ -381,17 +419,16 @@ class TorchSampler(Sampler):
|
||||
def __init__(self, args: Args):
|
||||
self.max_seq_len = args.max_seq_len
|
||||
self.enable_mixed_sampler = args.enable_mixed_sampler
|
||||
self.max_tokens = args.max_draft_len + 1
|
||||
assert args.max_beam_width == self.MAX_BEAM_WIDTH, "TorchSampler only supports beam_width = 1"
|
||||
self.max_num_sequences = args.max_num_sequences
|
||||
|
||||
self.NEW_TOKENS_SHAPE = (self.max_tokens, self.max_num_sequences,
|
||||
self.MAX_BEAM_WIDTH)
|
||||
# AutoDeploy build creates the sampler in inference mode,
|
||||
# which would disallow in-place mutating of new_tokens.
|
||||
# So, we temporarily exit inference mode.
|
||||
with torch.inference_mode(False):
|
||||
self.store = self.create_store()
|
||||
self.store = TorchStore(max_draft_len=args.max_draft_len,
|
||||
max_num_sequences=args.max_num_sequences,
|
||||
max_beam_width=args.max_beam_width)
|
||||
self.max_num_sequences = args.max_num_sequences
|
||||
self.max_tokens = self.store.max_tokens
|
||||
|
||||
# Initialize seed for multi-GPU consistency
|
||||
self._global_seed = 42
|
||||
@ -412,50 +449,7 @@ class TorchSampler(Sampler):
|
||||
self._generator.manual_seed(self._global_seed)
|
||||
return self._generator
|
||||
|
||||
def _meet_max_token_stop_criteria(self, request: LlmRequest):
|
||||
num_tokens = request.get_num_tokens(self.BEAM)
|
||||
return (num_tokens - request.py_orig_prompt_len
|
||||
>= request.py_max_new_tokens) or (num_tokens
|
||||
>= self.max_seq_len)
|
||||
|
||||
@staticmethod
|
||||
def _meet_stop_token_criteria(request: LlmRequest):
|
||||
if request.py_stop_words_list:
|
||||
assert isinstance(
|
||||
request.py_stop_words_list,
|
||||
list), "request.py_stop_words_list should be a list"
|
||||
stop_words_list, prefix_sum = request.py_stop_words_list
|
||||
tokens = request.get_tokens(0)
|
||||
offset = 0
|
||||
for i, offset_end in enumerate(prefix_sum):
|
||||
if i > 0:
|
||||
offset = prefix_sum[i - 1]
|
||||
stop_word = stop_words_list[offset:offset_end]
|
||||
if len(stop_word) > len(tokens):
|
||||
continue
|
||||
if tokens[-len(stop_word):] == stop_word:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _handle_stop_criteria(self, request: LlmRequest,
|
||||
new_token: int) -> bool:
|
||||
"""Handle stop criteria and set appropriate finish reasons and state.
|
||||
Returns True if generation should stop."""
|
||||
if new_token == request.py_end_id:
|
||||
request.finish_by(FinishReason.END_ID, self.BEAM)
|
||||
return True
|
||||
|
||||
if self._meet_max_token_stop_criteria(request):
|
||||
request.finish_by(FinishReason.LENGTH, self.BEAM)
|
||||
return True
|
||||
|
||||
if self._meet_stop_token_criteria(request):
|
||||
request.finish_by(FinishReason.STOP_WORDS, self.BEAM)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def handle_logprobs(self, request: LlmRequest, state: SampleState, *,
|
||||
def handle_logprobs(self, request: LlmRequest, state: SampleStateTorch, *,
|
||||
beam: int, count: int):
|
||||
current_slice = slice(0, count), request.py_seq_slot, beam
|
||||
if request.py_return_log_probs:
|
||||
@ -469,10 +463,26 @@ class TorchSampler(Sampler):
|
||||
assert beam == 0, "The following call relies on beam_width to be 1 - hence the list with a single element"
|
||||
request.py_result.append_log_probs([token_log_probs])
|
||||
|
||||
def _process_draft_tokens_greedy(self, request: LlmRequest,
|
||||
new_tokens: torch.Tensor) -> int:
|
||||
new_token = add_token(request, new_tokens, beam=self.BEAM)
|
||||
stop = self._handle_stop_criteria(request, new_token)
|
||||
FinishReasons: TypeAlias = list[list[int]]
|
||||
"""`(num_seq_slots, num_steps)`"""
|
||||
|
||||
@classmethod
|
||||
def finish_if_reason(cls, request: LlmRequest,
|
||||
finish_reasons: FinishReasons, *, step: int) -> bool:
|
||||
reason = FinishReason(finish_reasons[request.py_seq_slot][step])
|
||||
valid_reasons = {
|
||||
FinishReason.END_ID, FinishReason.LENGTH, FinishReason.STOP_WORDS
|
||||
}
|
||||
if reason in valid_reasons:
|
||||
request.finish_by(reason, BEAM_0)
|
||||
return True
|
||||
return False
|
||||
|
||||
def _process_draft_tokens_greedy(self, request: LlmRequest, *,
|
||||
new_tokens: torch.Tensor,
|
||||
finish_reasons: FinishReasons) -> int:
|
||||
new_token = add_token(request, new_tokens, beam=BEAM_0)
|
||||
stop = self.finish_if_reason(request, finish_reasons, step=0)
|
||||
if stop or get_draft_token_length(request) == 0:
|
||||
return 0
|
||||
num_accepted = 0
|
||||
@ -485,14 +495,17 @@ class TorchSampler(Sampler):
|
||||
num_accepted += 1
|
||||
new_token = add_token(request,
|
||||
new_tokens,
|
||||
beam=self.BEAM,
|
||||
beam=BEAM_0,
|
||||
step=num_accepted)
|
||||
if self._handle_stop_criteria(request, new_token):
|
||||
if self.finish_if_reason(request, finish_reasons,
|
||||
step=num_accepted):
|
||||
break
|
||||
return num_accepted
|
||||
|
||||
def _process_draft_tokens_rejection_sampling(
|
||||
self, request: LlmRequest, new_tokens: torch.Tensor) -> int:
|
||||
"""We cannot use finish_if_reason in _process_draft_tokens_rejection_sampling because it *writes to new_tokens*,
|
||||
rendering the finish reason calculation in sample_async stale (incorrect) for this batch"""
|
||||
sampling_strategy = request_strategy(request)
|
||||
generator = self.get_generator(request.py_draft_logits.device)
|
||||
_, draft_probs = sample(sampling_strategy,
|
||||
@ -503,7 +516,6 @@ class TorchSampler(Sampler):
|
||||
generator,
|
||||
request.py_draft_tokens)
|
||||
sample_last = True
|
||||
stop = False
|
||||
if rejected_indices.numel() == 0:
|
||||
num_initially_accepted = get_draft_token_length(request)
|
||||
sample_last = False
|
||||
@ -512,59 +524,62 @@ class TorchSampler(Sampler):
|
||||
num_accepted = num_initially_accepted
|
||||
for i in range(num_accepted):
|
||||
new_token = request.py_draft_tokens[i]
|
||||
new_tokens[i, request.seq_slot, self.BEAM] = new_token
|
||||
request.add_new_token(new_token, self.BEAM)
|
||||
stop = self._handle_stop_criteria(request, new_token)
|
||||
if stop:
|
||||
new_tokens[i, request.seq_slot, BEAM_0] = new_token
|
||||
request.add_new_token(new_token, BEAM_0)
|
||||
if handle_stop_single_beam(request,
|
||||
new_token,
|
||||
max_seq_len=self.max_seq_len):
|
||||
num_accepted = i + 1
|
||||
return num_accepted
|
||||
if sample_last:
|
||||
new_token = sample_rejected(draft_probs, target_probs, generator,
|
||||
num_accepted)
|
||||
new_tokens[num_accepted, request.seq_slot, self.BEAM] = new_token
|
||||
request.add_new_token(new_token, self.BEAM)
|
||||
stop = self._handle_stop_criteria(request, new_token)
|
||||
new_tokens[num_accepted, request.seq_slot, BEAM_0] = new_token
|
||||
request.add_new_token(new_token, BEAM_0)
|
||||
handle_stop_single_beam(request,
|
||||
new_token,
|
||||
max_seq_len=self.max_seq_len)
|
||||
else:
|
||||
new_token = add_token(request,
|
||||
new_tokens,
|
||||
beam=self.BEAM,
|
||||
beam=BEAM_0,
|
||||
step=num_accepted)
|
||||
stop = self._handle_stop_criteria(request, new_token)
|
||||
handle_stop_single_beam(request,
|
||||
new_token,
|
||||
max_seq_len=self.max_seq_len)
|
||||
|
||||
return num_accepted
|
||||
|
||||
def process_draft_tokens(self, request: LlmRequest,
|
||||
new_tokens: torch.Tensor) -> int:
|
||||
if request.py_draft_logits is None:
|
||||
return self._process_draft_tokens_greedy(request, new_tokens)
|
||||
else:
|
||||
return self._process_draft_tokens_rejection_sampling(
|
||||
request, new_tokens)
|
||||
|
||||
def update_requests(self, state: SampleState) -> None:
|
||||
assert isinstance(state, SampleState)
|
||||
def update_requests(self, state: SampleStateTorch) -> None:
|
||||
assert isinstance(state, SampleStateTorch)
|
||||
if state.sampler_event:
|
||||
state.sampler_event.synchronize()
|
||||
new_tokens = state.host.new_tokens
|
||||
finish_reasons = state.host.finish_reasons[:, :, BEAM_0].T.tolist()
|
||||
|
||||
for req in state.scheduled_requests.context_requests:
|
||||
if req.state == LlmRequestState.GENERATION_COMPLETE or req.context_remaining_length != 0:
|
||||
continue
|
||||
new_token = add_token(req, new_tokens, beam=self.BEAM)
|
||||
self._handle_stop_criteria(req, new_token)
|
||||
self.handle_logprobs(req, state, beam=self.BEAM, count=1)
|
||||
add_token(req, new_tokens, beam=BEAM_0)
|
||||
self.finish_if_reason(req, finish_reasons, step=0)
|
||||
self.handle_logprobs(req, state, beam=BEAM_0, count=1)
|
||||
req.py_decoding_iter += 1
|
||||
|
||||
for req in state.scheduled_requests.generation_requests:
|
||||
if req.state == LlmRequestState.GENERATION_COMPLETE:
|
||||
continue
|
||||
processed = 1
|
||||
num_accepted = self.process_draft_tokens(req, new_tokens)
|
||||
if req.py_draft_logits is None:
|
||||
num_accepted = self._process_draft_tokens_greedy(
|
||||
req, new_tokens=new_tokens, finish_reasons=finish_reasons)
|
||||
else:
|
||||
num_accepted = self._process_draft_tokens_rejection_sampling(
|
||||
req, new_tokens)
|
||||
if get_draft_token_length(req) > 0:
|
||||
req.py_num_accepted_draft_tokens = num_accepted
|
||||
req.py_rewind_len = req.py_draft_pages_allocated - num_accepted
|
||||
processed += num_accepted
|
||||
self.handle_logprobs(req, state, beam=self.BEAM, count=processed)
|
||||
self.handle_logprobs(req, state, beam=BEAM_0, count=processed)
|
||||
req.py_decoding_iter += 1
|
||||
|
||||
def log_probs_host(self, scheduled_requests: ScheduledRequests):
|
||||
@ -572,29 +587,50 @@ class TorchSampler(Sampler):
|
||||
if any(req.py_return_log_probs
|
||||
for req in scheduled_requests.all_requests()):
|
||||
return torch.empty(
|
||||
(self.max_num_sequences, self.MAX_BEAM_WIDTH, self.max_tokens),
|
||||
(self.max_num_sequences, SINGLE_BEAM_WIDTH, self.max_tokens),
|
||||
device="cpu",
|
||||
pin_memory=True)
|
||||
return None
|
||||
|
||||
def sample_async(self, scheduled_requests: ScheduledRequests,
|
||||
model_outputs: dict[str, torch.Tensor],
|
||||
num_context_logits_prefix_sum: list[int]) -> SampleState:
|
||||
def sample_async(
|
||||
self, scheduled_requests: ScheduledRequests,
|
||||
model_outputs: dict[str, torch.Tensor],
|
||||
num_context_logits_prefix_sum: list[int]) -> SampleStateTorch:
|
||||
requests = scheduled_requests.all_requests()
|
||||
new_tokens = self.store.new_tokens
|
||||
finish_reasons = self.store.finish_reasons
|
||||
log_probs_host = self.log_probs_host(scheduled_requests)
|
||||
seq_slots_host = torch.tensor(
|
||||
[r.py_seq_slot for r in requests],
|
||||
dtype=torch.int64, # for index_fill_
|
||||
pin_memory=True)
|
||||
seq_slots = seq_slots_host.to(device="cuda", non_blocking=True)
|
||||
self._process_requests(scheduled_requests,
|
||||
model_outputs,
|
||||
new_tokens,
|
||||
num_context_logits_prefix_sum,
|
||||
seq_slots=seq_slots,
|
||||
seq_slots_host=seq_slots_host,
|
||||
log_probs_host=log_probs_host)
|
||||
self._write_finish_reasons(requests,
|
||||
finish_reasons=finish_reasons,
|
||||
seq_slots=seq_slots,
|
||||
new_tokens=new_tokens)
|
||||
|
||||
new_tokens_host = new_tokens.to(device="cpu", non_blocking=True)
|
||||
finish_reasons_host = finish_reasons.to(device="cpu", non_blocking=True)
|
||||
sampler_event = torch.cuda.Event()
|
||||
sampler_event.record()
|
||||
return SampleState(scheduled_requests=scheduled_requests,
|
||||
device=SampleStateTensors(new_tokens=new_tokens),
|
||||
host=SampleStateTensors(new_tokens=new_tokens_host,
|
||||
log_probs=log_probs_host),
|
||||
sampler_event=sampler_event)
|
||||
return SampleStateTorch(
|
||||
scheduled_requests=scheduled_requests,
|
||||
device=SampleStateTensors(new_tokens=new_tokens),
|
||||
host=SampleStateTensorsHostTorch(
|
||||
new_tokens=new_tokens_host,
|
||||
log_probs=log_probs_host,
|
||||
finish_reasons=finish_reasons_host,
|
||||
),
|
||||
sampler_event=sampler_event,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def append_eagle3(tokens: torch.Tensor, model_outputs):
|
||||
@ -654,15 +690,202 @@ class TorchSampler(Sampler):
|
||||
|
||||
return logits
|
||||
|
||||
@staticmethod
|
||||
def _longest_stop_word_len(requests: Iterable[LlmRequest]) -> int:
|
||||
max_stop_word_len = 0
|
||||
for req in requests:
|
||||
_, cumsum = req.py_stop_words_list
|
||||
if -1 in cumsum:
|
||||
cumsum = cumsum[:cumsum.index(-1)]
|
||||
request_max_stop_word_len = np.max(np.diff(cumsum, prepend=0),
|
||||
initial=0)
|
||||
max_stop_word_len = max(max_stop_word_len,
|
||||
request_max_stop_word_len)
|
||||
return max_stop_word_len
|
||||
|
||||
@staticmethod
|
||||
def _requests_with_stop_words(
|
||||
requests: list[LlmRequest]) -> list[LlmRequest]:
|
||||
return [
|
||||
r for r in requests if (r.py_stop_words_list is not None
|
||||
and len(r.py_stop_words_list[0]) > 0)
|
||||
]
|
||||
|
||||
def _write_reason(self, finish_reasons: torch.Tensor, reason: FinishReason,
|
||||
*, where: torch.Tensor, seq_slots: torch.Tensor) -> None:
|
||||
"""Avoid GPU<->CPU syncs via:
|
||||
### `nonzero_static` [REF-A], see: https://ianbarber.blog/2024/12/18/nonzero_static-in-pytorch/.
|
||||
- `nonzero` syncs (frontend needs result size).
|
||||
- `nonzero_static` pads with dummy entries (`fill_value`), written into a prealloc buffer (max_num_sequences, 2).
|
||||
- Need to drop padding, but `buffer[buffer!=fill_value]`, `buffer[:count_nonzero]`, `buffer[:sum]` all sync.
|
||||
|
||||
### Hack:
|
||||
1. Use `fill_value=0`, so padding is `[..., [0,0], [0,0]]`.
|
||||
2. Write blindly to `finish_reasons` [REF-B]. Only `[seq_slot[0],0]` might have wrong values written to it, because of the padding entries.
|
||||
3. Save `[seq_slot[0],0]` in `before_write` [REF-C], restore if `where[0][0]` is `False` [REF-D].
|
||||
"""
|
||||
assert seq_slots.is_cuda and where.is_cuda
|
||||
assert seq_slots.shape[0] == where.shape[1]
|
||||
first_slot = seq_slots[0].unsqueeze(0)
|
||||
before_write = finish_reasons[0][:].index_select(
|
||||
0, first_slot).squeeze() # REF-C
|
||||
reason_tensor = self.store._reason_tensors[reason]
|
||||
buffer = self.store._finish_reasons_nonzero_static_buffer
|
||||
size = buffer.shape[0]
|
||||
torch.nonzero_static(where, size=size, fill_value=0,
|
||||
out=buffer) # REF-A
|
||||
r, c = buffer[:, 0], buffer[:, 1]
|
||||
finish_reasons[r, seq_slots[c], BEAM_0] = reason_tensor # REF-B
|
||||
|
||||
correct = torch.where(~where[0, 0], before_write, reason_tensor).view(1)
|
||||
assert correct.is_cuda
|
||||
finish_reasons[0, first_slot, BEAM_0] = correct # REF-D
|
||||
|
||||
def _write_finish_reasons(self, requests: list[LlmRequest], *,
|
||||
finish_reasons: torch.Tensor,
|
||||
seq_slots: torch.Tensor,
|
||||
new_tokens: torch.Tensor) -> None:
|
||||
"""later _write_reason overwrites earlier, in reverse precedence order"""
|
||||
tokens = new_tokens[:, seq_slots, BEAM_0]
|
||||
# we need to fill with NOT_FINISHED so we can differentiate between previous requests that had the same seq slot
|
||||
finish_reasons.index_fill_(1, seq_slots,
|
||||
FinishReason.NOT_FINISHED.value)
|
||||
|
||||
if with_stop_words := self._requests_with_stop_words(requests):
|
||||
stop_seq_slots = torch.tensor(
|
||||
[r.py_seq_slot for r in with_stop_words],
|
||||
pin_memory=True).to("cuda", non_blocking=True)
|
||||
stop_tokens = new_tokens[:, stop_seq_slots, BEAM_0]
|
||||
self._write_reason(
|
||||
finish_reasons,
|
||||
FinishReason.STOP_WORDS,
|
||||
where=self._are_stop_words(with_stop_words, stop_tokens),
|
||||
seq_slots=stop_seq_slots,
|
||||
)
|
||||
|
||||
self._write_reason(
|
||||
finish_reasons,
|
||||
FinishReason.LENGTH,
|
||||
where=self._are_max_length(requests),
|
||||
seq_slots=seq_slots,
|
||||
)
|
||||
|
||||
self._write_reason(
|
||||
finish_reasons,
|
||||
FinishReason.END_ID,
|
||||
where=self._are_end_id(requests, tokens),
|
||||
seq_slots=seq_slots,
|
||||
)
|
||||
|
||||
def _are_end_id(self, requests: list[LlmRequest],
|
||||
tokens: torch.Tensor) -> torch.Tensor:
|
||||
end_ids_tensor = torch.tensor(
|
||||
[([req.py_end_id if req.py_end_id is not None else -1] *
|
||||
self.max_tokens) for req in requests],
|
||||
pin_memory=True,
|
||||
dtype=tokens.dtype).T.to(device="cuda", non_blocking=True)
|
||||
return tokens == end_ids_tensor
|
||||
|
||||
def _are_max_length(self, requests: list[LlmRequest]) -> torch.Tensor:
|
||||
lengths_tensor = torch.tensor([[
|
||||
((req.get_num_tokens(BEAM_0) + num_tokens) - req.py_orig_prompt_len)
|
||||
for num_tokens in range(1, self.max_tokens + 1)
|
||||
] for req in requests])
|
||||
max_lengths_tensor = torch.tensor([
|
||||
([min(req.py_max_new_tokens, self.max_seq_len)] * self.max_tokens)
|
||||
for req in requests
|
||||
])
|
||||
return (lengths_tensor
|
||||
>= max_lengths_tensor).T.pin_memory().to(device="cuda",
|
||||
non_blocking=True)
|
||||
|
||||
_PAD_ID = -1
|
||||
"""Pad with negative, doesn't matter what"""
|
||||
|
||||
@cached_property
|
||||
def _pad_steps_mask(self):
|
||||
square = torch.ones(self.max_tokens, self.max_tokens, dtype=torch.bool)
|
||||
pad_id = torch.tensor(self._PAD_ID)
|
||||
mask = torch.where(square.tril(), torch.tensor(1), pad_id)
|
||||
mask.pin_memory()
|
||||
return mask.to("cuda", non_blocking=True)
|
||||
|
||||
def _padded_old_tokens(self,
|
||||
requests: list[LlmRequest],
|
||||
new_tokens: torch.Tensor,
|
||||
pad_id: int = _PAD_ID) -> torch.Tensor:
|
||||
# TODO: make sure only the lookback tokens are pulled into the list
|
||||
longest = self._longest_stop_word_len(requests)
|
||||
assert longest > 0, f"{longest=}, longest stop word length should be greater than 0, as this code path is only reached with requests with stop words"
|
||||
lookback = longest - 1
|
||||
old_tokens = []
|
||||
for request in requests:
|
||||
old = request.get_tokens(BEAM_0)[-lookback:] if lookback > 0 else []
|
||||
padded = [pad_id] * max(0, lookback - len(old)) + old
|
||||
old_tokens.append([padded] * self.max_tokens)
|
||||
old_tokens_tensor = torch.tensor(old_tokens,
|
||||
pin_memory=True).to("cuda",
|
||||
non_blocking=True)
|
||||
assert old_tokens_tensor.shape == (
|
||||
len(requests), self.max_tokens, lookback
|
||||
), f"{old_tokens_tensor.shape} != ({len(requests)=}, {self.max_tokens=}, {lookback=})"
|
||||
new_tokens = new_tokens.T.unsqueeze(1) * self._pad_steps_mask
|
||||
ret = torch.cat((old_tokens_tensor, new_tokens), dim=-1)
|
||||
assert ret.shape == (
|
||||
len(requests), self.max_tokens, lookback + self.max_tokens
|
||||
), f"{ret.shape} != ({len(requests)=}, {self.max_tokens=}, {lookback + self.max_tokens=})"
|
||||
return ret
|
||||
|
||||
def _are_stop_words(self, requests: list[LlmRequest],
|
||||
tokens: torch.Tensor) -> torch.Tensor:
|
||||
per_step = torch.zeros((self.max_tokens, len(requests)),
|
||||
dtype=torch.bool,
|
||||
pin_memory=True).to("cuda", non_blocking=True)
|
||||
|
||||
padded_tokens = self._padded_old_tokens(requests, tokens)
|
||||
|
||||
def request_stop_words(request: LlmRequest, new_tokens: torch.Tensor):
|
||||
swl, ends = request.py_stop_words_list
|
||||
if -1 in ends:
|
||||
ends = ends[:ends.index(-1)]
|
||||
lens = np.diff(ends, prepend=0)
|
||||
lens_device = torch.tensor(list(lens),
|
||||
pin_memory=True).to("cuda",
|
||||
non_blocking=True)
|
||||
max_len = np.max(lens)
|
||||
|
||||
words = torch.zeros(len(lens),
|
||||
max_len,
|
||||
dtype=torch.int32,
|
||||
pin_memory=True)
|
||||
for step, (start, l) in enumerate(zip([0] + ends, lens)):
|
||||
words[step, :l] = torch.tensor(swl[start:start + l],
|
||||
dtype=torch.int32)
|
||||
words_device = words.to("cuda", non_blocking=True)
|
||||
|
||||
for step, step_seq in enumerate(new_tokens):
|
||||
for word, L in zip(words_device, lens_device):
|
||||
truncated_seq = step_seq[step_seq >= 0][-L:]
|
||||
if torch.equal(truncated_seq, word[-L:]):
|
||||
# We don't care about subsequent steps because we already found a stop word match
|
||||
return step
|
||||
return None
|
||||
|
||||
for request_idx, request in enumerate(requests):
|
||||
step = request_stop_words(request, padded_tokens[request_idx])
|
||||
if step is not None:
|
||||
per_step[step][request_idx] = True
|
||||
return per_step
|
||||
|
||||
def _process_requests(self,
|
||||
scheduled_requests: ScheduledRequests,
|
||||
model_outputs: dict[str, torch.Tensor],
|
||||
new_tokens: torch.Tensor,
|
||||
num_context_logits_prefix_sum: list[int],
|
||||
*,
|
||||
seq_slots: torch.Tensor,
|
||||
seq_slots_host: torch.Tensor,
|
||||
log_probs_host: torch.Tensor | None = None):
|
||||
beam_width = self.MAX_BEAM_WIDTH
|
||||
beam = self.BEAM
|
||||
|
||||
# raw_logits should contain only the logits from the gen requests.
|
||||
# If return context logits is requested, fetch only the logits from gen requests.
|
||||
@ -687,16 +910,13 @@ class TorchSampler(Sampler):
|
||||
no_draft_tokens = len(requests) == sum_steps
|
||||
fast_path = not self.enable_mixed_sampler and no_draft_tokens and log_probs_host is None
|
||||
|
||||
seq_slots_host = torch.as_tensor([r.py_seq_slot for r in requests])
|
||||
seq_slots = seq_slots_host.to(device="cuda", non_blocking=True)
|
||||
|
||||
if fast_path:
|
||||
logits = raw_logits[:len(requests)]
|
||||
logits = self._apply_embedding_bias(logits, requests)
|
||||
next_tokens = torch.argmax(logits, dim=-1)
|
||||
self.append_eagle3(next_tokens, model_outputs)
|
||||
int_next_tokens = next_tokens.to(torch.int, non_blocking=True)
|
||||
next_tokens = int_next_tokens.view(1, -1, beam_width)
|
||||
next_tokens = int_next_tokens.view(1, -1, SINGLE_BEAM_WIDTH)
|
||||
new_tokens[:1].index_copy_(1, seq_slots, next_tokens)
|
||||
return
|
||||
|
||||
@ -737,17 +957,17 @@ class TorchSampler(Sampler):
|
||||
# Batched processing already applied bias, just use the results
|
||||
next_tokens = batched_next_tokens[input_slice]
|
||||
softmax = batched_softmax[input_slice]
|
||||
current_slice = slice(0, steps), slot, beam
|
||||
current_slice = slice(0, steps), slot, BEAM_0
|
||||
new_tokens[current_slice] = next_tokens
|
||||
if request.py_draft_logits is not None:
|
||||
request.py_target_probs = softmax.clone()
|
||||
if log_probs_host is not None:
|
||||
assert beam == 0, "The following call relies on beam_width to be 1 - hence the unsqueeze"
|
||||
assert BEAM_0 == 0, "The following call relies on beam_width to be 1 - hence the unsqueeze"
|
||||
token_probs = torch.gather(
|
||||
softmax, dim=1, index=next_tokens.unsqueeze(1)).squeeze(-1)
|
||||
log_probs = torch.log(token_probs)
|
||||
log_probs_host[slot, beam, :steps].copy_(log_probs,
|
||||
non_blocking=True)
|
||||
log_probs_host[slot, BEAM_0, :steps].copy_(log_probs,
|
||||
non_blocking=True)
|
||||
offset += steps
|
||||
|
||||
|
||||
|
||||
61
tensorrt_llm/_torch/pyexecutor/sampler_utils.py
Normal file
61
tensorrt_llm/_torch/pyexecutor/sampler_utils.py
Normal file
@ -0,0 +1,61 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from tensorrt_llm.bindings.executor import FinishReason
|
||||
|
||||
from .llm_request import LlmRequest, produce_stop_words
|
||||
|
||||
BEAM_0 = 0
|
||||
SINGLE_BEAM_WIDTH = 1
|
||||
|
||||
|
||||
def max_token_criteria_single_beam(request: LlmRequest,
|
||||
max_seq_len: int) -> bool:
|
||||
num_tokens = request.get_num_tokens(BEAM_0)
|
||||
return (num_tokens - request.py_orig_prompt_len
|
||||
>= request.py_max_new_tokens) or (num_tokens >= max_seq_len)
|
||||
|
||||
|
||||
def stop_token_criteria(py_stop_words_list: list[list[int]] | None,
|
||||
tokens: list[int]) -> bool:
|
||||
if py_stop_words_list:
|
||||
assert isinstance(py_stop_words_list,
|
||||
list), "request.py_stop_words_list should be a list"
|
||||
for stop_word in produce_stop_words(py_stop_words_list):
|
||||
if len(stop_word) > len(tokens):
|
||||
continue
|
||||
if tokens[-len(stop_word):] == stop_word:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def handle_stop_single_beam(request: LlmRequest, new_token: int, *,
|
||||
max_seq_len: int) -> bool:
|
||||
"""Handle stop criteria and set appropriate finish reasons and state.
|
||||
Returns True if generation should stop."""
|
||||
if new_token == request.py_end_id:
|
||||
request.finish_by(FinishReason.END_ID, BEAM_0)
|
||||
return True
|
||||
|
||||
if max_token_criteria_single_beam(request, max_seq_len):
|
||||
request.finish_by(FinishReason.LENGTH, BEAM_0)
|
||||
return True
|
||||
|
||||
if stop_token_criteria(request.py_stop_words_list,
|
||||
request.get_tokens(BEAM_0)):
|
||||
request.finish_by(FinishReason.STOP_WORDS, BEAM_0)
|
||||
return True
|
||||
|
||||
return False
|
||||
@ -20,9 +20,12 @@ from tensorrt_llm._torch.speculative.interface import SpecMetadata
|
||||
@contextmanager
|
||||
def save_metadata_state(attn_metadata: AttentionMetadata,
|
||||
spec_metadata: SpecMetadata) -> None:
|
||||
attn_metadata.prepare_for_spec_dec("_seq_lens", "_seq_lens_cuda",
|
||||
"kv_lens_cuda")
|
||||
attn_metadata.prepare_for_spec_dec("_seq_lens", "_seq_lens_cuda")
|
||||
batch_size = attn_metadata.num_seqs
|
||||
# Do not use prepare_for_spec_dec for this special field.
|
||||
# TRTLLM attention uses views of this tensor internally and prepare_for_spec_dec
|
||||
# creates a copy. If you write to the copy, TRTLLM attention won't see the updates.
|
||||
kv_lens = attn_metadata.kv_lens_cuda[:batch_size].clone()
|
||||
|
||||
if attn_metadata.is_cuda_graph:
|
||||
assert spec_metadata.is_cuda_graph
|
||||
@ -39,6 +42,8 @@ def save_metadata_state(attn_metadata: AttentionMetadata,
|
||||
yield
|
||||
finally:
|
||||
attn_metadata.restore_from_spec_dec()
|
||||
attn_metadata.kv_lens_cuda[:batch_size].copy_(kv_lens)
|
||||
|
||||
if attn_metadata.is_cuda_graph:
|
||||
spec_metadata.num_tokens = num_tokens
|
||||
if isinstance(spec_metadata, Eagle3SpecMetadata):
|
||||
|
||||
@ -10,8 +10,11 @@ from ..model_config import ModelConfig
|
||||
from ..pyexecutor.guided_decoder import CapturableGuidedDecoder
|
||||
from ..pyexecutor.llm_request import LlmRequest, LlmRequestState
|
||||
from ..pyexecutor.resource_manager import BaseResourceManager, SlotManager
|
||||
from ..pyexecutor.sampler import (SampleState, SampleStateTensors, TorchSampler,
|
||||
add_token, int_tensor)
|
||||
from ..pyexecutor.sampler import (Sampler, SampleState, SampleStateTensors,
|
||||
TorchSampler, TorchStore, add_token,
|
||||
int_tensor)
|
||||
from ..pyexecutor.sampler_utils import (BEAM_0, SINGLE_BEAM_WIDTH,
|
||||
handle_stop_single_beam)
|
||||
from ..pyexecutor.scheduler import ScheduledRequests
|
||||
from .interface import SpecMetadata
|
||||
|
||||
@ -206,40 +209,44 @@ class MTPSpecMetadata(SpecMetadata):
|
||||
self.slot_ids[:num_seqs].copy_(mtp_slot_ids, non_blocking=True)
|
||||
|
||||
|
||||
class MTPSampler(TorchSampler):
|
||||
class MTPStore(TorchStore):
|
||||
|
||||
def __init__(self, *, max_draft_len: int, max_num_sequences: int,
|
||||
max_beam_width: int):
|
||||
super().__init__(max_draft_len=max_draft_len,
|
||||
max_num_sequences=max_num_sequences,
|
||||
max_beam_width=max_beam_width)
|
||||
self.next_new_tokens = int_tensor(
|
||||
(self.max_tokens, self.max_num_sequences, SINGLE_BEAM_WIDTH))
|
||||
self.next_draft_tokens = int_tensor(
|
||||
(self.max_num_sequences, self.max_draft_len))
|
||||
self.new_tokens_lens = int_tensor((self.max_num_sequences, ))
|
||||
|
||||
|
||||
class MTPSampler(Sampler):
|
||||
"""
|
||||
MTP sampler.
|
||||
"""
|
||||
|
||||
SampleState = SampleStateMTP
|
||||
|
||||
def is_generation_model(self) -> bool:
|
||||
return True
|
||||
|
||||
def __init__(self, args: TorchSampler.Args, *, nextn: int):
|
||||
self.mapping = None
|
||||
self.draft_len = nextn
|
||||
super().__init__(args)
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Store(TorchSampler.Store):
|
||||
next_new_tokens: torch.Tensor
|
||||
next_draft_tokens: torch.Tensor
|
||||
new_tokens_lens: torch.Tensor
|
||||
|
||||
def create_store(self) -> Store:
|
||||
num_tokens, seq_slots, _ = self.NEW_TOKENS_SHAPE
|
||||
draft_len = num_tokens - 1
|
||||
assert draft_len == self.draft_len
|
||||
return self.Store(
|
||||
new_tokens=int_tensor(self.NEW_TOKENS_SHAPE),
|
||||
next_new_tokens=int_tensor(self.NEW_TOKENS_SHAPE),
|
||||
next_draft_tokens=int_tensor((seq_slots, draft_len)),
|
||||
new_tokens_lens=int_tensor((seq_slots, )),
|
||||
)
|
||||
self.store = MTPStore(max_draft_len=nextn,
|
||||
max_num_sequences=args.max_num_sequences,
|
||||
max_beam_width=args.max_beam_width)
|
||||
self.max_seq_len = args.max_seq_len
|
||||
|
||||
def _request_common_handling(self, request: LlmRequest,
|
||||
next_draft_tokens: list[list[int]]):
|
||||
assert not request.py_return_context_logits, "return_context_logits not implemented for MTPSampler"
|
||||
assert not request.py_return_generation_logits, "return_generation_logits not implemented for MTPSampler"
|
||||
assert not request.py_return_log_probs, "return_log_probs not implemented for MTPSampler"
|
||||
assert request.py_seq_slot is not None
|
||||
request.py_draft_tokens = next_draft_tokens[request.py_seq_slot]
|
||||
request.py_decoding_iter += 1
|
||||
|
||||
@ -250,12 +257,12 @@ class MTPSampler(TorchSampler):
|
||||
new_tokens = state.host.new_tokens
|
||||
new_tokens_lens_list = state.host.new_tokens_lens.tolist()
|
||||
next_draft_tokens_list = state.host.next_draft_tokens.tolist()
|
||||
beam_idx = self.BEAM
|
||||
max_seq_len = self.max_seq_len
|
||||
for req in state.scheduled_requests.context_requests:
|
||||
if req.state == LlmRequestState.GENERATION_COMPLETE or req.context_remaining_length != 0:
|
||||
continue
|
||||
new_token = add_token(req, new_tokens, beam=beam_idx)
|
||||
self._handle_stop_criteria(req, new_token)
|
||||
new_token = add_token(req, new_tokens, beam=BEAM_0)
|
||||
handle_stop_single_beam(req, new_token, max_seq_len=max_seq_len)
|
||||
self._request_common_handling(req, next_draft_tokens_list)
|
||||
|
||||
for req in state.scheduled_requests.generation_requests:
|
||||
@ -263,8 +270,10 @@ class MTPSampler(TorchSampler):
|
||||
continue
|
||||
num_new_tokens = new_tokens_lens_list[req.py_seq_slot]
|
||||
for i in range(num_new_tokens):
|
||||
new_token = add_token(req, new_tokens, beam=beam_idx, step=i)
|
||||
if self._handle_stop_criteria(req, new_token):
|
||||
new_token = add_token(req, new_tokens, beam=BEAM_0, step=i)
|
||||
if handle_stop_single_beam(req,
|
||||
new_token,
|
||||
max_seq_len=max_seq_len):
|
||||
break
|
||||
req.py_num_accepted_draft_tokens = num_new_tokens - 1
|
||||
req.py_rewind_len = self.draft_len - req.py_num_accepted_draft_tokens
|
||||
|
||||
@ -84,8 +84,24 @@ class RuntimeConfig(BaseModel):
|
||||
backend_cache_config = llm_args.pop("kv_cache_config", {})
|
||||
llm_args["kv_cache_config"] = backend_cache_config | kv_cache_config
|
||||
|
||||
return update_llm_args_with_extra_options(llm_args,
|
||||
self.extra_llm_api_options)
|
||||
updated_llm_args = update_llm_args_with_extra_options(
|
||||
llm_args, self.extra_llm_api_options)
|
||||
|
||||
if self.backend == "pytorch":
|
||||
cuda_graph_config = updated_llm_args.pop(
|
||||
"cuda_graph_config", llm_args["cuda_graph_config"])
|
||||
# Use runtime max_batch_size as cuda_graph_config.max_batch_size
|
||||
# if both max_batch_size and batch_sizes are not set.
|
||||
batch_sizes_set = cuda_graph_config.get("batch_sizes",
|
||||
None) is not None
|
||||
max_batch_size_set = cuda_graph_config.get("max_batch_size",
|
||||
None) is not None
|
||||
if not batch_sizes_set and not max_batch_size_set:
|
||||
cuda_graph_config[
|
||||
"max_batch_size"] = self.settings_config.max_batch_size
|
||||
updated_llm_args["cuda_graph_config"] = cuda_graph_config
|
||||
|
||||
return updated_llm_args
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_full_config(self) -> RuntimeConfig:
|
||||
|
||||
@ -125,13 +125,36 @@ class ZeroMqQueue:
|
||||
# Send data without HMAC
|
||||
self.socket.send_pyobj(obj)
|
||||
|
||||
def put_noblock(self, obj: Any):
|
||||
def put_noblock(self,
|
||||
obj: Any,
|
||||
*,
|
||||
retry: int = 1,
|
||||
wait_time: float = 0.001):
|
||||
'''
|
||||
Put an object into the queue without blocking, and retry if the send fails.
|
||||
NOTE: It won't raise any error if the send fails.
|
||||
|
||||
Parameters:
|
||||
obj (Any): The object to send.
|
||||
retry (int): The number of times to retry sending the object.
|
||||
wait_time (float): The time to wait before retrying.
|
||||
'''
|
||||
|
||||
assert retry >= 0 and retry <= 10, "Retry must be between 0 and 10, adjust the wait_time if needed"
|
||||
|
||||
self.setup_lazily()
|
||||
with nvtx_range_debug("send", color="blue", category="IPC"):
|
||||
data = pickle.dumps(obj) # nosec B301
|
||||
if self.use_hmac_encryption:
|
||||
data = self._sign_data(data)
|
||||
self.socket.send(data, flags=zmq.NOBLOCK)
|
||||
try:
|
||||
self.socket.send(data, flags=zmq.NOBLOCK)
|
||||
except zmq.Again:
|
||||
if retry > 0:
|
||||
time.sleep(wait_time)
|
||||
self.put_noblock(obj, retry=retry - 1, wait_time=wait_time)
|
||||
else:
|
||||
logger.error(f"Failed to send object: {obj}")
|
||||
|
||||
async def put_async(self, obj: Any):
|
||||
self.setup_lazily()
|
||||
|
||||
@ -351,7 +351,7 @@ class GenerationExecutorProxy(GenerationExecutor):
|
||||
|
||||
# notify the workers to quit
|
||||
if all(not f.done() for f in self.mpi_futures):
|
||||
self.request_queue.put_noblock(None)
|
||||
self.request_queue.put_noblock(None, retry=4)
|
||||
|
||||
def shutdown(self):
|
||||
if not self.workers_started:
|
||||
|
||||
@ -1041,6 +1041,9 @@ class KvCacheConfig(StrictBaseModel, PybindMirror):
|
||||
"The data type to use for the Mamba SSM cache. If set to 'auto', the data type will be inferred from the model config."
|
||||
)
|
||||
|
||||
tokens_per_block: int = Field(default=32,
|
||||
description="The number of tokens per block.")
|
||||
|
||||
def _to_pybind(self):
|
||||
return _KvCacheConfig(
|
||||
enable_block_reuse=self.enable_block_reuse,
|
||||
@ -1946,6 +1949,9 @@ class BaseLlmArgs(StrictBaseModel):
|
||||
from tensorrt_llm._torch.speculative import suggest_spec_config
|
||||
spec_config = suggest_spec_config(max_batch_size)
|
||||
|
||||
if self.kv_cache_config is not None:
|
||||
executor_config.tokens_per_block = self.kv_cache_config.tokens_per_block
|
||||
|
||||
update_executor_config(
|
||||
executor_config,
|
||||
backend=self.backend,
|
||||
|
||||
@ -435,7 +435,7 @@ class RemoteMpiCommSessionServer():
|
||||
f"RemoteMpiCommSessionServer received all results, sending to client\n",
|
||||
"green")
|
||||
try:
|
||||
self.queue.put_noblock(self.results)
|
||||
self.queue.put_noblock(self.results, retry=2)
|
||||
except zmq.ZMQError as e:
|
||||
# The client could be shutdown first.
|
||||
if e.errno == zmq.EAGAIN:
|
||||
|
||||
@ -518,6 +518,8 @@ def generate_api_docs_as_docstring(model: Type[BaseModel],
|
||||
|
||||
# Format the argument documentation with 12 spaces indent for args
|
||||
arg_line = f"{indent} {field_name} ({type_str}): "
|
||||
if status := field_info.get("status", None):
|
||||
arg_line += f":tag:`{status}` "
|
||||
if field_description:
|
||||
arg_line += field_description.split('\n')[0] # First line with type
|
||||
|
||||
@ -557,20 +559,21 @@ class ApiParamTagger:
|
||||
'''
|
||||
|
||||
def __call__(self, cls: Type[BaseModel]) -> None:
|
||||
self.process_pydantic_model(cls)
|
||||
""" The main entry point to tag the api doc. """
|
||||
self._process_pydantic_model(cls)
|
||||
|
||||
def process_pydantic_model(self, cls: Type[BaseModel]) -> None:
|
||||
def _process_pydantic_model(self, cls: Type[BaseModel]) -> None:
|
||||
"""Process the Pydantic model to add tags to the fields.
|
||||
"""
|
||||
for field_name, field_info in cls.model_fields.items():
|
||||
if field_info.json_schema_extra and 'status' in field_info.json_schema_extra:
|
||||
status = field_info.json_schema_extra['status']
|
||||
self.amend_pydantic_field_description_with_tags(
|
||||
self._amend_pydantic_field_description_with_tags(
|
||||
cls, [field_name], status)
|
||||
|
||||
def amend_pydantic_field_description_with_tags(self, cls: Type[BaseModel],
|
||||
field_names: list[str],
|
||||
tag: str) -> None:
|
||||
def _amend_pydantic_field_description_with_tags(self, cls: Type[BaseModel],
|
||||
field_names: list[str],
|
||||
tag: str) -> None:
|
||||
"""Amend the description of the fields with tags.
|
||||
e.g. :tag:`beta` or :tag:`prototype`
|
||||
Args:
|
||||
|
||||
@ -109,6 +109,7 @@ class Logger(metaclass=Singleton):
|
||||
self._func_wrapper(severity)(" ".join(parts))
|
||||
|
||||
def log_once(self, severity, *msg, key):
|
||||
assert key is not None, "key is required for log_once"
|
||||
if key not in self._appeared_keys:
|
||||
self._appeared_keys.add(key)
|
||||
self.log(severity, *msg)
|
||||
|
||||
@ -6,7 +6,7 @@ import re
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
from typing import Any, AsyncGenerator, Literal
|
||||
from typing import Any, List, Literal
|
||||
|
||||
from openai_harmony import (Author, Conversation, DeveloperContent,
|
||||
HarmonyEncodingName, HarmonyError, Message,
|
||||
@ -14,15 +14,15 @@ from openai_harmony import (Author, Conversation, DeveloperContent,
|
||||
SystemContent, TextContent, ToolDescription,
|
||||
load_harmony_encoding)
|
||||
|
||||
from tensorrt_llm.llmapi import RequestOutput
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
# yapf: disable
|
||||
from .openai_protocol import (ChatCompletionMessageParam, ChatCompletionRequest,
|
||||
from .openai_protocol import (ChatCompletionMessageParam,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseChoice,
|
||||
ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionStreamResponse, ChatMessage,
|
||||
ChatCompletionStreamResponse,
|
||||
ChatCompletionToolsParam, ChatMessage,
|
||||
DeltaFunctionCall, DeltaMessage, DeltaToolCall,
|
||||
UsageInfo)
|
||||
|
||||
@ -1485,36 +1485,72 @@ class HarmonyAdapter:
|
||||
return True
|
||||
|
||||
|
||||
async def handle_streaming_response(
|
||||
harmony_adapter: HarmonyAdapter,
|
||||
generator: RequestOutput,
|
||||
request_id: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Handle streaming response with harmony format."""
|
||||
_SERVE_HARMONY_ADAPTER: HarmonyAdapter = None
|
||||
|
||||
|
||||
def get_harmony_adapter():
|
||||
global _SERVE_HARMONY_ADAPTER
|
||||
if _SERVE_HARMONY_ADAPTER is None:
|
||||
_SERVE_HARMONY_ADAPTER = HarmonyAdapter()
|
||||
|
||||
return _SERVE_HARMONY_ADAPTER
|
||||
|
||||
|
||||
def handle_streaming_response(tools: List[ChatCompletionToolsParam],
|
||||
tool_choice: str, outputs: List, model: str,
|
||||
request_id: str, done: bool,
|
||||
num_prompt_tokens: int):
|
||||
first_iteration = True
|
||||
async for res in generator:
|
||||
output = res.outputs[0]
|
||||
output = outputs[0]
|
||||
|
||||
# Convert tools to dictionary format for harmony adapter (standard pattern)
|
||||
tools_dict = None
|
||||
if request.tools:
|
||||
tools_dict = [tool.model_dump() for tool in request.tools]
|
||||
# Convert tools to dictionary format for harmony adapter (standard pattern)
|
||||
tools_dict = None
|
||||
harmony_adapter = get_harmony_adapter()
|
||||
if tools:
|
||||
tools_dict = [tool.model_dump() for tool in tools]
|
||||
|
||||
# Get tool_choice from request - if "none", don't pass tools to parser
|
||||
tool_choice = getattr(request, 'tool_choice', None)
|
||||
if tool_choice == "none":
|
||||
tools_for_parser = None
|
||||
# Get tool_choice from request - if "none", don't pass tools to parser
|
||||
if tool_choice == "none":
|
||||
tools_for_parser = None
|
||||
else:
|
||||
tools_for_parser = tools_dict
|
||||
|
||||
# Create OpenAI streaming responses
|
||||
try:
|
||||
res = []
|
||||
if done:
|
||||
# Clean up state
|
||||
harmony_adapter.cleanup_stream_state(request_id)
|
||||
|
||||
usage_info = _create_usage_info(num_prompt_tokens, outputs)
|
||||
|
||||
# Send final message with finish_reason
|
||||
final_response = ChatCompletionStreamResponse(
|
||||
model=model,
|
||||
choices=[
|
||||
ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=DeltaMessage(),
|
||||
finish_reason=output.finish_reason,
|
||||
stop_reason=output.stop_reason)
|
||||
],
|
||||
)
|
||||
|
||||
final_response_json = final_response.model_dump_json(
|
||||
exclude_none=True)
|
||||
final_usage_chunk = ChatCompletionStreamResponse(choices=[],
|
||||
model=model,
|
||||
usage=usage_info)
|
||||
final_usage_json = final_usage_chunk.model_dump_json(
|
||||
exclude_none=True)
|
||||
res.append(f"data: {final_response_json}\n\n")
|
||||
res.append(f"data: {final_usage_json}\n\n")
|
||||
else:
|
||||
tools_for_parser = tools_dict
|
||||
|
||||
# Create OpenAI streaming responses
|
||||
try:
|
||||
responses = harmony_adapter.create_openai_streaming_response(
|
||||
request_id=request_id,
|
||||
tokens=output.token_ids_diff,
|
||||
available_tools=tools_for_parser,
|
||||
model_name=request.model,
|
||||
model_name=model,
|
||||
tool_choice=tool_choice)
|
||||
# Send first response after receiving the first output
|
||||
if first_iteration:
|
||||
@ -1525,64 +1561,44 @@ async def handle_streaming_response(
|
||||
delta=first_delta)
|
||||
|
||||
first_response = ChatCompletionStreamResponse(
|
||||
model=request.model,
|
||||
model=model,
|
||||
choices=[choice],
|
||||
)
|
||||
|
||||
response_json = first_response.model_dump_json(
|
||||
exclude_none=True)
|
||||
yield f"data: {response_json}\n\n"
|
||||
res.append(f"data: {response_json}\n\n")
|
||||
|
||||
for response in responses:
|
||||
yield response
|
||||
res.extend(responses)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create OpenAI streaming response: {e}")
|
||||
logger.debug(f"Streaming error details: {traceback.format_exc()}")
|
||||
# Clean up state
|
||||
harmony_adapter.cleanup_stream_state(request_id)
|
||||
raise e
|
||||
return res
|
||||
|
||||
# Clean up state
|
||||
harmony_adapter.cleanup_stream_state(request_id)
|
||||
|
||||
# Send final message with finish_reason
|
||||
output = generator.outputs[0]
|
||||
final_response = ChatCompletionStreamResponse(
|
||||
model=request.model,
|
||||
choices=[
|
||||
ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=DeltaMessage(),
|
||||
finish_reason=output.finish_reason,
|
||||
stop_reason=output.stop_reason)
|
||||
])
|
||||
|
||||
yield f"data: {final_response.model_dump_json(exclude_unset=True)}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create OpenAI streaming response: {e}")
|
||||
logger.debug(f"Streaming error details: {traceback.format_exc()}")
|
||||
# Clean up state
|
||||
harmony_adapter.cleanup_stream_state(request_id)
|
||||
raise e
|
||||
|
||||
|
||||
async def handle_non_streaming_response(
|
||||
harmony_adapter: HarmonyAdapter, promise: RequestOutput,
|
||||
request: ChatCompletionRequest) -> ChatCompletionResponse:
|
||||
def handle_non_streaming_response(tools: List[ChatCompletionToolsParam],
|
||||
tool_choice: str, outputs: List, model: str,
|
||||
num_prompt_tokens: int):
|
||||
"""Handle non-streaming response with harmony format."""
|
||||
# Get final result
|
||||
await promise
|
||||
|
||||
# Parse harmony output to OpenAI format
|
||||
# Convert tools to dictionary format for harmony adapter (standard pattern)
|
||||
tools_dict = None
|
||||
if request.tools:
|
||||
tools_dict = [tool.model_dump() for tool in request.tools]
|
||||
harmony_adapter = get_harmony_adapter()
|
||||
if tools:
|
||||
tools_dict = [tool.model_dump() for tool in tools]
|
||||
|
||||
# Get tool_choice from request - if "none", don't pass tools to parser
|
||||
tool_choice = getattr(request, 'tool_choice', None)
|
||||
if tool_choice == "none":
|
||||
tools_for_parser = None
|
||||
else:
|
||||
tools_for_parser = tools_dict
|
||||
|
||||
output = promise.outputs[0]
|
||||
output = outputs[0]
|
||||
parsed_output = harmony_adapter.harmony_output_to_openai(
|
||||
output.token_ids, tools_for_parser, tool_choice)
|
||||
|
||||
@ -1597,11 +1613,11 @@ async def handle_non_streaming_response(
|
||||
output.finish_reason)
|
||||
|
||||
# Create usage info from metrics (RequestOutput doesn't have usage in v1)
|
||||
usage_info = _create_usage_info(promise)
|
||||
usage_info = _create_usage_info(num_prompt_tokens, outputs)
|
||||
|
||||
# Create response
|
||||
response = ChatCompletionResponse(
|
||||
model=request.model,
|
||||
model=model,
|
||||
choices=[
|
||||
ChatCompletionResponseChoice(
|
||||
index=0,
|
||||
@ -1613,7 +1629,6 @@ async def handle_non_streaming_response(
|
||||
# Optional: Log if harmony parsing failed (for debugging)
|
||||
if parsed_output.get('_harmony_parsing_failed'):
|
||||
logger.warning("⚠️ Harmony parsing fell back to raw text decoding")
|
||||
logger.debug(f"request\n\n{request}")
|
||||
logger.debug(f"response\n\n{response}\n")
|
||||
|
||||
return response
|
||||
@ -1646,15 +1661,10 @@ def _determine_finish_reason(parsed_output: dict[str, Any],
|
||||
return reason
|
||||
|
||||
|
||||
def _create_usage_info(final_res: RequestOutput) -> UsageInfo:
|
||||
def _create_usage_info(num_prompt_tokens, outputs) -> UsageInfo:
|
||||
"""Create usage info from RequestOutput following serving_chat.py pattern."""
|
||||
# Calculate prompt tokens from prompt_token_ids and encoder_prompt_token_ids
|
||||
assert final_res.prompt_token_ids is not None
|
||||
num_prompt_tokens = len(final_res.prompt_token_ids)
|
||||
|
||||
# Calculate completion tokens from all outputs
|
||||
num_generated_tokens = sum(
|
||||
len(output.token_ids) for output in final_res.outputs)
|
||||
num_generated_tokens = sum(len(output.token_ids) for output in outputs)
|
||||
|
||||
# Create usage info
|
||||
usage = UsageInfo(prompt_tokens=num_prompt_tokens,
|
||||
|
||||
@ -322,6 +322,8 @@ class OpenAIDisaggServer:
|
||||
raise ValueError("Disagg server returned more than one choice. This is currently not supported in disaggregated server.")
|
||||
if choices[0].disaggregated_params is None:
|
||||
raise ValueError("Context server did not return disaggregated params")
|
||||
if choices[0].disaggregated_params.ctx_request_id is None:
|
||||
raise ValueError("Invalid disaggregated params in context phase response.")
|
||||
|
||||
return ctx_response
|
||||
|
||||
|
||||
@ -44,9 +44,10 @@ from tensorrt_llm.serve.openai_protocol import (ChatCompletionRequest,
|
||||
UsageInfo,
|
||||
to_llm_disaggregated_params)
|
||||
from tensorrt_llm.serve.postprocess_handlers import (
|
||||
ChatPostprocArgs, CompletionPostprocArgs, chat_response_post_processor,
|
||||
chat_stream_post_processor, completion_response_post_processor,
|
||||
completion_stream_post_processor)
|
||||
ChatCompletionPostprocArgs, ChatPostprocArgs, CompletionPostprocArgs,
|
||||
chat_harmony_post_processor, chat_harmony_streaming_post_processor,
|
||||
chat_response_post_processor, chat_stream_post_processor,
|
||||
completion_response_post_processor, completion_stream_post_processor)
|
||||
from tensorrt_llm.serve.responses_utils import ConversationHistoryStore
|
||||
from tensorrt_llm.serve.responses_utils import \
|
||||
create_response as responses_api_create_response
|
||||
@ -57,8 +58,7 @@ from tensorrt_llm.serve.responses_utils import \
|
||||
from tensorrt_llm.version import __version__ as VERSION
|
||||
|
||||
from .._utils import nvtx_mark, set_prometheus_multiproc_dir
|
||||
from .harmony_adapter import (HarmonyAdapter, handle_non_streaming_response,
|
||||
handle_streaming_response,
|
||||
from .harmony_adapter import (HarmonyAdapter, get_harmony_adapter,
|
||||
maybe_transform_reasoning_effort)
|
||||
|
||||
# yapf: enale
|
||||
@ -117,7 +117,11 @@ class OpenAIServer:
|
||||
|
||||
# gpt-oss
|
||||
self.harmony_adapter: HarmonyAdapter | None = None
|
||||
self.use_harmony = self.model_config.model_type == "gpt_oss"
|
||||
disable_harmony = os.getenv("DISABLE_HARMONY_ADAPTER", "0") == "1"
|
||||
if disable_harmony:
|
||||
self.use_harmony = False
|
||||
else:
|
||||
self.use_harmony = (self.model_config.model_type == "gpt_oss")
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
@ -704,11 +708,35 @@ class OpenAIServer:
|
||||
Chat Completion API with harmony format support.
|
||||
Supports both streaming and non-streaming modes.
|
||||
"""
|
||||
|
||||
async def create_harmony_response(
|
||||
promise: RequestOutput, postproc_params: PostprocParams) -> ChatCompletionResponse:
|
||||
await promise.aresult()
|
||||
if self.postproc_worker_enabled:
|
||||
chat_response =promise.outputs[0]._postprocess_result
|
||||
else:
|
||||
post_processor, args = postproc_params.post_processor, postproc_params.postproc_args
|
||||
chat_response = post_processor(promise, args)
|
||||
|
||||
return chat_response
|
||||
|
||||
async def create_streaming_generator(promise: RequestOutput, postproc_params: PostprocParams):
|
||||
if not self.postproc_worker_enabled:
|
||||
post_processor, args = postproc_params.post_processor, postproc_params.postproc_args
|
||||
|
||||
async for res in promise:
|
||||
pp_results = res.outputs[0]._postprocess_result if self.postproc_worker_enabled else post_processor(res, args)
|
||||
# await self._extract_metrics(res)
|
||||
for pp_res in pp_results:
|
||||
yield pp_res
|
||||
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
try:
|
||||
# Initialize HarmonyAdapter
|
||||
# NOTE: WAR for Disagg failure, may affect perf if no warmup
|
||||
if not self.harmony_adapter:
|
||||
self.harmony_adapter = HarmonyAdapter()
|
||||
self.harmony_adapter = get_harmony_adapter()
|
||||
# Convert Pydantic models to dictionaries for JSON serialization (standard pattern)
|
||||
tools_dict = None
|
||||
if request.tools:
|
||||
@ -743,27 +771,37 @@ class OpenAIServer:
|
||||
vocab_size=self.tokenizer.tokenizer.vocab_size)
|
||||
sampling_params.detokenize = False # Harmony adapter handles detokenization
|
||||
|
||||
postproc_args = ChatCompletionPostprocArgs.from_request(request)
|
||||
postproc_params = PostprocParams(
|
||||
post_processor=chat_harmony_streaming_post_processor
|
||||
if request.stream else chat_harmony_post_processor,
|
||||
postproc_args=postproc_args,
|
||||
)
|
||||
|
||||
# Generate
|
||||
promise = self.llm.generate_async(
|
||||
inputs=harmony_tokens,
|
||||
sampling_params=sampling_params,
|
||||
_postproc_params=postproc_params if self.postproc_worker_enabled else None,
|
||||
streaming=bool(request.stream),
|
||||
lora_request=request.lora_request,
|
||||
)
|
||||
postproc_args.request_id = promise.request_id
|
||||
|
||||
if not self.postproc_worker_enabled:
|
||||
postproc_args.num_prompt_tokens = len(promise.prompt_token_ids)
|
||||
|
||||
# Disconnect cancellation
|
||||
asyncio.create_task(self.await_disconnected(raw_request, promise))
|
||||
|
||||
# Handle streaming
|
||||
if request.stream:
|
||||
return StreamingResponse(
|
||||
handle_streaming_response(
|
||||
self.harmony_adapter, promise,
|
||||
str(promise.request_id), request,
|
||||
),
|
||||
content=create_streaming_generator(promise, postproc_params),
|
||||
media_type="text/event-stream"
|
||||
)
|
||||
else:
|
||||
response = await handle_non_streaming_response(self.harmony_adapter, promise, request)
|
||||
response = await create_harmony_response(promise, postproc_params)
|
||||
return JSONResponse(response.model_dump())
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@ -9,6 +9,8 @@ from ..llmapi.reasoning_parser import (BaseReasoningParser,
|
||||
ReasoningParserFactory)
|
||||
from ..llmapi.tokenizer import TransformersTokenizer
|
||||
# yapf: disable
|
||||
from .harmony_adapter import (handle_non_streaming_response,
|
||||
handle_streaming_response)
|
||||
from .openai_protocol import (ChatCompletionLogProbs,
|
||||
ChatCompletionLogProbsContent,
|
||||
ChatCompletionNamedToolChoiceParam,
|
||||
@ -24,7 +26,8 @@ from .openai_protocol import (ChatCompletionLogProbs,
|
||||
FunctionCall, StreamOptions, ToolCall, UsageInfo,
|
||||
to_disaggregated_params)
|
||||
|
||||
# yapf: enale
|
||||
# yapf: enable
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class ChatPostprocArgs(PostprocArgs):
|
||||
@ -57,8 +60,7 @@ class ChatPostprocArgs(PostprocArgs):
|
||||
)
|
||||
|
||||
|
||||
def create_logprobs(token_ids: List[int],
|
||||
tokenizer: TransformersTokenizer,
|
||||
def create_logprobs(token_ids: List[int], tokenizer: TransformersTokenizer,
|
||||
logprobs: List[float]) -> ChatCompletionLogProbs:
|
||||
assert len(token_ids) == len(logprobs), \
|
||||
"token_ids and logprobs have different lengths"
|
||||
@ -75,12 +77,14 @@ def create_logprobs(token_ids: List[int],
|
||||
return chat_logprobs
|
||||
|
||||
|
||||
def apply_reasoning_parser(args: ChatPostprocArgs, output_index: int, text: str, streaming: bool) -> Tuple[bool, str, str]:
|
||||
def apply_reasoning_parser(args: ChatPostprocArgs, output_index: int, text: str,
|
||||
streaming: bool) -> Tuple[bool, str, str]:
|
||||
reasoning_parser = None
|
||||
if args.reasoning_parser is not None:
|
||||
if output_index not in args.reasoning_parser_dict:
|
||||
args.reasoning_parser_dict[output_index] = ReasoningParserFactory.create_reasoning_parser(
|
||||
args.reasoning_parser)
|
||||
args.reasoning_parser_dict[
|
||||
output_index] = ReasoningParserFactory.create_reasoning_parser(
|
||||
args.reasoning_parser)
|
||||
reasoning_parser = args.reasoning_parser_dict[output_index]
|
||||
|
||||
in_reasoning = False
|
||||
@ -97,7 +101,8 @@ def apply_reasoning_parser(args: ChatPostprocArgs, output_index: int, text: str,
|
||||
|
||||
|
||||
@nvtx_range_debug("chat_stream_post_processor")
|
||||
def chat_stream_post_processor(rsp: GenerationResultBase, args: ChatPostprocArgs) -> List[str]:
|
||||
def chat_stream_post_processor(rsp: GenerationResultBase,
|
||||
args: ChatPostprocArgs) -> List[str]:
|
||||
|
||||
def yield_first_chat(num_tokens: int,
|
||||
idx: int,
|
||||
@ -128,9 +133,13 @@ def chat_stream_post_processor(rsp: GenerationResultBase, args: ChatPostprocArgs
|
||||
include_continuous_usage = False
|
||||
if args.first_iteration:
|
||||
for i in range(args.num_choices):
|
||||
res.append(f"data: {yield_first_chat(prompt_tokens, i, role=args.role)} \n\n")
|
||||
res.append(
|
||||
f"data: {yield_first_chat(prompt_tokens, i, role=args.role)} \n\n"
|
||||
)
|
||||
if args.echo and args.last_message_content:
|
||||
res.append(f"data: {yield_first_chat(prompt_tokens, i, content=args.last_message_content)} \n\n")
|
||||
res.append(
|
||||
f"data: {yield_first_chat(prompt_tokens, i, content=args.last_message_content)} \n\n"
|
||||
)
|
||||
args.first_iteration = False
|
||||
|
||||
for output in rsp.outputs:
|
||||
@ -158,14 +167,18 @@ def chat_stream_post_processor(rsp: GenerationResultBase, args: ChatPostprocArgs
|
||||
delta_message = DeltaMessage(
|
||||
content=delta_text, reasoning_content=reasoning_delta_text)
|
||||
|
||||
choice = ChatCompletionResponseStreamChoice(index=i,
|
||||
delta=delta_message,
|
||||
finish_reason=None,
|
||||
avg_decoded_tokens_per_iter=getattr(rsp, 'avg_decoded_tokens_per_iter', None))
|
||||
choice = ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
delta=delta_message,
|
||||
finish_reason=None,
|
||||
avg_decoded_tokens_per_iter=getattr(rsp,
|
||||
'avg_decoded_tokens_per_iter',
|
||||
None))
|
||||
if args.return_logprobs:
|
||||
logprobs = output.logprobs_diff
|
||||
token_ids = output.token_ids_diff
|
||||
choice.logprobs = create_logprobs(token_ids, args.tokenizer, logprobs)
|
||||
choice.logprobs = create_logprobs(token_ids, args.tokenizer,
|
||||
logprobs)
|
||||
if output.finish_reason is not None:
|
||||
choice.finish_reason = output.finish_reason
|
||||
choice.stop_reason = output.stop_reason
|
||||
@ -179,57 +192,62 @@ def chat_stream_post_processor(rsp: GenerationResultBase, args: ChatPostprocArgs
|
||||
res.append(f"data: {data}\n\n")
|
||||
|
||||
if include_usage and rsp._done:
|
||||
completion_tokens = sum(output.length
|
||||
for output in rsp.outputs)
|
||||
completion_tokens = sum(output.length for output in rsp.outputs)
|
||||
final_usage = UsageInfo(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
|
||||
final_usage_chunk = ChatCompletionStreamResponse(
|
||||
choices=[], model=args.model, usage=final_usage)
|
||||
final_usage_chunk = ChatCompletionStreamResponse(choices=[],
|
||||
model=args.model,
|
||||
usage=final_usage)
|
||||
final_usage_data = final_usage_chunk.model_dump_json()
|
||||
res.append(f"data: {final_usage_data}\n\n")
|
||||
return res
|
||||
|
||||
|
||||
@nvtx_range_debug("chat_response_post_processor")
|
||||
def chat_response_post_processor(rsp: GenerationResultBase, args: ChatPostprocArgs) -> ChatCompletionResponse:
|
||||
def chat_response_post_processor(
|
||||
rsp: GenerationResultBase,
|
||||
args: ChatPostprocArgs) -> ChatCompletionResponse:
|
||||
choices: List[ChatCompletionResponseChoice] = []
|
||||
role = args.role
|
||||
for output in rsp.outputs:
|
||||
_, text, reasoning_text = apply_reasoning_parser(
|
||||
args, output.index, output.text, False)
|
||||
|
||||
if args.tool_choice and isinstance(
|
||||
args.tool_choice,
|
||||
ChatCompletionNamedToolChoiceParam):
|
||||
if args.tool_choice and isinstance(args.tool_choice,
|
||||
ChatCompletionNamedToolChoiceParam):
|
||||
message = ChatMessage(
|
||||
role=role,
|
||||
content="",
|
||||
tool_calls=[
|
||||
ToolCall(function=FunctionCall(
|
||||
name=args.tool_choice.function.name,
|
||||
arguments=text))
|
||||
name=args.tool_choice.function.name, arguments=text))
|
||||
])
|
||||
else:
|
||||
if text is None:
|
||||
text = ""
|
||||
message = ChatMessage(
|
||||
role=role, content=text, reasoning_content=reasoning_text)
|
||||
disaggregated_params = to_disaggregated_params(output.disaggregated_params)
|
||||
message = ChatMessage(role=role,
|
||||
content=text,
|
||||
reasoning_content=reasoning_text)
|
||||
disaggregated_params = to_disaggregated_params(
|
||||
output.disaggregated_params)
|
||||
choice = ChatCompletionResponseChoice(
|
||||
index=output.index,
|
||||
message=message,
|
||||
finish_reason=output.finish_reason,
|
||||
stop_reason=output.stop_reason,
|
||||
disaggregated_params=disaggregated_params,
|
||||
avg_decoded_tokens_per_iter=getattr(rsp, 'avg_decoded_tokens_per_iter', None),
|
||||
avg_decoded_tokens_per_iter=getattr(rsp,
|
||||
'avg_decoded_tokens_per_iter',
|
||||
None),
|
||||
)
|
||||
|
||||
if args.return_logprobs:
|
||||
choice.logprobs = create_logprobs(output.token_ids, args.tokenizer, output.logprobs)
|
||||
choice.logprobs = create_logprobs(output.token_ids, args.tokenizer,
|
||||
output.logprobs)
|
||||
choices.append(choice)
|
||||
|
||||
if args.echo and args.last_message_content:
|
||||
@ -238,8 +256,7 @@ def chat_response_post_processor(rsp: GenerationResultBase, args: ChatPostprocAr
|
||||
choice.message.content = full_message
|
||||
|
||||
num_prompt_tokens = args.num_prompt_tokens
|
||||
num_generated_tokens = sum(
|
||||
len(output.token_ids) for output in rsp.outputs)
|
||||
num_generated_tokens = sum(len(output.token_ids) for output in rsp.outputs)
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=num_generated_tokens,
|
||||
@ -275,7 +292,8 @@ class CompletionPostprocArgs(PostprocArgs):
|
||||
|
||||
|
||||
@nvtx_range_debug("completion_stream_post_processor")
|
||||
def completion_stream_post_processor(rsp: DetokenizedGenerationResultBase, args: CompletionPostprocArgs) -> List[str]:
|
||||
def completion_stream_post_processor(rsp: DetokenizedGenerationResultBase,
|
||||
args: CompletionPostprocArgs) -> List[str]:
|
||||
res: List[str] = []
|
||||
prompt_tokens = args.num_prompt_tokens
|
||||
if stream_option := args.stream_options:
|
||||
@ -293,9 +311,11 @@ def completion_stream_post_processor(rsp: DetokenizedGenerationResultBase, args:
|
||||
index=args.prompt_idx * args.num_choices + output.index,
|
||||
text=delta_text if args.detokenize else "",
|
||||
token_ids=None if args.detokenize else output.token_ids_diff,
|
||||
finish_reason = output.finish_reason,
|
||||
stop_reason = output.stop_reason,
|
||||
avg_decoded_tokens_per_iter=getattr(rsp, 'avg_decoded_tokens_per_iter', None),
|
||||
finish_reason=output.finish_reason,
|
||||
stop_reason=output.stop_reason,
|
||||
avg_decoded_tokens_per_iter=getattr(rsp,
|
||||
'avg_decoded_tokens_per_iter',
|
||||
None),
|
||||
)
|
||||
chunk = CompletionStreamResponse(model=args.model, choices=[choice])
|
||||
if include_continuous_usage:
|
||||
@ -306,16 +326,16 @@ def completion_stream_post_processor(rsp: DetokenizedGenerationResultBase, args:
|
||||
res.append(f"data: {data}\n\n")
|
||||
|
||||
if include_usage and rsp._done:
|
||||
completion_tokens = sum(output.length
|
||||
for output in rsp.outputs)
|
||||
completion_tokens = sum(output.length for output in rsp.outputs)
|
||||
final_usage = UsageInfo(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
|
||||
final_usage_chunk = ChatCompletionStreamResponse(
|
||||
choices=[], model=args.model, usage=final_usage)
|
||||
final_usage_chunk = ChatCompletionStreamResponse(choices=[],
|
||||
model=args.model,
|
||||
usage=final_usage)
|
||||
final_usage_data = final_usage_chunk.model_dump_json()
|
||||
res.append(f"data: {final_usage_data}\n\n")
|
||||
args.first_iteration = False
|
||||
@ -323,7 +343,9 @@ def completion_stream_post_processor(rsp: DetokenizedGenerationResultBase, args:
|
||||
|
||||
|
||||
@nvtx_range_debug("completion_response_post_processor")
|
||||
def completion_response_post_processor(rsp: GenerationResult, args: CompletionPostprocArgs) -> CompletionResponse:
|
||||
def completion_response_post_processor(
|
||||
rsp: GenerationResult,
|
||||
args: CompletionPostprocArgs) -> CompletionResponse:
|
||||
prompt_tokens = args.num_prompt_tokens
|
||||
completion_tokens = 0
|
||||
choices = []
|
||||
@ -331,23 +353,75 @@ def completion_response_post_processor(rsp: GenerationResult, args: CompletionPo
|
||||
text = output.text
|
||||
if args.echo:
|
||||
text = args.prompt + text
|
||||
disaggregated_params = to_disaggregated_params(output.disaggregated_params)
|
||||
disaggregated_params = to_disaggregated_params(
|
||||
output.disaggregated_params)
|
||||
choice = CompletionResponseChoice(
|
||||
text=text if args.detokenize else "",
|
||||
token_ids=None if args.detokenize else output.token_ids,
|
||||
index=args.prompt_idx * args.num_choices + output.index,
|
||||
disaggregated_params=disaggregated_params,
|
||||
context_logits=None if rsp.context_logits is None else rsp.context_logits.tolist(),
|
||||
context_logits=None
|
||||
if rsp.context_logits is None else rsp.context_logits.tolist(),
|
||||
stop_reason=output.stop_reason,
|
||||
finish_reason=output.finish_reason,
|
||||
avg_decoded_tokens_per_iter=getattr(rsp, 'avg_decoded_tokens_per_iter', None),
|
||||
avg_decoded_tokens_per_iter=getattr(rsp,
|
||||
'avg_decoded_tokens_per_iter',
|
||||
None),
|
||||
)
|
||||
|
||||
completion_tokens += output.length
|
||||
choices.append(choice)
|
||||
|
||||
usage = UsageInfo(prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=completion_tokens + prompt_tokens)
|
||||
response = CompletionResponse(choices=choices, model=args.model, usage=usage)
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=completion_tokens + prompt_tokens)
|
||||
response = CompletionResponse(choices=choices,
|
||||
model=args.model,
|
||||
usage=usage)
|
||||
return response
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class ChatCompletionPostprocArgs(PostprocArgs):
|
||||
model: str
|
||||
tools: Optional[List[ChatCompletionToolsParam]]
|
||||
tool_choice: Optional[Union[Literal["none", "auto"],
|
||||
ChatCompletionNamedToolChoiceParam]]
|
||||
request_id: Optional[int] = None
|
||||
|
||||
@classmethod
|
||||
def from_request(cls, request: ChatCompletionRequest):
|
||||
return cls(
|
||||
model=request.model,
|
||||
tools=request.tools,
|
||||
tool_choice=request.tool_choice,
|
||||
)
|
||||
|
||||
|
||||
@nvtx_range_debug("chat_harmony_post_processor")
|
||||
def chat_harmony_post_processor(
|
||||
rsp: GenerationResult,
|
||||
args: ChatCompletionPostprocArgs) -> ChatCompletionResponse:
|
||||
response = handle_non_streaming_response(
|
||||
tools=args.tools,
|
||||
tool_choice=args.tool_choice,
|
||||
outputs=rsp.outputs,
|
||||
model=args.model,
|
||||
num_prompt_tokens=args.num_prompt_tokens,
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@nvtx_range_debug("chat_harmony_streaming_post_processor")
|
||||
def chat_harmony_streaming_post_processor(
|
||||
rsp: GenerationResult, args: ChatCompletionPostprocArgs) -> List[str]:
|
||||
response = handle_streaming_response(
|
||||
tools=args.tools,
|
||||
tool_choice=args.tool_choice,
|
||||
outputs=rsp.outputs,
|
||||
model=args.model,
|
||||
request_id=args.request_id,
|
||||
done=rsp._done,
|
||||
num_prompt_tokens=args.num_prompt_tokens,
|
||||
)
|
||||
return response
|
||||
|
||||
@ -8,6 +8,9 @@ meta-llama/Llama-3.3-70B-Instruct:
|
||||
accuracy: 48.03
|
||||
- quant_algo: FP8
|
||||
accuracy: 48.03
|
||||
- quant_algo: FP8
|
||||
kv_cache_quant_algo: FP8
|
||||
accuracy: 48.03
|
||||
deepseek-ai/DeepSeek-R1:
|
||||
- quant_algo: NVFP4
|
||||
accuracy: 70.45
|
||||
|
||||
@ -602,9 +602,9 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
|
||||
@parametrize_with_ids("gen_tp", [1, 2])
|
||||
@pytest.mark.parametrize("testset", ["GSM8K", "MMLU"])
|
||||
def test_ctx_pp_gen_tp_asymmetric(self, ctx_pp, gen_tp, testset):
|
||||
if ctx_pp * gen_tp * 2 > get_device_count():
|
||||
if ctx_pp + gen_tp > get_device_count():
|
||||
pytest.skip(
|
||||
f"Not enough devices for ctx_pp={ctx_pp}*gen_tp={gen_tp} test")
|
||||
f"Not enough devices for ctx_pp={ctx_pp}+gen_tp={gen_tp} test")
|
||||
return run_parallel_test(self.MODEL_NAME, self.MODEL_PATH, ctx_pp, 1, 1,
|
||||
gen_tp, 1, 1, [get_accuracy_task(testset)])
|
||||
|
||||
|
||||
@ -1565,6 +1565,11 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
@parametrize_with_ids("moe_backend", ["CUTLASS", "TRTLLM"])
|
||||
def test_nvfp4(self, fp8kv, attention_dp, cuda_graph, overlap_scheduler,
|
||||
torch_compile, mtp_nextn, moe_backend):
|
||||
if moe_backend == "TRTLLM" and (get_sm_version() == 120
|
||||
or get_sm_version() == 121):
|
||||
pytest.skip(
|
||||
"MOE TRTLLM backend does not support SM version 120 or 121")
|
||||
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.75)
|
||||
torch_compile_config = TorchCompileConfig(
|
||||
enable_fullgraph=True,
|
||||
@ -1613,8 +1618,10 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
torch_compile, mtp_nextn, moe_backend):
|
||||
if torch_compile and pp_size > 1:
|
||||
pytest.skip("PP with torch.compile is not supported yet.")
|
||||
if moe_backend == "TRTLLM" and get_sm_version() == 120:
|
||||
pytest.skip("MOE TRTLLM backend does not support SM version 120")
|
||||
if moe_backend == "TRTLLM" and (get_sm_version() == 120
|
||||
or get_sm_version() == 121):
|
||||
pytest.skip(
|
||||
"MOE TRTLLM backend does not support SM version 120 or 121")
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.75)
|
||||
# Picewise Cuda Graph cannot be enabled for nvfp4 attention dp.
|
||||
torch_compile_config = TorchCompileConfig(
|
||||
@ -1885,6 +1892,11 @@ class TestDeepSeekR1(LlmapiAccuracyTestHarness):
|
||||
def test_nvfp4_multi_gpus(self, tp_size, pp_size, ep_size, mtp_nextn, fp8kv,
|
||||
attention_dp, cuda_graph, overlap_scheduler,
|
||||
max_batch_size, moe_backend):
|
||||
if moe_backend == "TRTLLM" and (get_sm_version() == 120
|
||||
or get_sm_version() == 121):
|
||||
pytest.skip(
|
||||
"MOE TRTLLM backend does not support SM version 120 or 121")
|
||||
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.70)
|
||||
pytorch_config = dict(
|
||||
disable_overlap_scheduler=not overlap_scheduler,
|
||||
@ -2509,6 +2521,11 @@ class TestQwen3_30B_A3B(LlmapiAccuracyTestHarness):
|
||||
torch_compile,
|
||||
):
|
||||
|
||||
if moe_backend == "TRTLLM" and (get_sm_version() == 120
|
||||
or get_sm_version() == 121):
|
||||
pytest.skip(
|
||||
"MOE TRTLLM backend does not support SM version 120 or 121")
|
||||
|
||||
torch_compile_config = TorchCompileConfig(
|
||||
enable_fullgraph=True,
|
||||
enable_piecewise_cuda_graph=cuda_graph,
|
||||
@ -2700,6 +2717,11 @@ class TestQwen3_235B_A22B(LlmapiAccuracyTestHarness):
|
||||
def test_nvfp4(self, tp_size, pp_size, ep_size, attention_dp, cuda_graph,
|
||||
overlap_scheduler, moe_backend, eagle3):
|
||||
|
||||
if moe_backend == "TRTLLM" and (get_sm_version() == 120
|
||||
or get_sm_version() == 121):
|
||||
pytest.skip(
|
||||
"MOE TRTLLM backend does not support SM version 120 or 121")
|
||||
|
||||
pytorch_config = dict(
|
||||
disable_overlap_scheduler=not overlap_scheduler,
|
||||
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
|
||||
@ -2779,7 +2801,7 @@ class TestPhi4MiniInstruct(LlmapiAccuracyTestHarness):
|
||||
MODEL_PATH = f"{llm_models_root()}/Phi-4-mini-instruct"
|
||||
|
||||
def test_auto_dtype(self):
|
||||
with LLM(self.MODEL_PATH) as llm:
|
||||
with LLM(self.MODEL_PATH, max_seq_len=4096) as llm:
|
||||
task = CnnDailymail(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
@ -3046,10 +3068,8 @@ class TestGPTOSS(LlmapiAccuracyTestHarness):
|
||||
|
||||
class TestEXAONE4(LlmapiAccuracyTestHarness):
|
||||
MODEL_NAME = "LGAI-EXAONE/EXAONE-4.0-32B"
|
||||
kv_cache_config = KvCacheConfig(
|
||||
enable_block_reuse=False,
|
||||
enable_partial_reuse=False,
|
||||
max_attention_window=[4096, 4096, 4096, 131072])
|
||||
kv_cache_config = KvCacheConfig(enable_block_reuse=False,
|
||||
enable_partial_reuse=False)
|
||||
|
||||
def test_auto_dtype(self):
|
||||
model_path = f"{llm_models_root()}/EXAONE-4.0-32B"
|
||||
|
||||
@ -8,7 +8,7 @@ disable_overlap_scheduler: True
|
||||
context_servers:
|
||||
num_instances: 1
|
||||
max_num_tokens: 512
|
||||
max_batch_size: 256
|
||||
max_batch_size: 64
|
||||
cache_transceiver_config:
|
||||
backend: DEFAULT
|
||||
urls:
|
||||
@ -16,7 +16,7 @@ context_servers:
|
||||
generation_servers:
|
||||
num_instances: 1
|
||||
max_num_tokens: 256
|
||||
max_batch_size: 128
|
||||
max_batch_size: 32
|
||||
cache_transceiver_config:
|
||||
backend: DEFAULT
|
||||
urls:
|
||||
|
||||
@ -1302,7 +1302,7 @@ def get_config_for_benchmark(model_root, backend):
|
||||
"num_instances": 1,
|
||||
"max_batch_size": 2,
|
||||
"max_num_tokens": 384,
|
||||
"max_seq_len": 320,
|
||||
"max_seq_len": 384,
|
||||
"tensor_parallel_size": 1,
|
||||
"pipeline_parallel_size": 1,
|
||||
"disable_overlap_scheduler": True,
|
||||
@ -1318,7 +1318,7 @@ def get_config_for_benchmark(model_root, backend):
|
||||
"pipeline_parallel_size": 1,
|
||||
"max_batch_size": 2,
|
||||
"max_num_tokens": 384,
|
||||
"max_seq_len": 320,
|
||||
"max_seq_len": 384,
|
||||
"cache_transceiver_config": {
|
||||
"backend": backend,
|
||||
"max_tokens_in_buffer": 512,
|
||||
|
||||
@ -36,6 +36,42 @@ MODEL_PATHS = {
|
||||
}
|
||||
|
||||
|
||||
def mpi_publish_name():
|
||||
port_name = None
|
||||
try:
|
||||
port_name = MPI.Open_port()
|
||||
MPI.Publish_name('my_port', port_name)
|
||||
except MPI.Exception as e:
|
||||
print(f"Error publishing port name: {e}")
|
||||
raise e
|
||||
except Exception as e:
|
||||
print(f"Unexpected error publishing port name: {e}")
|
||||
raise e
|
||||
|
||||
return port_name
|
||||
|
||||
|
||||
def mpi_initialize_intercomm(port_name):
|
||||
intercomm = None
|
||||
try:
|
||||
intercomm = MPI.COMM_SELF.Accept(port_name)
|
||||
except MPI.Exception as e:
|
||||
print(f"Error accepting intercomm: {e}", flush=True)
|
||||
raise
|
||||
except Exception as e:
|
||||
print(f"Unexpected error accepting intercomm: {e}", flush=True)
|
||||
raise
|
||||
return intercomm
|
||||
|
||||
|
||||
def mpi_send_termination_request(intercomm):
|
||||
if intercomm is not None:
|
||||
# Send termination requests
|
||||
intercomm.send(None, dest=0, tag=MPI_REQUEST)
|
||||
intercomm.send(None, dest=1, tag=MPI_REQUEST)
|
||||
print("Sent termination requests to the workers.")
|
||||
|
||||
|
||||
def model_path(model_name):
|
||||
llm_models_root = os.environ["LLM_MODELS_ROOT"]
|
||||
for name, path in MODEL_PATHS.items():
|
||||
@ -48,8 +84,15 @@ async def run_worker(kv_cache_config, cache_transceiver_config, pytorch_config,
|
||||
model_name, rank):
|
||||
assert isinstance(pytorch_config, dict)
|
||||
print(f"Running worker {rank}")
|
||||
port_name = MPI.Lookup_name('my_port')
|
||||
intercomm = MPI.COMM_WORLD.Connect(port_name)
|
||||
try:
|
||||
port_name = MPI.Lookup_name('my_port')
|
||||
intercomm = MPI.COMM_WORLD.Connect(port_name)
|
||||
except MPI.Exception as e:
|
||||
print(f"Error publishing port name: {e}")
|
||||
raise e
|
||||
except Exception as e:
|
||||
print(f"Unexpected error publishing port name: {e}")
|
||||
raise e
|
||||
|
||||
session = MPI.COMM_WORLD.Split(color=rank, key=0)
|
||||
set_mpi_comm(session)
|
||||
@ -139,8 +182,7 @@ def verify_disaggregated(model, generation_overlap, enable_cuda_graph, prompt,
|
||||
zip(kv_cache_configs, cache_transceiver_configs, worker_pytorch_configs,
|
||||
model_names, ranks))
|
||||
|
||||
port_name = MPI.Open_port()
|
||||
MPI.Publish_name('my_port', port_name)
|
||||
port_name = mpi_publish_name()
|
||||
|
||||
with MPIPoolExecutor(max_workers=2, env={"UCX_TLS": "^ib"}) as executor:
|
||||
futures = []
|
||||
@ -152,9 +194,10 @@ def verify_disaggregated(model, generation_overlap, enable_cuda_graph, prompt,
|
||||
print(f"Error in worker {worker_arg}: {e}")
|
||||
raise e
|
||||
|
||||
intercomm = None
|
||||
try:
|
||||
print("Launched all the workers.")
|
||||
intercomm = MPI.COMM_SELF.Accept(port_name)
|
||||
print("Launched all the workers.", flush=True)
|
||||
intercomm = mpi_initialize_intercomm(port_name)
|
||||
|
||||
for _ in range(2):
|
||||
intercomm.recv(tag=MPI_READY)
|
||||
@ -187,14 +230,15 @@ def verify_disaggregated(model, generation_overlap, enable_cuda_graph, prompt,
|
||||
output = responses[0]
|
||||
assert output[0].text == expected_output
|
||||
assert output[0].token_ids == expected_output_ids
|
||||
|
||||
except Exception as e:
|
||||
print(f"Exception encountered: {e}", flush=True)
|
||||
raise e
|
||||
finally:
|
||||
# Send termination requests
|
||||
intercomm.send(None, dest=0, tag=MPI_REQUEST)
|
||||
intercomm.send(None, dest=1, tag=MPI_REQUEST)
|
||||
print("Sent termination requests to the workers.")
|
||||
print("Sending termination request", flush=True)
|
||||
mpi_send_termination_request(intercomm)
|
||||
|
||||
# Wait for all futures to complete
|
||||
print("Waiting for all workers to terminate. ", flush=True)
|
||||
for future in futures:
|
||||
future.result()
|
||||
print("All workers terminated.")
|
||||
@ -282,8 +326,7 @@ def test_disaggregated_llama_context_capacity(model, enable_cuda_graph,
|
||||
zip(kv_cache_configs, cache_transceiver_configs, worker_pytorch_configs,
|
||||
model_names, ranks))
|
||||
|
||||
port_name = MPI.Open_port()
|
||||
MPI.Publish_name('my_port', port_name)
|
||||
port_name = mpi_publish_name()
|
||||
|
||||
prompt = "European Union is a political and economic union of 27 countries. The European Union is headquartered in Brussels, Belgium. The first president of the European Union was Jean-Claude Juncker. The current president is Ursula von der Leyen. The European Union is a major economic and political entity."
|
||||
|
||||
@ -297,9 +340,10 @@ def test_disaggregated_llama_context_capacity(model, enable_cuda_graph,
|
||||
print(f"Error in worker {worker_arg}: {e}")
|
||||
raise e
|
||||
|
||||
intercomm = None
|
||||
try:
|
||||
print("Launched all the workers.")
|
||||
intercomm = MPI.COMM_SELF.Accept(port_name)
|
||||
intercomm = mpi_initialize_intercomm(port_name)
|
||||
|
||||
for _ in range(2):
|
||||
intercomm.recv(tag=MPI_READY)
|
||||
@ -334,11 +378,11 @@ def test_disaggregated_llama_context_capacity(model, enable_cuda_graph,
|
||||
intercomm.send(requests, dest=1, tag=MPI_REQUEST)
|
||||
output = intercomm.recv(source=1, tag=MPI_RESULT)
|
||||
|
||||
except MPI.Exception as e:
|
||||
print(f"MPI Error")
|
||||
raise e
|
||||
finally:
|
||||
# Send termination requests
|
||||
intercomm.send(None, dest=0, tag=MPI_REQUEST)
|
||||
intercomm.send(None, dest=1, tag=MPI_REQUEST)
|
||||
print("Sent termination requests to the workers.")
|
||||
mpi_send_termination_request(intercomm)
|
||||
|
||||
# Wait for all futures to complete
|
||||
for future in futures:
|
||||
@ -387,8 +431,7 @@ def test_disaggregated_spec_dec_batch_slot_limit(model, spec_dec_model_path,
|
||||
zip(kv_cache_configs, cache_transceiver_configs, worker_pytorch_configs,
|
||||
model_names, ranks))
|
||||
|
||||
port_name = MPI.Open_port()
|
||||
MPI.Publish_name('my_port', port_name)
|
||||
port_name = mpi_publish_name()
|
||||
|
||||
prompt = "What is the capital of Germany?"
|
||||
|
||||
@ -402,9 +445,10 @@ def test_disaggregated_spec_dec_batch_slot_limit(model, spec_dec_model_path,
|
||||
print(f"Error in worker {worker_arg}: {e}")
|
||||
raise e
|
||||
|
||||
intercomm = None
|
||||
try:
|
||||
print("Launched all the workers.")
|
||||
intercomm = MPI.COMM_SELF.Accept(port_name)
|
||||
intercomm = mpi_initialize_intercomm(port_name)
|
||||
|
||||
for _ in range(2):
|
||||
intercomm.recv(tag=MPI_READY)
|
||||
@ -438,11 +482,11 @@ def test_disaggregated_spec_dec_batch_slot_limit(model, spec_dec_model_path,
|
||||
intercomm.send(requests, dest=1, tag=MPI_REQUEST)
|
||||
output = intercomm.recv(source=1, tag=MPI_RESULT)
|
||||
|
||||
except MPI.Exception as e:
|
||||
print(f"MPI Error")
|
||||
raise e
|
||||
finally:
|
||||
# Send termination requests
|
||||
intercomm.send(None, dest=0, tag=MPI_REQUEST)
|
||||
intercomm.send(None, dest=1, tag=MPI_REQUEST)
|
||||
print("Sent termination requests to the workers.")
|
||||
mpi_send_termination_request(intercomm)
|
||||
|
||||
# Wait for all futures to complete
|
||||
for future in futures:
|
||||
|
||||
@ -23,7 +23,7 @@ class PythonVenvRunnerImpl(PythonRunnerInterface):
|
||||
venv_dir (str): Path to the virtualenv root directory, or None if this is
|
||||
an externally-built virtualenv
|
||||
venv_bin (str): Path to the Python executable to use when running tests
|
||||
workspace (str): Path to the TURTLE workspace
|
||||
workspace (str): Path to the test workspace
|
||||
"""
|
||||
|
||||
def __init__(self, pip_opts, venv_dir, venv_bin, workspace):
|
||||
|
||||
@ -189,7 +189,7 @@ class GPUClockLock:
|
||||
# Initialize thread
|
||||
self._thread = threading.Thread(
|
||||
target=self._monitoring_thread,
|
||||
name="TURTLE - GPUMonitor",
|
||||
name="LLM Test - GPUMonitor",
|
||||
kwargs={"interval_ms": self._interval_ms})
|
||||
self._thread.daemon = True
|
||||
self._thread.start()
|
||||
|
||||
@ -181,10 +181,19 @@ def get_model_yaml_config(model_label: str,
|
||||
|
||||
# lora-specific change for pytorch
|
||||
if 'pytorch' in model_label and 'loras' in model_label:
|
||||
# Derive the requested number of adapters from model_label (segment like "loras:X")
|
||||
lora_count = 1
|
||||
for part in model_label.split('-'):
|
||||
if part.startswith('loras:'):
|
||||
lora_count = max(1, int(part.split(':', 1)[1]))
|
||||
break
|
||||
|
||||
lora_config = {
|
||||
'lora_config': {
|
||||
'lora_dir': lora_dirs if lora_dirs is not None else [],
|
||||
'max_lora_rank': 64
|
||||
'max_lora_rank': 64,
|
||||
'max_loras': lora_count,
|
||||
'max_cpu_loras': lora_count,
|
||||
}
|
||||
}
|
||||
if 'phi_4_multimodal_instruct' in model_label:
|
||||
|
||||
@ -1755,7 +1755,7 @@ def parse_output(text):
|
||||
for item in text_lists:
|
||||
item = item.replace(os.linesep, "")
|
||||
while True:
|
||||
match = re.search(r"(Generated text: \'(.*?)\')", item,
|
||||
match = re.search(r'Generated text: ([\'"])(.*?)\1', item,
|
||||
re.MULTILINE)
|
||||
if match is None:
|
||||
break
|
||||
@ -2299,7 +2299,8 @@ def test_ptp_quickstart_advanced_mixed_precision(llm_root, llm_venv):
|
||||
marks=pytest.mark.skip_less_device_memory(80000)),
|
||||
pytest.param("gemma-3-27b-it",
|
||||
"gemma/gemma-3-27b-it",
|
||||
marks=pytest.mark.skip_less_device_memory(80000)),
|
||||
marks=(skip_post_blackwell,
|
||||
pytest.mark.skip_less_device_memory(80000))),
|
||||
])
|
||||
def test_ptp_quickstart_multimodal(llm_root, llm_venv, model_name, model_path,
|
||||
modality, use_cuda_graph):
|
||||
@ -2407,9 +2408,9 @@ def test_ptp_quickstart_multimodal(llm_root, llm_venv, model_name, model_path,
|
||||
},
|
||||
"gemma-3-27b-it": {
|
||||
"image": [
|
||||
["dramatic", "turbulent", "waves", "ocean", "overcast"],
|
||||
["half", "dome", "yosemite", "landmark", "rounded"],
|
||||
["flowing", "traffic", "vehicles", "road", "Changi"],
|
||||
["natural", "turbulent", "dramatic", "scene", "wave"],
|
||||
["image", "famous", "rock", "granite", "landmark"],
|
||||
["traffic", "moderate", "heavy", "flowing", "cars"],
|
||||
],
|
||||
},
|
||||
}
|
||||
@ -2600,9 +2601,10 @@ def test_ptp_quickstart_multimodal_phi4mm(llm_root, llm_venv, modality):
|
||||
@pytest.mark.skip_less_device(2)
|
||||
@pytest.mark.skip_less_device_memory(80000)
|
||||
@pytest.mark.parametrize("model_name,model_path", [
|
||||
("gemma-3-27b-it", "gemma/gemma-3-27b-it"),
|
||||
("mistral-small-3.1-24b-instruct", "Mistral-Small-3.1-24B-Instruct-2503"),
|
||||
("Phi-4-multimodal-instruct", "multimodals/Phi-4-multimodal-instruct"),
|
||||
pytest.param(
|
||||
"gemma-3-27b-it", "gemma/gemma-3-27b-it", marks=skip_post_blackwell),
|
||||
])
|
||||
def test_ptp_quickstart_multimodal_2gpu(llm_root, llm_venv, model_name,
|
||||
model_path):
|
||||
@ -2645,8 +2647,8 @@ def test_ptp_quickstart_multimodal_2gpu(llm_root, llm_venv, model_name,
|
||||
},
|
||||
"Phi-4-multimodal-instruct": {
|
||||
"image": [
|
||||
["image", "depicts", "mountain", "half", "rock"],
|
||||
["road", "car", "lane", "traffic", "bus"],
|
||||
["object", "mountain", "weather", "clear", "clouds"],
|
||||
["traffic", "road", "vehicles", "cars", "bus"],
|
||||
],
|
||||
},
|
||||
}
|
||||
@ -2674,6 +2676,8 @@ def test_ptp_quickstart_multimodal_2gpu(llm_root, llm_venv, model_name,
|
||||
cmd.append("--image_format=pil")
|
||||
cmd.append("--attention_backend=FLASHINFER")
|
||||
cmd.append("--disable_kv_cache_reuse")
|
||||
cmd.append("--kv_cache_fraction=0.5")
|
||||
cmd.append("--max_seq_len=1024")
|
||||
elif model_name == "Phi-4-multimodal-instruct":
|
||||
# Set max_seq_len to 4096 to use short rope factor.
|
||||
cmd.append("--max_seq_len=4096")
|
||||
@ -2702,9 +2706,10 @@ def test_ptp_quickstart_multimodal_2gpu(llm_root, llm_venv, model_name,
|
||||
|
||||
@pytest.mark.skip_less_device_memory(80000)
|
||||
@pytest.mark.parametrize("model_name,model_path", [
|
||||
("gemma-3-27b-it", "gemma/gemma-3-27b-it"),
|
||||
("mistral-small-3.1-24b-instruct", "Mistral-Small-3.1-24B-Instruct-2503"),
|
||||
("Phi-4-multimodal-instruct", "multimodals/Phi-4-multimodal-instruct"),
|
||||
pytest.param(
|
||||
"gemma-3-27b-it", "gemma/gemma-3-27b-it", marks=skip_post_blackwell),
|
||||
])
|
||||
def test_ptp_quickstart_multimodal_multiturn(llm_root, llm_venv, model_name,
|
||||
model_path):
|
||||
@ -2770,6 +2775,9 @@ def test_ptp_quickstart_multimodal_multiturn(llm_root, llm_venv, model_name,
|
||||
cmd.append("--image_format=pil")
|
||||
cmd.append("--attention_backend=FLASHINFER")
|
||||
cmd.append("--disable_kv_cache_reuse")
|
||||
cmd.append("--kv_cache_fraction=0.5")
|
||||
cmd.append("--max_seq_len=1024")
|
||||
|
||||
elif model_name == "Phi-4-multimodal-instruct":
|
||||
# Set max_seq_len to 4096 to use short rope factor.
|
||||
cmd.append("--max_seq_len=4096")
|
||||
|
||||
@ -69,11 +69,6 @@ def test_case_name(request):
|
||||
return request.node.nodeid
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def output_dir(request):
|
||||
return request.config._trt_config["output_dir"]
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def llm_backend_root():
|
||||
llm_root = os.environ.get("LLM_ROOT", find_repo_root())
|
||||
@ -655,10 +650,10 @@ def install_root_requirements(llm_backend_root):
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def output_dir(request):
|
||||
if USE_TURTLE:
|
||||
return request.config._trt_config["output_dir"]
|
||||
else:
|
||||
return request.config.getoption("--output-dir")
|
||||
output = request.config.getoption("--output-dir")
|
||||
if output:
|
||||
os.makedirs(str(output), exist_ok=True)
|
||||
return output
|
||||
|
||||
|
||||
def deselect_by_regex(regexp, items, test_prefix, config):
|
||||
|
||||
@ -39,6 +39,13 @@ l0_a10:
|
||||
- test_e2e.py::test_openai_chat_example[pytorch] TIMEOUT (90)
|
||||
- test_e2e.py::test_trtllm_bench_request_rate_and_concurrency[enable_concurrency-]
|
||||
- test_e2e.py::test_trtllm_bench_invalid_token_pytorch[TinyLlama-1.1B-Chat-v1.0-TinyLlama-1.1B-Chat-v1.0]
|
||||
# llmapi
|
||||
- unittest/llmapi/test_llm_utils.py
|
||||
- unittest/llmapi/test_gc_utils.py
|
||||
- unittest/llmapi/test_reasoning_parser.py
|
||||
- unittest/llmapi/test_serialization.py
|
||||
- unittest/llmapi/test_utils.py
|
||||
- unittest/llmapi/test_llm_args.py
|
||||
- condition:
|
||||
ranges:
|
||||
system_gpu_count:
|
||||
@ -114,12 +121,6 @@ l0_a10:
|
||||
- unittest/bindings
|
||||
- unittest/test_model_runner_cpp.py
|
||||
- unittest/llmapi/test_build_cache.py
|
||||
- unittest/llmapi/test_llm_utils.py
|
||||
- unittest/llmapi/test_gc_utils.py
|
||||
- unittest/llmapi/test_reasoning_parser.py
|
||||
- unittest/llmapi/test_serialization.py
|
||||
- unittest/llmapi/test_utils.py
|
||||
- unittest/llmapi/test_llm_args.py
|
||||
- accuracy/test_cli_flow.py::TestGpt2::test_auto_dtype # 1 min
|
||||
- accuracy/test_cli_flow.py::TestGpt2::test_beam_search # 1 min
|
||||
- accuracy/test_cli_flow.py::TestGpt2::test_beam_search_large # 6 mins
|
||||
|
||||
@ -106,7 +106,7 @@ l0_h100:
|
||||
- test_e2e.py::test_trtllm_bench_help_sanity[meta-llama/Llama-3.1-8B]
|
||||
- test_e2e.py::test_openai_chat_harmony
|
||||
- test_e2e.py::test_openai_responses
|
||||
- test_e2e.py::test_ptp_quickstart_multimodal[gemma-3-27b-it-gemma/gemma-3-27b-it-image-True]
|
||||
- test_e2e.py::test_ptp_quickstart_multimodal[gemma-3-27b-it-gemma/gemma-3-27b-it-image-True] TIMEOUT (90)
|
||||
# ------------- AutoDeploy tests ---------------
|
||||
- accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype
|
||||
- condition:
|
||||
@ -227,7 +227,7 @@ l0_h100:
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_guided_decoding[xgrammar-mtp_nextn=0]
|
||||
- accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_fp8_prequantized
|
||||
- accuracy/test_llm_api_pytorch.py::TestGemma3_27BInstruct::test_fp8_prequantized
|
||||
- accuracy/test_llm_api_pytorch.py::TestGemma3_27BInstruct::test_auto_dtype
|
||||
- accuracy/test_llm_api_pytorch.py::TestGemma3_27BInstruct::test_auto_dtype TIMEOUT (90)
|
||||
- accuracy/test_llm_api_pytorch.py::TestMistralSmall24B::test_auto_dtype
|
||||
- accuracy/test_llm_api_pytorch.py::TestMistralSmall24B::test_fp8
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8_block_scales[latency-torch_compile=False]
|
||||
|
||||
@ -252,11 +252,9 @@ examples/test_recurrentgemma.py::test_llm_recurrentgemma_1gpu[use_cpp_session-re
|
||||
examples/test_bert.py::test_llm_bert_general[compare_hf-enable_remove_input_padding-use_attention_plugin-enable_context_fmha-tp:1-pp:1-float16-BertForSequenceClassification-bert/bert-base-uncased-yelp-polarity] SKIP (https://nvbugs/5421989)
|
||||
examples/test_bert.py::test_llm_bert_general[compare_hf-enable_remove_input_padding-use_attention_plugin-enable_context_fmha-tp:1-pp:1-float16-RobertaForSequenceClassification-bert/twitter-roberta-base-emotion] SKIP (https://nvbugs/5421989)
|
||||
examples/test_granite.py::test_granite_bf16_lora[granite-3.0-1b-a400m-instruct] SKIP (https://nvbugs/5431132)
|
||||
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=1-ctx_pp=4] SKIP (https://nvbugs/5434320)
|
||||
accuracy/test_llm_api.py::TestLlama3_2_1B::test_int4_awq_int8_kv_cache SKIP (https://nvbugs/5433541)
|
||||
accuracy/test_llm_api.py::TestLlama3_2_1B::test_fp8_pp2 SKIP (https://nvbugs/5433541)
|
||||
accuracy/test_llm_api_pytorch.py::TestPhi4MiniInstruct::test_auto_dtype SKIP (https://nvbugs/5433545)
|
||||
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[GSM8K-gen_tp=2-ctx_pp=4] SKIP (https://nvbugs/5434320)
|
||||
examples/test_gemma.py::test_hf_gemma_fp8_base_bf16_multi_lora[gemma-2-9b-it] SKIP (https://nvbugs/5434451)
|
||||
examples/test_gemma.py::test_hf_gemma_fp8_base_bf16_multi_lora[gemma-2-27b-it] SKIP (https://nvbugs/5434451)
|
||||
examples/test_gemma.py::test_hf_gemma_fp8_base_bf16_multi_lora[gemma-3-1b-it] SKIP (https://nvbugs/5434451)
|
||||
@ -278,7 +276,6 @@ triton_server/test_triton.py::test_gpt_ib_ptuning[gpt-ib-ptuning] SKIP (https://
|
||||
triton_server/test_triton.py::test_mistral_ib_mm[mistral-ib-mm] SKIP (https://nvbugs/5371343)
|
||||
triton_server/test_triton.py::test_t5_ib[t5-ib] SKIP (https://nvbugs/5456482)
|
||||
triton_server/test_triton_llm.py::test_gpt_speculative_decoding_bls[False-False-1---False-True-True-0-128-disableDecoupleMode-inflight_fused_batching-disableTrtOverlap-0.2-guaranteed_no_evict---1-1-1-False-ensemble] SKIP (https://nvbugs/5456485)
|
||||
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[GSM8K-gen_tp=1-ctx_pp=4] SKIP (https://nvbugs/5434320)
|
||||
accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4[latency_moe_trtllm_eagle3] SKIP (https://nvbugs/5437384)
|
||||
llmapi/test_llm_examples.py::test_llmapi_speculative_decoding_mtp SKIP (https://nvbugs/5461796)
|
||||
accuracy/test_llm_api.py::TestLlama3_1_8BInstruct::test_gather_generation_logits_cuda_graph SKIP (https://nvbugs/5365525)
|
||||
@ -287,6 +284,12 @@ examples/test_phi.py::test_phi_fp8_with_bf16_lora[Phi-3-small-128k-instruct] SKI
|
||||
examples/test_phi.py::test_phi_fp8_with_bf16_lora[Phi-3.5-mini-instruct] SKIP (https://nvbugs/5465143)
|
||||
examples/test_phi.py::test_phi_fp8_with_bf16_lora[Phi-4-mini-instruct] SKIP (https://nvbugs/5465143)
|
||||
examples/test_llama.py::test_llm_llama_v1_2gpu_summary[llama-7b-nb:4-enable_auto_parallel] SKIP (https://nvbugs/5465173)
|
||||
test_e2e.py::test_ptp_quickstart_multimodal[llava-v1.6-mistral-7b-llava-v1.6-mistral-7b-hf-image-False] SKIP (https://nvbugs/5444095)
|
||||
full:GB200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen1.5_7b_chat-enable_fmha_fp32_acc-enable_plugin-tp2pp2-nb:4] SKIP (https://nvbugs/5247837)
|
||||
full:GB200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen2_7b_instruct-enable_fmha_fp32_acc-enable_plugin-tp2pp2-nb:4] SKIP (https://nvbugs/5247837)
|
||||
full:GB200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen2_vl_7b_instruct-enable_fmha_fp32_acc-enable_plugin-tp2pp2-nb:4] SKIP (https://nvbugs/5359696)
|
||||
full:GB200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen2.5_7b_chat-enable_fmha_fp32_acc-enable_plugin-tp2pp2-nb:4] SKIP (https://nvbugs/5247837)
|
||||
accuracy/test_cli_flow.py::TestLlama3_8BInstructGradient1048k::test_long_context_ppl SKIP (https://nvbugs/5413362)
|
||||
disaggregated/test_disaggregated.py::test_disaggregated_diff_max_tokens[TinyLlama-1.1B-Chat-v1.0] SKIP (https://nvbugs/5451272)
|
||||
disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_tp1_single_gpu_mtp[DeepSeek-V3-Lite-fp8] SKIP (https://nvbugs/5465642)
|
||||
examples/test_multimodal.py::test_llm_multimodal_general[Mistral-Small-3.1-24B-Instruct-2503-pp:1-tp:1-bfloat16-bs:1-cpp_e2e:False-nb:1] SKIP (https://nvbugs/5431146)
|
||||
@ -340,17 +343,6 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus_online_ep
|
||||
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False] SKIP (https://nvbugs/5488118)
|
||||
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True] SKIP (https://nvbugs/5488118)
|
||||
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ngram SKIP (https://nvbugs/5488118)
|
||||
full:GB200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen1.5_7b_chat-enable_fmha_fp32_acc-enable_plugin-tp2pp2-nb:4] SKIP (https://nvbugs/5247837)
|
||||
full:GB200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen2_7b_instruct-enable_fmha_fp32_acc-enable_plugin-tp2pp2-nb:4] SKIP (https://nvbugs/5247837)
|
||||
full:GB200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen2_vl_7b_instruct-enable_fmha_fp32_acc-enable_plugin-tp2pp2-nb:4] SKIP (https://nvbugs/5359696)
|
||||
full:GB200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen2.5_7b_chat-enable_fmha_fp32_acc-enable_plugin-tp2pp2-nb:4] SKIP (https://nvbugs/5247837)
|
||||
accuracy/test_cli_flow.py::TestLlama3_8BInstructGradient1048k::test_long_context_ppl SKIP (https://nvbugs/5413362)
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_tp8] SKIP (https://nvbugs/5455140)
|
||||
full:L40S/accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype[mtp_nextn=0-overlap_scheduler=False] SKIP (https://nvbugs/5347051)
|
||||
full:L40S/accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=False-overlap_scheduler=False] SKIP (https://nvbugs/5471106)
|
||||
full:L40S/accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[MMLU-tp2pp2] SKIP (https://nvbugs/5471108)
|
||||
test_e2e.py::test_multi_nodes_eval[llama4-models/nvidia/Llama-4-Maverick-17B-128E-Instruct-FP8-tp8pp2-mmlu] SKIP (https://nvbugs/5473781)
|
||||
accuracy/test_cli_flow.py::TestLlama3_8BInstructGradient1048k::test_long_context_ppl SKIP (https://nvbugs/5413362)
|
||||
stress_test/stress_test.py::test_run_stress_test[llama-v3-8b-instruct-hf_tp1-stress_time_300s_timeout_450s-GUARANTEED_NO_EVICT-pytorch-stress-test] SKIP (https://nvbugs/5474169)
|
||||
test_e2e.py::test_trtllm_bench_iteration_log[TRT-streaming-meta-llama/Llama-3.1-8B-llama-3.1-model/Meta-Llama-3.1-8B] SKIP (https://nvbugs/5448523)
|
||||
cpp/test_unit_tests.py::test_unit_tests[kernels-80] SKIP (https://nvbugs/5504078)
|
||||
|
||||
@ -17,6 +17,7 @@ def similar(a, b, threshold=0.9):
|
||||
return SequenceMatcher(None, a, b).ratio() >= threshold
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="https://nvbugs/5470782")
|
||||
@pytest.mark.parametrize("model_name", ["DeepSeek-V3-Lite"],
|
||||
ids=["deepseekv3_lite"])
|
||||
@pytest.mark.parametrize("backend", ["TRTLLM"], ids=["trtllm"])
|
||||
|
||||
205
tests/unittest/_torch/multimodal/test_fuse_input_embeds.py
Normal file
205
tests/unittest/_torch/multimodal/test_fuse_input_embeds.py
Normal file
@ -0,0 +1,205 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._torch.models.modeling_multimodal_utils import (
|
||||
filter_mm_token_from_input_ids, fuse_input_embeds)
|
||||
from tensorrt_llm._torch.modules.embedding import Embedding
|
||||
|
||||
|
||||
def make_embedding(num_embeddings: int = 100,
|
||||
hidden_size: int = 16,
|
||||
device: str = "cpu") -> Embedding:
|
||||
torch.manual_seed(0)
|
||||
emb = Embedding(num_embeddings=num_embeddings, embedding_dim=hidden_size)
|
||||
emb.weight.data.normal_(mean=0.0, std=0.02)
|
||||
return emb.to(device)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cpu"] +
|
||||
(["cuda"] if torch.cuda.is_available() else []))
|
||||
def test_filter_mm_token_from_input_ids_oov(device):
|
||||
vocab_size = 10
|
||||
# input_ids contains text (< vocab) and OOV mm tokens (>= vocab)
|
||||
input_ids = torch.tensor([1, 2, 11, 3, 12, 9, 15],
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
text_idx, mm_idx = filter_mm_token_from_input_ids(input_ids,
|
||||
vocab_size=vocab_size,
|
||||
mm_token_ids=None)
|
||||
|
||||
torch.testing.assert_close(text_idx.cpu(),
|
||||
torch.tensor([0, 1, 3, 5], dtype=torch.long))
|
||||
torch.testing.assert_close(mm_idx.cpu(),
|
||||
torch.tensor([2, 4, 6], dtype=torch.long))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cpu"] +
|
||||
(["cuda"] if torch.cuda.is_available() else []))
|
||||
def test_filter_mm_token_from_input_ids_explicit_ids(device):
|
||||
vocab_size = 100
|
||||
# All ids are < vocab; mm tokens are explicitly specified
|
||||
input_ids = torch.tensor([1, 2, 55, 3, 77, 9, 88],
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
mm_token_ids = torch.tensor([55, 77, 88], dtype=torch.long)
|
||||
text_idx, mm_idx = filter_mm_token_from_input_ids(input_ids,
|
||||
vocab_size=vocab_size,
|
||||
mm_token_ids=mm_token_ids)
|
||||
|
||||
torch.testing.assert_close(text_idx.cpu(),
|
||||
torch.tensor([0, 1, 3, 5], dtype=torch.long))
|
||||
torch.testing.assert_close(mm_idx.cpu(),
|
||||
torch.tensor([2, 4, 6], dtype=torch.long))
|
||||
|
||||
# Even with some ids > vocab, mm indices should still only match the given mm_token_ids
|
||||
input_ids = torch.tensor([1, 2, 55, 3, 77, 9, 88, 101, 102, 103],
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
_, mm_idx = filter_mm_token_from_input_ids(input_ids,
|
||||
vocab_size=vocab_size,
|
||||
mm_token_ids=mm_token_ids)
|
||||
torch.testing.assert_close(mm_idx.cpu(),
|
||||
torch.tensor([2, 4, 6], dtype=torch.long))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cpu"] +
|
||||
(["cuda"] if torch.cuda.is_available() else []))
|
||||
def test_fuse_input_embeds_empty_mm_returns_ids(device):
|
||||
emb = make_embedding(num_embeddings=20, hidden_size=8, device=device)
|
||||
input_ids = torch.tensor([1, 2, 3, 4], dtype=torch.long, device=device)
|
||||
|
||||
out_ids, out_embeds = fuse_input_embeds(emb,
|
||||
input_ids,
|
||||
mm_embeds=[],
|
||||
mm_token_ids=None)
|
||||
|
||||
# No mm embeddings => passthrough ids, no embeds fused
|
||||
assert out_embeds is None
|
||||
torch.testing.assert_close(out_ids, input_ids)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cpu"] +
|
||||
(["cuda"] if torch.cuda.is_available() else []))
|
||||
def test_fuse_input_embeds_mismatch_raises(device):
|
||||
emb = make_embedding(num_embeddings=50, hidden_size=8, device=device)
|
||||
|
||||
# Mix text (< vocab) and mm (>= vocab) tokens. Here vocab_size == 50
|
||||
input_ids = torch.tensor([1, 51, 2, 52, 3, 53],
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
|
||||
# Identify indices first to drive the lower-level CUDA fuse directly for mismatch
|
||||
text_idx, mm_idx = filter_mm_token_from_input_ids(
|
||||
input_ids, vocab_size=emb.num_embeddings)
|
||||
|
||||
# Provide the wrong number of mm embeddings (e.g., one short)
|
||||
hidden = 8
|
||||
true_mm_count = mm_idx.shape[0]
|
||||
wrong_mm = torch.randn(true_mm_count - 1, hidden, device=device)
|
||||
|
||||
with pytest.raises(ValueError, match="Multimodal token count mismatch"):
|
||||
fuse_input_embeds(emb,
|
||||
input_ids, [wrong_mm],
|
||||
mm_token_ids=None,
|
||||
text_token_indices=text_idx,
|
||||
mm_token_indices=mm_idx)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cpu"] +
|
||||
(["cuda"] if torch.cuda.is_available() else []))
|
||||
def test_fuse_input_embeds_success_oov_path(device):
|
||||
hidden = 8
|
||||
emb = make_embedding(num_embeddings=40, hidden_size=hidden, device=device)
|
||||
|
||||
# input ids: mix of text (<40) and mm (>=40)
|
||||
input_ids = torch.tensor([0, 1, 41, 2, 42, 3, 43, 4],
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
|
||||
# Build mm embeddings to match number of OOV positions
|
||||
text_idx, mm_idx = filter_mm_token_from_input_ids(
|
||||
input_ids, vocab_size=emb.num_embeddings)
|
||||
mm_emb = torch.randn(mm_idx.shape[0], hidden, device=device)
|
||||
|
||||
# kwargs path to produce fused embeddings
|
||||
out_ids, out_embeds = fuse_input_embeds(emb,
|
||||
input_ids,
|
||||
mm_embeds=[mm_emb],
|
||||
mm_token_ids=None,
|
||||
text_token_indices=text_idx,
|
||||
mm_token_indices=mm_idx)
|
||||
# integrated filtering path to produce fused embeddings (not kwargs path)
|
||||
out_ids_v2, out_embeds_v2 = fuse_input_embeds(emb,
|
||||
input_ids,
|
||||
mm_embeds=[mm_emb],
|
||||
mm_token_ids=None)
|
||||
|
||||
assert out_ids is None
|
||||
assert out_embeds is not None
|
||||
assert out_embeds.shape == (input_ids.numel(), hidden)
|
||||
|
||||
# Validate that text positions equal embedding lookup, and mm positions equal provided mm_emb
|
||||
text_idx, mm_idx2 = filter_mm_token_from_input_ids(
|
||||
input_ids, vocab_size=emb.num_embeddings)
|
||||
torch.testing.assert_close(mm_idx2, mm_idx)
|
||||
|
||||
ref_text = emb(input_ids[text_idx])
|
||||
torch.testing.assert_close(out_embeds[text_idx], ref_text)
|
||||
torch.testing.assert_close(
|
||||
out_embeds[mm_idx],
|
||||
mm_emb.to(dtype=out_embeds.dtype, device=out_embeds.device))
|
||||
torch.testing.assert_close(out_embeds_v2, out_embeds)
|
||||
torch.testing.assert_close(out_ids_v2, out_ids)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cpu"] +
|
||||
(["cuda"] if torch.cuda.is_available() else []))
|
||||
def test_fuse_input_embeds_kwargs_precedence_over_sentinel_and_ids(device):
|
||||
"""
|
||||
Ensure that when kwargs provide precomputed indices, they take precedence
|
||||
over both OOV-sentinel filtering and explicit mm_token_ids.
|
||||
"""
|
||||
hidden = 8
|
||||
vocab_size = 40
|
||||
emb = make_embedding(num_embeddings=vocab_size,
|
||||
hidden_size=hidden,
|
||||
device=device)
|
||||
|
||||
# Use vocab_size+1 as OOV sentinel
|
||||
oov_sentinel = vocab_size + 1
|
||||
input_ids = torch.tensor([0, oov_sentinel, 1, oov_sentinel, 2],
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
|
||||
# Precompute correct indices (kwargs path)
|
||||
text_idx, mm_idx = filter_mm_token_from_input_ids(input_ids,
|
||||
vocab_size=vocab_size,
|
||||
mm_token_ids=None)
|
||||
mm_emb = torch.randn(mm_idx.shape[0], hidden, device=device)
|
||||
|
||||
# Provide a deliberately incorrect mm_token_ids to ensure it is ignored
|
||||
bad_mm_token_ids = torch.tensor(
|
||||
[0], dtype=torch.long,
|
||||
device=device) # would misclassify index 0 as mm if used
|
||||
|
||||
out_ids, out_embeds = fuse_input_embeds(
|
||||
emb,
|
||||
input_ids,
|
||||
mm_embeds=[mm_emb],
|
||||
mm_token_ids=
|
||||
bad_mm_token_ids, # should be ignored because indices are provided
|
||||
text_token_indices=text_idx,
|
||||
mm_token_indices=mm_idx,
|
||||
)
|
||||
|
||||
# Validate outputs
|
||||
assert out_ids is None
|
||||
assert out_embeds is not None
|
||||
assert out_embeds.shape == (input_ids.numel(), hidden)
|
||||
|
||||
ref_text = emb(input_ids[text_idx])
|
||||
torch.testing.assert_close(out_embeds[text_idx], ref_text)
|
||||
torch.testing.assert_close(
|
||||
out_embeds[mm_idx],
|
||||
mm_emb.to(dtype=out_embeds.dtype, device=out_embeds.device),
|
||||
)
|
||||
@ -16,7 +16,6 @@ from tensorrt_llm.llmapi import (CudaGraphConfig, EagleDecodingConfig,
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="https://nvbugs/5461761")
|
||||
@pytest.mark.parametrize(
|
||||
"use_cuda_graph,attn_backend,disable_overlap_scheduler,enable_block_reuse,use_one_model,enable_chunked_prefill",
|
||||
[
|
||||
@ -27,7 +26,8 @@ sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
||||
[False, "TRTLLM", False, True, True, False],
|
||||
[True, "TRTLLM", False, True, True, False],
|
||||
[True, "TRTLLM", True, False, True, True],
|
||||
[True, "TRTLLM", True, False, False, True],
|
||||
# TODO: nvbugs/5461761
|
||||
# [True, "TRTLLM", True, False, False, True],
|
||||
])
|
||||
@pytest.mark.high_cuda_memory
|
||||
def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
|
||||
|
||||
189
tests/unittest/_torch/test_torch_sampler.py
Normal file
189
tests/unittest/_torch/test_torch_sampler.py
Normal file
@ -0,0 +1,189 @@
|
||||
import random
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._torch.pyexecutor.llm_request import convert_wordlist
|
||||
from tensorrt_llm._torch.pyexecutor.sampler import (BEAM_0, LlmRequest,
|
||||
TorchSampler)
|
||||
from tensorrt_llm._torch.pyexecutor.sampler_utils import produce_stop_words
|
||||
from tensorrt_llm.bindings import SamplingConfig
|
||||
from tensorrt_llm.bindings.executor import FinishReason
|
||||
|
||||
MAX_NUM_SEQUENCES = 128
|
||||
NOT_FINISHED = FinishReason.NOT_FINISHED
|
||||
STOP_WORDS = FinishReason.STOP_WORDS
|
||||
END_ID = FinishReason.END_ID
|
||||
LENGTH = FinishReason.LENGTH
|
||||
|
||||
|
||||
class RequestCase:
|
||||
MAX_NEW_TOKENS = 10
|
||||
seq_slots = random.sample(range(MAX_NUM_SEQUENCES), MAX_NUM_SEQUENCES)
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
prompt: list[int],
|
||||
new_tokens: list[int],
|
||||
finish_reasons: list[FinishReason],
|
||||
max_new_tokens: int = MAX_NEW_TOKENS,
|
||||
end_id: int = None,
|
||||
stop_words_list: list[list[int]] = None):
|
||||
seq_slot = self.seq_slots.pop() # random seq slot in MAX_NUM_SEQUENCES
|
||||
self.prompt = prompt
|
||||
self.request = LlmRequest(
|
||||
request_id=seq_slot,
|
||||
seq_slot=seq_slot,
|
||||
input_tokens=prompt,
|
||||
max_new_tokens=max_new_tokens,
|
||||
stop_words_list=convert_wordlist(stop_words_list)
|
||||
if stop_words_list is not None else None,
|
||||
end_id=end_id,
|
||||
sampling_config=SamplingConfig(),
|
||||
is_streaming=False,
|
||||
)
|
||||
assert len(new_tokens) == len(finish_reasons)
|
||||
self.new_tokens = new_tokens
|
||||
self.finish_reasons = finish_reasons
|
||||
|
||||
def __repr__(self):
|
||||
return f"RequestCase({self.prompt=}, {self.new_tokens=}, {self.finish_reasons=}, {self.request.max_new_tokens=}, {self.request.end_id=}, {self.request.stop_words_list=})"
|
||||
|
||||
@staticmethod
|
||||
def setup(requests: list["RequestCase"]):
|
||||
max_tokens = set(len(req.new_tokens) for req in requests)
|
||||
assert len(max_tokens) == 1
|
||||
max_draft_len = max_tokens.pop() - 1
|
||||
sampler_args = TorchSampler.Args(
|
||||
max_seq_len=20,
|
||||
max_draft_len=max_draft_len,
|
||||
# Fill with many more max requests than below, so we can test that write_finish_reasons uses seq_slots correctly
|
||||
max_num_sequences=MAX_NUM_SEQUENCES,
|
||||
max_beam_width=1,
|
||||
enable_mixed_sampler=False)
|
||||
sampler = TorchSampler(args=sampler_args)
|
||||
|
||||
# fill with garbage value so we can observe that finish reasons are filled with NOT_FINISHED before we write to them.
|
||||
sampler.store.finish_reasons.fill_(205)
|
||||
seq_slots = torch.tensor([req.request.py_seq_slot for req in requests],
|
||||
device="cuda",
|
||||
dtype=torch.int64)
|
||||
new_tokens = torch.tensor([req.new_tokens for req in requests],
|
||||
dtype=torch.int32,
|
||||
device="cuda").T
|
||||
sampler.store.new_tokens[:, seq_slots, BEAM_0] = new_tokens
|
||||
|
||||
def run():
|
||||
sampler._write_finish_reasons(
|
||||
[req.request for req in requests],
|
||||
finish_reasons=sampler.store.finish_reasons,
|
||||
new_tokens=sampler.store.new_tokens,
|
||||
seq_slots=seq_slots)
|
||||
|
||||
reasons = sampler.store.finish_reasons[:, seq_slots,
|
||||
BEAM_0].T.tolist()
|
||||
|
||||
for actual, request in zip(reasons, requests, strict=True):
|
||||
expected = request.finish_reasons
|
||||
msg = f"actual={[FinishReason(reason) for reason in actual]} != expected={expected}\nFor {request}"
|
||||
assert actual == [reason.value for reason in expected], msg
|
||||
|
||||
return run, sampler
|
||||
|
||||
|
||||
def test_write_finish_reasons():
|
||||
"""We don't really care about the finish reason past the first infraction, because we're not going to use it, although in some instance it is written anyway."""
|
||||
run, _ = RequestCase.setup([
|
||||
RequestCase(
|
||||
prompt=[13, 14],
|
||||
new_tokens=[60, 61, 62],
|
||||
# We pre-fill the finish reasons with NOT_FINISHED.
|
||||
finish_reasons=[NOT_FINISHED, NOT_FINISHED, NOT_FINISHED],
|
||||
),
|
||||
RequestCase(
|
||||
prompt=[7, 8, 6],
|
||||
stop_words_list=[[12, 13]],
|
||||
new_tokens=[12, 13, 60],
|
||||
finish_reasons=[NOT_FINISHED, STOP_WORDS, NOT_FINISHED],
|
||||
),
|
||||
RequestCase(
|
||||
prompt=[1, 2, 3, 4],
|
||||
end_id=99,
|
||||
new_tokens=[55, 99, 58],
|
||||
finish_reasons=[NOT_FINISHED, END_ID, NOT_FINISHED],
|
||||
),
|
||||
RequestCase(
|
||||
prompt=[4, 5, 6],
|
||||
max_new_tokens=2,
|
||||
new_tokens=[56, 57, 59],
|
||||
# The LENGTH check happens to not have an early exit
|
||||
finish_reasons=[NOT_FINISHED, LENGTH, LENGTH]),
|
||||
RequestCase(
|
||||
prompt=[1, 12],
|
||||
stop_words_list=[[12, 13], [14, 15]],
|
||||
new_tokens=[13, 14, 15],
|
||||
# We have an early exit specifically for stop words
|
||||
finish_reasons=[STOP_WORDS, NOT_FINISHED, NOT_FINISHED],
|
||||
),
|
||||
RequestCase(
|
||||
prompt=[1],
|
||||
max_new_tokens=2,
|
||||
end_id=99,
|
||||
stop_words_list=[[1, 12]],
|
||||
new_tokens=[12, 99, 63],
|
||||
# Different infractions are written to different places as we don't have an early exit between infractions
|
||||
finish_reasons=[STOP_WORDS, END_ID, LENGTH],
|
||||
),
|
||||
RequestCase(
|
||||
prompt=[1, 12, 56, 67, 68, 234, 678],
|
||||
stop_words_list=[[12, 56, 67, 68, 234, 678, 129, 182]],
|
||||
new_tokens=[129, 182, 600],
|
||||
# Notice the offending stop sequence is concatenated, as we lookback
|
||||
finish_reasons=[NOT_FINISHED, STOP_WORDS, NOT_FINISHED],
|
||||
),
|
||||
RequestCase(
|
||||
prompt=[1, 12],
|
||||
end_id=99,
|
||||
max_new_tokens=1,
|
||||
stop_words_list=[[1, 12, 99]],
|
||||
new_tokens=[99, 100, 101],
|
||||
# The latest infraction check overrides the earlier infraction checks, hence the first finish_reason is END_ID
|
||||
finish_reasons=[END_ID, LENGTH, LENGTH],
|
||||
),
|
||||
])
|
||||
run()
|
||||
|
||||
|
||||
def test_are_stop_words_isnt_called_when_no_stop_words():
|
||||
"""We don't want to call are_stop_words when there are no stop words because it's expensive"""
|
||||
|
||||
def stop_words_that_raises(*args, **kwargs):
|
||||
raise AssertionError
|
||||
|
||||
run_with_stop_words, sampler = RequestCase.setup([
|
||||
RequestCase(prompt=[1],
|
||||
stop_words_list=[[1]],
|
||||
new_tokens=[4],
|
||||
finish_reasons=[NOT_FINISHED])
|
||||
])
|
||||
sampler._are_stop_words = stop_words_that_raises
|
||||
with pytest.raises(AssertionError):
|
||||
run_with_stop_words()
|
||||
|
||||
run_without_stop_words, sampler = RequestCase.setup([
|
||||
RequestCase(prompt=[1], new_tokens=[4], finish_reasons=[NOT_FINISHED])
|
||||
])
|
||||
sampler._are_stop_words = stop_words_that_raises
|
||||
_ = run_without_stop_words()
|
||||
|
||||
|
||||
def test_produce_stop_words():
|
||||
for original in [
|
||||
[[]],
|
||||
[[1]],
|
||||
[[1, 2, 3]],
|
||||
[[1], [2, 3]],
|
||||
[[1, 2, 3], [4, 5]],
|
||||
[[10], [20], [30, 40], [50]],
|
||||
]:
|
||||
assert original == list(produce_stop_words(convert_wordlist(original)))
|
||||
@ -147,6 +147,10 @@ async def test_streaming(client: openai.AsyncOpenAI, model: str):
|
||||
collected_chunks = []
|
||||
collected_messages = []
|
||||
async for chunk in response:
|
||||
# Last streaming response will only contains usage info
|
||||
if len(chunk.choices) <= 0:
|
||||
continue
|
||||
|
||||
collected_chunks.append(chunk)
|
||||
collected_messages.append(chunk.choices[0].delta)
|
||||
|
||||
@ -198,6 +202,10 @@ async def test_streaming_tool_call(client: openai.AsyncOpenAI, model: str):
|
||||
reasoning_chunks: list[str] = []
|
||||
tool_arg_chunks: list[str] = []
|
||||
async for chunk in response:
|
||||
# Last streaming response will only contains usage info
|
||||
if len(chunk.choices) <= 0:
|
||||
continue
|
||||
|
||||
delta = chunk.choices[0].delta
|
||||
if hasattr(delta, "tool_calls") and delta.tool_calls:
|
||||
function = delta.tool_calls[0].function
|
||||
|
||||
@ -65,7 +65,10 @@ def engine_from_fp8_quantization(model_name):
|
||||
@pytest.fixture(scope="module")
|
||||
def server(model_name: str, engine_from_fp8_quantization: str):
|
||||
model_path = get_model_path(model_name)
|
||||
args = ["--tp_size", "2", "--tokenizer", model_path]
|
||||
args = [
|
||||
"--tp_size", "2", "--tokenizer", model_path, "--backend", "trt",
|
||||
"--max_num_tokens", "20480", "--max_batch_size", "128"
|
||||
]
|
||||
with RemoteOpenAIServer(engine_from_fp8_quantization,
|
||||
args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
@ -1,4 +1,6 @@
|
||||
from tensorrt_llm.llmapi.utils import ApiStatusRegistry
|
||||
from tensorrt_llm.llmapi import LlmArgs
|
||||
from tensorrt_llm.llmapi.utils import (ApiStatusRegistry,
|
||||
generate_api_docs_as_docstring)
|
||||
|
||||
|
||||
def test_api_status_registry():
|
||||
@ -24,3 +26,9 @@ def test_api_status_registry():
|
||||
pass
|
||||
|
||||
assert ApiStatusRegistry.get_api_status(App._my_method) == "beta"
|
||||
|
||||
|
||||
def test_generate_api_docs_as_docstring():
|
||||
doc = generate_api_docs_as_docstring(LlmArgs)
|
||||
assert ":tag:`beta`" in doc, "the label is not generated"
|
||||
print(doc)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user