[TRTLLM-5188] fix: [AutoDeploy] update output shape of prepare_fused_mha_metadata_fake (#4199)

* update output shape of fake kernel prepare_fused_mha_metadata_fake

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>

* minor

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>

---------

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>
This commit is contained in:
Fridah-nv 2025-05-12 08:11:40 -07:00 committed by GitHub
parent b1bee9c394
commit 3dbb087292
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -272,15 +272,18 @@ def prepare_fused_mha_metadata(
)
# TODO: Move the truncation of inputs out of this custom op
# SequenceInfo._get_sanitized_num_sequences could break in fake mode
@prepare_fused_mha_metadata.register_fake
def prepare_fused_mha_metadata_fake(
input_ids, position_ids, seq_len, input_pos, cache_loc, pages_per_seq, page_size
):
num_seq = SequenceInfo._get_sanitized_num_sequences(input_ids, seq_len)
return (
torch.empty_like(seq_len),
torch.empty_like(input_pos),
torch.empty_like(cache_loc),
torch.empty_like(seq_len),
torch.empty_like(seq_len[:num_seq]),
torch.empty_like(input_pos[:num_seq]),
torch.empty_like(cache_loc[:num_seq]),
torch.empty_like(seq_len[:num_seq]),
)