diff --git a/docker/common/install_cmake.sh b/docker/common/install_cmake.sh index 6272cf29dc..4a3598a609 100644 --- a/docker/common/install_cmake.sh +++ b/docker/common/install_cmake.sh @@ -12,7 +12,7 @@ fi PARSED_CMAKE_VERSION=$(echo $CMAKE_VERSION | sed 's/\.[0-9]*$//') CMAKE_FILE_NAME="cmake-${CMAKE_VERSION}-linux-${ARCH}" RELEASE_URL_CMAKE=${GITHUB_URL}/Kitware/CMake/releases/download/v${CMAKE_VERSION}/${CMAKE_FILE_NAME}.tar.gz -wget --no-verbose --timeout=180 --tries=3 ${RELEASE_URL_CMAKE} -P /tmp +wget --retry-connrefused --timeout=180 --tries=10 --continue ${RELEASE_URL_CMAKE} -P /tmp tar -xf /tmp/${CMAKE_FILE_NAME}.tar.gz -C /usr/local/ ln -s /usr/local/${CMAKE_FILE_NAME} /usr/local/cmake diff --git a/docker/common/install_cuda_toolkit.sh b/docker/common/install_cuda_toolkit.sh index 5d3ce166d6..e5372bb7d5 100644 --- a/docker/common/install_cuda_toolkit.sh +++ b/docker/common/install_cuda_toolkit.sh @@ -16,7 +16,7 @@ reinstall_rockylinux_cuda() { dnf -y install epel-release dnf remove -y "cuda*" "*cublas*" "*cufft*" "*cufile*" "*curand*" "*cusolver*" "*cusparse*" "*gds-tools*" "*npp*" "*nvjpeg*" "nsight*" "*nvvm*" rm -rf /usr/local/cuda-${OLD_CUDA_VER} - wget -q https://developer.download.nvidia.com/compute/cuda/${CUDA_VER_SHORT}/local_installers/cuda_${CUDA_VER}_linux.run + wget --retry-connrefused --timeout=180 --tries=10 --continue https://developer.download.nvidia.com/compute/cuda/${CUDA_VER_SHORT}/local_installers/cuda_${CUDA_VER}_linux.run sh cuda_${CUDA_VER}_linux.run --silent --override --toolkit rm -f cuda_${CUDA_VER}_linux.run } diff --git a/docker/common/install_tensorrt.sh b/docker/common/install_tensorrt.sh index 1ee819bb52..3724060f8d 100644 --- a/docker/common/install_tensorrt.sh +++ b/docker/common/install_tensorrt.sh @@ -96,7 +96,7 @@ install_rockylinux_requirements() { "cuda-toolkit-config-common-${CUDA_RUNTIME}.noarch" \ "libcublas-${CUBLAS_CUDA_VERSION}-${CUBLAS_VER}.${ARCH1}" \ "libcublas-devel-${CUBLAS_CUDA_VERSION}-${CUBLAS_VER}.${ARCH1}"; do - wget -q --timeout=180 --tries=3 "https://developer.download.nvidia.cn/compute/cuda/repos/rhel8/${ARCH3}/${pkg}.rpm" + wget --retry-connrefused --timeout=180 --tries=10 --continue "https://developer.download.nvidia.cn/compute/cuda/repos/rhel8/${ARCH3}/${pkg}.rpm" done # Remove old packages @@ -138,7 +138,7 @@ install_tensorrt() { if [ "$ARCH" = "x86_64" ];then curl -L --insecure --connect-timeout 600 --max-time 3600 --retry 3 -o /tmp/TensorRT.tar "${RELEASE_URL_TRT}" else - wget --no-verbose ${RELEASE_URL_TRT} -O /tmp/TensorRT.tar + wget --retry-connrefused --timeout=180 --tries=10 --continue ${RELEASE_URL_TRT} -O /tmp/TensorRT.tar fi tar -xf /tmp/TensorRT.tar -C /usr/local/ mv /usr/local/TensorRT-* /usr/local/tensorrt diff --git a/docker/common/install_triton.sh b/docker/common/install_triton.sh index 89b6aced03..f295554a0a 100644 --- a/docker/common/install_triton.sh +++ b/docker/common/install_triton.sh @@ -5,7 +5,7 @@ set -ex install_boost() { # Install boost version >= 1.78 for boost::span # Current libboost-dev apt packages are < 1.78, so install from tar.gz - wget -O /tmp/boost.tar.gz --timeout=180 --tries=3 https://archives.boost.io/release/1.80.0/source/boost_1_80_0.tar.gz \ + wget --retry-connrefused --timeout=180 --tries=10 --continue -O /tmp/boost.tar.gz https://archives.boost.io/release/1.80.0/source/boost_1_80_0.tar.gz \ && tar xzf /tmp/boost.tar.gz -C /tmp \ && mv /tmp/boost_1_80_0/boost /usr/include/boost \ && rm -rf /tmp/boost_1_80_0 /tmp/boost.tar.gz diff --git a/examples/auto_deploy/.gitignore b/examples/auto_deploy/.gitignore index 8e28b4431d..9836a37fc8 100644 --- a/examples/auto_deploy/.gitignore +++ b/examples/auto_deploy/.gitignore @@ -2,3 +2,5 @@ !.vscode benchmark_results.json *.png +# ignore config files that users might put here for debugging +*.yaml diff --git a/examples/auto_deploy/build_and_run_ad.py b/examples/auto_deploy/build_and_run_ad.py index 42a2f927dd..45ea1fe19d 100644 --- a/examples/auto_deploy/build_and_run_ad.py +++ b/examples/auto_deploy/build_and_run_ad.py @@ -26,6 +26,9 @@ from tensorrt_llm.sampling_params import SamplingParams # Global torch config, set the torch compile cache to fix up to llama 405B torch._dynamo.config.cache_size_limit = 20 +# simple string, TRT-LLM style text-only prompt or full-scale HF message template +PromptInput = Union[str, Dict, List[Dict]] + class PromptConfig(BaseModel): """Prompt configuration. @@ -35,13 +38,27 @@ class PromptConfig(BaseModel): """ batch_size: int = Field(default=2, description="Number of queries") - queries: Union[str, List[str]] = Field( + queries: Union[PromptInput, List[PromptInput]] = Field( default_factory=lambda: [ + # OPTION 1: simple text prompt "How big is the universe? ", - "In simple words and in a single sentence, explain the concept of gravity: ", - "How to fix slicing in golf? ", - "Where is the capital of Iceland? ", - ] + # OPTION 2: wrapped text prompt for TRT-LLM + {"prompt": "In simple words and a single sentence, explain the concept of gravity: "}, + # OPTION 3: a full-scale HF message template (this one works for text-only models!) + # Learn more about chat templates: https://huggingface.co/docs/transformers/en/chat_templating + # and multi-modal templates: https://huggingface.co/docs/transformers/en/chat_templating_multimodal + [ + { + "role": "user", + "content": "How to fix slicing in golf?", + } + ], + # More prompts... + {"prompt": "Where is the capital of Iceland? "}, + ], + description="Example queries to prompt the model with. We support both TRT-LLM text-only " + "queries via the 'prompt' key and full-scale HF message template called via " + "apply_chat_template.", ) sp_kwargs: Dict[str, Any] = Field( default_factory=lambda: {"max_tokens": 100, "top_k": 200, "temperature": 1.0}, @@ -55,10 +72,28 @@ class PromptConfig(BaseModel): NOTE (lucaslie): has to be done with model_post_init to ensure it's always run. field validators are only run if a value is provided. """ - queries = [self.queries] if isinstance(self.queries, str) else self.queries + queries = self.queries if isinstance(self.queries, list) else [self.queries] batch_size = self.batch_size queries = queries * (batch_size // len(queries) + 1) - self.queries = queries[:batch_size] + queries = queries[:batch_size] + + # now let's standardize the queries for the LLM api to understand them + queries_processed = [] + for query in queries: + if isinstance(query, str): + queries_processed.append({"prompt": query}) + elif isinstance(query, dict): + queries_processed.append(query) + elif isinstance(query, list): + queries_processed.append( + { + "prompt": "Fake prompt. Check out messages field for the HF chat template.", + "messages": query, # contains the actual HF chat template + } + ) + else: + raise ValueError(f"Invalid query type: {type(query)}") + self.queries = queries_processed @field_validator("sp_kwargs", mode="after") @classmethod diff --git a/jenkins/L0_Test.groovy b/jenkins/L0_Test.groovy index 6f034c6dab..624833545f 100644 --- a/jenkins/L0_Test.groovy +++ b/jenkins/L0_Test.groovy @@ -121,7 +121,7 @@ REQUIRED_NO_DRIVER_TYPES = ["dgx-h100", "dgx-h200", "gh200"] ENABLE_NGC_DEVEL_IMAGE_TEST = params.enableNgcDevelImageTest ?: false ENABLE_NGC_RELEASE_IMAGE_TEST = params.enableNgcReleaseImageTest ?: false -COMMON_SSH_OPTIONS = "-o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null" +COMMON_SSH_OPTIONS = "-o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -o ServerAliveInterval=60 -o ServerAliveCountMax=5" def uploadResults(def pipeline, SlurmCluster cluster, String nodeName, String stageName){ withCredentials([usernamePassword(credentialsId: 'svc_tensorrt', usernameVariable: 'USERNAME', passwordVariable: 'PASSWORD')]) { @@ -323,7 +323,7 @@ def runLLMTestlistOnSlurm(pipeline, platform, testList, config=VANILLA_CONFIG, p Utils.exec(pipeline, script: "chmod +x ${jenkinsSetupPath}", returnStdout: true) - Utils.exec(pipeline, script: "sshpass -p '${remote.passwd}' scp -r -p ${COMMON_SSH_OPTIONS} ${jenkinsSetupPath} ${remote.user}@${remote.host}:~/bloom/scripts/${nodeName}-slurm_jenkins_agent_setup.sh", numRetries: 3,) + Utils.exec(pipeline, script: "sshpass -p '${remote.passwd}' scp -r -p ${COMMON_SSH_OPTIONS} ${jenkinsSetupPath} ${remote.user}@${remote.host}:~/bloom/scripts/${nodeName}-slurm_jenkins_agent_setup.sh", numRetries: 3) Utils.exec(pipeline, script: "cat ${jenkinsSetupPath}") @@ -334,7 +334,8 @@ def runLLMTestlistOnSlurm(pipeline, platform, testList, config=VANILLA_CONFIG, p remote, "\"${SlurmConfig.generateCommand(cluster, partition, nodeSecret, nodeName, Jenkins.instance.rootUrl)}\"" ), - returnStdout: true + returnStdout: true, + numRetries: 3 ) def jobIDs = slurmSubmitOutput @@ -498,7 +499,7 @@ def runLLMTestlistOnSlurm_MultiNodes(pipeline, platform, testList, config=VANILL stage('Prepare Testing') { // Create Job Workspace folder in Frontend Node - Utils.exec(pipeline, script: "sshpass -p '${remote.passwd}' ssh ${COMMON_SSH_OPTIONS} ${remote.user}@${remote.host} 'mkdir -p ${jobWorkspace}'", numRetries: 3,) + Utils.exec(pipeline, script: Utils.sshUserCmd(remote, "\"mkdir -p ${jobWorkspace}\""), numRetries: 3) // Download and Unzip Tar File trtllm_utils.llmExecStepWithRetry(pipeline, script: "cd ${llmPath} && wget -nv ${llmTarfile}") @@ -508,12 +509,12 @@ def runLLMTestlistOnSlurm_MultiNodes(pipeline, platform, testList, config=VANILL def scriptRunLocalPath = "${llmSrcLocal}/jenkins/scripts/slurm_run.sh" Utils.exec(pipeline, script: "chmod +x ${scriptRunLocalPath}", returnStdout: true) - Utils.exec(pipeline, script: "sshpass -p '${remote.passwd}' scp -r -p ${COMMON_SSH_OPTIONS} ${scriptRunLocalPath} ${remote.user}@${remote.host}:${scriptRunNode}", numRetries: 3,) + Utils.exec(pipeline, script: "sshpass -p '${remote.passwd}' scp -r -p ${COMMON_SSH_OPTIONS} ${scriptRunLocalPath} ${remote.user}@${remote.host}:${scriptRunNode}", numRetries: 3) Utils.exec(pipeline, script: "cat ${scriptRunLocalPath}") // Upload waives.txt to Frontend node def waivesListLocalPath = "${llmSrcLocal}/tests/integration/test_lists/waives.txt" - Utils.exec(pipeline, script: "sshpass -p '${remote.passwd}' scp -r -p ${COMMON_SSH_OPTIONS} ${waivesListLocalPath} ${remote.user}@${remote.host}:${waivesListPathNode}", numRetries: 3,) + Utils.exec(pipeline, script: "sshpass -p '${remote.passwd}' scp -r -p ${COMMON_SSH_OPTIONS} ${waivesListLocalPath} ${remote.user}@${remote.host}:${waivesListPathNode}", numRetries: 3) // Generate Test List and Upload to Frontend Node def makoArgs = getMakoArgsFromStageName(stageName, true) @@ -522,7 +523,7 @@ def runLLMTestlistOnSlurm_MultiNodes(pipeline, platform, testList, config=VANILL // if the line cannot be split by "=", just ignore that line. def makoOptsJson = transformMakoArgsToJson(["Mako options:"] + makoArgs) def testListPath = renderTestDB(testList, llmSrcLocal, stageName, makoOptsJson) - Utils.exec(pipeline, script: "sshpass -p '${remote.passwd}' scp -r -p ${COMMON_SSH_OPTIONS} ${testListPath} ${remote.user}@${remote.host}:${testListPathNode}", numRetries: 3,) + Utils.exec(pipeline, script: "sshpass -p '${remote.passwd}' scp -r -p ${COMMON_SSH_OPTIONS} ${testListPath} ${remote.user}@${remote.host}:${testListPathNode}", numRetries: 3) // Generate Multi Node Job Launch Script def container = LLM_DOCKER_IMAGE.replace("urm.nvidia.com/", "urm.nvidia.com#") @@ -566,7 +567,7 @@ def runLLMTestlistOnSlurm_MultiNodes(pipeline, platform, testList, config=VANILL """.stripIndent() pipeline.writeFile(file: scriptLaunchDestPath, text: scriptContent) Utils.exec(pipeline, script: "chmod +x ${scriptLaunchDestPath}", returnStdout: true) - Utils.exec(pipeline, script: "sshpass -p '${remote.passwd}' scp -r -p ${COMMON_SSH_OPTIONS} ${scriptLaunchDestPath} ${remote.user}@${remote.host}:${scriptLaunch}", numRetries: 3,) + Utils.exec(pipeline, script: "sshpass -p '${remote.passwd}' scp -r -p ${COMMON_SSH_OPTIONS} ${scriptLaunchDestPath} ${remote.user}@${remote.host}:${scriptLaunch}", numRetries: 3) Utils.exec(pipeline, script: "cat ${scriptLaunchDestPath}") } @@ -577,7 +578,8 @@ def runLLMTestlistOnSlurm_MultiNodes(pipeline, platform, testList, config=VANILL script: Utils.sshUserCmd( remote, "\"bash ${scriptLaunch}\"" - ) + ), + numRetries: 3 ) } diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py index b6fc7b7693..d03a3329f1 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py @@ -10,15 +10,18 @@ and operates on a purely functional paradigm that is compatible with the torch c """ from abc import ABC, abstractmethod -from dataclasses import dataclass, field, fields -from typing import Dict, List, Literal, Optional, Protocol, Sequence, Tuple, Type, Union +from dataclasses import dataclass +from typing import Callable, Dict, List, Literal, Optional, Protocol, Sequence, Tuple, Type, Union import torch from torch._ops import OpOverloadPacket from torch.export import Dim from torch.fx import Node -from tensorrt_llm._utils import nvtx_range +from ...._utils import nvtx_range + +DynamicShape = Dict[int, Dim] # indicating the dynamic shape in tensor dimension +DynamicShapeCallback = Callable[[], DynamicShape] @dataclass @@ -28,15 +31,14 @@ class CacheConfig: dtype: Optional[torch.dtype] = None -@dataclass class SequenceInfo: - """A dataclass to hold information about how the sequence is laid out and stored in cache. + """An interface to hold information about how the sequence is laid out and stored in cache. We assume the sequence + cache is laid out in the following way. Also note that we differentiate between arguments that are originally part of the model/graph and arguments that are needed for the attention operator when we switch to cached+flattened attention. - # ORIGINAL MODEL ARGUMENTS ##################################################################### + ### ORIGINAL MODEL ARGUMENTS ################################################################### - input_ids: [id_0, ..., id_{s_total-1}] flattened sequence of [b, 1] or [1, s_total]. We use [b, 1] to denote generate-only batches. - position_ids: [pos_0, ..., pos_{s_total-1}] @@ -46,7 +48,17 @@ class SequenceInfo: NOTE: ``input_ids`` and ``position_ids`` are initially expected to be of shape [b, seq_len] before we switch to cached+flattened attention. - # EXTRA ARGUMENTS NEEDED FOR ATTENTION OPERATORS FOR FLATTENED SEQUENCES + CACHES ############## + ### EXTRA ARGUMENTS PROVIDED TO THE INTERFACE ################################################## + Those are extra arguments that can be provided to the interface and they are stored as follows: + - _extra_args: dictionary of extra arguments with currently active values. + - _extra_none_inputs: dictionary of none inputs to the extra arguments. + NOTE: we assume that extra arguments are *optional* arguments to the model. However, we + cannot represent them via `None` since fx graphs require a fixed input type. Instead, + we require a special placeholder tensor to represent the `None` input. + - _extra_dynamic_shapes_callbacks: dictionary of callbacks to initialize the dynamic shapes of + the extra arguments. + + ### CACHE ARGUMENTS NEEDED FOR ATTENTION OPERATORS FOR FLATTENED SEQUENCES + CACHES ############ - seq_len: [s_0, s_1, ..., s_{b-1}] such that s_total = sum(s_i) Describes how long each sequence is. For example, input_ids[:s_0] will correspond to sequence 0 in the batch and input_ids[s_0:s_1] will @@ -73,168 +85,238 @@ class SequenceInfo: """ - ## USE TO INITIALIZE DATA CLASS ############################################################### - # max_seq_len corresponds the maximum number of tokens in any sequence. It includes the tokens in the - # input sequence and the tokens generated by the model. - max_seq_len: int = 1 - # max_batch_size corresponds to the maximum number of sequences (or requests) that the model can process. - max_batch_size: int = 1 - # page_size is the granularity with which the cache pages are allocated for a paged kv cache. - # For an unpaged cache, the page size should be set to max_seq_len. - # Also note that two sequences in a batch can not share a page. - page_size: int = 0 - # max_num_tokens is the maximum number of tokens that the model can process across all sequences in the batch. - # If a batch is composed of context-only requests of input sequence length ISL, - # then the maximum number of sequences possible in the batch is min (max_batch_size, max_num_tokens // ISL). - # Similarly, if a batch is composed of generate-only requests, - # then the maximum number of sequences possible in the batch is min (max_batch_size, max_num_tokens). - max_num_tokens: Optional[int] = None - # device is the device on which the sequence info is stored. - device: str = "cuda" + def __init__( + self, + max_seq_len: int = 1, + max_batch_size: int = 1, + page_size: int = 0, + max_num_tokens: Optional[int] = None, + ): + """Initialize the SequenceInfo object. - ## [UPDATE WITH CARE] TENSOR FIELDS THAT WILL BE PASSED TO PREPARE_METADATA OP ################# - # input_ids MUST ALWAYS BE THE FIRST FIELD - input_ids: torch.Tensor = field(default_factory=lambda: torch.zeros(1, dtype=torch.int)) - position_ids: torch.Tensor = field(default_factory=lambda: torch.zeros(1, dtype=torch.long)) + Args: + max_seq_len: corresponds to the maximum sequence length of the input sequence. It + includes the tokens in the input sequence and the tokens generated by the model. + max_batch_size: corresponds to the maximum number of sequences (or requests) that the + model can process. + page_size: corresponds to the page size of the cache. For an unpaged cache, the page + size should be set to max_seq_len. Also note that two sequences in a batch can not + share a page. + max_num_tokens: corresponds to the maximum number of tokens that the model can process + across all sequences in the batch. If a batch is composed of context-only requests + of input sequence length ISL, then the maximum number of sequences possible in the + batch is min (max_batch_size, max_num_tokens // ISL). Similarly, if a batch is + composed of generate-only requests, then the maximum number of sequences possible in + the batch is min (max_batch_size, max_num_tokens). - seq_len: torch.Tensor = field(default_factory=lambda: torch.ones(1, dtype=torch.int)) - input_pos: torch.Tensor = field(default_factory=lambda: torch.zeros(1, dtype=torch.int)) - cache_loc: torch.Tensor = field(default_factory=lambda: torch.arange(1, dtype=torch.int)) - pages_per_seq: torch.Tensor = field(default_factory=lambda: torch.ones(1, dtype=torch.int)) - ################################################################################################ - - ## PRIVATE FIELDS ############################################################################## - _sequence_lengths: List[int] = field(default_factory=list) - _num_pages: int = 1 - - def __post_init__(self): - if self.page_size < 1: - self.page_size = self.max_seq_len + Returns: + None + """ + # set up basic attributes + self.max_seq_len = max_seq_len + self.max_batch_size = max_batch_size + self.page_size = page_size if page_size > 0 else max_seq_len # NOTE (lucaslie): WAR to address issue when using flashinfer attention with # (max_batch_size, max_seq_len) input in trtllm runtime. # see https://github.com/NVIDIA/TensorRT-LLM/issues/4504 - self.max_seq_len_adjusted = self.max_seq_len + 1 + max_seq_len_adjusted = self.max_seq_len + 1 + + if max_num_tokens is None or max_num_tokens < 1: + self.max_num_tokens = self.max_batch_size * max_seq_len_adjusted + else: + self.max_num_tokens = max_num_tokens - if self.max_num_tokens is None or self.max_num_tokens < 1: - self.max_num_tokens = self.max_batch_size * self.max_seq_len_adjusted # if the provided max_num_tokens is less than the max_batch_size * max_seq_len, # we use the provided max_num_tokens to calculate the number of pages - total_tokens = min(self.max_num_tokens, self.max_batch_size * self.max_seq_len_adjusted) + total_tokens = min(self.max_num_tokens, self.max_batch_size * max_seq_len_adjusted) # Num pages can not be less than max_batch_size. self._num_pages = max( self.max_batch_size, (total_tokens) // self.page_size + (total_tokens % self.page_size > 0), ) - # Ensure that the device is set before initializing the tensors. - self.input_ids = torch.ones(self.max_num_tokens, dtype=torch.int, device=self.device) - self.position_ids = torch.zeros(self.max_num_tokens, dtype=torch.long, device=self.device) - - # Consumers of the sequence info args require input_ids and position_ids to be truncated. - # We maintain a full version of the input_ids and position_ids to avoid overheads of tensor - # creation in every forward pass. - self.input_ids_full = torch.ones(self.max_num_tokens, dtype=torch.int, device=self.device) - self.position_ids_full = torch.zeros( - self.max_num_tokens, dtype=torch.long, device=self.device - ) - - self.seq_len = torch.empty(self.max_batch_size, dtype=torch.int, device=self.device) - self.input_pos = torch.empty_like(self.seq_len, device=self.device) - - # Allocated host tensors for sequence lengths and input positions so that - # position_ids calculation can be done on host. - self.seq_len_host = torch.empty(self.max_batch_size, dtype=torch.int) - self.input_pos_host = torch.empty_like(self.seq_len_host) - - self.cache_loc = torch.empty(self.num_pages, dtype=torch.int, device=self.device) - self.pages_per_seq = torch.empty_like(self.seq_len, device=self.device) - - self.previous_batch_indices_cuda = torch.empty( - self.max_num_tokens, dtype=torch.long, device=self.device - ) - assert self.num_pages >= self.max_batch_size, ( - "num_pages must be greater than max_batch_size" - ) - # dynamic shape descriptors for tensor args - self._dynamic_shapes: Optional[Tuple[Dict[str, Dim]]] = None - - # keep a list-like object of sequence lengths for simplicity as well - self._sequence_lengths = [0] * self.max_batch_size + # sanity check + assert self.num_pages >= self.max_batch_size, "num_pages can't be less than max_batch_size" # indicator if extra args are activated that are needed for cached attention backends self._is_cached_attn = False - # total number of tokens in the current batch - self.num_tokens: int = 0 + # container for dynamic shapes + self._dynamic_shapes: Optional[Dict[str, DynamicShape]] = None - # call reset once to initialize the tensors + # TENSOR FIELDS ############################################################################ + self._args_device: Dict[str, torch.Tensor] = { + # TENSOR FIELDS FOR UNCACHED ATTENTION + "input_ids": torch.ones(self.max_num_tokens, dtype=torch.int), + "position_ids": torch.zeros(self.max_num_tokens, dtype=torch.long), + # TENSOR FIELDS FOR CACHED ATTENTION + "seq_len": torch.empty(self.max_batch_size, dtype=torch.int), + "input_pos": torch.empty(self.max_batch_size, dtype=torch.int), + "cache_loc": torch.empty(self.num_pages, dtype=torch.int), + "pages_per_seq": torch.empty(self.max_batch_size, dtype=torch.int), + # OTHER FIELDS WHERE WE NEED EFFICIENT HOST<>DEVICE TRANSFER + "_gather_idx": torch.empty(self.max_num_tokens, dtype=torch.int), + } + self._args_host: Dict[str, List[int]] = { + k: v.tolist() for k, v in self._args_device.items() + } + # NOTE: order of keys is relevant here! + self._uncached_arg_names = ("input_ids", "position_ids") + self._cached_arg_names = ("seq_len", "input_pos", "cache_loc", "pages_per_seq") + ############################################################################################ + + # EXTRA TENSOR FIELDS ###################################################################### + self._extra_args: Dict[str, torch.Tensor] = {} + self._extra_none_inputs: Dict[str, torch.Tensor] = {} + self._extra_dynamic_shapes: Optional[Dict[str, DynamicShape]] = None + self._extra_dynamic_shapes_callbacks: Dict[str, DynamicShapeCallback] = {} + ############################################################################################ + + # call reset once to set a consistent initial state self.reset() + @property + def device(self) -> torch.device: + return self._args_device["input_ids"].device + + def _shape_for_forward(self, tnsr: torch.Tensor) -> torch.Tensor: + """Shape the tensor for the forward pass based on the current attention mode. + + Args: + tnsr: The tensor to shape assumed to be in shape [batch_size*seq_len, ...] + + Returns: + The shaped tensor flattened or unflattened based on the current attention mode. + """ + # check if we are still running uncached attention in which case we are also still + # operate on unflattened tensors with explicit [batch_size, seq_len, ...] shape + # generate-only batches are also formatted like this (i.e. [b, 1]) + if not self._is_cached_attn or self.is_generate: + bs = len(self.seq_len) + sl = self.seq_len[0] + # use [1,total_len] shape to indicate non-generate-only batch for cached attention + else: + bs, sl = 1, self.total_num_tokens + + # truncate to total tokens now, reshape, and return + return tnsr[: self.total_num_tokens].view(bs, sl, *tnsr.shape[1:]) + + def _named_args( + self, include_extra_args: bool = True, include_cached_args: bool = True + ) -> Dict[str, torch.Tensor]: + # start with uncached args and shape them along the way + args = {k: self._shape_for_forward(self._args_device[k]) for k in self._uncached_arg_names} + + # check other args to include + if include_extra_args: + args.update(self._extra_args) + + if include_cached_args: + args.update({k: self._args_device[k] for k in self._cached_arg_names}) + + return args + + @property + def named_args(self) -> Dict[str, torch.Tensor]: + """Return a dictionary of named arguments. + + These arguments contain all arguments that are managed by this interface and are required + to run a model's forward pass including all extra arguments. + + Cached arguments are only included if the attention mode is cached to reflect that after + switching to cached attention, the cached arguments are required for a forward pass. + """ + return self._named_args(include_extra_args=True, include_cached_args=self._is_cached_attn) + + @property + def named_standard_args(self) -> Dict[str, torch.Tensor]: + """Return a dictionary of named standard arguments. + + We define standard arguments as the arguments that are part of the model's forward function + by default (i.e., without the extra arguments). + + Just liked ``named_args``, this property includes cached attention arguments if the + attention mode is cached. + """ + return self._named_args(include_extra_args=False, include_cached_args=self._is_cached_attn) + @property def args(self) -> Tuple[torch.Tensor, ...]: - args = [] - for f in fields(self): - val = getattr(self, f.name) - if isinstance(val, torch.Tensor): - args.append(val) - if len(args) >= self._num_uncached_attn_args and not self._is_cached_attn: - break - - return tuple(args) + """Return a tuple of arguments.""" + return tuple(self.named_args.values()) @property - def _num_uncached_attn_args(self) -> int: - """Return the number of original graph arguments expected by the model. - This is 2 because we have input_ids and position_ids as the original graph arguments. + def const_args_for_prepare_metadata(self) -> Tuple: + """Return a tuple of extra (const, non-tensor) arguments for the prepare_metadata op. + + The ``prepare_metadata`` interface expects the following arguments: + + 1. ``named_standard_args`` as nodes,i.e., as input-dependent tensors. + 2. ``const_args_for_prepare_metadata`` as constants that can directly by passed in as args + to the corresponding ``prepare_metadata`` node/op. + + This interface handles the constant arguments part and can be used by compiler passes like + ``insert_cached_attention`` to extract the constant arguments and add them to the + ``prepare_metadata`` node/op. """ - return 2 + return (self.page_size,) @property - def _cached_attn_arg_names(self) -> List[str]: - """Return extra arg names for the prepare_metadata op beyond input_ids and position_ids. - - These extra args are needed once we switch from regular attention to inserting cached - attention ops in the model. - """ - return [f.name for f in fields(self) if isinstance(getattr(self, f.name), torch.Tensor)][ - self._num_uncached_attn_args : - ] - - @property - def dynamic_shapes(self) -> Tuple[Dict[str, Dim]]: + def named_dynamic_shapes(self) -> Dict[str, Dict[str, Dim]]: """Return dynamic shapes of sequence info tensors. NOTE: will be lazily initialized since the Dim object is not picklable for multi-processing. """ + # lazy initialization of dynamic shapes with Dim objects if self._dynamic_shapes is None: - # set up shape for input_ids and position_ids - dynamic_shapes = ({}, {}) + # set up shape for uncached args (same for all, i.e., batch_size and seq_len) + bs_seq_len_shape: DynamicShape = {} if self.max_batch_size > 1: - dynamic_shapes[0][0] = Dim("batch_size", max=self.max_batch_size) - dynamic_shapes[0][1] = Dim("seq_len", max=self.max_seq_len_adjusted) - # set up shape for position_ids (same as input_ids) - dynamic_shapes[1].update(dynamic_shapes[0]) - # set up shape for extra args - if self._is_cached_attn: - dynamic_shapes += ({},) * len(self._cached_attn_arg_names) - self._dynamic_shapes = dynamic_shapes - return self._dynamic_shapes + bs_seq_len_shape[0] = Dim("batch_size", max=self.max_batch_size) + bs_seq_len_shape[1] = Dim("seq_len", max=self.max_seq_len) + self._dynamic_shapes = {k: bs_seq_len_shape for k in self._uncached_arg_names} + # cached args are static + self._dynamic_shapes.update({k: {} for k in self._cached_arg_names}) + + for k, callback in self._extra_dynamic_shapes_callbacks.items(): + if k not in self._dynamic_shapes: + self._dynamic_shapes[k] = callback() + + # return dynamic shapes according to currently active named_args with consistent order + return {k: self._dynamic_shapes[k] for k in self.named_args.keys()} + + @property + def dynamic_shapes(self) -> Tuple[DynamicShape, ...]: + """Return dynamic shapes of sequence info tensors.""" + return tuple(self.named_dynamic_shapes.values()) + + @property + def seq_len(self) -> List[int]: + return self._args_host["seq_len"].copy() + + @property + def input_pos(self) -> List[int]: + return self._args_host["input_pos"].copy() + + @property + def cache_loc(self) -> List[int]: + return self._args_host["cache_loc"].copy() + + @property + def pages_per_seq(self) -> List[int]: + return self._args_host["pages_per_seq"].copy() @property def num_sequences(self) -> int: - return len(self._sequence_lengths) + return len(self.seq_len) @property - def sequence_lengths(self) -> List[int]: - return self._sequence_lengths - - @property - def input_positions(self) -> List[int]: - return self.input_pos_host[: self.num_sequences].tolist() + def total_num_tokens(self) -> int: + return sum(self.seq_len) @property def is_generate(self) -> bool: - return all(sl == 1 for sl in self.sequence_lengths) + return all(sl == 1 for sl in self.seq_len) @property def num_pages(self) -> int: @@ -244,7 +326,7 @@ class SequenceInfo: def num_pages(self, value): self._num_pages = value # update the cache_loc tensor - self.cache_loc.resize_(value) + self._args_device["cache_loc"].resize_(value) @property def is_paged(self) -> bool: @@ -253,12 +335,52 @@ class SequenceInfo: @property def page_assignments(self) -> List[List[int]]: """Return the page assignments for each sequence.""" - pages_per_seq = self.pages_per_seq[: self.num_sequences].tolist() + return self._get_page_assignments(self.cache_loc, self.pages_per_seq) + + @staticmethod + def _get_page_assignments( + cache_locations: List[int], pages_per_sequence: List[int] + ) -> List[List[int]]: + """Get nested page assignments from cache locations and pages per sequence as list of lists. + + Args: + cache_locations: A flat list of cache locations for each sequence ordered by sequence. + pages_per_sequence: A list of number of pages per sequence. + + Returns: + A list of page assignments for each sequence ordered by sequence. + For example: + cache_locations: [0, 4, 2] + pages_per_sequence: [2, 1] + --> returns [[0, 4], [2]] + """ return [ c_loc_one_seq.tolist() - for c_loc_one_seq in torch.split(self.cache_loc[: sum(pages_per_seq)], pages_per_seq) + for c_loc_one_seq in torch.split(torch.tensor(cache_locations), pages_per_sequence) ] + @staticmethod + def _get_cache_locations_and_pages_per_sequence( + page_assignments: List[List[int]], + ) -> Tuple[List[int], List[int]]: + """Get cache locations and pages per sequence from nested page assignments (lists of lists). + + Args: + page_assignments: A list of page assignments for each sequence ordered by sequence. + Returns: + A tuple of: + cache_locations: A flat list of cache locations for each sequence ordered by sequence. + pages_per_sequence: A list of number of pages per sequence. + + Example: + page_assignments: [[0, 4], [2]] + --> returns ([0, 4, 2], [2, 1]) + + """ + cache_loc_flat = [p_idx for pages in page_assignments for p_idx in pages] + pages_per_seq = [len(p) for p in page_assignments] + return cache_loc_flat, pages_per_seq + @classmethod def _get_sanitized_seq_len(cls, input_ids: torch.Tensor, seq_len: torch.Tensor) -> torch.Tensor: """Sanitize sequence lengths. @@ -331,26 +453,55 @@ class SequenceInfo: """ assert not self._is_cached_attn, "Cached+flattened attention already activated" self._is_cached_attn = True - return self._cached_attn_arg_names + return list(self._cached_arg_names) def to(self, *args, **kwargs) -> None: - for f in fields(self): - val = getattr(self, f.name) - if isinstance(val, torch.Tensor): - setattr(self, f.name, val.to(*args, **kwargs)) + def _move_dict(d: Dict[str, torch.Tensor]) -> None: + for k, v in d.items(): + d[k] = v.to(*args, **kwargs) - def sync(self, other: "SequenceInfo") -> None: - for f in fields(self): - val = getattr(self, f.name) - val_other = getattr(other, f.name) - if f.name in ["input_ids", "position_ids"]: - setattr(self, f.name, val_other.to(self.device)) - elif f.name == "_sequence_lengths": - self._sequence_lengths = val_other - elif isinstance(val, torch.Tensor): - val[: len(val_other)] = val_other.to(self.device) - else: - assert val == val_other, f"Field {f.name} mismatch: {val} != {val_other}." + _move_dict(self._args_device) + _move_dict(self._extra_args) + _move_dict(self._extra_none_inputs) + + def set_example_sequence( + self, + input_ids: Sequence[Sequence[int]] = None, + position_ids: Optional[torch.Tensor] = None, + **extra_args, + ) -> None: + """Set an example sequence useful for testing and export purposes without cache history.""" + # use a best guess default for input_ids if not provided + if input_ids is None: + bs, seq_len = min(2, self.max_batch_size), min(4, self.max_seq_len) + input_ids = torch.ones(bs, seq_len, dtype=torch.int).tolist() + + # figure out page assignments + pages_per_seq = [ + len(ids_one_seq) // self.page_size + (len(ids_one_seq) % self.page_size > 0) + for ids_one_seq in input_ids + ] + cache_loc = list(range(sum(pages_per_seq))) + page_assignments = self._get_page_assignments(cache_loc, pages_per_seq) + + self.nest_sequences( + input_ids, + position_ids, # will be auto-inferred if None + input_pos=0, # no cache history + page_assignments=page_assignments, # vanilla page assignments + **extra_args, + ) + + def set_max_num_tokens_sample(self) -> None: + """Set an example sequence with max_num_tokens.""" + # TODO (lucaslie): understand what this implies for extra arguments + seq_len = self.max_num_tokens // self.max_batch_size + input_ids = torch.ones(self.max_batch_size, seq_len, dtype=torch.int).tolist() + self.set_example_sequence(input_ids) + + def set_generate_only_batch(self) -> None: + """Set an example sequence for generate-only batch.""" + self.set_example_sequence([[1]] * self.max_batch_size) def reset(self) -> None: """Reset the sequence information. @@ -358,205 +509,173 @@ class SequenceInfo: After reset the sequence information should correspond to a "generate-only" batch of sequences (b, s==1) without cache history. """ - # reset input_pos - self.input_pos.zero_() - self.input_pos_host.zero_() + self.set_generate_only_batch() - # set a dummy sequence corresponding to a generate-only batch (will also reset position_ids) - self.nest_sequences([[1]] * self.max_batch_size, allow_realloc=True) + @staticmethod + def _flatten(nested_seqs: Sequence[Sequence[int]]) -> List[int]: + return [ + val + for lst in nested_seqs + for val in (lst.detach().tolist() if isinstance(lst, torch.Tensor) else lst) + ] - # reset cache information - self.cache_loc[:] = torch.arange(self.num_pages, dtype=torch.int, device=self.device) - self.pages_per_seq.fill_(1) - - # let's also reset the input_ids and position_ids tensors to their max shapes (max_num_tokens) - self.input_ids = torch.ones(self.max_num_tokens, dtype=torch.int, device=self.device) - self.position_ids = torch.zeros(self.max_num_tokens, dtype=torch.long, device=self.device) - - def set_example_sequence(self) -> None: - """Set an example sequence useful for testing and export purposes.""" - self.reset() - bs, seq_len = min(2, self.max_batch_size), min(4, self.max_seq_len) - input_ids = torch.ones( - bs, - seq_len, - dtype=torch.int, - device=self.device, - ) - self.nest_sequences(input_ids, allow_realloc=True) - - # unflatten if we are not yet using cached+flattened attention - if not self._is_cached_attn: - self.input_ids = self.input_ids.view(bs, seq_len) - self.position_ids = self.position_ids.view(bs, seq_len) - - def _set_max_num_tokens_sample(self) -> None: - """Set an example sequence with max_num_tokens.""" - self.reset() - seq_len = self.max_num_tokens // self.max_batch_size - input_ids = torch.ones( - self.max_batch_size, - seq_len, - dtype=torch.int, - device=self.device, - ) - self.pages_per_seq.fill_(seq_len // self.page_size) - self.nest_sequences(input_ids, allow_realloc=True) - - def set_generate_only_batch(self) -> None: - """Set an example sequence for generate-only batch. - - NOTE: this batch is already formatted as [b, 1] in both original and in the cached attention - mode. So we don't need to do anything mode-specific here. - """ - self.reset() - self.nest_sequences([[1]] * self.max_batch_size, allow_realloc=True) - - def maybe_reshape_for_generate(self, tensor: torch.Tensor) -> torch.Tensor: - # use [b,1] shape to indicate generate-only batch, otherwise use [1,total_len] - if self.is_generate: - return tensor.view(-1, 1, *tensor.shape[1:]) - else: - return tensor.view(1, -1, *tensor.shape[1:]) - - @nvtx_range("ad_update_position_ids") - def _update_position_ids(self, allow_realloc: bool = False) -> None: - # set new position_ids from input_pos and seq_len - # Make sure this is done on host to avoid host-device copies. - with nvtx_range("prepare_list"): - # Optimize for the common case where all seq_len values are 1 (generation mode) - if torch.all(self.seq_len_host == 1): - # Fast path: when all seq_len are 1, position_ids is just input_pos_host - position_ids_host = ( - self.input_pos_host[: self.num_tokens].to(dtype=torch.long).pin_memory() - ) - else: - # General case - can probably be optimized too, but overall impact will be minor. - position_ids_list = [] - for in_pos, seq_len in zip(self.input_pos_host, self.seq_len_host): - position_ids_list.extend(range(in_pos, in_pos + seq_len)) - position_ids_host = torch.tensor( - position_ids_list, dtype=torch.long, pin_memory=True - ) - with nvtx_range("copy_to_device"): - if allow_realloc: - # Create a new position_ids tensor on the device - self.position_ids = position_ids_host.to(self.device).clone() - else: - self.position_ids_full = self.position_ids_full.flatten() - self.position_ids_full[: self.num_tokens].copy_( - position_ids_host, non_blocking=True - ) - with nvtx_range("maybe_reshape"): - self.position_ids = self.maybe_reshape_for_generate( - self.position_ids if allow_realloc else self.position_ids_full[: self.num_tokens] - ) - - @nvtx_range("ad_update_sequence_lengths") - def _update_sequence_lengths(self, sequence_lengths: List[int]) -> None: - self._sequence_lengths = sequence_lengths - self.num_tokens = sum(self._sequence_lengths) - self.seq_len.zero_() - self.seq_len_host = torch.tensor(self._sequence_lengths, pin_memory=True) - self.seq_len[: len(self._sequence_lengths)].copy_(self.seq_len_host, non_blocking=True) - - def update_input_ids_with_new_tokens( - self, new_tokens: torch.Tensor, previous_batch_indices: List[int] + def _store_arg( + self, + name: str, + tnsr_like: List[int | float], + reset: bool = False, ) -> None: - """Update the input_ids with new tokens. + """Store the argument on the host and copy to the device in a non-blocking fashion. - This function will update the input_ids with new tokens and previous batch indices. + Args: + name: Name of the argument to store. + tnsr_like: List of values to store. + reset: Whether to reset the full tensor on the device to 0 before writing to it. """ - # 1) flatten once - original_shape = self.input_ids.shape - flat = self.input_ids.flatten() + with nvtx_range(f"ad_store_seq_info_arg_{name}"): + tnsr_device = self._args_device[name] - # copy indices to the GPU - host_idx = torch.tensor(previous_batch_indices, dtype=torch.int, pin_memory=True) - idx = self.previous_batch_indices_cuda[: len(previous_batch_indices)] - idx.copy_(host_idx, non_blocking=True) + # store list object on the host + self._args_host[name] = tnsr_like.copy() - # gather the exact values you want to write - src = new_tokens[0, idx, 0] + # pin the memory on the host + tnsr_host = torch.tensor(tnsr_like, dtype=tnsr_device.dtype, pin_memory=True) - # in‐place fill every slot where flat == -1 with src, in order - flat.masked_scatter_(flat == -1, src) + # reset/copy to the device in a non-blocking fashion + if reset: + tnsr_device.zero_() + tnsr_device[: len(tnsr_like)].copy_(tnsr_host, non_blocking=True) - # 4) reshape back - self.input_ids = flat.view(original_shape) + def _store_extra_arg( + self, name: str, tnsr_like: Optional[Union[torch.Tensor, Sequence[torch.Tensor]]] + ) -> None: + with nvtx_range(f"ad_store_extra_arg_{name}"): + if tnsr_like is not None: + if not isinstance(tnsr_like, torch.Tensor): + if len(tnsr_like) > 1: + tnsr_like = torch.cat(tnsr_like) + else: + tnsr_like = tnsr_like[0] + self._extra_args[name] = tnsr_like.to(self.device, non_blocking=True) + else: + self._extra_args[name] = self._extra_none_inputs[name] @nvtx_range("ad_nest_sequences") def nest_sequences( - self, input_ids: Sequence[Sequence[int]], allow_realloc: bool = False + self, + input_ids: Sequence[Sequence[int]], + position_ids: Optional[Sequence[Sequence[int]]] = None, + input_pos: Optional[Union[Sequence[int], int]] = None, + page_assignments: Optional[Sequence[Sequence[int]]] = None, + **extra_args: Dict[str, Union[torch.Tensor, Sequence[torch.Tensor]]], ) -> None: - """Create and store a flattened list of input_ids from the provided list of sequences. + """Create and store sequence information for the next forward pass. - When allow_realloc is True, the input_ids will be reallocated on the device. - This i/f will also update any relevant sequence information. + Args: + input_ids: List of sequences of input_ids. + position_ids: List of sequences of position_ids for each token. + input_pos: Absolute starting position in the cache for each sequence. + page_assignments: List of sequences of page assignments for each sequence. + extra_args: Extra arguments to be stored in the interface. + + This i/f will ensure that all sequence info args are updated accordingly. """ - # set new sequence lengths - self._update_sequence_lengths([len(ids) for ids in input_ids]) + ### UPDATE METADATA ######################################################################## + # update metadata first since it's useful for other updates to have up-to-date information - # We'll preserve the dtype of the input_ids tensor if it is a tensor, otherwise we'll use int - dtype = input_ids.dtype if isinstance(input_ids, torch.Tensor) else torch.int - # set new input_ids as new tensor from flattened input_ids - ids_list = [ - val - for lst in input_ids - for val in (lst.detach().tolist() if isinstance(lst, torch.Tensor) else lst) - ] - input_ids_host = torch.tensor(ids_list, dtype=dtype, pin_memory=True) + # set new sequence lengths --> resetting the remaining entries to zero is important to help + # us discern the actual number of sequences in the batch. + self._store_arg("seq_len", [len(ids) for ids in input_ids], reset=True) - if allow_realloc: - self.input_ids = input_ids_host.to(self.device).clone() - else: - self.input_ids_full = self.input_ids_full.flatten() - self.input_ids_full[: self.num_tokens].copy_(input_ids_host, non_blocking=True) + # check for updated input_pos (i.e. cache start position) + if input_pos is not None: + self._store_arg( + "input_pos", + [input_pos] * self.num_sequences if isinstance(input_pos, int) else input_pos, + ) - self.input_ids = self.maybe_reshape_for_generate( - self.input_ids if allow_realloc else self.input_ids_full[: self.num_tokens] - ) - # update position_ids - self._update_position_ids(allow_realloc=allow_realloc) + # check for updated page_assignments + if page_assignments is not None: + cache_loc, pages_per_seq = self._get_cache_locations_and_pages_per_sequence( + page_assignments + ) + self._store_arg("cache_loc", cache_loc) + self._store_arg("pages_per_seq", pages_per_seq) + + ### UPDATE MAIN INPUTS ##################################################################### + # set new input_ids and make sure to flatten it + self._store_arg("input_ids", self._flatten(input_ids)) + + # set new position_ids and make sure to flatten it + if position_ids is None: + position_ids = [ + [num for num in range(in_pos, in_pos + seq_len)] + for in_pos, seq_len in zip(self.input_pos, self.seq_len) + ] + self._store_arg("position_ids", self._flatten(position_ids)) + + ### UPDATE EXTRA INPUTS #################################################################### + # go through all extra tensor arguments and update them + for name in self._extra_none_inputs.keys(): + self._store_extra_arg(name, extra_args.pop(name, None)) + assert not extra_args, f"Extra arguments {extra_args.keys()} not found" + + @nvtx_range("ad_rescatter_input_ids") + def rescatter_input_ids( + self, ungathered_input_ids: torch.Tensor, gather_idx: List[int], scatter_ref: int + ): + """Re-scatter the provided ungathered input ids into the input_ids tensor. + + Args: + ungathered_input_ids: The input ids on the device from which to gather. + gather_idx: The list of indices to gather from the ungathered_input_ids. + scatter_ref: The reference index to scatter to in input_ids via masked scatter. + + Returns: + None + + This function will assume that we are in a generate-only batch. + """ + # store the new gather indices + self._store_arg("_gather_idx", gather_idx) + + # gather the provided input ids in a streaming fashion + gather_ids_device = self._args_device["_gather_idx"][: len(gather_idx)] + packed_input_ids = ungathered_input_ids[gather_ids_device] + + # re-scatter the provided input ids into the input_ids tensor + input_ids_device = self._args_device["input_ids"] + input_ids_device.masked_scatter_(input_ids_device == scatter_ref, packed_input_ids) @nvtx_range("ad_unnest_sequences") def unnest_sequences(self, t_nested: torch.Tensor) -> List[torch.Tensor]: t_squeezed = t_nested.squeeze(1) if self.is_generate else t_nested.squeeze(0) - return list(torch.split(t_squeezed, self.sequence_lengths)) + return list(torch.split(t_squeezed, self.seq_len)) - @nvtx_range("ad_update_pos") - def update_pos(self, seq_len: Union[torch.Tensor, List[int], int], reset: bool = False) -> None: - """Update the starting position for each sequence in the cache. + def add_extra_arg( + self, + name: str, + none_input: torch.Tensor, + dynamic_shape_callback: Optional[DynamicShapeCallback] = None, + ) -> None: + """Add an extra argument to the sequence info object. - If ``reset=True`, ``input_pos`` will be reset to zero before updating. + Args: + name: The name of the extra argument. + none_input: None input value of the extra argument. + dynamic_shape_callback: The callback to get the dynamic shape of the extra argument. + + Note that the extra argument is expected to be a tensor. """ - if not isinstance(seq_len, torch.Tensor): - seq_len = torch.tensor(seq_len, dtype=torch.int, pin_memory=True) - bs = len(seq_len) if seq_len.dim() > 0 else self.max_batch_size + assert name not in self._named_args().keys(), f"Extra argument {name} already exists" - if reset: - self.input_pos_host[:bs].copy_(seq_len, non_blocking=True) + self._extra_args[name] = none_input.to(self.device) + self._extra_none_inputs[name] = none_input.to(self.device) + + if dynamic_shape_callback is None: + self._extra_dynamic_shapes_callbacks[name] = lambda: {} else: - self.input_pos_host[:bs] += seq_len - - # update position_ids - self._update_position_ids() - self.input_pos[:bs].copy_(self.input_pos_host[:bs], non_blocking=True) - - @nvtx_range("ad_assign_cache_loc") - def assign_cache_loc(self, page_assignments: Sequence[Sequence[int]]) -> None: - """Set the cache location and pages_per_seq tensors from page assignments.""" - cache_loc_flat = torch.tensor( - [p_idx for pages in page_assignments for p_idx in pages], - dtype=torch.int, - pin_memory=True, - ) - self.cache_loc[: len(cache_loc_flat)].copy_(cache_loc_flat, non_blocking=True) - - pages_per_seq = torch.tensor( - [len(p) for p in page_assignments], dtype=torch.int, pin_memory=True - ) - self.pages_per_seq[: len(pages_per_seq)].copy_(pages_per_seq, non_blocking=True) + self._extra_dynamic_shapes_callbacks[name] = dynamic_shape_callback Constant = Union[int, float, str, None] diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py index 89dc59f635..b55bbe6bfd 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py @@ -63,6 +63,7 @@ def scaled_dot_product_attention( dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, + enable_gqa: bool = False, ) -> torch.Tensor: """A carbon copy of torch.nn.functional.scaled_dot_product_attention as custom op. @@ -78,12 +79,13 @@ def scaled_dot_product_attention( dropout_p=dropout_p, is_causal=is_causal, scale=scale, + enable_gqa=enable_gqa, ) @scaled_dot_product_attention.register_fake def scaled_dot_product_attention_fake( - query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None + query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, enable_gqa=False ): """Fake implementation of scaled_dot_product_attention.""" return query.new_empty(*query.shape[:-1], value.shape[-1]).contiguous() diff --git a/tensorrt_llm/_torch/auto_deploy/export/export.py b/tensorrt_llm/_torch/auto_deploy/export/export.py index 475017a284..9d9af3cf9e 100644 --- a/tensorrt_llm/_torch/auto_deploy/export/export.py +++ b/tensorrt_llm/_torch/auto_deploy/export/export.py @@ -18,7 +18,7 @@ from ..transformations._graph import ( ) from ..utils.logger import ad_logger from ..utils.node_utils import is_op -from .interface import ExportPatchRegistry, apply_export_patches +from .interface import apply_export_patches try: from modelopt.torch.quantization.utils import export_torch_mode as torch_export_context @@ -229,20 +229,9 @@ def torch_export_to_gm( patch_list: Optional list of patch names to apply with default settings. Cannot be used together with patch_configs. """ - # Validate that both patch_configs and patch_list are not provided simultaneously - if patch_configs is not None and patch_list is not None: - raise ValueError("Cannot specify both patch_configs and patch_list. Use only one.") - - # Handle patch configuration - if patch_list is not None: - # Convert patch_list to patch_configs format - patch_configs = {patch_name: {} for patch_name in patch_list} - elif patch_configs is None: - # Default patch configurations - apply all registered patches with default settings - patch_configs = {patch_name: {} for patch_name in ExportPatchRegistry.list_patches()} # run export with patches and lifted to meta - with apply_export_patches(patch_configs), lift_to_meta(model) as state_dict: + with apply_export_patches(patch_configs, patch_list), lift_to_meta(model) as state_dict: # clean up args, kwargs and move to correct device args, kwargs = tree_to((args, kwargs or {}), device="meta") diff --git a/tensorrt_llm/_torch/auto_deploy/export/interface.py b/tensorrt_llm/_torch/auto_deploy/export/interface.py index c97b056a00..db0cbbd94d 100644 --- a/tensorrt_llm/_torch/auto_deploy/export/interface.py +++ b/tensorrt_llm/_torch/auto_deploy/export/interface.py @@ -5,7 +5,7 @@ This module defines the base classes and interfaces for all export patches. from abc import ABC, abstractmethod from contextlib import contextmanager -from typing import Any, Callable, Dict, List, Type, Union, final +from typing import Any, Callable, Dict, List, Optional, Type, Union, final from pydantic import BaseModel, Field @@ -183,6 +183,8 @@ class ExportPatchRegistry: @classmethod def get(cls, name: str) -> Type[BaseExportPatch]: """Get a patch class by name.""" + if not cls.has(name): + raise ValueError(f"Unknown patch: {name}") return cls._registry[name] @classmethod @@ -212,20 +214,29 @@ class ExportPatchRegistry: @contextmanager -def apply_export_patches(patch_configs: Dict[str, Union[ExportPatchConfig, Dict[str, Any]]]): +def apply_export_patches( + patch_configs: Optional[Dict[str, Union[ExportPatchConfig, Dict[str, Any]]]] = None, + patch_list: Optional[List[str]] = None, +): """Context manager to apply multiple patches. Args: patch_configs: Dict mapping patch names to their configurations. """ - patches = [] + # Validate that both patch_configs and patch_list are not provided simultaneously + if patch_configs is not None and patch_list is not None: + raise ValueError("Cannot specify both patch_configs and patch_list. Use only one.") + + # Handle patch configuration + if patch_list is not None: + # Convert patch_list to patch_configs format + patch_configs = {patch_name: {} for patch_name in patch_list} + elif patch_configs is None: + # Default patch configurations - apply all registered patches with default settings + patch_configs = {patch_name: {} for patch_name in ExportPatchRegistry.list_patches()} # Create patch instances - for name, config in patch_configs.items(): - if not ExportPatchRegistry.has(name): - raise ValueError(f"Unknown patch: {name}") - patch = ExportPatchRegistry.create_patch(name, config) - patches.append(patch) + patches = [ExportPatchRegistry.create_patch(k, conf) for k, conf in patch_configs.items()] # Apply patches using nested context managers if not patches: diff --git a/tensorrt_llm/_torch/auto_deploy/llm.py b/tensorrt_llm/_torch/auto_deploy/llm.py index 999a024fb3..0076489957 100644 --- a/tensorrt_llm/_torch/auto_deploy/llm.py +++ b/tensorrt_llm/_torch/auto_deploy/llm.py @@ -1,19 +1,92 @@ import types -from typing import List, Optional +from typing import Any, Dict, List, Optional, Tuple from ...executor.result import CompletionOutput -from ...inputs.registry import create_input_processor +from ...inputs.registry import DefaultInputProcessor, ExtraProcessedInputs from ...llmapi.llm import RequestOutput, _TorchLLM -from ...llmapi.tokenizer import TokenizerBase, tokenizer_factory +from ...llmapi.tokenizer import TokenizerBase, TransformersTokenizer, tokenizer_factory +from ...sampling_params import SamplingParams from .distributed import common as dist_ad from .llm_args import LlmArgs +from .models.factory import ModelFactory from .shim.demollm import DemoGenerationExecutor +class ADInputProcessor(DefaultInputProcessor): + """Input processor for AutoDeploy backend. + + This is a wrapper to either support standard TRT-LLM text-only input processing or use HF's + message chat template system to process multimodal inputs. + """ + + def __init__(self, tokenizer: Optional[TokenizerBase], processor: Optional[Any] = None): + super().__init__(model_path=None, model_config=None, tokenizer=tokenizer) + # NOTE: HF's tokenizer/processor that has the apply_chat_template method + self.processor = processor or getattr(tokenizer, "tokenizer", None) + + def __call__( + self, inputs: Dict[str, Any], sampling_params: SamplingParams + ) -> Tuple[List[int], Optional[ExtraProcessedInputs]]: + if self.processor is None: + raise ValueError("processor is required to tokenize inputs") + + # construct kwargs to reflect DefaultInputProcessor + kwargs = { + "add_special_tokens": sampling_params.add_special_tokens, + } + if sampling_params.truncate_prompt_tokens is not None: + kwargs = { + "truncation": True, + "max_length": sampling_params.truncate_prompt_tokens, + } + # check for messages field and if yes, use the apply_chat_template method + if "messages" in inputs: + # TODO: we don't really need this but it makes for a good sanity check. Consider + # removing this in the future if we need to speed things up. + prompt = self.processor.apply_chat_template( + inputs["messages"], + add_generation_prompt=True, + tokenize=False, + ) + inputs["prompt"] = prompt + + all_args = self.processor.apply_chat_template( + inputs["messages"], + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + padding=False, # there shouldn't be a need for padding ever... + return_attention_mask=False, + **kwargs, + ) + # TODO: is there a more reliable way to avoid the attention_mask here? + all_args.pop("attention_mask", None) + + # TODO: can we avoid the extra tolist() here eventually? + token_ids = all_args.pop("input_ids") + assert token_ids.shape[0] == 1, "messages should be unbatched at this point." + if all_args: + extra_processed_inputs = {"multimodal_data": all_args} + else: + extra_processed_inputs = None + return token_ids[0].tolist(), extra_processed_inputs + else: + token_ids = self.tokenizer.encode(inputs["prompt"], **kwargs) + return token_ids, None + + class LLM(_TorchLLM): """LLM class is the main class for running an LLM model using AutoDeploy backend.""" args: LlmArgs + _factory: ModelFactory + + @property + def factory(self) -> ModelFactory: + if not getattr(self, "_factory", None): + self._factory = self.args.create_factory() + return self._factory def __init__(self, *args, **kwargs): kwargs["backend"] = "_autodeploy" @@ -23,16 +96,18 @@ class LLM(_TorchLLM): if self.args.skip_tokenizer_init: return None - factory = self.args.create_factory() - return tokenizer_factory(factory.init_tokenizer()) + return tokenizer_factory(self.factory.init_tokenizer()) def _validate_args_for_torch_backend(self, kwargs: dict) -> None: """We don't need to validate args for AutoDeploy backend for now.""" pass + def _create_input_processor(self) -> ADInputProcessor: + return ADInputProcessor(self.tokenizer, self.factory.init_processor()) + def _prefetch_model(self): """Prefetch the model for the LLM.""" - self.args.create_factory().prefetch_checkpoint() + self.factory.prefetch_checkpoint() def _build_model(self): """Build the model for the LLM. @@ -47,6 +122,11 @@ class LLM(_TorchLLM): # _autodeploy backend. super()._build_model() + # now correct input processor + assert isinstance(self.input_processor, DefaultInputProcessor) + assert self.tokenizer is None or isinstance(self.tokenizer, TransformersTokenizer) + self.input_processor = self._create_input_processor() + class DemoLLM(LLM): """A simple LLM class to demo the LLM interface while debugging the e2e workflow. @@ -63,7 +143,7 @@ class DemoLLM(LLM): # prefetch model and load tokenizer self._prefetch_model() self._tokenizer = self._try_load_tokenizer() - self.input_processor = create_input_processor(None, self.tokenizer) + self.input_processor = self._create_input_processor() # construct demo executor + engine self._executor = DemoGenerationExecutor( diff --git a/tensorrt_llm/_torch/auto_deploy/llm_args.py b/tensorrt_llm/_torch/auto_deploy/llm_args.py index 2384d5c953..dbe5a857bd 100644 --- a/tensorrt_llm/_torch/auto_deploy/llm_args.py +++ b/tensorrt_llm/_torch/auto_deploy/llm_args.py @@ -57,7 +57,7 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings): description="The path to the model checkpoint or the model name from the Hugging Face Hub." ) - model_factory: Literal["AutoModelForCausalLM", "AutoModelForImageTextToText"] = Field( + model_factory: str = Field( default="AutoModelForCausalLM", description="The model factory to use for loading the model.", ) diff --git a/tensorrt_llm/_torch/auto_deploy/models/factory.py b/tensorrt_llm/_torch/auto_deploy/models/factory.py index 4cf0a093ee..8e19ea0ed1 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/factory.py +++ b/tensorrt_llm/_torch/auto_deploy/models/factory.py @@ -3,13 +3,13 @@ import copy from abc import ABC, abstractmethod from enum import Enum -from typing import Any, Callable, Dict, Optional, Type +from typing import Any, Callable, Dict, Optional, Tuple, Type import torch import torch.nn as nn from torch._prims_common import DeviceLikeType -from ..custom_ops.attention_interface import CacheConfig +from ..custom_ops.attention_interface import CacheConfig, DynamicShapeCallback from ..utils.logger import ad_logger @@ -25,7 +25,8 @@ class ModelFactory(ABC): NOTE: we make the assumption that the model can be prompted with a set of input_ids and position_ids of shape (batch_size, seq_len) and will return a tensor of shape - (batch_size, seq_len, embedding_size). + (batch_size, seq_len, embedding_size) by default. Individual factories have the ability to + define additional optional inputs and their (dynamic) shapes. """ def __init__( @@ -74,7 +75,7 @@ class ModelFactory(ABC): .. code-block:: python def forward( - self, input_ids: torch.Tensor, position_ids: torch.Tensor + self, input_ids: torch.Tensor, position_ids: torch.Tensor, *extra_args: torch.Tensor ) -> Sequence[torch.Tensor]: ... ``logits`` are assumed to be the first output of the model, i.e., @@ -87,6 +88,9 @@ class ModelFactory(ABC): input_ids.shape == (batch_size, seq_len) position_ids.shape == (batch_size, seq_len) logits.shape == (batch_size, seq_len, vocab_size) + + We allow for additional arguments to be passed to the model's forward function as defined by + the factory. """ # make sure model architecture is pre-fetched (no weights needed at this point) skip_loading_weights = self.skip_loading_weights @@ -127,6 +131,15 @@ class ModelFactory(ABC): """ return None + def init_processor(self) -> Optional[Any]: + """Initialize the (multi-modal) processor for the model. + + Returns: + The initialized processor for the model. If the processor is not available, then this + method should return None. + """ + return None + def prefetch_checkpoint(self, force: bool = False): """Try or skip prefetching the checkpoint for the model and tokenizer. @@ -220,6 +233,35 @@ class ModelFactory(ABC): device: The device to load the model on. """ + def get_example_inputs(self) -> Dict[str, torch.Tensor]: + """Return a dictionary of example inputs for the model. + + This function can be overwritten by a factory when it requires a specific example input to + in order to run through export. + + Returns: + A dictionary of example inputs for the model where the key corresponds to the argument + name and the value corresponds to the example input. + """ + return {} + + def get_extra_inputs(self) -> Dict[str, Tuple[torch.Tensor, Optional[DynamicShapeCallback]]]: + """Return a dictionary of extra model inputs that behave like optional forward arguments. + + Returns: + A dictionary of extra inputs for the model where the key corresponds to the argument + name and the value corresponds to a tuple of (none_input, dynamic_shape_callback): + - `none_input`: The none input value of the extra input indicating the tensor + value corresponding to the equivalent of the None input. `None` is not supported + as we require the input to be a tensor. Hence, this none_input acts as a + placeholder for the None input. We assume that the "optional" behavior of these + arguments can be represented via a placeholder tensor and and an appropriate + check within the forward function using ``torch.cond``. + - `dynamic_shape_callback`: A function that returns the dynamic shape of the extra + input. Simply set to `None` if the extra input is not dynamic. + """ + return {} + class ModelFactoryRegistry: _registry: Dict[str, Type[ModelFactory]] = {} diff --git a/tensorrt_llm/_torch/auto_deploy/models/hf.py b/tensorrt_llm/_torch/auto_deploy/models/hf.py index 0a35690c68..5c3942f082 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/hf.py +++ b/tensorrt_llm/_torch/auto_deploy/models/hf.py @@ -3,7 +3,7 @@ import os import types from contextlib import contextmanager, nullcontext -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -11,11 +11,13 @@ from accelerate import init_empty_weights, load_checkpoint_in_model from accelerate.utils import modeling from huggingface_hub import HfApi, snapshot_download from huggingface_hub.utils import HFValidationError, filter_repo_objects, validate_repo_id +from PIL import Image from torch._prims_common import DeviceLikeType from transformers import ( AutoConfig, AutoModelForCausalLM, AutoModelForImageTextToText, + AutoProcessor, AutoTokenizer, PretrainedConfig, ) @@ -26,7 +28,7 @@ from transformers.utils import ( WEIGHTS_NAME, ) -from ..custom_ops.attention_interface import CacheConfig +from ..custom_ops.attention_interface import CacheConfig, Dim, DynamicShapeCallback from ..utils._config import deep_merge_dicts from ..utils.logger import ad_logger from .factory import ModelFactory, ModelFactoryRegistry, ShardingConfigSource @@ -101,10 +103,6 @@ class AutoModelForCausalLMFactory(ModelFactory): def autoconfig_from_pretrained(self): return AutoConfig.from_pretrained - @property - def autotokenizer_from_pretrained(self): - return AutoTokenizer.from_pretrained - # TODO (@lucaslie): Do we ever want to switch to from_pretrained? @property def automodel_from_config(self): @@ -114,7 +112,9 @@ class AutoModelForCausalLMFactory(ModelFactory): def _simple_forward(model: nn.Module, input_ids: torch.Tensor, position_ids: torch.Tensor): """A simple forward pass for the model to functionalize the args. - This follows the standard function signature as expected by factory.py. + This follows the standard function signature as expected by factory.py. We do _not_ use the + model.forward method directly to create the patch. Instead we use the type of the model to + get the forward method to keep the patch composable with other forward patches. """ return type(model).forward(model, input_ids=input_ids, position_ids=position_ids) @@ -158,11 +158,13 @@ class AutoModelForCausalLMFactory(ModelFactory): with (init_empty_weights if device == "meta" else nullcontext)(): model = self.automodel_from_config(model_config, trust_remote_code=True) - - # post-init --> this must be called explicitly for HF models the way we initialize them since - # this "gets lost" with the init_empty_weights context manager. - if hasattr(model, "post_init"): - model.post_init() + if device == "meta": + # post-init --> this must be called explicitly for HF models the way we initialize them + # since this "gets lost" with the init_empty_weights context manager. + if hasattr(model, "post_init"): + model.post_init() + else: + model.to(device) # if present, initialize sharding config. We need head_dim for colwise sharding. self._set_sharding_config(model.config) @@ -211,7 +213,7 @@ class AutoModelForCausalLMFactory(ModelFactory): """Initialize the tokenizer—either a custom name or the model's default.""" if self.tokenizer is None: return None - return self.autotokenizer_from_pretrained(self.tokenizer, **self.tokenizer_kwargs) + return AutoTokenizer.from_pretrained(self.tokenizer, **self.tokenizer_kwargs) @staticmethod def _get_ignore_patterns(repo_id: str, skip_prefetch_weights: bool) -> List[str]: @@ -375,3 +377,102 @@ class AutoModelForImageTextToTextFactory(AutoModelForCausalLMFactory): @property def automodel_from_config(self): return AutoModelForImageTextToText.from_config + + def init_tokenizer(self) -> Optional[Any]: + """Initialize the tokenizer—either a custom name or the model's default.""" + processor = self.init_processor() + if processor is None: + return None + return processor.tokenizer + + def init_processor(self) -> Optional[Any]: + """Initialize the processor for the model.""" + if self.tokenizer is None: + return None + return AutoProcessor.from_pretrained(self.tokenizer, **self.tokenizer_kwargs) + + @staticmethod + def _simple_forward( + model: nn.Module, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + pixel_values: torch.Tensor, + ): + """A simple forward pass for the model to functionalize the args. + + This follows the standard function signature as expected by factory.py. We do _not_ use the + model.forward method directly to create the patch. Instead we use the type of the model to + get the forward method to keep the patch composable with other forward patches. + """ + return type(model).forward( + model, + input_ids=input_ids, + position_ids=position_ids, + pixel_values=pixel_values, + ) + + def get_example_inputs(self) -> Dict[str, torch.Tensor]: + """Return a dictionary of example inputs for the model.""" + + def _prep_seq(text, img1, img2): + return [ + { + "role": "user", + "content": [ + {"type": "image", "image": img1}, + {"type": "image", "image": img2}, + {"type": "text", "text": text}, + ], + } + ] + + # Create a batch of conversations (batch_size = 2) + batch_messages = [ + _prep_seq( + "Describe what you see in the two images and their differences.", + Image.new("RGB", (16, 16), color=(128, 128, 128)), + Image.new("RGB", (16, 16), color=(64, 64, 64)), + ), + _prep_seq( + "What are the main differences between these two images?", + Image.new("RGB", (16, 16), color=(255, 0, 0)), + Image.new("RGB", (16, 16), color=(0, 255, 0)), + ), + ] + + processor = AutoProcessor.from_pretrained(self.tokenizer, **self.tokenizer_kwargs) + inputs = processor.apply_chat_template( + batch_messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + padding=True, + return_attention_mask=False, + ) + + return { + "input_ids": inputs["input_ids"], + "pixel_values": inputs["pixel_values"], + } + + def get_extra_inputs(self) -> Dict[str, Tuple[torch.Tensor, Optional[DynamicShapeCallback]]]: + """Return a dictionary of extra inputs for the model. + + Returns: + A dictionary of extra inputs for the model where the key corresponds to the argument + name and the value corresponds to a tuple of (example_input, dynamic_shape_callback). + The dynamic shape callback is a function that returns the dynamic shape of the extra + input. Simply set to `None` if the extra input is not dynamic. + """ + + def _get_dynamic_shape(): + return { + # TODO (lucaslie): how to set default values for dynamic shapes? + 0: Dim("img_batch_size", max=10), + 2: Dim("img_height", min=32, max=2048), + 3: Dim("img_width", min=32, max=2048), + } + + none_pixel_values = torch.zeros(0, 3, 336, 336) + return {"pixel_values": (none_pixel_values, _get_dynamic_shape)} diff --git a/tests/unittest/_torch/auto_deploy/integration/test_llama4_vlm_export.py b/tensorrt_llm/_torch/auto_deploy/models/patches/llama4.py similarity index 55% rename from tests/unittest/_torch/auto_deploy/integration/test_llama4_vlm_export.py rename to tensorrt_llm/_torch/auto_deploy/models/patches/llama4.py index 596b7ff50d..239cdf35af 100644 --- a/tests/unittest/_torch/auto_deploy/integration/test_llama4_vlm_export.py +++ b/tensorrt_llm/_torch/auto_deploy/models/patches/llama4.py @@ -1,15 +1,13 @@ +"""A patch to handle vision branch in Llama4ForConditionalGeneration.""" + from typing import List, Optional, Tuple, Union import torch import torch.nn as nn -from _model_test_utils import _hf_model_dir_or_hub_id -from PIL import Image -from transformers import AutoConfig, AutoProcessor, Llama4ForConditionalGeneration -from transformers.models.llama4.modeling_llama4 import Llama4CausalLMOutputWithPast -from utils.llm_data import llm_models_root +from transformers import Llama4ForConditionalGeneration +from transformers.models.llama4.modeling_llama4 import Llama4CausalLMOutputWithPast, Llama4TextMoe -from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm -from tensorrt_llm._torch.auto_deploy.transformations._graph import move_to_device +from ...export.interface import BaseExportPatch, ExportPatchRegistry # Copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama4/modeling_llama4.py#L1651 @@ -76,30 +74,34 @@ def _forward_with_cond( vision_feature_select_strategy=vision_feature_select_strategy, image_sizes=None, ) - original_inputs_embeds_shape = inputs_embeds.shape vision_flat = image_features.view(-1, image_features.size(-1)) - projected_vision_flat = self.multi_modal_projector(vision_flat) + projected_vision_flat = self.multi_modal_projector(vision_flat).to( + inputs_embeds.device, inputs_embeds.dtype + ) + # NOTE: get_placeholder_mask is not supported by torch.export due to numel check ########### + # special_image_mask = self.get_placeholder_mask( + # input_ids, inputs_embeds=inputs_embeds, image_features=projected_vision_flat + # ) + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor( + self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device + ) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id - special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) - final_mask = special_image_mask.to(inputs_embeds.device) - inputs_embeds = inputs_embeds.view(-1, inputs_embeds.size(-1)) + # n_image_tokens = special_image_mask.sum() + special_image_mask = ( + special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + ) + ### END OF get_placeholder_mask ############################################################ - final_mask_1d = final_mask[..., 0].reshape(-1) - # num_tokens_to_fill = final_mask_1d.sum() + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, projected_vision_flat) - # This condition statement breaks torch.export: - # TODO: sanity check on the inputs for this - # if num_tokens_to_fill != projected_vision_flat.size(0): - # raise ValueError( - # f"Mismatch: final_mask wants {num_tokens_to_fill} embeddings, " - # f"but multi_modal_projector returned {projected_vision_flat.size(0)}" - # ) - - expanded_mask = final_mask_1d.unsqueeze(-1).expand(-1, inputs_embeds.size(-1)) - inputs_embeds.masked_scatter_(expanded_mask, projected_vision_flat) - - return inputs_embeds.view(original_inputs_embeds_shape) + return inputs_embeds def _no_vision_branch(inputs_embeds, pixel_values, input_ids): return inputs_embeds @@ -132,7 +134,10 @@ def _forward_with_cond( loss = None if labels is not None: + # Shift so that tokens < n predict n if attention_mask is not None: + # we use the input attention mask to shift the logits and labels, because it is 2D. + # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device) shift_logits = logits[..., :-1, :][ shift_attention_mask.to(logits.device) != 0 @@ -141,6 +146,7 @@ def _forward_with_cond( else: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens loss_fct = nn.CrossEntropyLoss() loss = loss_fct( shift_logits.view(-1, shift_logits.size(-1)), @@ -161,81 +167,65 @@ def _forward_with_cond( ) -def test_build_run_llama4_vlm(): - atol = 1e-3 - rtol = 1e-3 +@ExportPatchRegistry.register("hf_llama4_vision") +class Llama4VisionPatch(BaseExportPatch): + """Patch for Llama4ForConditionalGeneration to make it compatible with torch.export. - model_id = _hf_model_dir_or_hub_id( - f"{llm_models_root()}/Llama-4-Scout-17B-16E-Instruct", - "meta-llama/Llama-4-Scout-17B-16E-Instruct", - ) - processor = AutoProcessor.from_pretrained(model_id) + This patch replaces the forward method of Llama4ForConditionalGeneration with + a version that uses the torch.cond to handle the optional vision branch. + """ - config = AutoConfig.from_pretrained(model_id) - config.text_config.num_hidden_layers = 2 - config.text_config.intermediate_size = 64 - config.text_config.intermediate_size_mlp = 128 - config.vision_config.num_hidden_layers = 2 - - # The returned cache breaks torch.export - config.text_config.use_cache = False - - model = Llama4ForConditionalGeneration(config).eval().to("cuda").bfloat16() - - img1 = Image.new("RGB", (16, 16), color=(128, 128, 128)) - img2 = Image.new("RGB", (16, 16), color=(64, 64, 64)) - messages = [ - { - "role": "user", - "content": [ - {"type": "image", "image": img1}, - {"type": "image", "image": img2}, - {"type": "text", "text": "What's the difference?"}, - ], - }, - ] - - inputs = ( - processor.apply_chat_template( - messages, - add_generation_prompt=True, - tokenize=True, - return_dict=True, - return_tensors="pt", + def _apply_patch(self): + """Apply the Llama4 vision patch.""" + # Store original forward method + self.original_values["Llama4ForConditionalGeneration.forward"] = ( + Llama4ForConditionalGeneration.forward ) - .to(model.device) - .to(torch.bfloat16) - ) - with torch.inference_mode(): - # the original model queried with text-only - out_text_only = model(inputs["input_ids"], None, inputs["attention_mask"]) + # Apply patch by replacing the forward method + Llama4ForConditionalGeneration.forward = _forward_with_cond - Llama4ForConditionalGeneration.forward = _forward_with_cond + def _revert_patch(self): + """Revert the Llama4 vision patch.""" + # Restore original forward method + Llama4ForConditionalGeneration.forward = self.original_values[ + "Llama4ForConditionalGeneration.forward" + ] - with torch.inference_mode(): - out_real = model(inputs["input_ids"], inputs["pixel_values"], inputs["attention_mask"]) - out_dummy = model( - inputs["input_ids"], torch.zeros_like(inputs["pixel_values"]), inputs["attention_mask"] - ) - torch.testing.assert_close(out_dummy.logits, out_text_only.logits, rtol=rtol, atol=atol) - gm = torch_export_to_gm( - model, - (inputs["input_ids"], inputs["pixel_values"], inputs["attention_mask"]), - kwargs={}, - ) - move_to_device(gm, model.device) +def _moe_forward_with_transpose(self, hidden_states): + hidden_states = hidden_states.reshape(-1, self.hidden_dim) + router_scores, router_logits = self.router(hidden_states) + routed_in = hidden_states.repeat(router_scores.shape[1], 1) - with torch.inference_mode(): - out_real_gm = gm(inputs["input_ids"], inputs["pixel_values"], inputs["attention_mask"]) - torch.testing.assert_close(out_real.logits, out_real_gm.logits, rtol=rtol, atol=atol) - out_dummy_gm = gm( - inputs["input_ids"], torch.zeros_like(inputs["pixel_values"]), inputs["attention_mask"] - ) - torch.testing.assert_close(out_dummy.logits, out_dummy_gm.logits, rtol=rtol, atol=atol) - torch.testing.assert_close(out_dummy_gm.logits, out_text_only.logits, rtol=rtol, atol=atol) + # BUG IN ORIGINAL CODE + # routed_in = routed_in * router_scores.reshape(-1, 1) + # END OF BUG IN ORIGINAL CODE - assert not torch.allclose(out_real.logits, out_dummy.logits, rtol=rtol, atol=atol), ( - "Expected outputs to differ between text only input and text+image input" - ) + # PATCH STARTED + routed_in = routed_in * router_scores.transpose(0, 1).reshape(-1, 1) + # PATCH ENDED + + routed_out = self.experts(routed_in) + out = self.shared_expert(hidden_states) + out.add_(routed_out.reshape(router_scores.shape[1], -1, routed_out.shape[-1]).sum(dim=0)) + return out, router_logits + + +# TODO: remove this patch once https://github.com/huggingface/transformers/pull/40609 is merged, +# gets released, and TRT-LLM updates to the relevant transformers version +@ExportPatchRegistry.register("hf_llama4_moe") +class Llama4MoEPatch(BaseExportPatch): + """Patch for Llama4 MoE routing to fix its current accuracy issue.""" + + def _apply_patch(self): + """Apply the Llama4 MoE routing patch.""" + # Store original forward method + self.original_values["Llama4TextMoe.forward"] = Llama4TextMoe.forward + + # Apply patch by replacing the forward method + Llama4TextMoe.forward = _moe_forward_with_transpose + + def _revert_patch(self): + """Revert the Llama4 MoE routing patch.""" + Llama4TextMoe.forward = self.original_values["Llama4TextMoe.forward"] diff --git a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py index a9ea768e27..d7e82b9838 100644 --- a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py +++ b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py @@ -1,6 +1,6 @@ -from itertools import chain +from collections import defaultdict from types import SimpleNamespace -from typing import List, Optional, Tuple +from typing import Dict, List, Optional, Tuple import torch from torch._prims_common import DeviceLikeType @@ -105,13 +105,19 @@ class ADEngine(ModelEngine): max_batch_size=max_batch_size, page_size=attn_page_size, max_num_tokens=max_num_tokens, - device=device, ) + factory = ad_config.create_factory() + + # pass in extra arguments defined by the model factory + for name, (none_input, dynamic_shape_callback) in factory.get_extra_inputs().items(): + seq_info.add_extra_arg(name, none_input, dynamic_shape_callback) + + # TODO (lucaslie): consider how we move args around InferenceOptimizer.__init__, + # ADEngine.__init__, and ADEngine.build_from_config. Seems a bit unnatural atm. + # construct inference optimizer - build_and_optimize = InferenceOptimizer( - factory=ad_config.create_factory(), ad_config=ad_config - ) + build_and_optimize = InferenceOptimizer(factory=factory, ad_config=ad_config) # construct engine return cls(build_and_optimize, seq_info, device, max_beam_width) @@ -176,7 +182,11 @@ class ADEngine(ModelEngine): input_pos: List[int] = [] last_logit_only: List[bool] = [] page_assignments: List[List[int]] = [] - previous_batch_indices: List[int] = [] + flat_gather_idx: List[int] = [] + extra_args: Dict[str, List[torch.Tensor]] = defaultdict(list) + + dummy_token = -1 + # look at context requests first for request in context_requests: # store input ids and pos of first token in sequence @@ -186,6 +196,15 @@ class ADEngine(ModelEngine): request.py_batch_idx = request.seq_slot last_logit_only.append(True) + # get cache indices + cache_indices = kv_cache_manager.get_cache_indices(request) + page_assignments.append(cache_indices) + + # store extra arguments + if request.py_multimodal_data is not None: + for k, v in request.py_multimodal_data.items(): + extra_args[k].append(v) + # look at generate requests next # TODO: we should also handle extend requests (for speculative decoding) here for request in gen_requests: @@ -194,30 +213,34 @@ class ADEngine(ModelEngine): input_ids.append([request.get_token(0, request.get_num_tokens(0) - 1)]) input_pos.append(request.max_beam_num_tokens - 1) else: - # insert a dummy token to indicate the new tokens - input_ids.append([-1]) - previous_batch_indices.append(request.py_batch_idx) + input_ids.append([dummy_token]) input_pos.append(request.max_beam_num_tokens) + flat_gather_idx.append(request.py_batch_idx) request.py_batch_idx = request.seq_slot # return all logits last_logit_only.append(False) - # extract cache information for all requests - for request in chain(context_requests, gen_requests): # get cache indices cache_indices = kv_cache_manager.get_cache_indices(request) page_assignments.append(cache_indices) # update the sequence info object now - si = self.cache_seq_interface.info - si.update_pos(input_pos, reset=True) - si.assign_cache_loc(page_assignments) - si.nest_sequences(input_ids) - + self.cache_seq_interface.info.nest_sequences( + input_ids, + input_pos=input_pos, + page_assignments=page_assignments, + **extra_args, + ) + # scatter the new tokens into the input_ids tensor if provided if new_tokens is not None: - si.update_input_ids_with_new_tokens(new_tokens, previous_batch_indices) + self.cache_seq_interface.info.rescatter_input_ids( + ungathered_input_ids=new_tokens.flatten(), # ensure it's flattened + gather_idx=flat_gather_idx, + scatter_ref=dummy_token, + ) + return last_logit_only @nvtx_range("ad_compute_logits") diff --git a/tensorrt_llm/_torch/auto_deploy/shim/demollm.py b/tensorrt_llm/_torch/auto_deploy/shim/demollm.py index c29cb5fbd7..fb374f1e94 100644 --- a/tensorrt_llm/_torch/auto_deploy/shim/demollm.py +++ b/tensorrt_llm/_torch/auto_deploy/shim/demollm.py @@ -1,8 +1,9 @@ """A demo LLM api to for debugging and testing purposes of e2e workflows.""" import gc +from collections import defaultdict from queue import Empty -from typing import Any, Callable, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple import torch import torch.multiprocessing as mp @@ -10,6 +11,7 @@ import torch.multiprocessing as mp from ....executor import GenerationExecutor from ....executor.request import GenerationRequest from ....executor.result import CompletionOutput, GenerationResult +from ....inputs.multimodal import MultimodalParams from ....sampling_params import SamplingParams from ...pyexecutor.sampler import greedy_search_sampling_batch, top_k_sampling_batch from ..distributed import common as dist_ad @@ -34,8 +36,11 @@ class DemoEngine(ADEngine): self.queue = mp.Queue() @torch.inference_mode() - def __call__(self, requests: GenerationRequest) -> mp.Queue: + def __call__( + self, requests: GenerationRequest, multimodal_params: Optional[MultimodalParams] + ) -> mp.Queue: """Generate tokens and put the results in a queue and return the queue.""" + requests.multimodal_params = multimodal_params output = self.generate_tokens_batched([requests])[0] self.queue.put(output) return self.queue @@ -45,7 +50,7 @@ class DemoEngine(ADEngine): self.queue.close() self.queue.join_thread() - def _assign_pages(self) -> List[List[int]]: + def _assign_pages(self, total_lens: List[int]) -> List[List[int]]: """A simple heuristic to assign pages based on current sequence info. In a nutshell, we will look at the following information to update the page assignments: @@ -67,7 +72,6 @@ class DemoEngine(ADEngine): unassigned page if needed. """ si = self.cache_seq_interface.info - total_lens = [s_l + i_p for s_l, i_p in zip(si.sequence_lengths, si.input_positions)] page_assignments = si.page_assignments free_pages = set(range(si.num_pages)) - {i for pages in page_assignments for i in pages} @@ -76,7 +80,7 @@ class DemoEngine(ADEngine): extra_tokens = t_l - len(pages) * si.page_size num_extra_pages = (extra_tokens // si.page_size) + (extra_tokens > 0) updated_assignments.append(pages + [free_pages.pop() for _ in range(num_extra_pages)]) - si.assign_cache_loc(updated_assignments) + return updated_assignments def generate_tokens_batched( self, requests: List[GenerationRequest] @@ -91,10 +95,27 @@ class DemoEngine(ADEngine): ) assert sampling_params.best_of == 1, "Best-of is not supported." - # set up sequence info object + # set up sequence info object for decode phase sequence_info = self.cache_seq_interface.info + + input_ids = [] + total_lens = [] + extra_args: Dict[str, List[torch.Tensor]] = defaultdict(list) + + for request in requests: + total_lens.append(len(request.prompt_token_ids)) + input_ids.append(request.prompt_token_ids) + if request.multimodal_params is not None: + for k, v in request.multimodal_params.multimodal_data.items(): + extra_args[k].append(v) + sequence_info.reset() - sequence_info.nest_sequences([r.prompt_token_ids for r in requests]) + sequence_info.nest_sequences( + input_ids=input_ids, + input_pos=0, + page_assignments=self._assign_pages(total_lens), + **extra_args, + ) # setup objects we want to track for the output batch_size = sequence_info.num_sequences @@ -105,18 +126,21 @@ class DemoEngine(ADEngine): context_logits: Optional[List[torch.Tensor]] = None def _generate_single_step(idx: int): - # assign pages - self._assign_pages() - - # get the logits and then last token logits in each sequence ([b, 1, vocab_size]) logits = self._compute_logits() logits_last = torch.stack([l_one_seq[-1] for l_one_seq in logits]).float().unsqueeze(1) token_ids, _ = self._decode_tokens(logits_last, sampling_params) # [b,1] - # update sequence info accordingly for next step - sequence_info.update_pos(sequence_info.sequence_lengths) - sequence_info.nest_sequences(token_ids) + # update sequence info accordingly for next step (generate phase) + input_pos_next = sequence_info.input_pos + seq_lens_current = sequence_info.seq_len + input_pos_next = [ip + sl for ip, sl in zip(input_pos_next, seq_lens_current)] + total_lens_next = [ip + len(t_ids) for ip, t_ids in zip(input_pos_next, token_ids)] + sequence_info.nest_sequences( + token_ids, + input_pos=input_pos_next, + page_assignments=self._assign_pages(total_lens_next), + ) # nest new tokens and run stop check for b, (new_tokens_b, new_id) in enumerate(zip(new_tokens, token_ids)): @@ -255,6 +279,7 @@ class DemoGenerationExecutor(GenerationExecutor): def _unpack(inputs) -> GenerationRequest: args, kwargs = inputs # unpack the inputs request: GenerationRequest = args[0] + request.multimodal_params: Optional[MultimodalParams] = args[1] return request engine = DemoEngine.build_from_config(**engine_kwargs) @@ -309,8 +334,11 @@ class DemoGenerationExecutor(GenerationExecutor): request.set_id(client_id) # submit request to our demo engine and store results + # NOTE: when returning from this function, the reference request.multimodal_params will + # be cleared immediately. So we pass it in explicitly to maintain a reference even when + # requests get submitted asynchronously. result = GenerationResult(request) - result.queue = self.engine_executor(request) + result.queue = self.engine_executor(request, request.multimodal_params) return result diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py b/tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py index d07ab02c62..3d2d587adb 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py @@ -63,7 +63,7 @@ class ExportToGM(BaseTransform): model = gm.get_submodule("factory_model") # set the example sequence - cm.info.set_example_sequence() + cm.info.set_example_sequence(**factory.get_example_inputs()) # export the model to a graph module gm = torch_export_to_gm( diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py b/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py index 9ded1bea1c..a1177880c8 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py @@ -44,9 +44,6 @@ class UpdateInOutNodes(BaseTransform): # loop through nodes to get input, output, and get_attr nodes input_nodes, output_nodes = get_all_input_output_nodes(gm.graph) - # we only expect one input node - assert len(input_nodes) == 2, "Expected exactly two input nodes (input_ids, position_ids)." - # NOTE: for now, we wanna make sure we *only* return the final output and no hidden states. # Later on, we can revisit how to support returning hidden states. assert len(output_nodes) == 1, "Expected exactly one output node!" @@ -117,16 +114,17 @@ class InsertCachedAttention(BaseTransform): # retrieve input nodes input_nodes, _ = get_all_input_output_nodes(gm.graph) + input_nodes_mapping = {n.target: n for n in input_nodes} + + # filtered and sorted for SequenceInfo arguments + constants (input_ids, position_ids, etc.) + inputs_from_info = [input_nodes_mapping[k] for k in cm.info.named_standard_args.keys()] + constants_from_info = cm.info.const_args_for_prepare_metadata # insert metadata computation and extract each argument as a node get_metadata, num_metadata = attn_descriptor.get_prepare_metadata_op() with graph.inserting_before(input_nodes[-1].next): ret_node = graph.call_function( - get_metadata, - args=( - *input_nodes, - cm.info.page_size, - ), + get_metadata, args=(*inputs_from_info, *constants_from_info) ) metadata_nodes = [ graph.call_function(operator.getitem, args=(ret_node, idx)) @@ -244,7 +242,7 @@ class ResizeKVCache(BaseTransform): try: # Let's run a forward pass to get the memory usage - cm.info._set_max_num_tokens_sample() + cm.info.set_max_num_tokens_sample() free_mem_pre, _ = _get_mem_info_in_mb() ad_logger.info(f"Free memory before forward pass (MB): {free_mem_pre}") diff --git a/tensorrt_llm/executor/worker.py b/tensorrt_llm/executor/worker.py index f4cd66d6f7..9463ed9547 100644 --- a/tensorrt_llm/executor/worker.py +++ b/tensorrt_llm/executor/worker.py @@ -18,8 +18,7 @@ from .._utils import (KVCacheEventSerializer, global_mpi_rank, global_mpi_size, mpi_comm, mpi_rank, nvtx_range_debug) from ..bindings import executor as tllm from ..builder import ConfigEncoder, Engine, EngineConfig -from ..llmapi.llm_args import (BaseLlmArgs, KvCacheConnectorConfig, - PybindMirror, TorchLlmArgs) +from ..llmapi.llm_args import BaseLlmArgs, KvCacheConnectorConfig, PybindMirror from ..llmapi.mpi_session import set_mpi_session_cpp from ..llmapi.tokenizer import TokenizerBase from ..llmapi.tracer import VizTracer, global_tracer, set_global_tracer @@ -86,7 +85,9 @@ class GenerationExecutorWorker(GenerationExecutor): self._await_response_helper = AwaitResponseHelper( self) # TODO: make it weakref self._executor_config = executor_config - self._is_pytorch_backend = llm_args is not None and llm_args.backend == "pytorch" + self._is_pytorch_backend = llm_args is not None and llm_args.backend in [ + "pytorch", "_autodeploy" + ] self.llm_args = llm_args if not self._is_pytorch_backend and kv_connector_config is not None: @@ -468,7 +469,6 @@ class GenerationExecutorWorker(GenerationExecutor): ) if self._is_pytorch_backend: - assert isinstance(self.llm_args, TorchLlmArgs) if not self.llm_args.disable_overlap_scheduler: is_disaggregated = self.engine.kv_cache_transceiver is not None if is_disaggregated and ( diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 46822f05bb..dfadc5f05d 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -2317,6 +2317,7 @@ class TestQwen3_8B(LlmapiAccuracyTestHarness): MODEL_NAME = "Qwen3/Qwen3-8B" @skip_pre_hopper + @pytest.mark.skip(reason="https://nvbugs/5505402") @pytest.mark.parametrize( "tp_size,pp_size,ep_size,attention_dp,cuda_graph,overlap_scheduler", [(1, 1, 1, False, True, True)], diff --git a/tests/integration/test_lists/test-db/l0_b200.yml b/tests/integration/test_lists/test-db/l0_b200.yml index 7d53b90373..031548955b 100644 --- a/tests/integration/test_lists/test-db/l0_b200.yml +++ b/tests/integration/test_lists/test-db/l0_b200.yml @@ -37,6 +37,7 @@ l0_b200: - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[True-True-cutlass] - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[True-True-trtllm] - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[True-True-triton] + - accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_fp8_block_scales[latency] # nvbugs 5505402 - disaggregated/test_workers.py::test_workers_kv_cache_aware_router_eviction[TinyLlama-1.1B-Chat-v1.0] # nvbugs 5300551 - test_e2e.py::test_ptp_quickstart_advanced[Llama3.1-8B-NVFP4-nvfp4-quantized/Meta-Llama-3.1-8B] - test_e2e.py::test_ptp_quickstart_advanced[Llama3.1-8B-FP8-llama-3.1-model/Llama-3.1-8B-Instruct-FP8] diff --git a/tests/integration/test_lists/test-db/l0_dgx_h100.yml b/tests/integration/test_lists/test-db/l0_dgx_h100.yml index f0171fd2c8..7770bcd2ba 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_h100.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_h100.yml @@ -86,33 +86,12 @@ l0_dgx_h100: - unittest/_torch/modules/test_fused_moe.py::test_fused_moe_w4afp8[MoEWeightLoadingMode.VANILLA-dtype1] - unittest/_torch/modules/test_fused_moe.py::test_fused_moe_w4afp8[MoEWeightLoadingMode.W4A8_CUSTOM-dtype0] - unittest/_torch/modules/test_fused_moe.py::test_fused_moe_w4afp8[MoEWeightLoadingMode.W4A8_CUSTOM-dtype1] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=True-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=False-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=False] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=2-fp8kv=True-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=2-fp8kv=False-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=False] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=False] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[ep4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[ep4-mtp_nextn=0-fp8kv=True-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[ep4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[ep4-mtp_nextn=0-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[ep4-mtp_nextn=2-fp8kv=True-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[ep4-mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[ep4-mtp_nextn=2-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=0-fp8kv=True-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=0-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=2-fp8kv=False-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=False] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=False] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[pp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[pp4-mtp_nextn=0-fp8kv=False-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=True] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=True-attention_dp=False-cuda_graph=True-overlap_scheduler=True-torch_compile=True] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=False-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=True] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=True] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[pp4-mtp_nextn=0-fp8kv=True-attention_dp=False-cuda_graph=True-overlap_scheduler=True-torch_compile=True] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=0-fp8kv=True-attention_dp=False-cuda_graph=True-overlap_scheduler=True-torch_compile=True] diff --git a/tests/integration/test_lists/test-db/l0_dgx_h200.yml b/tests/integration/test_lists/test-db/l0_dgx_h200.yml index e4a9b0ecdd..36cd9b9696 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_h200.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_h200.yml @@ -50,6 +50,27 @@ l0_dgx_h200: stage: post_merge backend: pytorch tests: + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=True-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=False-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=2-fp8kv=True-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=2-fp8kv=False-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[ep4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[ep4-mtp_nextn=0-fp8kv=True-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=0-fp8kv=True-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=2-fp8kv=False-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=True] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=False-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=True] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[ep4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[ep4-mtp_nextn=2-fp8kv=True-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[ep4-mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[pp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[tp4-mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[tp4-mtp_nextn=0-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[tp4-mtp_nextn=0-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False] diff --git a/tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py b/tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py index 283a3eb8f0..9548bed96e 100644 --- a/tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py +++ b/tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py @@ -1,5 +1,5 @@ import copy -from typing import Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Sequence import numpy as np import torch @@ -39,19 +39,45 @@ class FakeFactory(ModelFactory): class SequenceEmbeddingInfo(SequenceInfo): - hidden_size: int - dtype: torch.dtype + """A sequence info object for testing that replaces the input_ids with an embedding tensor. - def set_example_sequence(self) -> None: - super().set_example_sequence() - # set input ids to a 3D tensor (actually input embeddings) - self.input_ids = torch.rand( - *self.input_ids.shape, + This is useful to run tests without the tokenizer in the loop. + """ + + def _add_hidden_dim(self, input_ids: Sequence[Sequence[Any]]) -> torch.Tensor: + return torch.rand( + *input_ids.shape, self.hidden_size, - device=self.input_ids.device, + device=self.device, dtype=self.dtype, ) + def __init__(self, *args, hidden_size: int, dtype: torch.dtype, **kwargs): + self._initialized = False + super().__init__(*args, **kwargs) + + # overwrite input_ids with an embedding tensor and run reset again + self.hidden_size = hidden_size + self.dtype = dtype + self._args_device["input_ids"] = self._add_hidden_dim(self._args_device["input_ids"]) + self._args_host["input_ids"] = self._args_device["input_ids"].cpu() + self._initialized = True + self.reset() + + def nest_sequences(self, input_ids: Sequence[Sequence[Any]], *args, **kwargs) -> None: + # convert input_ids to an embedding tensor if needed + if not (isinstance(input_ids, torch.Tensor) and input_ids.ndim == 3) and self._initialized: + # first convert to a list of tensors + input_embeds = [ + torch.tensor(ids, device=self.device, dtype=self.dtype) for ids in input_ids + ] + # then add the hidden dimension to every tensor + input_embeds = [self._add_hidden_dim(ids) for ids in input_embeds] + else: + input_embeds = input_ids + + super().nest_sequences(input_embeds, *args, **kwargs) + def count_parameters(model: torch.nn.Module): for n, p in model.named_parameters(): diff --git a/tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py b/tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py index e13891ee4a..aa90df29d3 100644 --- a/tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py +++ b/tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py @@ -400,9 +400,6 @@ _SMALL_MODEL_CONFIGS = { }, "vision_config": { "num_hidden_layers": 1, - "hidden_size": 64, - "intermediate_size": 64, - "num_attention_heads": 2, }, }, }, diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_llama4_vlm_patch.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_llama4_vlm_patch.py new file mode 100644 index 0000000000..40f118b18d --- /dev/null +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_llama4_vlm_patch.py @@ -0,0 +1,108 @@ +import torch +from _model_test_utils import get_small_model_config +from build_and_run_ad import ExperimentConfig +from PIL import Image + +from tensorrt_llm._torch.auto_deploy import LlmArgs +from tensorrt_llm._torch.auto_deploy.export import apply_export_patches, torch_export_to_gm +from tensorrt_llm._torch.auto_deploy.transformations._graph import move_to_device + + +def test_build_run_llama4_vlm(): + atol = 1e-3 + rtol = 1e-3 + + experiment_config = get_small_model_config("meta-llama/Llama-4-Scout-17B-16E-Instruct") + experiment_config["args"]["model_kwargs"]["_attn_implementation"] = "eager" + experiment_config = ExperimentConfig(**experiment_config) + llm_args: LlmArgs = experiment_config.args + + factory = llm_args.create_factory() + model = factory.build_model("cuda") + processor = factory.init_processor() + + img1 = Image.new("RGB", (16, 16), color=(128, 128, 128)) + img2 = Image.new("RGB", (16, 16), color=(64, 64, 64)) + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "image": img1}, + {"type": "image", "image": img2}, + { + "type": "text", + "text": "Describe what you see in the two images and their differences.", + }, + ], + }, + ] + + inputs = ( + processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + ) + .to(model.device) + .to(torch.bfloat16) + ) + + # get relevant inputs + input_ids = inputs["input_ids"] + position_ids = torch.arange(input_ids.shape[1], device=input_ids.device).repeat( + input_ids.shape[0], 1 + ) + pixel_values = inputs["pixel_values"] + + def _run_with_and_without_image(model, use_patch=True): + with apply_export_patches(patch_list=["hf_llama4_vision"] if use_patch else []): + with torch.inference_mode(): + out_no_images = model( + input_ids, + position_ids, + torch.zeros_like(pixel_values) if use_patch else None, + ) + out_with_images = model( + input_ids, + position_ids, + pixel_values, + ) + return {"no_images": out_no_images.logits, "with_images": out_with_images.logits} + + # Get output pre-patch + out_original = _run_with_and_without_image(model, use_patch=False) + + # Get output post-patch + outputs_for_comparison = {} + outputs_for_comparison["model_with_patch"] = _run_with_and_without_image(model) + + # Export to GM + gm = torch_export_to_gm( + model, + args=(input_ids, position_ids, pixel_values), + patch_list=[ + "transformers_sdpa_mask", + "autocast_noop", + "torch_where", + "tensor_meta_device", + "sdpa_kernel_noop", + "sdpa", + "hf_llama4_vision", + ], + ) + move_to_device(gm, model.device) + + # Get the output post export + outputs_for_comparison["gm"] = _run_with_and_without_image(gm) + + # Run comparisons to out_original with no patch now... + for comp, outs in outputs_for_comparison.items(): + torch.testing.assert_close( + outs, + out_original, + rtol=rtol, + atol=atol, + msg=lambda m: f"Comparison for {comp} failed:\n{m}", + ) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_engine.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_engine.py index e9d7acd7dc..472bc71f1e 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_engine.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_engine.py @@ -71,7 +71,6 @@ def test_engine(engine_cls: Type[ADEngine], attn_backend: str, attn_page_size: i input_ids = [torch.tensor([0, 1, 2], device=device)] sequence_info.reset() sequence_info.nest_sequences(input_ids) - engine.cache_seq_interface.info.sync(sequence_info) logits = engine._compute_logits() logits = torch.stack(logits) assert logits is not None, "Logits are None" @@ -106,7 +105,6 @@ def test_demo_engine_sampling(attn_page_size: int): input_ids = [torch.tensor([1, 2, 3, 4], device=device)] sequence_info.reset() sequence_info.nest_sequences(input_ids) - engine.cache_seq_interface.info.sync(sequence_info) logits = engine._compute_logits() logits = torch.stack(logits) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_llm_config.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_llm_config.py index 6a4016234e..70c788d7b9 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_llm_config.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_llm_config.py @@ -92,7 +92,7 @@ def test_config_params(): @patch("tensorrt_llm._torch.auto_deploy.llm.DemoGenerationExecutor") @patch("tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface.SequenceInfo") @patch("tensorrt_llm._torch.auto_deploy.shim.demollm.dist_ad.initialize_or_skip") -@patch("tensorrt_llm._torch.auto_deploy.llm.create_input_processor") +@patch("tensorrt_llm._torch.auto_deploy.llm.LLM._create_input_processor") @patch("tensorrt_llm._torch.auto_deploy.llm.LLM._build_model") def test_config_flow( mock_build_model, @@ -147,13 +147,6 @@ def test_config_flow( pass -def test_invalid_model_factory(): - """Test behavior with invalid model factory.""" - # Pydantic validates Literal types at runtime, so this should raise ValidationError - with pytest.raises(Exception): # Could be ValidationError or ValueError - LlmArgs(model="test-model", model_factory="InvalidFactory") - - @pytest.mark.parametrize( "parallel_field,invalid_value", [ diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py index 3d8b7e2ee1..273a50123d 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py @@ -68,7 +68,7 @@ def _check_ad_config(experiment_config: ExperimentConfig, llm_args: LlmArgs): get_small_model_config( "meta-llama/Llama-4-Scout-17B-16E-Instruct", attn_backend="flashinfer", - compile_backend="torch-opt", + compile_backend="torch-simple", ), get_small_model_config( "deepseek-ai/DeepSeek-V3", diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py index 9266027e11..e4865754e4 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py @@ -54,17 +54,19 @@ class GQAWithSdpa(GQA): self.num_key_value_groups = None @torch.no_grad() - def forward(self, x: torch.Tensor, position_ids: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward( + self, input_ids: torch.Tensor, position_ids: Optional[torch.Tensor] = None + ) -> torch.Tensor: """ Forward pass with input tokens and optional position ids. position_ids parameter added to match expected interface in kvcache.py """ - b, s, _ = x.shape + b, s, _ = input_ids.shape # Project input to q, k, v representations - q = self.q_proj(x) # [b, s, n*h_d] - k = self.k_proj(x) # [b, s, n_kv*h_d] - v = self.v_proj(x) # [b, s, n_kv*h_d] + q = self.q_proj(input_ids) # [b, s, n*h_d] + k = self.k_proj(input_ids) # [b, s, n_kv*h_d] + v = self.v_proj(input_ids) # [b, s, n_kv*h_d] # Reshape to [b, s, n, h_d] q = q.view(b, s, self.num_heads, self.head_dim) @@ -126,9 +128,9 @@ def test_sdpa_with_kv_cache(dtype, attn_backend, gqa_config): ci = SequenceEmbeddingInfo( max_seq_len=max_position_embeddings, max_batch_size=batch_size, + hidden_size=hidden_size, + dtype=dtype, ) - ci.hidden_size = hidden_size - ci.dtype = dtype cm = CachedSequenceInterface(sequence_info=ci, device="cuda") # Create the model with SDPA and wrap it in a fake factory @@ -180,9 +182,9 @@ def test_sdpa_with_kv_cache(dtype, attn_backend, gqa_config): cm.initialize_caches() # Helper function to call the model with proper sequence nesting - def _call_and_unnest(x): + def _call_and_unnest(x, input_pos): # Use nest_sequences to properly set input_ids and automatically update position_ids - cm.info.nest_sequences(x, allow_realloc=True) + cm.info.nest_sequences(x, input_pos=input_pos) # Use the cm.args as is - it already contains the correct position_ids y = gm(*cm.args) @@ -192,31 +194,25 @@ def test_sdpa_with_kv_cache(dtype, attn_backend, gqa_config): # Test 1: Regular inference (all tokens at once) cm.info.reset() - y_no_cache = _call_and_unnest(x) + y_no_cache = _call_and_unnest(x, 0) assert all_close(y_model, y_no_cache, atol=atol, rtol=rtol) # Test 2: Autoregressive inference with KV cache cm.info.reset() y_with_cache = torch.empty_like(y_model) - for i in range(x.shape[1]): + for i_p in range(x.shape[1]): # Just pass the current token - y_with_cache[:, i : i + 1] = _call_and_unnest(x[:, i : i + 1]) - # Update position for next token - cm.info.update_pos(1) # This automatically updates position_ids too + y_with_cache[:, i_p : i_p + 1] = _call_and_unnest(x[:, i_p : i_p + 1], i_p) assert all_close(y_model, y_with_cache, atol=atol, rtol=rtol) # Test 3: Cache continuation after random tokens - cm.info.update_pos(-num_reset_steps) # Rewind position - for i in range(num_random_steps): - _call_and_unnest(torch.rand_like(x[:, :1])) - cm.info.update_pos(1) + for i_p in range(x.shape[1] - num_reset_steps, x.shape[1] - num_reset_steps + num_random_steps): + _call_and_unnest(torch.rand_like(x[:, :1]), i_p) # Continue inference from previous context cm.info.reset() - cm.info.update_pos(x.shape[1] - num_reset_steps) - for i in range(x.shape[1] - num_reset_steps, x.shape[1]): - y_with_cache[:, i : i + 1] = _call_and_unnest(x[:, i : i + 1]) - cm.info.update_pos(1) + for i_p in range(x.shape[1] - num_reset_steps, x.shape[1]): + y_with_cache[:, i_p : i_p + 1] = _call_and_unnest(x[:, i_p : i_p + 1], i_p) assert all_close(y_model, y_with_cache, atol=atol, rtol=rtol) # Test 4: Exportability of the transformed model diff --git a/tests/unittest/_torch/speculative/test_eagle3.py b/tests/unittest/_torch/speculative/test_eagle3.py index 038cea3c4f..8de0ac8642 100644 --- a/tests/unittest/_torch/speculative/test_eagle3.py +++ b/tests/unittest/_torch/speculative/test_eagle3.py @@ -16,6 +16,7 @@ from tensorrt_llm.llmapi import (CudaGraphConfig, EagleDecodingConfig, sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +@pytest.mark.skip(reason="https://nvbugs/5461761") @pytest.mark.parametrize( "use_cuda_graph,attn_backend,disable_overlap_scheduler,enable_block_reuse,use_one_model,enable_chunked_prefill", [