[https://nvbugs/5427801][fix] Torch compile support for Llama4 and Ea… (#6978)

Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com>
This commit is contained in:
Jin Li 2025-08-20 15:06:56 +08:00 committed by GitHub
parent df00c81aea
commit 69846c6586
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 74 additions and 63 deletions

View File

@ -165,22 +165,9 @@ class Llama4Attention(Attention):
q, k, v = self.split_qkv(q, k, v)
q = self._attention_scaling(q, position_ids)
out_scale = None
out_scale_sf = None
if self.o_proj.has_fp8_qdq or self.o_proj.has_nvfp4 or self.o_proj.has_fp8_block_scales:
out_scale = self.o_proj.inv_input_scale
if self.o_proj.has_nvfp4 and self.support_nvfp4_output:
out_scale_sf = self.o_proj.input_scale
q, k, v = self.convert_qkv(q, k, v)
attn_output = self.attn.forward(q,
k,
v,
attn_metadata,
out_scale=out_scale,
out_scale_sf=out_scale_sf,
attention_mask=attention_mask,
mrope_config=mrope_config)
attn_output = self.forward_impl(q, k, v, attn_metadata, attention_mask,
None, None, mrope_config)
if isinstance(attn_output, tuple):
attn_output = Fp4QuantizedTensor(attn_output[0], attn_output[1])

View File

@ -362,6 +362,12 @@ class SpecDecOneEngineForCausalLM(DecoderModelForCausalLM[TModel, TConfig],
model_config,
model_config.mapping)
if draft_config is not None:
for key, value in draft_config.extra_attrs.items():
assert key in ('attn_layers', 'mla_layers')
assert key in model_config.extra_attrs
model_config.extra_attrs[key].update(value)
def forward(
self,
attn_metadata: AttentionMetadata,

View File

@ -92,7 +92,7 @@ def attn_custom_op_inplace(
mrope_position_deltas,
attention_window_size,
attention_mask_data,
False,
enable_attn_nvfp4_output=False,
output=output)
@ -372,6 +372,58 @@ class Attention(nn.Module):
return attn_output[0], attn_output[1]
return attn_output, None
def forward_impl(
self,
q: torch.Tensor,
k: Optional[torch.Tensor],
v: Optional[torch.Tensor],
attn_metadata: AttentionMetadata,
attention_mask: AttentionMask,
attention_window_size: Optional[int],
attention_mask_data: Optional[torch.Tensor],
mrope_config: Optional[dict],
):
mrope_rotary_cos_sin = None
mrope_position_deltas = None
if mrope_config is not None:
if "mrope_rotary_cos_sin" in mrope_config:
mrope_rotary_cos_sin = mrope_config["mrope_rotary_cos_sin"]
if "mrope_position_deltas" in mrope_config:
mrope_position_deltas = mrope_config["mrope_position_deltas"]
# Currently only TRTLLM and FLASHINFER are torch compile compatible backends.
# Only enable custom inplace op when torch compiling.
use_custom_inplace_op = (self.register_to_config
and (self.attn_backend == "TRTLLM"
or self.attn_backend == "FLASHINFER")
and is_torch_compiling())
if use_custom_inplace_op:
output = self.create_output(q)
attn_custom_op_inplace(
q,
k,
v,
attention_mask,
mrope_rotary_cos_sin,
mrope_position_deltas,
attention_window_size,
attention_mask_data,
self.layer_idx_str,
output,
)
else:
output, output_sf = self._attn_impl(q, k, v, attn_metadata,
attention_mask,
mrope_rotary_cos_sin,
mrope_position_deltas,
attention_window_size,
attention_mask_data)
if output_sf is not None:
output = Fp4QuantizedTensor(output, output_sf)
return output
def forward(
self,
position_ids: Optional[torch.IntTensor],
@ -414,54 +466,18 @@ class Attention(nn.Module):
if qkv_lora is not None:
qkv = qkv + qkv_lora
mrope_rotary_cos_sin = None
mrope_position_deltas = None
if mrope_config is not None:
if "mrope_rotary_cos_sin" in mrope_config:
mrope_rotary_cos_sin = mrope_config["mrope_rotary_cos_sin"]
if "mrope_position_deltas" in mrope_config:
mrope_position_deltas = mrope_config["mrope_position_deltas"]
output = None
q, k, v = qkv, None, None
q, k, v = self.apply_rope(q, k, v, position_ids)
q, k, v = self.convert_qkv(q, k, v)
# Currently only TRTLLM and FLASHINFER are torch compile compatible backends.
# Only enable custom inplace op when torch compiling.
use_custom_inplace_op = (self.register_to_config
and (self.attn_backend == "TRTLLM"
or self.attn_backend == "FLASHINFER")
and is_torch_compiling())
if use_custom_inplace_op:
output = self.create_output(q)
attn_custom_op_inplace(
q,
k,
v,
attention_mask,
mrope_rotary_cos_sin,
mrope_position_deltas,
attention_window_size,
attention_mask_data,
self.layer_idx_str,
output=output,
)
else:
output, output_sf = self._attn_impl(
q,
k,
v,
attn_metadata,
attention_mask,
mrope_rotary_cos_sin,
mrope_position_deltas,
attention_window_size,
attention_mask_data,
)
if output_sf is not None:
output = Fp4QuantizedTensor(output, output_sf)
output = self.forward_impl(q,
k,
v,
attn_metadata,
attention_mask,
attention_window_size,
attention_mask_data,
mrope_config=mrope_config)
attn_output = self.o_proj(output,
all_reduce_params=all_reduce_params,

View File

@ -550,11 +550,14 @@ class TestLlama4MaverickInstruct(LlmapiAccuracyTestHarness):
speculative_model_dir=eagle_model_dir)
kv_cache_config = KvCacheConfig(enable_block_reuse=False,
free_gpu_memory_fraction=0.75)
torch_compile_config = TorchCompileConfig(
enable_fullgraph=True,
enable_piecewise_cuda_graph=True,
max_num_streams=3) if torch_compile else None
pytorch_config = dict(
cuda_graph_config=CudaGraphConfig(max_batch_size=8),
enable_attention_dp=False,
torch_compile_config=TorchCompileConfig(
enable_fullgraph=torch_compile))
torch_compile_config=torch_compile_config)
with LLM(model_path,
kv_cache_config=kv_cache_config,
tensor_parallel_size=tp_size,

View File

@ -260,7 +260,6 @@ examples/test_bert.py::test_llm_bert_general[compare_hf-enable_remove_input_padd
test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-mixture_text_image-True] SKIP (https://nvbugs/5430124)
examples/test_granite.py::test_granite_bf16_lora[granite-3.0-1b-a400m-instruct] SKIP (https://nvbugs/5431132)
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=1-ctx_pp=4] SKIP (https://nvbugs/5434320)
accuracy/test_llm_api_pytorch.py::TestLlama4MaverickInstruct::test_fp8_eagle3[tp8-torch_compile=True] SKIP (https://nvbugs/5427801)
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=2-ctx_pp=4] SKIP (https://nvbugs/5434320)
accuracy/test_llm_api.py::TestLlama3_2_1B::test_int4_awq_int8_kv_cache SKIP (https://nvbugs/5433541)
accuracy/test_llm_api.py::TestLlama3_2_1B::test_fp8_pp2 SKIP (https://nvbugs/5433541)