mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Merge branch 'main' into tools/wave-analyzer
This commit is contained in:
commit
e2f48eb400
@ -1460,6 +1460,11 @@ repos:
|
||||
entry: ./scripts/format_test_list.py
|
||||
language: script
|
||||
files: tests/integration/test_lists/.*\.txt$
|
||||
- id: waive list check
|
||||
name: Checks for duplicated test items in waives.txt
|
||||
entry: ./scripts/check_test_list.py --check-duplicate-waives
|
||||
language: script
|
||||
pass_filenames: false
|
||||
- id: DCO check
|
||||
name: Checks the commit message for a developer certificate of origin signature
|
||||
entry: ./scripts/dco_check.py
|
||||
|
||||
@ -536,8 +536,8 @@ void help()
|
||||
"- \"num_tokens\" - The total number of tokens to benchmark\n"
|
||||
"- \"bias\" - If bias should be used, 0 = no bias, 1 = bias\n"
|
||||
"- \"do_final_scale\" - If final scales should be applied, 0 = no scale, 1 = scale\n"
|
||||
"- \"act_fn\" - The activation function to use, 0 = identity, 1 = relu, 2 = gelu, 3 = silu, 4 = geglu, 5 = "
|
||||
"swiglu\n"
|
||||
"- \"act_fn\" - The activation function to use, 1 = identity, 2 = gelu, 3 = relu, 4 = silu, 5 = swiglu, 6 = "
|
||||
"geglu, 7 = swiglu_bias, 8 = relu2\n"
|
||||
"- \"tactic_id1, tactic_id2\"\n"
|
||||
"The config for the CUTLASS GEMM. tactic_idX sets the tactic for the corresponding GEMM"
|
||||
"Valid tactics are:\n"
|
||||
|
||||
@ -31,6 +31,7 @@ namespace
|
||||
{
|
||||
using ElemCopyType = uint4;
|
||||
using SFCopyType = uint32_t;
|
||||
using ActivationType = tensorrt_llm::kernels::cutlass_kernels::ActivationType;
|
||||
|
||||
template <typename T>
|
||||
auto constexpr bitsPerElem()
|
||||
@ -385,23 +386,43 @@ void moeActivation(InputType const* input, OutputType* output, float const* glob
|
||||
int32_t const blocks = std::min(smCount, max_num_permuted_tokens);
|
||||
int32_t const threads = kThreadsPerBlock;
|
||||
|
||||
auto kernel_array
|
||||
= std::array{&moeActivationKernel<InputType, OutputType, SFType, kSFVecSize,
|
||||
cutlass_kernels::IdentityAdaptor<cutlass::epilogue::thread::GELU>, kThreadsPerBlock>,
|
||||
&moeActivationKernel<InputType, OutputType, SFType, kSFVecSize,
|
||||
cutlass_kernels::IdentityAdaptor<cutlass::epilogue::thread::ReLu>, kThreadsPerBlock>,
|
||||
&moeActivationKernel<InputType, OutputType, SFType, kSFVecSize,
|
||||
cutlass_kernels::IdentityAdaptor<cutlass::epilogue::thread::SiLu>, kThreadsPerBlock>,
|
||||
&moeActivationKernel<InputType, OutputType, SFType, kSFVecSize,
|
||||
cutlass_kernels::GLUAdaptor<cutlass::epilogue::thread::SiLu>, kThreadsPerBlock>,
|
||||
&moeActivationKernel<InputType, OutputType, SFType, kSFVecSize,
|
||||
cutlass_kernels::GLUAdaptor<cutlass::epilogue::thread::GELU>, kThreadsPerBlock>,
|
||||
&moeActivationKernel<InputType, OutputType, SFType, kSFVecSize, cutlass_kernels::SwigluBiasAdaptor,
|
||||
kThreadsPerBlock>,
|
||||
&moeActivationKernel<InputType, OutputType, SFType, kSFVecSize,
|
||||
cutlass_kernels::IdentityAdaptor<cutlass::epilogue::thread::Identity>, kThreadsPerBlock>};
|
||||
|
||||
auto kernel = kernel_array[static_cast<int32_t>(activation_params.activation_type)];
|
||||
auto get_act_kernel = [](ActivationType activation_type) -> void (*)(InputType const* input, OutputType* output,
|
||||
float const* global_sf, SFType* output_sf,
|
||||
int32_t const* tile_idx_to_mn_limit,
|
||||
int32_t const* num_non_exiting_tiles,
|
||||
int32_t const interm_size, int32_t const tile_size)
|
||||
{
|
||||
switch (activation_type)
|
||||
{
|
||||
case ActivationType::Identity:
|
||||
return &moeActivationKernel<InputType, OutputType, SFType, kSFVecSize,
|
||||
cutlass_kernels::IdentityAdaptor<cutlass::epilogue::thread::Identity>, kThreadsPerBlock>;
|
||||
case ActivationType::Gelu:
|
||||
return &moeActivationKernel<InputType, OutputType, SFType, kSFVecSize,
|
||||
cutlass_kernels::IdentityAdaptor<cutlass::epilogue::thread::GELU>, kThreadsPerBlock>;
|
||||
case ActivationType::Geglu:
|
||||
return &moeActivationKernel<InputType, OutputType, SFType, kSFVecSize,
|
||||
cutlass_kernels::GLUAdaptor<cutlass::epilogue::thread::GELU>, kThreadsPerBlock>;
|
||||
case ActivationType::Relu:
|
||||
return &moeActivationKernel<InputType, OutputType, SFType, kSFVecSize,
|
||||
cutlass_kernels::IdentityAdaptor<cutlass::epilogue::thread::ReLu>, kThreadsPerBlock>;
|
||||
case ActivationType::Silu:
|
||||
return &moeActivationKernel<InputType, OutputType, SFType, kSFVecSize,
|
||||
cutlass_kernels::IdentityAdaptor<cutlass::epilogue::thread::SiLu>, kThreadsPerBlock>;
|
||||
case ActivationType::Swiglu:
|
||||
return &moeActivationKernel<InputType, OutputType, SFType, kSFVecSize,
|
||||
cutlass_kernels::GLUAdaptor<cutlass::epilogue::thread::SiLu>, kThreadsPerBlock>;
|
||||
case ActivationType::SwigluBias:
|
||||
return &moeActivationKernel<InputType, OutputType, SFType, kSFVecSize, cutlass_kernels::SwigluBiasAdaptor,
|
||||
kThreadsPerBlock>;
|
||||
case ActivationType::Relu2:
|
||||
// Unsupported activation type
|
||||
break;
|
||||
}
|
||||
TLLM_CHECK_WITH_INFO(false, "Unsupported activation type: %d", int(activation_type));
|
||||
return nullptr;
|
||||
};
|
||||
auto kernel = get_act_kernel(activation_params.activation_type);
|
||||
|
||||
cudaLaunchConfig_t config;
|
||||
config.gridDim = blocks;
|
||||
|
||||
@ -23,15 +23,15 @@ namespace tensorrt_llm::kernels::cutlass_kernels
|
||||
// cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu::doActivationKernel().
|
||||
enum class ActivationType
|
||||
{
|
||||
Gelu = 0,
|
||||
Relu,
|
||||
Silu,
|
||||
Swiglu,
|
||||
Geglu,
|
||||
SwigluBias,
|
||||
Identity,
|
||||
Relu2,
|
||||
InvalidType
|
||||
InvalidType = 0,
|
||||
Identity = 1,
|
||||
Gelu = 2,
|
||||
Relu = 3,
|
||||
Silu = 4,
|
||||
Swiglu = 5,
|
||||
Geglu = 6,
|
||||
SwigluBias = 7,
|
||||
Relu2 = 8,
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::kernels::cutlass_kernels
|
||||
|
||||
@ -2244,29 +2244,39 @@ void doActivation(T* output, GemmOutputType const* gemm_result, float const* fp8
|
||||
{
|
||||
// IMPORTANT: Keep the order of the activation functions in the same order as the ActivationType enum in
|
||||
// common.h
|
||||
auto fn = [&](auto block_scaling_type)
|
||||
auto fn
|
||||
= [&](auto block_scaling_type) -> void (*)(T*, GemmOutputType const*, float const*, ScaleBiasType const*,
|
||||
bool, int64_t const*, int, int64_t, float const*, bool,
|
||||
TmaWarpSpecializedGroupedGemmInput::ElementSF*, ActivationParams)
|
||||
{
|
||||
auto fn_list = std::array{
|
||||
&doActivationKernel<T, GemmOutputType, ScaleBiasType, IdentityAdaptor<cutlass::epilogue::thread::GELU>,
|
||||
decltype(block_scaling_type)::value>, // Gelu
|
||||
&doActivationKernel<T, GemmOutputType, ScaleBiasType, IdentityAdaptor<cutlass::epilogue::thread::ReLu>,
|
||||
decltype(block_scaling_type)::value>, // Relu
|
||||
&doActivationKernel<T, GemmOutputType, ScaleBiasType, IdentityAdaptor<cutlass::epilogue::thread::SiLu>,
|
||||
decltype(block_scaling_type)::value>, // Silu
|
||||
&doActivationKernel<T, GemmOutputType, ScaleBiasType, GLUAdaptor<cutlass::epilogue::thread::SiLu>,
|
||||
decltype(block_scaling_type)::value>, // Swiglu
|
||||
&doActivationKernel<T, GemmOutputType, ScaleBiasType, GLUAdaptor<cutlass::epilogue::thread::GELU>,
|
||||
decltype(block_scaling_type)::value>, // Geglu
|
||||
&doActivationKernel<T, GemmOutputType, ScaleBiasType, SwigluBiasAdaptor,
|
||||
decltype(block_scaling_type)::value>, // SwigluBias
|
||||
&doActivationKernel<T, GemmOutputType, ScaleBiasType,
|
||||
IdentityAdaptor<cutlass::epilogue::thread::Identity>,
|
||||
decltype(block_scaling_type)::value>, // Identity
|
||||
&doActivationKernel<T, GemmOutputType, ScaleBiasType, IdentityAdaptor<cutlass::epilogue::thread::Relu2>,
|
||||
decltype(block_scaling_type)::value> // Relu2
|
||||
|
||||
};
|
||||
return fn_list[static_cast<int>(activation_type.activation_type)];
|
||||
switch (activation_type.activation_type)
|
||||
{
|
||||
case ActivationType::Identity:
|
||||
return &doActivationKernel<T, GemmOutputType, ScaleBiasType,
|
||||
IdentityAdaptor<cutlass::epilogue::thread::Identity>, decltype(block_scaling_type)::value>;
|
||||
case ActivationType::Gelu:
|
||||
return &doActivationKernel<T, GemmOutputType, ScaleBiasType,
|
||||
IdentityAdaptor<cutlass::epilogue::thread::GELU>, decltype(block_scaling_type)::value>;
|
||||
case ActivationType::Relu:
|
||||
return &doActivationKernel<T, GemmOutputType, ScaleBiasType,
|
||||
IdentityAdaptor<cutlass::epilogue::thread::ReLu>, decltype(block_scaling_type)::value>;
|
||||
case ActivationType::Silu:
|
||||
return &doActivationKernel<T, GemmOutputType, ScaleBiasType,
|
||||
IdentityAdaptor<cutlass::epilogue::thread::SiLu>, decltype(block_scaling_type)::value>;
|
||||
case ActivationType::Swiglu:
|
||||
return &doActivationKernel<T, GemmOutputType, ScaleBiasType,
|
||||
GLUAdaptor<cutlass::epilogue::thread::SiLu>, decltype(block_scaling_type)::value>;
|
||||
case ActivationType::Geglu:
|
||||
return &doActivationKernel<T, GemmOutputType, ScaleBiasType,
|
||||
GLUAdaptor<cutlass::epilogue::thread::GELU>, decltype(block_scaling_type)::value>;
|
||||
case ActivationType::SwigluBias:
|
||||
return &doActivationKernel<T, GemmOutputType, ScaleBiasType, SwigluBiasAdaptor,
|
||||
decltype(block_scaling_type)::value>;
|
||||
case ActivationType::Relu2:
|
||||
return &doActivationKernel<T, GemmOutputType, ScaleBiasType,
|
||||
IdentityAdaptor<cutlass::epilogue::thread::Relu2>, decltype(block_scaling_type)::value>;
|
||||
default: TLLM_CHECK_WITH_INFO(false, "Invalid activation type"); return nullptr;
|
||||
}
|
||||
};
|
||||
auto NVFP4 = tensorrt_llm::common::ConstExprWrapper<TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType,
|
||||
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4>{};
|
||||
|
||||
88
scripts/check_test_list.py
Normal file → Executable file
88
scripts/check_test_list.py
Normal file → Executable file
@ -1,3 +1,4 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
This script is used to verify test lists for L0, QA, and waives file.
|
||||
|
||||
@ -110,14 +111,14 @@ def verify_qa_test_lists(llm_src):
|
||||
f.write(f"{cleaned_line}\n")
|
||||
|
||||
|
||||
def verify_waive_list(llm_src, args):
|
||||
def check_waive_duplicates(llm_src):
|
||||
"""Check for duplicate entries in waives.txt and write report."""
|
||||
waives_list_path = f"{llm_src}/tests/integration/test_lists/waives.txt"
|
||||
dup_cases_record = f"{llm_src}/dup_cases.txt"
|
||||
non_existent_cases_record = f"{llm_src}/nonexits_cases.json"
|
||||
# Remove prefix and markers in wavies.txt
|
||||
dedup_lines = {
|
||||
} # Track all occurrences: processed_line -> [(line_no, original_line), ...]
|
||||
processed_lines = set()
|
||||
|
||||
# Track all occurrences: processed_line -> [(line_no, original_line), ...]
|
||||
dedup_lines = {}
|
||||
|
||||
with open(waives_list_path, "r") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
@ -125,6 +126,43 @@ def verify_waive_list(llm_src, args):
|
||||
original_line = line.strip()
|
||||
line = line.strip()
|
||||
|
||||
if not line:
|
||||
continue
|
||||
|
||||
# Check for SKIP marker in waives.txt and split by the first occurrence
|
||||
line = line.split(" SKIP", 1)[0].strip()
|
||||
|
||||
# Track all occurrences of each processed line
|
||||
if line in dedup_lines:
|
||||
dedup_lines[line].append((line_no, original_line))
|
||||
else:
|
||||
dedup_lines[line] = [(line_no, original_line)]
|
||||
|
||||
# Write duplicate report after processing all lines
|
||||
for processed_line, occurrences in dedup_lines.items():
|
||||
if len(occurrences) > 1:
|
||||
with open(dup_cases_record, "a") as f:
|
||||
f.write(
|
||||
f"Duplicate waive records found for '{processed_line}' ({len(occurrences)} occurrences):\n"
|
||||
)
|
||||
for i, (line_no, original_line) in enumerate(occurrences, 1):
|
||||
f.write(
|
||||
f" Occurrence {i} at line {line_no}: '{original_line}'\n"
|
||||
)
|
||||
f.write(f"\n")
|
||||
|
||||
|
||||
def verify_waive_list(llm_src, args):
|
||||
waives_list_path = f"{llm_src}/tests/integration/test_lists/waives.txt"
|
||||
non_existent_cases_record = f"{llm_src}/nonexits_cases.json"
|
||||
|
||||
processed_lines = set()
|
||||
with open(waives_list_path, "r") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
|
||||
if not line:
|
||||
continue
|
||||
|
||||
@ -135,12 +173,6 @@ def verify_waive_list(llm_src, args):
|
||||
# Check for SKIP marker in waives.txt and split by the first occurrence
|
||||
line = line.split(" SKIP", 1)[0].strip()
|
||||
|
||||
# Track all occurrences of each processed line
|
||||
if line in dedup_lines:
|
||||
dedup_lines[line].append((line_no, original_line))
|
||||
else:
|
||||
dedup_lines[line] = [(line_no, original_line)]
|
||||
|
||||
# If the line starts with 'full:', process it
|
||||
if line.startswith("full:"):
|
||||
line = line.split("/", 1)[1].lstrip("/")
|
||||
@ -173,19 +205,6 @@ def verify_waive_list(llm_src, args):
|
||||
|
||||
processed_lines.add(line)
|
||||
|
||||
# Write duplicate report after processing all lines
|
||||
for processed_line, occurrences in dedup_lines.items():
|
||||
if len(occurrences) > 1:
|
||||
with open(dup_cases_record, "a") as f:
|
||||
f.write(
|
||||
f"Duplicate waive records found for '{processed_line}' ({len(occurrences)} occurrences):\n"
|
||||
)
|
||||
for i, (line_no, original_line) in enumerate(occurrences, 1):
|
||||
f.write(
|
||||
f" Occurrence {i} at line {line_no}: '{original_line}'\n"
|
||||
)
|
||||
f.write(f"\n")
|
||||
|
||||
# Write the processed lines to a tmp file
|
||||
tmp_waives_file = f"{llm_src}/processed_waive_list.txt"
|
||||
with open(tmp_waives_file, "w") as f:
|
||||
@ -210,11 +229,19 @@ def main():
|
||||
parser.add_argument("--waive",
|
||||
action="store_true",
|
||||
help="Enable test list verification for waive file.")
|
||||
parser.add_argument(
|
||||
"--check-duplicate-waives",
|
||||
action="store_true",
|
||||
help="Enable duplicate check in waives.txt (fails if duplicates found)."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
script_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
llm_src = os.path.abspath(os.path.join(script_dir, "../"))
|
||||
|
||||
install_python_dependencies(llm_src)
|
||||
# Only skip installing dependencies if ONLY --check-duplicates is used
|
||||
if args.l0 or args.qa or args.waive:
|
||||
install_python_dependencies(llm_src)
|
||||
|
||||
pass_flag = True
|
||||
# Verify L0 test lists
|
||||
if args.l0:
|
||||
@ -243,6 +270,12 @@ def main():
|
||||
print("-----------Skipping waive list verification.-----------",
|
||||
flush=True)
|
||||
|
||||
# Check for duplicates in waives.txt if requested
|
||||
if args.check_duplicate_waives:
|
||||
print("-----------Checking for duplicates in waives.txt...-----------",
|
||||
flush=True)
|
||||
check_waive_duplicates(llm_src)
|
||||
|
||||
invalid_json_file = os.path.join(llm_src, "invalid_tests.json")
|
||||
if os.path.isfile(invalid_json_file) and os.path.getsize(
|
||||
invalid_json_file) > 0:
|
||||
@ -261,7 +294,8 @@ def main():
|
||||
print(
|
||||
"Duplicate test names found in waives.txt, please delete one or combine them first!!!\n"
|
||||
)
|
||||
# pass_flag = False
|
||||
if args.check_duplicate_waives:
|
||||
pass_flag = False
|
||||
|
||||
non_existent_cases_file = os.path.join(llm_src, "nonexits_cases.json")
|
||||
if os.path.isfile(non_existent_cases_file) and os.path.getsize(
|
||||
|
||||
@ -674,6 +674,11 @@ def _scale(weights: torch.Tensor, q_scale: torch.Tensor,
|
||||
return weights * q_scale.squeeze(-1) * s
|
||||
|
||||
|
||||
@maybe_compile(dynamic=True)
|
||||
def _to_float(hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
return hidden_states.float()
|
||||
|
||||
|
||||
class Indexer(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
@ -715,7 +720,7 @@ class Indexer(nn.Module):
|
||||
self.hidden_size,
|
||||
self.n_heads,
|
||||
bias=False,
|
||||
dtype=dtype,
|
||||
dtype=torch.float32,
|
||||
quant_config=None,
|
||||
skip_create_weights_in_init=skip_create_weights_in_init,
|
||||
use_custom_cublas_mm=True)
|
||||
@ -1233,82 +1238,63 @@ class Indexer(nn.Module):
|
||||
dtype=torch.int32)
|
||||
return topk_indices_buffer
|
||||
|
||||
def weight_scale(self, hidden_states: torch.Tensor,
|
||||
indexer_weights: Optional[torch.Tensor],
|
||||
q_scale: torch.Tensor) -> torch.Tensor:
|
||||
weights = indexer_weights if indexer_weights is not None else self.weights_proj(
|
||||
hidden_states)
|
||||
def _weight_scale(self, weights: torch.Tensor,
|
||||
q_scale: torch.Tensor) -> torch.Tensor:
|
||||
weights = _scale(weights, q_scale, self.weight_scale_factor)
|
||||
return weights
|
||||
|
||||
def _qk_projection_and_rope(self, qr: torch.Tensor, indexer_k: torch.Tensor,
|
||||
position_ids: torch.Tensor):
|
||||
"""Project Q/K and apply RoPE"""
|
||||
q = self.wq_b(qr)
|
||||
k = self.k_norm(indexer_k)
|
||||
q = q.view(-1, self.n_heads, self.head_dim)
|
||||
q_pe, q_nope = q.split([self.rope_dim, self.head_dim - self.rope_dim],
|
||||
dim=-1)
|
||||
k_pe, k_nope = k.split([self.rope_dim, self.head_dim - self.rope_dim],
|
||||
dim=-1)
|
||||
q_pe, k_pe = self.rotary_emb(position_ids, [q_pe, k_pe.unsqueeze(1)])
|
||||
k_pe = k_pe[:, 0, :]
|
||||
return q_pe, q_nope, k_pe, k_nope
|
||||
|
||||
def _prep_q_or_k(self, qk_pe: torch.Tensor, qk_nope: torch.Tensor):
|
||||
"""Concatenate, rotate, and FP8 quantize for Q or K"""
|
||||
q_or_k = torch.cat([qk_pe, qk_nope], dim=-1)
|
||||
q_or_k = rotate_activation(q_or_k)
|
||||
q_or_k = q_or_k.view(-1, self.head_dim)
|
||||
q_or_k = fp8_utils.fp8_quantize_1x128_sf_transpose(
|
||||
q_or_k, use_ue8m0=self.scale_fmt == "ue8m0")
|
||||
return q_or_k
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(self, qr: torch.Tensor, hidden_states: torch.Tensor,
|
||||
metadata: DSAtrtllmAttentionMetadata,
|
||||
position_ids: torch.Tensor, indexer_k: Optional[torch.Tensor],
|
||||
indexer_weights: Optional[torch.Tensor]):
|
||||
position_ids: torch.Tensor, indexer_k: torch.Tensor):
|
||||
quant_block_size = metadata.kv_cache_manager.quant_block_size
|
||||
assert quant_block_size == 128, "Only support quant_block_size = 128 for now"
|
||||
|
||||
if indexer_k is not None:
|
||||
q, k = maybe_execute_in_parallel(
|
||||
lambda: self.wq_b(
|
||||
qr), # TODO: fuse wq_b and move this outside of the indexer
|
||||
lambda: self.k_norm(indexer_k),
|
||||
self.ln_events[0],
|
||||
self.ln_events[1],
|
||||
self.aux_stream,
|
||||
)
|
||||
else:
|
||||
q, k = maybe_execute_in_parallel(
|
||||
lambda: self.wq_b(qr),
|
||||
lambda: self.k_norm(self.wk(hidden_states)),
|
||||
self.ln_events[0],
|
||||
self.ln_events[1],
|
||||
self.aux_stream,
|
||||
)
|
||||
|
||||
# q/k rope + possible fast_hadamard_transform
|
||||
q = q.view(-1, self.n_heads, self.head_dim)
|
||||
|
||||
q, k = maybe_execute_in_parallel(
|
||||
lambda: torch.split(
|
||||
q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1),
|
||||
lambda: torch.split(
|
||||
k, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1),
|
||||
q_and_k, weights = maybe_execute_in_parallel(
|
||||
lambda: self._qk_projection_and_rope(qr, indexer_k, position_ids),
|
||||
lambda: self.weights_proj(_to_float(hidden_states)),
|
||||
self.ln_events[0],
|
||||
self.ln_events[1],
|
||||
self.aux_stream,
|
||||
)
|
||||
|
||||
q_pe, q_nope = q
|
||||
k_pe, k_nope = k
|
||||
q_pe, k_pe = self.rotary_emb(position_ids, [q_pe, k_pe.unsqueeze(1)])
|
||||
|
||||
k_pe = k_pe[:, 0, :]
|
||||
|
||||
def _prep_q_or_k(qk_pe, qk_nope):
|
||||
q_or_k = torch.cat([qk_pe, qk_nope], dim=-1)
|
||||
q_or_k = rotate_activation(q_or_k)
|
||||
q_or_k = q_or_k.view(-1, self.head_dim)
|
||||
q_or_k = fp8_utils.fp8_quantize_1x128_sf_transpose(
|
||||
q_or_k, use_ue8m0=self.scale_fmt == "ue8m0")
|
||||
return q_or_k
|
||||
|
||||
q_pe, q_nope, k_pe, k_nope = q_and_k
|
||||
q, k = maybe_execute_in_parallel(
|
||||
lambda: _prep_q_or_k(q_pe, q_nope),
|
||||
lambda: _prep_q_or_k(k_pe, k_nope),
|
||||
lambda: self._prep_q_or_k(q_pe, q_nope),
|
||||
lambda: self._prep_q_or_k(k_pe, k_nope),
|
||||
self.ln_events[0],
|
||||
self.ln_events[1],
|
||||
self.aux_stream,
|
||||
)
|
||||
|
||||
q_fp8, q_scale = q
|
||||
k_fp8, k_scale = k
|
||||
q_fp8 = q_fp8.view(-1, self.n_heads, self.head_dim)
|
||||
q_scale = q_scale.view(-1, self.n_heads, 1)
|
||||
|
||||
weights, _ = maybe_execute_in_parallel(
|
||||
lambda: self.weight_scale(hidden_states, indexer_weights, q_scale),
|
||||
lambda: self._weight_scale(weights, q_scale),
|
||||
lambda: self._update_k_cache(
|
||||
k_fp8, k_scale, metadata), # store k_fp8 and k_scale in k cache
|
||||
self.ln_events[0],
|
||||
|
||||
@ -114,9 +114,11 @@ transforms:
|
||||
fuse_moe:
|
||||
stage: post_load_fusion
|
||||
enabled: true
|
||||
backend: trtllm
|
||||
fuse_fp8_moe:
|
||||
stage: post_load_fusion
|
||||
enabled: true
|
||||
backend: trtllm
|
||||
fuse_allreduce_residual_rmsnorm:
|
||||
stage: post_load_fusion
|
||||
# TODO (lucaslie): add backend selection as part of configurable inference optimizers
|
||||
|
||||
@ -341,7 +341,7 @@ def get_moe_configs(
|
||||
for config_file_path in config_file_paths:
|
||||
if os.path.exists(config_file_path):
|
||||
with open(config_file_path) as f:
|
||||
ad_logger.info("Using configuration from %s for MoE layer.", config_file_path)
|
||||
ad_logger.info(f"Using configuration from {config_file_path} for MoE layer.")
|
||||
# If a configuration has been found, return it
|
||||
tuned_config = json.load(f)
|
||||
# Delete triton_version from tuned_config
|
||||
@ -601,8 +601,16 @@ def triton_fused_moe(
|
||||
routing_weights: torch.Tensor,
|
||||
w1_stacked_weight: torch.Tensor,
|
||||
w2_stacked_weight: torch.Tensor,
|
||||
mlp_style: str = "mlp",
|
||||
act_fn: str = "relu2",
|
||||
) -> torch.Tensor:
|
||||
"""Triton unquantized MoE with 2-layer MLP and ReLU^2 activation."""
|
||||
|
||||
mlp_style = mlp_style.lower()
|
||||
act_fn = act_fn.lower()
|
||||
assert mlp_style == "mlp", "Triton backend only supports mlp style."
|
||||
assert act_fn == "relu2", "Triton backend only supports relu2 activation."
|
||||
|
||||
x_shape = x.shape
|
||||
x2d = x.view(-1, x_shape[-1])
|
||||
|
||||
|
||||
@ -1,25 +1,42 @@
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Dict, List, Literal, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
from pydantic import Field
|
||||
from torch.fx import GraphModule, Node
|
||||
|
||||
from ...models.factory import ModelFactory
|
||||
from ...shim.interface import CachedSequenceInterface
|
||||
from ...utils.cuda_mem_tracker import cuda_memory_tracker
|
||||
from ...utils.node_utils import bfs, extract_op_args, identify_regions_between_residuals, is_op
|
||||
from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry
|
||||
from ..interface import (
|
||||
BaseTransform,
|
||||
SharedConfig,
|
||||
TransformConfig,
|
||||
TransformInfo,
|
||||
TransformRegistry,
|
||||
)
|
||||
|
||||
|
||||
def _insert_fused_moe_ops(gm: GraphModule) -> int:
|
||||
def _insert_fused_moe_ops(gm: GraphModule, backend: Literal["auto", "trtllm", "triton"]) -> int:
|
||||
fused_key_counter = 0
|
||||
graph = gm.graph
|
||||
backend = backend.lower()
|
||||
|
||||
replacement_op = {
|
||||
"auto": torch.ops.auto_deploy.trtllm_moe_fused,
|
||||
"trtllm": torch.ops.auto_deploy.trtllm_moe_fused,
|
||||
"triton": torch.ops.auto_deploy.triton_moe_fused,
|
||||
}[backend]
|
||||
|
||||
for node in graph.nodes:
|
||||
if not is_op(node, torch.ops.auto_deploy.torch_moe):
|
||||
continue
|
||||
|
||||
(mlp_style_val,) = extract_op_args(node, "mlp_style")
|
||||
(mlp_style_val, act_fn_val) = extract_op_args(node, "mlp_style", "act_fn")
|
||||
assert backend != "triton" or mlp_style_val == "mlp", (
|
||||
"Triton backend only supports mlp style."
|
||||
)
|
||||
|
||||
hidden_states, selected_experts, routing_weights, w1_list, w2_list, w3_list = (
|
||||
extract_op_args(
|
||||
@ -43,15 +60,10 @@ def _insert_fused_moe_ops(gm: GraphModule) -> int:
|
||||
dim=0,
|
||||
)
|
||||
new_key_w_up = f"fused_moe_w3_w1_stacked_{fused_key_counter}"
|
||||
# TRTLLM fused MoE op supports gated MLP only.
|
||||
replacement_op = torch.ops.auto_deploy.trtllm_moe_fused
|
||||
|
||||
elif mlp_style_val == "mlp":
|
||||
fused_w_up_experts = torch.stack([gm.get_parameter(n.target) for n in w1_list], dim=0)
|
||||
new_key_w_up = f"fused_moe_w1_stacked_{fused_key_counter}"
|
||||
# Triton fused MoE op supports mlp only.
|
||||
replacement_op = torch.ops.auto_deploy.triton_moe_fused
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown mlp_style: {mlp_style_val}")
|
||||
|
||||
@ -569,7 +581,7 @@ class MatchNVFP4MoePattern(MatchMoePattern):
|
||||
return ["input_scale", "weight_scale", "alpha"]
|
||||
|
||||
|
||||
def _stack_fp8_moe_weights(gm: GraphModule) -> int:
|
||||
def _stack_fp8_moe_weights(gm: GraphModule, backend: Literal["auto", "trtllm", "triton"]) -> int:
|
||||
"""
|
||||
Stack per-expert FP8 weights and scales by materializing stacked tensors as parameters.
|
||||
This is fast because we directly stack the tensor values (not graph nodes).
|
||||
@ -578,6 +590,13 @@ def _stack_fp8_moe_weights(gm: GraphModule) -> int:
|
||||
fused_key_counter = 0
|
||||
graph = gm.graph
|
||||
|
||||
backend = backend.lower()
|
||||
replacement_op = {
|
||||
"auto": torch.ops.auto_deploy.trtllm_quant_fp8_moe_fused,
|
||||
"trtllm": torch.ops.auto_deploy.trtllm_quant_fp8_moe_fused,
|
||||
"triton": torch.ops.auto_deploy.triton_quant_fp8_moe,
|
||||
}[backend]
|
||||
|
||||
for node in graph.nodes:
|
||||
if not is_op(node, torch.ops.auto_deploy.torch_quant_fp8_moe):
|
||||
continue
|
||||
@ -709,7 +728,7 @@ def _stack_fp8_moe_weights(gm: GraphModule) -> int:
|
||||
# Create new node with get_attr for stacked parameters
|
||||
with graph.inserting_before(node):
|
||||
new_node = graph.call_function(
|
||||
torch.ops.auto_deploy.trtllm_quant_fp8_moe_fused,
|
||||
replacement_op,
|
||||
args=(
|
||||
hidden_states,
|
||||
selected_experts,
|
||||
@ -739,6 +758,15 @@ def _stack_fp8_moe_weights(gm: GraphModule) -> int:
|
||||
return fused_key_counter
|
||||
|
||||
|
||||
class FuseMoeConfig(TransformConfig):
|
||||
"""Configuration for MoE fusion transform."""
|
||||
|
||||
backend: str = Field(
|
||||
default="auto",
|
||||
description="Backend to use for MoE computation ('auto', 'trtllm' or 'triton'. default: 'auto').",
|
||||
)
|
||||
|
||||
|
||||
@TransformRegistry.register("fuse_moe")
|
||||
class FuseMoe(BaseTransform):
|
||||
"""
|
||||
@ -746,6 +774,10 @@ class FuseMoe(BaseTransform):
|
||||
torch.ops.auto_deploy.trtllm_moe_fused.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_config_class(cls) -> Type[TransformConfig]:
|
||||
return FuseMoeConfig
|
||||
|
||||
def _apply(
|
||||
self,
|
||||
gm: GraphModule,
|
||||
@ -754,7 +786,7 @@ class FuseMoe(BaseTransform):
|
||||
shared_config: SharedConfig,
|
||||
) -> Tuple[GraphModule, TransformInfo]:
|
||||
with cuda_memory_tracker():
|
||||
fused_key_counter = _insert_fused_moe_ops(gm)
|
||||
fused_key_counter = _insert_fused_moe_ops(gm, backend=self.config.backend)
|
||||
|
||||
info = TransformInfo(
|
||||
skipped=False,
|
||||
@ -765,6 +797,15 @@ class FuseMoe(BaseTransform):
|
||||
return gm, info
|
||||
|
||||
|
||||
class FuseFP8MoeConfig(TransformConfig):
|
||||
"""Configuration for FP8 MoE fusion transform."""
|
||||
|
||||
backend: str = Field(
|
||||
default="auto",
|
||||
description="Backend to use for FP8 MoE computation ('auto', 'trtllm' or 'triton'. default: 'auto').",
|
||||
)
|
||||
|
||||
|
||||
@TransformRegistry.register("fuse_fp8_moe")
|
||||
class FuseFP8Moe(BaseTransform):
|
||||
"""
|
||||
@ -780,7 +821,7 @@ class FuseFP8Moe(BaseTransform):
|
||||
shared_config: SharedConfig,
|
||||
) -> Tuple[GraphModule, TransformInfo]:
|
||||
with cuda_memory_tracker():
|
||||
fused_key_counter = _stack_fp8_moe_weights(gm)
|
||||
fused_key_counter = _stack_fp8_moe_weights(gm, backend=self.config.backend)
|
||||
|
||||
info = TransformInfo(
|
||||
skipped=(fused_key_counter == 0),
|
||||
|
||||
@ -362,9 +362,10 @@ class DeepseekV3WeightLoader:
|
||||
fused_a_scale = torch.cat(
|
||||
[q_a_proj_scale, fused_a_scale], dim=0)
|
||||
|
||||
module.weight_scale.data.copy_(fused_a_scale)
|
||||
# For DeepseekV32 with fuse_a_indexer_k_weight=True: kv_a_proj_with_mqa is oversized
|
||||
# to include indexer weights, which is filled in post_load_weights.
|
||||
module.weight_scale.data[0:fused_a_scale.
|
||||
shape[0]].copy_(fused_a_scale)
|
||||
# For DeepseekV32: kv_a_proj_with_mqa is oversized
|
||||
# to include indexer k weights, which is filled in post_load_weights.
|
||||
module.weight.data[0:fused_a.shape[0]].copy_(fused_a)
|
||||
elif names[-1] in params_map:
|
||||
module_weights = []
|
||||
@ -556,13 +557,6 @@ class DeepseekV32Attention(MLA):
|
||||
config = model_config.pretrained_config
|
||||
predicted_tokens_per_seq = model_config.spec_config.max_total_draft_tokens + 1 if model_config.spec_config is not None else 1
|
||||
|
||||
# DSV3.2 nvfp4 ckpt has kv_a_proj_with_mqa module in bfloat16
|
||||
# TODO: check it more directly/robustly, e.g., indexer_weight_quant == fuseA_quant == indexer_quant
|
||||
if model_config.get_quant_config().quant_algo == QuantAlgo.NVFP4:
|
||||
self.fuse_a_indexer_k_weight = True
|
||||
else:
|
||||
self.fuse_a_indexer_k_weight = False
|
||||
|
||||
super().__init__(hidden_size=config.hidden_size,
|
||||
num_attention_heads=config.num_attention_heads,
|
||||
num_key_value_heads=config.num_key_value_heads,
|
||||
@ -586,36 +580,46 @@ class DeepseekV32Attention(MLA):
|
||||
|
||||
self.indexer = self.mqa.indexer
|
||||
|
||||
if self.fuse_a_indexer_k_weight:
|
||||
# For DeepseekV32, the kv_a_proj_with_mqa includes:
|
||||
# q_a_proj + kv_a_proj_with_mqa + indexer.wk + indexer.weights_proj
|
||||
self.kv_a_proj_with_mqa = DeepseekV3Linear(
|
||||
config.hidden_size,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim + self.q_lora_rank +
|
||||
self.indexer.head_dim + self.indexer.n_heads,
|
||||
bias=False,
|
||||
dtype=config.torch_dtype,
|
||||
quant_config=model_config.get_quant_config(),
|
||||
skip_create_weights_in_init=model_config.
|
||||
skip_create_weights_in_init,
|
||||
use_custom_cublas_mm=True)
|
||||
# For DeepseekV32, the kv_a_proj_with_mqa includes:
|
||||
# q_a_proj + kv_a_proj_with_mqa + indexer.wk
|
||||
self.kv_a_proj_with_mqa = DeepseekV3Linear(
|
||||
config.hidden_size,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim + self.q_lora_rank +
|
||||
self.indexer.head_dim,
|
||||
bias=False,
|
||||
dtype=config.torch_dtype,
|
||||
quant_config=model_config.get_quant_config(),
|
||||
skip_create_weights_in_init=model_config.
|
||||
skip_create_weights_in_init,
|
||||
use_custom_cublas_mm=True)
|
||||
|
||||
def post_load_weights(self):
|
||||
if self.fuse_a_indexer_k_weight:
|
||||
assert self.kv_a_proj_with_mqa.weight.data.dtype == self.indexer.wk.weight.data.dtype == self.indexer.weights_proj.weight.data.dtype, "all weights in kv_a_proj_with_mqa module must have matching dtype"
|
||||
# Copy indexer weights into the fused kv_a_proj_with_mqa module
|
||||
indexer_wk_weight = self.indexer.wk.weight.data
|
||||
offset = self.kv_lora_rank + self.qk_rope_head_dim + self.q_lora_rank
|
||||
self.kv_a_proj_with_mqa.weight.data[offset:offset +
|
||||
self.indexer.head_dim].copy_(
|
||||
indexer_wk_weight)
|
||||
offset += self.indexer.head_dim
|
||||
indexer_weights_proj_weight = self.indexer.weights_proj.weight.data
|
||||
self.kv_a_proj_with_mqa.weight.data[offset:offset +
|
||||
self.indexer.n_heads].copy_(
|
||||
indexer_weights_proj_weight)
|
||||
self.indexer.wk = None
|
||||
self.indexer.weights_proj = None
|
||||
"""
|
||||
Concatenate indexer.wk weights into kv_a_proj_with_mqa's last dimension, to fuse indexer.wk projection with kv_a_proj_with_mqa GEMM.
|
||||
"""
|
||||
assert self.kv_a_proj_with_mqa.weight.data.dtype == self.indexer.wk.weight.data.dtype, "all weights in kv_a_proj_with_mqa module must have matching dtype"
|
||||
# Copy indexer weights into the fused kv_a_proj_with_mqa module
|
||||
indexer_wk_weight = self.indexer.wk.weight.data
|
||||
offset = self.kv_lora_rank + self.qk_rope_head_dim + self.q_lora_rank
|
||||
self.kv_a_proj_with_mqa.weight.data[offset:offset +
|
||||
self.indexer.head_dim].copy_(
|
||||
indexer_wk_weight)
|
||||
|
||||
# Copy indexer scale data if it exists
|
||||
if hasattr(self.indexer.wk,
|
||||
'weight_scale') and self.indexer.wk.weight_scale is not None:
|
||||
indexer_wk_scale = self.indexer.wk.weight_scale.data
|
||||
assert self.kv_a_proj_with_mqa.weight_scale.dim(
|
||||
) == 2, "weight_scale must be a 2D tensor"
|
||||
group_size = self.kv_a_proj_with_mqa.weight.shape[
|
||||
0] // self.kv_a_proj_with_mqa.weight_scale.shape[0]
|
||||
scale_offset = offset // group_size
|
||||
scale_size = indexer_wk_scale.shape[0]
|
||||
# Copy indexer scale to the corresponding position in the fused module
|
||||
self.kv_a_proj_with_mqa.weight_scale.data[
|
||||
scale_offset:scale_offset + scale_size].copy_(indexer_wk_scale)
|
||||
|
||||
self.indexer.wk = None
|
||||
|
||||
|
||||
class Deepseekv3RoutingImpl():
|
||||
|
||||
@ -42,8 +42,13 @@ class Gemma3InputProcessor(BaseMultimodalInputProcessor,
|
||||
model_path: str,
|
||||
config: PretrainedConfig,
|
||||
tokenizer: AutoTokenizer,
|
||||
trust_remote_code: bool = True):
|
||||
super().__init__()
|
||||
trust_remote_code: bool = True,
|
||||
**kwargs):
|
||||
super().__init__(model_path=model_path,
|
||||
config=config,
|
||||
tokenizer=tokenizer,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**kwargs)
|
||||
self._config = config
|
||||
self._tokenizer = tokenizer
|
||||
self._model_path = model_path
|
||||
|
||||
@ -572,8 +572,13 @@ class HCXVisionInputProcessor(BaseMultimodalDummyInputsBuilder,
|
||||
model_path: str,
|
||||
config: PretrainedConfig,
|
||||
tokenizer: AutoTokenizer,
|
||||
trust_remote_code: bool = True):
|
||||
super().__init__()
|
||||
trust_remote_code: bool = True,
|
||||
**kwargs):
|
||||
super().__init__(model_path=model_path,
|
||||
config=config,
|
||||
tokenizer=tokenizer,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**kwargs)
|
||||
self._config = config
|
||||
self._tokenizer = tokenizer if tokenizer is not None else AutoTokenizer.from_pretrained(
|
||||
model_path,
|
||||
|
||||
@ -1054,8 +1054,13 @@ class Llama4InputProcessor(BaseMultimodalInputProcessor,
|
||||
model_path: str,
|
||||
config: PretrainedConfig,
|
||||
tokenizer: AutoTokenizer,
|
||||
trust_remote_code: bool = True):
|
||||
super().__init__()
|
||||
trust_remote_code: bool = True,
|
||||
**kwargs):
|
||||
super().__init__(model_path=model_path,
|
||||
config=config,
|
||||
tokenizer=tokenizer,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**kwargs)
|
||||
self._config = config
|
||||
self._dtype = self._config.torch_dtype
|
||||
self._tokenizer = tokenizer if tokenizer is not None else AutoTokenizer.from_pretrained(
|
||||
|
||||
@ -43,8 +43,13 @@ class LlavaNextInputProcessor(BaseMultimodalInputProcessor,
|
||||
model_path: str,
|
||||
config: PretrainedConfig,
|
||||
tokenizer: AutoTokenizer,
|
||||
trust_remote_code: bool = True):
|
||||
super().__init__()
|
||||
trust_remote_code: bool = True,
|
||||
**kwargs):
|
||||
super().__init__(model_path=model_path,
|
||||
config=config,
|
||||
tokenizer=tokenizer,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**kwargs)
|
||||
self._config = config
|
||||
self._tokenizer = tokenizer if tokenizer is not None else AutoTokenizer.from_pretrained(
|
||||
model_path,
|
||||
|
||||
@ -224,8 +224,13 @@ class Mistral3InputProcessor(BaseMultimodalInputProcessor,
|
||||
config: PretrainedConfig,
|
||||
tokenizer: Optional[AutoTokenizer],
|
||||
trust_remote_code: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
super().__init__(model_path=model_path,
|
||||
config=config,
|
||||
tokenizer=tokenizer,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**kwargs)
|
||||
self._config = config
|
||||
self._dtype = self._config.torch_dtype
|
||||
self._tokenizer = tokenizer if tokenizer is not None else AutoTokenizer.from_pretrained(
|
||||
|
||||
@ -265,8 +265,15 @@ class NanoV2VLInputProcessor(BaseMultimodalInputProcessor, BaseMultimodalDummyIn
|
||||
config: transformers.PretrainedConfig,
|
||||
tokenizer: transformers.AutoTokenizer,
|
||||
trust_remote_code: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
super().__init__(
|
||||
model_path=model_path,
|
||||
config=config,
|
||||
tokenizer=tokenizer,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**kwargs,
|
||||
)
|
||||
if not trust_remote_code:
|
||||
raise ValueError("trust_remote_code must be True for Phi4MM")
|
||||
|
||||
|
||||
@ -763,8 +763,13 @@ class Phi4MMInputProcessor(BaseMultimodalInputProcessor,
|
||||
model_path: str,
|
||||
config: transformers.PretrainedConfig,
|
||||
tokenizer: transformers.AutoTokenizer,
|
||||
trust_remote_code: bool = True):
|
||||
super().__init__()
|
||||
trust_remote_code: bool = True,
|
||||
**kwargs):
|
||||
super().__init__(model_path=model_path,
|
||||
config=config,
|
||||
tokenizer=tokenizer,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**kwargs)
|
||||
if not trust_remote_code:
|
||||
raise ValueError("trust_remote_code must be True for Phi4MM")
|
||||
|
||||
|
||||
@ -95,10 +95,13 @@ class Qwen2VLInputProcessorBase(BaseMultimodalInputProcessor,
|
||||
model_path: str,
|
||||
config: PretrainedConfig,
|
||||
tokenizer: AutoTokenizer,
|
||||
trust_remote_code: bool = True):
|
||||
|
||||
super().__init__()
|
||||
self._config = config
|
||||
trust_remote_code: bool = True,
|
||||
**kwargs):
|
||||
super().__init__(model_path=model_path,
|
||||
config=config,
|
||||
tokenizer=tokenizer,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**kwargs)
|
||||
self._dtype = self._config.torch_dtype
|
||||
self._tokenizer = tokenizer if tokenizer is not None else AutoTokenizer.from_pretrained(
|
||||
model_path)
|
||||
|
||||
@ -35,7 +35,8 @@ from transformers import (AutoConfig, AutoImageProcessor, AutoModel,
|
||||
PretrainedConfig, PreTrainedModel)
|
||||
|
||||
from ..._utils import nvtx_range
|
||||
from ...inputs import (BaseMultimodalInputProcessor, ExtraProcessedInputs,
|
||||
from ...inputs import (BaseMultimodalDummyInputsBuilder,
|
||||
BaseMultimodalInputProcessor, ExtraProcessedInputs,
|
||||
MultimodalPlaceholderMetadata,
|
||||
MultimodalPlaceholderPlacement, TextPrompt,
|
||||
register_input_processor)
|
||||
@ -864,15 +865,22 @@ def _apply_chat_template(text, conv, tokenizer):
|
||||
return text
|
||||
|
||||
|
||||
class VilaInputProcessor(BaseMultimodalInputProcessor):
|
||||
class VilaInputProcessor(BaseMultimodalInputProcessor,
|
||||
BaseMultimodalDummyInputsBuilder):
|
||||
|
||||
def __init__(self,
|
||||
model_path: str,
|
||||
config: PretrainedConfig,
|
||||
tokenizer: AutoTokenizer,
|
||||
trust_remote_code: bool = True):
|
||||
super().__init__()
|
||||
trust_remote_code: bool = True,
|
||||
**kwargs):
|
||||
super().__init__(model_path=model_path,
|
||||
config=config,
|
||||
tokenizer=tokenizer,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**kwargs)
|
||||
self._config = config
|
||||
self._model_path = model_path
|
||||
llm_path, vision_tower_path, mm_projector_path = _get_model_paths(
|
||||
self.config)
|
||||
self._dtype = self.config.model_dtype
|
||||
@ -905,6 +913,9 @@ class VilaInputProcessor(BaseMultimodalInputProcessor):
|
||||
def dtype(self) -> torch.dtype:
|
||||
return self._dtype
|
||||
|
||||
def model_path(self) -> str:
|
||||
return self._model_path
|
||||
|
||||
@nvtx_range("[Vision] preprocess")
|
||||
def _preprocess(self,
|
||||
mm_data: dict[str, any],
|
||||
|
||||
@ -1221,19 +1221,11 @@ class MLA(nn.Module):
|
||||
if position_ids is not None:
|
||||
position_ids = position_ids[..., :num_tokens]
|
||||
|
||||
if self.fuse_a_indexer_k_weight:
|
||||
q, compressed_kv, k_pe, indexer_k, indexer_weights = self.kv_a_proj_with_mqa(
|
||||
hidden_states).split([
|
||||
self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim,
|
||||
self.indexer.head_dim, self.indexer.n_heads
|
||||
], -1)
|
||||
else:
|
||||
q, compressed_kv, k_pe = self.kv_a_proj_with_mqa(
|
||||
hidden_states).split([
|
||||
self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim
|
||||
], -1)
|
||||
indexer_k = None
|
||||
indexer_weights = None
|
||||
q, compressed_kv, k_pe, indexer_k = self.kv_a_proj_with_mqa(
|
||||
hidden_states).split([
|
||||
self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim,
|
||||
self.indexer.head_dim
|
||||
], -1)
|
||||
|
||||
# TODO: possibly overlap/fuse q_a_rmsnorm + kv_a_rmsnorm + indexer.k_layernorm?
|
||||
q, compressed_kv = maybe_execute_in_parallel(
|
||||
@ -1255,7 +1247,6 @@ class MLA(nn.Module):
|
||||
attn_metadata,
|
||||
position_ids,
|
||||
indexer_k=indexer_k, # indexer K proj
|
||||
indexer_weights=indexer_weights, # indexer weights proj
|
||||
)
|
||||
|
||||
assert q.shape[
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import contextlib
|
||||
import os
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, IntEnum
|
||||
@ -34,15 +35,15 @@ EventType = Enum(
|
||||
# IMPORTANT: Keep the same order of activation functions in this enum and the enum in
|
||||
# cpp/tensorrt_llm/kernels/cutlass_kernels/include/common.h
|
||||
class ActivationType(IntEnum):
|
||||
Gelu = 0
|
||||
Relu = 1
|
||||
Silu = 2
|
||||
Swiglu = 3
|
||||
Geglu = 4
|
||||
SwigluBias = 5
|
||||
Identity = 6
|
||||
Relu2 = 7
|
||||
InvalidType = 8
|
||||
InvalidType = 0
|
||||
Identity = 1
|
||||
Gelu = 2
|
||||
Relu = 3
|
||||
Silu = 4
|
||||
Swiglu = 5
|
||||
Geglu = 6
|
||||
SwigluBias = 7
|
||||
Relu2 = 8
|
||||
|
||||
|
||||
def set_torch_compiling(enable: bool):
|
||||
@ -316,10 +317,16 @@ def create_lm_head_tp_mapping(mapping: Mapping, token_count: int) -> Mapping:
|
||||
# We use heuristic to determine the lm_head_tp_size
|
||||
# Since token_count=256 will hit the boundary of math-bound problem
|
||||
# We use 256 // token_count to determine the lm_head_tp_size
|
||||
# For more details, refer to the blog: https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/blogs/tech_blog/blog14_Scaling_Expert_Parallelism_in_TensorRT-LLM_part3.md#mtp-lm-head-tensor-parallelism
|
||||
lm_head_tp_size_raw = 256 // token_count
|
||||
lm_head_tp_size = nearest_in_buckets(lm_head_tp_size_raw,
|
||||
[1, mapping.gpus_per_node])
|
||||
assert mapping.tp_size % lm_head_tp_size == 0
|
||||
# TODO: On platforms like GB200, setting lm_head_tp_size_upper_bound to world_size could be more efficient when world_size > gpus_per_node, we need to do further investigation.
|
||||
lm_head_tp_size_upper_bound = min(mapping.world_size, mapping.gpus_per_node)
|
||||
lm_head_tp_size = int(
|
||||
os.getenv(
|
||||
'LM_HEAD_TP_SIZE',
|
||||
nearest_in_buckets(lm_head_tp_size_raw,
|
||||
[1, lm_head_tp_size_upper_bound])))
|
||||
assert mapping.tp_size % lm_head_tp_size == 0, f"mapping.tp_size: {mapping.tp_size}, lm_head_tp_size: {lm_head_tp_size}"
|
||||
lm_head_pp_size = mapping.pp_size * mapping.tp_size // lm_head_tp_size
|
||||
|
||||
return Mapping(
|
||||
|
||||
@ -3,7 +3,7 @@ import random
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import (Any, Callable, Dict, List, Optional, Protocol, Tuple, Type,
|
||||
TypeVar)
|
||||
TypeVar, Union)
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
@ -118,7 +118,7 @@ class DefaultInputProcessor(InputProcessor):
|
||||
return token_ids, None
|
||||
|
||||
|
||||
class BaseMultimodalInputProcessor(InputProcessor, ABC):
|
||||
class BaseMultimodalInputProcessor(ABC):
|
||||
"""
|
||||
Base class for multimodal input processors with default implementations
|
||||
of get_num_tokens_per_image and get_num_tokens_per_video methods.
|
||||
@ -127,8 +127,16 @@ class BaseMultimodalInputProcessor(InputProcessor, ABC):
|
||||
models. Specific processors can override these methods if they need custom logic.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
def __init__(self,
|
||||
model_path,
|
||||
config,
|
||||
tokenizer,
|
||||
trust_remote_code: bool = True,
|
||||
**kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self._config = config
|
||||
self._model_path = model_path
|
||||
self._tokenizer = tokenizer
|
||||
self._use_fast: bool = kwargs.get('use_fast', True)
|
||||
self._multimodal_hashing_supported: Optional[bool] = None
|
||||
|
||||
@ -142,13 +150,13 @@ class BaseMultimodalInputProcessor(InputProcessor, ABC):
|
||||
@abstractmethod
|
||||
def tokenizer(self) -> PreTrainedTokenizerBase:
|
||||
"""The HF tokenizer for this model."""
|
||||
...
|
||||
return self._tokenizer
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def config(self) -> PretrainedConfig:
|
||||
"""The HF pretrained config for this model."""
|
||||
...
|
||||
return self._config
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
@ -306,22 +314,23 @@ class BaseMultimodalDummyInputsBuilder(ABC):
|
||||
super().__init__(**kwargs)
|
||||
self.image_max_dim = kwargs.get('image_max_dim',
|
||||
self.DEFAULT_IMAGE_MAX_DIM)
|
||||
self.img_min_dim = kwargs.get('img_min_dim', self.DEFAULT_IMAGE_MIN_DIM)
|
||||
self.image_min_dim = kwargs.get('image_min_dim',
|
||||
self.DEFAULT_IMAGE_MIN_DIM)
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def tokenizer(self) -> PreTrainedTokenizerBase:
|
||||
pass
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def config(self) -> PretrainedConfig:
|
||||
pass
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def model_path(self) -> str:
|
||||
pass
|
||||
...
|
||||
|
||||
def get_dummy_image(self, max_width: int, max_height: int) -> Image.Image:
|
||||
image = Image.new("RGB", (max_width, max_height),
|
||||
@ -331,7 +340,7 @@ class BaseMultimodalDummyInputsBuilder(ABC):
|
||||
def get_dummy_prompt(self, input_seq_len: int):
|
||||
# TODO(yechank): We use the max resolution as starting point and keep reducing the resolution until the prompt length is less than the input sequence length.
|
||||
# Need to find better way to calculate the dummy prompt length as this iteration may not be efficient.
|
||||
while self.image_max_dim >= self.img_min_dim:
|
||||
while self.image_max_dim >= self.image_min_dim:
|
||||
image = self.get_dummy_image(max_width=self.image_max_dim,
|
||||
max_height=self.image_max_dim)
|
||||
|
||||
@ -549,7 +558,7 @@ def create_input_processor(
|
||||
model_path_or_dir: str,
|
||||
tokenizer,
|
||||
checkpoint_format: Optional[str] = "HF",
|
||||
) -> InputProcessor:
|
||||
) -> Union[InputProcessor, BaseMultimodalInputProcessor]:
|
||||
"""Create an input processor for a specific model.
|
||||
|
||||
Args:
|
||||
|
||||
@ -20,6 +20,7 @@ import numpy as np
|
||||
import tensorrt as trt
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._torch.utils import ActivationType
|
||||
from tensorrt_llm._utils import (get_init_params, str_dtype_to_torch,
|
||||
str_dtype_to_trt)
|
||||
from tensorrt_llm.layers.lora import LoraParams
|
||||
@ -49,14 +50,15 @@ from .mlp import MLP, GatedMLP
|
||||
|
||||
activation_str_to_int_map = {
|
||||
# [WARNING] Keep the below in sync with cpp/tensorrt_llm/kernels/cutlass_kernels/include/common.h
|
||||
"gelu": 0,
|
||||
"gelu_new": 0,
|
||||
"relu": 1,
|
||||
"silu": 2,
|
||||
"swiglu": 3,
|
||||
"geglu": 4,
|
||||
"swiglu_bias": 5,
|
||||
"identity": 6,
|
||||
"gelu": int(ActivationType.Gelu),
|
||||
"gelu_new": int(ActivationType.Gelu),
|
||||
"relu": int(ActivationType.Relu),
|
||||
"silu": int(ActivationType.Silu),
|
||||
"swiglu": int(ActivationType.Swiglu),
|
||||
"geglu": int(ActivationType.Geglu),
|
||||
"swiglu_bias": int(ActivationType.SwigluBias),
|
||||
"identity": int(ActivationType.Identity),
|
||||
"relu2": int(ActivationType.Relu2),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -2677,7 +2677,8 @@ class TorchLlmArgs(BaseLlmArgs):
|
||||
|
||||
enable_autotuner: bool = Field(
|
||||
default=True,
|
||||
description="Enable autotuner only when torch compile is enabled.",
|
||||
description=
|
||||
"Enable autotuner for all tunable ops. This flag is for debugging purposes only, and the performance may significantly degrade if set to false.",
|
||||
status="prototype")
|
||||
|
||||
enable_layerwise_nvtx_marker: bool = Field(
|
||||
|
||||
@ -161,6 +161,11 @@ nvidia/Nemotron-H-56B-Base-8K:
|
||||
- quant_algo: FP8
|
||||
kv_cache_quant_algo: FP8
|
||||
accuracy: 89.27
|
||||
nvidia/Nemotron-MOE:
|
||||
- accuracy: 88.249
|
||||
- quant_algo: FP8
|
||||
kv_cache_quant_algo: FP8
|
||||
accuracy: 86.884
|
||||
nvidia/Llama-3.1-Nemotron-Nano-8B-v1:
|
||||
- accuracy: 37.15
|
||||
- quant_algo: FP8
|
||||
|
||||
@ -270,6 +270,11 @@ nvidia/Nemotron-H-56B-Base-8K:
|
||||
- quant_algo: FP8
|
||||
kv_cache_quant_algo: FP8
|
||||
accuracy: 83.82
|
||||
nvidia/Nemotron-MOE:
|
||||
- accuracy: 77.802
|
||||
- quant_algo: FP8
|
||||
kv_cache_quant_algo: FP8
|
||||
accuracy: 73.879
|
||||
microsoft/Phi-4-mini-instruct:
|
||||
- accuracy: 68.98
|
||||
- quant_algo: FP8
|
||||
|
||||
@ -18,6 +18,7 @@ import os
|
||||
import pytest
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy import LLM as AutoDeployLLM
|
||||
from tensorrt_llm.quantization import QuantAlgo
|
||||
from tensorrt_llm.sampling_params import SamplingParams
|
||||
|
||||
from ..conftest import llm_models_root
|
||||
@ -153,7 +154,8 @@ class TestNemotronH(LlmapiAccuracyTestHarness):
|
||||
|
||||
class TestNemotronMOE(LlmapiAccuracyTestHarness):
|
||||
MODEL_NAME = "nvidia/Nemotron-MOE"
|
||||
MODEL_PATH = f"{llm_models_root()}/Nemotron-MOE/"
|
||||
MODEL_PATH_BF16 = f"{llm_models_root()}/Nemotron-Nano-3-30B-A3.5B-dev-1024"
|
||||
MODEL_PATH_FP8 = f"{llm_models_root()}/Nemotron-Nano-3-30B-A3.5B-FP8-KVFP8-dev"
|
||||
|
||||
def get_default_kwargs(self):
|
||||
return {
|
||||
@ -196,14 +198,29 @@ class TestNemotronMOE(LlmapiAccuracyTestHarness):
|
||||
use_beam_search=beam_width > 1)
|
||||
|
||||
@pytest.mark.skip_less_device_memory(32000)
|
||||
def test_auto_dtype(self):
|
||||
pytest.skip("Nemotron-MOE is not in CI yet")
|
||||
def test_bf16(self):
|
||||
kwargs = self.get_default_kwargs()
|
||||
sampling_params = self.get_default_sampling_params()
|
||||
with AutoDeployLLM(model=self.MODEL_PATH,
|
||||
tokenizer=self.MODEL_PATH,
|
||||
with AutoDeployLLM(model=self.MODEL_PATH_BF16,
|
||||
tokenizer=self.MODEL_PATH_BF16,
|
||||
**kwargs) as llm:
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm, sampling_params=sampling_params)
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
@pytest.mark.skip_less_device_memory(32000)
|
||||
def test_fp8(self):
|
||||
kwargs = self.get_default_kwargs()
|
||||
sampling_params = self.get_default_sampling_params()
|
||||
with AutoDeployLLM(model=self.MODEL_PATH_FP8,
|
||||
tokenizer=self.MODEL_PATH_FP8,
|
||||
**kwargs) as llm:
|
||||
# Manually set quant_config for FP8 model to get the accuracy threshold
|
||||
llm.args.quant_config.quant_algo = QuantAlgo.FP8
|
||||
llm.args.quant_config.kv_cache_quant_algo = QuantAlgo.FP8
|
||||
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm, sampling_params=sampling_params)
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
@ -2049,6 +2049,18 @@ class TestDeepSeekR1(LlmapiAccuracyTestHarness):
|
||||
32,
|
||||
"TRTLLM",
|
||||
marks=pytest.mark.skip_less_mpi_world_size(8)),
|
||||
pytest.param(4,
|
||||
1,
|
||||
4,
|
||||
3,
|
||||
False,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
16,
|
||||
"CUTLASS",
|
||||
marks=pytest.mark.skip_less_mpi_world_size(4)),
|
||||
pytest.param(8,
|
||||
1,
|
||||
8,
|
||||
@ -2124,9 +2136,9 @@ class TestDeepSeekR1(LlmapiAccuracyTestHarness):
|
||||
],
|
||||
ids=[
|
||||
"latency", "latency_trtllmgen", "latency_adp_lmtp",
|
||||
"latency_trtllmgen_adp_lmtp", "throughput", "throughput_tp8",
|
||||
"throughput_tp4", "throughput_mtp", "throughput_bs8_mtp",
|
||||
"throughput_pp4_mtp"
|
||||
"latency_trtllmgen_adp_lmtp", "latency_adp_lmtp_tp4", "throughput",
|
||||
"throughput_tp8", "throughput_tp4", "throughput_mtp",
|
||||
"throughput_bs8_mtp", "throughput_pp4_mtp"
|
||||
])
|
||||
def test_nvfp4_multi_gpus(self, tp_size, pp_size, ep_size, mtp_nextn, fp8kv,
|
||||
attention_dp, enable_lm_head_tp_in_adp,
|
||||
|
||||
@ -274,15 +274,14 @@ llm_perf_core:
|
||||
|
||||
- condition:
|
||||
ranges:
|
||||
compute_capability:
|
||||
gte: 9.0
|
||||
lt: 12.0
|
||||
system_gpu_count:
|
||||
gte: 8
|
||||
gpu_memory:
|
||||
gt: 80000
|
||||
wildcards:
|
||||
gpu:
|
||||
- '*h100*'
|
||||
- '*h200*'
|
||||
- '*h20*'
|
||||
|
||||
tests:
|
||||
# E2E trtllm-bench
|
||||
#mixtral_8x7b_v0.1_instruct
|
||||
@ -309,7 +308,7 @@ llm_perf_core:
|
||||
- perf/test_perf.py::test_perf[llama_v4_maverick_17b_128e_instruct_fp8-bench-pytorch-float8-maxbs:1024-maxnt:4096-kv_frac:0.85-input_output_len:1000,1000-reqs:3000-ep:8-tp:8-gpus:8]
|
||||
- perf/test_perf.py::test_perf[llama_v4_maverick_17b_128e_instruct_fp8-bench-pytorch-float8-input_output_len:128,128-ep:8-tp:8-gpus:8]
|
||||
- perf/test_perf.py::test_perf[llama_v4_maverick_17b_128e_instruct_fp8-bench-pytorch-float8-input_output_len:512,32-ep:8-tp:8-gpus:8]
|
||||
#rcca case
|
||||
# chunked attention case
|
||||
- perf/test_perf.py::test_perf[llama_v4_maverick_17b_128e_instruct_fp8-bench-pytorch-float8-maxbs:1024-maxnt:20000-kv_frac:0.6-input_output_len:20000,2000-reqs:1000-ep:8-tp:8-gpus:8]
|
||||
|
||||
#llama_v4_scout_17b_16e_instruct_fp8
|
||||
|
||||
@ -168,11 +168,23 @@ llm_perf_sanity:
|
||||
# for chunked prefill cases
|
||||
- perf/test_perf.py::test_perf[deepseek_v3_lite_fp8-bench-pytorch-float8-maxbs:512-maxnt:2048-kv_frac:0.85-input_output_len:3000,500-reqs:200]
|
||||
- perf/test_perf.py::test_perf[llama_v3.1_405b_instruct_fp8-bench-pytorch-float8-input_output_len:128,128-tp:8-gpus:8]
|
||||
- perf/test_perf.py::test_perf[llama_v4_maverick_17b_128e_instruct_fp8-bench-pytorch-float8-input_output_len:512,32-ep:8-tp:8-gpus:8]
|
||||
- perf/test_perf.py::test_perf[llama_v4_maverick_17b_128e_instruct_fp8-bench-pytorch-float8-maxbs:1024-maxnt:20000-kv_frac:0.6-input_output_len:20000,2000-reqs:1000-ep:8-tp:8-gpus:8]
|
||||
- perf/test_perf.py::test_perf[llama_v4_maverick_17b_128e_instruct_fp8-bench-pytorch-float8-input_output_len:512,32-ep:8-tp:8-gpus:8] TIMEOUT(100)
|
||||
- perf/test_perf.py::test_perf[llama_v4_scout_17b_16e_instruct_fp8-bench-pytorch-float8-maxbs:1024-maxnt:20000-kv_frac:0.85-input_output_len:20000,2000-reqs:1000-ep:8-tp:8-gpus:8] TIMEOUT(100)
|
||||
- perf/test_perf.py::test_perf[qwen3_235b_a22b_fp8-bench-pytorch-float8-input_output_len:1000,2000-con:256-ep:8-gpus:8] TIMEOUT(60)
|
||||
- perf/test_perf.py::test_perf[deepseek_v3_lite_fp8-disagg_server-ctx_dp:4-gen_tp:4]
|
||||
- perf/test_perf.py::test_perf[llama_v3.1_8b-disagg_server-ctx_dp:4-gen_tp:4]
|
||||
# gpt_oss_20b_fp4
|
||||
- perf/test_perf.py::test_perf[gpt_oss_20b_fp4-bench-pytorch-float4-input_output_len:512,512]
|
||||
|
||||
# gpu_arch > Hopper, exclude GB20X, RTX 6000 for not supported
|
||||
- condition:
|
||||
ranges:
|
||||
system_gpu_count:
|
||||
gte: 8
|
||||
compute_capability:
|
||||
gte: 9.0
|
||||
lt: 12.0
|
||||
|
||||
tests:
|
||||
# chunked attention case
|
||||
- perf/test_perf.py::test_perf[llama_v4_maverick_17b_128e_instruct_fp8-bench-pytorch-float8-maxbs:1024-maxnt:20000-kv_frac:0.6-input_output_len:20000,2000-reqs:1000-ep:8-tp:8-gpus:8]
|
||||
|
||||
@ -59,6 +59,7 @@ l0_dgx_b200:
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_tp2pp2
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_nvfp4_tp4
|
||||
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-cutlass-auto]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[latency_adp_lmtp_tp4]
|
||||
- condition:
|
||||
ranges:
|
||||
system_gpu_count:
|
||||
|
||||
@ -342,11 +342,11 @@ accuracy/test_cli_flow.py::TestMinitron4BBase::test_fp8 SKIP (https://nvbugs/560
|
||||
examples/test_gpt.py::test_llm_minitron_fp8_with_pseudo_loras[4b] SKIP (https://nvbugs/5606233)
|
||||
disaggregated/test_disaggregated_single_gpu.py::test_disaggregated_simple_deepseek[True-False-DeepSeek-V3-Lite-fp8/fp8] SKIP (https://nvbugs/5626197)
|
||||
disaggregated/test_disaggregated_single_gpu.py::test_disaggregated_simple_deepseek[True-True-DeepSeek-V3-Lite-fp8/fp8] SKIP (https://nvbugs/5628952)
|
||||
accuracy/test_llm_api_pytorch_multimodal.py::TestQwen2_5_VL_7B::test_auto_dtype SKIP (https://nvbugs/5636894)
|
||||
test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-hf-nvfp4-False-False] SKIP (https://nvbugs/5629791)
|
||||
accuracy/test_disaggregated_serving.py::TestLlama4ScoutInstruct::test_auto_dtype[False] SKIP (https://nvbugs/5629792)
|
||||
accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_fp8[throughput_latency] SKIP (https://nvbugs/5631036)
|
||||
test_e2e.py::test_openai_chat_multimodal_example SKIP (https://nvbugs/5636894)
|
||||
accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype[False-2] SKIP (https://nvbugs/5636912)
|
||||
accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype[False-4] SKIP (https://nvbugs/5636912)
|
||||
accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4[latency_moe_trtllm_attention_dp] SKIP (https://nvbugs/5637220)
|
||||
llmapi/test_llm_examples.py::test_llmapi_example_multilora SKIP (https://nvbugs/5636857)
|
||||
unittest/_torch/modules/test_mla_helix.py::test_mla_helix_distributed SKIP (https://nvbugspro.nvidia.com/bug/5637012)
|
||||
@ -379,8 +379,6 @@ accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_auto_dtype[True] SKIP
|
||||
disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_bf16_empty_batch[DeepSeek-V3-Lite-bf16] SKIP (https://nvbugs/5601682)
|
||||
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=False-overlap_scheduler=False] SKIP (https://nvbugs/5655584)
|
||||
accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_chunked_prefill SKIP (https://nvbugs/5608930)
|
||||
accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_auto_dtype[False] SKIP (https://nvbugspro.nvidia.com/bug/5651854)
|
||||
test_e2e.py::test_ptp_quickstart_multimodal_chunked_prefill[phi4-multimodal-instruct-fp4-multimodals/Phi-4-multimodal-instruct-FP4-0.8-image] SKIP (https://nvbugs/5568836)
|
||||
test_e2e.py::test_ptp_quickstart_multimodal_chunked_prefill[phi4-multimodal-instruct-fp4-multimodals/Phi-4-multimodal-instruct-FP4-0.8-image] SKIP (https://nvbugs/5568836)
|
||||
test_e2e.py::test_ptp_quickstart_multimodal_kv_cache_reuse[phi4-multimodal-instruct-fp4-multimodals/Phi-4-multimodal-instruct-FP4-0.8-image] SKIP (https://nvbugs/5568836)
|
||||
test_e2e.py::test_ptp_quickstart_multimodal_multiturn[phi4-multimodal-instruct-fp4-multimodals/Phi-4-multimodal-instruct-FP4] SKIP (https://nvbugs/5568836)
|
||||
@ -415,3 +413,5 @@ accuracy/test_llm_api_pytorch_multimodal.py::TestNemotron_Nano_12B_V2_VL::test_a
|
||||
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_eagle3_tp8[eagle3_one_model=True] SKIP (https://nvbugs/5673743)
|
||||
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_eagle3_tp8[eagle3_one_model=False] SKIP (https://nvbugs/5673743)
|
||||
accuracy/test_llm_api_pytorch.py::TestNemotronNas::test_auto_dtype_tp8 SKIP (https://nvbugs/5673527)
|
||||
disaggregated/test_auto_scaling.py::test_disagg_server_restart[etcd-round_robin] SKIP (https://nvbugs/5633340)
|
||||
disaggregated/test_auto_scaling.py::test_disagg_server_restart[http-round_robin] SKIP (https://nvbugs/5633340)
|
||||
|
||||
@ -681,8 +681,7 @@ def test_forward_sparse_mla_unified(batch_name, kv_cache_dtype: str):
|
||||
hidden_states,
|
||||
attn_metadata,
|
||||
position_ids,
|
||||
None, # indexer_k
|
||||
None, # indexer_weights
|
||||
indexer_k=mla.mqa.indexer.wk(hidden_states), # indexer_k
|
||||
)
|
||||
|
||||
# Validate indexer output against expected causal indices (since seq_len < topk=2048)
|
||||
|
||||
@ -136,6 +136,7 @@ def test_openai_compatible_json_schema(client: openai.OpenAI, model_name: str):
|
||||
"type": "json_schema",
|
||||
"json_schema": json_schema
|
||||
},
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
message = chat_completion.choices[0].message
|
||||
|
||||
Loading…
Reference in New Issue
Block a user