mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-05 02:31:33 +08:00
[https://nvbugs/5803813][fix] Fix llama 4 min latency (#10724)
Signed-off-by: Mike Iovine <6158008+mikeiovine@users.noreply.github.com> Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
This commit is contained in:
parent
93e7ae73ea
commit
f02948d956
@ -88,9 +88,10 @@ class Llama4MinLatencyLinear(Linear):
|
||||
self.enable_trtllm_gen = enable_trtllm_gen
|
||||
self.position_ids = None
|
||||
|
||||
def load_weights(self, weights: List[Dict]):
|
||||
def load_weights(self, weights: List[Dict], allow_partial_loading: bool):
|
||||
|
||||
super().load_weights(weights)
|
||||
super().load_weights(weights,
|
||||
allow_partial_loading=allow_partial_loading)
|
||||
|
||||
# After loading weights, calculate the combined scale (input_scale * weight_scale) for special kernels and
|
||||
# trtllm-gen kernels.
|
||||
@ -384,11 +385,6 @@ class Llama4MinLatencyAttention(Llama4Attention):
|
||||
and self.floor_scale == 8192.0 \
|
||||
and self.attn_scale == 0.1
|
||||
|
||||
qkv_shard_indices_mapping = {
|
||||
"q": (0, self.q_size),
|
||||
"k": (self.q_size, self.kv_size),
|
||||
"v": (self.q_size + self.kv_size, self.kv_size),
|
||||
}
|
||||
# When min-latency QKV gemm is enabled, override qkv_proj.
|
||||
self.qkv_proj = Llama4MinLatencyLinear(
|
||||
self.hidden_size,
|
||||
@ -405,7 +401,6 @@ class Llama4MinLatencyAttention(Llama4Attention):
|
||||
enable_fused_gemm_attn_scaling=self.
|
||||
enable_fused_gemm_attn_scaling,
|
||||
enable_trtllm_gen=True,
|
||||
fused_weight_shard_indices_mapping=qkv_shard_indices_mapping,
|
||||
)
|
||||
|
||||
def _forward_nope(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user