mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None][feat] Fused kernels (qknormrope + moe routing) and two-model MTP support for glm4moe (#9852)
Signed-off-by: Xuanyu Chen <xuanyuc@nvidia.com>
This commit is contained in:
parent
64d7796234
commit
a5a37227d6
@ -66,6 +66,7 @@ __global__ void fusedQKNormRopeKernel(
|
||||
int const num_heads_q, // Number of query heads
|
||||
int const num_heads_k, // Number of key heads
|
||||
int const num_heads_v, // Number of value heads
|
||||
int const rotary_dim, // Dimension for RoPE
|
||||
float const eps, // Epsilon for RMS normalization
|
||||
__nv_bfloat16 const* q_weight, // RMSNorm weights for query
|
||||
__nv_bfloat16 const* k_weight, // RMSNorm weights for key
|
||||
@ -184,7 +185,7 @@ __global__ void fusedQKNormRopeKernel(
|
||||
|
||||
int dim_idx = laneId * numElemsPerThread + i;
|
||||
int half_dim = dim_idx / 2;
|
||||
float freq = powf(base, -2.0f * half_dim / static_cast<float>(head_dim));
|
||||
float freq = powf(base, -2.0f * half_dim / static_cast<float>(rotary_dim));
|
||||
|
||||
if (factor != 1.0f)
|
||||
{
|
||||
@ -212,19 +213,20 @@ __global__ void fusedQKNormRopeKernel(
|
||||
{
|
||||
// Before data exchange with in warp, we need to sync.
|
||||
__syncwarp();
|
||||
int pairOffset = (rotary_dim / 2) / numElemsPerThread;
|
||||
// Get the data from the other half of the warp. Fill cos_vals and sin_vals.
|
||||
for (int i = 0; i < numElemsPerThread; i++)
|
||||
{
|
||||
elements2[i] = __shfl_xor_sync(0xffffffff, elements[i], 16);
|
||||
if (laneId < 16)
|
||||
elements2[i] = __shfl_xor_sync(0xffffffff, elements[i], pairOffset);
|
||||
if (laneId < pairOffset)
|
||||
{
|
||||
elements2[i] = -elements2[i];
|
||||
}
|
||||
|
||||
int dim_idx = laneId * numElemsPerThread + i;
|
||||
dim_idx = (dim_idx * 2) % head_dim;
|
||||
dim_idx = (dim_idx * 2) % rotary_dim;
|
||||
int half_dim = dim_idx / 2;
|
||||
float freq = powf(base, -2.0f * half_dim / static_cast<float>(head_dim));
|
||||
float freq = powf(base, -2.0f * half_dim / static_cast<float>(rotary_dim));
|
||||
|
||||
if (factor != 1.0f)
|
||||
{
|
||||
@ -251,9 +253,25 @@ __global__ void fusedQKNormRopeKernel(
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
for (int i = 0; i < numElemsPerThread; i++)
|
||||
bool const is_full_rope = (rotary_dim == head_dim);
|
||||
if (is_full_rope)
|
||||
{
|
||||
elements[i] = (elements[i] * cos_vals[i] + elements2[i] * sin_vals[i]) * attention_factor;
|
||||
for (int i = 0; i < numElemsPerThread; i++)
|
||||
{
|
||||
elements[i] = (elements[i] * cos_vals[i] + elements2[i] * sin_vals[i]) * attention_factor;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for (int i = 0; i < numElemsPerThread; i++)
|
||||
{
|
||||
int dim_idx = laneId * numElemsPerThread + i;
|
||||
|
||||
if (dim_idx < rotary_dim)
|
||||
{
|
||||
elements[i] = (elements[i] * cos_vals[i] + elements2[i] * sin_vals[i]) * attention_factor;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Store.
|
||||
@ -284,14 +302,23 @@ __global__ void fusedQKNormRopeKernel(
|
||||
}
|
||||
|
||||
void launchFusedQKNormRope(void* qkv, int const num_tokens, int const num_heads_q, int const num_heads_k,
|
||||
int const num_heads_v, int const head_dim, float const eps, void const* q_weight, void const* k_weight,
|
||||
float const base, bool const interleave, int const* position_ids, float factor, float low, float high,
|
||||
float attention_factor, cudaStream_t stream, bool is_qk_norm)
|
||||
int const num_heads_v, int const head_dim, int const rotary_dim, float const eps, void const* q_weight,
|
||||
void const* k_weight, float const base, bool const interleave, int const* position_ids, float factor, float low,
|
||||
float high, float attention_factor, cudaStream_t stream, bool is_qk_norm)
|
||||
{
|
||||
if (factor == 1.0f)
|
||||
{
|
||||
TLLM_CHECK(attention_factor == 1.0f);
|
||||
}
|
||||
|
||||
TLLM_CHECK_WITH_INFO(rotary_dim % 2 == 0, "rotary_dim must be even");
|
||||
if (!interleave)
|
||||
{
|
||||
// To allow warp-level pairing for partial rope
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
(rotary_dim * 16) % head_dim == 0, "Unsupported rotary dimension for fusedQKNormRope: %d", rotary_dim);
|
||||
}
|
||||
|
||||
constexpr int blockSize = 256;
|
||||
|
||||
int const warpsPerBlock = blockSize / 32;
|
||||
@ -309,7 +336,7 @@ void launchFusedQKNormRope(void* qkv, int const num_tokens, int const num_heads_
|
||||
case 64:
|
||||
DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
|
||||
fusedQKNormRopeKernel<64, INTERLEAVE><<<gridDim, blockDim, 0, stream>>>(
|
||||
reinterpret_cast<__nv_bfloat16*>(qkv), num_heads_q, num_heads_k, num_heads_v, eps,
|
||||
reinterpret_cast<__nv_bfloat16*>(qkv), num_heads_q, num_heads_k, num_heads_v, rotary_dim, eps,
|
||||
reinterpret_cast<__nv_bfloat16 const*>(q_weight), reinterpret_cast<__nv_bfloat16 const*>(k_weight),
|
||||
base, position_ids, num_tokens, factor, low, high, attention_factor, is_qk_norm);
|
||||
});
|
||||
@ -317,7 +344,7 @@ void launchFusedQKNormRope(void* qkv, int const num_tokens, int const num_heads_
|
||||
case 128:
|
||||
DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
|
||||
fusedQKNormRopeKernel<128, INTERLEAVE><<<gridDim, blockDim, 0, stream>>>(
|
||||
reinterpret_cast<__nv_bfloat16*>(qkv), num_heads_q, num_heads_k, num_heads_v, eps,
|
||||
reinterpret_cast<__nv_bfloat16*>(qkv), num_heads_q, num_heads_k, num_heads_v, rotary_dim, eps,
|
||||
reinterpret_cast<__nv_bfloat16 const*>(q_weight), reinterpret_cast<__nv_bfloat16 const*>(k_weight),
|
||||
base, position_ids, num_tokens, factor, low, high, attention_factor, is_qk_norm);
|
||||
});
|
||||
@ -325,7 +352,7 @@ void launchFusedQKNormRope(void* qkv, int const num_tokens, int const num_heads_
|
||||
case 256:
|
||||
DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
|
||||
fusedQKNormRopeKernel<256, INTERLEAVE><<<gridDim, blockDim, 0, stream>>>(
|
||||
reinterpret_cast<__nv_bfloat16*>(qkv), num_heads_q, num_heads_k, num_heads_v, eps,
|
||||
reinterpret_cast<__nv_bfloat16*>(qkv), num_heads_q, num_heads_k, num_heads_v, rotary_dim, eps,
|
||||
reinterpret_cast<__nv_bfloat16 const*>(q_weight), reinterpret_cast<__nv_bfloat16 const*>(k_weight),
|
||||
base, position_ids, num_tokens, factor, low, high, attention_factor, is_qk_norm);
|
||||
});
|
||||
|
||||
@ -33,6 +33,7 @@ void launchFusedQKNormRope(
|
||||
int const num_heads_k, // Number of key heads
|
||||
int const num_heads_v, // Number of value heads
|
||||
int const head_dim, // Dimension per head
|
||||
int const rotary_dim, // Dimension for RoPE
|
||||
float const eps, // Epsilon for RMS normalization
|
||||
void const* q_weight, // RMSNorm weights for query [head_dim]
|
||||
void const* k_weight, // RMSNorm weights for key [head_dim]
|
||||
|
||||
@ -34,6 +34,7 @@ void fused_qk_norm_rope(
|
||||
int64_t num_heads_k, // Number of key heads
|
||||
int64_t num_heads_v, // Number of value heads
|
||||
int64_t head_dim, // Dimension per head
|
||||
int64_t rotary_dim, // Dimension for RoPE
|
||||
double eps, // Epsilon for RMS normalization
|
||||
torch::Tensor& q_weight, // RMSNorm weights for query [head_dim]
|
||||
torch::Tensor& k_weight, // RMSNorm weights for key [head_dim]
|
||||
@ -72,9 +73,9 @@ void fused_qk_norm_rope(
|
||||
|
||||
tensorrt_llm::kernels::launchFusedQKNormRope(reinterpret_cast<__nv_bfloat16*>(qkv.data_ptr()),
|
||||
static_cast<int>(num_tokens), static_cast<int>(num_heads_q), static_cast<int>(num_heads_k),
|
||||
static_cast<int>(num_heads_v), static_cast<int>(head_dim), static_cast<float>(eps),
|
||||
reinterpret_cast<__nv_bfloat16*>(q_weight.data_ptr()), reinterpret_cast<__nv_bfloat16*>(k_weight.data_ptr()),
|
||||
static_cast<float>(base),
|
||||
static_cast<int>(num_heads_v), static_cast<int>(head_dim), static_cast<int>(rotary_dim),
|
||||
static_cast<float>(eps), reinterpret_cast<__nv_bfloat16*>(q_weight.data_ptr()),
|
||||
reinterpret_cast<__nv_bfloat16*>(k_weight.data_ptr()), static_cast<float>(base),
|
||||
!is_neox, // interleave
|
||||
reinterpret_cast<int const*>(position_ids.data_ptr()), static_cast<float>(factor), static_cast<float>(low),
|
||||
static_cast<float>(high), static_cast<float>(attention_factor), stream, is_qk_norm);
|
||||
@ -84,7 +85,8 @@ void fused_qk_norm_rope(
|
||||
TORCH_LIBRARY_FRAGMENT(trtllm, m)
|
||||
{
|
||||
m.def(
|
||||
"fused_qk_norm_rope(Tensor(a!) qkv, int num_heads_q, int num_heads_k, int num_heads_v, int head_dim, float "
|
||||
"fused_qk_norm_rope(Tensor(a!) qkv, int num_heads_q, int num_heads_k, int num_heads_v, int head_dim, int "
|
||||
"rotary_dim, float "
|
||||
"eps, Tensor q_weight, Tensor k_weight, float base, bool is_neox, Tensor position_ids, float factor, float "
|
||||
"low, float high, float attention_factor, bool is_qk_norm) -> ()");
|
||||
}
|
||||
|
||||
@ -31,7 +31,9 @@ class AutoModelForCausalLM(Generic[TModel, TConfig]):
|
||||
"") # Strip the appended EAGLE3
|
||||
if hasattr(config.pretrained_config, "draft_vocab_size"):
|
||||
model_arch = "EAGLE3" + model_arch
|
||||
if model_arch == "DeepseekV3ForCausalLM" and config.spec_config is not None and config.spec_config.max_draft_len == 0:
|
||||
if model_arch in (
|
||||
"DeepseekV3ForCausalLM", "Glm4MoeForCausalLM"
|
||||
) and config.spec_config is not None and config.spec_config.max_draft_len == 0:
|
||||
model_arch = "MTPDraftModelForCausalLM"
|
||||
|
||||
cls = MODEL_CLASS_MAPPING.get(model_arch)
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import inspect
|
||||
import math
|
||||
import os
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
@ -8,14 +9,10 @@ from tqdm import tqdm
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from tensorrt_llm._ipc_utils import can_access_peer
|
||||
from tensorrt_llm._utils import get_sm_version, is_sm_100f
|
||||
from tensorrt_llm._utils import get_sm_version
|
||||
from tensorrt_llm.functional import PositionEmbeddingType
|
||||
from tensorrt_llm.models.modeling_utils import QuantConfig
|
||||
from tensorrt_llm.quantization.mode import QuantAlgo
|
||||
from tensorrt_llm.quantization.utils.fp8_utils import (
|
||||
resmooth_to_fp8_e8m0,
|
||||
transform_sf_into_required_layout,
|
||||
)
|
||||
|
||||
from ..attention_backend import AttentionMetadata
|
||||
from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams
|
||||
@ -29,7 +26,7 @@ from ..distributed import (
|
||||
from ..model_config import ModelConfig
|
||||
from ..modules.decoder_layer import DecoderLayer
|
||||
from ..modules.embedding import Embedding
|
||||
from ..modules.fused_moe import MoEWeightLoadingMode, create_moe
|
||||
from ..modules.fused_moe import MoE, MoEWeightLoadingMode, create_moe
|
||||
from ..modules.gated_mlp import GatedMLP
|
||||
from ..modules.linear import Linear, TensorParallelMode
|
||||
from ..modules.multi_stream_utils import maybe_execute_in_parallel
|
||||
@ -39,7 +36,142 @@ from ..speculative import SpecMetadata
|
||||
from ..utils import AuxStreamType, EventType, Fp4QuantizedTensor
|
||||
from .modeling_deepseekv3 import DeepseekV3Gate, DeepseekV3MTPHead, moe_reduce_add_shared_output
|
||||
from .modeling_speculative import SpecDecOneEngineForCausalLM
|
||||
from .modeling_utils import DecoderModel, EagerFusionConfig, _load_weights_impl, register_auto_model
|
||||
from .modeling_utils import (
|
||||
DecoderModel,
|
||||
EagerFusionConfig,
|
||||
duplicate_kv_weight,
|
||||
filter_weights,
|
||||
register_auto_model,
|
||||
)
|
||||
|
||||
|
||||
class Glm4WeightLoader:
|
||||
def __init__(self, model, is_draft_model: bool = False):
|
||||
self.model = model
|
||||
self.config = model.config
|
||||
self.model_config = model.model_config
|
||||
self.is_draft_model = is_draft_model
|
||||
|
||||
def load_weights(self, weights: Dict, allow_partial_loading: bool = False):
|
||||
def rename_moe_weight(weights: Dict, rename_rules: Dict):
|
||||
result = {}
|
||||
for key, value in weights.items():
|
||||
new_key = key
|
||||
for old, new in rename_rules.items():
|
||||
new_key = new_key.replace(old, new)
|
||||
result[new_key] = value
|
||||
return result
|
||||
|
||||
params_map = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
}
|
||||
all_named_modules = dict(self.model.named_modules())
|
||||
|
||||
tp_size = (
|
||||
1
|
||||
if self.model_config.mapping.enable_attention_dp
|
||||
else self.model_config.mapping.tp_size
|
||||
)
|
||||
num_kv_heads = (
|
||||
self.config.num_key_value_heads
|
||||
if hasattr(self.config, "num_key_value_heads")
|
||||
and self.config.num_key_value_heads is not None
|
||||
else self.config.num_attention_heads
|
||||
)
|
||||
|
||||
for name, module in tqdm(all_named_modules.items(), desc="Loading weights"):
|
||||
if len(module._parameters) <= 0 or name.startswith("draft_model"):
|
||||
continue
|
||||
else:
|
||||
names = name.split(".")
|
||||
if "model.layers" in name and int(names[2]) >= self.config.num_hidden_layers:
|
||||
mtp_layer_idx = int(names[2]) - self.config.num_hidden_layers
|
||||
names[2] = str(
|
||||
mtp_layer_idx % self.config.num_nextn_predict_layers
|
||||
+ self.config.num_hidden_layers
|
||||
)
|
||||
name = ".".join(names)
|
||||
|
||||
if names[-1] in params_map:
|
||||
module_weights = []
|
||||
for new_name in params_map[names[-1]]:
|
||||
fw = filter_weights(".".join(names[:-1] + [new_name]), weights)
|
||||
if new_name in ["k_proj", "v_proj"]:
|
||||
num_kv_heads_list = (
|
||||
[num_kv_heads] * len(fw)
|
||||
if isinstance(num_kv_heads, int)
|
||||
else num_kv_heads
|
||||
)
|
||||
fw = {
|
||||
k: duplicate_kv_weight(
|
||||
weight=v[:],
|
||||
num_kv_heads=num_kv_heads_list[i],
|
||||
tensor_parallel_size=tp_size,
|
||||
)
|
||||
if k in ["weight", "bias"]
|
||||
else v
|
||||
for i, (k, v) in enumerate(fw.items())
|
||||
}
|
||||
module_weights.append(fw)
|
||||
module.load_weights(weights=module_weights)
|
||||
elif names[-1] == "experts":
|
||||
module_weights = filter_weights(name, weights)
|
||||
module_weights = rename_moe_weight(
|
||||
module_weights,
|
||||
{
|
||||
"down_proj": "w2",
|
||||
"up_proj": "w3",
|
||||
"gate_proj": "w1",
|
||||
},
|
||||
)
|
||||
module.load_weights(
|
||||
weights=[module_weights], allow_partial_loading=allow_partial_loading
|
||||
)
|
||||
elif names[-1] == "backend" and isinstance(module, MoE):
|
||||
# Special case: ConfigurableMoE.backend (TRTLLMGenFusedMoE)
|
||||
# Currently saved MoE weights don't include 'backend' in their names.
|
||||
# After MoE refactoring, ConfigurableMoE now has a backend submodule,
|
||||
# and weights loading is done in the backend, so module name includes '.backend'.
|
||||
# We need to use parent module name (without .backend) to match saved weight names.
|
||||
# After MoE refactoring is fully complete, all paths will follow this branch.
|
||||
parent_name = ".".join(names[:-1])
|
||||
module_weights = filter_weights(parent_name, weights)
|
||||
module_weights = rename_moe_weight(
|
||||
module_weights,
|
||||
{
|
||||
"down_proj": "w2",
|
||||
"up_proj": "w3",
|
||||
"gate_proj": "w1",
|
||||
},
|
||||
)
|
||||
module.load_weights(
|
||||
weights=[module_weights], allow_partial_loading=allow_partial_loading
|
||||
)
|
||||
elif names[-1] == "self_attn":
|
||||
continue
|
||||
elif names[-1] == "next_layer_layernorm":
|
||||
continue
|
||||
else:
|
||||
module_weights = filter_weights(name, weights)
|
||||
if hasattr(module, "load_weights"):
|
||||
args = inspect.getfullargspec(module.load_weights).args
|
||||
if "allow_partial_loading" not in args:
|
||||
assert not allow_partial_loading, (
|
||||
"allow_partial_loading is not supported for this model"
|
||||
)
|
||||
module.load_weights(weights=[module_weights])
|
||||
else:
|
||||
module.load_weights(
|
||||
weights=[module_weights],
|
||||
allow_partial_loading=allow_partial_loading,
|
||||
)
|
||||
else:
|
||||
for n, p in module.named_parameters():
|
||||
if not allow_partial_loading:
|
||||
assert n in module_weights
|
||||
if n in module_weights:
|
||||
p.data.copy_(module_weights[n][:])
|
||||
|
||||
|
||||
class Glm4Attention(QKNormRoPEAttention):
|
||||
@ -61,7 +193,7 @@ class Glm4Attention(QKNormRoPEAttention):
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
bias=config.attention_bias,
|
||||
pos_embd_params=pos_embd_params,
|
||||
fuse_qk_norm_rope=False,
|
||||
fuse_qk_norm_rope=True,
|
||||
layer_idx=layer_idx,
|
||||
dtype=config.torch_dtype,
|
||||
dense_bias=False,
|
||||
@ -98,7 +230,7 @@ class Glm4MoE(nn.Module):
|
||||
topk_group=config.topk_group,
|
||||
routed_scaling_factor=config.routed_scaling_factor,
|
||||
dtype=dtype,
|
||||
fuse_routing_kernel=False,
|
||||
fuse_routing_kernel=True,
|
||||
apply_routing=False,
|
||||
moe_backend=model_config.moe_backend,
|
||||
)
|
||||
@ -872,40 +1004,11 @@ class Glm4MoeForCausalLM(SpecDecOneEngineForCausalLM[Glm4Model, PretrainedConfig
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Dict):
|
||||
# model.layers.91.mlp.experts.75.up_proj.weight_scale_2
|
||||
_load_weights_impl(
|
||||
self,
|
||||
weights,
|
||||
params_map={
|
||||
r"(?!.*shared_experts)(?=.*experts?)(.*?)up_proj(.*)": r"\1w3\2",
|
||||
r"(?!.*shared_experts)(?=.*experts?)(.*?)down_proj(.*)": r"\1w2\2",
|
||||
r"(?!.*shared_experts)(?=.*experts?)(.*?)gate_proj(.*)": r"\1w1\2",
|
||||
},
|
||||
)
|
||||
def load_weights(self, weights: Dict, allow_partial_loading: bool = False):
|
||||
weight_loader = Glm4WeightLoader(self)
|
||||
weight_loader.load_weights(weights, allow_partial_loading=allow_partial_loading)
|
||||
|
||||
def post_load_weights(self):
|
||||
all_named_modules = dict(self.model.named_modules())
|
||||
for name, module in tqdm(all_named_modules.items(), desc="Post loading weights"):
|
||||
if len(module._parameters) <= 0 or name.startswith("draft_model"):
|
||||
continue
|
||||
else:
|
||||
if (
|
||||
self.model_config.quant_config.layer_quant_mode.has_fp8_block_scales()
|
||||
and is_sm_100f()
|
||||
and hasattr(module, "weight_scale")
|
||||
):
|
||||
weight, weight_scale = resmooth_to_fp8_e8m0(module.weight, module.weight_scale)
|
||||
transfromed_scale = transform_sf_into_required_layout(
|
||||
weight_scale,
|
||||
mn=weight.shape[0],
|
||||
k=weight.shape[1],
|
||||
recipe=(1, 128, 128),
|
||||
is_sfa=False,
|
||||
)
|
||||
module.weight = nn.Parameter(weight, requires_grad=False)
|
||||
module.weight_scale = nn.Parameter(transfromed_scale, requires_grad=False)
|
||||
|
||||
for idx, layer in enumerate(self.model.layers[: self.config.num_hidden_layers]):
|
||||
if idx == self.config.num_hidden_layers - 1:
|
||||
layer.next_layer_layernorm = self.model.norm
|
||||
|
||||
@ -428,12 +428,22 @@ class MTPDraftModel(nn.Module):
|
||||
torch.cuda.Stream]):
|
||||
super().__init__()
|
||||
# Import here to avoid circular import
|
||||
from .modeling_deepseekv3 import DeepseekV3MTP
|
||||
|
||||
mtp_layer = DeepseekV3MTP(model_config,
|
||||
layer_idx,
|
||||
aux_stream_dict,
|
||||
is_separate_draft_engine=True)
|
||||
model_type = model_config.pretrained_config.model_type
|
||||
if model_type == "glm4_moe":
|
||||
from .modeling_glm import Glm4MTP
|
||||
mtp_layer = Glm4MTP(model_config,
|
||||
layer_idx,
|
||||
aux_stream_dict,
|
||||
is_separate_draft_engine=True)
|
||||
elif model_type in ["deepseek_v3", "deepseek_v32"]:
|
||||
from .modeling_deepseekv3 import DeepseekV3MTP
|
||||
mtp_layer = DeepseekV3MTP(model_config,
|
||||
layer_idx,
|
||||
aux_stream_dict,
|
||||
is_separate_draft_engine=True)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"MTPDraftModel does not support model_type: {model_type}")
|
||||
setattr(self, f"layers.{layer_idx}", mtp_layer)
|
||||
self.layers = mtp_layer
|
||||
self.layer_idx = layer_idx
|
||||
@ -493,8 +503,18 @@ class MTPDraftModelForCausalLM(DecoderModelForCausalLM[MTPDraftModel,
|
||||
|
||||
def load_weights(self, weights: Dict):
|
||||
# Import here to avoid circular import
|
||||
from .modeling_deepseekv3 import DeepseekV3WeightLoader
|
||||
weight_loader = DeepseekV3WeightLoader(self, is_draft_model=True)
|
||||
model_type = self.model_config.pretrained_config.model_type
|
||||
match model_type:
|
||||
case "glm4_moe":
|
||||
from .modeling_glm import Glm4WeightLoader
|
||||
weight_loader = Glm4WeightLoader(self, is_draft_model=True)
|
||||
case "deepseek_v3" | "deepseek_v32":
|
||||
from .modeling_deepseekv3 import DeepseekV3WeightLoader
|
||||
weight_loader = DeepseekV3WeightLoader(self,
|
||||
is_draft_model=True)
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"Model type {model_type} not supported for MTP")
|
||||
weight_loader.load_weights(weights)
|
||||
|
||||
def load_weights_from_target_model(self,
|
||||
|
||||
@ -229,9 +229,14 @@ class QKNormRoPEAttention(Attention):
|
||||
def apply_qk_norm_rope(self, qkv, position_ids):
|
||||
factor, low, high, attention_factor = compute_yarn_parameters(
|
||||
self.pretrained_config)
|
||||
|
||||
partial_rotary_factor = self.pretrained_config.partial_rotary_factor if hasattr(
|
||||
self.pretrained_config, "partial_rotary_factor") else 1.0
|
||||
rotary_dim = int(self.head_dim * partial_rotary_factor)
|
||||
|
||||
torch.ops.trtllm.fused_qk_norm_rope(
|
||||
qkv, self.num_heads, self.num_key_value_heads,
|
||||
self.num_key_value_heads, self.head_dim,
|
||||
self.num_key_value_heads, self.head_dim, rotary_dim,
|
||||
self.q_norm.variance_epsilon, self.q_norm.weight,
|
||||
self.k_norm.weight,
|
||||
self.pos_embd_params.rope.theta, self.pos_embd_params.is_neox,
|
||||
|
||||
@ -272,6 +272,8 @@ ByteDance-Seed/Seed-OSS-36B-Instruct:
|
||||
- accuracy: 90.8
|
||||
zai-org/GLM-4.6:
|
||||
- accuracy: 81.3
|
||||
- spec_dec_algo: MTP
|
||||
accuracy: 81.3
|
||||
- quant_algo: NVFP4
|
||||
spec_dec_algo: MTP
|
||||
accuracy: 88.0
|
||||
|
||||
@ -2869,8 +2869,11 @@ class TestGLM4_6(LlmapiAccuracyTestHarness):
|
||||
@pytest.mark.skip_less_device(4)
|
||||
@pytest.mark.parametrize(
|
||||
"tp_size,pp_size,mtp_nextn,cuda_graph,overlap_scheduler,chunked_prefill,max_batch_size,moe_backend",
|
||||
[pytest.param(4, 1, 2, True, True, True, 16, "CUTLASS")],
|
||||
ids=["throughput"])
|
||||
[
|
||||
pytest.param(4, 1, 2, True, True, True, 16, "CUTLASS"),
|
||||
pytest.param(4, 1, 2, True, True, True, 16, "TRTLLM")
|
||||
],
|
||||
ids=["throughput", "throughput_trtllm"])
|
||||
def test_nvfp4_multi_gpus(self, tp_size, pp_size, mtp_nextn, cuda_graph,
|
||||
overlap_scheduler, chunked_prefill,
|
||||
max_batch_size, moe_backend):
|
||||
@ -2897,6 +2900,39 @@ class TestGLM4_6(LlmapiAccuracyTestHarness):
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
@pytest.mark.skip_less_device(4)
|
||||
@pytest.mark.parametrize(
|
||||
"tp_size,cuda_graph,overlap_scheduler,chunked_prefill,max_batch_size,moe_backend",
|
||||
[
|
||||
pytest.param(4, True, True, True, 16, "CUTLASS"),
|
||||
pytest.param(4, True, True, True, 16, "TRTLLM"),
|
||||
],
|
||||
ids=["2model", "2model_trtllm"])
|
||||
def test_nvfp4_2_model_mtp(self, tp_size, cuda_graph, overlap_scheduler,
|
||||
chunked_prefill, max_batch_size, moe_backend):
|
||||
model_path = f"{llm_models_root()}/glm-4.6-fp4"
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.70)
|
||||
pytorch_config = dict(
|
||||
disable_overlap_scheduler=not overlap_scheduler,
|
||||
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
|
||||
moe_config=MoeConfig(backend=moe_backend))
|
||||
|
||||
mtp_config = MTPDecodingConfig(num_nextn_predict_layers=3,
|
||||
mtp_eagle_one_model=False,
|
||||
speculative_model_dir=model_path)
|
||||
|
||||
with LLM(model_path,
|
||||
max_batch_size=max_batch_size,
|
||||
tensor_parallel_size=tp_size,
|
||||
kv_cache_config=kv_cache_config,
|
||||
**pytorch_config,
|
||||
speculative_config=mtp_config,
|
||||
enable_chunked_prefill=chunked_prefill) as llm:
|
||||
|
||||
assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
|
||||
@pytest.mark.timeout(7200)
|
||||
@pytest.mark.skip_less_device_memory(100000)
|
||||
|
||||
@ -507,6 +507,9 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[latency
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus_chunked_prefill[baseline_fp8kv]
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus_chunked_prefill[latency]
|
||||
accuracy/test_llm_api_pytorch.py::TestGLM4_6::test_nvfp4_multi_gpus[throughput]
|
||||
accuracy/test_llm_api_pytorch.py::TestGLM4_6::test_nvfp4_multi_gpus[throughput_trtllm]
|
||||
accuracy/test_llm_api_pytorch.py::TestGLM4_6::test_nvfp4_2_model_mtp[2model]
|
||||
accuracy/test_llm_api_pytorch.py::TestGLM4_6::test_nvfp4_2_model_mtp[2model_trtllm]
|
||||
accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_fp8_block_scales[latency]
|
||||
accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_bf16[multi_gpus_no_cache]
|
||||
accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8_block_scales[latency-torch_compile=False]
|
||||
|
||||
@ -8,8 +8,8 @@ from tensorrt_llm._torch.modules.rotary_embedding import RotaryEmbedding
|
||||
|
||||
@torch.inference_mode()
|
||||
def torch_ref_rms_norm_rope(qkv, num_heads_q, num_heads_k, num_heads_v,
|
||||
head_dim, eps, q_weight, k_weight, base, is_neox,
|
||||
position_ids):
|
||||
head_dim, rotary_dim, eps, q_weight, k_weight, base,
|
||||
is_neox, position_ids):
|
||||
"""
|
||||
PyTorch reference implementation of RMSNorm+RoPE for verification.
|
||||
|
||||
@ -22,6 +22,7 @@ def torch_ref_rms_norm_rope(qkv, num_heads_q, num_heads_k, num_heads_v,
|
||||
num_heads_k: Number of key heads
|
||||
num_heads_v: Number of value heads (unused for normalization/RoPE but needed for tensor splitting)
|
||||
head_dim: Dimension of each head
|
||||
rotary_dim: Dimension for RoPE
|
||||
eps: Epsilon value for RMS normalization
|
||||
q_weight: RMSNorm weights for query [head_dim]
|
||||
k_weight: RMSNorm weights for key [head_dim]
|
||||
@ -65,7 +66,7 @@ def torch_ref_rms_norm_rope(qkv, num_heads_q, num_heads_k, num_heads_v,
|
||||
|
||||
# Create and apply RotaryEmbedding module
|
||||
rope_params = RopeParams(
|
||||
dim=head_dim, # Set the rotary dimension to match the head dimension
|
||||
dim=rotary_dim, # Set the rotary dimension
|
||||
theta=base, # Base value for RoPE calculations
|
||||
max_positions=8192 # Large enough for any reasonable hidden size
|
||||
)
|
||||
@ -88,10 +89,12 @@ num_heads_groups = [
|
||||
(16, 8, 8), # Qwen3-0.6B, Qwen3-1.7B
|
||||
(32, 8, 8), # Qwen3-4B, Qwen3-8B, Qwen3-30B-A3B
|
||||
(40, 8, 8), # Qwen3-14B
|
||||
(64, 8, 8) # Qwen3-32B, Qwen3-235B-A22B
|
||||
(64, 8, 8), # Qwen3-32B, Qwen3-235B-A22B
|
||||
(24, 8, 8), # GLM 4.6
|
||||
]
|
||||
num_tokens_list = [1, 3, 8, 32, 256]
|
||||
is_neox_list = [False, True]
|
||||
partial_rotary_factor_list = [1.0, 0.5]
|
||||
dtypes = [torch.bfloat16] # TODO: support float16
|
||||
|
||||
|
||||
@ -100,8 +103,9 @@ dtypes = [torch.bfloat16] # TODO: support float16
|
||||
@pytest.mark.parametrize("num_tokens", num_tokens_list)
|
||||
@pytest.mark.parametrize("is_neox", is_neox_list)
|
||||
@pytest.mark.parametrize("dtype", dtypes)
|
||||
def test_fused_qk_norm_rope(head_dim, num_heads_group, num_tokens, is_neox,
|
||||
dtype):
|
||||
@pytest.mark.parametrize("partial_rotary_factor", partial_rotary_factor_list)
|
||||
def test_fused_qk_norm_rope(head_dim, num_heads_group, num_tokens,
|
||||
partial_rotary_factor, is_neox, dtype):
|
||||
"""
|
||||
Test the fused QK RMSNorm + RoPE operation with various configurations.
|
||||
|
||||
@ -143,18 +147,20 @@ def test_fused_qk_norm_rope(head_dim, num_heads_group, num_tokens, is_neox,
|
||||
base = 10000.0
|
||||
|
||||
factor, low, high, attention_factor = 1.0, 0, 0, 1.0
|
||||
rotary_dim = int(head_dim * partial_rotary_factor)
|
||||
# Run the custom fusedQKNormRope operation
|
||||
torch.ops.trtllm.fused_qk_norm_rope(qkv, num_heads_q, num_heads_k,
|
||||
num_heads_v, head_dim, eps, q_weight,
|
||||
k_weight, base, is_neox, position_ids,
|
||||
factor, low, high, attention_factor,
|
||||
True)
|
||||
num_heads_v, head_dim, rotary_dim, eps,
|
||||
q_weight, k_weight, base, is_neox,
|
||||
position_ids, factor, low, high,
|
||||
attention_factor, True)
|
||||
output = qkv # This op is inplace
|
||||
|
||||
# Compute reference output using TensorRT LLM modules
|
||||
ref_output = torch_ref_rms_norm_rope(qkv_copy, num_heads_q, num_heads_k,
|
||||
num_heads_v, head_dim, eps, q_weight,
|
||||
k_weight, base, is_neox, position_ids)
|
||||
num_heads_v, head_dim, rotary_dim, eps,
|
||||
q_weight, k_weight, base, is_neox,
|
||||
position_ids)
|
||||
|
||||
# Compare outputs from custom kernel vs reference implementation
|
||||
torch.testing.assert_close(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user