Merge remote-tracking branch 'origin/main' into feat/b300_cu13

Signed-off-by: Xiwen Yu <13230610+VALLIS-NERIA@users.noreply.github.com>
This commit is contained in:
Xiwen Yu 2025-09-06 23:58:04 +08:00
commit 322db710dc
33 changed files with 1138 additions and 595 deletions

View File

@ -12,7 +12,7 @@ fi
PARSED_CMAKE_VERSION=$(echo $CMAKE_VERSION | sed 's/\.[0-9]*$//')
CMAKE_FILE_NAME="cmake-${CMAKE_VERSION}-linux-${ARCH}"
RELEASE_URL_CMAKE=${GITHUB_URL}/Kitware/CMake/releases/download/v${CMAKE_VERSION}/${CMAKE_FILE_NAME}.tar.gz
wget --no-verbose --timeout=180 --tries=3 ${RELEASE_URL_CMAKE} -P /tmp
wget --retry-connrefused --timeout=180 --tries=10 --continue ${RELEASE_URL_CMAKE} -P /tmp
tar -xf /tmp/${CMAKE_FILE_NAME}.tar.gz -C /usr/local/
ln -s /usr/local/${CMAKE_FILE_NAME} /usr/local/cmake

View File

@ -16,7 +16,7 @@ reinstall_rockylinux_cuda() {
dnf -y install epel-release
dnf remove -y "cuda*" "*cublas*" "*cufft*" "*cufile*" "*curand*" "*cusolver*" "*cusparse*" "*gds-tools*" "*npp*" "*nvjpeg*" "nsight*" "*nvvm*"
rm -rf /usr/local/cuda-${OLD_CUDA_VER}
wget -q https://developer.download.nvidia.com/compute/cuda/${CUDA_VER_SHORT}/local_installers/cuda_${CUDA_VER}_linux.run
wget --retry-connrefused --timeout=180 --tries=10 --continue https://developer.download.nvidia.com/compute/cuda/${CUDA_VER_SHORT}/local_installers/cuda_${CUDA_VER}_linux.run
sh cuda_${CUDA_VER}_linux.run --silent --override --toolkit
rm -f cuda_${CUDA_VER}_linux.run
}

View File

