mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[https://nvbugs/5355316] fix: update torch.compile option to fix triton store_cubin error (#5865)
Signed-off-by: Zhenhuan Chen <chenzhh3671@gmail.com>
This commit is contained in:
parent
ce048eccd3
commit
d9e265d5e7
@ -123,7 +123,7 @@ def get_masked_input_and_mask(
|
|||||||
|
|
||||||
|
|
||||||
# We use torch.compile() to fuse the tiny pointwise ops before all_reduce/all_gather for Embedding module.
|
# We use torch.compile() to fuse the tiny pointwise ops before all_reduce/all_gather for Embedding module.
|
||||||
@torch.compile(mode="max-autotune-no-cudagraphs")
|
@torch.compile(options={"max-autotune": True})
|
||||||
def pre_comm_embedding_ops(
|
def pre_comm_embedding_ops(
|
||||||
input_: torch.Tensor,
|
input_: torch.Tensor,
|
||||||
weight: torch.Tensor,
|
weight: torch.Tensor,
|
||||||
|
|||||||
@ -335,7 +335,7 @@ class Eagle3OneModelWorker(nn.Module):
|
|||||||
self.max_draft_tokens = self.spec_config.max_draft_tokens
|
self.max_draft_tokens = self.spec_config.max_draft_tokens
|
||||||
self.mapping = mapping
|
self.mapping = mapping
|
||||||
|
|
||||||
@torch.compile(mode="max-autotune-no-cudagraphs")
|
@torch.compile(options={"max-autotune": True})
|
||||||
def forward(self, input_ids, position_ids, hidden_states, logits,
|
def forward(self, input_ids, position_ids, hidden_states, logits,
|
||||||
attn_metadata, spec_metadata, draft_model):
|
attn_metadata, spec_metadata, draft_model):
|
||||||
batch_size = attn_metadata.num_seqs
|
batch_size = attn_metadata.num_seqs
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user