mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[None][feat] Fused kernels (qknormrope + moe routing) and two-model MTP support for glm4moe (#9852)
Signed-off-by: Xuanyu Chen <xuanyuc@nvidia.com>
This commit is contained in:
parent
64d7796234
commit
a5a37227d6
@ -66,6 +66,7 @@ __global__ void fusedQKNormRopeKernel(
|
|||||||
int const num_heads_q, // Number of query heads
|
int const num_heads_q, // Number of query heads
|
||||||
int const num_heads_k, // Number of key heads
|
int const num_heads_k, // Number of key heads
|
||||||
int const num_heads_v, // Number of value heads
|
int const num_heads_v, // Number of value heads
|
||||||
|
int const rotary_dim, // Dimension for RoPE
|
||||||
float const eps, // Epsilon for RMS normalization
|
float const eps, // Epsilon for RMS normalization
|
||||||
__nv_bfloat16 const* q_weight, // RMSNorm weights for query
|
__nv_bfloat16 const* q_weight, // RMSNorm weights for query
|
||||||
__nv_bfloat16 const* k_weight, // RMSNorm weights for key
|
__nv_bfloat16 const* k_weight, // RMSNorm weights for key
|
||||||
@ -184,7 +185,7 @@ __global__ void fusedQKNormRopeKernel(
|
|||||||
|
|
||||||
int dim_idx = laneId * numElemsPerThread + i;
|
int dim_idx = laneId * numElemsPerThread + i;
|
||||||
int half_dim = dim_idx / 2;
|
int half_dim = dim_idx / 2;
|
||||||
float freq = powf(base, -2.0f * half_dim / static_cast<float>(head_dim));
|
float freq = powf(base, -2.0f * half_dim / static_cast<float>(rotary_dim));
|
||||||
|
|
||||||
if (factor != 1.0f)
|
if (factor != 1.0f)
|
||||||
{
|
{
|
||||||
@ -212,19 +213,20 @@ __global__ void fusedQKNormRopeKernel(
|
|||||||
{
|
{
|
||||||
// Before data exchange with in warp, we need to sync.
|
// Before data exchange with in warp, we need to sync.
|
||||||
__syncwarp();
|
__syncwarp();
|
||||||
|
int pairOffset = (rotary_dim / 2) / numElemsPerThread;
|
||||||
// Get the data from the other half of the warp. Fill cos_vals and sin_vals.
|
// Get the data from the other half of the warp. Fill cos_vals and sin_vals.
|
||||||
for (int i = 0; i < numElemsPerThread; i++)
|
for (int i = 0; i < numElemsPerThread; i++)
|
||||||
{
|
{
|
||||||
elements2[i] = __shfl_xor_sync(0xffffffff, elements[i], 16);
|
elements2[i] = __shfl_xor_sync(0xffffffff, elements[i], pairOffset);
|
||||||
if (laneId < 16)
|
if (laneId < pairOffset)
|
||||||
{
|
{
|
||||||
elements2[i] = -elements2[i];
|
elements2[i] = -elements2[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
int dim_idx = laneId * numElemsPerThread + i;
|
int dim_idx = laneId * numElemsPerThread + i;
|
||||||
dim_idx = (dim_idx * 2) % head_dim;
|
dim_idx = (dim_idx * 2) % rotary_dim;
|
||||||
int half_dim = dim_idx / 2;
|
int half_dim = dim_idx / 2;
|
||||||
float freq = powf(base, -2.0f * half_dim / static_cast<float>(head_dim));
|
float freq = powf(base, -2.0f * half_dim / static_cast<float>(rotary_dim));
|
||||||
|
|
||||||
if (factor != 1.0f)
|
if (factor != 1.0f)
|
||||||
{
|
{
|
||||||
@ -251,9 +253,25 @@ __global__ void fusedQKNormRopeKernel(
|
|||||||
__syncwarp();
|
__syncwarp();
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < numElemsPerThread; i++)
|
bool const is_full_rope = (rotary_dim == head_dim);
|
||||||
|
if (is_full_rope)
|
||||||
{
|
{
|
||||||
elements[i] = (elements[i] * cos_vals[i] + elements2[i] * sin_vals[i]) * attention_factor;
|
for (int i = 0; i < numElemsPerThread; i++)
|
||||||
|
{
|
||||||
|
elements[i] = (elements[i] * cos_vals[i] + elements2[i] * sin_vals[i]) * attention_factor;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
for (int i = 0; i < numElemsPerThread; i++)
|
||||||
|
{
|
||||||
|
int dim_idx = laneId * numElemsPerThread + i;
|
||||||
|
|
||||||
|
if (dim_idx < rotary_dim)
|
||||||
|
{
|
||||||
|
elements[i] = (elements[i] * cos_vals[i] + elements2[i] * sin_vals[i]) * attention_factor;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Store.
|
// Store.
|
||||||
@ -284,14 +302,23 @@ __global__ void fusedQKNormRopeKernel(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void launchFusedQKNormRope(void* qkv, int const num_tokens, int const num_heads_q, int const num_heads_k,
|
void launchFusedQKNormRope(void* qkv, int const num_tokens, int const num_heads_q, int const num_heads_k,
|
||||||
int const num_heads_v, int const head_dim, float const eps, void const* q_weight, void const* k_weight,
|
int const num_heads_v, int const head_dim, int const rotary_dim, float const eps, void const* q_weight,
|
||||||
float const base, bool const interleave, int const* position_ids, float factor, float low, float high,
|
void const* k_weight, float const base, bool const interleave, int const* position_ids, float factor, float low,
|
||||||
float attention_factor, cudaStream_t stream, bool is_qk_norm)
|
float high, float attention_factor, cudaStream_t stream, bool is_qk_norm)
|
||||||
{
|
{
|
||||||
if (factor == 1.0f)
|
if (factor == 1.0f)
|
||||||
{
|
{
|
||||||
TLLM_CHECK(attention_factor == 1.0f);
|
TLLM_CHECK(attention_factor == 1.0f);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TLLM_CHECK_WITH_INFO(rotary_dim % 2 == 0, "rotary_dim must be even");
|
||||||
|
if (!interleave)
|
||||||
|
{
|
||||||
|
// To allow warp-level pairing for partial rope
|
||||||
|
TLLM_CHECK_WITH_INFO(
|
||||||
|
(rotary_dim * 16) % head_dim == 0, "Unsupported rotary dimension for fusedQKNormRope: %d", rotary_dim);
|
||||||
|
}
|
||||||
|
|
||||||
constexpr int blockSize = 256;
|
constexpr int blockSize = 256;
|
||||||
|
|
||||||
int const warpsPerBlock = blockSize / 32;
|
int const warpsPerBlock = blockSize / 32;
|
||||||
@ -309,7 +336,7 @@ void launchFusedQKNormRope(void* qkv, int const num_tokens, int const num_heads_
|
|||||||
case 64:
|
case 64:
|
||||||
DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
|
DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
|
||||||
fusedQKNormRopeKernel<64, INTERLEAVE><<<gridDim, blockDim, 0, stream>>>(
|
fusedQKNormRopeKernel<64, INTERLEAVE><<<gridDim, blockDim, 0, stream>>>(
|
||||||
reinterpret_cast<__nv_bfloat16*>(qkv), num_heads_q, num_heads_k, num_heads_v, eps,
|
reinterpret_cast<__nv_bfloat16*>(qkv), num_heads_q, num_heads_k, num_heads_v, rotary_dim, eps,
|
||||||
reinterpret_cast<__nv_bfloat16 const*>(q_weight), reinterpret_cast<__nv_bfloat16 const*>(k_weight),
|
reinterpret_cast<__nv_bfloat16 const*>(q_weight), reinterpret_cast<__nv_bfloat16 const*>(k_weight),
|
||||||
base, position_ids, num_tokens, factor, low, high, attention_factor, is_qk_norm);
|
base, position_ids, num_tokens, factor, low, high, attention_factor, is_qk_norm);
|
||||||
});
|
});
|
||||||
@ -317,7 +344,7 @@ void launchFusedQKNormRope(void* qkv, int const num_tokens, int const num_heads_
|
|||||||
case 128:
|
case 128:
|
||||||
DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
|
DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
|
||||||
fusedQKNormRopeKernel<128, INTERLEAVE><<<gridDim, blockDim, 0, stream>>>(
|
fusedQKNormRopeKernel<128, INTERLEAVE><<<gridDim, blockDim, 0, stream>>>(
|
||||||
reinterpret_cast<__nv_bfloat16*>(qkv), num_heads_q, num_heads_k, num_heads_v, eps,
|
reinterpret_cast<__nv_bfloat16*>(qkv), num_heads_q, num_heads_k, num_heads_v, rotary_dim, eps,
|
||||||
reinterpret_cast<__nv_bfloat16 const*>(q_weight), reinterpret_cast<__nv_bfloat16 const*>(k_weight),
|
reinterpret_cast<__nv_bfloat16 const*>(q_weight), reinterpret_cast<__nv_bfloat16 const*>(k_weight),
|
||||||
base, position_ids, num_tokens, factor, low, high, attention_factor, is_qk_norm);
|
base, position_ids, num_tokens, factor, low, high, attention_factor, is_qk_norm);
|
||||||
});
|
});
|
||||||
@ -325,7 +352,7 @@ void launchFusedQKNormRope(void* qkv, int const num_tokens, int const num_heads_
|
|||||||
case 256:
|
case 256:
|
||||||
DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
|
DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
|
||||||
fusedQKNormRopeKernel<256, INTERLEAVE><<<gridDim, blockDim, 0, stream>>>(
|
fusedQKNormRopeKernel<256, INTERLEAVE><<<gridDim, blockDim, 0, stream>>>(
|
||||||
reinterpret_cast<__nv_bfloat16*>(qkv), num_heads_q, num_heads_k, num_heads_v, eps,
|
reinterpret_cast<__nv_bfloat16*>(qkv), num_heads_q, num_heads_k, num_heads_v, rotary_dim, eps,
|
||||||
reinterpret_cast<__nv_bfloat16 const*>(q_weight), reinterpret_cast<__nv_bfloat16 const*>(k_weight),
|
reinterpret_cast<__nv_bfloat16 const*>(q_weight), reinterpret_cast<__nv_bfloat16 const*>(k_weight),
|
||||||
base, position_ids, num_tokens, factor, low, high, attention_factor, is_qk_norm);
|
base, position_ids, num_tokens, factor, low, high, attention_factor, is_qk_norm);
|
||||||
});
|
});
|
||||||
|
|||||||
@ -33,6 +33,7 @@ void launchFusedQKNormRope(
|
|||||||
int const num_heads_k, // Number of key heads
|
int const num_heads_k, // Number of key heads
|
||||||
int const num_heads_v, // Number of value heads
|
int const num_heads_v, // Number of value heads
|
||||||
int const head_dim, // Dimension per head
|
int const head_dim, // Dimension per head
|
||||||
|
int const rotary_dim, // Dimension for RoPE
|
||||||
float const eps, // Epsilon for RMS normalization
|
float const eps, // Epsilon for RMS normalization
|
||||||
void const* q_weight, // RMSNorm weights for query [head_dim]
|
void const* q_weight, // RMSNorm weights for query [head_dim]
|
||||||
void const* k_weight, // RMSNorm weights for key [head_dim]
|
void const* k_weight, // RMSNorm weights for key [head_dim]
|
||||||
|
|||||||
@ -34,6 +34,7 @@ void fused_qk_norm_rope(
|
|||||||
int64_t num_heads_k, // Number of key heads
|
int64_t num_heads_k, // Number of key heads
|
||||||
int64_t num_heads_v, // Number of value heads
|
int64_t num_heads_v, // Number of value heads
|
||||||
int64_t head_dim, // Dimension per head
|
int64_t head_dim, // Dimension per head
|
||||||
|
int64_t rotary_dim, // Dimension for RoPE
|
||||||
double eps, // Epsilon for RMS normalization
|
double eps, // Epsilon for RMS normalization
|
||||||
torch::Tensor& q_weight, // RMSNorm weights for query [head_dim]
|
torch::Tensor& q_weight, // RMSNorm weights for query [head_dim]
|
||||||
torch::Tensor& k_weight, // RMSNorm weights for key [head_dim]
|
torch::Tensor& k_weight, // RMSNorm weights for key [head_dim]
|
||||||
@ -72,9 +73,9 @@ void fused_qk_norm_rope(
|
|||||||
|
|
||||||
tensorrt_llm::kernels::launchFusedQKNormRope(reinterpret_cast<__nv_bfloat16*>(qkv.data_ptr()),
|
tensorrt_llm::kernels::launchFusedQKNormRope(reinterpret_cast<__nv_bfloat16*>(qkv.data_ptr()),
|
||||||
static_cast<int>(num_tokens), static_cast<int>(num_heads_q), static_cast<int>(num_heads_k),
|
static_cast<int>(num_tokens), static_cast<int>(num_heads_q), static_cast<int>(num_heads_k),
|
||||||
static_cast<int>(num_heads_v), static_cast<int>(head_dim), static_cast<float>(eps),
|
static_cast<int>(num_heads_v), static_cast<int>(head_dim), static_cast<int>(rotary_dim),
|
||||||
reinterpret_cast<__nv_bfloat16*>(q_weight.data_ptr()), reinterpret_cast<__nv_bfloat16*>(k_weight.data_ptr()),
|
static_cast<float>(eps), reinterpret_cast<__nv_bfloat16*>(q_weight.data_ptr()),
|
||||||
static_cast<float>(base),
|
reinterpret_cast<__nv_bfloat16*>(k_weight.data_ptr()), static_cast<float>(base),
|
||||||
!is_neox, // interleave
|
!is_neox, // interleave
|
||||||
reinterpret_cast<int const*>(position_ids.data_ptr()), static_cast<float>(factor), static_cast<float>(low),
|
reinterpret_cast<int const*>(position_ids.data_ptr()), static_cast<float>(factor), static_cast<float>(low),
|
||||||
static_cast<float>(high), static_cast<float>(attention_factor), stream, is_qk_norm);
|
static_cast<float>(high), static_cast<float>(attention_factor), stream, is_qk_norm);
|
||||||
@ -84,7 +85,8 @@ void fused_qk_norm_rope(
|
|||||||
TORCH_LIBRARY_FRAGMENT(trtllm, m)
|
TORCH_LIBRARY_FRAGMENT(trtllm, m)
|
||||||
{
|
{
|
||||||
m.def(
|
m.def(
|
||||||
"fused_qk_norm_rope(Tensor(a!) qkv, int num_heads_q, int num_heads_k, int num_heads_v, int head_dim, float "
|
"fused_qk_norm_rope(Tensor(a!) qkv, int num_heads_q, int num_heads_k, int num_heads_v, int head_dim, int "
|
||||||
|
"rotary_dim, float "
|
||||||
"eps, Tensor q_weight, Tensor k_weight, float base, bool is_neox, Tensor position_ids, float factor, float "
|
"eps, Tensor q_weight, Tensor k_weight, float base, bool is_neox, Tensor position_ids, float factor, float "
|
||||||
"low, float high, float attention_factor, bool is_qk_norm) -> ()");
|
"low, float high, float attention_factor, bool is_qk_norm) -> ()");
|
||||||
}
|
}
|
||||||
|
|||||||
@ -31,7 +31,9 @@ class AutoModelForCausalLM(Generic[TModel, TConfig]):
|
|||||||
"") # Strip the appended EAGLE3
|
"") # Strip the appended EAGLE3
|
||||||
if hasattr(config.pretrained_config, "draft_vocab_size"):
|
if hasattr(config.pretrained_config, "draft_vocab_size"):
|
||||||
model_arch = "EAGLE3" + model_arch
|
model_arch = "EAGLE3" + model_arch
|
||||||
if model_arch == "DeepseekV3ForCausalLM" and config.spec_config is not None and config.spec_config.max_draft_len == 0:
|
if model_arch in (
|
||||||
|
"DeepseekV3ForCausalLM", "Glm4MoeForCausalLM"
|
||||||
|
) and config.spec_config is not None and config.spec_config.max_draft_len == 0:
|
||||||
model_arch = "MTPDraftModelForCausalLM"
|
model_arch = "MTPDraftModelForCausalLM"
|
||||||
|
|
||||||
cls = MODEL_CLASS_MAPPING.get(model_arch)
|
cls = MODEL_CLASS_MAPPING.get(model_arch)
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import inspect
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
@ -8,14 +9,10 @@ from tqdm import tqdm
|
|||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from tensorrt_llm._ipc_utils import can_access_peer
|
from tensorrt_llm._ipc_utils import can_access_peer
|
||||||
from tensorrt_llm._utils import get_sm_version, is_sm_100f
|
from tensorrt_llm._utils import get_sm_version
|
||||||
from tensorrt_llm.functional import PositionEmbeddingType
|
from tensorrt_llm.functional import PositionEmbeddingType
|
||||||
from tensorrt_llm.models.modeling_utils import QuantConfig
|
from tensorrt_llm.models.modeling_utils import QuantConfig
|
||||||
from tensorrt_llm.quantization.mode import QuantAlgo
|
from tensorrt_llm.quantization.mode import QuantAlgo
|
||||||
from tensorrt_llm.quantization.utils.fp8_utils import (
|
|
||||||
resmooth_to_fp8_e8m0,
|
|
||||||
transform_sf_into_required_layout,
|
|
||||||
)
|
|
||||||
|
|
||||||
from ..attention_backend import AttentionMetadata
|
from ..attention_backend import AttentionMetadata
|
||||||
from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams
|
from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams
|
||||||
@ -29,7 +26,7 @@ from ..distributed import (
|
|||||||
from ..model_config import ModelConfig
|
from ..model_config import ModelConfig
|
||||||
from ..modules.decoder_layer import DecoderLayer
|
from ..modules.decoder_layer import DecoderLayer
|
||||||
from ..modules.embedding import Embedding
|
from ..modules.embedding import Embedding
|
||||||
from ..modules.fused_moe import MoEWeightLoadingMode, create_moe
|
from ..modules.fused_moe import MoE, MoEWeightLoadingMode, create_moe
|
||||||
from ..modules.gated_mlp import GatedMLP
|
from ..modules.gated_mlp import GatedMLP
|
||||||
from ..modules.linear import Linear, TensorParallelMode
|
from ..modules.linear import Linear, TensorParallelMode
|
||||||
from ..modules.multi_stream_utils import maybe_execute_in_parallel
|
from ..modules.multi_stream_utils import maybe_execute_in_parallel
|
||||||
@ -39,7 +36,142 @@ from ..speculative import SpecMetadata
|
|||||||
from ..utils import AuxStreamType, EventType, Fp4QuantizedTensor
|
from ..utils import AuxStreamType, EventType, Fp4QuantizedTensor
|
||||||
from .modeling_deepseekv3 import DeepseekV3Gate, DeepseekV3MTPHead, moe_reduce_add_shared_output
|
from .modeling_deepseekv3 import DeepseekV3Gate, DeepseekV3MTPHead, moe_reduce_add_shared_output
|
||||||
from .modeling_speculative import SpecDecOneEngineForCausalLM
|
from .modeling_speculative import SpecDecOneEngineForCausalLM
|
||||||
from .modeling_utils import DecoderModel, EagerFusionConfig, _load_weights_impl, register_auto_model
|
from .modeling_utils import (
|
||||||
|
DecoderModel,
|
||||||
|
EagerFusionConfig,
|
||||||
|
duplicate_kv_weight,
|
||||||
|
filter_weights,
|
||||||
|
register_auto_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Glm4WeightLoader:
|
||||||
|
def __init__(self, model, is_draft_model: bool = False):
|
||||||
|
self.model = model
|
||||||
|
self.config = model.config
|
||||||
|
self.model_config = model.model_config
|
||||||
|
self.is_draft_model = is_draft_model
|
||||||
|
|
||||||
|
def load_weights(self, weights: Dict, allow_partial_loading: bool = False):
|
||||||
|
def rename_moe_weight(weights: Dict, rename_rules: Dict):
|
||||||
|
result = {}
|
||||||
|
for key, value in weights.items():
|
||||||
|
new_key = key
|
||||||
|
for old, new in rename_rules.items():
|
||||||
|
new_key = new_key.replace(old, new)
|
||||||
|
result[new_key] = value
|
||||||
|
return result
|
||||||
|
|
||||||
|
params_map = {
|
||||||
|
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||||
|
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||||
|
}
|
||||||
|
all_named_modules = dict(self.model.named_modules())
|
||||||
|
|
||||||
|
tp_size = (
|
||||||
|
1
|
||||||
|
if self.model_config.mapping.enable_attention_dp
|
||||||
|
else self.model_config.mapping.tp_size
|
||||||
|
)
|
||||||
|
num_kv_heads = (
|
||||||
|
self.config.num_key_value_heads
|
||||||
|
if hasattr(self.config, "num_key_value_heads")
|
||||||
|
and self.config.num_key_value_heads is not None
|
||||||
|
else self.config.num_attention_heads
|
||||||
|
)
|
||||||
|
|
||||||
|
for name, module in tqdm(all_named_modules.items(), desc="Loading weights"):
|
||||||
|
if len(module._parameters) <= 0 or name.startswith("draft_model"):
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
names = name.split(".")
|
||||||
|
if "model.layers" in name and int(names[2]) >= self.config.num_hidden_layers:
|
||||||
|
mtp_layer_idx = int(names[2]) - self.config.num_hidden_layers
|
||||||
|
names[2] = str(
|
||||||
|
mtp_layer_idx % self.config.num_nextn_predict_layers
|
||||||
|
+ self.config.num_hidden_layers
|
||||||
|
)
|
||||||
|
name = ".".join(names)
|
||||||
|
|
||||||
|
if names[-1] in params_map:
|
||||||
|
module_weights = []
|
||||||
|
for new_name in params_map[names[-1]]:
|
||||||
|
fw = filter_weights(".".join(names[:-1] + [new_name]), weights)
|
||||||
|
if new_name in ["k_proj", "v_proj"]:
|
||||||
|
num_kv_heads_list = (
|
||||||
|
[num_kv_heads] * len(fw)
|
||||||
|
if isinstance(num_kv_heads, int)
|
||||||
|
else num_kv_heads
|
||||||
|
)
|
||||||
|
fw = {
|
||||||
|
k: duplicate_kv_weight(
|
||||||
|
weight=v[:],
|
||||||
|
num_kv_heads=num_kv_heads_list[i],
|
||||||
|
tensor_parallel_size=tp_size,
|
||||||
|
)
|
||||||
|
if k in ["weight", "bias"]
|
||||||
|
else v
|
||||||
|
for i, (k, v) in enumerate(fw.items())
|
||||||
|
}
|
||||||
|
module_weights.append(fw)
|
||||||
|
module.load_weights(weights=module_weights)
|
||||||
|
elif names[-1] == "experts":
|
||||||
|
module_weights = filter_weights(name, weights)
|
||||||
|
module_weights = rename_moe_weight(
|
||||||
|
module_weights,
|
||||||
|
{
|
||||||
|
"down_proj": "w2",
|
||||||
|
"up_proj": "w3",
|
||||||
|
"gate_proj": "w1",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
module.load_weights(
|
||||||
|
weights=[module_weights], allow_partial_loading=allow_partial_loading
|
||||||
|
)
|
||||||
|
elif names[-1] == "backend" and isinstance(module, MoE):
|
||||||
|
# Special case: ConfigurableMoE.backend (TRTLLMGenFusedMoE)
|
||||||
|
# Currently saved MoE weights don't include 'backend' in their names.
|
||||||
|
# After MoE refactoring, ConfigurableMoE now has a backend submodule,
|
||||||
|
# and weights loading is done in the backend, so module name includes '.backend'.
|
||||||
|
# We need to use parent module name (without .backend) to match saved weight names.
|
||||||
|
# After MoE refactoring is fully complete, all paths will follow this branch.
|
||||||
|
parent_name = ".".join(names[:-1])
|
||||||
|
module_weights = filter_weights(parent_name, weights)
|
||||||
|
module_weights = rename_moe_weight(
|
||||||
|
module_weights,
|
||||||
|
{
|
||||||
|
"down_proj": "w2",
|
||||||
|
"up_proj": "w3",
|
||||||
|
"gate_proj": "w1",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
module.load_weights(
|
||||||
|
weights=[module_weights], allow_partial_loading=allow_partial_loading
|
||||||
|
)
|
||||||
|
elif names[-1] == "self_attn":
|
||||||
|
continue
|
||||||
|
elif names[-1] == "next_layer_layernorm":
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
module_weights = filter_weights(name, weights)
|
||||||
|
if hasattr(module, "load_weights"):
|
||||||
|
args = inspect.getfullargspec(module.load_weights).args
|
||||||
|
if "allow_partial_loading" not in args:
|
||||||
|
assert not allow_partial_loading, (
|
||||||
|
"allow_partial_loading is not supported for this model"
|
||||||
|
)
|
||||||
|
module.load_weights(weights=[module_weights])
|
||||||
|
else:
|
||||||
|
module.load_weights(
|
||||||
|
weights=[module_weights],
|
||||||
|
allow_partial_loading=allow_partial_loading,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
for n, p in module.named_parameters():
|
||||||
|
if not allow_partial_loading:
|
||||||
|
assert n in module_weights
|
||||||
|
if n in module_weights:
|
||||||
|
p.data.copy_(module_weights[n][:])
|
||||||
|
|
||||||
|
|
||||||
class Glm4Attention(QKNormRoPEAttention):
|
class Glm4Attention(QKNormRoPEAttention):
|
||||||
@ -61,7 +193,7 @@ class Glm4Attention(QKNormRoPEAttention):
|
|||||||
max_position_embeddings=config.max_position_embeddings,
|
max_position_embeddings=config.max_position_embeddings,
|
||||||
bias=config.attention_bias,
|
bias=config.attention_bias,
|
||||||
pos_embd_params=pos_embd_params,
|
pos_embd_params=pos_embd_params,
|
||||||
fuse_qk_norm_rope=False,
|
fuse_qk_norm_rope=True,
|
||||||
layer_idx=layer_idx,
|
layer_idx=layer_idx,
|
||||||
dtype=config.torch_dtype,
|
dtype=config.torch_dtype,
|
||||||
dense_bias=False,
|
dense_bias=False,
|
||||||
@ -98,7 +230,7 @@ class Glm4MoE(nn.Module):
|
|||||||
topk_group=config.topk_group,
|
topk_group=config.topk_group,
|
||||||
routed_scaling_factor=config.routed_scaling_factor,
|
routed_scaling_factor=config.routed_scaling_factor,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
fuse_routing_kernel=False,
|
fuse_routing_kernel=True,
|
||||||
apply_routing=False,
|
apply_routing=False,
|
||||||
moe_backend=model_config.moe_backend,
|
moe_backend=model_config.moe_backend,
|
||||||
)
|
)
|
||||||
@ -872,40 +1004,11 @@ class Glm4MoeForCausalLM(SpecDecOneEngineForCausalLM[Glm4Model, PretrainedConfig
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_weights(self, weights: Dict):
|
def load_weights(self, weights: Dict, allow_partial_loading: bool = False):
|
||||||
# model.layers.91.mlp.experts.75.up_proj.weight_scale_2
|
weight_loader = Glm4WeightLoader(self)
|
||||||
_load_weights_impl(
|
weight_loader.load_weights(weights, allow_partial_loading=allow_partial_loading)
|
||||||
self,
|
|
||||||
weights,
|
|
||||||
params_map={
|
|
||||||
r"(?!.*shared_experts)(?=.*experts?)(.*?)up_proj(.*)": r"\1w3\2",
|
|
||||||
r"(?!.*shared_experts)(?=.*experts?)(.*?)down_proj(.*)": r"\1w2\2",
|
|
||||||
r"(?!.*shared_experts)(?=.*experts?)(.*?)gate_proj(.*)": r"\1w1\2",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
def post_load_weights(self):
|
def post_load_weights(self):
|
||||||
all_named_modules = dict(self.model.named_modules())
|
|
||||||
for name, module in tqdm(all_named_modules.items(), desc="Post loading weights"):
|
|
||||||
if len(module._parameters) <= 0 or name.startswith("draft_model"):
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
if (
|
|
||||||
self.model_config.quant_config.layer_quant_mode.has_fp8_block_scales()
|
|
||||||
and is_sm_100f()
|
|
||||||
and hasattr(module, "weight_scale")
|
|
||||||
):
|
|
||||||
weight, weight_scale = resmooth_to_fp8_e8m0(module.weight, module.weight_scale)
|
|
||||||
transfromed_scale = transform_sf_into_required_layout(
|
|
||||||
weight_scale,
|
|
||||||
mn=weight.shape[0],
|
|
||||||
k=weight.shape[1],
|
|
||||||
recipe=(1, 128, 128),
|
|
||||||
is_sfa=False,
|
|
||||||
)
|
|
||||||
module.weight = nn.Parameter(weight, requires_grad=False)
|
|
||||||
module.weight_scale = nn.Parameter(transfromed_scale, requires_grad=False)
|
|
||||||
|
|
||||||
for idx, layer in enumerate(self.model.layers[: self.config.num_hidden_layers]):
|
for idx, layer in enumerate(self.model.layers[: self.config.num_hidden_layers]):
|
||||||
if idx == self.config.num_hidden_layers - 1:
|
if idx == self.config.num_hidden_layers - 1:
|
||||||
layer.next_layer_layernorm = self.model.norm
|
layer.next_layer_layernorm = self.model.norm
|
||||||
|
|||||||
@ -428,12 +428,22 @@ class MTPDraftModel(nn.Module):
|
|||||||
torch.cuda.Stream]):
|
torch.cuda.Stream]):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# Import here to avoid circular import
|
# Import here to avoid circular import
|
||||||
from .modeling_deepseekv3 import DeepseekV3MTP
|
model_type = model_config.pretrained_config.model_type
|
||||||
|
if model_type == "glm4_moe":
|
||||||
mtp_layer = DeepseekV3MTP(model_config,
|
from .modeling_glm import Glm4MTP
|
||||||
layer_idx,
|
mtp_layer = Glm4MTP(model_config,
|
||||||
aux_stream_dict,
|
layer_idx,
|
||||||
is_separate_draft_engine=True)
|
aux_stream_dict,
|
||||||
|
is_separate_draft_engine=True)
|
||||||
|
elif model_type in ["deepseek_v3", "deepseek_v32"]:
|
||||||
|
from .modeling_deepseekv3 import DeepseekV3MTP
|
||||||
|
mtp_layer = DeepseekV3MTP(model_config,
|
||||||
|
layer_idx,
|
||||||
|
aux_stream_dict,
|
||||||
|
is_separate_draft_engine=True)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"MTPDraftModel does not support model_type: {model_type}")
|
||||||
setattr(self, f"layers.{layer_idx}", mtp_layer)
|
setattr(self, f"layers.{layer_idx}", mtp_layer)
|
||||||
self.layers = mtp_layer
|
self.layers = mtp_layer
|
||||||
self.layer_idx = layer_idx
|
self.layer_idx = layer_idx
|
||||||
@ -493,8 +503,18 @@ class MTPDraftModelForCausalLM(DecoderModelForCausalLM[MTPDraftModel,
|
|||||||
|
|
||||||
def load_weights(self, weights: Dict):
|
def load_weights(self, weights: Dict):
|
||||||
# Import here to avoid circular import
|
# Import here to avoid circular import
|
||||||
from .modeling_deepseekv3 import DeepseekV3WeightLoader
|
model_type = self.model_config.pretrained_config.model_type
|
||||||
weight_loader = DeepseekV3WeightLoader(self, is_draft_model=True)
|
match model_type:
|
||||||
|
case "glm4_moe":
|
||||||
|
from .modeling_glm import Glm4WeightLoader
|
||||||
|
weight_loader = Glm4WeightLoader(self, is_draft_model=True)
|
||||||
|
case "deepseek_v3" | "deepseek_v32":
|
||||||
|
from .modeling_deepseekv3 import DeepseekV3WeightLoader
|
||||||
|
weight_loader = DeepseekV3WeightLoader(self,
|
||||||
|
is_draft_model=True)
|
||||||
|
case _:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model type {model_type} not supported for MTP")
|
||||||
weight_loader.load_weights(weights)
|
weight_loader.load_weights(weights)
|
||||||
|
|
||||||
def load_weights_from_target_model(self,
|
def load_weights_from_target_model(self,
|
||||||
|
|||||||
@ -229,9 +229,14 @@ class QKNormRoPEAttention(Attention):
|
|||||||
def apply_qk_norm_rope(self, qkv, position_ids):
|
def apply_qk_norm_rope(self, qkv, position_ids):
|
||||||
factor, low, high, attention_factor = compute_yarn_parameters(
|
factor, low, high, attention_factor = compute_yarn_parameters(
|
||||||
self.pretrained_config)
|
self.pretrained_config)
|
||||||
|
|
||||||
|
partial_rotary_factor = self.pretrained_config.partial_rotary_factor if hasattr(
|
||||||
|
self.pretrained_config, "partial_rotary_factor") else 1.0
|
||||||
|
rotary_dim = int(self.head_dim * partial_rotary_factor)
|
||||||
|
|
||||||
torch.ops.trtllm.fused_qk_norm_rope(
|
torch.ops.trtllm.fused_qk_norm_rope(
|
||||||
qkv, self.num_heads, self.num_key_value_heads,
|
qkv, self.num_heads, self.num_key_value_heads,
|
||||||
self.num_key_value_heads, self.head_dim,
|
self.num_key_value_heads, self.head_dim, rotary_dim,
|
||||||
self.q_norm.variance_epsilon, self.q_norm.weight,
|
self.q_norm.variance_epsilon, self.q_norm.weight,
|
||||||
self.k_norm.weight,
|
self.k_norm.weight,
|
||||||
self.pos_embd_params.rope.theta, self.pos_embd_params.is_neox,
|
self.pos_embd_params.rope.theta, self.pos_embd_params.is_neox,
|
||||||
|
|||||||
@ -272,6 +272,8 @@ ByteDance-Seed/Seed-OSS-36B-Instruct:
|
|||||||
- accuracy: 90.8
|
- accuracy: 90.8
|
||||||
zai-org/GLM-4.6:
|
zai-org/GLM-4.6:
|
||||||
- accuracy: 81.3
|
- accuracy: 81.3
|
||||||
|
- spec_dec_algo: MTP
|
||||||
|
accuracy: 81.3
|
||||||
- quant_algo: NVFP4
|
- quant_algo: NVFP4
|
||||||
spec_dec_algo: MTP
|
spec_dec_algo: MTP
|
||||||
accuracy: 88.0
|
accuracy: 88.0
|
||||||
|
|||||||
@ -2869,8 +2869,11 @@ class TestGLM4_6(LlmapiAccuracyTestHarness):
|
|||||||
@pytest.mark.skip_less_device(4)
|
@pytest.mark.skip_less_device(4)
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"tp_size,pp_size,mtp_nextn,cuda_graph,overlap_scheduler,chunked_prefill,max_batch_size,moe_backend",
|
"tp_size,pp_size,mtp_nextn,cuda_graph,overlap_scheduler,chunked_prefill,max_batch_size,moe_backend",
|
||||||
[pytest.param(4, 1, 2, True, True, True, 16, "CUTLASS")],
|
[
|
||||||
ids=["throughput"])
|
pytest.param(4, 1, 2, True, True, True, 16, "CUTLASS"),
|
||||||
|
pytest.param(4, 1, 2, True, True, True, 16, "TRTLLM")
|
||||||
|
],
|
||||||
|
ids=["throughput", "throughput_trtllm"])
|
||||||
def test_nvfp4_multi_gpus(self, tp_size, pp_size, mtp_nextn, cuda_graph,
|
def test_nvfp4_multi_gpus(self, tp_size, pp_size, mtp_nextn, cuda_graph,
|
||||||
overlap_scheduler, chunked_prefill,
|
overlap_scheduler, chunked_prefill,
|
||||||
max_batch_size, moe_backend):
|
max_batch_size, moe_backend):
|
||||||
@ -2897,6 +2900,39 @@ class TestGLM4_6(LlmapiAccuracyTestHarness):
|
|||||||
task = GSM8K(self.MODEL_NAME)
|
task = GSM8K(self.MODEL_NAME)
|
||||||
task.evaluate(llm)
|
task.evaluate(llm)
|
||||||
|
|
||||||
|
@pytest.mark.skip_less_device(4)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"tp_size,cuda_graph,overlap_scheduler,chunked_prefill,max_batch_size,moe_backend",
|
||||||
|
[
|
||||||
|
pytest.param(4, True, True, True, 16, "CUTLASS"),
|
||||||
|
pytest.param(4, True, True, True, 16, "TRTLLM"),
|
||||||
|
],
|
||||||
|
ids=["2model", "2model_trtllm"])
|
||||||
|
def test_nvfp4_2_model_mtp(self, tp_size, cuda_graph, overlap_scheduler,
|
||||||
|
chunked_prefill, max_batch_size, moe_backend):
|
||||||
|
model_path = f"{llm_models_root()}/glm-4.6-fp4"
|
||||||
|
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.70)
|
||||||
|
pytorch_config = dict(
|
||||||
|
disable_overlap_scheduler=not overlap_scheduler,
|
||||||
|
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
|
||||||
|
moe_config=MoeConfig(backend=moe_backend))
|
||||||
|
|
||||||
|
mtp_config = MTPDecodingConfig(num_nextn_predict_layers=3,
|
||||||
|
mtp_eagle_one_model=False,
|
||||||
|
speculative_model_dir=model_path)
|
||||||
|
|
||||||
|
with LLM(model_path,
|
||||||
|
max_batch_size=max_batch_size,
|
||||||
|
tensor_parallel_size=tp_size,
|
||||||
|
kv_cache_config=kv_cache_config,
|
||||||
|
**pytorch_config,
|
||||||
|
speculative_config=mtp_config,
|
||||||
|
enable_chunked_prefill=chunked_prefill) as llm:
|
||||||
|
|
||||||
|
assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4
|
||||||
|
task = GSM8K(self.MODEL_NAME)
|
||||||
|
task.evaluate(llm)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.timeout(7200)
|
@pytest.mark.timeout(7200)
|
||||||
@pytest.mark.skip_less_device_memory(100000)
|
@pytest.mark.skip_less_device_memory(100000)
|
||||||
|
|||||||
@ -507,6 +507,9 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[latency
|
|||||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus_chunked_prefill[baseline_fp8kv]
|
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus_chunked_prefill[baseline_fp8kv]
|
||||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus_chunked_prefill[latency]
|
accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus_chunked_prefill[latency]
|
||||||
accuracy/test_llm_api_pytorch.py::TestGLM4_6::test_nvfp4_multi_gpus[throughput]
|
accuracy/test_llm_api_pytorch.py::TestGLM4_6::test_nvfp4_multi_gpus[throughput]
|
||||||
|
accuracy/test_llm_api_pytorch.py::TestGLM4_6::test_nvfp4_multi_gpus[throughput_trtllm]
|
||||||
|
accuracy/test_llm_api_pytorch.py::TestGLM4_6::test_nvfp4_2_model_mtp[2model]
|
||||||
|
accuracy/test_llm_api_pytorch.py::TestGLM4_6::test_nvfp4_2_model_mtp[2model_trtllm]
|
||||||
accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_fp8_block_scales[latency]
|
accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_fp8_block_scales[latency]
|
||||||
accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_bf16[multi_gpus_no_cache]
|
accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_bf16[multi_gpus_no_cache]
|
||||||
accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8_block_scales[latency-torch_compile=False]
|
accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8_block_scales[latency-torch_compile=False]
|
||||||
|
|||||||
@ -8,8 +8,8 @@ from tensorrt_llm._torch.modules.rotary_embedding import RotaryEmbedding
|
|||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def torch_ref_rms_norm_rope(qkv, num_heads_q, num_heads_k, num_heads_v,
|
def torch_ref_rms_norm_rope(qkv, num_heads_q, num_heads_k, num_heads_v,
|
||||||
head_dim, eps, q_weight, k_weight, base, is_neox,
|
head_dim, rotary_dim, eps, q_weight, k_weight, base,
|
||||||
position_ids):
|
is_neox, position_ids):
|
||||||
"""
|
"""
|
||||||
PyTorch reference implementation of RMSNorm+RoPE for verification.
|
PyTorch reference implementation of RMSNorm+RoPE for verification.
|
||||||
|
|
||||||
@ -22,6 +22,7 @@ def torch_ref_rms_norm_rope(qkv, num_heads_q, num_heads_k, num_heads_v,
|
|||||||
num_heads_k: Number of key heads
|
num_heads_k: Number of key heads
|
||||||
num_heads_v: Number of value heads (unused for normalization/RoPE but needed for tensor splitting)
|
num_heads_v: Number of value heads (unused for normalization/RoPE but needed for tensor splitting)
|
||||||
head_dim: Dimension of each head
|
head_dim: Dimension of each head
|
||||||
|
rotary_dim: Dimension for RoPE
|
||||||
eps: Epsilon value for RMS normalization
|
eps: Epsilon value for RMS normalization
|
||||||
q_weight: RMSNorm weights for query [head_dim]
|
q_weight: RMSNorm weights for query [head_dim]
|
||||||
k_weight: RMSNorm weights for key [head_dim]
|
k_weight: RMSNorm weights for key [head_dim]
|
||||||
@ -65,7 +66,7 @@ def torch_ref_rms_norm_rope(qkv, num_heads_q, num_heads_k, num_heads_v,
|
|||||||
|
|
||||||
# Create and apply RotaryEmbedding module
|
# Create and apply RotaryEmbedding module
|
||||||
rope_params = RopeParams(
|
rope_params = RopeParams(
|
||||||
dim=head_dim, # Set the rotary dimension to match the head dimension
|
dim=rotary_dim, # Set the rotary dimension
|
||||||
theta=base, # Base value for RoPE calculations
|
theta=base, # Base value for RoPE calculations
|
||||||
max_positions=8192 # Large enough for any reasonable hidden size
|
max_positions=8192 # Large enough for any reasonable hidden size
|
||||||
)
|
)
|
||||||
@ -88,10 +89,12 @@ num_heads_groups = [
|
|||||||
(16, 8, 8), # Qwen3-0.6B, Qwen3-1.7B
|
(16, 8, 8), # Qwen3-0.6B, Qwen3-1.7B
|
||||||
(32, 8, 8), # Qwen3-4B, Qwen3-8B, Qwen3-30B-A3B
|
(32, 8, 8), # Qwen3-4B, Qwen3-8B, Qwen3-30B-A3B
|
||||||
(40, 8, 8), # Qwen3-14B
|
(40, 8, 8), # Qwen3-14B
|
||||||
(64, 8, 8) # Qwen3-32B, Qwen3-235B-A22B
|
(64, 8, 8), # Qwen3-32B, Qwen3-235B-A22B
|
||||||
|
(24, 8, 8), # GLM 4.6
|
||||||
]
|
]
|
||||||
num_tokens_list = [1, 3, 8, 32, 256]
|
num_tokens_list = [1, 3, 8, 32, 256]
|
||||||
is_neox_list = [False, True]
|
is_neox_list = [False, True]
|
||||||
|
partial_rotary_factor_list = [1.0, 0.5]
|
||||||
dtypes = [torch.bfloat16] # TODO: support float16
|
dtypes = [torch.bfloat16] # TODO: support float16
|
||||||
|
|
||||||
|
|
||||||
@ -100,8 +103,9 @@ dtypes = [torch.bfloat16] # TODO: support float16
|
|||||||
@pytest.mark.parametrize("num_tokens", num_tokens_list)
|
@pytest.mark.parametrize("num_tokens", num_tokens_list)
|
||||||
@pytest.mark.parametrize("is_neox", is_neox_list)
|
@pytest.mark.parametrize("is_neox", is_neox_list)
|
||||||
@pytest.mark.parametrize("dtype", dtypes)
|
@pytest.mark.parametrize("dtype", dtypes)
|
||||||
def test_fused_qk_norm_rope(head_dim, num_heads_group, num_tokens, is_neox,
|
@pytest.mark.parametrize("partial_rotary_factor", partial_rotary_factor_list)
|
||||||
dtype):
|
def test_fused_qk_norm_rope(head_dim, num_heads_group, num_tokens,
|
||||||
|
partial_rotary_factor, is_neox, dtype):
|
||||||
"""
|
"""
|
||||||
Test the fused QK RMSNorm + RoPE operation with various configurations.
|
Test the fused QK RMSNorm + RoPE operation with various configurations.
|
||||||
|
|
||||||
@ -143,18 +147,20 @@ def test_fused_qk_norm_rope(head_dim, num_heads_group, num_tokens, is_neox,
|
|||||||
base = 10000.0
|
base = 10000.0
|
||||||
|
|
||||||
factor, low, high, attention_factor = 1.0, 0, 0, 1.0
|
factor, low, high, attention_factor = 1.0, 0, 0, 1.0
|
||||||
|
rotary_dim = int(head_dim * partial_rotary_factor)
|
||||||
# Run the custom fusedQKNormRope operation
|
# Run the custom fusedQKNormRope operation
|
||||||
torch.ops.trtllm.fused_qk_norm_rope(qkv, num_heads_q, num_heads_k,
|
torch.ops.trtllm.fused_qk_norm_rope(qkv, num_heads_q, num_heads_k,
|
||||||
num_heads_v, head_dim, eps, q_weight,
|
num_heads_v, head_dim, rotary_dim, eps,
|
||||||
k_weight, base, is_neox, position_ids,
|
q_weight, k_weight, base, is_neox,
|
||||||
factor, low, high, attention_factor,
|
position_ids, factor, low, high,
|
||||||
True)
|
attention_factor, True)
|
||||||
output = qkv # This op is inplace
|
output = qkv # This op is inplace
|
||||||
|
|
||||||
# Compute reference output using TensorRT LLM modules
|
# Compute reference output using TensorRT LLM modules
|
||||||
ref_output = torch_ref_rms_norm_rope(qkv_copy, num_heads_q, num_heads_k,
|
ref_output = torch_ref_rms_norm_rope(qkv_copy, num_heads_q, num_heads_k,
|
||||||
num_heads_v, head_dim, eps, q_weight,
|
num_heads_v, head_dim, rotary_dim, eps,
|
||||||
k_weight, base, is_neox, position_ids)
|
q_weight, k_weight, base, is_neox,
|
||||||
|
position_ids)
|
||||||
|
|
||||||
# Compare outputs from custom kernel vs reference implementation
|
# Compare outputs from custom kernel vs reference implementation
|
||||||
torch.testing.assert_close(
|
torch.testing.assert_close(
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user