mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[TRTLLM-5188] fix: [AutoDeploy] update output shape of prepare_fused_mha_metadata_fake (#4199)
* update output shape of fake kernel prepare_fused_mha_metadata_fake Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> * minor Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> --------- Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>
This commit is contained in:
parent
b1bee9c394
commit
3dbb087292
@ -272,15 +272,18 @@ def prepare_fused_mha_metadata(
|
||||
)
|
||||
|
||||
|
||||
# TODO: Move the truncation of inputs out of this custom op
|
||||
# SequenceInfo._get_sanitized_num_sequences could break in fake mode
|
||||
@prepare_fused_mha_metadata.register_fake
|
||||
def prepare_fused_mha_metadata_fake(
|
||||
input_ids, position_ids, seq_len, input_pos, cache_loc, pages_per_seq, page_size
|
||||
):
|
||||
num_seq = SequenceInfo._get_sanitized_num_sequences(input_ids, seq_len)
|
||||
return (
|
||||
torch.empty_like(seq_len),
|
||||
torch.empty_like(input_pos),
|
||||
torch.empty_like(cache_loc),
|
||||
torch.empty_like(seq_len),
|
||||
torch.empty_like(seq_len[:num_seq]),
|
||||
torch.empty_like(input_pos[:num_seq]),
|
||||
torch.empty_like(cache_loc[:num_seq]),
|
||||
torch.empty_like(seq_len[:num_seq]),
|
||||
)
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user