diff --git a/cpp/tensorrt_llm/kernels/fusedQKNormRopeKernel.cu b/cpp/tensorrt_llm/kernels/fusedQKNormRopeKernel.cu index 73326af8c4..a73ea79270 100644 --- a/cpp/tensorrt_llm/kernels/fusedQKNormRopeKernel.cu +++ b/cpp/tensorrt_llm/kernels/fusedQKNormRopeKernel.cu @@ -66,6 +66,7 @@ __global__ void fusedQKNormRopeKernel( int const num_heads_q, // Number of query heads int const num_heads_k, // Number of key heads int const num_heads_v, // Number of value heads + int const rotary_dim, // Dimension for RoPE float const eps, // Epsilon for RMS normalization __nv_bfloat16 const* q_weight, // RMSNorm weights for query __nv_bfloat16 const* k_weight, // RMSNorm weights for key @@ -184,7 +185,7 @@ __global__ void fusedQKNormRopeKernel( int dim_idx = laneId * numElemsPerThread + i; int half_dim = dim_idx / 2; - float freq = powf(base, -2.0f * half_dim / static_cast(head_dim)); + float freq = powf(base, -2.0f * half_dim / static_cast(rotary_dim)); if (factor != 1.0f) { @@ -212,19 +213,20 @@ __global__ void fusedQKNormRopeKernel( { // Before data exchange with in warp, we need to sync. __syncwarp(); + int pairOffset = (rotary_dim / 2) / numElemsPerThread; // Get the data from the other half of the warp. Fill cos_vals and sin_vals. for (int i = 0; i < numElemsPerThread; i++) { - elements2[i] = __shfl_xor_sync(0xffffffff, elements[i], 16); - if (laneId < 16) + elements2[i] = __shfl_xor_sync(0xffffffff, elements[i], pairOffset); + if (laneId < pairOffset) { elements2[i] = -elements2[i]; } int dim_idx = laneId * numElemsPerThread + i; - dim_idx = (dim_idx * 2) % head_dim; + dim_idx = (dim_idx * 2) % rotary_dim; int half_dim = dim_idx / 2; - float freq = powf(base, -2.0f * half_dim / static_cast(head_dim)); + float freq = powf(base, -2.0f * half_dim / static_cast(rotary_dim)); if (factor != 1.0f) { @@ -251,9 +253,25 @@ __global__ void fusedQKNormRopeKernel( __syncwarp(); } - for (int i = 0; i < numElemsPerThread; i++) + bool const is_full_rope = (rotary_dim == head_dim); + if (is_full_rope) { - elements[i] = (elements[i] * cos_vals[i] + elements2[i] * sin_vals[i]) * attention_factor; + for (int i = 0; i < numElemsPerThread; i++) + { + elements[i] = (elements[i] * cos_vals[i] + elements2[i] * sin_vals[i]) * attention_factor; + } + } + else + { + for (int i = 0; i < numElemsPerThread; i++) + { + int dim_idx = laneId * numElemsPerThread + i; + + if (dim_idx < rotary_dim) + { + elements[i] = (elements[i] * cos_vals[i] + elements2[i] * sin_vals[i]) * attention_factor; + } + } } // Store. @@ -284,14 +302,23 @@ __global__ void fusedQKNormRopeKernel( } void launchFusedQKNormRope(void* qkv, int const num_tokens, int const num_heads_q, int const num_heads_k, - int const num_heads_v, int const head_dim, float const eps, void const* q_weight, void const* k_weight, - float const base, bool const interleave, int const* position_ids, float factor, float low, float high, - float attention_factor, cudaStream_t stream, bool is_qk_norm) + int const num_heads_v, int const head_dim, int const rotary_dim, float const eps, void const* q_weight, + void const* k_weight, float const base, bool const interleave, int const* position_ids, float factor, float low, + float high, float attention_factor, cudaStream_t stream, bool is_qk_norm) { if (factor == 1.0f) { TLLM_CHECK(attention_factor == 1.0f); } + + TLLM_CHECK_WITH_INFO(rotary_dim % 2 == 0, "rotary_dim must be even"); + if (!interleave) + { + // To allow warp-level pairing for partial rope + TLLM_CHECK_WITH_INFO( + (rotary_dim * 16) % head_dim == 0, "Unsupported rotary dimension for fusedQKNormRope: %d", rotary_dim); + } + constexpr int blockSize = 256; int const warpsPerBlock = blockSize / 32; @@ -309,7 +336,7 @@ void launchFusedQKNormRope(void* qkv, int const num_tokens, int const num_heads_ case 64: DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { fusedQKNormRopeKernel<64, INTERLEAVE><<>>( - reinterpret_cast<__nv_bfloat16*>(qkv), num_heads_q, num_heads_k, num_heads_v, eps, + reinterpret_cast<__nv_bfloat16*>(qkv), num_heads_q, num_heads_k, num_heads_v, rotary_dim, eps, reinterpret_cast<__nv_bfloat16 const*>(q_weight), reinterpret_cast<__nv_bfloat16 const*>(k_weight), base, position_ids, num_tokens, factor, low, high, attention_factor, is_qk_norm); }); @@ -317,7 +344,7 @@ void launchFusedQKNormRope(void* qkv, int const num_tokens, int const num_heads_ case 128: DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { fusedQKNormRopeKernel<128, INTERLEAVE><<>>( - reinterpret_cast<__nv_bfloat16*>(qkv), num_heads_q, num_heads_k, num_heads_v, eps, + reinterpret_cast<__nv_bfloat16*>(qkv), num_heads_q, num_heads_k, num_heads_v, rotary_dim, eps, reinterpret_cast<__nv_bfloat16 const*>(q_weight), reinterpret_cast<__nv_bfloat16 const*>(k_weight), base, position_ids, num_tokens, factor, low, high, attention_factor, is_qk_norm); }); @@ -325,7 +352,7 @@ void launchFusedQKNormRope(void* qkv, int const num_tokens, int const num_heads_ case 256: DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { fusedQKNormRopeKernel<256, INTERLEAVE><<>>( - reinterpret_cast<__nv_bfloat16*>(qkv), num_heads_q, num_heads_k, num_heads_v, eps, + reinterpret_cast<__nv_bfloat16*>(qkv), num_heads_q, num_heads_k, num_heads_v, rotary_dim, eps, reinterpret_cast<__nv_bfloat16 const*>(q_weight), reinterpret_cast<__nv_bfloat16 const*>(k_weight), base, position_ids, num_tokens, factor, low, high, attention_factor, is_qk_norm); }); diff --git a/cpp/tensorrt_llm/kernels/fusedQKNormRopeKernel.h b/cpp/tensorrt_llm/kernels/fusedQKNormRopeKernel.h index 7dab7dbbb2..c976f2a0fe 100644 --- a/cpp/tensorrt_llm/kernels/fusedQKNormRopeKernel.h +++ b/cpp/tensorrt_llm/kernels/fusedQKNormRopeKernel.h @@ -33,6 +33,7 @@ void launchFusedQKNormRope( int const num_heads_k, // Number of key heads int const num_heads_v, // Number of value heads int const head_dim, // Dimension per head + int const rotary_dim, // Dimension for RoPE float const eps, // Epsilon for RMS normalization void const* q_weight, // RMSNorm weights for query [head_dim] void const* k_weight, // RMSNorm weights for key [head_dim] diff --git a/cpp/tensorrt_llm/thop/fusedQKNormRopeOp.cpp b/cpp/tensorrt_llm/thop/fusedQKNormRopeOp.cpp index 14bf8578dc..a6635c0285 100644 --- a/cpp/tensorrt_llm/thop/fusedQKNormRopeOp.cpp +++ b/cpp/tensorrt_llm/thop/fusedQKNormRopeOp.cpp @@ -34,6 +34,7 @@ void fused_qk_norm_rope( int64_t num_heads_k, // Number of key heads int64_t num_heads_v, // Number of value heads int64_t head_dim, // Dimension per head + int64_t rotary_dim, // Dimension for RoPE double eps, // Epsilon for RMS normalization torch::Tensor& q_weight, // RMSNorm weights for query [head_dim] torch::Tensor& k_weight, // RMSNorm weights for key [head_dim] @@ -72,9 +73,9 @@ void fused_qk_norm_rope( tensorrt_llm::kernels::launchFusedQKNormRope(reinterpret_cast<__nv_bfloat16*>(qkv.data_ptr()), static_cast(num_tokens), static_cast(num_heads_q), static_cast(num_heads_k), - static_cast(num_heads_v), static_cast(head_dim), static_cast(eps), - reinterpret_cast<__nv_bfloat16*>(q_weight.data_ptr()), reinterpret_cast<__nv_bfloat16*>(k_weight.data_ptr()), - static_cast(base), + static_cast(num_heads_v), static_cast(head_dim), static_cast(rotary_dim), + static_cast(eps), reinterpret_cast<__nv_bfloat16*>(q_weight.data_ptr()), + reinterpret_cast<__nv_bfloat16*>(k_weight.data_ptr()), static_cast(base), !is_neox, // interleave reinterpret_cast(position_ids.data_ptr()), static_cast(factor), static_cast(low), static_cast(high), static_cast(attention_factor), stream, is_qk_norm); @@ -84,7 +85,8 @@ void fused_qk_norm_rope( TORCH_LIBRARY_FRAGMENT(trtllm, m) { m.def( - "fused_qk_norm_rope(Tensor(a!) qkv, int num_heads_q, int num_heads_k, int num_heads_v, int head_dim, float " + "fused_qk_norm_rope(Tensor(a!) qkv, int num_heads_q, int num_heads_k, int num_heads_v, int head_dim, int " + "rotary_dim, float " "eps, Tensor q_weight, Tensor k_weight, float base, bool is_neox, Tensor position_ids, float factor, float " "low, float high, float attention_factor, bool is_qk_norm) -> ()"); } diff --git a/tensorrt_llm/_torch/models/modeling_auto.py b/tensorrt_llm/_torch/models/modeling_auto.py index ff48edc5cb..84c8f73c5a 100644 --- a/tensorrt_llm/_torch/models/modeling_auto.py +++ b/tensorrt_llm/_torch/models/modeling_auto.py @@ -31,7 +31,9 @@ class AutoModelForCausalLM(Generic[TModel, TConfig]): "") # Strip the appended EAGLE3 if hasattr(config.pretrained_config, "draft_vocab_size"): model_arch = "EAGLE3" + model_arch - if model_arch == "DeepseekV3ForCausalLM" and config.spec_config is not None and config.spec_config.max_draft_len == 0: + if model_arch in ( + "DeepseekV3ForCausalLM", "Glm4MoeForCausalLM" + ) and config.spec_config is not None and config.spec_config.max_draft_len == 0: model_arch = "MTPDraftModelForCausalLM" cls = MODEL_CLASS_MAPPING.get(model_arch) diff --git a/tensorrt_llm/_torch/models/modeling_glm.py b/tensorrt_llm/_torch/models/modeling_glm.py index be300bcf08..868e43195b 100644 --- a/tensorrt_llm/_torch/models/modeling_glm.py +++ b/tensorrt_llm/_torch/models/modeling_glm.py @@ -1,3 +1,4 @@ +import inspect import math import os from typing import Dict, List, Optional, Tuple @@ -8,14 +9,10 @@ from tqdm import tqdm from transformers import PretrainedConfig from tensorrt_llm._ipc_utils import can_access_peer -from tensorrt_llm._utils import get_sm_version, is_sm_100f +from tensorrt_llm._utils import get_sm_version from tensorrt_llm.functional import PositionEmbeddingType from tensorrt_llm.models.modeling_utils import QuantConfig from tensorrt_llm.quantization.mode import QuantAlgo -from tensorrt_llm.quantization.utils.fp8_utils import ( - resmooth_to_fp8_e8m0, - transform_sf_into_required_layout, -) from ..attention_backend import AttentionMetadata from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams @@ -29,7 +26,7 @@ from ..distributed import ( from ..model_config import ModelConfig from ..modules.decoder_layer import DecoderLayer from ..modules.embedding import Embedding -from ..modules.fused_moe import MoEWeightLoadingMode, create_moe +from ..modules.fused_moe import MoE, MoEWeightLoadingMode, create_moe from ..modules.gated_mlp import GatedMLP from ..modules.linear import Linear, TensorParallelMode from ..modules.multi_stream_utils import maybe_execute_in_parallel @@ -39,7 +36,142 @@ from ..speculative import SpecMetadata from ..utils import AuxStreamType, EventType, Fp4QuantizedTensor from .modeling_deepseekv3 import DeepseekV3Gate, DeepseekV3MTPHead, moe_reduce_add_shared_output from .modeling_speculative import SpecDecOneEngineForCausalLM -from .modeling_utils import DecoderModel, EagerFusionConfig, _load_weights_impl, register_auto_model +from .modeling_utils import ( + DecoderModel, + EagerFusionConfig, + duplicate_kv_weight, + filter_weights, + register_auto_model, +) + + +class Glm4WeightLoader: + def __init__(self, model, is_draft_model: bool = False): + self.model = model + self.config = model.config + self.model_config = model.model_config + self.is_draft_model = is_draft_model + + def load_weights(self, weights: Dict, allow_partial_loading: bool = False): + def rename_moe_weight(weights: Dict, rename_rules: Dict): + result = {} + for key, value in weights.items(): + new_key = key + for old, new in rename_rules.items(): + new_key = new_key.replace(old, new) + result[new_key] = value + return result + + params_map = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], + } + all_named_modules = dict(self.model.named_modules()) + + tp_size = ( + 1 + if self.model_config.mapping.enable_attention_dp + else self.model_config.mapping.tp_size + ) + num_kv_heads = ( + self.config.num_key_value_heads + if hasattr(self.config, "num_key_value_heads") + and self.config.num_key_value_heads is not None + else self.config.num_attention_heads + ) + + for name, module in tqdm(all_named_modules.items(), desc="Loading weights"): + if len(module._parameters) <= 0 or name.startswith("draft_model"): + continue + else: + names = name.split(".") + if "model.layers" in name and int(names[2]) >= self.config.num_hidden_layers: + mtp_layer_idx = int(names[2]) - self.config.num_hidden_layers + names[2] = str( + mtp_layer_idx % self.config.num_nextn_predict_layers + + self.config.num_hidden_layers + ) + name = ".".join(names) + + if names[-1] in params_map: + module_weights = [] + for new_name in params_map[names[-1]]: + fw = filter_weights(".".join(names[:-1] + [new_name]), weights) + if new_name in ["k_proj", "v_proj"]: + num_kv_heads_list = ( + [num_kv_heads] * len(fw) + if isinstance(num_kv_heads, int) + else num_kv_heads + ) + fw = { + k: duplicate_kv_weight( + weight=v[:], + num_kv_heads=num_kv_heads_list[i], + tensor_parallel_size=tp_size, + ) + if k in ["weight", "bias"] + else v + for i, (k, v) in enumerate(fw.items()) + } + module_weights.append(fw) + module.load_weights(weights=module_weights) + elif names[-1] == "experts": + module_weights = filter_weights(name, weights) + module_weights = rename_moe_weight( + module_weights, + { + "down_proj": "w2", + "up_proj": "w3", + "gate_proj": "w1", + }, + ) + module.load_weights( + weights=[module_weights], allow_partial_loading=allow_partial_loading + ) + elif names[-1] == "backend" and isinstance(module, MoE): + # Special case: ConfigurableMoE.backend (TRTLLMGenFusedMoE) + # Currently saved MoE weights don't include 'backend' in their names. + # After MoE refactoring, ConfigurableMoE now has a backend submodule, + # and weights loading is done in the backend, so module name includes '.backend'. + # We need to use parent module name (without .backend) to match saved weight names. + # After MoE refactoring is fully complete, all paths will follow this branch. + parent_name = ".".join(names[:-1]) + module_weights = filter_weights(parent_name, weights) + module_weights = rename_moe_weight( + module_weights, + { + "down_proj": "w2", + "up_proj": "w3", + "gate_proj": "w1", + }, + ) + module.load_weights( + weights=[module_weights], allow_partial_loading=allow_partial_loading + ) + elif names[-1] == "self_attn": + continue + elif names[-1] == "next_layer_layernorm": + continue + else: + module_weights = filter_weights(name, weights) + if hasattr(module, "load_weights"): + args = inspect.getfullargspec(module.load_weights).args + if "allow_partial_loading" not in args: + assert not allow_partial_loading, ( + "allow_partial_loading is not supported for this model" + ) + module.load_weights(weights=[module_weights]) + else: + module.load_weights( + weights=[module_weights], + allow_partial_loading=allow_partial_loading, + ) + else: + for n, p in module.named_parameters(): + if not allow_partial_loading: + assert n in module_weights + if n in module_weights: + p.data.copy_(module_weights[n][:]) class Glm4Attention(QKNormRoPEAttention): @@ -61,7 +193,7 @@ class Glm4Attention(QKNormRoPEAttention): max_position_embeddings=config.max_position_embeddings, bias=config.attention_bias, pos_embd_params=pos_embd_params, - fuse_qk_norm_rope=False, + fuse_qk_norm_rope=True, layer_idx=layer_idx, dtype=config.torch_dtype, dense_bias=False, @@ -98,7 +230,7 @@ class Glm4MoE(nn.Module): topk_group=config.topk_group, routed_scaling_factor=config.routed_scaling_factor, dtype=dtype, - fuse_routing_kernel=False, + fuse_routing_kernel=True, apply_routing=False, moe_backend=model_config.moe_backend, ) @@ -872,40 +1004,11 @@ class Glm4MoeForCausalLM(SpecDecOneEngineForCausalLM[Glm4Model, PretrainedConfig **kwargs, ) - def load_weights(self, weights: Dict): - # model.layers.91.mlp.experts.75.up_proj.weight_scale_2 - _load_weights_impl( - self, - weights, - params_map={ - r"(?!.*shared_experts)(?=.*experts?)(.*?)up_proj(.*)": r"\1w3\2", - r"(?!.*shared_experts)(?=.*experts?)(.*?)down_proj(.*)": r"\1w2\2", - r"(?!.*shared_experts)(?=.*experts?)(.*?)gate_proj(.*)": r"\1w1\2", - }, - ) + def load_weights(self, weights: Dict, allow_partial_loading: bool = False): + weight_loader = Glm4WeightLoader(self) + weight_loader.load_weights(weights, allow_partial_loading=allow_partial_loading) def post_load_weights(self): - all_named_modules = dict(self.model.named_modules()) - for name, module in tqdm(all_named_modules.items(), desc="Post loading weights"): - if len(module._parameters) <= 0 or name.startswith("draft_model"): - continue - else: - if ( - self.model_config.quant_config.layer_quant_mode.has_fp8_block_scales() - and is_sm_100f() - and hasattr(module, "weight_scale") - ): - weight, weight_scale = resmooth_to_fp8_e8m0(module.weight, module.weight_scale) - transfromed_scale = transform_sf_into_required_layout( - weight_scale, - mn=weight.shape[0], - k=weight.shape[1], - recipe=(1, 128, 128), - is_sfa=False, - ) - module.weight = nn.Parameter(weight, requires_grad=False) - module.weight_scale = nn.Parameter(transfromed_scale, requires_grad=False) - for idx, layer in enumerate(self.model.layers[: self.config.num_hidden_layers]): if idx == self.config.num_hidden_layers - 1: layer.next_layer_layernorm = self.model.norm diff --git a/tensorrt_llm/_torch/models/modeling_speculative.py b/tensorrt_llm/_torch/models/modeling_speculative.py index a94e288172..17d3aba15f 100755 --- a/tensorrt_llm/_torch/models/modeling_speculative.py +++ b/tensorrt_llm/_torch/models/modeling_speculative.py @@ -428,12 +428,22 @@ class MTPDraftModel(nn.Module): torch.cuda.Stream]): super().__init__() # Import here to avoid circular import - from .modeling_deepseekv3 import DeepseekV3MTP - - mtp_layer = DeepseekV3MTP(model_config, - layer_idx, - aux_stream_dict, - is_separate_draft_engine=True) + model_type = model_config.pretrained_config.model_type + if model_type == "glm4_moe": + from .modeling_glm import Glm4MTP + mtp_layer = Glm4MTP(model_config, + layer_idx, + aux_stream_dict, + is_separate_draft_engine=True) + elif model_type in ["deepseek_v3", "deepseek_v32"]: + from .modeling_deepseekv3 import DeepseekV3MTP + mtp_layer = DeepseekV3MTP(model_config, + layer_idx, + aux_stream_dict, + is_separate_draft_engine=True) + else: + raise ValueError( + f"MTPDraftModel does not support model_type: {model_type}") setattr(self, f"layers.{layer_idx}", mtp_layer) self.layers = mtp_layer self.layer_idx = layer_idx @@ -493,8 +503,18 @@ class MTPDraftModelForCausalLM(DecoderModelForCausalLM[MTPDraftModel, def load_weights(self, weights: Dict): # Import here to avoid circular import - from .modeling_deepseekv3 import DeepseekV3WeightLoader - weight_loader = DeepseekV3WeightLoader(self, is_draft_model=True) + model_type = self.model_config.pretrained_config.model_type + match model_type: + case "glm4_moe": + from .modeling_glm import Glm4WeightLoader + weight_loader = Glm4WeightLoader(self, is_draft_model=True) + case "deepseek_v3" | "deepseek_v32": + from .modeling_deepseekv3 import DeepseekV3WeightLoader + weight_loader = DeepseekV3WeightLoader(self, + is_draft_model=True) + case _: + raise ValueError( + f"Model type {model_type} not supported for MTP") weight_loader.load_weights(weights) def load_weights_from_target_model(self, diff --git a/tensorrt_llm/_torch/modules/qk_norm_attention.py b/tensorrt_llm/_torch/modules/qk_norm_attention.py index 5a794783a1..771e7f79a5 100644 --- a/tensorrt_llm/_torch/modules/qk_norm_attention.py +++ b/tensorrt_llm/_torch/modules/qk_norm_attention.py @@ -229,9 +229,14 @@ class QKNormRoPEAttention(Attention): def apply_qk_norm_rope(self, qkv, position_ids): factor, low, high, attention_factor = compute_yarn_parameters( self.pretrained_config) + + partial_rotary_factor = self.pretrained_config.partial_rotary_factor if hasattr( + self.pretrained_config, "partial_rotary_factor") else 1.0 + rotary_dim = int(self.head_dim * partial_rotary_factor) + torch.ops.trtllm.fused_qk_norm_rope( qkv, self.num_heads, self.num_key_value_heads, - self.num_key_value_heads, self.head_dim, + self.num_key_value_heads, self.head_dim, rotary_dim, self.q_norm.variance_epsilon, self.q_norm.weight, self.k_norm.weight, self.pos_embd_params.rope.theta, self.pos_embd_params.is_neox, diff --git a/tests/integration/defs/accuracy/references/gsm8k.yaml b/tests/integration/defs/accuracy/references/gsm8k.yaml index 33f7dddc6b..20143e4540 100644 --- a/tests/integration/defs/accuracy/references/gsm8k.yaml +++ b/tests/integration/defs/accuracy/references/gsm8k.yaml @@ -272,6 +272,8 @@ ByteDance-Seed/Seed-OSS-36B-Instruct: - accuracy: 90.8 zai-org/GLM-4.6: - accuracy: 81.3 + - spec_dec_algo: MTP + accuracy: 81.3 - quant_algo: NVFP4 spec_dec_algo: MTP accuracy: 88.0 diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 2f27c5dc18..538277ba0b 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -2869,8 +2869,11 @@ class TestGLM4_6(LlmapiAccuracyTestHarness): @pytest.mark.skip_less_device(4) @pytest.mark.parametrize( "tp_size,pp_size,mtp_nextn,cuda_graph,overlap_scheduler,chunked_prefill,max_batch_size,moe_backend", - [pytest.param(4, 1, 2, True, True, True, 16, "CUTLASS")], - ids=["throughput"]) + [ + pytest.param(4, 1, 2, True, True, True, 16, "CUTLASS"), + pytest.param(4, 1, 2, True, True, True, 16, "TRTLLM") + ], + ids=["throughput", "throughput_trtllm"]) def test_nvfp4_multi_gpus(self, tp_size, pp_size, mtp_nextn, cuda_graph, overlap_scheduler, chunked_prefill, max_batch_size, moe_backend): @@ -2897,6 +2900,39 @@ class TestGLM4_6(LlmapiAccuracyTestHarness): task = GSM8K(self.MODEL_NAME) task.evaluate(llm) + @pytest.mark.skip_less_device(4) + @pytest.mark.parametrize( + "tp_size,cuda_graph,overlap_scheduler,chunked_prefill,max_batch_size,moe_backend", + [ + pytest.param(4, True, True, True, 16, "CUTLASS"), + pytest.param(4, True, True, True, 16, "TRTLLM"), + ], + ids=["2model", "2model_trtllm"]) + def test_nvfp4_2_model_mtp(self, tp_size, cuda_graph, overlap_scheduler, + chunked_prefill, max_batch_size, moe_backend): + model_path = f"{llm_models_root()}/glm-4.6-fp4" + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.70) + pytorch_config = dict( + disable_overlap_scheduler=not overlap_scheduler, + cuda_graph_config=CudaGraphConfig() if cuda_graph else None, + moe_config=MoeConfig(backend=moe_backend)) + + mtp_config = MTPDecodingConfig(num_nextn_predict_layers=3, + mtp_eagle_one_model=False, + speculative_model_dir=model_path) + + with LLM(model_path, + max_batch_size=max_batch_size, + tensor_parallel_size=tp_size, + kv_cache_config=kv_cache_config, + **pytorch_config, + speculative_config=mtp_config, + enable_chunked_prefill=chunked_prefill) as llm: + + assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4 + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) + @pytest.mark.timeout(7200) @pytest.mark.skip_less_device_memory(100000) diff --git a/tests/integration/test_lists/qa/llm_function_core.txt b/tests/integration/test_lists/qa/llm_function_core.txt index 5b5ad88d3b..eab8fea284 100644 --- a/tests/integration/test_lists/qa/llm_function_core.txt +++ b/tests/integration/test_lists/qa/llm_function_core.txt @@ -507,6 +507,9 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[latency accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus_chunked_prefill[baseline_fp8kv] accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus_chunked_prefill[latency] accuracy/test_llm_api_pytorch.py::TestGLM4_6::test_nvfp4_multi_gpus[throughput] +accuracy/test_llm_api_pytorch.py::TestGLM4_6::test_nvfp4_multi_gpus[throughput_trtllm] +accuracy/test_llm_api_pytorch.py::TestGLM4_6::test_nvfp4_2_model_mtp[2model] +accuracy/test_llm_api_pytorch.py::TestGLM4_6::test_nvfp4_2_model_mtp[2model_trtllm] accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_fp8_block_scales[latency] accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_bf16[multi_gpus_no_cache] accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8_block_scales[latency-torch_compile=False] diff --git a/tests/unittest/_torch/thop/parallel/test_fused_qk_norm_rope.py b/tests/unittest/_torch/thop/parallel/test_fused_qk_norm_rope.py index ab8db650a4..565f8b3b58 100644 --- a/tests/unittest/_torch/thop/parallel/test_fused_qk_norm_rope.py +++ b/tests/unittest/_torch/thop/parallel/test_fused_qk_norm_rope.py @@ -8,8 +8,8 @@ from tensorrt_llm._torch.modules.rotary_embedding import RotaryEmbedding @torch.inference_mode() def torch_ref_rms_norm_rope(qkv, num_heads_q, num_heads_k, num_heads_v, - head_dim, eps, q_weight, k_weight, base, is_neox, - position_ids): + head_dim, rotary_dim, eps, q_weight, k_weight, base, + is_neox, position_ids): """ PyTorch reference implementation of RMSNorm+RoPE for verification. @@ -22,6 +22,7 @@ def torch_ref_rms_norm_rope(qkv, num_heads_q, num_heads_k, num_heads_v, num_heads_k: Number of key heads num_heads_v: Number of value heads (unused for normalization/RoPE but needed for tensor splitting) head_dim: Dimension of each head + rotary_dim: Dimension for RoPE eps: Epsilon value for RMS normalization q_weight: RMSNorm weights for query [head_dim] k_weight: RMSNorm weights for key [head_dim] @@ -65,7 +66,7 @@ def torch_ref_rms_norm_rope(qkv, num_heads_q, num_heads_k, num_heads_v, # Create and apply RotaryEmbedding module rope_params = RopeParams( - dim=head_dim, # Set the rotary dimension to match the head dimension + dim=rotary_dim, # Set the rotary dimension theta=base, # Base value for RoPE calculations max_positions=8192 # Large enough for any reasonable hidden size ) @@ -88,10 +89,12 @@ num_heads_groups = [ (16, 8, 8), # Qwen3-0.6B, Qwen3-1.7B (32, 8, 8), # Qwen3-4B, Qwen3-8B, Qwen3-30B-A3B (40, 8, 8), # Qwen3-14B - (64, 8, 8) # Qwen3-32B, Qwen3-235B-A22B + (64, 8, 8), # Qwen3-32B, Qwen3-235B-A22B + (24, 8, 8), # GLM 4.6 ] num_tokens_list = [1, 3, 8, 32, 256] is_neox_list = [False, True] +partial_rotary_factor_list = [1.0, 0.5] dtypes = [torch.bfloat16] # TODO: support float16 @@ -100,8 +103,9 @@ dtypes = [torch.bfloat16] # TODO: support float16 @pytest.mark.parametrize("num_tokens", num_tokens_list) @pytest.mark.parametrize("is_neox", is_neox_list) @pytest.mark.parametrize("dtype", dtypes) -def test_fused_qk_norm_rope(head_dim, num_heads_group, num_tokens, is_neox, - dtype): +@pytest.mark.parametrize("partial_rotary_factor", partial_rotary_factor_list) +def test_fused_qk_norm_rope(head_dim, num_heads_group, num_tokens, + partial_rotary_factor, is_neox, dtype): """ Test the fused QK RMSNorm + RoPE operation with various configurations. @@ -143,18 +147,20 @@ def test_fused_qk_norm_rope(head_dim, num_heads_group, num_tokens, is_neox, base = 10000.0 factor, low, high, attention_factor = 1.0, 0, 0, 1.0 + rotary_dim = int(head_dim * partial_rotary_factor) # Run the custom fusedQKNormRope operation torch.ops.trtllm.fused_qk_norm_rope(qkv, num_heads_q, num_heads_k, - num_heads_v, head_dim, eps, q_weight, - k_weight, base, is_neox, position_ids, - factor, low, high, attention_factor, - True) + num_heads_v, head_dim, rotary_dim, eps, + q_weight, k_weight, base, is_neox, + position_ids, factor, low, high, + attention_factor, True) output = qkv # This op is inplace # Compute reference output using TensorRT LLM modules ref_output = torch_ref_rms_norm_rope(qkv_copy, num_heads_q, num_heads_k, - num_heads_v, head_dim, eps, q_weight, - k_weight, base, is_neox, position_ids) + num_heads_v, head_dim, rotary_dim, eps, + q_weight, k_weight, base, is_neox, + position_ids) # Compare outputs from custom kernel vs reference implementation torch.testing.assert_close(