@ -96,7 +96,7 @@ install_rockylinux_requirements() {
"cuda-toolkit-config-common-${CUDA_RUNTIME}.noarch" \
"libcublas-${CUBLAS_CUDA_VERSION}-${CUBLAS_VER}.${ARCH1}" \
"libcublas-devel-${CUBLAS_CUDA_VERSION}-${CUBLAS_VER}.${ARCH1}"; do
wget -q --timeout=180 --tries=3 "https://developer.download.nvidia.cn/compute/cuda/repos/rhel8/${ARCH3}/${pkg}.rpm"
wget --retry-connrefused --timeout=180 --tries=10 --continue "https://developer.download.nvidia.cn/compute/cuda/repos/rhel8/${ARCH3}/${pkg}.rpm"
done
# Remove old packages
@ -138,7 +138,7 @@ install_tensorrt() {
if [ "$ARCH" = "x86_64" ];then
curl -L --insecure --connect-timeout 600 --max-time 3600 --retry 3 -o /tmp/TensorRT.tar "${RELEASE_URL_TRT}"
else
wget --no-verbose ${RELEASE_URL_TRT} -O /tmp/TensorRT.tar
wget --retry-connrefused --timeout=180 --tries=10 --continue ${RELEASE_URL_TRT} -O /tmp/TensorRT.tar
fi
tar -xf /tmp/TensorRT.tar -C /usr/local/
mv /usr/local/TensorRT-* /usr/local/tensorrt

View File

@ -5,7 +5,7 @@ set -ex
install_boost() {
# Install boost version >= 1.78 for boost::span
# Current libboost-dev apt packages are < 1.78, so install from tar.gz
wget -O /tmp/boost.tar.gz --timeout=180 --tries=3 https://archives.boost.io/release/1.80.0/source/boost_1_80_0.tar.gz \
wget --retry-connrefused --timeout=180 --tries=10 --continue -O /tmp/boost.tar.gz https://archives.boost.io/release/1.80.0/source/boost_1_80_0.tar.gz \
&& tar xzf /tmp/boost.tar.gz -C /tmp \
&& mv /tmp/boost_1_80_0/boost /usr/include/boost \
&& rm -rf /tmp/boost_1_80_0 /tmp/boost.tar.gz

View File

@ -2,3 +2,5 @@
!.vscode
benchmark_results.json
*.png
# ignore config files that users might put here for debugging
*.yaml

View File

@ -26,6 +26,9 @@ from tensorrt_llm.sampling_params import SamplingParams
# Global torch config, set the torch compile cache to fix up to llama 405B
torch._dynamo.config.cache_size_limit = 20
# simple string, TRT-LLM style text-only prompt or full-scale HF message template
PromptInput = Union[str, Dict, List[Dict]]
class PromptConfig(BaseModel):
"""Prompt configuration.
@ -35,13 +38,27 @@ class PromptConfig(BaseModel):
"""
batch_size: int = Field(default=2, description="Number of queries")
queries: Union[str, List[str]] = Field(
queries: Union[PromptInput, List[PromptInput]] = Field(
default_factory=lambda: [
# OPTION 1: simple text prompt
"How big is the universe? ",
"In simple words and in a single sentence, explain the concept of gravity: ",
"How to fix slicing in golf? ",
"Where is the capital of Iceland? ",
]
# OPTION 2: wrapped text prompt for TRT-LLM
{"prompt": "In simple words and a single sentence, explain the concept of gravity: "},
# OPTION 3: a full-scale HF message template (this one works for text-only models!)
# Learn more about chat templates: https://huggingface.co/docs/transformers/en/chat_templating
# and multi-modal templates: https://huggingface.co/docs/transformers/en/chat_templating_multimodal
[
{
"role": "user",
"content": "How to fix slicing in golf?",
}
],
# More prompts...
{"prompt": "Where is the capital of Iceland? "},
],
description="Example queries to prompt the model with. We support both TRT-LLM text-only "
"queries via the 'prompt' key and full-scale HF message template called via "
"apply_chat_template.",
)
sp_kwargs: Dict[str, Any] = Field(
default_factory=lambda: {"max_tokens": 100, "top_k": 200, "temperature": 1.0},
@ -55,10 +72,28 @@ class PromptConfig(BaseModel):
NOTE (lucaslie): has to be done with model_post_init to ensure it's always run. field
validators are only run if a value is provided.
"""
queries = [self.queries] if isinstance(self.queries, str) else self.queries
queries = self.queries if isinstance(self.queries, list) else [self.queries]
batch_size = self.batch_size
queries = queries * (batch_size // len(queries) + 1)
self.queries = queries[:batch_size]
queries = queries[:batch_size]
# now let's standardize the queries for the LLM api to understand them
queries_processed = []
for query in queries:
if isinstance(query, str):
queries_processed.append({"prompt": query})
elif isinstance(query, dict):
queries_processed.append(query)
elif isinstance(query, list):
queries_processed.append(
{
"prompt": "Fake prompt. Check out messages field for the HF chat template.",
"messages": query, # contains the actual HF chat template
}
)
else:
raise ValueError(f"Invalid query type: {type(query)}")
self.queries = queries_processed
@field_validator("sp_kwargs", mode="after")
@classmethod

View File

@ -121,7 +121,7 @@ REQUIRED_NO_DRIVER_TYPES = ["dgx-h100", "dgx-h200", "gh200"]
ENABLE_NGC_DEVEL_IMAGE_TEST = params.enableNgcDevelImageTest ?: false
ENABLE_NGC_RELEASE_IMAGE_TEST = params.enableNgcReleaseImageTest ?: false
COMMON_SSH_OPTIONS = "-o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null"
COMMON_SSH_OPTIONS = "-o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -o ServerAliveInterval=60 -o ServerAliveCountMax=5"
def uploadResults(def pipeline, SlurmCluster cluster, String nodeName, String stageName){
withCredentials([usernamePassword(credentialsId: 'svc_tensorrt', usernameVariable: 'USERNAME', passwordVariable: 'PASSWORD')]) {
@ -323,7 +323,7 @@ def runLLMTestlistOnSlurm(pipeline, platform, testList, config=VANILLA_CONFIG, p
Utils.exec(pipeline, script: "chmod +x ${jenkinsSetupPath}", returnStdout: true)
Utils.exec(pipeline, script: "sshpass -p '${remote.passwd}' scp -r -p ${COMMON_SSH_OPTIONS} ${jenkinsSetupPath} ${remote.user}@${remote.host}:~/bloom/scripts/${nodeName}-slurm_jenkins_agent_setup.sh", numRetries: 3,)
Utils.exec(pipeline, script: "sshpass -p '${remote.passwd}' scp -r -p ${COMMON_SSH_OPTIONS} ${jenkinsSetupPath} ${remote.user}@${remote.host}:~/bloom/scripts/${nodeName}-slurm_jenkins_agent_setup.sh", numRetries: 3)
Utils.exec(pipeline, script: "cat ${jenkinsSetupPath}")
@ -334,7 +334,8 @@ def runLLMTestlistOnSlurm(pipeline, platform, testList, config=VANILLA_CONFIG, p
remote,
"\"${SlurmConfig.generateCommand(cluster, partition, nodeSecret, nodeName, Jenkins.instance.rootUrl)}\""
),
returnStdout: true
returnStdout: true,
numRetries: 3
)
def jobIDs = slurmSubmitOutput
@ -498,7 +499,7 @@ def runLLMTestlistOnSlurm_MultiNodes(pipeline, platform, testList, config=VANILL
stage('Prepare Testing') {
// Create Job Workspace folder in Frontend Node
Utils.exec(pipeline, script: "sshpass -p '${remote.passwd}' ssh ${COMMON_SSH_OPTIONS} ${remote.user}@${remote.host} 'mkdir -p ${jobWorkspace}'", numRetries: 3,)
Utils.exec(pipeline, script: Utils.sshUserCmd(remote, "\"mkdir -p ${jobWorkspace}\""), numRetries: 3)
// Download and Unzip Tar File
trtllm_utils.llmExecStepWithRetry(pipeline, script: "cd ${llmPath} && wget -nv ${llmTarfile}")
@ -508,12 +509,12 @@ def runLLMTestlistOnSlurm_MultiNodes(pipeline, platform, testList, config=VANILL
def scriptRunLocalPath = "${llmSrcLocal}/jenkins/scripts/slurm_run.sh"
Utils.exec(pipeline, script: "chmod +x ${scriptRunLocalPath}", returnStdout: true)
Utils.exec(pipeline, script: "sshpass -p '${remote.passwd}' scp -r -p ${COMMON_SSH_OPTIONS} ${scriptRunLocalPath} ${remote.user}@${remote.host}:${scriptRunNode}", numRetries: 3,)
Utils.exec(pipeline, script: "sshpass -p '${remote.passwd}' scp -r -p ${COMMON_SSH_OPTIONS} ${scriptRunLocalPath} ${remote.user}@${remote.host}:${scriptRunNode}", numRetries: 3)
Utils.exec(pipeline, script: "cat ${scriptRunLocalPath}")
// Upload waives.txt to Frontend node
def waivesListLocalPath = "${llmSrcLocal}/tests/integration/test_lists/waives.txt"
Utils.exec(pipeline, script: "sshpass -p '${remote.passwd}' scp -r -p ${COMMON_SSH_OPTIONS} ${waivesListLocalPath} ${remote.user}@${remote.host}:${waivesListPathNode}", numRetries: 3,)
Utils.exec(pipeline, script: "sshpass -p '${remote.passwd}' scp -r -p ${COMMON_SSH_OPTIONS} ${waivesListLocalPath} ${remote.user}@${remote.host}:${waivesListPathNode}", numRetries: 3)
// Generate Test List and Upload to Frontend Node
def makoArgs = getMakoArgsFromStageName(stageName, true)
@ -522,7 +523,7 @@ def runLLMTestlistOnSlurm_MultiNodes(pipeline, platform, testList, config=VANILL
// if the line cannot be split by "=", just ignore that line.
def makoOptsJson = transformMakoArgsToJson(["Mako options:"] + makoArgs)
def testListPath = renderTestDB(testList, llmSrcLocal, stageName, makoOptsJson)
Utils.exec(pipeline, script: "sshpass -p '${remote.passwd}' scp -r -p ${COMMON_SSH_OPTIONS} ${testListPath} ${remote.user}@${remote.host}:${testListPathNode}", numRetries: 3,)
Utils.exec(pipeline, script: "sshpass -p '${remote.passwd}' scp -r -p ${COMMON_SSH_OPTIONS} ${testListPath} ${remote.user}@${remote.host}:${testListPathNode}", numRetries: 3)
// Generate Multi Node Job Launch Script
def container = LLM_DOCKER_IMAGE.replace("urm.nvidia.com/", "urm.nvidia.com#")
@ -566,7 +567,7 @@ def runLLMTestlistOnSlurm_MultiNodes(pipeline, platform, testList, config=VANILL
""".stripIndent()
pipeline.writeFile(file: scriptLaunchDestPath, text: scriptContent)
Utils.exec(pipeline, script: "chmod +x ${scriptLaunchDestPath}", returnStdout: true)
Utils.exec(pipeline, script: "sshpass -p '${remote.passwd}' scp -r -p ${COMMON_SSH_OPTIONS} ${scriptLaunchDestPath} ${remote.user}@${remote.host}:${scriptLaunch}", numRetries: 3,)
Utils.exec(pipeline, script: "sshpass -p '${remote.passwd}' scp -r -p ${COMMON_SSH_OPTIONS} ${scriptLaunchDestPath} ${remote.user}@${remote.host}:${scriptLaunch}", numRetries: 3)
Utils.exec(pipeline, script: "cat ${scriptLaunchDestPath}")
}
@ -577,7 +578,8 @@ def runLLMTestlistOnSlurm_MultiNodes(pipeline, platform, testList, config=VANILL
script: Utils.sshUserCmd(
remote,
"\"bash ${scriptLaunch}\""
)
),
numRetries: 3
)
}

View File

@ -10,15 +10,18 @@ and operates on a purely functional paradigm that is compatible with the torch c
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field, fields
from typing import Dict, List, Literal, Optional, Protocol, Sequence, Tuple, Type, Union
from dataclasses import dataclass
from typing import Callable, Dict, List, Literal, Optional, Protocol, Sequence, Tuple, Type, Union
import torch
from torch._ops import OpOverloadPacket
from torch.export import Dim
from torch.fx import Node
from tensorrt_llm._utils import nvtx_range
from ...._utils import nvtx_range
DynamicShape = Dict[int, Dim] # indicating the dynamic shape in tensor dimension
DynamicShapeCallback = Callable[[], DynamicShape]
@dataclass
@ -28,15 +31,14 @@ class CacheConfig:
dtype: Optional[torch.dtype] = None
@dataclass
class SequenceInfo:
"""A dataclass to hold information about how the sequence is laid out and stored in cache.
"""An interface to hold information about how the sequence is laid out and stored in cache.
We assume the sequence + cache is laid out in the following way. Also note that we differentiate
between arguments that are originally part of the model/graph and arguments that are needed for
the attention operator when we switch to cached+flattened attention.
# ORIGINAL MODEL ARGUMENTS #####################################################################
### ORIGINAL MODEL ARGUMENTS ###################################################################
- input_ids: [id_0, ..., id_{s_total-1}]
flattened sequence of [b, 1] or [1, s_total]. We use [b, 1] to denote generate-only batches.
- position_ids: [pos_0, ..., pos_{s_total-1}]
@ -46,7 +48,17 @@ class SequenceInfo:
NOTE: ``input_ids`` and ``position_ids`` are initially expected to be of shape [b, seq_len]
before we switch to cached+flattened attention.
# EXTRA ARGUMENTS NEEDED FOR ATTENTION OPERATORS FOR FLATTENED SEQUENCES + CACHES ##############
### EXTRA ARGUMENTS PROVIDED TO THE INTERFACE ##################################################
Those are extra arguments that can be provided to the interface and they are stored as follows:
- _extra_args: dictionary of extra arguments with currently active values.
- _extra_none_inputs: dictionary of none inputs to the extra arguments.
NOTE: we assume that extra arguments are *optional* arguments to the model. However, we
cannot represent them via `None` since fx graphs require a fixed input type. Instead,
we require a special placeholder tensor to represent the `None` input.
- _extra_dynamic_shapes_callbacks: dictionary of callbacks to initialize the dynamic shapes of
the extra arguments.
### CACHE ARGUMENTS NEEDED FOR ATTENTION OPERATORS FOR FLATTENED SEQUENCES + CACHES ############
- seq_len: [s_0, s_1, ..., s_{b-1}] such that s_total = sum(s_i)
Describes how long each sequence is. For example,
input_ids[:s_0] will correspond to sequence 0 in the batch and input_ids[s_0:s_1] will
@ -73,168 +85,238 @@ class SequenceInfo:
"""
## USE TO INITIALIZE DATA CLASS ###############################################################
# max_seq_len corresponds the maximum number of tokens in any sequence. It includes the tokens in the
# input sequence and the tokens generated by the model.
max_seq_len: int = 1
# max_batch_size corresponds to the maximum number of sequences (or requests) that the model can process.
max_batch_size: int = 1
# page_size is the granularity with which the cache pages are allocated for a paged kv cache.
# For an unpaged cache, the page size should be set to max_seq_len.
# Also note that two sequences in a batch can not share a page.
page_size: int = 0
# max_num_tokens is the maximum number of tokens that the model can process across all sequences in the batch.
# If a batch is composed of context-only requests of input sequence length ISL,
# then the maximum number of sequences possible in the batch is min (max_batch_size, max_num_tokens // ISL).
# Similarly, if a batch is composed of generate-only requests,
# then the maximum number of sequences possible in the batch is min (max_batch_size, max_num_tokens).
max_num_tokens: Optional[int] = None
# device is the device on which the sequence info is stored.
device: str = "cuda"
def __init__(
self,
max_seq_len: int = 1,
max_batch_size: int = 1,
page_size: int = 0,
max_num_tokens: Optional[int] = None,
):
"""Initialize the SequenceInfo object.
## [UPDATE WITH CARE] TENSOR FIELDS THAT WILL BE PASSED TO PREPARE_METADATA OP #################
# input_ids MUST ALWAYS BE THE FIRST FIELD
input_ids: torch.Tensor = field(default_factory=lambda: torch.zeros(1, dtype=torch.int))
position_ids: torch.Tensor = field(default_factory=lambda: torch.zeros(1, dtype=torch.long))
Args:
max_seq_len: corresponds to the maximum sequence length of the input sequence. It
includes the tokens in the input sequence and the tokens generated by the model.
max_batch_size: corresponds to the maximum number of sequences (or requests) that the
model can process.
page_size: corresponds to the page size of the cache. For an unpaged cache, the page
size should be set to max_seq_len. Also note that two sequences in a batch can not
share a page.
max_num_tokens: corresponds to the maximum number of tokens that the model can process
across all sequences in the batch. If a batch is composed of context-only requests
of input sequence length ISL, then the maximum number of sequences possible in the
batch is min (max_batch_size, max_num_tokens // ISL). Similarly, if a batch is
composed of generate-only requests, then the maximum number of sequences possible in
the batch is min (max_batch_size, max_num_tokens).
seq_len: torch.Tensor = field(default_factory=lambda: torch.ones(1, dtype=torch.int))
input_pos: torch.Tensor = field(default_factory=lambda: torch.zeros(1, dtype=torch.int))
cache_loc: torch.Tensor = field(default_factory=lambda: torch.arange(1, dtype=torch.int))
pages_per_seq: torch.Tensor = field(default_factory=lambda: torch.ones(1, dtype=torch.int))
################################################################################################
## PRIVATE FIELDS ##############################################################################
_sequence_lengths: List[int] = field(default_factory=list)
_num_pages: int = 1
def __post_init__(self):
if self.page_size < 1:
self.page_size = self.max_seq_len
Returns:
None
"""
# set up basic attributes
self.max_seq_len = max_seq_len
self.max_batch_size = max_batch_size
self.page_size = page_size if page_size > 0 else max_seq_len
# NOTE (lucaslie): WAR to address issue when using flashinfer attention with
# (max_batch_size, max_seq_len) input in trtllm runtime.
# see https://github.com/NVIDIA/TensorRT-LLM/issues/4504
self.max_seq_len_adjusted = self.max_seq_len + 1
max_seq_len_adjusted = self.max_seq_len + 1
if max_num_tokens is None or max_num_tokens < 1:
self.max_num_tokens = self.max_batch_size * max_seq_len_adjusted
else:
self.max_num_tokens = max_num_tokens
if self.max_num_tokens is None or self.max_num_tokens < 1:
self.max_num_tokens = self.max_batch_size * self.max_seq_len_adjusted
# if the provided max_num_tokens is less than the max_batch_size * max_seq_len,
# we use the provided max_num_tokens to calculate the number of pages
total_tokens = min(self.max_num_tokens, self.max_batch_size * self.max_seq_len_adjusted)
total_tokens = min(self.max_num_tokens, self.max_batch_size * max_seq_len_adjusted)
# Num pages can not be less than max_batch_size.
self._num_pages = max(
self.max_batch_size,
(total_tokens) // self.page_size + (total_tokens % self.page_size > 0),
)
# Ensure that the device is set before initializing the tensors.
self.input_ids = torch.ones(self.max_num_tokens, dtype=torch.int, device=self.device)
self.position_ids = torch.zeros(self.max_num_tokens, dtype=torch.long, device=self.device)
# Consumers of the sequence info args require input_ids and position_ids to be truncated.
# We maintain a full version of the input_ids and position_ids to avoid overheads of tensor
# creation in every forward pass.
self.input_ids_full = torch.ones(self.max_num_tokens, dtype=torch.int, device=self.device)
self.position_ids_full = torch.zeros(
self.max_num_tokens, dtype=torch.long, device=self.device
)
self.seq_len = torch.empty(self.max_batch_size, dtype=torch.int, device=self.device)
self.input_pos = torch.empty_like(self.seq_len, device=self.device)
# Allocated host tensors for sequence lengths and input positions so that
# position_ids calculation can be done on host.
self.seq_len_host = torch.empty(self.max_batch_size, dtype=torch.int)
self.input_pos_host = torch.empty_like(self.seq_len_host)
self.cache_loc = torch.empty(self.num_pages, dtype=torch.int, device=self.device)
self.pages_per_seq = torch.empty_like(self.seq_len, device=self.device)
self.previous_batch_indices_cuda = torch.empty(
self.max_num_tokens, dtype=torch.long, device=self.device
)
assert self.num_pages >= self.max_batch_size, (
"num_pages must be greater than max_batch_size"
)
# dynamic shape descriptors for tensor args
self._dynamic_shapes: Optional[Tuple[Dict[str, Dim]]] = None
# keep a list-like object of sequence lengths for simplicity as well
self._sequence_lengths = [0] * self.max_batch_size
# sanity check
assert self.num_pages >= self.max_batch_size, "num_pages can't be less than max_batch_size"
# indicator if extra args are activated that are needed for cached attention backends
self._is_cached_attn = False
# total number of tokens in the current batch
self.num_tokens: int = 0
# container for dynamic shapes
self._dynamic_shapes: Optional[Dict[str, DynamicShape]] = None
# call reset once to initialize the tensors
# TENSOR FIELDS ############################################################################
self._args_device: Dict[str, torch.Tensor] = {
# TENSOR FIELDS FOR UNCACHED ATTENTION
"input_ids": torch.ones(self.max_num_tokens, dtype=torch.int),
"position_ids": torch.zeros(self.max_num_tokens, dtype=torch.long),
# TENSOR FIELDS FOR CACHED ATTENTION
"seq_len": torch.empty(self.max_batch_size, dtype=torch.int),
"input_pos": torch.empty(self.max_batch_size, dtype=torch.int),
"cache_loc": torch.empty(self.num_pages, dtype=torch.int),
"pages_per_seq": torch.empty(self.max_batch_size, dtype=torch.int),
# OTHER FIELDS WHERE WE NEED EFFICIENT HOST<>DEVICE TRANSFER
"_gather_idx": torch.empty(self.max_num_tokens, dtype=torch.int),
}
self._args_host: Dict[str, List[int]] = {
k: v.tolist() for k, v in self._args_device.items()
}
# NOTE: order of keys is relevant here!
self._uncached_arg_names = ("input_ids", "position_ids")
self._cached_arg_names = ("seq_len", "input_pos", "cache_loc", "pages_per_seq")
############################################################################################
# EXTRA TENSOR FIELDS ######################################################################
self._extra_args: Dict[str, torch.Tensor] = {}
self._extra_none_inputs: Dict[str, torch.Tensor] = {}
self._extra_dynamic_shapes: Optional[Dict[str, DynamicShape]] = None
self._extra_dynamic_shapes_callbacks: Dict[str, DynamicShapeCallback] = {}
############################################################################################
# call reset once to set a consistent initial state
self.reset()
@property
def device(self) -> torch.device:
return self._args_device["input_ids"].device
def _shape_for_forward(self, tnsr: torch.Tensor) -> torch.Tensor:
"""Shape the tensor for the forward pass based on the current attention mode.
Args:
tnsr: The tensor to shape assumed to be in shape [batch_size*seq_len, ...]
Returns:
The shaped tensor flattened or unflattened based on the current attention mode.
"""
# check if we are still running uncached attention in which case we are also still
# operate on unflattened tensors with explicit [batch_size, seq_len, ...] shape
# generate-only batches are also formatted like this (i.e. [b, 1])
if not self._is_cached_attn or self.is_generate:
bs = len(self.seq_len)
sl = self.seq_len[0]
# use [1,total_len] shape to indicate non-generate-only batch for cached attention
else:
bs, sl = 1, self.total_num_tokens
# truncate to total tokens now, reshape, and return
return tnsr[: self.total_num_tokens].view(bs, sl, *tnsr.shape[1:])
def _named_args(
self, include_extra_args: bool = True, include_cached_args: bool = True
) -> Dict[str, torch.Tensor]:
# start with uncached args and shape them along the way
args = {k: self._shape_for_forward(self._args_device[k]) for k in self._uncached_arg_names}
# check other args to include
if include_extra_args:
args.update(self._extra_args)
if include_cached_args:
args.update({k: self._args_device[k] for k in self._cached_arg_names})
return args
@property
def named_args(self) -> Dict[str, torch.Tensor]:
"""Return a dictionary of named arguments.
These arguments contain all arguments that are managed by this interface and are required
to run a model's forward pass including all extra arguments.
Cached arguments are only included if the attention mode is cached to reflect that after
switching to cached attention, the cached arguments are required for a forward pass.
"""
return self._named_args(include_extra_args=True, include_cached_args=self._is_cached_attn)
@property
def named_standard_args(self) -> Dict[str, torch.Tensor]:
"""Return a dictionary of named standard arguments.
We define standard arguments as the arguments that are part of the model's forward function
by default (i.e., without the extra arguments).
Just liked ``named_args``, this property includes cached attention arguments if the
attention mode is cached.
"""
return self._named_args(include_extra_args=False, include_cached_args=self._is_cached_attn)
@property
def args(self) -> Tuple[torch.Tensor, ...]:
args = []
for f in fields(self):
val = getattr(self, f.name)
if isinstance(val, torch.Tensor):
args.append(val)
if len(args) >= self._num_uncached_attn_args and not self._is_cached_attn:
break
return tuple(args)
"""Return a tuple of arguments."""
return tuple(self.named_args.values())
@property
def _num_uncached_attn_args(self) -> int:
"""Return the number of original graph arguments expected by the model.
This is 2 because we have input_ids and position_ids as the original graph arguments.
def const_args_for_prepare_metadata(self) -> Tuple:
"""Return a tuple of extra (const, non-tensor) arguments for the prepare_metadata op.
The ``prepare_metadata`` interface expects the following arguments:
1. ``named_standard_args`` as nodes,i.e., as input-dependent tensors.
2. ``const_args_for_prepare_metadata`` as constants that can directly by passed in as args
to the corresponding ``prepare_metadata`` node/op.
This interface handles the constant arguments part and can be used by compiler passes like
``insert_cached_attention`` to extract the constant arguments and add them to the
``prepare_metadata`` node/op.
"""
return 2
return (self.page_size,)
@property
def _cached_attn_arg_names(self) -> List[str]:
"""Return extra arg names for the prepare_metadata op beyond input_ids and position_ids.
These extra args are needed once we switch from regular attention to inserting cached
attention ops in the model.
"""
return [f.name for f in fields(self) if isinstance(getattr(self, f.name), torch.Tensor)][
self._num_uncached_attn_args :
]
@property
def dynamic_shapes(self) -> Tuple[Dict[str, Dim]]:
def named_dynamic_shapes(self) -> Dict[str, Dict[str, Dim]]:
"""Return dynamic shapes of sequence info tensors.
NOTE: will be lazily initialized since the Dim object is not picklable for multi-processing.
"""
# lazy initialization of dynamic shapes with Dim objects
if self._dynamic_shapes is None:
# set up shape for input_ids and position_ids
dynamic_shapes = ({}, {})
# set up shape for uncached args (same for all, i.e., batch_size and seq_len)
bs_seq_len_shape: DynamicShape = {}
if self.max_batch_size > 1:
dynamic_shapes[0][0] = Dim("batch_size", max=self.max_batch_size)
dynamic_shapes[0][1] = Dim("seq_len", max=self.max_seq_len_adjusted)
# set up shape for position_ids (same as input_ids)
dynamic_shapes[1].update(dynamic_shapes[0])
# set up shape for extra args
if self._is_cached_attn:
dynamic_shapes += ({},) * len(self._cached_attn_arg_names)
self._dynamic_shapes = dynamic_shapes
return self._dynamic_shapes
bs_seq_len_shape[0] = Dim("batch_size", max=self.max_batch_size)
bs_seq_len_shape[1] = Dim("seq_len", max=self.max_seq_len)
self._dynamic_shapes = {k: bs_seq_len_shape for k in self._uncached_arg_names}
# cached args are static
self._dynamic_shapes.update({k: {} for k in self._cached_arg_names})
for k, callback in self._extra_dynamic_shapes_callbacks.items():
if k not in self._dynamic_shapes:
self._dynamic_shapes[k] = callback()
# return dynamic shapes according to currently active named_args with consistent order
return {k: self._dynamic_shapes[k] for k in self.named_args.keys()}
@property
def dynamic_shapes(self) -> Tuple[DynamicShape, ...]:
"""Return dynamic shapes of sequence info tensors."""
return tuple(self.named_dynamic_shapes.values())
@property
def seq_len(self) -> List[int]:
return self._args_host["seq_len"].copy()
@property
def input_pos(self) -> List[int]:
return self._args_host["input_pos"].copy()
@property
def cache_loc(self) -> List[int]:
return self._args_host["cache_loc"].copy()
@property
def pages_per_seq(self) -> List[int]:
return self._args_host["pages_per_seq"].copy()
@property
def num_sequences(self) -> int:
return len(self._sequence_lengths)
return len(self.seq_len)
@property
def sequence_lengths(self) -> List[int]:
return self._sequence_lengths
@property
def input_positions(self) -> List[int]:
return self.input_pos_host[: self.num_sequences].tolist()
def total_num_tokens(self) -> int:
return sum(self.seq_len)
@property
def is_generate(self) -> bool:
return all(sl == 1 for sl in self.sequence_lengths)
return all(sl == 1 for sl in self.seq_len)
@property
def num_pages(self) -> int:
@ -244,7 +326,7 @@ class SequenceInfo:
def num_pages(self, value):
self._num_pages = value
# update the cache_loc tensor
self.cache_loc.resize_(value)
self._args_device["cache_loc"].resize_(value)
@property
def is_paged(self) -> bool:
@ -253,12 +335,52 @@ class SequenceInfo:
@property
def page_assignments(self) -> List[List[int]]:
"""Return the page assignments for each sequence."""
pages_per_seq = self.pages_per_seq[: self.num_sequences].tolist()
return self._get_page_assignments(self.cache_loc, self.pages_per_seq)
@staticmethod
def _get_page_assignments(
cache_locations: List[int], pages_per_sequence: List[int]
) -> List[List[int]]:
"""Get nested page assignments from cache locations and pages per sequence as list of lists.
Args:
cache_locations: A flat list of cache locations for each sequence ordered by sequence.
pages_per_sequence: A list of number of pages per sequence.
Returns:
A list of page assignments for each sequence ordered by sequence.
For example:
cache_locations: [0, 4, 2]
pages_per_sequence: [2, 1]
--> returns [[0, 4], [2]]
"""
return [
c_loc_one_seq.tolist()
for c_loc_one_seq in torch.split(self.cache_loc[: sum(pages_per_seq)], pages_per_seq)
for c_loc_one_seq in torch.split(torch.tensor(cache_locations), pages_per_sequence)
]
@staticmethod
def _get_cache_locations_and_pages_per_sequence(
page_assignments: List[List[int]],
) -> Tuple[List[int], List[int]]:
"""Get cache locations and pages per sequence from nested page assignments (lists of lists).
Args:
page_assignments: A list of page assignments for each sequence ordered by sequence.
Returns:
A tuple of:
cache_locations: A flat list of cache locations for each sequence ordered by sequence.
pages_per_sequence: A list of number of pages per sequence.
Example:
page_assignments: [[0, 4], [2]]
--> returns ([0, 4, 2], [2, 1])
"""
cache_loc_flat = [p_idx for pages in page_assignments for p_idx in pages]
pages_per_seq = [len(p) for p in page_assignments]
return cache_loc_flat, pages_per_seq
@classmethod
def _get_sanitized_seq_len(cls, input_ids: torch.Tensor, seq_len: torch.Tensor) -> torch.Tensor:
"""Sanitize sequence lengths.
@ -331,26 +453,55 @@ class SequenceInfo:
"""
assert not self._is_cached_attn, "Cached+flattened attention already activated"
self._is_cached_attn = True
return self._cached_attn_arg_names
return list(self._cached_arg_names)
def to(self, *args, **kwargs) -> None:
for f in fields(self):
val = getattr(self, f.name)
if isinstance(val, torch.Tensor):
setattr(self, f.name, val.to(*args, **kwargs))
def _move_dict(d: Dict[str, torch.Tensor]) -> None:
for k, v in d.items():
d[k] = v.to(*args, **kwargs)
def sync(self, other: "SequenceInfo") -> None:
for f in fields(self):
val = getattr(self, f.name)
val_other = getattr(other, f.name)
if f.name in ["input_ids", "position_ids"]:
setattr(self, f.name, val_other.to(self.device))
elif f.name == "_sequence_lengths":
self._sequence_lengths = val_other
elif isinstance(val, torch.Tensor):
val[: len(val_other)] = val_other.to(self.device)
else:
assert val == val_other, f"Field {f.name} mismatch: {val} != {val_other}."
_move_dict(self._args_device)
_move_dict(self._extra_args)
_move_dict(self._extra_none_inputs)
def set_example_sequence(
self,
input_ids: Sequence[Sequence[int]] = None,
position_ids: Optional[torch.Tensor] = None,
**extra_args,
) -> None:
"""Set an example sequence useful for testing and export purposes without cache history."""
# use a best guess default for input_ids if not provided
if input_ids is None:
bs, seq_len = min(2, self.max_batch_size), min(4, self.max_seq_len)
input_ids = torch.ones(bs, seq_len, dtype=torch.int).tolist()
# figure out page assignments
pages_per_seq = [
len(ids_one_seq) // self.page_size + (len(ids_one_seq) % self.page_size > 0)
for ids_one_seq in input_ids
]
cache_loc = list(range(sum(pages_per_seq)))
page_assignments = self._get_page_assignments(cache_loc, pages_per_seq)
self.nest_sequences(
input_ids,
position_ids, # will be auto-inferred if None
input_pos=0, # no cache history
page_assignments=page_assignments, # vanilla page assignments
**extra_args,
)
def set_max_num_tokens_sample(self) -> None:
"""Set an example sequence with max_num_tokens."""
# TODO (lucaslie): understand what this implies for extra arguments
seq_len = self.max_num_tokens // self.max_batch_size
input_ids = torch.ones(self.max_batch_size, seq_len, dtype=torch.int).tolist()
self.set_example_sequence(input_ids)
def set_generate_only_batch(self) -> None:
"""Set an example sequence for generate-only batch."""
self.set_example_sequence([[1]] * self.max_batch_size)
def reset(self) -> None:
"""Reset the sequence information.
@ -358,205 +509,173 @@ class SequenceInfo:
After reset the sequence information should correspond to a "generate-only" batch of
sequences (b, s==1) without cache history.
"""
# reset input_pos
self.input_pos.zero_()
self.input_pos_host.zero_()
self.set_generate_only_batch()
# set a dummy sequence corresponding to a generate-only batch (will also reset position_ids)
self.nest_sequences([[1]] * self.max_batch_size, allow_realloc=True)
@staticmethod
def _flatten(nested_seqs: Sequence[Sequence[int]]) -> List[int]:
return [
val
for lst in nested_seqs
for val in (lst.detach().tolist() if isinstance(lst, torch.Tensor) else lst)
]
# reset cache information
self.cache_loc[:] = torch.arange(self.num_pages, dtype=torch.int, device=self.device)
self.pages_per_seq.fill_(1)
# let's also reset the input_ids and position_ids tensors to their max shapes (max_num_tokens)
self.input_ids = torch.ones(self.max_num_tokens, dtype=torch.int, device=self.device)
self.position_ids = torch.zeros(self.max_num_tokens, dtype=torch.long, device=self.device)
def set_example_sequence(self) -> None:
"""Set an example sequence useful for testing and export purposes."""
self.reset()
bs, seq_len = min(2, self.max_batch_size), min(4, self.max_seq_len)
input_ids = torch.ones(
bs,
seq_len,
dtype=torch.int,
device=self.device,
)
self.nest_sequences(input_ids, allow_realloc=True)
# unflatten if we are not yet using cached+flattened attention
if not self._is_cached_attn:
self.input_ids = self.input_ids.view(bs, seq_len)
self.position_ids = self.position_ids.view(bs, seq_len)
def _set_max_num_tokens_sample(self) -> None:
"""Set an example sequence with max_num_tokens."""
self.reset()
seq_len = self.max_num_tokens // self.max_batch_size
input_ids = torch.ones(
self.max_batch_size,
seq_len,
dtype=torch.int,
device=self.device,
)
self.pages_per_seq.fill_(seq_len // self.page_size)
self.nest_sequences(input_ids, allow_realloc=True)
def set_generate_only_batch(self) -> None:
"""Set an example sequence for generate-only batch.
NOTE: this batch is already formatted as [b, 1] in both original and in the cached attention
mode. So we don't need to do anything mode-specific here.
"""
self.reset()
self.nest_sequences([[1]] * self.max_batch_size, allow_realloc=True)
def maybe_reshape_for_generate(self, tensor: torch.Tensor) -> torch.Tensor:
# use [b,1] shape to indicate generate-only batch, otherwise use [1,total_len]
if self.is_generate:
return tensor.view(-1, 1, *tensor.shape[1:])
else:
return tensor.view(1, -1, *tensor.shape[1:])
@nvtx_range("ad_update_position_ids")
def _update_position_ids(self, allow_realloc: bool = False) -> None:
# set new position_ids from input_pos and seq_len
# Make sure this is done on host to avoid host-device copies.
with nvtx_range("prepare_list"):
# Optimize for the common case where all seq_len values are 1 (generation mode)
if torch.all(self.seq_len_host == 1):
# Fast path: when all seq_len are 1, position_ids is just input_pos_host
position_ids_host = (
self.input_pos_host[: self.num_tokens].to(dtype=torch.long).pin_memory()
)
else:
# General case - can probably be optimized too, but overall impact will be minor.
position_ids_list = []
for in_pos, seq_len in zip(self.input_pos_host, self.seq_len_host):
position_ids_list.extend(range(in_pos, in_pos + seq_len))
position_ids_host = torch.tensor(
position_ids_list, dtype=torch.long, pin_memory=True
)
with nvtx_range("copy_to_device"):
if allow_realloc:
# Create a new position_ids tensor on the device
self.position_ids = position_ids_host.to(self.device).clone()
else:
self.position_ids_full = self.position_ids_full.flatten()
self.position_ids_full[: self.num_tokens].copy_(
position_ids_host, non_blocking=True
)
with nvtx_range("maybe_reshape"):
self.position_ids = self.maybe_reshape_for_generate(
self.position_ids if allow_realloc else self.position_ids_full[: self.num_tokens]
)
@nvtx_range("ad_update_sequence_lengths")
def _update_sequence_lengths(self, sequence_lengths: List[int]) -> None:
self._sequence_lengths = sequence_lengths
self.num_tokens = sum(self._sequence_lengths)
self.seq_len.zero_()
self.seq_len_host = torch.tensor(self._sequence_lengths, pin_memory=True)
self.seq_len[: len(self._sequence_lengths)].copy_(self.seq_len_host, non_blocking=True)
def update_input_ids_with_new_tokens(
self, new_tokens: torch.Tensor, previous_batch_indices: List[int]
def _store_arg(
self,
name: str,
tnsr_like: List[int | float],
reset: bool = False,
) -> None:
"""Update the input_ids with new tokens.
"""Store the argument on the host and copy to the device in a non-blocking fashion.
This function will update the input_ids with new tokens and previous batch indices.
Args:
name: Name of the argument to store.
tnsr_like: List of values to store.
reset: Whether to reset the full tensor on the device to 0 before writing to it.
"""
# 1) flatten once
original_shape = self.input_ids.shape
flat = self.input_ids.flatten()
with nvtx_range(f"ad_store_seq_info_arg_{name}"):
tnsr_device = self._args_device[name]
# copy indices to the GPU
host_idx = torch.tensor(previous_batch_indices, dtype=torch.int, pin_memory=True)
idx = self.previous_batch_indices_cuda[: len(previous_batch_indices)]
idx.copy_(host_idx, non_blocking=True)
# store list object on the host
self._args_host[name] = tnsr_like.copy()
# gather the exact values you want to write
src = new_tokens[0, idx, 0]
# pin the memory on the host
tnsr_host = torch.tensor(tnsr_like, dtype=tnsr_device.dtype, pin_memory=True)
# inplace fill every slot where flat == -1 with src, in order
flat.masked_scatter_(flat == -1, src)
# reset/copy to the device in a non-blocking fashion
if reset:
tnsr_device.zero_()
tnsr_device[: len(tnsr_like)].copy_(tnsr_host, non_blocking=True)
# 4) reshape back
self.input_ids = flat.view(original_shape)
def _store_extra_arg(
self, name: str, tnsr_like: Optional[Union[torch.Tensor, Sequence[torch.Tensor]]]
) -> None:
with nvtx_range(f"ad_store_extra_arg_{name}"):
if tnsr_like is not None:
if not isinstance(tnsr_like, torch.Tensor):
if len(tnsr_like) > 1:
tnsr_like = torch.cat(tnsr_like)
else:
tnsr_like = tnsr_like[0]
self._extra_args[name] = tnsr_like.to(self.device, non_blocking=True)
else:
self._extra_args[name] = self._extra_none_inputs[name]
@nvtx_range("ad_nest_sequences")
def nest_sequences(
self, input_ids: Sequence[Sequence[int]], allow_realloc: bool = False
self,
input_ids: Sequence[Sequence[int]],
position_ids: Optional[Sequence[Sequence[int]]] = None,
input_pos: Optional[Union[Sequence[int], int]] = None,
page_assignments: Optional[Sequence[Sequence[int]]] = None,
**extra_args: Dict[str, Union[torch.Tensor, Sequence[torch.Tensor]]],
) -> None:
"""Create and store a flattened list of input_ids from the provided list of sequences.
"""Create and store sequence information for the next forward pass.
When allow_realloc is True, the input_ids will be reallocated on the device.
This i/f will also update any relevant sequence information.
Args:
input_ids: List of sequences of input_ids.
position_ids: List of sequences of position_ids for each token.
input_pos: Absolute starting position in the cache for each sequence.
page_assignments: List of sequences of page assignments for each sequence.
extra_args: Extra arguments to be stored in the interface.
This i/f will ensure that all sequence info args are updated accordingly.
"""
# set new sequence lengths
self._update_sequence_lengths([len(ids) for ids in input_ids])
### UPDATE METADATA ########################################################################
# update metadata first since it's useful for other updates to have up-to-date information
# We'll preserve the dtype of the input_ids tensor if it is a tensor, otherwise we'll use int
dtype = input_ids.dtype if isinstance(input_ids, torch.Tensor) else torch.int
# set new input_ids as new tensor from flattened input_ids
ids_list = [
val
for lst in input_ids
for val in (lst.detach().tolist() if isinstance(lst, torch.Tensor) else lst)
]
input_ids_host = torch.tensor(ids_list, dtype=dtype, pin_memory=True)
# set new sequence lengths --> resetting the remaining entries to zero is important to help
# us discern the actual number of sequences in the batch.
self._store_arg("seq_len", [len(ids) for ids in input_ids], reset=True)
if allow_realloc:
self.input_ids = input_ids_host.to(self.device).clone()
else:
self.input_ids_full = self.input_ids_full.flatten()
self.input_ids_full[: self.num_tokens].copy_(input_ids_host, non_blocking=True)
# check for updated input_pos (i.e. cache start position)
if input_pos is not None:
self._store_arg(
"input_pos",
[input_pos] * self.num_sequences if isinstance(input_pos, int) else input_pos,
)
self.input_ids = self.maybe_reshape_for_generate(
self.input_ids if allow_realloc else self.input_ids_full[: self.num_tokens]
)
# update position_ids
self._update_position_ids(allow_realloc=allow_realloc)
# check for updated page_assignments
if page_assignments is not None:
cache_loc, pages_per_seq = self._get_cache_locations_and_pages_per_sequence(
page_assignments
)
self._store_arg("cache_loc", cache_loc)
self._store_arg("pages_per_seq", pages_per_seq)
### UPDATE MAIN INPUTS #####################################################################
# set new input_ids and make sure to flatten it
self._store_arg("input_ids", self._flatten(input_ids))
# set new position_ids and make sure to flatten it
if position_ids is None:
position_ids = [
[num for num in range(in_pos, in_pos + seq_len)]
for in_pos, seq_len in zip(self.input_pos, self.seq_len)
]
self._store_arg("position_ids", self._flatten(position_ids))
### UPDATE EXTRA INPUTS ####################################################################
# go through all extra tensor arguments and update them
for name in self._extra_none_inputs.keys():
self._store_extra_arg(name, extra_args.pop(name, None))
assert not extra_args, f"Extra arguments {extra_args.keys()} not found"
@nvtx_range("ad_rescatter_input_ids")
def rescatter_input_ids(
self, ungathered_input_ids: torch.Tensor, gather_idx: List[int], scatter_ref: int
):
"""Re-scatter the provided ungathered input ids into the input_ids tensor.
Args:
ungathered_input_ids: The input ids on the device from which to gather.
gather_idx: The list of indices to gather from the ungathered_input_ids.
scatter_ref: The reference index to scatter to in input_ids via masked scatter.
Returns:
None
This function will assume that we are in a generate-only batch.
"""
# store the new gather indices
self._store_arg("_gather_idx", gather_idx)
# gather the provided input ids in a streaming fashion
gather_ids_device = self._args_device["_gather_idx"][: len(gather_idx)]
packed_input_ids = ungathered_input_ids[gather_ids_device]
# re-scatter the provided input ids into the input_ids tensor
input_ids_device = self._args_device["input_ids"]
input_ids_device.masked_scatter_(input_ids_device == scatter_ref, packed_input_ids)
@nvtx_range("ad_unnest_sequences")
def unnest_sequences(self, t_nested: torch.Tensor) -> List[torch.Tensor]:
t_squeezed = t_nested.squeeze(1) if self.is_generate else t_nested.squeeze(0)
return list(torch.split(t_squeezed, self.sequence_lengths))
return list(torch.split(t_squeezed, self.seq_len))
@nvtx_range("ad_update_pos")
def update_pos(self, seq_len: Union[torch.Tensor, List[int], int], reset: bool = False) -> None:
"""Update the starting position for each sequence in the cache.
def add_extra_arg(
self,
name: str,
none_input: torch.Tensor,
dynamic_shape_callback: Optional[DynamicShapeCallback] = None,
) -> None:
"""Add an extra argument to the sequence info object.
If ``reset=True`, ``input_pos`` will be reset to zero before updating.
Args:
name: The name of the extra argument.
none_input: None input value of the extra argument.
dynamic_shape_callback: The callback to get the dynamic shape of the extra argument.
Note that the extra argument is expected to be a tensor.
"""
if not isinstance(seq_len, torch.Tensor):
seq_len = torch.tensor(seq_len, dtype=torch.int, pin_memory=True)
bs = len(seq_len) if seq_len.dim() > 0 else self.max_batch_size
assert name not in self._named_args().keys(), f"Extra argument {name} already exists"
if reset:
self.input_pos_host[:bs].copy_(seq_len, non_blocking=True)
self._extra_args[name] = none_input.to(self.device)
self._extra_none_inputs[name] = none_input.to(self.device)
if dynamic_shape_callback is None:
self._extra_dynamic_shapes_callbacks[name] = lambda: {}
else:
self.input_pos_host[:bs] += seq_len
# update position_ids
self._update_position_ids()
self.input_pos[:bs].copy_(self.input_pos_host[:bs], non_blocking=True)
@nvtx_range("ad_assign_cache_loc")
def assign_cache_loc(self, page_assignments: Sequence[Sequence[int]]) -> None:
"""Set the cache location and pages_per_seq tensors from page assignments."""
cache_loc_flat = torch.tensor(
[p_idx for pages in page_assignments for p_idx in pages],
dtype=torch.int,
pin_memory=True,
)
self.cache_loc[: len(cache_loc_flat)].copy_(cache_loc_flat, non_blocking=True)
pages_per_seq = torch.tensor(
[len(p) for p in page_assignments], dtype=torch.int, pin_memory=True
)
self.pages_per_seq[: len(pages_per_seq)].copy_(pages_per_seq, non_blocking=True)
self._extra_dynamic_shapes_callbacks[name] = dynamic_shape_callback
Constant = Union[int, float, str, None]

View File

@ -63,6 +63,7 @@ def scaled_dot_product_attention(
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
) -> torch.Tensor:
"""A carbon copy of torch.nn.functional.scaled_dot_product_attention as custom op.
@ -78,12 +79,13 @@ def scaled_dot_product_attention(
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
enable_gqa=enable_gqa,
)
@scaled_dot_product_attention.register_fake
def scaled_dot_product_attention_fake(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, enable_gqa=False
):
"""Fake implementation of scaled_dot_product_attention."""
return query.new_empty(*query.shape[:-1], value.shape[-1]).contiguous()

View File

@ -18,7 +18,7 @@ from ..transformations._graph import (
)
from ..utils.logger import ad_logger
from ..utils.node_utils import is_op
from .interface import ExportPatchRegistry, apply_export_patches
from .interface import apply_export_patches
try:
from modelopt.torch.quantization.utils import export_torch_mode as torch_export_context
@ -229,20 +229,9 @@ def torch_export_to_gm(
patch_list: Optional list of patch names to apply with default settings.
Cannot be used together with patch_configs.
"""
# Validate that both patch_configs and patch_list are not provided simultaneously
if patch_configs is not None and patch_list is not None:
raise ValueError("Cannot specify both patch_configs and patch_list. Use only one.")
# Handle patch configuration
if patch_list is not None:
# Convert patch_list to patch_configs format
patch_configs = {patch_name: {} for patch_name in patch_list}
elif patch_configs is None:
# Default patch configurations - apply all registered patches with default settings
patch_configs = {patch_name: {} for patch_name in ExportPatchRegistry.list_patches()}
# run export with patches and lifted to meta
with apply_export_patches(patch_configs), lift_to_meta(model) as state_dict:
with apply_export_patches(patch_configs, patch_list), lift_to_meta(model) as state_dict:
# clean up args, kwargs and move to correct device
args, kwargs = tree_to((args, kwargs or {}), device="meta")

View File

@ -5,7 +5,7 @@ This module defines the base classes and interfaces for all export patches.
from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import Any, Callable, Dict, List, Type, Union, final
from typing import Any, Callable, Dict, List, Optional, Type, Union, final
from pydantic import BaseModel, Field
@ -183,6 +183,8 @@ class ExportPatchRegistry:
@classmethod
def get(cls, name: str) -> Type[BaseExportPatch]:
"""Get a patch class by name."""
if not cls.has(name):
raise ValueError(f"Unknown patch: {name}")
return cls._registry[name]
@classmethod
@ -212,20 +214,29 @@ class ExportPatchRegistry:
@contextmanager
def apply_export_patches(patch_configs: Dict[str, Union[ExportPatchConfig, Dict[str, Any]]]):
def apply_export_patches(
patch_configs: Optional[Dict[str, Union[ExportPatchConfig, Dict[str, Any]]]] = None,
patch_list: Optional[List[str]] = None,
):
"""Context manager to apply multiple patches.
Args:
patch_configs: Dict mapping patch names to their configurations.
"""
patches = []
# Validate that both patch_configs and patch_list are not provided simultaneously
if patch_configs is not None and patch_list is not None:
raise ValueError("Cannot specify both patch_configs and patch_list. Use only one.")
# Handle patch configuration
if patch_list is not None:
# Convert patch_list to patch_configs format
patch_configs = {patch_name: {} for patch_name in patch_list}
elif patch_configs is None:
# Default patch configurations - apply all registered patches with default settings
patch_configs = {patch_name: {} for patch_name in ExportPatchRegistry.list_patches()}
# Create patch instances
for name, config in patch_configs.items():
if not ExportPatchRegistry.has(name):
raise ValueError(f"Unknown patch: {name}")
patch = ExportPatchRegistry.create_patch(name, config)
patches.append(patch)
patches = [ExportPatchRegistry.create_patch(k, conf) for k, conf in patch_configs.items()]
# Apply patches using nested context managers
if not patches:

View File

@ -1,19 +1,92 @@
import types
from typing import List, Optional
from typing import Any, Dict, List, Optional, Tuple
from ...executor.result import CompletionOutput
from ...inputs.registry import create_input_processor
from ...inputs.registry import DefaultInputProcessor, ExtraProcessedInputs
from ...llmapi.llm import RequestOutput, _TorchLLM
from ...llmapi.tokenizer import TokenizerBase, tokenizer_factory
from ...llmapi.tokenizer import TokenizerBase, TransformersTokenizer, tokenizer_factory
from ...sampling_params import SamplingParams
from .distributed import common as dist_ad
from .llm_args import LlmArgs
from .models.factory import ModelFactory
from .shim.demollm import DemoGenerationExecutor
class ADInputProcessor(DefaultInputProcessor):
"""Input processor for AutoDeploy backend.
This is a wrapper to either support standard TRT-LLM text-only input processing or use HF's
message chat template system to process multimodal inputs.
"""
def __init__(self, tokenizer: Optional[TokenizerBase], processor: Optional[Any] = None):
super().__init__(model_path=None, model_config=None, tokenizer=tokenizer)
# NOTE: HF's tokenizer/processor that has the apply_chat_template method
self.processor = processor or getattr(tokenizer, "tokenizer", None)
def __call__(
self, inputs: Dict[str, Any], sampling_params: SamplingParams
) -> Tuple[List[int], Optional[ExtraProcessedInputs]]:
if self.processor is None:
raise ValueError("processor is required to tokenize inputs")
# construct kwargs to reflect DefaultInputProcessor
kwargs = {
"add_special_tokens": sampling_params.add_special_tokens,
}
if sampling_params.truncate_prompt_tokens is not None:
kwargs = {
"truncation": True,
"max_length": sampling_params.truncate_prompt_tokens,
}
# check for messages field and if yes, use the apply_chat_template method
if "messages" in inputs:
# TODO: we don't really need this but it makes for a good sanity check. Consider
# removing this in the future if we need to speed things up.
prompt = self.processor.apply_chat_template(
inputs["messages"],
add_generation_prompt=True,
tokenize=False,
)
inputs["prompt"] = prompt
all_args = self.processor.apply_chat_template(
inputs["messages"],
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
padding=False, # there shouldn't be a need for padding ever...
return_attention_mask=False,
**kwargs,
)
# TODO: is there a more reliable way to avoid the attention_mask here?
all_args.pop("attention_mask", None)
# TODO: can we avoid the extra tolist() here eventually?
token_ids = all_args.pop("input_ids")
assert token_ids.shape[0] == 1, "messages should be unbatched at this point."
if all_args:
extra_processed_inputs = {"multimodal_data": all_args}
else:
extra_processed_inputs = None
return token_ids[0].tolist(), extra_processed_inputs
else:
token_ids = self.tokenizer.encode(inputs["prompt"], **kwargs)
return token_ids, None
class LLM(_TorchLLM):
"""LLM class is the main class for running an LLM model using AutoDeploy backend."""
args: LlmArgs
_factory: ModelFactory
@property
def factory(self) -> ModelFactory:
if not getattr(self, "_factory", None):
self._factory = self.args.create_factory()
return self._factory
def __init__(self, *args, **kwargs):
kwargs["backend"] = "_autodeploy"
@ -23,16 +96,18 @@ class LLM(_TorchLLM):
if self.args.skip_tokenizer_init:
return None
factory = self.args.create_factory()
return tokenizer_factory(factory.init_tokenizer())
return tokenizer_factory(self.factory.init_tokenizer())
def _validate_args_for_torch_backend(self, kwargs: dict) -> None:
"""We don't need to validate args for AutoDeploy backend for now."""
pass
def _create_input_processor(self) -> ADInputProcessor:
return ADInputProcessor(self.tokenizer, self.factory.init_processor())
def _prefetch_model(self):
"""Prefetch the model for the LLM."""
self.args.create_factory().prefetch_checkpoint()
self.factory.prefetch_checkpoint()
def _build_model(self):
"""Build the model for the LLM.
@ -47,6 +122,11 @@ class LLM(_TorchLLM):
# _autodeploy backend.
super()._build_model()
# now correct input processor
assert isinstance(self.input_processor, DefaultInputProcessor)
assert self.tokenizer is None or isinstance(self.tokenizer, TransformersTokenizer)
self.input_processor = self._create_input_processor()
class DemoLLM(LLM):
"""A simple LLM class to demo the LLM interface while debugging the e2e workflow.
@ -63,7 +143,7 @@ class DemoLLM(LLM):
# prefetch model and load tokenizer
self._prefetch_model()
self._tokenizer = self._try_load_tokenizer()
self.input_processor = create_input_processor(None, self.tokenizer)
self.input_processor = self._create_input_processor()
# construct demo executor + engine
self._executor = DemoGenerationExecutor(

View File

@ -57,7 +57,7 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings):
description="The path to the model checkpoint or the model name from the Hugging Face Hub."
)
model_factory: Literal["AutoModelForCausalLM", "AutoModelForImageTextToText"] = Field(
model_factory: str = Field(
default="AutoModelForCausalLM",
description="The model factory to use for loading the model.",
)

View File

@ -3,13 +3,13 @@
import copy
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Callable, Dict, Optional, Type
from typing import Any, Callable, Dict, Optional, Tuple, Type
import torch
import torch.nn as nn
from torch._prims_common import DeviceLikeType
from ..custom_ops.attention_interface import CacheConfig
from ..custom_ops.attention_interface import CacheConfig, DynamicShapeCallback
from ..utils.logger import ad_logger
@ -25,7 +25,8 @@ class ModelFactory(ABC):
NOTE: we make the assumption that the model can be prompted with a set of input_ids and
position_ids of shape (batch_size, seq_len) and will return a tensor of shape
(batch_size, seq_len, embedding_size).
(batch_size, seq_len, embedding_size) by default. Individual factories have the ability to
define additional optional inputs and their (dynamic) shapes.
"""
def __init__(
@ -74,7 +75,7 @@ class ModelFactory(ABC):
.. code-block:: python
def forward(
self, input_ids: torch.Tensor, position_ids: torch.Tensor
self, input_ids: torch.Tensor, position_ids: torch.Tensor, *extra_args: torch.Tensor
) -> Sequence[torch.Tensor]: ...
``logits`` are assumed to be the first output of the model, i.e.,
@ -87,6 +88,9 @@ class ModelFactory(ABC):
input_ids.shape == (batch_size, seq_len)
position_ids.shape == (batch_size, seq_len)
logits.shape == (batch_size, seq_len, vocab_size)
We allow for additional arguments to be passed to the model's forward function as defined by
the factory.
"""
# make sure model architecture is pre-fetched (no weights needed at this point)
skip_loading_weights = self.skip_loading_weights
@ -127,6 +131,15 @@ class ModelFactory(ABC):
"""
return None
def init_processor(self) -> Optional[Any]:
"""Initialize the (multi-modal) processor for the model.
Returns:
The initialized processor for the model. If the processor is not available, then this
method should return None.
"""
return None
def prefetch_checkpoint(self, force: bool = False):
"""Try or skip prefetching the checkpoint for the model and tokenizer.
@ -220,6 +233,35 @@ class ModelFactory(ABC):
device: The device to load the model on.
"""
def get_example_inputs(self) -> Dict[str, torch.Tensor]:
"""Return a dictionary of example inputs for the model.
This function can be overwritten by a factory when it requires a specific example input to
in order to run through export.
Returns:
A dictionary of example inputs for the model where the key corresponds to the argument
name and the value corresponds to the example input.
"""
return {}
def get_extra_inputs(self) -> Dict[str, Tuple[torch.Tensor, Optional[DynamicShapeCallback]]]:
"""Return a dictionary of extra model inputs that behave like optional forward arguments.
Returns:
A dictionary of extra inputs for the model where the key corresponds to the argument
name and the value corresponds to a tuple of (none_input, dynamic_shape_callback):
- `none_input`: The none input value of the extra input indicating the tensor
value corresponding to the equivalent of the None input. `None` is not supported
as we require the input to be a tensor. Hence, this none_input acts as a
placeholder for the None input. We assume that the "optional" behavior of these
arguments can be represented via a placeholder tensor and and an appropriate
check within the forward function using ``torch.cond``.
- `dynamic_shape_callback`: A function that returns the dynamic shape of the extra
input. Simply set to `None` if the extra input is not dynamic.
"""
return {}
class ModelFactoryRegistry:
_registry: Dict[str, Type[ModelFactory]] = {}

View File

@ -3,7 +3,7 @@
import os
import types
from contextlib import contextmanager, nullcontext
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
@ -11,11 +11,13 @@ from accelerate import init_empty_weights, load_checkpoint_in_model
from accelerate.utils import modeling
from huggingface_hub import HfApi, snapshot_download
from huggingface_hub.utils import HFValidationError, filter_repo_objects, validate_repo_id
from PIL import Image
from torch._prims_common import DeviceLikeType
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoModelForImageTextToText,
AutoProcessor,
AutoTokenizer,
PretrainedConfig,
)
@ -26,7 +28,7 @@ from transformers.utils import (
WEIGHTS_NAME,
)
from ..custom_ops.attention_interface import CacheConfig
from ..custom_ops.attention_interface import CacheConfig, Dim, DynamicShapeCallback
from ..utils._config import deep_merge_dicts
from ..utils.logger import ad_logger
from .factory import ModelFactory, ModelFactoryRegistry, ShardingConfigSource
@ -101,10 +103,6 @@ class AutoModelForCausalLMFactory(ModelFactory):
def autoconfig_from_pretrained(self):
return AutoConfig.from_pretrained
@property
def autotokenizer_from_pretrained(self):
return AutoTokenizer.from_pretrained
# TODO (@lucaslie): Do we ever want to switch to from_pretrained?
@property
def automodel_from_config(self):
@ -114,7 +112,9 @@ class AutoModelForCausalLMFactory(ModelFactory):
def _simple_forward(model: nn.Module, input_ids: torch.Tensor, position_ids: torch.Tensor):
"""A simple forward pass for the model to functionalize the args.
This follows the standard function signature as expected by factory.py.
This follows the standard function signature as expected by factory.py. We do _not_ use the
model.forward method directly to create the patch. Instead we use the type of the model to
get the forward method to keep the patch composable with other forward patches.
"""
return type(model).forward(model, input_ids=input_ids, position_ids=position_ids)
@ -158,11 +158,13 @@ class AutoModelForCausalLMFactory(ModelFactory):
with (init_empty_weights if device == "meta" else nullcontext)():
model = self.automodel_from_config(model_config, trust_remote_code=True)
# post-init --> this must be called explicitly for HF models the way we initialize them since
# this "gets lost" with the init_empty_weights context manager.
if hasattr(model, "post_init"):
model.post_init()
if device == "meta":
# post-init --> this must be called explicitly for HF models the way we initialize them
# since this "gets lost" with the init_empty_weights context manager.
if hasattr(model, "post_init"):
model.post_init()
else:
model.to(device)
# if present, initialize sharding config. We need head_dim for colwise sharding.
self._set_sharding_config(model.config)
@ -211,7 +213,7 @@ class AutoModelForCausalLMFactory(ModelFactory):
"""Initialize the tokenizer—either a custom name or the model's default."""
if self.tokenizer is None:
return None
return self.autotokenizer_from_pretrained(self.tokenizer, **self.tokenizer_kwargs)
return AutoTokenizer.from_pretrained(self.tokenizer, **self.tokenizer_kwargs)
@staticmethod
def _get_ignore_patterns(repo_id: str, skip_prefetch_weights: bool) -> List[str]:
@ -375,3 +377,102 @@ class AutoModelForImageTextToTextFactory(AutoModelForCausalLMFactory):
@property
def automodel_from_config(self):
return AutoModelForImageTextToText.from_config
def init_tokenizer(self) -> Optional[Any]:
"""Initialize the tokenizer—either a custom name or the model's default."""
processor = self.init_processor()
if processor is None:
return None
return processor.tokenizer
def init_processor(self) -> Optional[Any]:
"""Initialize the processor for the model."""
if self.tokenizer is None:
return None
return AutoProcessor.from_pretrained(self.tokenizer, **self.tokenizer_kwargs)
@staticmethod
def _simple_forward(
model: nn.Module,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
pixel_values: torch.Tensor,
):
"""A simple forward pass for the model to functionalize the args.
This follows the standard function signature as expected by factory.py. We do _not_ use the
model.forward method directly to create the patch. Instead we use the type of the model to
get the forward method to keep the patch composable with other forward patches.
"""
return type(model).forward(
model,
input_ids=input_ids,
position_ids=position_ids,
pixel_values=pixel_values,
)
def get_example_inputs(self) -> Dict[str, torch.Tensor]:
"""Return a dictionary of example inputs for the model."""
def _prep_seq(text, img1, img2):
return [
{
"role": "user",
"content": [
{"type": "image", "image": img1},
{"type": "image", "image": img2},
{"type": "text", "text": text},
],
}
]
# Create a batch of conversations (batch_size = 2)
batch_messages = [
_prep_seq(
"Describe what you see in the two images and their differences.",
Image.new("RGB", (16, 16), color=(128, 128, 128)),
Image.new("RGB", (16, 16), color=(64, 64, 64)),
),
_prep_seq(
"What are the main differences between these two images?",
Image.new("RGB", (16, 16), color=(255, 0, 0)),
Image.new("RGB", (16, 16), color=(0, 255, 0)),
),
]
processor = AutoProcessor.from_pretrained(self.tokenizer, **self.tokenizer_kwargs)
inputs = processor.apply_chat_template(
batch_messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
padding=True,
return_attention_mask=False,
)
return {
"input_ids": inputs["input_ids"],
"pixel_values": inputs["pixel_values"],
}
def get_extra_inputs(self) -> Dict[str, Tuple[torch.Tensor, Optional[DynamicShapeCallback]]]:
"""Return a dictionary of extra inputs for the model.
Returns:
A dictionary of extra inputs for the model where the key corresponds to the argument
name and the value corresponds to a tuple of (example_input, dynamic_shape_callback).
The dynamic shape callback is a function that returns the dynamic shape of the extra
input. Simply set to `None` if the extra input is not dynamic.
"""
def _get_dynamic_shape():
return {
# TODO (lucaslie): how to set default values for dynamic shapes?
0: Dim("img_batch_size", max=10),
2: Dim("img_height", min=32, max=2048),
3: Dim("img_width", min=32, max=2048),
}
none_pixel_values = torch.zeros(0, 3, 336, 336)
return {"pixel_values": (none_pixel_values, _get_dynamic_shape)}

View File

@ -1,15 +1,13 @@
"""A patch to handle vision branch in Llama4ForConditionalGeneration."""
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from _model_test_utils import _hf_model_dir_or_hub_id
from PIL import Image
from transformers import AutoConfig, AutoProcessor, Llama4ForConditionalGeneration
from transformers.models.llama4.modeling_llama4 import Llama4CausalLMOutputWithPast
from utils.llm_data import llm_models_root
from transformers import Llama4ForConditionalGeneration
from transformers.models.llama4.modeling_llama4 import Llama4CausalLMOutputWithPast, Llama4TextMoe
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
from tensorrt_llm._torch.auto_deploy.transformations._graph import move_to_device
from ...export.interface import BaseExportPatch, ExportPatchRegistry
# Copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama4/modeling_llama4.py#L1651
@ -76,30 +74,34 @@ def _forward_with_cond(
vision_feature_select_strategy=vision_feature_select_strategy,
image_sizes=None,
)
original_inputs_embeds_shape = inputs_embeds.shape
vision_flat = image_features.view(-1, image_features.size(-1))
projected_vision_flat = self.multi_modal_projector(vision_flat)
projected_vision_flat = self.multi_modal_projector(vision_flat).to(
inputs_embeds.device, inputs_embeds.dtype
)
# NOTE: get_placeholder_mask is not supported by torch.export due to numel check ###########
# special_image_mask = self.get_placeholder_mask(
# input_ids, inputs_embeds=inputs_embeds, image_features=projected_vision_flat
# )
if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(
self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device
)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.config.image_token_id
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
final_mask = special_image_mask.to(inputs_embeds.device)
inputs_embeds = inputs_embeds.view(-1, inputs_embeds.size(-1))
# n_image_tokens = special_image_mask.sum()
special_image_mask = (
special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
)
### END OF get_placeholder_mask ############################################################
final_mask_1d = final_mask[..., 0].reshape(-1)
# num_tokens_to_fill = final_mask_1d.sum()
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, projected_vision_flat)
# This condition statement breaks torch.export:
# TODO: sanity check on the inputs for this
# if num_tokens_to_fill != projected_vision_flat.size(0):
# raise ValueError(
# f"Mismatch: final_mask wants {num_tokens_to_fill} embeddings, "
# f"but multi_modal_projector returned {projected_vision_flat.size(0)}"
# )
expanded_mask = final_mask_1d.unsqueeze(-1).expand(-1, inputs_embeds.size(-1))
inputs_embeds.masked_scatter_(expanded_mask, projected_vision_flat)
return inputs_embeds.view(original_inputs_embeds_shape)
return inputs_embeds
def _no_vision_branch(inputs_embeds, pixel_values, input_ids):
return inputs_embeds
@ -132,7 +134,10 @@ def _forward_with_cond(
loss = None
if labels is not None:
# Shift so that tokens < n predict n
if attention_mask is not None:
# we use the input attention mask to shift the logits and labels, because it is 2D.
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
shift_logits = logits[..., :-1, :][
shift_attention_mask.to(logits.device) != 0
@ -141,6 +146,7 @@ def _forward_with_cond(
else:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)),
@ -161,81 +167,65 @@ def _forward_with_cond(
)
def test_build_run_llama4_vlm():
atol = 1e-3
rtol = 1e-3
@ExportPatchRegistry.register("hf_llama4_vision")
class Llama4VisionPatch(BaseExportPatch):
"""Patch for Llama4ForConditionalGeneration to make it compatible with torch.export.
model_id = _hf_model_dir_or_hub_id(
f"{llm_models_root()}/Llama-4-Scout-17B-16E-Instruct",
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
)
processor = AutoProcessor.from_pretrained(model_id)
This patch replaces the forward method of Llama4ForConditionalGeneration with
a version that uses the torch.cond to handle the optional vision branch.
"""
config = AutoConfig.from_pretrained(model_id)
config.text_config.num_hidden_layers = 2
config.text_config.intermediate_size = 64
config.text_config.intermediate_size_mlp = 128
config.vision_config.num_hidden_layers = 2
# The returned cache <class 'transformers.cache_utils.HybridChunkedCache'> breaks torch.export
config.text_config.use_cache = False
model = Llama4ForConditionalGeneration(config).eval().to("cuda").bfloat16()
img1 = Image.new("RGB", (16, 16), color=(128, 128, 128))
img2 = Image.new("RGB", (16, 16), color=(64, 64, 64))
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": img1},
{"type": "image", "image": img2},
{"type": "text", "text": "What's the difference?"},
],
},
]
inputs = (
processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
def _apply_patch(self):
"""Apply the Llama4 vision patch."""
# Store original forward method
self.original_values["Llama4ForConditionalGeneration.forward"] = (
Llama4ForConditionalGeneration.forward
)
.to(model.device)
.to(torch.bfloat16)
)
with torch.inference_mode():
# the original model queried with text-only
out_text_only = model(inputs["input_ids"], None, inputs["attention_mask"])
# Apply patch by replacing the forward method
Llama4ForConditionalGeneration.forward = _forward_with_cond
Llama4ForConditionalGeneration.forward = _forward_with_cond
def _revert_patch(self):
"""Revert the Llama4 vision patch."""
# Restore original forward method
Llama4ForConditionalGeneration.forward = self.original_values[
"Llama4ForConditionalGeneration.forward"
]
with torch.inference_mode():
out_real = model(inputs["input_ids"], inputs["pixel_values"], inputs["attention_mask"])
out_dummy = model(
inputs["input_ids"], torch.zeros_like(inputs["pixel_values"]), inputs["attention_mask"]
)
torch.testing.assert_close(out_dummy.logits, out_text_only.logits, rtol=rtol, atol=atol)
gm = torch_export_to_gm(
model,
(inputs["input_ids"], inputs["pixel_values"], inputs["attention_mask"]),
kwargs={},
)
move_to_device(gm, model.device)
def _moe_forward_with_transpose(self, hidden_states):
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
router_scores, router_logits = self.router(hidden_states)
routed_in = hidden_states.repeat(router_scores.shape[1], 1)
with torch.inference_mode():
out_real_gm = gm(inputs["input_ids"], inputs["pixel_values"], inputs["attention_mask"])
torch.testing.assert_close(out_real.logits, out_real_gm.logits, rtol=rtol, atol=atol)
out_dummy_gm = gm(
inputs["input_ids"], torch.zeros_like(inputs["pixel_values"]), inputs["attention_mask"]
)
torch.testing.assert_close(out_dummy.logits, out_dummy_gm.logits, rtol=rtol, atol=atol)
torch.testing.assert_close(out_dummy_gm.logits, out_text_only.logits, rtol=rtol, atol=atol)
# BUG IN ORIGINAL CODE
# routed_in = routed_in * router_scores.reshape(-1, 1)
# END OF BUG IN ORIGINAL CODE
assert not torch.allclose(out_real.logits, out_dummy.logits, rtol=rtol, atol=atol), (
"Expected outputs to differ between text only input and text+image input"
)
# PATCH STARTED
routed_in = routed_in * router_scores.transpose(0, 1).reshape(-1, 1)
# PATCH ENDED
routed_out = self.experts(routed_in)
out = self.shared_expert(hidden_states)
out.add_(routed_out.reshape(router_scores.shape[1], -1, routed_out.shape[-1]).sum(dim=0))
return out, router_logits
# TODO: remove this patch once https://github.com/huggingface/transformers/pull/40609 is merged,
# gets released, and TRT-LLM updates to the relevant transformers version
@ExportPatchRegistry.register("hf_llama4_moe")
class Llama4MoEPatch(BaseExportPatch):
"""Patch for Llama4 MoE routing to fix its current accuracy issue."""
def _apply_patch(self):
"""Apply the Llama4 MoE routing patch."""
# Store original forward method
self.original_values["Llama4TextMoe.forward"] = Llama4TextMoe.forward
# Apply patch by replacing the forward method
Llama4TextMoe.forward = _moe_forward_with_transpose
def _revert_patch(self):
"""Revert the Llama4 MoE routing patch."""
Llama4TextMoe.forward = self.original_values["Llama4TextMoe.forward"]

View File

@ -1,6 +1,6 @@
from itertools import chain
from collections import defaultdict
from types import SimpleNamespace
from typing import List, Optional, Tuple
from typing import Dict, List, Optional, Tuple
import torch
from torch._prims_common import DeviceLikeType
@ -105,13 +105,19 @@ class ADEngine(ModelEngine):
max_batch_size=max_batch_size,
page_size=attn_page_size,
max_num_tokens=max_num_tokens,
device=device,
)
factory = ad_config.create_factory()
# pass in extra arguments defined by the model factory
for name, (none_input, dynamic_shape_callback) in factory.get_extra_inputs().items():
seq_info.add_extra_arg(name, none_input, dynamic_shape_callback)
# TODO (lucaslie): consider how we move args around InferenceOptimizer.__init__,
# ADEngine.__init__, and ADEngine.build_from_config. Seems a bit unnatural atm.
# construct inference optimizer
build_and_optimize = InferenceOptimizer(
factory=ad_config.create_factory(), ad_config=ad_config
)
build_and_optimize = InferenceOptimizer(factory=factory, ad_config=ad_config)
# construct engine
return cls(build_and_optimize, seq_info, device, max_beam_width)
@ -176,7 +182,11 @@ class ADEngine(ModelEngine):
input_pos: List[int] = []
last_logit_only: List[bool] = []
page_assignments: List[List[int]] = []
previous_batch_indices: List[int] = []
flat_gather_idx: List[int] = []
extra_args: Dict[str, List[torch.Tensor]] = defaultdict(list)
dummy_token = -1
# look at context requests first
for request in context_requests:
# store input ids and pos of first token in sequence
@ -186,6 +196,15 @@ class ADEngine(ModelEngine):
request.py_batch_idx = request.seq_slot
last_logit_only.append(True)
# get cache indices
cache_indices = kv_cache_manager.get_cache_indices(request)
page_assignments.append(cache_indices)
# store extra arguments
if request.py_multimodal_data is not None:
for k, v in request.py_multimodal_data.items():
extra_args[k].append(v)
# look at generate requests next
# TODO: we should also handle extend requests (for speculative decoding) here
for request in gen_requests:
@ -194,30 +213,34 @@ class ADEngine(ModelEngine):
input_ids.append([request.get_token(0, request.get_num_tokens(0) - 1)])
input_pos.append(request.max_beam_num_tokens - 1)
else:
# insert a dummy token to indicate the new tokens
input_ids.append([-1])
previous_batch_indices.append(request.py_batch_idx)
input_ids.append([dummy_token])
input_pos.append(request.max_beam_num_tokens)
flat_gather_idx.append(request.py_batch_idx)
request.py_batch_idx = request.seq_slot
# return all logits
last_logit_only.append(False)
# extract cache information for all requests
for request in chain(context_requests, gen_requests):
# get cache indices
cache_indices = kv_cache_manager.get_cache_indices(request)
page_assignments.append(cache_indices)
# update the sequence info object now
si = self.cache_seq_interface.info
si.update_pos(input_pos, reset=True)
si.assign_cache_loc(page_assignments)
si.nest_sequences(input_ids)
self.cache_seq_interface.info.nest_sequences(
input_ids,
input_pos=input_pos,
page_assignments=page_assignments,
**extra_args,
)
# scatter the new tokens into the input_ids tensor if provided
if new_tokens is not None:
si.update_input_ids_with_new_tokens(new_tokens, previous_batch_indices)
self.cache_seq_interface.info.rescatter_input_ids(
ungathered_input_ids=new_tokens.flatten(), # ensure it's flattened
gather_idx=flat_gather_idx,
scatter_ref=dummy_token,
)
return last_logit_only
@nvtx_range("ad_compute_logits")

View File

@ -1,8 +1,9 @@
"""A demo LLM api to for debugging and testing purposes of e2e workflows."""
import gc
from collections import defaultdict
from queue import Empty
from typing import Any, Callable, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
import torch.multiprocessing as mp
@ -10,6 +11,7 @@ import torch.multiprocessing as mp
from ....executor import GenerationExecutor
from ....executor.request import GenerationRequest
from ....executor.result import CompletionOutput, GenerationResult
from ....inputs.multimodal import MultimodalParams
from ....sampling_params import SamplingParams
from ...pyexecutor.sampler import greedy_search_sampling_batch, top_k_sampling_batch
from ..distributed import common as dist_ad
@ -34,8 +36,11 @@ class DemoEngine(ADEngine):
self.queue = mp.Queue()
@torch.inference_mode()
def __call__(self, requests: GenerationRequest) -> mp.Queue:
def __call__(
self, requests: GenerationRequest, multimodal_params: Optional[MultimodalParams]
) -> mp.Queue:
"""Generate tokens and put the results in a queue and return the queue."""
requests.multimodal_params = multimodal_params
output = self.generate_tokens_batched([requests])[0]
self.queue.put(output)
return self.queue
@ -45,7 +50,7 @@ class DemoEngine(ADEngine):
self.queue.close()
self.queue.join_thread()
def _assign_pages(self) -> List[List[int]]:
def _assign_pages(self, total_lens: List[int]) -> List[List[int]]:
"""A simple heuristic to assign pages based on current sequence info.
In a nutshell, we will look at the following information to update the page assignments:
@ -67,7 +72,6 @@ class DemoEngine(ADEngine):
unassigned page if needed.
"""
si = self.cache_seq_interface.info
total_lens = [s_l + i_p for s_l, i_p in zip(si.sequence_lengths, si.input_positions)]
page_assignments = si.page_assignments
free_pages = set(range(si.num_pages)) - {i for pages in page_assignments for i in pages}
@ -76,7 +80,7 @@ class DemoEngine(ADEngine):
extra_tokens = t_l - len(pages) * si.page_size
num_extra_pages = (extra_tokens // si.page_size) + (extra_tokens > 0)
updated_assignments.append(pages + [free_pages.pop() for _ in range(num_extra_pages)])
si.assign_cache_loc(updated_assignments)
return updated_assignments
def generate_tokens_batched(
self, requests: List[GenerationRequest]
@ -91,10 +95,27 @@ class DemoEngine(ADEngine):
)
assert sampling_params.best_of == 1, "Best-of is not supported."
# set up sequence info object
# set up sequence info object for decode phase
sequence_info = self.cache_seq_interface.info
input_ids = []
total_lens = []
extra_args: Dict[str, List[torch.Tensor]] = defaultdict(list)
for request in requests:
total_lens.append(len(request.prompt_token_ids))
input_ids.append(request.prompt_token_ids)
if request.multimodal_params is not None:
for k, v in request.multimodal_params.multimodal_data.items():
extra_args[k].append(v)
sequence_info.reset()
sequence_info.nest_sequences([r.prompt_token_ids for r in requests])
sequence_info.nest_sequences(
input_ids=input_ids,
input_pos=0,
page_assignments=self._assign_pages(total_lens),
**extra_args,
)
# setup objects we want to track for the output
batch_size = sequence_info.num_sequences
@ -105,18 +126,21 @@ class DemoEngine(ADEngine):
context_logits: Optional[List[torch.Tensor]] = None
def _generate_single_step(idx: int):
# assign pages
self._assign_pages()
# get the logits and then last token logits in each sequence ([b, 1, vocab_size])
logits = self._compute_logits()
logits_last = torch.stack([l_one_seq[-1] for l_one_seq in logits]).float().unsqueeze(1)
token_ids, _ = self._decode_tokens(logits_last, sampling_params) # [b,1]
# update sequence info accordingly for next step
sequence_info.update_pos(sequence_info.sequence_lengths)
sequence_info.nest_sequences(token_ids)
# update sequence info accordingly for next step (generate phase)
input_pos_next = sequence_info.input_pos
seq_lens_current = sequence_info.seq_len
input_pos_next = [ip + sl for ip, sl in zip(input_pos_next, seq_lens_current)]
total_lens_next = [ip + len(t_ids) for ip, t_ids in zip(input_pos_next, token_ids)]
sequence_info.nest_sequences(
token_ids,
input_pos=input_pos_next,
page_assignments=self._assign_pages(total_lens_next),
)
# nest new tokens and run stop check
for b, (new_tokens_b, new_id) in enumerate(zip(new_tokens, token_ids)):
@ -255,6 +279,7 @@ class DemoGenerationExecutor(GenerationExecutor):
def _unpack(inputs) -> GenerationRequest:
args, kwargs = inputs # unpack the inputs
request: GenerationRequest = args[0]
request.multimodal_params: Optional[MultimodalParams] = args[1]
return request
engine = DemoEngine.build_from_config(**engine_kwargs)
@ -309,8 +334,11 @@ class DemoGenerationExecutor(GenerationExecutor):
request.set_id(client_id)
# submit request to our demo engine and store results
# NOTE: when returning from this function, the reference request.multimodal_params will
# be cleared immediately. So we pass it in explicitly to maintain a reference even when
# requests get submitted asynchronously.
result = GenerationResult(request)
result.queue = self.engine_executor(request)
result.queue = self.engine_executor(request, request.multimodal_params)
return result

View File

@ -63,7 +63,7 @@ class ExportToGM(BaseTransform):
model = gm.get_submodule("factory_model")
# set the example sequence
cm.info.set_example_sequence()
cm.info.set_example_sequence(**factory.get_example_inputs())
# export the model to a graph module
gm = torch_export_to_gm(

View File

@ -44,9 +44,6 @@ class UpdateInOutNodes(BaseTransform):
# loop through nodes to get input, output, and get_attr nodes
input_nodes, output_nodes = get_all_input_output_nodes(gm.graph)
# we only expect one input node
assert len(input_nodes) == 2, "Expected exactly two input nodes (input_ids, position_ids)."
# NOTE: for now, we wanna make sure we *only* return the final output and no hidden states.
# Later on, we can revisit how to support returning hidden states.
assert len(output_nodes) == 1, "Expected exactly one output node!"
@ -117,16 +114,17 @@ class InsertCachedAttention(BaseTransform):
# retrieve input nodes
input_nodes, _ = get_all_input_output_nodes(gm.graph)
input_nodes_mapping = {n.target: n for n in input_nodes}
# filtered and sorted for SequenceInfo arguments + constants (input_ids, position_ids, etc.)
inputs_from_info = [input_nodes_mapping[k] for k in cm.info.named_standard_args.keys()]
constants_from_info = cm.info.const_args_for_prepare_metadata
# insert metadata computation and extract each argument as a node
get_metadata, num_metadata = attn_descriptor.get_prepare_metadata_op()
with graph.inserting_before(input_nodes[-1].next):
ret_node = graph.call_function(
get_metadata,
args=(
*input_nodes,
cm.info.page_size,
),
get_metadata, args=(*inputs_from_info, *constants_from_info)
)
metadata_nodes = [
graph.call_function(operator.getitem, args=(ret_node, idx))
@ -244,7 +242,7 @@ class ResizeKVCache(BaseTransform):
try:
# Let's run a forward pass to get the memory usage
cm.info._set_max_num_tokens_sample()
cm.info.set_max_num_tokens_sample()
free_mem_pre, _ = _get_mem_info_in_mb()
ad_logger.info(f"Free memory before forward pass (MB): {free_mem_pre}")

View File

@ -18,8 +18,7 @@ from .._utils import (KVCacheEventSerializer, global_mpi_rank, global_mpi_size,
mpi_comm, mpi_rank, nvtx_range_debug)
from ..bindings import executor as tllm
from ..builder import ConfigEncoder, Engine, EngineConfig
from ..llmapi.llm_args import (BaseLlmArgs, KvCacheConnectorConfig,
PybindMirror, TorchLlmArgs)
from ..llmapi.llm_args import BaseLlmArgs, KvCacheConnectorConfig, PybindMirror
from ..llmapi.mpi_session import set_mpi_session_cpp
from ..llmapi.tokenizer import TokenizerBase
from ..llmapi.tracer import VizTracer, global_tracer, set_global_tracer
@ -86,7 +85,9 @@ class GenerationExecutorWorker(GenerationExecutor):
self._await_response_helper = AwaitResponseHelper(
self) # TODO: make it weakref
self._executor_config = executor_config
self._is_pytorch_backend = llm_args is not None and llm_args.backend == "pytorch"
self._is_pytorch_backend = llm_args is not None and llm_args.backend in [
"pytorch", "_autodeploy"
]
self.llm_args = llm_args
if not self._is_pytorch_backend and kv_connector_config is not None:
@ -468,7 +469,6 @@ class GenerationExecutorWorker(GenerationExecutor):
)
if self._is_pytorch_backend:
assert isinstance(self.llm_args, TorchLlmArgs)
if not self.llm_args.disable_overlap_scheduler:
is_disaggregated = self.engine.kv_cache_transceiver is not None
if is_disaggregated and (

View File

@ -2317,6 +2317,7 @@ class TestQwen3_8B(LlmapiAccuracyTestHarness):
MODEL_NAME = "Qwen3/Qwen3-8B"
@skip_pre_hopper
@pytest.mark.skip(reason="https://nvbugs/5505402")
@pytest.mark.parametrize(
"tp_size,pp_size,ep_size,attention_dp,cuda_graph,overlap_scheduler",
[(1, 1, 1, False, True, True)],

View File

@ -37,6 +37,7 @@ l0_b200:
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[True-True-cutlass]
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[True-True-trtllm]
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[True-True-triton]
- accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_fp8_block_scales[latency] # nvbugs 5505402
- disaggregated/test_workers.py::test_workers_kv_cache_aware_router_eviction[TinyLlama-1.1B-Chat-v1.0] # nvbugs 5300551
- test_e2e.py::test_ptp_quickstart_advanced[Llama3.1-8B-NVFP4-nvfp4-quantized/Meta-Llama-3.1-8B]
- test_e2e.py::test_ptp_quickstart_advanced[Llama3.1-8B-FP8-llama-3.1-model/Llama-3.1-8B-Instruct-FP8]

View File

@ -86,33 +86,12 @@ l0_dgx_h100:
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_w4afp8[MoEWeightLoadingMode.VANILLA-dtype1]
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_w4afp8[MoEWeightLoadingMode.W4A8_CUSTOM-dtype0]
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_w4afp8[MoEWeightLoadingMode.W4A8_CUSTOM-dtype1]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=True-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=False-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=2-fp8kv=True-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=2-fp8kv=False-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[ep4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[ep4-mtp_nextn=0-fp8kv=True-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[ep4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[ep4-mtp_nextn=0-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[ep4-mtp_nextn=2-fp8kv=True-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[ep4-mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[ep4-mtp_nextn=2-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=0-fp8kv=True-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=0-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=2-fp8kv=False-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[pp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[pp4-mtp_nextn=0-fp8kv=False-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=True-attention_dp=False-cuda_graph=True-overlap_scheduler=True-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=False-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[pp4-mtp_nextn=0-fp8kv=True-attention_dp=False-cuda_graph=True-overlap_scheduler=True-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=0-fp8kv=True-attention_dp=False-cuda_graph=True-overlap_scheduler=True-torch_compile=True]

View File

@ -50,6 +50,27 @@ l0_dgx_h200:
stage: post_merge
backend: pytorch
tests:
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=True-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=False-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=2-fp8kv=True-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=2-fp8kv=False-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[ep4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[ep4-mtp_nextn=0-fp8kv=True-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=0-fp8kv=True-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=2-fp8kv=False-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=False-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[ep4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[ep4-mtp_nextn=2-fp8kv=True-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[ep4-mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[pp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[tp4-mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[tp4-mtp_nextn=0-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[tp4-mtp_nextn=0-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False]

View File

@ -1,5 +1,5 @@
import copy
from typing import Callable, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional, Sequence
import numpy as np
import torch
@ -39,19 +39,45 @@ class FakeFactory(ModelFactory):
class SequenceEmbeddingInfo(SequenceInfo):
hidden_size: int
dtype: torch.dtype
"""A sequence info object for testing that replaces the input_ids with an embedding tensor.
def set_example_sequence(self) -> None:
super().set_example_sequence()
# set input ids to a 3D tensor (actually input embeddings)
self.input_ids = torch.rand(
*self.input_ids.shape,
This is useful to run tests without the tokenizer in the loop.
"""
def _add_hidden_dim(self, input_ids: Sequence[Sequence[Any]]) -> torch.Tensor:
return torch.rand(
*input_ids.shape,
self.hidden_size,
device=self.input_ids.device,
device=self.device,
dtype=self.dtype,
)
def __init__(self, *args, hidden_size: int, dtype: torch.dtype, **kwargs):
self._initialized = False
super().__init__(*args, **kwargs)
# overwrite input_ids with an embedding tensor and run reset again
self.hidden_size = hidden_size
self.dtype = dtype
self._args_device["input_ids"] = self._add_hidden_dim(self._args_device["input_ids"])
self._args_host["input_ids"] = self._args_device["input_ids"].cpu()
self._initialized = True
self.reset()
def nest_sequences(self, input_ids: Sequence[Sequence[Any]], *args, **kwargs) -> None:
# convert input_ids to an embedding tensor if needed
if not (isinstance(input_ids, torch.Tensor) and input_ids.ndim == 3) and self._initialized:
# first convert to a list of tensors
input_embeds = [
torch.tensor(ids, device=self.device, dtype=self.dtype) for ids in input_ids
]
# then add the hidden dimension to every tensor
input_embeds = [self._add_hidden_dim(ids) for ids in input_embeds]
else:
input_embeds = input_ids
super().nest_sequences(input_embeds, *args, **kwargs)
def count_parameters(model: torch.nn.Module):
for n, p in model.named_parameters():

View File

@ -400,9 +400,6 @@ _SMALL_MODEL_CONFIGS = {
},
"vision_config": {
"num_hidden_layers": 1,
"hidden_size": 64,
"intermediate_size": 64,
"num_attention_heads": 2,
},
},
},

View File

@ -0,0 +1,108 @@
import torch
from _model_test_utils import get_small_model_config
from build_and_run_ad import ExperimentConfig
from PIL import Image
from tensorrt_llm._torch.auto_deploy import LlmArgs
from tensorrt_llm._torch.auto_deploy.export import apply_export_patches, torch_export_to_gm
from tensorrt_llm._torch.auto_deploy.transformations._graph import move_to_device
def test_build_run_llama4_vlm():
atol = 1e-3
rtol = 1e-3
experiment_config = get_small_model_config("meta-llama/Llama-4-Scout-17B-16E-Instruct")
experiment_config["args"]["model_kwargs"]["_attn_implementation"] = "eager"
experiment_config = ExperimentConfig(**experiment_config)
llm_args: LlmArgs = experiment_config.args
factory = llm_args.create_factory()
model = factory.build_model("cuda")
processor = factory.init_processor()
img1 = Image.new("RGB", (16, 16), color=(128, 128, 128))
img2 = Image.new("RGB", (16, 16), color=(64, 64, 64))
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": img1},
{"type": "image", "image": img2},
{
"type": "text",
"text": "Describe what you see in the two images and their differences.",
},
],
},
]
inputs = (
processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
)
.to(model.device)
.to(torch.bfloat16)
)
# get relevant inputs
input_ids = inputs["input_ids"]
position_ids = torch.arange(input_ids.shape[1], device=input_ids.device).repeat(
input_ids.shape[0], 1
)
pixel_values = inputs["pixel_values"]
def _run_with_and_without_image(model, use_patch=True):
with apply_export_patches(patch_list=["hf_llama4_vision"] if use_patch else []):
with torch.inference_mode():
out_no_images = model(
input_ids,
position_ids,
torch.zeros_like(pixel_values) if use_patch else None,
)
out_with_images = model(
input_ids,
position_ids,
pixel_values,
)
return {"no_images": out_no_images.logits, "with_images": out_with_images.logits}
# Get output pre-patch
out_original = _run_with_and_without_image(model, use_patch=False)
# Get output post-patch
outputs_for_comparison = {}
outputs_for_comparison["model_with_patch"] = _run_with_and_without_image(model)
# Export to GM
gm = torch_export_to_gm(
model,
args=(input_ids, position_ids, pixel_values),
patch_list=[
"transformers_sdpa_mask",
"autocast_noop",
"torch_where",
"tensor_meta_device",
"sdpa_kernel_noop",
"sdpa",
"hf_llama4_vision",
],
)
move_to_device(gm, model.device)
# Get the output post export
outputs_for_comparison["gm"] = _run_with_and_without_image(gm)
# Run comparisons to out_original with no patch now...
for comp, outs in outputs_for_comparison.items():
torch.testing.assert_close(
outs,
out_original,
rtol=rtol,
atol=atol,
msg=lambda m: f"Comparison for {comp} failed:\n{m}",
)

View File

@ -71,7 +71,6 @@ def test_engine(engine_cls: Type[ADEngine], attn_backend: str, attn_page_size: i
input_ids = [torch.tensor([0, 1, 2], device=device)]
sequence_info.reset()
sequence_info.nest_sequences(input_ids)
engine.cache_seq_interface.info.sync(sequence_info)
logits = engine._compute_logits()
logits = torch.stack(logits)
assert logits is not None, "Logits are None"
@ -106,7 +105,6 @@ def test_demo_engine_sampling(attn_page_size: int):
input_ids = [torch.tensor([1, 2, 3, 4], device=device)]
sequence_info.reset()
sequence_info.nest_sequences(input_ids)
engine.cache_seq_interface.info.sync(sequence_info)
logits = engine._compute_logits()
logits = torch.stack(logits)

View File

@ -92,7 +92,7 @@ def test_config_params():
@patch("tensorrt_llm._torch.auto_deploy.llm.DemoGenerationExecutor")
@patch("tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface.SequenceInfo")
@patch("tensorrt_llm._torch.auto_deploy.shim.demollm.dist_ad.initialize_or_skip")
@patch("tensorrt_llm._torch.auto_deploy.llm.create_input_processor")
@patch("tensorrt_llm._torch.auto_deploy.llm.LLM._create_input_processor")
@patch("tensorrt_llm._torch.auto_deploy.llm.LLM._build_model")
def test_config_flow(
mock_build_model,
@ -147,13 +147,6 @@ def test_config_flow(
pass
def test_invalid_model_factory():
"""Test behavior with invalid model factory."""
# Pydantic validates Literal types at runtime, so this should raise ValidationError
with pytest.raises(Exception): # Could be ValidationError or ValueError
LlmArgs(model="test-model", model_factory="InvalidFactory")
@pytest.mark.parametrize(
"parallel_field,invalid_value",
[

View File

@ -68,7 +68,7 @@ def _check_ad_config(experiment_config: ExperimentConfig, llm_args: LlmArgs):
get_small_model_config(
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
attn_backend="flashinfer",
compile_backend="torch-opt",
compile_backend="torch-simple",
),
get_small_model_config(
"deepseek-ai/DeepSeek-V3",

View File

@ -54,17 +54,19 @@ class GQAWithSdpa(GQA):
self.num_key_value_groups = None
@torch.no_grad()
def forward(self, x: torch.Tensor, position_ids: Optional[torch.Tensor] = None) -> torch.Tensor:
def forward(
self, input_ids: torch.Tensor, position_ids: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Forward pass with input tokens and optional position ids.
position_ids parameter added to match expected interface in kvcache.py
"""
b, s, _ = x.shape
b, s, _ = input_ids.shape
# Project input to q, k, v representations
q = self.q_proj(x) # [b, s, n*h_d]
k = self.k_proj(x) # [b, s, n_kv*h_d]
v = self.v_proj(x) # [b, s, n_kv*h_d]
q = self.q_proj(input_ids) # [b, s, n*h_d]
k = self.k_proj(input_ids) # [b, s, n_kv*h_d]
v = self.v_proj(input_ids) # [b, s, n_kv*h_d]
# Reshape to [b, s, n, h_d]
q = q.view(b, s, self.num_heads, self.head_dim)
@ -126,9 +128,9 @@ def test_sdpa_with_kv_cache(dtype, attn_backend, gqa_config):
ci = SequenceEmbeddingInfo(
max_seq_len=max_position_embeddings,
max_batch_size=batch_size,
hidden_size=hidden_size,
dtype=dtype,
)
ci.hidden_size = hidden_size
ci.dtype = dtype
cm = CachedSequenceInterface(sequence_info=ci, device="cuda")
# Create the model with SDPA and wrap it in a fake factory
@ -180,9 +182,9 @@ def test_sdpa_with_kv_cache(dtype, attn_backend, gqa_config):
cm.initialize_caches()
# Helper function to call the model with proper sequence nesting
def _call_and_unnest(x):
def _call_and_unnest(x, input_pos):
# Use nest_sequences to properly set input_ids and automatically update position_ids
cm.info.nest_sequences(x, allow_realloc=True)
cm.info.nest_sequences(x, input_pos=input_pos)
# Use the cm.args as is - it already contains the correct position_ids
y = gm(*cm.args)
@ -192,31 +194,25 @@ def test_sdpa_with_kv_cache(dtype, attn_backend, gqa_config):
# Test 1: Regular inference (all tokens at once)
cm.info.reset()
y_no_cache = _call_and_unnest(x)
y_no_cache = _call_and_unnest(x, 0)
assert all_close(y_model, y_no_cache, atol=atol, rtol=rtol)
# Test 2: Autoregressive inference with KV cache
cm.info.reset()
y_with_cache = torch.empty_like(y_model)
for i in range(x.shape[1]):
for i_p in range(x.shape[1]):
# Just pass the current token
y_with_cache[:, i : i + 1] = _call_and_unnest(x[:, i : i + 1])
# Update position for next token
cm.info.update_pos(1) # This automatically updates position_ids too
y_with_cache[:, i_p : i_p + 1] = _call_and_unnest(x[:, i_p : i_p + 1], i_p)
assert all_close(y_model, y_with_cache, atol=atol, rtol=rtol)
# Test 3: Cache continuation after random tokens
cm.info.update_pos(-num_reset_steps) # Rewind position
for i in range(num_random_steps):
_call_and_unnest(torch.rand_like(x[:, :1]))
cm.info.update_pos(1)
for i_p in range(x.shape[1] - num_reset_steps, x.shape[1] - num_reset_steps + num_random_steps):
_call_and_unnest(torch.rand_like(x[:, :1]), i_p)
# Continue inference from previous context
cm.info.reset()
cm.info.update_pos(x.shape[1] - num_reset_steps)
for i in range(x.shape[1] - num_reset_steps, x.shape[1]):
y_with_cache[:, i : i + 1] = _call_and_unnest(x[:, i : i + 1])
cm.info.update_pos(1)
for i_p in range(x.shape[1] - num_reset_steps, x.shape[1]):
y_with_cache[:, i_p : i_p + 1] = _call_and_unnest(x[:, i_p : i_p + 1], i_p)
assert all_close(y_model, y_with_cache, atol=atol, rtol=rtol)
# Test 4: Exportability of the transformed model

View File

@ -16,6 +16,7 @@ from tensorrt_llm.llmapi import (CudaGraphConfig, EagleDecodingConfig,
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
@pytest.mark.skip(reason="https://nvbugs/5461761")
@pytest.mark.parametrize(
"use_cuda_graph,attn_backend,disable_overlap_scheduler,enable_block_reuse,use_one_model,enable_chunked_prefill",
[