mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Merge remote-tracking branch 'origin/main' into feat/b300_cu13
Signed-off-by: Xiwen Yu <13230610+VALLIS-NERIA@users.noreply.github.com>
This commit is contained in:
commit
322db710dc
@ -12,7 +12,7 @@ fi
|
||||
PARSED_CMAKE_VERSION=$(echo $CMAKE_VERSION | sed 's/\.[0-9]*$//')
|
||||
CMAKE_FILE_NAME="cmake-${CMAKE_VERSION}-linux-${ARCH}"
|
||||
RELEASE_URL_CMAKE=${GITHUB_URL}/Kitware/CMake/releases/download/v${CMAKE_VERSION}/${CMAKE_FILE_NAME}.tar.gz
|
||||
wget --no-verbose --timeout=180 --tries=3 ${RELEASE_URL_CMAKE} -P /tmp
|
||||
wget --retry-connrefused --timeout=180 --tries=10 --continue ${RELEASE_URL_CMAKE} -P /tmp
|
||||
tar -xf /tmp/${CMAKE_FILE_NAME}.tar.gz -C /usr/local/
|
||||
ln -s /usr/local/${CMAKE_FILE_NAME} /usr/local/cmake
|
||||
|
||||
|
||||
@ -16,7 +16,7 @@ reinstall_rockylinux_cuda() {
|
||||
dnf -y install epel-release
|
||||
dnf remove -y "cuda*" "*cublas*" "*cufft*" "*cufile*" "*curand*" "*cusolver*" "*cusparse*" "*gds-tools*" "*npp*" "*nvjpeg*" "nsight*" "*nvvm*"
|
||||
rm -rf /usr/local/cuda-${OLD_CUDA_VER}
|
||||
wget -q https://developer.download.nvidia.com/compute/cuda/${CUDA_VER_SHORT}/local_installers/cuda_${CUDA_VER}_linux.run
|
||||
wget --retry-connrefused --timeout=180 --tries=10 --continue https://developer.download.nvidia.com/compute/cuda/${CUDA_VER_SHORT}/local_installers/cuda_${CUDA_VER}_linux.run
|
||||
sh cuda_${CUDA_VER}_linux.run --silent --override --toolkit
|
||||
rm -f cuda_${CUDA_VER}_linux.run
|
||||
}
|
||||
|
||||
@ -96,7 +96,7 @@ install_rockylinux_requirements() {
|
||||
"cuda-toolkit-config-common-${CUDA_RUNTIME}.noarch" \
|
||||
"libcublas-${CUBLAS_CUDA_VERSION}-${CUBLAS_VER}.${ARCH1}" \
|
||||
"libcublas-devel-${CUBLAS_CUDA_VERSION}-${CUBLAS_VER}.${ARCH1}"; do
|
||||
wget -q --timeout=180 --tries=3 "https://developer.download.nvidia.cn/compute/cuda/repos/rhel8/${ARCH3}/${pkg}.rpm"
|
||||
wget --retry-connrefused --timeout=180 --tries=10 --continue "https://developer.download.nvidia.cn/compute/cuda/repos/rhel8/${ARCH3}/${pkg}.rpm"
|
||||
done
|
||||
|
||||
# Remove old packages
|
||||
@ -138,7 +138,7 @@ install_tensorrt() {
|
||||
if [ "$ARCH" = "x86_64" ];then
|
||||
curl -L --insecure --connect-timeout 600 --max-time 3600 --retry 3 -o /tmp/TensorRT.tar "${RELEASE_URL_TRT}"
|
||||
else
|
||||
wget --no-verbose ${RELEASE_URL_TRT} -O /tmp/TensorRT.tar
|
||||
wget --retry-connrefused --timeout=180 --tries=10 --continue ${RELEASE_URL_TRT} -O /tmp/TensorRT.tar
|
||||
fi
|
||||
tar -xf /tmp/TensorRT.tar -C /usr/local/
|
||||
mv /usr/local/TensorRT-* /usr/local/tensorrt
|
||||
|
||||
@ -5,7 +5,7 @@ set -ex
|
||||
install_boost() {
|
||||
# Install boost version >= 1.78 for boost::span
|
||||
# Current libboost-dev apt packages are < 1.78, so install from tar.gz
|
||||
wget -O /tmp/boost.tar.gz --timeout=180 --tries=3 https://archives.boost.io/release/1.80.0/source/boost_1_80_0.tar.gz \
|
||||
wget --retry-connrefused --timeout=180 --tries=10 --continue -O /tmp/boost.tar.gz https://archives.boost.io/release/1.80.0/source/boost_1_80_0.tar.gz \
|
||||
&& tar xzf /tmp/boost.tar.gz -C /tmp \
|
||||
&& mv /tmp/boost_1_80_0/boost /usr/include/boost \
|
||||
&& rm -rf /tmp/boost_1_80_0 /tmp/boost.tar.gz
|
||||
|
||||
2
examples/auto_deploy/.gitignore
vendored
2
examples/auto_deploy/.gitignore
vendored
@ -2,3 +2,5 @@
|
||||
!.vscode
|
||||
benchmark_results.json
|
||||
*.png
|
||||
# ignore config files that users might put here for debugging
|
||||
*.yaml
|
||||
|
||||
@ -26,6 +26,9 @@ from tensorrt_llm.sampling_params import SamplingParams
|
||||
# Global torch config, set the torch compile cache to fix up to llama 405B
|
||||
torch._dynamo.config.cache_size_limit = 20
|
||||
|
||||
# simple string, TRT-LLM style text-only prompt or full-scale HF message template
|
||||
PromptInput = Union[str, Dict, List[Dict]]
|
||||
|
||||
|
||||
class PromptConfig(BaseModel):
|
||||
"""Prompt configuration.
|
||||
@ -35,13 +38,27 @@ class PromptConfig(BaseModel):
|
||||
"""
|
||||
|
||||
batch_size: int = Field(default=2, description="Number of queries")
|
||||
queries: Union[str, List[str]] = Field(
|
||||
queries: Union[PromptInput, List[PromptInput]] = Field(
|
||||
default_factory=lambda: [
|
||||
# OPTION 1: simple text prompt
|
||||
"How big is the universe? ",
|
||||
"In simple words and in a single sentence, explain the concept of gravity: ",
|
||||
"How to fix slicing in golf? ",
|
||||
"Where is the capital of Iceland? ",
|
||||
]
|
||||
# OPTION 2: wrapped text prompt for TRT-LLM
|
||||
{"prompt": "In simple words and a single sentence, explain the concept of gravity: "},
|
||||
# OPTION 3: a full-scale HF message template (this one works for text-only models!)
|
||||
# Learn more about chat templates: https://huggingface.co/docs/transformers/en/chat_templating
|
||||
# and multi-modal templates: https://huggingface.co/docs/transformers/en/chat_templating_multimodal
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "How to fix slicing in golf?",
|
||||
}
|
||||
],
|
||||
# More prompts...
|
||||
{"prompt": "Where is the capital of Iceland? "},
|
||||
],
|
||||
description="Example queries to prompt the model with. We support both TRT-LLM text-only "
|
||||
"queries via the 'prompt' key and full-scale HF message template called via "
|
||||
"apply_chat_template.",
|
||||
)
|
||||
sp_kwargs: Dict[str, Any] = Field(
|
||||
default_factory=lambda: {"max_tokens": 100, "top_k": 200, "temperature": 1.0},
|
||||
@ -55,10 +72,28 @@ class PromptConfig(BaseModel):
|
||||
NOTE (lucaslie): has to be done with model_post_init to ensure it's always run. field
|
||||
validators are only run if a value is provided.
|
||||
"""
|
||||
queries = [self.queries] if isinstance(self.queries, str) else self.queries
|
||||
queries = self.queries if isinstance(self.queries, list) else [self.queries]
|
||||
batch_size = self.batch_size
|
||||
queries = queries * (batch_size // len(queries) + 1)
|
||||
self.queries = queries[:batch_size]
|
||||
queries = queries[:batch_size]
|
||||
|
||||
# now let's standardize the queries for the LLM api to understand them
|
||||
queries_processed = []
|
||||
for query in queries:
|
||||
if isinstance(query, str):
|
||||
queries_processed.append({"prompt": query})
|
||||
elif isinstance(query, dict):
|
||||
queries_processed.append(query)
|
||||
elif isinstance(query, list):
|
||||
queries_processed.append(
|
||||
{
|
||||
"prompt": "Fake prompt. Check out messages field for the HF chat template.",
|
||||
"messages": query, # contains the actual HF chat template
|
||||
}
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid query type: {type(query)}")
|
||||
self.queries = queries_processed
|
||||
|
||||
@field_validator("sp_kwargs", mode="after")
|
||||
@classmethod
|
||||
|
||||
@ -121,7 +121,7 @@ REQUIRED_NO_DRIVER_TYPES = ["dgx-h100", "dgx-h200", "gh200"]
|
||||
ENABLE_NGC_DEVEL_IMAGE_TEST = params.enableNgcDevelImageTest ?: false
|
||||
ENABLE_NGC_RELEASE_IMAGE_TEST = params.enableNgcReleaseImageTest ?: false
|
||||
|
||||
COMMON_SSH_OPTIONS = "-o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null"
|
||||
COMMON_SSH_OPTIONS = "-o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -o ServerAliveInterval=60 -o ServerAliveCountMax=5"
|
||||
|
||||
def uploadResults(def pipeline, SlurmCluster cluster, String nodeName, String stageName){
|
||||
withCredentials([usernamePassword(credentialsId: 'svc_tensorrt', usernameVariable: 'USERNAME', passwordVariable: 'PASSWORD')]) {
|
||||
@ -323,7 +323,7 @@ def runLLMTestlistOnSlurm(pipeline, platform, testList, config=VANILLA_CONFIG, p
|
||||
|
||||
Utils.exec(pipeline, script: "chmod +x ${jenkinsSetupPath}", returnStdout: true)
|
||||
|
||||
Utils.exec(pipeline, script: "sshpass -p '${remote.passwd}' scp -r -p ${COMMON_SSH_OPTIONS} ${jenkinsSetupPath} ${remote.user}@${remote.host}:~/bloom/scripts/${nodeName}-slurm_jenkins_agent_setup.sh", numRetries: 3,)
|
||||
Utils.exec(pipeline, script: "sshpass -p '${remote.passwd}' scp -r -p ${COMMON_SSH_OPTIONS} ${jenkinsSetupPath} ${remote.user}@${remote.host}:~/bloom/scripts/${nodeName}-slurm_jenkins_agent_setup.sh", numRetries: 3)
|
||||
|
||||
Utils.exec(pipeline, script: "cat ${jenkinsSetupPath}")
|
||||
|
||||
@ -334,7 +334,8 @@ def runLLMTestlistOnSlurm(pipeline, platform, testList, config=VANILLA_CONFIG, p
|
||||
remote,
|
||||
"\"${SlurmConfig.generateCommand(cluster, partition, nodeSecret, nodeName, Jenkins.instance.rootUrl)}\""
|
||||
),
|
||||
returnStdout: true
|
||||
returnStdout: true,
|
||||
numRetries: 3
|
||||
)
|
||||
|
||||
def jobIDs = slurmSubmitOutput
|
||||
@ -498,7 +499,7 @@ def runLLMTestlistOnSlurm_MultiNodes(pipeline, platform, testList, config=VANILL
|
||||
|
||||
stage('Prepare Testing') {
|
||||
// Create Job Workspace folder in Frontend Node
|
||||
Utils.exec(pipeline, script: "sshpass -p '${remote.passwd}' ssh ${COMMON_SSH_OPTIONS} ${remote.user}@${remote.host} 'mkdir -p ${jobWorkspace}'", numRetries: 3,)
|
||||
Utils.exec(pipeline, script: Utils.sshUserCmd(remote, "\"mkdir -p ${jobWorkspace}\""), numRetries: 3)
|
||||
|
||||
// Download and Unzip Tar File
|
||||
trtllm_utils.llmExecStepWithRetry(pipeline, script: "cd ${llmPath} && wget -nv ${llmTarfile}")
|
||||
@ -508,12 +509,12 @@ def runLLMTestlistOnSlurm_MultiNodes(pipeline, platform, testList, config=VANILL
|
||||
def scriptRunLocalPath = "${llmSrcLocal}/jenkins/scripts/slurm_run.sh"
|
||||
Utils.exec(pipeline, script: "chmod +x ${scriptRunLocalPath}", returnStdout: true)
|
||||
|
||||
Utils.exec(pipeline, script: "sshpass -p '${remote.passwd}' scp -r -p ${COMMON_SSH_OPTIONS} ${scriptRunLocalPath} ${remote.user}@${remote.host}:${scriptRunNode}", numRetries: 3,)
|
||||
Utils.exec(pipeline, script: "sshpass -p '${remote.passwd}' scp -r -p ${COMMON_SSH_OPTIONS} ${scriptRunLocalPath} ${remote.user}@${remote.host}:${scriptRunNode}", numRetries: 3)
|
||||
Utils.exec(pipeline, script: "cat ${scriptRunLocalPath}")
|
||||
|
||||
// Upload waives.txt to Frontend node
|
||||
def waivesListLocalPath = "${llmSrcLocal}/tests/integration/test_lists/waives.txt"
|
||||
Utils.exec(pipeline, script: "sshpass -p '${remote.passwd}' scp -r -p ${COMMON_SSH_OPTIONS} ${waivesListLocalPath} ${remote.user}@${remote.host}:${waivesListPathNode}", numRetries: 3,)
|
||||
Utils.exec(pipeline, script: "sshpass -p '${remote.passwd}' scp -r -p ${COMMON_SSH_OPTIONS} ${waivesListLocalPath} ${remote.user}@${remote.host}:${waivesListPathNode}", numRetries: 3)
|
||||
|
||||
// Generate Test List and Upload to Frontend Node
|
||||
def makoArgs = getMakoArgsFromStageName(stageName, true)
|
||||
@ -522,7 +523,7 @@ def runLLMTestlistOnSlurm_MultiNodes(pipeline, platform, testList, config=VANILL
|
||||
// if the line cannot be split by "=", just ignore that line.
|
||||
def makoOptsJson = transformMakoArgsToJson(["Mako options:"] + makoArgs)
|
||||
def testListPath = renderTestDB(testList, llmSrcLocal, stageName, makoOptsJson)
|
||||
Utils.exec(pipeline, script: "sshpass -p '${remote.passwd}' scp -r -p ${COMMON_SSH_OPTIONS} ${testListPath} ${remote.user}@${remote.host}:${testListPathNode}", numRetries: 3,)
|
||||
Utils.exec(pipeline, script: "sshpass -p '${remote.passwd}' scp -r -p ${COMMON_SSH_OPTIONS} ${testListPath} ${remote.user}@${remote.host}:${testListPathNode}", numRetries: 3)
|
||||
|
||||
// Generate Multi Node Job Launch Script
|
||||
def container = LLM_DOCKER_IMAGE.replace("urm.nvidia.com/", "urm.nvidia.com#")
|
||||
@ -566,7 +567,7 @@ def runLLMTestlistOnSlurm_MultiNodes(pipeline, platform, testList, config=VANILL
|
||||
""".stripIndent()
|
||||
pipeline.writeFile(file: scriptLaunchDestPath, text: scriptContent)
|
||||
Utils.exec(pipeline, script: "chmod +x ${scriptLaunchDestPath}", returnStdout: true)
|
||||
Utils.exec(pipeline, script: "sshpass -p '${remote.passwd}' scp -r -p ${COMMON_SSH_OPTIONS} ${scriptLaunchDestPath} ${remote.user}@${remote.host}:${scriptLaunch}", numRetries: 3,)
|
||||
Utils.exec(pipeline, script: "sshpass -p '${remote.passwd}' scp -r -p ${COMMON_SSH_OPTIONS} ${scriptLaunchDestPath} ${remote.user}@${remote.host}:${scriptLaunch}", numRetries: 3)
|
||||
Utils.exec(pipeline, script: "cat ${scriptLaunchDestPath}")
|
||||
}
|
||||
|
||||
@ -577,7 +578,8 @@ def runLLMTestlistOnSlurm_MultiNodes(pipeline, platform, testList, config=VANILL
|
||||
script: Utils.sshUserCmd(
|
||||
remote,
|
||||
"\"bash ${scriptLaunch}\""
|
||||
)
|
||||
),
|
||||
numRetries: 3
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@ -10,15 +10,18 @@ and operates on a purely functional paradigm that is compatible with the torch c
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field, fields
|
||||
from typing import Dict, List, Literal, Optional, Protocol, Sequence, Tuple, Type, Union
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Dict, List, Literal, Optional, Protocol, Sequence, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
from torch._ops import OpOverloadPacket
|
||||
from torch.export import Dim
|
||||
from torch.fx import Node
|
||||
|
||||
from tensorrt_llm._utils import nvtx_range
|
||||
from ...._utils import nvtx_range
|
||||
|
||||
DynamicShape = Dict[int, Dim] # indicating the dynamic shape in tensor dimension
|
||||
DynamicShapeCallback = Callable[[], DynamicShape]
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -28,15 +31,14 @@ class CacheConfig:
|
||||
dtype: Optional[torch.dtype] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SequenceInfo:
|
||||
"""A dataclass to hold information about how the sequence is laid out and stored in cache.
|
||||
"""An interface to hold information about how the sequence is laid out and stored in cache.
|
||||
|
||||
We assume the sequence + cache is laid out in the following way. Also note that we differentiate
|
||||
between arguments that are originally part of the model/graph and arguments that are needed for
|
||||
the attention operator when we switch to cached+flattened attention.
|
||||
|
||||
# ORIGINAL MODEL ARGUMENTS #####################################################################
|
||||
### ORIGINAL MODEL ARGUMENTS ###################################################################
|
||||
- input_ids: [id_0, ..., id_{s_total-1}]
|
||||
flattened sequence of [b, 1] or [1, s_total]. We use [b, 1] to denote generate-only batches.
|
||||
- position_ids: [pos_0, ..., pos_{s_total-1}]
|
||||
@ -46,7 +48,17 @@ class SequenceInfo:
|
||||
NOTE: ``input_ids`` and ``position_ids`` are initially expected to be of shape [b, seq_len]
|
||||
before we switch to cached+flattened attention.
|
||||
|
||||
# EXTRA ARGUMENTS NEEDED FOR ATTENTION OPERATORS FOR FLATTENED SEQUENCES + CACHES ##############
|
||||
### EXTRA ARGUMENTS PROVIDED TO THE INTERFACE ##################################################
|
||||
Those are extra arguments that can be provided to the interface and they are stored as follows:
|
||||
- _extra_args: dictionary of extra arguments with currently active values.
|
||||
- _extra_none_inputs: dictionary of none inputs to the extra arguments.
|
||||
NOTE: we assume that extra arguments are *optional* arguments to the model. However, we
|
||||
cannot represent them via `None` since fx graphs require a fixed input type. Instead,
|
||||
we require a special placeholder tensor to represent the `None` input.
|
||||
- _extra_dynamic_shapes_callbacks: dictionary of callbacks to initialize the dynamic shapes of
|
||||
the extra arguments.
|
||||
|
||||
### CACHE ARGUMENTS NEEDED FOR ATTENTION OPERATORS FOR FLATTENED SEQUENCES + CACHES ############
|
||||
- seq_len: [s_0, s_1, ..., s_{b-1}] such that s_total = sum(s_i)
|
||||
Describes how long each sequence is. For example,
|
||||
input_ids[:s_0] will correspond to sequence 0 in the batch and input_ids[s_0:s_1] will
|
||||
@ -73,168 +85,238 @@ class SequenceInfo:
|
||||
|
||||
"""
|
||||
|
||||
## USE TO INITIALIZE DATA CLASS ###############################################################
|
||||
# max_seq_len corresponds the maximum number of tokens in any sequence. It includes the tokens in the
|
||||
# input sequence and the tokens generated by the model.
|
||||
max_seq_len: int = 1
|
||||
# max_batch_size corresponds to the maximum number of sequences (or requests) that the model can process.
|
||||
max_batch_size: int = 1
|
||||
# page_size is the granularity with which the cache pages are allocated for a paged kv cache.
|
||||
# For an unpaged cache, the page size should be set to max_seq_len.
|
||||
# Also note that two sequences in a batch can not share a page.
|
||||
page_size: int = 0
|
||||
# max_num_tokens is the maximum number of tokens that the model can process across all sequences in the batch.
|
||||
# If a batch is composed of context-only requests of input sequence length ISL,
|
||||
# then the maximum number of sequences possible in the batch is min (max_batch_size, max_num_tokens // ISL).
|
||||
# Similarly, if a batch is composed of generate-only requests,
|
||||
# then the maximum number of sequences possible in the batch is min (max_batch_size, max_num_tokens).
|
||||
max_num_tokens: Optional[int] = None
|
||||
# device is the device on which the sequence info is stored.
|
||||
device: str = "cuda"
|
||||
def __init__(
|
||||
self,
|
||||
max_seq_len: int = 1,
|
||||
max_batch_size: int = 1,
|
||||
page_size: int = 0,
|
||||
max_num_tokens: Optional[int] = None,
|
||||
):
|
||||
"""Initialize the SequenceInfo object.
|
||||
|
||||
## [UPDATE WITH CARE] TENSOR FIELDS THAT WILL BE PASSED TO PREPARE_METADATA OP #################
|
||||
# input_ids MUST ALWAYS BE THE FIRST FIELD
|
||||
input_ids: torch.Tensor = field(default_factory=lambda: torch.zeros(1, dtype=torch.int))
|
||||
position_ids: torch.Tensor = field(default_factory=lambda: torch.zeros(1, dtype=torch.long))
|
||||
Args:
|
||||
max_seq_len: corresponds to the maximum sequence length of the input sequence. It
|
||||
includes the tokens in the input sequence and the tokens generated by the model.
|
||||
max_batch_size: corresponds to the maximum number of sequences (or requests) that the
|
||||
model can process.
|
||||
page_size: corresponds to the page size of the cache. For an unpaged cache, the page
|
||||
size should be set to max_seq_len. Also note that two sequences in a batch can not
|
||||
share a page.
|
||||
max_num_tokens: corresponds to the maximum number of tokens that the model can process
|
||||
across all sequences in the batch. If a batch is composed of context-only requests
|
||||
of input sequence length ISL, then the maximum number of sequences possible in the
|
||||
batch is min (max_batch_size, max_num_tokens // ISL). Similarly, if a batch is
|
||||
composed of generate-only requests, then the maximum number of sequences possible in
|
||||
the batch is min (max_batch_size, max_num_tokens).
|
||||
|
||||
seq_len: torch.Tensor = field(default_factory=lambda: torch.ones(1, dtype=torch.int))
|
||||
input_pos: torch.Tensor = field(default_factory=lambda: torch.zeros(1, dtype=torch.int))
|
||||
cache_loc: torch.Tensor = field(default_factory=lambda: torch.arange(1, dtype=torch.int))
|
||||
pages_per_seq: torch.Tensor = field(default_factory=lambda: torch.ones(1, dtype=torch.int))
|
||||
################################################################################################
|
||||
|
||||
## PRIVATE FIELDS ##############################################################################
|
||||
_sequence_lengths: List[int] = field(default_factory=list)
|
||||
_num_pages: int = 1
|
||||
|
||||
def __post_init__(self):
|
||||
if self.page_size < 1:
|
||||
self.page_size = self.max_seq_len
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# set up basic attributes
|
||||
self.max_seq_len = max_seq_len
|
||||
self.max_batch_size = max_batch_size
|
||||
self.page_size = page_size if page_size > 0 else max_seq_len
|
||||
|
||||
# NOTE (lucaslie): WAR to address issue when using flashinfer attention with
|
||||
# (max_batch_size, max_seq_len) input in trtllm runtime.
|
||||
# see https://github.com/NVIDIA/TensorRT-LLM/issues/4504
|
||||
self.max_seq_len_adjusted = self.max_seq_len + 1
|
||||
max_seq_len_adjusted = self.max_seq_len + 1
|
||||
|
||||
if max_num_tokens is None or max_num_tokens < 1:
|
||||
self.max_num_tokens = self.max_batch_size * max_seq_len_adjusted
|
||||
else:
|
||||
self.max_num_tokens = max_num_tokens
|
||||
|
||||
if self.max_num_tokens is None or self.max_num_tokens < 1:
|
||||
self.max_num_tokens = self.max_batch_size * self.max_seq_len_adjusted
|
||||
# if the provided max_num_tokens is less than the max_batch_size * max_seq_len,
|
||||
# we use the provided max_num_tokens to calculate the number of pages
|
||||
total_tokens = min(self.max_num_tokens, self.max_batch_size * self.max_seq_len_adjusted)
|
||||
total_tokens = min(self.max_num_tokens, self.max_batch_size * max_seq_len_adjusted)
|
||||
# Num pages can not be less than max_batch_size.
|
||||
self._num_pages = max(
|
||||
self.max_batch_size,
|
||||
(total_tokens) // self.page_size + (total_tokens % self.page_size > 0),
|
||||
)
|
||||
# Ensure that the device is set before initializing the tensors.
|
||||
self.input_ids = torch.ones(self.max_num_tokens, dtype=torch.int, device=self.device)
|
||||
self.position_ids = torch.zeros(self.max_num_tokens, dtype=torch.long, device=self.device)
|
||||
|
||||
# Consumers of the sequence info args require input_ids and position_ids to be truncated.
|
||||
# We maintain a full version of the input_ids and position_ids to avoid overheads of tensor
|
||||
# creation in every forward pass.
|
||||
self.input_ids_full = torch.ones(self.max_num_tokens, dtype=torch.int, device=self.device)
|
||||
self.position_ids_full = torch.zeros(
|
||||
self.max_num_tokens, dtype=torch.long, device=self.device
|
||||
)
|
||||
|
||||
self.seq_len = torch.empty(self.max_batch_size, dtype=torch.int, device=self.device)
|
||||
self.input_pos = torch.empty_like(self.seq_len, device=self.device)
|
||||
|
||||
# Allocated host tensors for sequence lengths and input positions so that
|
||||
# position_ids calculation can be done on host.
|
||||
self.seq_len_host = torch.empty(self.max_batch_size, dtype=torch.int)
|
||||
self.input_pos_host = torch.empty_like(self.seq_len_host)
|
||||
|
||||
self.cache_loc = torch.empty(self.num_pages, dtype=torch.int, device=self.device)
|
||||
self.pages_per_seq = torch.empty_like(self.seq_len, device=self.device)
|
||||
|
||||
self.previous_batch_indices_cuda = torch.empty(
|
||||
self.max_num_tokens, dtype=torch.long, device=self.device
|
||||
)
|
||||
assert self.num_pages >= self.max_batch_size, (
|
||||
"num_pages must be greater than max_batch_size"
|
||||
)
|
||||
# dynamic shape descriptors for tensor args
|
||||
self._dynamic_shapes: Optional[Tuple[Dict[str, Dim]]] = None
|
||||
|
||||
# keep a list-like object of sequence lengths for simplicity as well
|
||||
self._sequence_lengths = [0] * self.max_batch_size
|
||||
# sanity check
|
||||
assert self.num_pages >= self.max_batch_size, "num_pages can't be less than max_batch_size"
|
||||
|
||||
# indicator if extra args are activated that are needed for cached attention backends
|
||||
self._is_cached_attn = False
|
||||
|
||||
# total number of tokens in the current batch
|
||||
self.num_tokens: int = 0
|
||||
# container for dynamic shapes
|
||||
self._dynamic_shapes: Optional[Dict[str, DynamicShape]] = None
|
||||
|
||||
# call reset once to initialize the tensors
|
||||
# TENSOR FIELDS ############################################################################
|
||||
self._args_device: Dict[str, torch.Tensor] = {
|
||||
# TENSOR FIELDS FOR UNCACHED ATTENTION
|
||||
"input_ids": torch.ones(self.max_num_tokens, dtype=torch.int),
|
||||
"position_ids": torch.zeros(self.max_num_tokens, dtype=torch.long),
|
||||
# TENSOR FIELDS FOR CACHED ATTENTION
|
||||
"seq_len": torch.empty(self.max_batch_size, dtype=torch.int),
|
||||
"input_pos": torch.empty(self.max_batch_size, dtype=torch.int),
|
||||
"cache_loc": torch.empty(self.num_pages, dtype=torch.int),
|
||||
"pages_per_seq": torch.empty(self.max_batch_size, dtype=torch.int),
|
||||
# OTHER FIELDS WHERE WE NEED EFFICIENT HOST<>DEVICE TRANSFER
|
||||
"_gather_idx": torch.empty(self.max_num_tokens, dtype=torch.int),
|
||||
}
|
||||
self._args_host: Dict[str, List[int]] = {
|
||||
k: v.tolist() for k, v in self._args_device.items()
|
||||
}
|
||||
# NOTE: order of keys is relevant here!
|
||||
self._uncached_arg_names = ("input_ids", "position_ids")
|
||||
self._cached_arg_names = ("seq_len", "input_pos", "cache_loc", "pages_per_seq")
|
||||
############################################################################################
|
||||
|
||||
# EXTRA TENSOR FIELDS ######################################################################
|
||||
self._extra_args: Dict[str, torch.Tensor] = {}
|
||||
self._extra_none_inputs: Dict[str, torch.Tensor] = {}
|
||||
self._extra_dynamic_shapes: Optional[Dict[str, DynamicShape]] = None
|
||||
self._extra_dynamic_shapes_callbacks: Dict[str, DynamicShapeCallback] = {}
|
||||
############################################################################################
|
||||
|
||||
# call reset once to set a consistent initial state
|
||||
self.reset()
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return self._args_device["input_ids"].device
|
||||
|
||||
def _shape_for_forward(self, tnsr: torch.Tensor) -> torch.Tensor:
|
||||
"""Shape the tensor for the forward pass based on the current attention mode.
|
||||
|
||||
Args:
|
||||
tnsr: The tensor to shape assumed to be in shape [batch_size*seq_len, ...]
|
||||
|
||||
Returns:
|
||||
The shaped tensor flattened or unflattened based on the current attention mode.
|
||||
"""
|
||||
# check if we are still running uncached attention in which case we are also still
|
||||
# operate on unflattened tensors with explicit [batch_size, seq_len, ...] shape
|
||||
# generate-only batches are also formatted like this (i.e. [b, 1])
|
||||
if not self._is_cached_attn or self.is_generate:
|
||||
bs = len(self.seq_len)
|
||||
sl = self.seq_len[0]
|
||||
# use [1,total_len] shape to indicate non-generate-only batch for cached attention
|
||||
else:
|
||||
bs, sl = 1, self.total_num_tokens
|
||||
|
||||
# truncate to total tokens now, reshape, and return
|
||||
return tnsr[: self.total_num_tokens].view(bs, sl, *tnsr.shape[1:])
|
||||
|
||||
def _named_args(
|
||||
self, include_extra_args: bool = True, include_cached_args: bool = True
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
# start with uncached args and shape them along the way
|
||||
args = {k: self._shape_for_forward(self._args_device[k]) for k in self._uncached_arg_names}
|
||||
|
||||
# check other args to include
|
||||
if include_extra_args:
|
||||
args.update(self._extra_args)
|
||||
|
||||
if include_cached_args:
|
||||
args.update({k: self._args_device[k] for k in self._cached_arg_names})
|
||||
|
||||
return args
|
||||
|
||||
@property
|
||||
def named_args(self) -> Dict[str, torch.Tensor]:
|
||||
"""Return a dictionary of named arguments.
|
||||
|
||||
These arguments contain all arguments that are managed by this interface and are required
|
||||
to run a model's forward pass including all extra arguments.
|
||||
|
||||
Cached arguments are only included if the attention mode is cached to reflect that after
|
||||
switching to cached attention, the cached arguments are required for a forward pass.
|
||||
"""
|
||||
return self._named_args(include_extra_args=True, include_cached_args=self._is_cached_attn)
|
||||
|
||||
@property
|
||||
def named_standard_args(self) -> Dict[str, torch.Tensor]:
|
||||
"""Return a dictionary of named standard arguments.
|
||||
|
||||
We define standard arguments as the arguments that are part of the model's forward function
|
||||
by default (i.e., without the extra arguments).
|
||||
|
||||
Just liked ``named_args``, this property includes cached attention arguments if the
|
||||
attention mode is cached.
|
||||
"""
|
||||
return self._named_args(include_extra_args=False, include_cached_args=self._is_cached_attn)
|
||||
|
||||
@property
|
||||
def args(self) -> Tuple[torch.Tensor, ...]:
|
||||
args = []
|
||||
for f in fields(self):
|
||||
val = getattr(self, f.name)
|
||||
if isinstance(val, torch.Tensor):
|
||||
args.append(val)
|
||||
if len(args) >= self._num_uncached_attn_args and not self._is_cached_attn:
|
||||
break
|
||||
|
||||
return tuple(args)
|
||||
"""Return a tuple of arguments."""
|
||||
return tuple(self.named_args.values())
|
||||
|
||||
@property
|
||||
def _num_uncached_attn_args(self) -> int:
|
||||
"""Return the number of original graph arguments expected by the model.
|
||||
This is 2 because we have input_ids and position_ids as the original graph arguments.
|
||||
def const_args_for_prepare_metadata(self) -> Tuple:
|
||||
"""Return a tuple of extra (const, non-tensor) arguments for the prepare_metadata op.
|
||||
|
||||
The ``prepare_metadata`` interface expects the following arguments:
|
||||
|
||||
1. ``named_standard_args`` as nodes,i.e., as input-dependent tensors.
|
||||
2. ``const_args_for_prepare_metadata`` as constants that can directly by passed in as args
|
||||
to the corresponding ``prepare_metadata`` node/op.
|
||||
|
||||
This interface handles the constant arguments part and can be used by compiler passes like
|
||||
``insert_cached_attention`` to extract the constant arguments and add them to the
|
||||
``prepare_metadata`` node/op.
|
||||
"""
|
||||
return 2
|
||||
return (self.page_size,)
|
||||
|
||||
@property
|
||||
def _cached_attn_arg_names(self) -> List[str]:
|
||||
"""Return extra arg names for the prepare_metadata op beyond input_ids and position_ids.
|
||||
|
||||
These extra args are needed once we switch from regular attention to inserting cached
|
||||
attention ops in the model.
|
||||
"""
|
||||
return [f.name for f in fields(self) if isinstance(getattr(self, f.name), torch.Tensor)][
|
||||
self._num_uncached_attn_args :
|
||||
]
|
||||
|
||||
@property
|
||||
def dynamic_shapes(self) -> Tuple[Dict[str, Dim]]:
|
||||
def named_dynamic_shapes(self) -> Dict[str, Dict[str, Dim]]:
|
||||
"""Return dynamic shapes of sequence info tensors.
|
||||
|
||||
NOTE: will be lazily initialized since the Dim object is not picklable for multi-processing.
|
||||
"""
|
||||
# lazy initialization of dynamic shapes with Dim objects
|
||||
if self._dynamic_shapes is None:
|
||||
# set up shape for input_ids and position_ids
|
||||
dynamic_shapes = ({}, {})
|
||||
# set up shape for uncached args (same for all, i.e., batch_size and seq_len)
|
||||
bs_seq_len_shape: DynamicShape = {}
|
||||
if self.max_batch_size > 1:
|
||||
dynamic_shapes[0][0] = Dim("batch_size", max=self.max_batch_size)
|
||||
dynamic_shapes[0][1] = Dim("seq_len", max=self.max_seq_len_adjusted)
|
||||
# set up shape for position_ids (same as input_ids)
|
||||
dynamic_shapes[1].update(dynamic_shapes[0])
|
||||
# set up shape for extra args
|
||||
if self._is_cached_attn:
|
||||
dynamic_shapes += ({},) * len(self._cached_attn_arg_names)
|
||||
self._dynamic_shapes = dynamic_shapes
|
||||
return self._dynamic_shapes
|
||||
bs_seq_len_shape[0] = Dim("batch_size", max=self.max_batch_size)
|
||||
bs_seq_len_shape[1] = Dim("seq_len", max=self.max_seq_len)
|
||||
self._dynamic_shapes = {k: bs_seq_len_shape for k in self._uncached_arg_names}
|
||||
# cached args are static
|
||||
self._dynamic_shapes.update({k: {} for k in self._cached_arg_names})
|
||||
|
||||
for k, callback in self._extra_dynamic_shapes_callbacks.items():
|
||||
if k not in self._dynamic_shapes:
|
||||
self._dynamic_shapes[k] = callback()
|
||||
|
||||
# return dynamic shapes according to currently active named_args with consistent order
|
||||
return {k: self._dynamic_shapes[k] for k in self.named_args.keys()}
|
||||
|
||||
@property
|
||||
def dynamic_shapes(self) -> Tuple[DynamicShape, ...]:
|
||||
"""Return dynamic shapes of sequence info tensors."""
|
||||
return tuple(self.named_dynamic_shapes.values())
|
||||
|
||||
@property
|
||||
def seq_len(self) -> List[int]:
|
||||
return self._args_host["seq_len"].copy()
|
||||
|
||||
@property
|
||||
def input_pos(self) -> List[int]:
|
||||
return self._args_host["input_pos"].copy()
|
||||
|
||||
@property
|
||||
def cache_loc(self) -> List[int]:
|
||||
return self._args_host["cache_loc"].copy()
|
||||
|
||||
@property
|
||||
def pages_per_seq(self) -> List[int]:
|
||||
return self._args_host["pages_per_seq"].copy()
|
||||
|
||||
@property
|
||||
def num_sequences(self) -> int:
|
||||
return len(self._sequence_lengths)
|
||||
return len(self.seq_len)
|
||||
|
||||
@property
|
||||
def sequence_lengths(self) -> List[int]:
|
||||
return self._sequence_lengths
|
||||
|
||||
@property
|
||||
def input_positions(self) -> List[int]:
|
||||
return self.input_pos_host[: self.num_sequences].tolist()
|
||||
def total_num_tokens(self) -> int:
|
||||
return sum(self.seq_len)
|
||||
|
||||
@property
|
||||
def is_generate(self) -> bool:
|
||||
return all(sl == 1 for sl in self.sequence_lengths)
|
||||
return all(sl == 1 for sl in self.seq_len)
|
||||
|
||||
@property
|
||||
def num_pages(self) -> int:
|
||||
@ -244,7 +326,7 @@ class SequenceInfo:
|
||||
def num_pages(self, value):
|
||||
self._num_pages = value
|
||||
# update the cache_loc tensor
|
||||
self.cache_loc.resize_(value)
|
||||
self._args_device["cache_loc"].resize_(value)
|
||||
|
||||
@property
|
||||
def is_paged(self) -> bool:
|
||||
@ -253,12 +335,52 @@ class SequenceInfo:
|
||||
@property
|
||||
def page_assignments(self) -> List[List[int]]:
|
||||
"""Return the page assignments for each sequence."""
|
||||
pages_per_seq = self.pages_per_seq[: self.num_sequences].tolist()
|
||||
return self._get_page_assignments(self.cache_loc, self.pages_per_seq)
|
||||
|
||||
@staticmethod
|
||||
def _get_page_assignments(
|
||||
cache_locations: List[int], pages_per_sequence: List[int]
|
||||
) -> List[List[int]]:
|
||||
"""Get nested page assignments from cache locations and pages per sequence as list of lists.
|
||||
|
||||
Args:
|
||||
cache_locations: A flat list of cache locations for each sequence ordered by sequence.
|
||||
pages_per_sequence: A list of number of pages per sequence.
|
||||
|
||||
Returns:
|
||||
A list of page assignments for each sequence ordered by sequence.
|
||||
For example:
|
||||
cache_locations: [0, 4, 2]
|
||||
pages_per_sequence: [2, 1]
|
||||
--> returns [[0, 4], [2]]
|
||||
"""
|
||||
return [
|
||||
c_loc_one_seq.tolist()
|
||||
for c_loc_one_seq in torch.split(self.cache_loc[: sum(pages_per_seq)], pages_per_seq)
|
||||
for c_loc_one_seq in torch.split(torch.tensor(cache_locations), pages_per_sequence)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _get_cache_locations_and_pages_per_sequence(
|
||||
page_assignments: List[List[int]],
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
"""Get cache locations and pages per sequence from nested page assignments (lists of lists).
|
||||
|
||||
Args:
|
||||
page_assignments: A list of page assignments for each sequence ordered by sequence.
|
||||
Returns:
|
||||
A tuple of:
|
||||
cache_locations: A flat list of cache locations for each sequence ordered by sequence.
|
||||
pages_per_sequence: A list of number of pages per sequence.
|
||||
|
||||
Example:
|
||||
page_assignments: [[0, 4], [2]]
|
||||
--> returns ([0, 4, 2], [2, 1])
|
||||
|
||||
"""
|
||||
cache_loc_flat = [p_idx for pages in page_assignments for p_idx in pages]
|
||||
pages_per_seq = [len(p) for p in page_assignments]
|
||||
return cache_loc_flat, pages_per_seq
|
||||
|
||||
@classmethod
|
||||
def _get_sanitized_seq_len(cls, input_ids: torch.Tensor, seq_len: torch.Tensor) -> torch.Tensor:
|
||||
"""Sanitize sequence lengths.
|
||||
@ -331,26 +453,55 @@ class SequenceInfo:
|
||||
"""
|
||||
assert not self._is_cached_attn, "Cached+flattened attention already activated"
|
||||
self._is_cached_attn = True
|
||||
return self._cached_attn_arg_names
|
||||
return list(self._cached_arg_names)
|
||||
|
||||
def to(self, *args, **kwargs) -> None:
|
||||
for f in fields(self):
|
||||
val = getattr(self, f.name)
|
||||
if isinstance(val, torch.Tensor):
|
||||
setattr(self, f.name, val.to(*args, **kwargs))
|
||||
def _move_dict(d: Dict[str, torch.Tensor]) -> None:
|
||||
for k, v in d.items():
|
||||
d[k] = v.to(*args, **kwargs)
|
||||
|
||||
def sync(self, other: "SequenceInfo") -> None:
|
||||
for f in fields(self):
|
||||
val = getattr(self, f.name)
|
||||
val_other = getattr(other, f.name)
|
||||
if f.name in ["input_ids", "position_ids"]:
|
||||
setattr(self, f.name, val_other.to(self.device))
|
||||
elif f.name == "_sequence_lengths":
|
||||
self._sequence_lengths = val_other
|
||||
elif isinstance(val, torch.Tensor):
|
||||
val[: len(val_other)] = val_other.to(self.device)
|
||||
else:
|
||||
assert val == val_other, f"Field {f.name} mismatch: {val} != {val_other}."
|
||||
_move_dict(self._args_device)
|
||||
_move_dict(self._extra_args)
|
||||
_move_dict(self._extra_none_inputs)
|
||||
|
||||
def set_example_sequence(
|
||||
self,
|
||||
input_ids: Sequence[Sequence[int]] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
**extra_args,
|
||||
) -> None:
|
||||
"""Set an example sequence useful for testing and export purposes without cache history."""
|
||||
# use a best guess default for input_ids if not provided
|
||||
if input_ids is None:
|
||||
bs, seq_len = min(2, self.max_batch_size), min(4, self.max_seq_len)
|
||||
input_ids = torch.ones(bs, seq_len, dtype=torch.int).tolist()
|
||||
|
||||
# figure out page assignments
|
||||
pages_per_seq = [
|
||||
len(ids_one_seq) // self.page_size + (len(ids_one_seq) % self.page_size > 0)
|
||||
for ids_one_seq in input_ids
|
||||
]
|
||||
cache_loc = list(range(sum(pages_per_seq)))
|
||||
page_assignments = self._get_page_assignments(cache_loc, pages_per_seq)
|
||||
|
||||
self.nest_sequences(
|
||||
input_ids,
|
||||
position_ids, # will be auto-inferred if None
|
||||
input_pos=0, # no cache history
|
||||
page_assignments=page_assignments, # vanilla page assignments
|
||||
**extra_args,
|
||||
)
|
||||
|
||||
def set_max_num_tokens_sample(self) -> None:
|
||||
"""Set an example sequence with max_num_tokens."""
|
||||
# TODO (lucaslie): understand what this implies for extra arguments
|
||||
seq_len = self.max_num_tokens // self.max_batch_size
|
||||
input_ids = torch.ones(self.max_batch_size, seq_len, dtype=torch.int).tolist()
|
||||
self.set_example_sequence(input_ids)
|
||||
|
||||
def set_generate_only_batch(self) -> None:
|
||||
"""Set an example sequence for generate-only batch."""
|
||||
self.set_example_sequence([[1]] * self.max_batch_size)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the sequence information.
|
||||
@ -358,205 +509,173 @@ class SequenceInfo:
|
||||
After reset the sequence information should correspond to a "generate-only" batch of
|
||||
sequences (b, s==1) without cache history.
|
||||
"""
|
||||
# reset input_pos
|
||||
self.input_pos.zero_()
|
||||
self.input_pos_host.zero_()
|
||||
self.set_generate_only_batch()
|
||||
|
||||
# set a dummy sequence corresponding to a generate-only batch (will also reset position_ids)
|
||||
self.nest_sequences([[1]] * self.max_batch_size, allow_realloc=True)
|
||||
@staticmethod
|
||||
def _flatten(nested_seqs: Sequence[Sequence[int]]) -> List[int]:
|
||||
return [
|
||||
val
|
||||
for lst in nested_seqs
|
||||
for val in (lst.detach().tolist() if isinstance(lst, torch.Tensor) else lst)
|
||||
]
|
||||
|
||||
# reset cache information
|
||||
self.cache_loc[:] = torch.arange(self.num_pages, dtype=torch.int, device=self.device)
|
||||
self.pages_per_seq.fill_(1)
|
||||
|
||||
# let's also reset the input_ids and position_ids tensors to their max shapes (max_num_tokens)
|
||||
self.input_ids = torch.ones(self.max_num_tokens, dtype=torch.int, device=self.device)
|
||||
self.position_ids = torch.zeros(self.max_num_tokens, dtype=torch.long, device=self.device)
|
||||
|
||||
def set_example_sequence(self) -> None:
|
||||
"""Set an example sequence useful for testing and export purposes."""
|
||||
self.reset()
|
||||
bs, seq_len = min(2, self.max_batch_size), min(4, self.max_seq_len)
|
||||
input_ids = torch.ones(
|
||||
bs,
|
||||
seq_len,
|
||||
dtype=torch.int,
|
||||
device=self.device,
|
||||
)
|
||||
self.nest_sequences(input_ids, allow_realloc=True)
|
||||
|
||||
# unflatten if we are not yet using cached+flattened attention
|
||||
if not self._is_cached_attn:
|
||||
self.input_ids = self.input_ids.view(bs, seq_len)
|
||||
self.position_ids = self.position_ids.view(bs, seq_len)
|
||||
|
||||
def _set_max_num_tokens_sample(self) -> None:
|
||||
"""Set an example sequence with max_num_tokens."""
|
||||
self.reset()
|
||||
seq_len = self.max_num_tokens // self.max_batch_size
|
||||
input_ids = torch.ones(
|
||||
self.max_batch_size,
|
||||
seq_len,
|
||||
dtype=torch.int,
|
||||
device=self.device,
|
||||
)
|
||||
self.pages_per_seq.fill_(seq_len // self.page_size)
|
||||
self.nest_sequences(input_ids, allow_realloc=True)
|
||||
|
||||
def set_generate_only_batch(self) -> None:
|
||||
"""Set an example sequence for generate-only batch.
|
||||
|
||||
NOTE: this batch is already formatted as [b, 1] in both original and in the cached attention
|
||||
mode. So we don't need to do anything mode-specific here.
|
||||
"""
|
||||
self.reset()
|
||||
self.nest_sequences([[1]] * self.max_batch_size, allow_realloc=True)
|
||||
|
||||
def maybe_reshape_for_generate(self, tensor: torch.Tensor) -> torch.Tensor:
|
||||
# use [b,1] shape to indicate generate-only batch, otherwise use [1,total_len]
|
||||
if self.is_generate:
|
||||
return tensor.view(-1, 1, *tensor.shape[1:])
|
||||
else:
|
||||
return tensor.view(1, -1, *tensor.shape[1:])
|
||||
|
||||
@nvtx_range("ad_update_position_ids")
|
||||
def _update_position_ids(self, allow_realloc: bool = False) -> None:
|
||||
# set new position_ids from input_pos and seq_len
|
||||
# Make sure this is done on host to avoid host-device copies.
|
||||
with nvtx_range("prepare_list"):
|
||||
# Optimize for the common case where all seq_len values are 1 (generation mode)
|
||||
if torch.all(self.seq_len_host == 1):
|
||||
# Fast path: when all seq_len are 1, position_ids is just input_pos_host
|
||||
position_ids_host = (
|
||||
self.input_pos_host[: self.num_tokens].to(dtype=torch.long).pin_memory()
|
||||
)
|
||||
else:
|
||||
# General case - can probably be optimized too, but overall impact will be minor.
|
||||
position_ids_list = []
|
||||
for in_pos, seq_len in zip(self.input_pos_host, self.seq_len_host):
|
||||
position_ids_list.extend(range(in_pos, in_pos + seq_len))
|
||||
position_ids_host = torch.tensor(
|
||||
position_ids_list, dtype=torch.long, pin_memory=True
|
||||
)
|
||||
with nvtx_range("copy_to_device"):
|
||||
if allow_realloc:
|
||||
# Create a new position_ids tensor on the device
|
||||
self.position_ids = position_ids_host.to(self.device).clone()
|
||||
else:
|
||||
self.position_ids_full = self.position_ids_full.flatten()
|
||||
self.position_ids_full[: self.num_tokens].copy_(
|
||||
position_ids_host, non_blocking=True
|
||||
)
|
||||
with nvtx_range("maybe_reshape"):
|
||||
self.position_ids = self.maybe_reshape_for_generate(
|
||||
self.position_ids if allow_realloc else self.position_ids_full[: self.num_tokens]
|
||||
)
|
||||
|
||||
@nvtx_range("ad_update_sequence_lengths")
|
||||
def _update_sequence_lengths(self, sequence_lengths: List[int]) -> None:
|
||||
self._sequence_lengths = sequence_lengths
|
||||
self.num_tokens = sum(self._sequence_lengths)
|
||||
self.seq_len.zero_()
|
||||
self.seq_len_host = torch.tensor(self._sequence_lengths, pin_memory=True)
|
||||
self.seq_len[: len(self._sequence_lengths)].copy_(self.seq_len_host, non_blocking=True)
|
||||
|
||||
def update_input_ids_with_new_tokens(
|
||||
self, new_tokens: torch.Tensor, previous_batch_indices: List[int]
|
||||
def _store_arg(
|
||||
self,
|
||||
name: str,
|
||||
tnsr_like: List[int | float],
|
||||
reset: bool = False,
|
||||
) -> None:
|
||||
"""Update the input_ids with new tokens.
|
||||
"""Store the argument on the host and copy to the device in a non-blocking fashion.
|
||||
|
||||
This function will update the input_ids with new tokens and previous batch indices.
|
||||
Args:
|
||||
name: Name of the argument to store.
|
||||
tnsr_like: List of values to store.
|
||||
reset: Whether to reset the full tensor on the device to 0 before writing to it.
|
||||
"""
|
||||
# 1) flatten once
|
||||
original_shape = self.input_ids.shape
|
||||
flat = self.input_ids.flatten()
|
||||
with nvtx_range(f"ad_store_seq_info_arg_{name}"):
|
||||
tnsr_device = self._args_device[name]
|
||||
|
||||
# copy indices to the GPU
|
||||
host_idx = torch.tensor(previous_batch_indices, dtype=torch.int, pin_memory=True)
|
||||
idx = self.previous_batch_indices_cuda[: len(previous_batch_indices)]
|
||||
idx.copy_(host_idx, non_blocking=True)
|
||||
# store list object on the host
|
||||
self._args_host[name] = tnsr_like.copy()
|
||||
|
||||
# gather the exact values you want to write
|
||||
src = new_tokens[0, idx, 0]
|
||||
# pin the memory on the host
|
||||
tnsr_host = torch.tensor(tnsr_like, dtype=tnsr_device.dtype, pin_memory=True)
|
||||
|
||||
# in‐place fill every slot where flat == -1 with src, in order
|
||||
flat.masked_scatter_(flat == -1, src)
|
||||
# reset/copy to the device in a non-blocking fashion
|
||||
if reset:
|
||||
tnsr_device.zero_()
|
||||
tnsr_device[: len(tnsr_like)].copy_(tnsr_host, non_blocking=True)
|
||||
|
||||
# 4) reshape back
|
||||
self.input_ids = flat.view(original_shape)
|
||||
def _store_extra_arg(
|
||||
self, name: str, tnsr_like: Optional[Union[torch.Tensor, Sequence[torch.Tensor]]]
|
||||
) -> None:
|
||||
with nvtx_range(f"ad_store_extra_arg_{name}"):
|
||||
if tnsr_like is not None:
|
||||
if not isinstance(tnsr_like, torch.Tensor):
|
||||
if len(tnsr_like) > 1:
|
||||
tnsr_like = torch.cat(tnsr_like)
|
||||
else:
|
||||
tnsr_like = tnsr_like[0]
|
||||
self._extra_args[name] = tnsr_like.to(self.device, non_blocking=True)
|
||||
else:
|
||||
self._extra_args[name] = self._extra_none_inputs[name]
|
||||
|
||||
@nvtx_range("ad_nest_sequences")
|
||||
def nest_sequences(
|
||||
self, input_ids: Sequence[Sequence[int]], allow_realloc: bool = False
|
||||
self,
|
||||
input_ids: Sequence[Sequence[int]],
|
||||
position_ids: Optional[Sequence[Sequence[int]]] = None,
|
||||
input_pos: Optional[Union[Sequence[int], int]] = None,
|
||||
page_assignments: Optional[Sequence[Sequence[int]]] = None,
|
||||
**extra_args: Dict[str, Union[torch.Tensor, Sequence[torch.Tensor]]],
|
||||
) -> None:
|
||||
"""Create and store a flattened list of input_ids from the provided list of sequences.
|
||||
"""Create and store sequence information for the next forward pass.
|
||||
|
||||
When allow_realloc is True, the input_ids will be reallocated on the device.
|
||||
This i/f will also update any relevant sequence information.
|
||||
Args:
|
||||
input_ids: List of sequences of input_ids.
|
||||
position_ids: List of sequences of position_ids for each token.
|
||||
input_pos: Absolute starting position in the cache for each sequence.
|
||||
page_assignments: List of sequences of page assignments for each sequence.
|
||||
extra_args: Extra arguments to be stored in the interface.
|
||||
|
||||
This i/f will ensure that all sequence info args are updated accordingly.
|
||||
"""
|
||||
# set new sequence lengths
|
||||
self._update_sequence_lengths([len(ids) for ids in input_ids])
|
||||
### UPDATE METADATA ########################################################################
|
||||
# update metadata first since it's useful for other updates to have up-to-date information
|
||||
|
||||
# We'll preserve the dtype of the input_ids tensor if it is a tensor, otherwise we'll use int
|
||||
dtype = input_ids.dtype if isinstance(input_ids, torch.Tensor) else torch.int
|
||||
# set new input_ids as new tensor from flattened input_ids
|
||||
ids_list = [
|
||||
val
|
||||
for lst in input_ids
|
||||
for val in (lst.detach().tolist() if isinstance(lst, torch.Tensor) else lst)
|
||||
]
|
||||
input_ids_host = torch.tensor(ids_list, dtype=dtype, pin_memory=True)
|
||||
# set new sequence lengths --> resetting the remaining entries to zero is important to help
|
||||
# us discern the actual number of sequences in the batch.
|
||||
self._store_arg("seq_len", [len(ids) for ids in input_ids], reset=True)
|
||||
|
||||
if allow_realloc:
|
||||
self.input_ids = input_ids_host.to(self.device).clone()
|
||||
else:
|
||||
self.input_ids_full = self.input_ids_full.flatten()
|
||||
self.input_ids_full[: self.num_tokens].copy_(input_ids_host, non_blocking=True)
|
||||
# check for updated input_pos (i.e. cache start position)
|
||||
if input_pos is not None:
|
||||
self._store_arg(
|
||||
"input_pos",
|
||||
[input_pos] * self.num_sequences if isinstance(input_pos, int) else input_pos,
|
||||
)
|
||||
|
||||
self.input_ids = self.maybe_reshape_for_generate(
|
||||
self.input_ids if allow_realloc else self.input_ids_full[: self.num_tokens]
|
||||
)
|
||||
# update position_ids
|
||||
self._update_position_ids(allow_realloc=allow_realloc)
|
||||
# check for updated page_assignments
|
||||
if page_assignments is not None:
|
||||
cache_loc, pages_per_seq = self._get_cache_locations_and_pages_per_sequence(
|
||||
page_assignments
|
||||
)
|
||||
self._store_arg("cache_loc", cache_loc)
|
||||
self._store_arg("pages_per_seq", pages_per_seq)
|
||||
|
||||
### UPDATE MAIN INPUTS #####################################################################
|
||||
# set new input_ids and make sure to flatten it
|
||||
self._store_arg("input_ids", self._flatten(input_ids))
|
||||
|
||||
# set new position_ids and make sure to flatten it
|
||||
if position_ids is None:
|
||||
position_ids = [
|
||||
[num for num in range(in_pos, in_pos + seq_len)]
|
||||
for in_pos, seq_len in zip(self.input_pos, self.seq_len)
|
||||
]
|
||||
self._store_arg("position_ids", self._flatten(position_ids))
|
||||
|
||||
### UPDATE EXTRA INPUTS ####################################################################
|
||||
# go through all extra tensor arguments and update them
|
||||
for name in self._extra_none_inputs.keys():
|
||||
self._store_extra_arg(name, extra_args.pop(name, None))
|
||||
assert not extra_args, f"Extra arguments {extra_args.keys()} not found"
|
||||
|
||||
@nvtx_range("ad_rescatter_input_ids")
|
||||
def rescatter_input_ids(
|
||||
self, ungathered_input_ids: torch.Tensor, gather_idx: List[int], scatter_ref: int
|
||||
):
|
||||
"""Re-scatter the provided ungathered input ids into the input_ids tensor.
|
||||
|
||||
Args:
|
||||
ungathered_input_ids: The input ids on the device from which to gather.
|
||||
gather_idx: The list of indices to gather from the ungathered_input_ids.
|
||||
scatter_ref: The reference index to scatter to in input_ids via masked scatter.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
This function will assume that we are in a generate-only batch.
|
||||
"""
|
||||
# store the new gather indices
|
||||
self._store_arg("_gather_idx", gather_idx)
|
||||
|
||||
# gather the provided input ids in a streaming fashion
|
||||
gather_ids_device = self._args_device["_gather_idx"][: len(gather_idx)]
|
||||
packed_input_ids = ungathered_input_ids[gather_ids_device]
|
||||
|
||||
# re-scatter the provided input ids into the input_ids tensor
|
||||
input_ids_device = self._args_device["input_ids"]
|
||||
input_ids_device.masked_scatter_(input_ids_device == scatter_ref, packed_input_ids)
|
||||
|
||||
@nvtx_range("ad_unnest_sequences")
|
||||
def unnest_sequences(self, t_nested: torch.Tensor) -> List[torch.Tensor]:
|
||||
t_squeezed = t_nested.squeeze(1) if self.is_generate else t_nested.squeeze(0)
|
||||
return list(torch.split(t_squeezed, self.sequence_lengths))
|
||||
return list(torch.split(t_squeezed, self.seq_len))
|
||||
|
||||
@nvtx_range("ad_update_pos")
|
||||
def update_pos(self, seq_len: Union[torch.Tensor, List[int], int], reset: bool = False) -> None:
|
||||
"""Update the starting position for each sequence in the cache.
|
||||
def add_extra_arg(
|
||||
self,
|
||||
name: str,
|
||||
none_input: torch.Tensor,
|
||||
dynamic_shape_callback: Optional[DynamicShapeCallback] = None,
|
||||
) -> None:
|
||||
"""Add an extra argument to the sequence info object.
|
||||
|
||||
If ``reset=True`, ``input_pos`` will be reset to zero before updating.
|
||||
Args:
|
||||
name: The name of the extra argument.
|
||||
none_input: None input value of the extra argument.
|
||||
dynamic_shape_callback: The callback to get the dynamic shape of the extra argument.
|
||||
|
||||
Note that the extra argument is expected to be a tensor.
|
||||
"""
|
||||
if not isinstance(seq_len, torch.Tensor):
|
||||
seq_len = torch.tensor(seq_len, dtype=torch.int, pin_memory=True)
|
||||
bs = len(seq_len) if seq_len.dim() > 0 else self.max_batch_size
|
||||
assert name not in self._named_args().keys(), f"Extra argument {name} already exists"
|
||||
|
||||
if reset:
|
||||
self.input_pos_host[:bs].copy_(seq_len, non_blocking=True)
|
||||
self._extra_args[name] = none_input.to(self.device)
|
||||
self._extra_none_inputs[name] = none_input.to(self.device)
|
||||
|
||||
if dynamic_shape_callback is None:
|
||||
self._extra_dynamic_shapes_callbacks[name] = lambda: {}
|
||||
else:
|
||||
self.input_pos_host[:bs] += seq_len
|
||||
|
||||
# update position_ids
|
||||
self._update_position_ids()
|
||||
self.input_pos[:bs].copy_(self.input_pos_host[:bs], non_blocking=True)
|
||||
|
||||
@nvtx_range("ad_assign_cache_loc")
|
||||
def assign_cache_loc(self, page_assignments: Sequence[Sequence[int]]) -> None:
|
||||
"""Set the cache location and pages_per_seq tensors from page assignments."""
|
||||
cache_loc_flat = torch.tensor(
|
||||
[p_idx for pages in page_assignments for p_idx in pages],
|
||||
dtype=torch.int,
|
||||
pin_memory=True,
|
||||
)
|
||||
self.cache_loc[: len(cache_loc_flat)].copy_(cache_loc_flat, non_blocking=True)
|
||||
|
||||
pages_per_seq = torch.tensor(
|
||||
[len(p) for p in page_assignments], dtype=torch.int, pin_memory=True
|
||||
)
|
||||
self.pages_per_seq[: len(pages_per_seq)].copy_(pages_per_seq, non_blocking=True)
|
||||
self._extra_dynamic_shapes_callbacks[name] = dynamic_shape_callback
|
||||
|
||||
|
||||
Constant = Union[int, float, str, None]
|
||||
|
||||
@ -63,6 +63,7 @@ def scaled_dot_product_attention(
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
enable_gqa: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""A carbon copy of torch.nn.functional.scaled_dot_product_attention as custom op.
|
||||
|
||||
@ -78,12 +79,13 @@ def scaled_dot_product_attention(
|
||||
dropout_p=dropout_p,
|
||||
is_causal=is_causal,
|
||||
scale=scale,
|
||||
enable_gqa=enable_gqa,
|
||||
)
|
||||
|
||||
|
||||
@scaled_dot_product_attention.register_fake
|
||||
def scaled_dot_product_attention_fake(
|
||||
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None
|
||||
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, enable_gqa=False
|
||||
):
|
||||
"""Fake implementation of scaled_dot_product_attention."""
|
||||
return query.new_empty(*query.shape[:-1], value.shape[-1]).contiguous()
|
||||
|
||||
@ -18,7 +18,7 @@ from ..transformations._graph import (
|
||||
)
|
||||
from ..utils.logger import ad_logger
|
||||
from ..utils.node_utils import is_op
|
||||
from .interface import ExportPatchRegistry, apply_export_patches
|
||||
from .interface import apply_export_patches
|
||||
|
||||
try:
|
||||
from modelopt.torch.quantization.utils import export_torch_mode as torch_export_context
|
||||
@ -229,20 +229,9 @@ def torch_export_to_gm(
|
||||
patch_list: Optional list of patch names to apply with default settings.
|
||||
Cannot be used together with patch_configs.
|
||||
"""
|
||||
# Validate that both patch_configs and patch_list are not provided simultaneously
|
||||
if patch_configs is not None and patch_list is not None:
|
||||
raise ValueError("Cannot specify both patch_configs and patch_list. Use only one.")
|
||||
|
||||
# Handle patch configuration
|
||||
if patch_list is not None:
|
||||
# Convert patch_list to patch_configs format
|
||||
patch_configs = {patch_name: {} for patch_name in patch_list}
|
||||
elif patch_configs is None:
|
||||
# Default patch configurations - apply all registered patches with default settings
|
||||
patch_configs = {patch_name: {} for patch_name in ExportPatchRegistry.list_patches()}
|
||||
|
||||
# run export with patches and lifted to meta
|
||||
with apply_export_patches(patch_configs), lift_to_meta(model) as state_dict:
|
||||
with apply_export_patches(patch_configs, patch_list), lift_to_meta(model) as state_dict:
|
||||
# clean up args, kwargs and move to correct device
|
||||
args, kwargs = tree_to((args, kwargs or {}), device="meta")
|
||||
|
||||
|
||||
@ -5,7 +5,7 @@ This module defines the base classes and interfaces for all export patches.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, Dict, List, Type, Union, final
|
||||
from typing import Any, Callable, Dict, List, Optional, Type, Union, final
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@ -183,6 +183,8 @@ class ExportPatchRegistry:
|
||||
@classmethod
|
||||
def get(cls, name: str) -> Type[BaseExportPatch]:
|
||||
"""Get a patch class by name."""
|
||||
if not cls.has(name):
|
||||
raise ValueError(f"Unknown patch: {name}")
|
||||
return cls._registry[name]
|
||||
|
||||
@classmethod
|
||||
@ -212,20 +214,29 @@ class ExportPatchRegistry:
|
||||
|
||||
|
||||
@contextmanager
|
||||
def apply_export_patches(patch_configs: Dict[str, Union[ExportPatchConfig, Dict[str, Any]]]):
|
||||
def apply_export_patches(
|
||||
patch_configs: Optional[Dict[str, Union[ExportPatchConfig, Dict[str, Any]]]] = None,
|
||||
patch_list: Optional[List[str]] = None,
|
||||
):
|
||||
"""Context manager to apply multiple patches.
|
||||
|
||||
Args:
|
||||
patch_configs: Dict mapping patch names to their configurations.
|
||||
"""
|
||||
patches = []
|
||||
# Validate that both patch_configs and patch_list are not provided simultaneously
|
||||
if patch_configs is not None and patch_list is not None:
|
||||
raise ValueError("Cannot specify both patch_configs and patch_list. Use only one.")
|
||||
|
||||
# Handle patch configuration
|
||||
if patch_list is not None:
|
||||
# Convert patch_list to patch_configs format
|
||||
patch_configs = {patch_name: {} for patch_name in patch_list}
|
||||
elif patch_configs is None:
|
||||
# Default patch configurations - apply all registered patches with default settings
|
||||
patch_configs = {patch_name: {} for patch_name in ExportPatchRegistry.list_patches()}
|
||||
|
||||
# Create patch instances
|
||||
for name, config in patch_configs.items():
|
||||
if not ExportPatchRegistry.has(name):
|
||||
raise ValueError(f"Unknown patch: {name}")
|
||||
patch = ExportPatchRegistry.create_patch(name, config)
|
||||
patches.append(patch)
|
||||
patches = [ExportPatchRegistry.create_patch(k, conf) for k, conf in patch_configs.items()]
|
||||
|
||||
# Apply patches using nested context managers
|
||||
if not patches:
|
||||
|
||||
@ -1,19 +1,92 @@
|
||||
import types
|
||||
from typing import List, Optional
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from ...executor.result import CompletionOutput
|
||||
from ...inputs.registry import create_input_processor
|
||||
from ...inputs.registry import DefaultInputProcessor, ExtraProcessedInputs
|
||||
from ...llmapi.llm import RequestOutput, _TorchLLM
|
||||
from ...llmapi.tokenizer import TokenizerBase, tokenizer_factory
|
||||
from ...llmapi.tokenizer import TokenizerBase, TransformersTokenizer, tokenizer_factory
|
||||
from ...sampling_params import SamplingParams
|
||||
from .distributed import common as dist_ad
|
||||
from .llm_args import LlmArgs
|
||||
from .models.factory import ModelFactory
|
||||
from .shim.demollm import DemoGenerationExecutor
|
||||
|
||||
|
||||
class ADInputProcessor(DefaultInputProcessor):
|
||||
"""Input processor for AutoDeploy backend.
|
||||
|
||||
This is a wrapper to either support standard TRT-LLM text-only input processing or use HF's
|
||||
message chat template system to process multimodal inputs.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: Optional[TokenizerBase], processor: Optional[Any] = None):
|
||||
super().__init__(model_path=None, model_config=None, tokenizer=tokenizer)
|
||||
# NOTE: HF's tokenizer/processor that has the apply_chat_template method
|
||||
self.processor = processor or getattr(tokenizer, "tokenizer", None)
|
||||
|
||||
def __call__(
|
||||
self, inputs: Dict[str, Any], sampling_params: SamplingParams
|
||||
) -> Tuple[List[int], Optional[ExtraProcessedInputs]]:
|
||||
if self.processor is None:
|
||||
raise ValueError("processor is required to tokenize inputs")
|
||||
|
||||
# construct kwargs to reflect DefaultInputProcessor
|
||||
kwargs = {
|
||||
"add_special_tokens": sampling_params.add_special_tokens,
|
||||
}
|
||||
if sampling_params.truncate_prompt_tokens is not None:
|
||||
kwargs = {
|
||||
"truncation": True,
|
||||
"max_length": sampling_params.truncate_prompt_tokens,
|
||||
}
|
||||
# check for messages field and if yes, use the apply_chat_template method
|
||||
if "messages" in inputs:
|
||||
# TODO: we don't really need this but it makes for a good sanity check. Consider
|
||||
# removing this in the future if we need to speed things up.
|
||||
prompt = self.processor.apply_chat_template(
|
||||
inputs["messages"],
|
||||
add_generation_prompt=True,
|
||||
tokenize=False,
|
||||
)
|
||||
inputs["prompt"] = prompt
|
||||
|
||||
all_args = self.processor.apply_chat_template(
|
||||
inputs["messages"],
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
padding=False, # there shouldn't be a need for padding ever...
|
||||
return_attention_mask=False,
|
||||
**kwargs,
|
||||
)
|
||||
# TODO: is there a more reliable way to avoid the attention_mask here?
|
||||
all_args.pop("attention_mask", None)
|
||||
|
||||
# TODO: can we avoid the extra tolist() here eventually?
|
||||
token_ids = all_args.pop("input_ids")
|
||||
assert token_ids.shape[0] == 1, "messages should be unbatched at this point."
|
||||
if all_args:
|
||||
extra_processed_inputs = {"multimodal_data": all_args}
|
||||
else:
|
||||
extra_processed_inputs = None
|
||||
return token_ids[0].tolist(), extra_processed_inputs
|
||||
else:
|
||||
token_ids = self.tokenizer.encode(inputs["prompt"], **kwargs)
|
||||
return token_ids, None
|
||||
|
||||
|
||||
class LLM(_TorchLLM):
|
||||
"""LLM class is the main class for running an LLM model using AutoDeploy backend."""
|
||||
|
||||
args: LlmArgs
|
||||
_factory: ModelFactory
|
||||
|
||||
@property
|
||||
def factory(self) -> ModelFactory:
|
||||
if not getattr(self, "_factory", None):
|
||||
self._factory = self.args.create_factory()
|
||||
return self._factory
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs["backend"] = "_autodeploy"
|
||||
@ -23,16 +96,18 @@ class LLM(_TorchLLM):
|
||||
if self.args.skip_tokenizer_init:
|
||||
return None
|
||||
|
||||
factory = self.args.create_factory()
|
||||
return tokenizer_factory(factory.init_tokenizer())
|
||||
return tokenizer_factory(self.factory.init_tokenizer())
|
||||
|
||||
def _validate_args_for_torch_backend(self, kwargs: dict) -> None:
|
||||
"""We don't need to validate args for AutoDeploy backend for now."""
|
||||
pass
|
||||
|
||||
def _create_input_processor(self) -> ADInputProcessor:
|
||||
return ADInputProcessor(self.tokenizer, self.factory.init_processor())
|
||||
|
||||
def _prefetch_model(self):
|
||||
"""Prefetch the model for the LLM."""
|
||||
self.args.create_factory().prefetch_checkpoint()
|
||||
self.factory.prefetch_checkpoint()
|
||||
|
||||
def _build_model(self):
|
||||
"""Build the model for the LLM.
|
||||
@ -47,6 +122,11 @@ class LLM(_TorchLLM):
|
||||
# _autodeploy backend.
|
||||
super()._build_model()
|
||||
|
||||
# now correct input processor
|
||||
assert isinstance(self.input_processor, DefaultInputProcessor)
|
||||
assert self.tokenizer is None or isinstance(self.tokenizer, TransformersTokenizer)
|
||||
self.input_processor = self._create_input_processor()
|
||||
|
||||
|
||||
class DemoLLM(LLM):
|
||||
"""A simple LLM class to demo the LLM interface while debugging the e2e workflow.
|
||||
@ -63,7 +143,7 @@ class DemoLLM(LLM):
|
||||
# prefetch model and load tokenizer
|
||||
self._prefetch_model()
|
||||
self._tokenizer = self._try_load_tokenizer()
|
||||
self.input_processor = create_input_processor(None, self.tokenizer)
|
||||
self.input_processor = self._create_input_processor()
|
||||
|
||||
# construct demo executor + engine
|
||||
self._executor = DemoGenerationExecutor(
|
||||
|
||||
@ -57,7 +57,7 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings):
|
||||
description="The path to the model checkpoint or the model name from the Hugging Face Hub."
|
||||
)
|
||||
|
||||
model_factory: Literal["AutoModelForCausalLM", "AutoModelForImageTextToText"] = Field(
|
||||
model_factory: str = Field(
|
||||
default="AutoModelForCausalLM",
|
||||
description="The model factory to use for loading the model.",
|
||||
)
|
||||
|
||||
@ -3,13 +3,13 @@
|
||||
import copy
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, Optional, Type
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch._prims_common import DeviceLikeType
|
||||
|
||||
from ..custom_ops.attention_interface import CacheConfig
|
||||
from ..custom_ops.attention_interface import CacheConfig, DynamicShapeCallback
|
||||
from ..utils.logger import ad_logger
|
||||
|
||||
|
||||
@ -25,7 +25,8 @@ class ModelFactory(ABC):
|
||||
|
||||
NOTE: we make the assumption that the model can be prompted with a set of input_ids and
|
||||
position_ids of shape (batch_size, seq_len) and will return a tensor of shape
|
||||
(batch_size, seq_len, embedding_size).
|
||||
(batch_size, seq_len, embedding_size) by default. Individual factories have the ability to
|
||||
define additional optional inputs and their (dynamic) shapes.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -74,7 +75,7 @@ class ModelFactory(ABC):
|
||||
.. code-block:: python
|
||||
|
||||
def forward(
|
||||
self, input_ids: torch.Tensor, position_ids: torch.Tensor
|
||||
self, input_ids: torch.Tensor, position_ids: torch.Tensor, *extra_args: torch.Tensor
|
||||
) -> Sequence[torch.Tensor]: ...
|
||||
|
||||
``logits`` are assumed to be the first output of the model, i.e.,
|
||||
@ -87,6 +88,9 @@ class ModelFactory(ABC):
|
||||
input_ids.shape == (batch_size, seq_len)
|
||||
position_ids.shape == (batch_size, seq_len)
|
||||
logits.shape == (batch_size, seq_len, vocab_size)
|
||||
|
||||
We allow for additional arguments to be passed to the model's forward function as defined by
|
||||
the factory.
|
||||
"""
|
||||
# make sure model architecture is pre-fetched (no weights needed at this point)
|
||||
skip_loading_weights = self.skip_loading_weights
|
||||
@ -127,6 +131,15 @@ class ModelFactory(ABC):
|
||||
"""
|
||||
return None
|
||||
|
||||
def init_processor(self) -> Optional[Any]:
|
||||
"""Initialize the (multi-modal) processor for the model.
|
||||
|
||||
Returns:
|
||||
The initialized processor for the model. If the processor is not available, then this
|
||||
method should return None.
|
||||
"""
|
||||
return None
|
||||
|
||||
def prefetch_checkpoint(self, force: bool = False):
|
||||
"""Try or skip prefetching the checkpoint for the model and tokenizer.
|
||||
|
||||
@ -220,6 +233,35 @@ class ModelFactory(ABC):
|
||||
device: The device to load the model on.
|
||||
"""
|
||||
|
||||
def get_example_inputs(self) -> Dict[str, torch.Tensor]:
|
||||
"""Return a dictionary of example inputs for the model.
|
||||
|
||||
This function can be overwritten by a factory when it requires a specific example input to
|
||||
in order to run through export.
|
||||
|
||||
Returns:
|
||||
A dictionary of example inputs for the model where the key corresponds to the argument
|
||||
name and the value corresponds to the example input.
|
||||
"""
|
||||
return {}
|
||||
|
||||
def get_extra_inputs(self) -> Dict[str, Tuple[torch.Tensor, Optional[DynamicShapeCallback]]]:
|
||||
"""Return a dictionary of extra model inputs that behave like optional forward arguments.
|
||||
|
||||
Returns:
|
||||
A dictionary of extra inputs for the model where the key corresponds to the argument
|
||||
name and the value corresponds to a tuple of (none_input, dynamic_shape_callback):
|
||||
- `none_input`: The none input value of the extra input indicating the tensor
|
||||
value corresponding to the equivalent of the None input. `None` is not supported
|
||||
as we require the input to be a tensor. Hence, this none_input acts as a
|
||||
placeholder for the None input. We assume that the "optional" behavior of these
|
||||
arguments can be represented via a placeholder tensor and and an appropriate
|
||||
check within the forward function using ``torch.cond``.
|
||||
- `dynamic_shape_callback`: A function that returns the dynamic shape of the extra
|
||||
input. Simply set to `None` if the extra input is not dynamic.
|
||||
"""
|
||||
return {}
|
||||
|
||||
|
||||
class ModelFactoryRegistry:
|
||||
_registry: Dict[str, Type[ModelFactory]] = {}
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
import os
|
||||
import types
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -11,11 +11,13 @@ from accelerate import init_empty_weights, load_checkpoint_in_model
|
||||
from accelerate.utils import modeling
|
||||
from huggingface_hub import HfApi, snapshot_download
|
||||
from huggingface_hub.utils import HFValidationError, filter_repo_objects, validate_repo_id
|
||||
from PIL import Image
|
||||
from torch._prims_common import DeviceLikeType
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForImageTextToText,
|
||||
AutoProcessor,
|
||||
AutoTokenizer,
|
||||
PretrainedConfig,
|
||||
)
|
||||
@ -26,7 +28,7 @@ from transformers.utils import (
|
||||
WEIGHTS_NAME,
|
||||
)
|
||||
|
||||
from ..custom_ops.attention_interface import CacheConfig
|
||||
from ..custom_ops.attention_interface import CacheConfig, Dim, DynamicShapeCallback
|
||||
from ..utils._config import deep_merge_dicts
|
||||
from ..utils.logger import ad_logger
|
||||
from .factory import ModelFactory, ModelFactoryRegistry, ShardingConfigSource
|
||||
@ -101,10 +103,6 @@ class AutoModelForCausalLMFactory(ModelFactory):
|
||||
def autoconfig_from_pretrained(self):
|
||||
return AutoConfig.from_pretrained
|
||||
|
||||
@property
|
||||
def autotokenizer_from_pretrained(self):
|
||||
return AutoTokenizer.from_pretrained
|
||||
|
||||
# TODO (@lucaslie): Do we ever want to switch to from_pretrained?
|
||||
@property
|
||||
def automodel_from_config(self):
|
||||
@ -114,7 +112,9 @@ class AutoModelForCausalLMFactory(ModelFactory):
|
||||
def _simple_forward(model: nn.Module, input_ids: torch.Tensor, position_ids: torch.Tensor):
|
||||
"""A simple forward pass for the model to functionalize the args.
|
||||
|
||||
This follows the standard function signature as expected by factory.py.
|
||||
This follows the standard function signature as expected by factory.py. We do _not_ use the
|
||||
model.forward method directly to create the patch. Instead we use the type of the model to
|
||||
get the forward method to keep the patch composable with other forward patches.
|
||||
"""
|
||||
return type(model).forward(model, input_ids=input_ids, position_ids=position_ids)
|
||||
|
||||
@ -158,11 +158,13 @@ class AutoModelForCausalLMFactory(ModelFactory):
|
||||
|
||||
with (init_empty_weights if device == "meta" else nullcontext)():
|
||||
model = self.automodel_from_config(model_config, trust_remote_code=True)
|
||||
|
||||
# post-init --> this must be called explicitly for HF models the way we initialize them since
|
||||
# this "gets lost" with the init_empty_weights context manager.
|
||||
if hasattr(model, "post_init"):
|
||||
model.post_init()
|
||||
if device == "meta":
|
||||
# post-init --> this must be called explicitly for HF models the way we initialize them
|
||||
# since this "gets lost" with the init_empty_weights context manager.
|
||||
if hasattr(model, "post_init"):
|
||||
model.post_init()
|
||||
else:
|
||||
model.to(device)
|
||||
|
||||
# if present, initialize sharding config. We need head_dim for colwise sharding.
|
||||
self._set_sharding_config(model.config)
|
||||
@ -211,7 +213,7 @@ class AutoModelForCausalLMFactory(ModelFactory):
|
||||
"""Initialize the tokenizer—either a custom name or the model's default."""
|
||||
if self.tokenizer is None:
|
||||
return None
|
||||
return self.autotokenizer_from_pretrained(self.tokenizer, **self.tokenizer_kwargs)
|
||||
return AutoTokenizer.from_pretrained(self.tokenizer, **self.tokenizer_kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _get_ignore_patterns(repo_id: str, skip_prefetch_weights: bool) -> List[str]:
|
||||
@ -375,3 +377,102 @@ class AutoModelForImageTextToTextFactory(AutoModelForCausalLMFactory):
|
||||
@property
|
||||
def automodel_from_config(self):
|
||||
return AutoModelForImageTextToText.from_config
|
||||
|
||||
def init_tokenizer(self) -> Optional[Any]:
|
||||
"""Initialize the tokenizer—either a custom name or the model's default."""
|
||||
processor = self.init_processor()
|
||||
if processor is None:
|
||||
return None
|
||||
return processor.tokenizer
|
||||
|
||||
def init_processor(self) -> Optional[Any]:
|
||||
"""Initialize the processor for the model."""
|
||||
if self.tokenizer is None:
|
||||
return None
|
||||
return AutoProcessor.from_pretrained(self.tokenizer, **self.tokenizer_kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _simple_forward(
|
||||
model: nn.Module,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
pixel_values: torch.Tensor,
|
||||
):
|
||||
"""A simple forward pass for the model to functionalize the args.
|
||||
|
||||
This follows the standard function signature as expected by factory.py. We do _not_ use the
|
||||
model.forward method directly to create the patch. Instead we use the type of the model to
|
||||
get the forward method to keep the patch composable with other forward patches.
|
||||
"""
|
||||
return type(model).forward(
|
||||
model,
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
pixel_values=pixel_values,
|
||||
)
|
||||
|
||||
def get_example_inputs(self) -> Dict[str, torch.Tensor]:
|
||||
"""Return a dictionary of example inputs for the model."""
|
||||
|
||||
def _prep_seq(text, img1, img2):
|
||||
return [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "image": img1},
|
||||
{"type": "image", "image": img2},
|
||||
{"type": "text", "text": text},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
# Create a batch of conversations (batch_size = 2)
|
||||
batch_messages = [
|
||||
_prep_seq(
|
||||
"Describe what you see in the two images and their differences.",
|
||||
Image.new("RGB", (16, 16), color=(128, 128, 128)),
|
||||
Image.new("RGB", (16, 16), color=(64, 64, 64)),
|
||||
),
|
||||
_prep_seq(
|
||||
"What are the main differences between these two images?",
|
||||
Image.new("RGB", (16, 16), color=(255, 0, 0)),
|
||||
Image.new("RGB", (16, 16), color=(0, 255, 0)),
|
||||
),
|
||||
]
|
||||
|
||||
processor = AutoProcessor.from_pretrained(self.tokenizer, **self.tokenizer_kwargs)
|
||||
inputs = processor.apply_chat_template(
|
||||
batch_messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
return_attention_mask=False,
|
||||
)
|
||||
|
||||
return {
|
||||
"input_ids": inputs["input_ids"],
|
||||
"pixel_values": inputs["pixel_values"],
|
||||
}
|
||||
|
||||
def get_extra_inputs(self) -> Dict[str, Tuple[torch.Tensor, Optional[DynamicShapeCallback]]]:
|
||||
"""Return a dictionary of extra inputs for the model.
|
||||
|
||||
Returns:
|
||||
A dictionary of extra inputs for the model where the key corresponds to the argument
|
||||
name and the value corresponds to a tuple of (example_input, dynamic_shape_callback).
|
||||
The dynamic shape callback is a function that returns the dynamic shape of the extra
|
||||
input. Simply set to `None` if the extra input is not dynamic.
|
||||
"""
|
||||
|
||||
def _get_dynamic_shape():
|
||||
return {
|
||||
# TODO (lucaslie): how to set default values for dynamic shapes?
|
||||
0: Dim("img_batch_size", max=10),
|
||||
2: Dim("img_height", min=32, max=2048),
|
||||
3: Dim("img_width", min=32, max=2048),
|
||||
}
|
||||
|
||||
none_pixel_values = torch.zeros(0, 3, 336, 336)
|
||||
return {"pixel_values": (none_pixel_values, _get_dynamic_shape)}
|
||||
|
||||
@ -1,15 +1,13 @@
|
||||
"""A patch to handle vision branch in Llama4ForConditionalGeneration."""
|
||||
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from _model_test_utils import _hf_model_dir_or_hub_id
|
||||
from PIL import Image
|
||||
from transformers import AutoConfig, AutoProcessor, Llama4ForConditionalGeneration
|
||||
from transformers.models.llama4.modeling_llama4 import Llama4CausalLMOutputWithPast
|
||||
from utils.llm_data import llm_models_root
|
||||
from transformers import Llama4ForConditionalGeneration
|
||||
from transformers.models.llama4.modeling_llama4 import Llama4CausalLMOutputWithPast, Llama4TextMoe
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
|
||||
from tensorrt_llm._torch.auto_deploy.transformations._graph import move_to_device
|
||||
from ...export.interface import BaseExportPatch, ExportPatchRegistry
|
||||
|
||||
|
||||
# Copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama4/modeling_llama4.py#L1651
|
||||
@ -76,30 +74,34 @@ def _forward_with_cond(
|
||||
vision_feature_select_strategy=vision_feature_select_strategy,
|
||||
image_sizes=None,
|
||||
)
|
||||
original_inputs_embeds_shape = inputs_embeds.shape
|
||||
|
||||
vision_flat = image_features.view(-1, image_features.size(-1))
|
||||
projected_vision_flat = self.multi_modal_projector(vision_flat)
|
||||
projected_vision_flat = self.multi_modal_projector(vision_flat).to(
|
||||
inputs_embeds.device, inputs_embeds.dtype
|
||||
)
|
||||
# NOTE: get_placeholder_mask is not supported by torch.export due to numel check ###########
|
||||
# special_image_mask = self.get_placeholder_mask(
|
||||
# input_ids, inputs_embeds=inputs_embeds, image_features=projected_vision_flat
|
||||
# )
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(
|
||||
self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device
|
||||
)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
||||
final_mask = special_image_mask.to(inputs_embeds.device)
|
||||
inputs_embeds = inputs_embeds.view(-1, inputs_embeds.size(-1))
|
||||
# n_image_tokens = special_image_mask.sum()
|
||||
special_image_mask = (
|
||||
special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
)
|
||||
### END OF get_placeholder_mask ############################################################
|
||||
|
||||
final_mask_1d = final_mask[..., 0].reshape(-1)
|
||||
# num_tokens_to_fill = final_mask_1d.sum()
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, projected_vision_flat)
|
||||
|
||||
# This condition statement breaks torch.export:
|
||||
# TODO: sanity check on the inputs for this
|
||||
# if num_tokens_to_fill != projected_vision_flat.size(0):
|
||||
# raise ValueError(
|
||||
# f"Mismatch: final_mask wants {num_tokens_to_fill} embeddings, "
|
||||
# f"but multi_modal_projector returned {projected_vision_flat.size(0)}"
|
||||
# )
|
||||
|
||||
expanded_mask = final_mask_1d.unsqueeze(-1).expand(-1, inputs_embeds.size(-1))
|
||||
inputs_embeds.masked_scatter_(expanded_mask, projected_vision_flat)
|
||||
|
||||
return inputs_embeds.view(original_inputs_embeds_shape)
|
||||
return inputs_embeds
|
||||
|
||||
def _no_vision_branch(inputs_embeds, pixel_values, input_ids):
|
||||
return inputs_embeds
|
||||
@ -132,7 +134,10 @@ def _forward_with_cond(
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
if attention_mask is not None:
|
||||
# we use the input attention mask to shift the logits and labels, because it is 2D.
|
||||
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
|
||||
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
|
||||
shift_logits = logits[..., :-1, :][
|
||||
shift_attention_mask.to(logits.device) != 0
|
||||
@ -141,6 +146,7 @@ def _forward_with_cond(
|
||||
else:
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = nn.CrossEntropyLoss()
|
||||
loss = loss_fct(
|
||||
shift_logits.view(-1, shift_logits.size(-1)),
|
||||
@ -161,81 +167,65 @@ def _forward_with_cond(
|
||||
)
|
||||
|
||||
|
||||
def test_build_run_llama4_vlm():
|
||||
atol = 1e-3
|
||||
rtol = 1e-3
|
||||
@ExportPatchRegistry.register("hf_llama4_vision")
|
||||
class Llama4VisionPatch(BaseExportPatch):
|
||||
"""Patch for Llama4ForConditionalGeneration to make it compatible with torch.export.
|
||||
|
||||
model_id = _hf_model_dir_or_hub_id(
|
||||
f"{llm_models_root()}/Llama-4-Scout-17B-16E-Instruct",
|
||||
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
)
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
This patch replaces the forward method of Llama4ForConditionalGeneration with
|
||||
a version that uses the torch.cond to handle the optional vision branch.
|
||||
"""
|
||||
|
||||
config = AutoConfig.from_pretrained(model_id)
|
||||
config.text_config.num_hidden_layers = 2
|
||||
config.text_config.intermediate_size = 64
|
||||
config.text_config.intermediate_size_mlp = 128
|
||||
config.vision_config.num_hidden_layers = 2
|
||||
|
||||
# The returned cache <class 'transformers.cache_utils.HybridChunkedCache'> breaks torch.export
|
||||
config.text_config.use_cache = False
|
||||
|
||||
model = Llama4ForConditionalGeneration(config).eval().to("cuda").bfloat16()
|
||||
|
||||
img1 = Image.new("RGB", (16, 16), color=(128, 128, 128))
|
||||
img2 = Image.new("RGB", (16, 16), color=(64, 64, 64))
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "image": img1},
|
||||
{"type": "image", "image": img2},
|
||||
{"type": "text", "text": "What's the difference?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
inputs = (
|
||||
processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
def _apply_patch(self):
|
||||
"""Apply the Llama4 vision patch."""
|
||||
# Store original forward method
|
||||
self.original_values["Llama4ForConditionalGeneration.forward"] = (
|
||||
Llama4ForConditionalGeneration.forward
|
||||
)
|
||||
.to(model.device)
|
||||
.to(torch.bfloat16)
|
||||
)
|
||||
|
||||
with torch.inference_mode():
|
||||
# the original model queried with text-only
|
||||
out_text_only = model(inputs["input_ids"], None, inputs["attention_mask"])
|
||||
# Apply patch by replacing the forward method
|
||||
Llama4ForConditionalGeneration.forward = _forward_with_cond
|
||||
|
||||
Llama4ForConditionalGeneration.forward = _forward_with_cond
|
||||
def _revert_patch(self):
|
||||
"""Revert the Llama4 vision patch."""
|
||||
# Restore original forward method
|
||||
Llama4ForConditionalGeneration.forward = self.original_values[
|
||||
"Llama4ForConditionalGeneration.forward"
|
||||
]
|
||||
|
||||
with torch.inference_mode():
|
||||
out_real = model(inputs["input_ids"], inputs["pixel_values"], inputs["attention_mask"])
|
||||
out_dummy = model(
|
||||
inputs["input_ids"], torch.zeros_like(inputs["pixel_values"]), inputs["attention_mask"]
|
||||
)
|
||||
torch.testing.assert_close(out_dummy.logits, out_text_only.logits, rtol=rtol, atol=atol)
|
||||
|
||||
gm = torch_export_to_gm(
|
||||
model,
|
||||
(inputs["input_ids"], inputs["pixel_values"], inputs["attention_mask"]),
|
||||
kwargs={},
|
||||
)
|
||||
move_to_device(gm, model.device)
|
||||
def _moe_forward_with_transpose(self, hidden_states):
|
||||
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
|
||||
router_scores, router_logits = self.router(hidden_states)
|
||||
routed_in = hidden_states.repeat(router_scores.shape[1], 1)
|
||||
|
||||
with torch.inference_mode():
|
||||
out_real_gm = gm(inputs["input_ids"], inputs["pixel_values"], inputs["attention_mask"])
|
||||
torch.testing.assert_close(out_real.logits, out_real_gm.logits, rtol=rtol, atol=atol)
|
||||
out_dummy_gm = gm(
|
||||
inputs["input_ids"], torch.zeros_like(inputs["pixel_values"]), inputs["attention_mask"]
|
||||
)
|
||||
torch.testing.assert_close(out_dummy.logits, out_dummy_gm.logits, rtol=rtol, atol=atol)
|
||||
torch.testing.assert_close(out_dummy_gm.logits, out_text_only.logits, rtol=rtol, atol=atol)
|
||||
# BUG IN ORIGINAL CODE
|
||||
# routed_in = routed_in * router_scores.reshape(-1, 1)
|
||||
# END OF BUG IN ORIGINAL CODE
|
||||
|
||||
assert not torch.allclose(out_real.logits, out_dummy.logits, rtol=rtol, atol=atol), (
|
||||
"Expected outputs to differ between text only input and text+image input"
|
||||
)
|
||||
# PATCH STARTED
|
||||
routed_in = routed_in * router_scores.transpose(0, 1).reshape(-1, 1)
|
||||
# PATCH ENDED
|
||||
|
||||
routed_out = self.experts(routed_in)
|
||||
out = self.shared_expert(hidden_states)
|
||||
out.add_(routed_out.reshape(router_scores.shape[1], -1, routed_out.shape[-1]).sum(dim=0))
|
||||
return out, router_logits
|
||||
|
||||
|
||||
# TODO: remove this patch once https://github.com/huggingface/transformers/pull/40609 is merged,
|
||||
# gets released, and TRT-LLM updates to the relevant transformers version
|
||||
@ExportPatchRegistry.register("hf_llama4_moe")
|
||||
class Llama4MoEPatch(BaseExportPatch):
|
||||
"""Patch for Llama4 MoE routing to fix its current accuracy issue."""
|
||||
|
||||
def _apply_patch(self):
|
||||
"""Apply the Llama4 MoE routing patch."""
|
||||
# Store original forward method
|
||||
self.original_values["Llama4TextMoe.forward"] = Llama4TextMoe.forward
|
||||
|
||||
# Apply patch by replacing the forward method
|
||||
Llama4TextMoe.forward = _moe_forward_with_transpose
|
||||
|
||||
def _revert_patch(self):
|
||||
"""Revert the Llama4 MoE routing patch."""
|
||||
Llama4TextMoe.forward = self.original_values["Llama4TextMoe.forward"]
|
||||
@ -1,6 +1,6 @@
|
||||
from itertools import chain
|
||||
from collections import defaultdict
|
||||
from types import SimpleNamespace
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch._prims_common import DeviceLikeType
|
||||
@ -105,13 +105,19 @@ class ADEngine(ModelEngine):
|
||||
max_batch_size=max_batch_size,
|
||||
page_size=attn_page_size,
|
||||
max_num_tokens=max_num_tokens,
|
||||
device=device,
|
||||
)
|
||||
|
||||
factory = ad_config.create_factory()
|
||||
|
||||
# pass in extra arguments defined by the model factory
|
||||
for name, (none_input, dynamic_shape_callback) in factory.get_extra_inputs().items():
|
||||
seq_info.add_extra_arg(name, none_input, dynamic_shape_callback)
|
||||
|
||||
# TODO (lucaslie): consider how we move args around InferenceOptimizer.__init__,
|
||||
# ADEngine.__init__, and ADEngine.build_from_config. Seems a bit unnatural atm.
|
||||
|
||||
# construct inference optimizer
|
||||
build_and_optimize = InferenceOptimizer(
|
||||
factory=ad_config.create_factory(), ad_config=ad_config
|
||||
)
|
||||
build_and_optimize = InferenceOptimizer(factory=factory, ad_config=ad_config)
|
||||
|
||||
# construct engine
|
||||
return cls(build_and_optimize, seq_info, device, max_beam_width)
|
||||
@ -176,7 +182,11 @@ class ADEngine(ModelEngine):
|
||||
input_pos: List[int] = []
|
||||
last_logit_only: List[bool] = []
|
||||
page_assignments: List[List[int]] = []
|
||||
previous_batch_indices: List[int] = []
|
||||
flat_gather_idx: List[int] = []
|
||||
extra_args: Dict[str, List[torch.Tensor]] = defaultdict(list)
|
||||
|
||||
dummy_token = -1
|
||||
|
||||
# look at context requests first
|
||||
for request in context_requests:
|
||||
# store input ids and pos of first token in sequence
|
||||
@ -186,6 +196,15 @@ class ADEngine(ModelEngine):
|
||||
request.py_batch_idx = request.seq_slot
|
||||
last_logit_only.append(True)
|
||||
|
||||
# get cache indices
|
||||
cache_indices = kv_cache_manager.get_cache_indices(request)
|
||||
page_assignments.append(cache_indices)
|
||||
|
||||
# store extra arguments
|
||||
if request.py_multimodal_data is not None:
|
||||
for k, v in request.py_multimodal_data.items():
|
||||
extra_args[k].append(v)
|
||||
|
||||
# look at generate requests next
|
||||
# TODO: we should also handle extend requests (for speculative decoding) here
|
||||
for request in gen_requests:
|
||||
@ -194,30 +213,34 @@ class ADEngine(ModelEngine):
|
||||
input_ids.append([request.get_token(0, request.get_num_tokens(0) - 1)])
|
||||
input_pos.append(request.max_beam_num_tokens - 1)
|
||||
else:
|
||||
# insert a dummy token to indicate the new tokens
|
||||
input_ids.append([-1])
|
||||
previous_batch_indices.append(request.py_batch_idx)
|
||||
input_ids.append([dummy_token])
|
||||
input_pos.append(request.max_beam_num_tokens)
|
||||
flat_gather_idx.append(request.py_batch_idx)
|
||||
|
||||
request.py_batch_idx = request.seq_slot
|
||||
|
||||
# return all logits
|
||||
last_logit_only.append(False)
|
||||
|
||||
# extract cache information for all requests
|
||||
for request in chain(context_requests, gen_requests):
|
||||
# get cache indices
|
||||
cache_indices = kv_cache_manager.get_cache_indices(request)
|
||||
page_assignments.append(cache_indices)
|
||||
|
||||
# update the sequence info object now
|
||||
si = self.cache_seq_interface.info
|
||||
si.update_pos(input_pos, reset=True)
|
||||
si.assign_cache_loc(page_assignments)
|
||||
si.nest_sequences(input_ids)
|
||||
|
||||
self.cache_seq_interface.info.nest_sequences(
|
||||
input_ids,
|
||||
input_pos=input_pos,
|
||||
page_assignments=page_assignments,
|
||||
**extra_args,
|
||||
)
|
||||
# scatter the new tokens into the input_ids tensor if provided
|
||||
if new_tokens is not None:
|
||||
si.update_input_ids_with_new_tokens(new_tokens, previous_batch_indices)
|
||||
self.cache_seq_interface.info.rescatter_input_ids(
|
||||
ungathered_input_ids=new_tokens.flatten(), # ensure it's flattened
|
||||
gather_idx=flat_gather_idx,
|
||||
scatter_ref=dummy_token,
|
||||
)
|
||||
|
||||
return last_logit_only
|
||||
|
||||
@nvtx_range("ad_compute_logits")
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
"""A demo LLM api to for debugging and testing purposes of e2e workflows."""
|
||||
|
||||
import gc
|
||||
from collections import defaultdict
|
||||
from queue import Empty
|
||||
from typing import Any, Callable, List, Optional, Tuple
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
@ -10,6 +11,7 @@ import torch.multiprocessing as mp
|
||||
from ....executor import GenerationExecutor
|
||||
from ....executor.request import GenerationRequest
|
||||
from ....executor.result import CompletionOutput, GenerationResult
|
||||
from ....inputs.multimodal import MultimodalParams
|
||||
from ....sampling_params import SamplingParams
|
||||
from ...pyexecutor.sampler import greedy_search_sampling_batch, top_k_sampling_batch
|
||||
from ..distributed import common as dist_ad
|
||||
@ -34,8 +36,11 @@ class DemoEngine(ADEngine):
|
||||
self.queue = mp.Queue()
|
||||
|
||||
@torch.inference_mode()
|
||||
def __call__(self, requests: GenerationRequest) -> mp.Queue:
|
||||
def __call__(
|
||||
self, requests: GenerationRequest, multimodal_params: Optional[MultimodalParams]
|
||||
) -> mp.Queue:
|
||||
"""Generate tokens and put the results in a queue and return the queue."""
|
||||
requests.multimodal_params = multimodal_params
|
||||
output = self.generate_tokens_batched([requests])[0]
|
||||
self.queue.put(output)
|
||||
return self.queue
|
||||
@ -45,7 +50,7 @@ class DemoEngine(ADEngine):
|
||||
self.queue.close()
|
||||
self.queue.join_thread()
|
||||
|
||||
def _assign_pages(self) -> List[List[int]]:
|
||||
def _assign_pages(self, total_lens: List[int]) -> List[List[int]]:
|
||||
"""A simple heuristic to assign pages based on current sequence info.
|
||||
|
||||
In a nutshell, we will look at the following information to update the page assignments:
|
||||
@ -67,7 +72,6 @@ class DemoEngine(ADEngine):
|
||||
unassigned page if needed.
|
||||
"""
|
||||
si = self.cache_seq_interface.info
|
||||
total_lens = [s_l + i_p for s_l, i_p in zip(si.sequence_lengths, si.input_positions)]
|
||||
page_assignments = si.page_assignments
|
||||
|
||||
free_pages = set(range(si.num_pages)) - {i for pages in page_assignments for i in pages}
|
||||
@ -76,7 +80,7 @@ class DemoEngine(ADEngine):
|
||||
extra_tokens = t_l - len(pages) * si.page_size
|
||||
num_extra_pages = (extra_tokens // si.page_size) + (extra_tokens > 0)
|
||||
updated_assignments.append(pages + [free_pages.pop() for _ in range(num_extra_pages)])
|
||||
si.assign_cache_loc(updated_assignments)
|
||||
return updated_assignments
|
||||
|
||||
def generate_tokens_batched(
|
||||
self, requests: List[GenerationRequest]
|
||||
@ -91,10 +95,27 @@ class DemoEngine(ADEngine):
|
||||
)
|
||||
assert sampling_params.best_of == 1, "Best-of is not supported."
|
||||
|
||||
# set up sequence info object
|
||||
# set up sequence info object for decode phase
|
||||
sequence_info = self.cache_seq_interface.info
|
||||
|
||||
input_ids = []
|
||||
total_lens = []
|
||||
extra_args: Dict[str, List[torch.Tensor]] = defaultdict(list)
|
||||
|
||||
for request in requests:
|
||||
total_lens.append(len(request.prompt_token_ids))
|
||||
input_ids.append(request.prompt_token_ids)
|
||||
if request.multimodal_params is not None:
|
||||
for k, v in request.multimodal_params.multimodal_data.items():
|
||||
extra_args[k].append(v)
|
||||
|
||||
sequence_info.reset()
|
||||
sequence_info.nest_sequences([r.prompt_token_ids for r in requests])
|
||||
sequence_info.nest_sequences(
|
||||
input_ids=input_ids,
|
||||
input_pos=0,
|
||||
page_assignments=self._assign_pages(total_lens),
|
||||
**extra_args,
|
||||
)
|
||||
|
||||
# setup objects we want to track for the output
|
||||
batch_size = sequence_info.num_sequences
|
||||
@ -105,18 +126,21 @@ class DemoEngine(ADEngine):
|
||||
context_logits: Optional[List[torch.Tensor]] = None
|
||||
|
||||
def _generate_single_step(idx: int):
|
||||
# assign pages
|
||||
self._assign_pages()
|
||||
|
||||
# get the logits and then last token logits in each sequence ([b, 1, vocab_size])
|
||||
logits = self._compute_logits()
|
||||
logits_last = torch.stack([l_one_seq[-1] for l_one_seq in logits]).float().unsqueeze(1)
|
||||
|
||||
token_ids, _ = self._decode_tokens(logits_last, sampling_params) # [b,1]
|
||||
|
||||
# update sequence info accordingly for next step
|
||||
sequence_info.update_pos(sequence_info.sequence_lengths)
|
||||
sequence_info.nest_sequences(token_ids)
|
||||
# update sequence info accordingly for next step (generate phase)
|
||||
input_pos_next = sequence_info.input_pos
|
||||
seq_lens_current = sequence_info.seq_len
|
||||
input_pos_next = [ip + sl for ip, sl in zip(input_pos_next, seq_lens_current)]
|
||||
total_lens_next = [ip + len(t_ids) for ip, t_ids in zip(input_pos_next, token_ids)]
|
||||
sequence_info.nest_sequences(
|
||||
token_ids,
|
||||
input_pos=input_pos_next,
|
||||
page_assignments=self._assign_pages(total_lens_next),
|
||||
)
|
||||
|
||||
# nest new tokens and run stop check
|
||||
for b, (new_tokens_b, new_id) in enumerate(zip(new_tokens, token_ids)):
|
||||
@ -255,6 +279,7 @@ class DemoGenerationExecutor(GenerationExecutor):
|
||||
def _unpack(inputs) -> GenerationRequest:
|
||||
args, kwargs = inputs # unpack the inputs
|
||||
request: GenerationRequest = args[0]
|
||||
request.multimodal_params: Optional[MultimodalParams] = args[1]
|
||||
return request
|
||||
|
||||
engine = DemoEngine.build_from_config(**engine_kwargs)
|
||||
@ -309,8 +334,11 @@ class DemoGenerationExecutor(GenerationExecutor):
|
||||
request.set_id(client_id)
|
||||
|
||||
# submit request to our demo engine and store results
|
||||
# NOTE: when returning from this function, the reference request.multimodal_params will
|
||||
# be cleared immediately. So we pass it in explicitly to maintain a reference even when
|
||||
# requests get submitted asynchronously.
|
||||
result = GenerationResult(request)
|
||||
result.queue = self.engine_executor(request)
|
||||
result.queue = self.engine_executor(request, request.multimodal_params)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@ -63,7 +63,7 @@ class ExportToGM(BaseTransform):
|
||||
model = gm.get_submodule("factory_model")
|
||||
|
||||
# set the example sequence
|
||||
cm.info.set_example_sequence()
|
||||
cm.info.set_example_sequence(**factory.get_example_inputs())
|
||||
|
||||
# export the model to a graph module
|
||||
gm = torch_export_to_gm(
|
||||
|
||||
@ -44,9 +44,6 @@ class UpdateInOutNodes(BaseTransform):
|
||||
# loop through nodes to get input, output, and get_attr nodes
|
||||
input_nodes, output_nodes = get_all_input_output_nodes(gm.graph)
|
||||
|
||||
# we only expect one input node
|
||||
assert len(input_nodes) == 2, "Expected exactly two input nodes (input_ids, position_ids)."
|
||||
|
||||
# NOTE: for now, we wanna make sure we *only* return the final output and no hidden states.
|
||||
# Later on, we can revisit how to support returning hidden states.
|
||||
assert len(output_nodes) == 1, "Expected exactly one output node!"
|
||||
@ -117,16 +114,17 @@ class InsertCachedAttention(BaseTransform):
|
||||
|
||||
# retrieve input nodes
|
||||
input_nodes, _ = get_all_input_output_nodes(gm.graph)
|
||||
input_nodes_mapping = {n.target: n for n in input_nodes}
|
||||
|
||||
# filtered and sorted for SequenceInfo arguments + constants (input_ids, position_ids, etc.)
|
||||
inputs_from_info = [input_nodes_mapping[k] for k in cm.info.named_standard_args.keys()]
|
||||
constants_from_info = cm.info.const_args_for_prepare_metadata
|
||||
|
||||
# insert metadata computation and extract each argument as a node
|
||||
get_metadata, num_metadata = attn_descriptor.get_prepare_metadata_op()
|
||||
with graph.inserting_before(input_nodes[-1].next):
|
||||
ret_node = graph.call_function(
|
||||
get_metadata,
|
||||
args=(
|
||||
*input_nodes,
|
||||
cm.info.page_size,
|
||||
),
|
||||
get_metadata, args=(*inputs_from_info, *constants_from_info)
|
||||
)
|
||||
metadata_nodes = [
|
||||
graph.call_function(operator.getitem, args=(ret_node, idx))
|
||||
@ -244,7 +242,7 @@ class ResizeKVCache(BaseTransform):
|
||||
|
||||
try:
|
||||
# Let's run a forward pass to get the memory usage
|
||||
cm.info._set_max_num_tokens_sample()
|
||||
cm.info.set_max_num_tokens_sample()
|
||||
free_mem_pre, _ = _get_mem_info_in_mb()
|
||||
ad_logger.info(f"Free memory before forward pass (MB): {free_mem_pre}")
|
||||
|
||||
|
||||
@ -18,8 +18,7 @@ from .._utils import (KVCacheEventSerializer, global_mpi_rank, global_mpi_size,
|
||||
mpi_comm, mpi_rank, nvtx_range_debug)
|
||||
from ..bindings import executor as tllm
|
||||
from ..builder import ConfigEncoder, Engine, EngineConfig
|
||||
from ..llmapi.llm_args import (BaseLlmArgs, KvCacheConnectorConfig,
|
||||
PybindMirror, TorchLlmArgs)
|
||||
from ..llmapi.llm_args import BaseLlmArgs, KvCacheConnectorConfig, PybindMirror
|
||||
from ..llmapi.mpi_session import set_mpi_session_cpp
|
||||
from ..llmapi.tokenizer import TokenizerBase
|
||||
from ..llmapi.tracer import VizTracer, global_tracer, set_global_tracer
|
||||
@ -86,7 +85,9 @@ class GenerationExecutorWorker(GenerationExecutor):
|
||||
self._await_response_helper = AwaitResponseHelper(
|
||||
self) # TODO: make it weakref
|
||||
self._executor_config = executor_config
|
||||
self._is_pytorch_backend = llm_args is not None and llm_args.backend == "pytorch"
|
||||
self._is_pytorch_backend = llm_args is not None and llm_args.backend in [
|
||||
"pytorch", "_autodeploy"
|
||||
]
|
||||
self.llm_args = llm_args
|
||||
|
||||
if not self._is_pytorch_backend and kv_connector_config is not None:
|
||||
@ -468,7 +469,6 @@ class GenerationExecutorWorker(GenerationExecutor):
|
||||
)
|
||||
|
||||
if self._is_pytorch_backend:
|
||||
assert isinstance(self.llm_args, TorchLlmArgs)
|
||||
if not self.llm_args.disable_overlap_scheduler:
|
||||
is_disaggregated = self.engine.kv_cache_transceiver is not None
|
||||
if is_disaggregated and (
|
||||
|
||||
@ -2317,6 +2317,7 @@ class TestQwen3_8B(LlmapiAccuracyTestHarness):
|
||||
MODEL_NAME = "Qwen3/Qwen3-8B"
|
||||
|
||||
@skip_pre_hopper
|
||||
@pytest.mark.skip(reason="https://nvbugs/5505402")
|
||||
@pytest.mark.parametrize(
|
||||
"tp_size,pp_size,ep_size,attention_dp,cuda_graph,overlap_scheduler",
|
||||
[(1, 1, 1, False, True, True)],
|
||||
|
||||
@ -37,6 +37,7 @@ l0_b200:
|
||||
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[True-True-cutlass]
|
||||
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[True-True-trtllm]
|
||||
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[True-True-triton]
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_fp8_block_scales[latency] # nvbugs 5505402
|
||||
- disaggregated/test_workers.py::test_workers_kv_cache_aware_router_eviction[TinyLlama-1.1B-Chat-v1.0] # nvbugs 5300551
|
||||
- test_e2e.py::test_ptp_quickstart_advanced[Llama3.1-8B-NVFP4-nvfp4-quantized/Meta-Llama-3.1-8B]
|
||||
- test_e2e.py::test_ptp_quickstart_advanced[Llama3.1-8B-FP8-llama-3.1-model/Llama-3.1-8B-Instruct-FP8]
|
||||
|
||||
@ -86,33 +86,12 @@ l0_dgx_h100:
|
||||
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_w4afp8[MoEWeightLoadingMode.VANILLA-dtype1]
|
||||
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_w4afp8[MoEWeightLoadingMode.W4A8_CUSTOM-dtype0]
|
||||
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_w4afp8[MoEWeightLoadingMode.W4A8_CUSTOM-dtype1]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=True-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=False-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=2-fp8kv=True-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=2-fp8kv=False-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[ep4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[ep4-mtp_nextn=0-fp8kv=True-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[ep4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[ep4-mtp_nextn=0-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[ep4-mtp_nextn=2-fp8kv=True-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[ep4-mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[ep4-mtp_nextn=2-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=0-fp8kv=True-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=0-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=2-fp8kv=False-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[pp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[pp4-mtp_nextn=0-fp8kv=False-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=True-attention_dp=False-cuda_graph=True-overlap_scheduler=True-torch_compile=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=False-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[pp4-mtp_nextn=0-fp8kv=True-attention_dp=False-cuda_graph=True-overlap_scheduler=True-torch_compile=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=0-fp8kv=True-attention_dp=False-cuda_graph=True-overlap_scheduler=True-torch_compile=True]
|
||||
|
||||
@ -50,6 +50,27 @@ l0_dgx_h200:
|
||||
stage: post_merge
|
||||
backend: pytorch
|
||||
tests:
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=True-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=False-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=2-fp8kv=True-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=2-fp8kv=False-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[ep4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[ep4-mtp_nextn=0-fp8kv=True-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=0-fp8kv=True-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=2-fp8kv=False-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=False-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[ep4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[ep4-mtp_nextn=2-fp8kv=True-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[ep4-mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[pp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[tp4-mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[tp4-mtp_nextn=0-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[tp4-mtp_nextn=0-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False]
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import copy
|
||||
from typing import Callable, Dict, List, Optional
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -39,19 +39,45 @@ class FakeFactory(ModelFactory):
|
||||
|
||||
|
||||
class SequenceEmbeddingInfo(SequenceInfo):
|
||||
hidden_size: int
|
||||
dtype: torch.dtype
|
||||
"""A sequence info object for testing that replaces the input_ids with an embedding tensor.
|
||||
|
||||
def set_example_sequence(self) -> None:
|
||||
super().set_example_sequence()
|
||||
# set input ids to a 3D tensor (actually input embeddings)
|
||||
self.input_ids = torch.rand(
|
||||
*self.input_ids.shape,
|
||||
This is useful to run tests without the tokenizer in the loop.
|
||||
"""
|
||||
|
||||
def _add_hidden_dim(self, input_ids: Sequence[Sequence[Any]]) -> torch.Tensor:
|
||||
return torch.rand(
|
||||
*input_ids.shape,
|
||||
self.hidden_size,
|
||||
device=self.input_ids.device,
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
def __init__(self, *args, hidden_size: int, dtype: torch.dtype, **kwargs):
|
||||
self._initialized = False
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# overwrite input_ids with an embedding tensor and run reset again
|
||||
self.hidden_size = hidden_size
|
||||
self.dtype = dtype
|
||||
self._args_device["input_ids"] = self._add_hidden_dim(self._args_device["input_ids"])
|
||||
self._args_host["input_ids"] = self._args_device["input_ids"].cpu()
|
||||
self._initialized = True
|
||||
self.reset()
|
||||
|
||||
def nest_sequences(self, input_ids: Sequence[Sequence[Any]], *args, **kwargs) -> None:
|
||||
# convert input_ids to an embedding tensor if needed
|
||||
if not (isinstance(input_ids, torch.Tensor) and input_ids.ndim == 3) and self._initialized:
|
||||
# first convert to a list of tensors
|
||||
input_embeds = [
|
||||
torch.tensor(ids, device=self.device, dtype=self.dtype) for ids in input_ids
|
||||
]
|
||||
# then add the hidden dimension to every tensor
|
||||
input_embeds = [self._add_hidden_dim(ids) for ids in input_embeds]
|
||||
else:
|
||||
input_embeds = input_ids
|
||||
|
||||
super().nest_sequences(input_embeds, *args, **kwargs)
|
||||
|
||||
|
||||
def count_parameters(model: torch.nn.Module):
|
||||
for n, p in model.named_parameters():
|
||||
|
||||
@ -400,9 +400,6 @@ _SMALL_MODEL_CONFIGS = {
|
||||
},
|
||||
"vision_config": {
|
||||
"num_hidden_layers": 1,
|
||||
"hidden_size": 64,
|
||||
"intermediate_size": 64,
|
||||
"num_attention_heads": 2,
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
@ -0,0 +1,108 @@
|
||||
import torch
|
||||
from _model_test_utils import get_small_model_config
|
||||
from build_and_run_ad import ExperimentConfig
|
||||
from PIL import Image
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy import LlmArgs
|
||||
from tensorrt_llm._torch.auto_deploy.export import apply_export_patches, torch_export_to_gm
|
||||
from tensorrt_llm._torch.auto_deploy.transformations._graph import move_to_device
|
||||
|
||||
|
||||
def test_build_run_llama4_vlm():
|
||||
atol = 1e-3
|
||||
rtol = 1e-3
|
||||
|
||||
experiment_config = get_small_model_config("meta-llama/Llama-4-Scout-17B-16E-Instruct")
|
||||
experiment_config["args"]["model_kwargs"]["_attn_implementation"] = "eager"
|
||||
experiment_config = ExperimentConfig(**experiment_config)
|
||||
llm_args: LlmArgs = experiment_config.args
|
||||
|
||||
factory = llm_args.create_factory()
|
||||
model = factory.build_model("cuda")
|
||||
processor = factory.init_processor()
|
||||
|
||||
img1 = Image.new("RGB", (16, 16), color=(128, 128, 128))
|
||||
img2 = Image.new("RGB", (16, 16), color=(64, 64, 64))
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "image": img1},
|
||||
{"type": "image", "image": img2},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Describe what you see in the two images and their differences.",
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
inputs = (
|
||||
processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
.to(model.device)
|
||||
.to(torch.bfloat16)
|
||||
)
|
||||
|
||||
# get relevant inputs
|
||||
input_ids = inputs["input_ids"]
|
||||
position_ids = torch.arange(input_ids.shape[1], device=input_ids.device).repeat(
|
||||
input_ids.shape[0], 1
|
||||
)
|
||||
pixel_values = inputs["pixel_values"]
|
||||
|
||||
def _run_with_and_without_image(model, use_patch=True):
|
||||
with apply_export_patches(patch_list=["hf_llama4_vision"] if use_patch else []):
|
||||
with torch.inference_mode():
|
||||
out_no_images = model(
|
||||
input_ids,
|
||||
position_ids,
|
||||
torch.zeros_like(pixel_values) if use_patch else None,
|
||||
)
|
||||
out_with_images = model(
|
||||
input_ids,
|
||||
position_ids,
|
||||
pixel_values,
|
||||
)
|
||||
return {"no_images": out_no_images.logits, "with_images": out_with_images.logits}
|
||||
|
||||
# Get output pre-patch
|
||||
out_original = _run_with_and_without_image(model, use_patch=False)
|
||||
|
||||
# Get output post-patch
|
||||
outputs_for_comparison = {}
|
||||
outputs_for_comparison["model_with_patch"] = _run_with_and_without_image(model)
|
||||
|
||||
# Export to GM
|
||||
gm = torch_export_to_gm(
|
||||
model,
|
||||
args=(input_ids, position_ids, pixel_values),
|
||||
patch_list=[
|
||||
"transformers_sdpa_mask",
|
||||
"autocast_noop",
|
||||
"torch_where",
|
||||
"tensor_meta_device",
|
||||
"sdpa_kernel_noop",
|
||||
"sdpa",
|
||||
"hf_llama4_vision",
|
||||
],
|
||||
)
|
||||
move_to_device(gm, model.device)
|
||||
|
||||
# Get the output post export
|
||||
outputs_for_comparison["gm"] = _run_with_and_without_image(gm)
|
||||
|
||||
# Run comparisons to out_original with no patch now...
|
||||
for comp, outs in outputs_for_comparison.items():
|
||||
torch.testing.assert_close(
|
||||
outs,
|
||||
out_original,
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
msg=lambda m: f"Comparison for {comp} failed:\n{m}",
|
||||
)
|
||||
@ -71,7 +71,6 @@ def test_engine(engine_cls: Type[ADEngine], attn_backend: str, attn_page_size: i
|
||||
input_ids = [torch.tensor([0, 1, 2], device=device)]
|
||||
sequence_info.reset()
|
||||
sequence_info.nest_sequences(input_ids)
|
||||
engine.cache_seq_interface.info.sync(sequence_info)
|
||||
logits = engine._compute_logits()
|
||||
logits = torch.stack(logits)
|
||||
assert logits is not None, "Logits are None"
|
||||
@ -106,7 +105,6 @@ def test_demo_engine_sampling(attn_page_size: int):
|
||||
input_ids = [torch.tensor([1, 2, 3, 4], device=device)]
|
||||
sequence_info.reset()
|
||||
sequence_info.nest_sequences(input_ids)
|
||||
engine.cache_seq_interface.info.sync(sequence_info)
|
||||
logits = engine._compute_logits()
|
||||
logits = torch.stack(logits)
|
||||
|
||||
|
||||
@ -92,7 +92,7 @@ def test_config_params():
|
||||
@patch("tensorrt_llm._torch.auto_deploy.llm.DemoGenerationExecutor")
|
||||
@patch("tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface.SequenceInfo")
|
||||
@patch("tensorrt_llm._torch.auto_deploy.shim.demollm.dist_ad.initialize_or_skip")
|
||||
@patch("tensorrt_llm._torch.auto_deploy.llm.create_input_processor")
|
||||
@patch("tensorrt_llm._torch.auto_deploy.llm.LLM._create_input_processor")
|
||||
@patch("tensorrt_llm._torch.auto_deploy.llm.LLM._build_model")
|
||||
def test_config_flow(
|
||||
mock_build_model,
|
||||
@ -147,13 +147,6 @@ def test_config_flow(
|
||||
pass
|
||||
|
||||
|
||||
def test_invalid_model_factory():
|
||||
"""Test behavior with invalid model factory."""
|
||||
# Pydantic validates Literal types at runtime, so this should raise ValidationError
|
||||
with pytest.raises(Exception): # Could be ValidationError or ValueError
|
||||
LlmArgs(model="test-model", model_factory="InvalidFactory")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"parallel_field,invalid_value",
|
||||
[
|
||||
|
||||
@ -68,7 +68,7 @@ def _check_ad_config(experiment_config: ExperimentConfig, llm_args: LlmArgs):
|
||||
get_small_model_config(
|
||||
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
attn_backend="flashinfer",
|
||||
compile_backend="torch-opt",
|
||||
compile_backend="torch-simple",
|
||||
),
|
||||
get_small_model_config(
|
||||
"deepseek-ai/DeepSeek-V3",
|
||||
|
||||
@ -54,17 +54,19 @@ class GQAWithSdpa(GQA):
|
||||
self.num_key_value_groups = None
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x: torch.Tensor, position_ids: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
def forward(
|
||||
self, input_ids: torch.Tensor, position_ids: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass with input tokens and optional position ids.
|
||||
position_ids parameter added to match expected interface in kvcache.py
|
||||
"""
|
||||
b, s, _ = x.shape
|
||||
b, s, _ = input_ids.shape
|
||||
|
||||
# Project input to q, k, v representations
|
||||
q = self.q_proj(x) # [b, s, n*h_d]
|
||||
k = self.k_proj(x) # [b, s, n_kv*h_d]
|
||||
v = self.v_proj(x) # [b, s, n_kv*h_d]
|
||||
q = self.q_proj(input_ids) # [b, s, n*h_d]
|
||||
k = self.k_proj(input_ids) # [b, s, n_kv*h_d]
|
||||
v = self.v_proj(input_ids) # [b, s, n_kv*h_d]
|
||||
|
||||
# Reshape to [b, s, n, h_d]
|
||||
q = q.view(b, s, self.num_heads, self.head_dim)
|
||||
@ -126,9 +128,9 @@ def test_sdpa_with_kv_cache(dtype, attn_backend, gqa_config):
|
||||
ci = SequenceEmbeddingInfo(
|
||||
max_seq_len=max_position_embeddings,
|
||||
max_batch_size=batch_size,
|
||||
hidden_size=hidden_size,
|
||||
dtype=dtype,
|
||||
)
|
||||
ci.hidden_size = hidden_size
|
||||
ci.dtype = dtype
|
||||
cm = CachedSequenceInterface(sequence_info=ci, device="cuda")
|
||||
|
||||
# Create the model with SDPA and wrap it in a fake factory
|
||||
@ -180,9 +182,9 @@ def test_sdpa_with_kv_cache(dtype, attn_backend, gqa_config):
|
||||
cm.initialize_caches()
|
||||
|
||||
# Helper function to call the model with proper sequence nesting
|
||||
def _call_and_unnest(x):
|
||||
def _call_and_unnest(x, input_pos):
|
||||
# Use nest_sequences to properly set input_ids and automatically update position_ids
|
||||
cm.info.nest_sequences(x, allow_realloc=True)
|
||||
cm.info.nest_sequences(x, input_pos=input_pos)
|
||||
|
||||
# Use the cm.args as is - it already contains the correct position_ids
|
||||
y = gm(*cm.args)
|
||||
@ -192,31 +194,25 @@ def test_sdpa_with_kv_cache(dtype, attn_backend, gqa_config):
|
||||
|
||||
# Test 1: Regular inference (all tokens at once)
|
||||
cm.info.reset()
|
||||
y_no_cache = _call_and_unnest(x)
|
||||
y_no_cache = _call_and_unnest(x, 0)
|
||||
assert all_close(y_model, y_no_cache, atol=atol, rtol=rtol)
|
||||
|
||||
# Test 2: Autoregressive inference with KV cache
|
||||
cm.info.reset()
|
||||
y_with_cache = torch.empty_like(y_model)
|
||||
for i in range(x.shape[1]):
|
||||
for i_p in range(x.shape[1]):
|
||||
# Just pass the current token
|
||||
y_with_cache[:, i : i + 1] = _call_and_unnest(x[:, i : i + 1])
|
||||
# Update position for next token
|
||||
cm.info.update_pos(1) # This automatically updates position_ids too
|
||||
y_with_cache[:, i_p : i_p + 1] = _call_and_unnest(x[:, i_p : i_p + 1], i_p)
|
||||
assert all_close(y_model, y_with_cache, atol=atol, rtol=rtol)
|
||||
|
||||
# Test 3: Cache continuation after random tokens
|
||||
cm.info.update_pos(-num_reset_steps) # Rewind position
|
||||
for i in range(num_random_steps):
|
||||
_call_and_unnest(torch.rand_like(x[:, :1]))
|
||||
cm.info.update_pos(1)
|
||||
for i_p in range(x.shape[1] - num_reset_steps, x.shape[1] - num_reset_steps + num_random_steps):
|
||||
_call_and_unnest(torch.rand_like(x[:, :1]), i_p)
|
||||
|
||||
# Continue inference from previous context
|
||||
cm.info.reset()
|
||||
cm.info.update_pos(x.shape[1] - num_reset_steps)
|
||||
for i in range(x.shape[1] - num_reset_steps, x.shape[1]):
|
||||
y_with_cache[:, i : i + 1] = _call_and_unnest(x[:, i : i + 1])
|
||||
cm.info.update_pos(1)
|
||||
for i_p in range(x.shape[1] - num_reset_steps, x.shape[1]):
|
||||
y_with_cache[:, i_p : i_p + 1] = _call_and_unnest(x[:, i_p : i_p + 1], i_p)
|
||||
assert all_close(y_model, y_with_cache, atol=atol, rtol=rtol)
|
||||
|
||||
# Test 4: Exportability of the transformed model
|
||||
|
||||
@ -16,6 +16,7 @@ from tensorrt_llm.llmapi import (CudaGraphConfig, EagleDecodingConfig,
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="https://nvbugs/5461761")
|
||||
@pytest.mark.parametrize(
|
||||
"use_cuda_graph,attn_backend,disable_overlap_scheduler,enable_block_reuse,use_one_model,enable_chunked_prefill",
|
||||
[
|
||||
|
||||
Loading…
Reference in New Issue
Block a user