From 3dbb08729249641db030ebe412f027585fb156b9 Mon Sep 17 00:00:00 2001 From: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> Date: Mon, 12 May 2025 08:11:40 -0700 Subject: [PATCH] [TRTLLM-5188] fix: [AutoDeploy] update output shape of prepare_fused_mha_metadata_fake (#4199) * update output shape of fake kernel prepare_fused_mha_metadata_fake Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> * minor Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> --------- Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> --- .../_torch/auto_deploy/custom_ops/triton_attention.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py index 1773e16f7b..817068dd6c 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py @@ -272,15 +272,18 @@ def prepare_fused_mha_metadata( ) +# TODO: Move the truncation of inputs out of this custom op +# SequenceInfo._get_sanitized_num_sequences could break in fake mode @prepare_fused_mha_metadata.register_fake def prepare_fused_mha_metadata_fake( input_ids, position_ids, seq_len, input_pos, cache_loc, pages_per_seq, page_size ): + num_seq = SequenceInfo._get_sanitized_num_sequences(input_ids, seq_len) return ( - torch.empty_like(seq_len), - torch.empty_like(input_pos), - torch.empty_like(cache_loc), - torch.empty_like(seq_len), + torch.empty_like(seq_len[:num_seq]), + torch.empty_like(input_pos[:num_seq]), + torch.empty_like(cache_loc[:num_seq]), + torch.empty_like(seq_len[:num_seq]), )