mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[https://nvbugs/5427801][fix] Torch compile support for Llama4 and Ea… (#6978)
Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com>
This commit is contained in:
parent
df00c81aea
commit
69846c6586
@ -165,22 +165,9 @@ class Llama4Attention(Attention):
|
|||||||
q, k, v = self.split_qkv(q, k, v)
|
q, k, v = self.split_qkv(q, k, v)
|
||||||
q = self._attention_scaling(q, position_ids)
|
q = self._attention_scaling(q, position_ids)
|
||||||
|
|
||||||
out_scale = None
|
|
||||||
out_scale_sf = None
|
|
||||||
if self.o_proj.has_fp8_qdq or self.o_proj.has_nvfp4 or self.o_proj.has_fp8_block_scales:
|
|
||||||
out_scale = self.o_proj.inv_input_scale
|
|
||||||
if self.o_proj.has_nvfp4 and self.support_nvfp4_output:
|
|
||||||
out_scale_sf = self.o_proj.input_scale
|
|
||||||
|
|
||||||
q, k, v = self.convert_qkv(q, k, v)
|
q, k, v = self.convert_qkv(q, k, v)
|
||||||
attn_output = self.attn.forward(q,
|
attn_output = self.forward_impl(q, k, v, attn_metadata, attention_mask,
|
||||||
k,
|
None, None, mrope_config)
|
||||||
v,
|
|
||||||
attn_metadata,
|
|
||||||
out_scale=out_scale,
|
|
||||||
out_scale_sf=out_scale_sf,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
mrope_config=mrope_config)
|
|
||||||
|
|
||||||
if isinstance(attn_output, tuple):
|
if isinstance(attn_output, tuple):
|
||||||
attn_output = Fp4QuantizedTensor(attn_output[0], attn_output[1])
|
attn_output = Fp4QuantizedTensor(attn_output[0], attn_output[1])
|
||||||
|
|||||||
@ -362,6 +362,12 @@ class SpecDecOneEngineForCausalLM(DecoderModelForCausalLM[TModel, TConfig],
|
|||||||
model_config,
|
model_config,
|
||||||
model_config.mapping)
|
model_config.mapping)
|
||||||
|
|
||||||
|
if draft_config is not None:
|
||||||
|
for key, value in draft_config.extra_attrs.items():
|
||||||
|
assert key in ('attn_layers', 'mla_layers')
|
||||||
|
assert key in model_config.extra_attrs
|
||||||
|
model_config.extra_attrs[key].update(value)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
attn_metadata: AttentionMetadata,
|
attn_metadata: AttentionMetadata,
|
||||||
|
|||||||
@ -92,7 +92,7 @@ def attn_custom_op_inplace(
|
|||||||
mrope_position_deltas,
|
mrope_position_deltas,
|
||||||
attention_window_size,
|
attention_window_size,
|
||||||
attention_mask_data,
|
attention_mask_data,
|
||||||
False,
|
enable_attn_nvfp4_output=False,
|
||||||
output=output)
|
output=output)
|
||||||
|
|
||||||
|
|
||||||
@ -372,6 +372,58 @@ class Attention(nn.Module):
|
|||||||
return attn_output[0], attn_output[1]
|
return attn_output[0], attn_output[1]
|
||||||
return attn_output, None
|
return attn_output, None
|
||||||
|
|
||||||
|
def forward_impl(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: Optional[torch.Tensor],
|
||||||
|
v: Optional[torch.Tensor],
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
attention_mask: AttentionMask,
|
||||||
|
attention_window_size: Optional[int],
|
||||||
|
attention_mask_data: Optional[torch.Tensor],
|
||||||
|
mrope_config: Optional[dict],
|
||||||
|
):
|
||||||
|
mrope_rotary_cos_sin = None
|
||||||
|
mrope_position_deltas = None
|
||||||
|
if mrope_config is not None:
|
||||||
|
if "mrope_rotary_cos_sin" in mrope_config:
|
||||||
|
mrope_rotary_cos_sin = mrope_config["mrope_rotary_cos_sin"]
|
||||||
|
if "mrope_position_deltas" in mrope_config:
|
||||||
|
mrope_position_deltas = mrope_config["mrope_position_deltas"]
|
||||||
|
|
||||||
|
# Currently only TRTLLM and FLASHINFER are torch compile compatible backends.
|
||||||
|
# Only enable custom inplace op when torch compiling.
|
||||||
|
use_custom_inplace_op = (self.register_to_config
|
||||||
|
and (self.attn_backend == "TRTLLM"
|
||||||
|
or self.attn_backend == "FLASHINFER")
|
||||||
|
and is_torch_compiling())
|
||||||
|
|
||||||
|
if use_custom_inplace_op:
|
||||||
|
output = self.create_output(q)
|
||||||
|
attn_custom_op_inplace(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
attention_mask,
|
||||||
|
mrope_rotary_cos_sin,
|
||||||
|
mrope_position_deltas,
|
||||||
|
attention_window_size,
|
||||||
|
attention_mask_data,
|
||||||
|
self.layer_idx_str,
|
||||||
|
output,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
output, output_sf = self._attn_impl(q, k, v, attn_metadata,
|
||||||
|
attention_mask,
|
||||||
|
mrope_rotary_cos_sin,
|
||||||
|
mrope_position_deltas,
|
||||||
|
attention_window_size,
|
||||||
|
attention_mask_data)
|
||||||
|
if output_sf is not None:
|
||||||
|
output = Fp4QuantizedTensor(output, output_sf)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
position_ids: Optional[torch.IntTensor],
|
position_ids: Optional[torch.IntTensor],
|
||||||
@ -414,54 +466,18 @@ class Attention(nn.Module):
|
|||||||
if qkv_lora is not None:
|
if qkv_lora is not None:
|
||||||
qkv = qkv + qkv_lora
|
qkv = qkv + qkv_lora
|
||||||
|
|
||||||
mrope_rotary_cos_sin = None
|
|
||||||
mrope_position_deltas = None
|
|
||||||
if mrope_config is not None:
|
|
||||||
if "mrope_rotary_cos_sin" in mrope_config:
|
|
||||||
mrope_rotary_cos_sin = mrope_config["mrope_rotary_cos_sin"]
|
|
||||||
if "mrope_position_deltas" in mrope_config:
|
|
||||||
mrope_position_deltas = mrope_config["mrope_position_deltas"]
|
|
||||||
|
|
||||||
output = None
|
|
||||||
|
|
||||||
q, k, v = qkv, None, None
|
q, k, v = qkv, None, None
|
||||||
q, k, v = self.apply_rope(q, k, v, position_ids)
|
q, k, v = self.apply_rope(q, k, v, position_ids)
|
||||||
q, k, v = self.convert_qkv(q, k, v)
|
q, k, v = self.convert_qkv(q, k, v)
|
||||||
|
|
||||||
# Currently only TRTLLM and FLASHINFER are torch compile compatible backends.
|
output = self.forward_impl(q,
|
||||||
# Only enable custom inplace op when torch compiling.
|
k,
|
||||||
use_custom_inplace_op = (self.register_to_config
|
v,
|
||||||
and (self.attn_backend == "TRTLLM"
|
attn_metadata,
|
||||||
or self.attn_backend == "FLASHINFER")
|
attention_mask,
|
||||||
and is_torch_compiling())
|
attention_window_size,
|
||||||
if use_custom_inplace_op:
|
attention_mask_data,
|
||||||
output = self.create_output(q)
|
mrope_config=mrope_config)
|
||||||
attn_custom_op_inplace(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
attention_mask,
|
|
||||||
mrope_rotary_cos_sin,
|
|
||||||
mrope_position_deltas,
|
|
||||||
attention_window_size,
|
|
||||||
attention_mask_data,
|
|
||||||
self.layer_idx_str,
|
|
||||||
output=output,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
output, output_sf = self._attn_impl(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
attn_metadata,
|
|
||||||
attention_mask,
|
|
||||||
mrope_rotary_cos_sin,
|
|
||||||
mrope_position_deltas,
|
|
||||||
attention_window_size,
|
|
||||||
attention_mask_data,
|
|
||||||
)
|
|
||||||
if output_sf is not None:
|
|
||||||
output = Fp4QuantizedTensor(output, output_sf)
|
|
||||||
|
|
||||||
attn_output = self.o_proj(output,
|
attn_output = self.o_proj(output,
|
||||||
all_reduce_params=all_reduce_params,
|
all_reduce_params=all_reduce_params,
|
||||||
|
|||||||
@ -550,11 +550,14 @@ class TestLlama4MaverickInstruct(LlmapiAccuracyTestHarness):
|
|||||||
speculative_model_dir=eagle_model_dir)
|
speculative_model_dir=eagle_model_dir)
|
||||||
kv_cache_config = KvCacheConfig(enable_block_reuse=False,
|
kv_cache_config = KvCacheConfig(enable_block_reuse=False,
|
||||||
free_gpu_memory_fraction=0.75)
|
free_gpu_memory_fraction=0.75)
|
||||||
|
torch_compile_config = TorchCompileConfig(
|
||||||
|
enable_fullgraph=True,
|
||||||
|
enable_piecewise_cuda_graph=True,
|
||||||
|
max_num_streams=3) if torch_compile else None
|
||||||
pytorch_config = dict(
|
pytorch_config = dict(
|
||||||
cuda_graph_config=CudaGraphConfig(max_batch_size=8),
|
cuda_graph_config=CudaGraphConfig(max_batch_size=8),
|
||||||
enable_attention_dp=False,
|
enable_attention_dp=False,
|
||||||
torch_compile_config=TorchCompileConfig(
|
torch_compile_config=torch_compile_config)
|
||||||
enable_fullgraph=torch_compile))
|
|
||||||
with LLM(model_path,
|
with LLM(model_path,
|
||||||
kv_cache_config=kv_cache_config,
|
kv_cache_config=kv_cache_config,
|
||||||
tensor_parallel_size=tp_size,
|
tensor_parallel_size=tp_size,
|
||||||
|
|||||||
@ -260,7 +260,6 @@ examples/test_bert.py::test_llm_bert_general[compare_hf-enable_remove_input_padd
|
|||||||
test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-mixture_text_image-True] SKIP (https://nvbugs/5430124)
|
test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-mixture_text_image-True] SKIP (https://nvbugs/5430124)
|
||||||
examples/test_granite.py::test_granite_bf16_lora[granite-3.0-1b-a400m-instruct] SKIP (https://nvbugs/5431132)
|
examples/test_granite.py::test_granite_bf16_lora[granite-3.0-1b-a400m-instruct] SKIP (https://nvbugs/5431132)
|
||||||
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=1-ctx_pp=4] SKIP (https://nvbugs/5434320)
|
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=1-ctx_pp=4] SKIP (https://nvbugs/5434320)
|
||||||
accuracy/test_llm_api_pytorch.py::TestLlama4MaverickInstruct::test_fp8_eagle3[tp8-torch_compile=True] SKIP (https://nvbugs/5427801)
|
|
||||||
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=2-ctx_pp=4] SKIP (https://nvbugs/5434320)
|
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=2-ctx_pp=4] SKIP (https://nvbugs/5434320)
|
||||||
accuracy/test_llm_api.py::TestLlama3_2_1B::test_int4_awq_int8_kv_cache SKIP (https://nvbugs/5433541)
|
accuracy/test_llm_api.py::TestLlama3_2_1B::test_int4_awq_int8_kv_cache SKIP (https://nvbugs/5433541)
|
||||||
accuracy/test_llm_api.py::TestLlama3_2_1B::test_fp8_pp2 SKIP (https://nvbugs/5433541)
|
accuracy/test_llm_api.py::TestLlama3_2_1B::test_fp8_pp2 SKIP (https://nvbugs/5433541)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user