diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py index 1773e16f7b..817068dd6c 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py @@ -272,15 +272,18 @@ def prepare_fused_mha_metadata( ) +# TODO: Move the truncation of inputs out of this custom op +# SequenceInfo._get_sanitized_num_sequences could break in fake mode @prepare_fused_mha_metadata.register_fake def prepare_fused_mha_metadata_fake( input_ids, position_ids, seq_len, input_pos, cache_loc, pages_per_seq, page_size ): + num_seq = SequenceInfo._get_sanitized_num_sequences(input_ids, seq_len) return ( - torch.empty_like(seq_len), - torch.empty_like(input_pos), - torch.empty_like(cache_loc), - torch.empty_like(seq_len), + torch.empty_like(seq_len[:num_seq]), + torch.empty_like(input_pos[:num_seq]), + torch.empty_like(cache_loc[:num_seq]), + torch.empty_like(seq_len[:num_seq]), )