[https://nvbugs/5355316] fix: update torch.compile option to fix triton store_cubin error (#5865)

Signed-off-by: Zhenhuan Chen <chenzhh3671@gmail.com>
This commit is contained in:
Zhenhuan Chen 2025-07-10 11:16:57 +08:00 committed by GitHub
parent ce048eccd3
commit d9e265d5e7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 2 additions and 2 deletions

View File

@ -123,7 +123,7 @@ def get_masked_input_and_mask(
# We use torch.compile() to fuse the tiny pointwise ops before all_reduce/all_gather for Embedding module. # We use torch.compile() to fuse the tiny pointwise ops before all_reduce/all_gather for Embedding module.
@torch.compile(mode="max-autotune-no-cudagraphs") @torch.compile(options={"max-autotune": True})
def pre_comm_embedding_ops( def pre_comm_embedding_ops(
input_: torch.Tensor, input_: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,

View File

@ -335,7 +335,7 @@ class Eagle3OneModelWorker(nn.Module):
self.max_draft_tokens = self.spec_config.max_draft_tokens self.max_draft_tokens = self.spec_config.max_draft_tokens
self.mapping = mapping self.mapping = mapping
@torch.compile(mode="max-autotune-no-cudagraphs") @torch.compile(options={"max-autotune": True})
def forward(self, input_ids, position_ids, hidden_states, logits, def forward(self, input_ids, position_ids, hidden_states, logits,
attn_metadata, spec_metadata, draft_model): attn_metadata, spec_metadata, draft_model):
batch_size = attn_metadata.num_seqs batch_size = attn_metadata.num_seqs