[None][feat] Fused kernels (qknormrope + moe routing) and two-model MTP support for glm4moe (#9852)

Signed-off-by: Xuanyu Chen <xuanyuc@nvidia.com>
This commit is contained in:
nvxuanyuc 2025-12-13 18:47:24 -08:00 committed by GitHub
parent 64d7796234
commit a5a37227d6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 289 additions and 82 deletions

View File

@ -66,6 +66,7 @@ __global__ void fusedQKNormRopeKernel(
int const num_heads_q, // Number of query heads
int const num_heads_k, // Number of key heads
int const num_heads_v, // Number of value heads
int const rotary_dim, // Dimension for RoPE
float const eps, // Epsilon for RMS normalization
__nv_bfloat16 const* q_weight, // RMSNorm weights for query
__nv_bfloat16 const* k_weight, // RMSNorm weights for key
@ -184,7 +185,7 @@ __global__ void fusedQKNormRopeKernel(
int dim_idx = laneId * numElemsPerThread + i;
int half_dim = dim_idx / 2;
float freq = powf(base, -2.0f * half_dim / static_cast<float>(head_dim));
float freq = powf(base, -2.0f * half_dim / static_cast<float>(rotary_dim));
if (factor != 1.0f)
{
@ -212,19 +213,20 @@ __global__ void fusedQKNormRopeKernel(
{
// Before data exchange with in warp, we need to sync.
__syncwarp();
int pairOffset = (rotary_dim / 2) / numElemsPerThread;
// Get the data from the other half of the warp. Fill cos_vals and sin_vals.
for (int i = 0; i < numElemsPerThread; i++)
{
elements2[i] = __shfl_xor_sync(0xffffffff, elements[i], 16);
if (laneId < 16)
elements2[i] = __shfl_xor_sync(0xffffffff, elements[i], pairOffset);
if (laneId < pairOffset)
{
elements2[i] = -elements2[i];
}
int dim_idx = laneId * numElemsPerThread + i;
dim_idx = (dim_idx * 2) % head_dim;
dim_idx = (dim_idx * 2) % rotary_dim;
int half_dim = dim_idx / 2;
float freq = powf(base, -2.0f * half_dim / static_cast<float>(head_dim));
float freq = powf(base, -2.0f * half_dim / static_cast<float>(rotary_dim));
if (factor != 1.0f)
{
@ -251,9 +253,25 @@ __global__ void fusedQKNormRopeKernel(
__syncwarp();
}
for (int i = 0; i < numElemsPerThread; i++)
bool const is_full_rope = (rotary_dim == head_dim);
if (is_full_rope)
{
elements[i] = (elements[i] * cos_vals[i] + elements2[i] * sin_vals[i]) * attention_factor;
for (int i = 0; i < numElemsPerThread; i++)
{
elements[i] = (elements[i] * cos_vals[i] + elements2[i] * sin_vals[i]) * attention_factor;
}
}
else
{
for (int i = 0; i < numElemsPerThread; i++)
{
int dim_idx = laneId * numElemsPerThread + i;
if (dim_idx < rotary_dim)
{
elements[i] = (elements[i] * cos_vals[i] + elements2[i] * sin_vals[i]) * attention_factor;
}
}
}
// Store.
@ -284,14 +302,23 @@ __global__ void fusedQKNormRopeKernel(
}
void launchFusedQKNormRope(void* qkv, int const num_tokens, int const num_heads_q, int const num_heads_k,
int const num_heads_v, int const head_dim, float const eps, void const* q_weight, void const* k_weight,
float const base, bool const interleave, int const* position_ids, float factor, float low, float high,
float attention_factor, cudaStream_t stream, bool is_qk_norm)
int const num_heads_v, int const head_dim, int const rotary_dim, float const eps, void const* q_weight,
void const* k_weight, float const base, bool const interleave, int const* position_ids, float factor, float low,
float high, float attention_factor, cudaStream_t stream, bool is_qk_norm)
{
if (factor == 1.0f)
{
TLLM_CHECK(attention_factor == 1.0f);
}
TLLM_CHECK_WITH_INFO(rotary_dim % 2 == 0, "rotary_dim must be even");
if (!interleave)
{
// To allow warp-level pairing for partial rope
TLLM_CHECK_WITH_INFO(
(rotary_dim * 16) % head_dim == 0, "Unsupported rotary dimension for fusedQKNormRope: %d", rotary_dim);
}
constexpr int blockSize = 256;
int const warpsPerBlock = blockSize / 32;
@ -309,7 +336,7 @@ void launchFusedQKNormRope(void* qkv, int const num_tokens, int const num_heads_
case 64:
DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
fusedQKNormRopeKernel<64, INTERLEAVE><<<gridDim, blockDim, 0, stream>>>(
reinterpret_cast<__nv_bfloat16*>(qkv), num_heads_q, num_heads_k, num_heads_v, eps,
reinterpret_cast<__nv_bfloat16*>(qkv), num_heads_q, num_heads_k, num_heads_v, rotary_dim, eps,
reinterpret_cast<__nv_bfloat16 const*>(q_weight), reinterpret_cast<__nv_bfloat16 const*>(k_weight),
base, position_ids, num_tokens, factor, low, high, attention_factor, is_qk_norm);
});
@ -317,7 +344,7 @@ void launchFusedQKNormRope(void* qkv, int const num_tokens, int const num_heads_
case 128:
DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
fusedQKNormRopeKernel<128, INTERLEAVE><<<gridDim, blockDim, 0, stream>>>(
reinterpret_cast<__nv_bfloat16*>(qkv), num_heads_q, num_heads_k, num_heads_v, eps,
reinterpret_cast<__nv_bfloat16*>(qkv), num_heads_q, num_heads_k, num_heads_v, rotary_dim, eps,
reinterpret_cast<__nv_bfloat16 const*>(q_weight), reinterpret_cast<__nv_bfloat16 const*>(k_weight),
base, position_ids, num_tokens, factor, low, high, attention_factor, is_qk_norm);
});
@ -325,7 +352,7 @@ void launchFusedQKNormRope(void* qkv, int const num_tokens, int const num_heads_
case 256:
DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
fusedQKNormRopeKernel<256, INTERLEAVE><<<gridDim, blockDim, 0, stream>>>(
reinterpret_cast<__nv_bfloat16*>(qkv), num_heads_q, num_heads_k, num_heads_v, eps,
reinterpret_cast<__nv_bfloat16*>(qkv), num_heads_q, num_heads_k, num_heads_v, rotary_dim, eps,
reinterpret_cast<__nv_bfloat16 const*>(q_weight), reinterpret_cast<__nv_bfloat16 const*>(k_weight),
base, position_ids, num_tokens, factor, low, high, attention_factor, is_qk_norm);
});

View File

@ -33,6 +33,7 @@ void launchFusedQKNormRope(
int const num_heads_k, // Number of key heads
int const num_heads_v, // Number of value heads
int const head_dim, // Dimension per head
int const rotary_dim, // Dimension for RoPE
float const eps, // Epsilon for RMS normalization
void const* q_weight, // RMSNorm weights for query [head_dim]
void const* k_weight, // RMSNorm weights for key [head_dim]

View File

@ -34,6 +34,7 @@ void fused_qk_norm_rope(
int64_t num_heads_k, // Number of key heads
int64_t num_heads_v, // Number of value heads
int64_t head_dim, // Dimension per head
int64_t rotary_dim, // Dimension for RoPE
double eps, // Epsilon for RMS normalization
torch::Tensor& q_weight, // RMSNorm weights for query [head_dim]
torch::Tensor& k_weight, // RMSNorm weights for key [head_dim]
@ -72,9 +73,9 @@ void fused_qk_norm_rope(
tensorrt_llm::kernels::launchFusedQKNormRope(reinterpret_cast<__nv_bfloat16*>(qkv.data_ptr()),
static_cast<int>(num_tokens), static_cast<int>(num_heads_q), static_cast<int>(num_heads_k),
static_cast<int>(num_heads_v), static_cast<int>(head_dim), static_cast<float>(eps),
reinterpret_cast<__nv_bfloat16*>(q_weight.data_ptr()), reinterpret_cast<__nv_bfloat16*>(k_weight.data_ptr()),
static_cast<float>(base),
static_cast<int>(num_heads_v), static_cast<int>(head_dim), static_cast<int>(rotary_dim),
static_cast<float>(eps), reinterpret_cast<__nv_bfloat16*>(q_weight.data_ptr()),
reinterpret_cast<__nv_bfloat16*>(k_weight.data_ptr()), static_cast<float>(base),
!is_neox, // interleave
reinterpret_cast<int const*>(position_ids.data_ptr()), static_cast<float>(factor), static_cast<float>(low),
static_cast<float>(high), static_cast<float>(attention_factor), stream, is_qk_norm);
@ -84,7 +85,8 @@ void fused_qk_norm_rope(
TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def(
"fused_qk_norm_rope(Tensor(a!) qkv, int num_heads_q, int num_heads_k, int num_heads_v, int head_dim, float "
"fused_qk_norm_rope(Tensor(a!) qkv, int num_heads_q, int num_heads_k, int num_heads_v, int head_dim, int "
"rotary_dim, float "
"eps, Tensor q_weight, Tensor k_weight, float base, bool is_neox, Tensor position_ids, float factor, float "
"low, float high, float attention_factor, bool is_qk_norm) -> ()");
}

View File

@ -31,7 +31,9 @@ class AutoModelForCausalLM(Generic[TModel, TConfig]):
"") # Strip the appended EAGLE3
if hasattr(config.pretrained_config, "draft_vocab_size"):
model_arch = "EAGLE3" + model_arch
if model_arch == "DeepseekV3ForCausalLM" and config.spec_config is not None and config.spec_config.max_draft_len == 0:
if model_arch in (
"DeepseekV3ForCausalLM", "Glm4MoeForCausalLM"
) and config.spec_config is not None and config.spec_config.max_draft_len == 0:
model_arch = "MTPDraftModelForCausalLM"
cls = MODEL_CLASS_MAPPING.get(model_arch)

View File

@ -1,3 +1,4 @@
import inspect
import math
import os
from typing import Dict, List, Optional, Tuple
@ -8,14 +9,10 @@ from tqdm import tqdm
from transformers import PretrainedConfig
from tensorrt_llm._ipc_utils import can_access_peer
from tensorrt_llm._utils import get_sm_version, is_sm_100f
from tensorrt_llm._utils import get_sm_version
from tensorrt_llm.functional import PositionEmbeddingType
from tensorrt_llm.models.modeling_utils import QuantConfig
from tensorrt_llm.quantization.mode import QuantAlgo
from tensorrt_llm.quantization.utils.fp8_utils import (
resmooth_to_fp8_e8m0,
transform_sf_into_required_layout,
)
from ..attention_backend import AttentionMetadata
from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams
@ -29,7 +26,7 @@ from ..distributed import (
from ..model_config import ModelConfig
from ..modules.decoder_layer import DecoderLayer
from ..modules.embedding import Embedding
from ..modules.fused_moe import MoEWeightLoadingMode, create_moe
from ..modules.fused_moe import MoE, MoEWeightLoadingMode, create_moe
from ..modules.gated_mlp import GatedMLP
from ..modules.linear import Linear, TensorParallelMode
from ..modules.multi_stream_utils import maybe_execute_in_parallel
@ -39,7 +36,142 @@ from ..speculative import SpecMetadata
from ..utils import AuxStreamType, EventType, Fp4QuantizedTensor
from .modeling_deepseekv3 import DeepseekV3Gate, DeepseekV3MTPHead, moe_reduce_add_shared_output
from .modeling_speculative import SpecDecOneEngineForCausalLM
from .modeling_utils import DecoderModel, EagerFusionConfig, _load_weights_impl, register_auto_model
from .modeling_utils import (
DecoderModel,
EagerFusionConfig,
duplicate_kv_weight,
filter_weights,
register_auto_model,
)
class Glm4WeightLoader:
def __init__(self, model, is_draft_model: bool = False):
self.model = model
self.config = model.config
self.model_config = model.model_config
self.is_draft_model = is_draft_model
def load_weights(self, weights: Dict, allow_partial_loading: bool = False):
def rename_moe_weight(weights: Dict, rename_rules: Dict):
result = {}
for key, value in weights.items():
new_key = key
for old, new in rename_rules.items():
new_key = new_key.replace(old, new)
result[new_key] = value
return result
params_map = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
}
all_named_modules = dict(self.model.named_modules())
tp_size = (
1
if self.model_config.mapping.enable_attention_dp
else self.model_config.mapping.tp_size
)
num_kv_heads = (
self.config.num_key_value_heads
if hasattr(self.config, "num_key_value_heads")
and self.config.num_key_value_heads is not None
else self.config.num_attention_heads
)
for name, module in tqdm(all_named_modules.items(), desc="Loading weights"):
if len(module._parameters) <= 0 or name.startswith("draft_model"):
continue
else:
names = name.split(".")
if "model.layers" in name and int(names[2]) >= self.config.num_hidden_layers:
mtp_layer_idx = int(names[2]) - self.config.num_hidden_layers
names[2] = str(
mtp_layer_idx % self.config.num_nextn_predict_layers
+ self.config.num_hidden_layers
)
name = ".".join(names)
if names[-1] in params_map:
module_weights = []
for new_name in params_map[names[-1]]:
fw = filter_weights(".".join(names[:-1] + [new_name]), weights)
if new_name in ["k_proj", "v_proj"]:
num_kv_heads_list = (
[num_kv_heads] * len(fw)
if isinstance(num_kv_heads, int)
else num_kv_heads
)
fw = {
k: duplicate_kv_weight(
weight=v[:],
num_kv_heads=num_kv_heads_list[i],
tensor_parallel_size=tp_size,
)
if k in ["weight", "bias"]
else v
for i, (k, v) in enumerate(fw.items())
}
module_weights.append(fw)
module.load_weights(weights=module_weights)
elif names[-1] == "experts":
module_weights = filter_weights(name, weights)
module_weights = rename_moe_weight(
module_weights,
{
"down_proj": "w2",
"up_proj": "w3",
"gate_proj": "w1",
},
)
module.load_weights(
weights=[module_weights], allow_partial_loading=allow_partial_loading
)
elif names[-1] == "backend" and isinstance(module, MoE):
# Special case: ConfigurableMoE.backend (TRTLLMGenFusedMoE)
# Currently saved MoE weights don't include 'backend' in their names.
# After MoE refactoring, ConfigurableMoE now has a backend submodule,
# and weights loading is done in the backend, so module name includes '.backend'.
# We need to use parent module name (without .backend) to match saved weight names.
# After MoE refactoring is fully complete, all paths will follow this branch.
parent_name = ".".join(names[:-1])
module_weights = filter_weights(parent_name, weights)
module_weights = rename_moe_weight(
module_weights,
{
"down_proj": "w2",
"up_proj": "w3",
"gate_proj": "w1",
},
)
module.load_weights(
weights=[module_weights], allow_partial_loading=allow_partial_loading
)
elif names[-1] == "self_attn":
continue
elif names[-1] == "next_layer_layernorm":
continue
else:
module_weights = filter_weights(name, weights)
if hasattr(module, "load_weights"):
args = inspect.getfullargspec(module.load_weights).args
if "allow_partial_loading" not in args:
assert not allow_partial_loading, (
"allow_partial_loading is not supported for this model"
)
module.load_weights(weights=[module_weights])
else:
module.load_weights(
weights=[module_weights],
allow_partial_loading=allow_partial_loading,
)
else:
for n, p in module.named_parameters():
if not allow_partial_loading:
assert n in module_weights
if n in module_weights:
p.data.copy_(module_weights[n][:])
class Glm4Attention(QKNormRoPEAttention):
@ -61,7 +193,7 @@ class Glm4Attention(QKNormRoPEAttention):
max_position_embeddings=config.max_position_embeddings,
bias=config.attention_bias,
pos_embd_params=pos_embd_params,
fuse_qk_norm_rope=False,
fuse_qk_norm_rope=True,
layer_idx=layer_idx,
dtype=config.torch_dtype,
dense_bias=False,
@ -98,7 +230,7 @@ class Glm4MoE(nn.Module):
topk_group=config.topk_group,
routed_scaling_factor=config.routed_scaling_factor,
dtype=dtype,
fuse_routing_kernel=False,
fuse_routing_kernel=True,
apply_routing=False,
moe_backend=model_config.moe_backend,
)
@ -872,40 +1004,11 @@ class Glm4MoeForCausalLM(SpecDecOneEngineForCausalLM[Glm4Model, PretrainedConfig
**kwargs,
)
def load_weights(self, weights: Dict):
# model.layers.91.mlp.experts.75.up_proj.weight_scale_2
_load_weights_impl(
self,
weights,
params_map={
r"(?!.*shared_experts)(?=.*experts?)(.*?)up_proj(.*)": r"\1w3\2",
r"(?!.*shared_experts)(?=.*experts?)(.*?)down_proj(.*)": r"\1w2\2",
r"(?!.*shared_experts)(?=.*experts?)(.*?)gate_proj(.*)": r"\1w1\2",
},
)
def load_weights(self, weights: Dict, allow_partial_loading: bool = False):
weight_loader = Glm4WeightLoader(self)
weight_loader.load_weights(weights, allow_partial_loading=allow_partial_loading)
def post_load_weights(self):
all_named_modules = dict(self.model.named_modules())
for name, module in tqdm(all_named_modules.items(), desc="Post loading weights"):
if len(module._parameters) <= 0 or name.startswith("draft_model"):
continue
else:
if (
self.model_config.quant_config.layer_quant_mode.has_fp8_block_scales()
and is_sm_100f()
and hasattr(module, "weight_scale")
):
weight, weight_scale = resmooth_to_fp8_e8m0(module.weight, module.weight_scale)
transfromed_scale = transform_sf_into_required_layout(
weight_scale,
mn=weight.shape[0],
k=weight.shape[1],
recipe=(1, 128, 128),
is_sfa=False,
)
module.weight = nn.Parameter(weight, requires_grad=False)
module.weight_scale = nn.Parameter(transfromed_scale, requires_grad=False)
for idx, layer in enumerate(self.model.layers[: self.config.num_hidden_layers]):
if idx == self.config.num_hidden_layers - 1:
layer.next_layer_layernorm = self.model.norm

View File

@ -428,12 +428,22 @@ class MTPDraftModel(nn.Module):
torch.cuda.Stream]):
super().__init__()
# Import here to avoid circular import
from .modeling_deepseekv3 import DeepseekV3MTP
mtp_layer = DeepseekV3MTP(model_config,
layer_idx,
aux_stream_dict,
is_separate_draft_engine=True)
model_type = model_config.pretrained_config.model_type
if model_type == "glm4_moe":
from .modeling_glm import Glm4MTP
mtp_layer = Glm4MTP(model_config,
layer_idx,
aux_stream_dict,
is_separate_draft_engine=True)
elif model_type in ["deepseek_v3", "deepseek_v32"]:
from .modeling_deepseekv3 import DeepseekV3MTP
mtp_layer = DeepseekV3MTP(model_config,
layer_idx,
aux_stream_dict,
is_separate_draft_engine=True)
else:
raise ValueError(
f"MTPDraftModel does not support model_type: {model_type}")
setattr(self, f"layers.{layer_idx}", mtp_layer)
self.layers = mtp_layer
self.layer_idx = layer_idx
@ -493,8 +503,18 @@ class MTPDraftModelForCausalLM(DecoderModelForCausalLM[MTPDraftModel,
def load_weights(self, weights: Dict):
# Import here to avoid circular import
from .modeling_deepseekv3 import DeepseekV3WeightLoader
weight_loader = DeepseekV3WeightLoader(self, is_draft_model=True)
model_type = self.model_config.pretrained_config.model_type
match model_type:
case "glm4_moe":
from .modeling_glm import Glm4WeightLoader
weight_loader = Glm4WeightLoader(self, is_draft_model=True)
case "deepseek_v3" | "deepseek_v32":
from .modeling_deepseekv3 import DeepseekV3WeightLoader
weight_loader = DeepseekV3WeightLoader(self,
is_draft_model=True)
case _:
raise ValueError(
f"Model type {model_type} not supported for MTP")
weight_loader.load_weights(weights)
def load_weights_from_target_model(self,

View File

@ -229,9 +229,14 @@ class QKNormRoPEAttention(Attention):
def apply_qk_norm_rope(self, qkv, position_ids):
factor, low, high, attention_factor = compute_yarn_parameters(
self.pretrained_config)
partial_rotary_factor = self.pretrained_config.partial_rotary_factor if hasattr(
self.pretrained_config, "partial_rotary_factor") else 1.0
rotary_dim = int(self.head_dim * partial_rotary_factor)
torch.ops.trtllm.fused_qk_norm_rope(
qkv, self.num_heads, self.num_key_value_heads,
self.num_key_value_heads, self.head_dim,
self.num_key_value_heads, self.head_dim, rotary_dim,
self.q_norm.variance_epsilon, self.q_norm.weight,
self.k_norm.weight,
self.pos_embd_params.rope.theta, self.pos_embd_params.is_neox,

View File

@ -272,6 +272,8 @@ ByteDance-Seed/Seed-OSS-36B-Instruct:
- accuracy: 90.8
zai-org/GLM-4.6:
- accuracy: 81.3
- spec_dec_algo: MTP
accuracy: 81.3
- quant_algo: NVFP4
spec_dec_algo: MTP
accuracy: 88.0

View File

@ -2869,8 +2869,11 @@ class TestGLM4_6(LlmapiAccuracyTestHarness):
@pytest.mark.skip_less_device(4)
@pytest.mark.parametrize(
"tp_size,pp_size,mtp_nextn,cuda_graph,overlap_scheduler,chunked_prefill,max_batch_size,moe_backend",
[pytest.param(4, 1, 2, True, True, True, 16, "CUTLASS")],
ids=["throughput"])
[
pytest.param(4, 1, 2, True, True, True, 16, "CUTLASS"),
pytest.param(4, 1, 2, True, True, True, 16, "TRTLLM")
],
ids=["throughput", "throughput_trtllm"])
def test_nvfp4_multi_gpus(self, tp_size, pp_size, mtp_nextn, cuda_graph,
overlap_scheduler, chunked_prefill,
max_batch_size, moe_backend):
@ -2897,6 +2900,39 @@ class TestGLM4_6(LlmapiAccuracyTestHarness):
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)
@pytest.mark.skip_less_device(4)
@pytest.mark.parametrize(
"tp_size,cuda_graph,overlap_scheduler,chunked_prefill,max_batch_size,moe_backend",
[
pytest.param(4, True, True, True, 16, "CUTLASS"),
pytest.param(4, True, True, True, 16, "TRTLLM"),
],
ids=["2model", "2model_trtllm"])
def test_nvfp4_2_model_mtp(self, tp_size, cuda_graph, overlap_scheduler,
chunked_prefill, max_batch_size, moe_backend):
model_path = f"{llm_models_root()}/glm-4.6-fp4"
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.70)
pytorch_config = dict(
disable_overlap_scheduler=not overlap_scheduler,
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
moe_config=MoeConfig(backend=moe_backend))
mtp_config = MTPDecodingConfig(num_nextn_predict_layers=3,
mtp_eagle_one_model=False,
speculative_model_dir=model_path)
with LLM(model_path,
max_batch_size=max_batch_size,
tensor_parallel_size=tp_size,
kv_cache_config=kv_cache_config,
**pytorch_config,
speculative_config=mtp_config,
enable_chunked_prefill=chunked_prefill) as llm:
assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)
@pytest.mark.timeout(7200)
@pytest.mark.skip_less_device_memory(100000)

View File

@ -507,6 +507,9 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[latency
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus_chunked_prefill[baseline_fp8kv]
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus_chunked_prefill[latency]
accuracy/test_llm_api_pytorch.py::TestGLM4_6::test_nvfp4_multi_gpus[throughput]
accuracy/test_llm_api_pytorch.py::TestGLM4_6::test_nvfp4_multi_gpus[throughput_trtllm]
accuracy/test_llm_api_pytorch.py::TestGLM4_6::test_nvfp4_2_model_mtp[2model]
accuracy/test_llm_api_pytorch.py::TestGLM4_6::test_nvfp4_2_model_mtp[2model_trtllm]
accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_fp8_block_scales[latency]
accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_bf16[multi_gpus_no_cache]
accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8_block_scales[latency-torch_compile=False]

View File

@ -8,8 +8,8 @@ from tensorrt_llm._torch.modules.rotary_embedding import RotaryEmbedding
@torch.inference_mode()
def torch_ref_rms_norm_rope(qkv, num_heads_q, num_heads_k, num_heads_v,
head_dim, eps, q_weight, k_weight, base, is_neox,
position_ids):
head_dim, rotary_dim, eps, q_weight, k_weight, base,
is_neox, position_ids):
"""
PyTorch reference implementation of RMSNorm+RoPE for verification.
@ -22,6 +22,7 @@ def torch_ref_rms_norm_rope(qkv, num_heads_q, num_heads_k, num_heads_v,
num_heads_k: Number of key heads
num_heads_v: Number of value heads (unused for normalization/RoPE but needed for tensor splitting)
head_dim: Dimension of each head
rotary_dim: Dimension for RoPE
eps: Epsilon value for RMS normalization
q_weight: RMSNorm weights for query [head_dim]
k_weight: RMSNorm weights for key [head_dim]
@ -65,7 +66,7 @@ def torch_ref_rms_norm_rope(qkv, num_heads_q, num_heads_k, num_heads_v,
# Create and apply RotaryEmbedding module
rope_params = RopeParams(
dim=head_dim, # Set the rotary dimension to match the head dimension
dim=rotary_dim, # Set the rotary dimension
theta=base, # Base value for RoPE calculations
max_positions=8192 # Large enough for any reasonable hidden size
)
@ -88,10 +89,12 @@ num_heads_groups = [
(16, 8, 8), # Qwen3-0.6B, Qwen3-1.7B
(32, 8, 8), # Qwen3-4B, Qwen3-8B, Qwen3-30B-A3B
(40, 8, 8), # Qwen3-14B
(64, 8, 8) # Qwen3-32B, Qwen3-235B-A22B
(64, 8, 8), # Qwen3-32B, Qwen3-235B-A22B
(24, 8, 8), # GLM 4.6
]
num_tokens_list = [1, 3, 8, 32, 256]
is_neox_list = [False, True]
partial_rotary_factor_list = [1.0, 0.5]
dtypes = [torch.bfloat16] # TODO: support float16
@ -100,8 +103,9 @@ dtypes = [torch.bfloat16] # TODO: support float16
@pytest.mark.parametrize("num_tokens", num_tokens_list)
@pytest.mark.parametrize("is_neox", is_neox_list)
@pytest.mark.parametrize("dtype", dtypes)
def test_fused_qk_norm_rope(head_dim, num_heads_group, num_tokens, is_neox,
dtype):
@pytest.mark.parametrize("partial_rotary_factor", partial_rotary_factor_list)
def test_fused_qk_norm_rope(head_dim, num_heads_group, num_tokens,
partial_rotary_factor, is_neox, dtype):
"""
Test the fused QK RMSNorm + RoPE operation with various configurations.
@ -143,18 +147,20 @@ def test_fused_qk_norm_rope(head_dim, num_heads_group, num_tokens, is_neox,
base = 10000.0
factor, low, high, attention_factor = 1.0, 0, 0, 1.0
rotary_dim = int(head_dim * partial_rotary_factor)
# Run the custom fusedQKNormRope operation
torch.ops.trtllm.fused_qk_norm_rope(qkv, num_heads_q, num_heads_k,
num_heads_v, head_dim, eps, q_weight,
k_weight, base, is_neox, position_ids,
factor, low, high, attention_factor,
True)
num_heads_v, head_dim, rotary_dim, eps,
q_weight, k_weight, base, is_neox,
position_ids, factor, low, high,
attention_factor, True)
output = qkv # This op is inplace
# Compute reference output using TensorRT LLM modules
ref_output = torch_ref_rms_norm_rope(qkv_copy, num_heads_q, num_heads_k,
num_heads_v, head_dim, eps, q_weight,
k_weight, base, is_neox, position_ids)
num_heads_v, head_dim, rotary_dim, eps,
q_weight, k_weight, base, is_neox,
position_ids)
# Compare outputs from custom kernel vs reference implementation
torch.testing.assert_close(