[TRTLLM-10030][perf] avoid sync in PyTorchModelEngine when using beam search (#11341)

Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com>
This commit is contained in:
mpikulski 2026-02-07 05:31:11 +01:00 committed by GitHub
parent ffc0f54959
commit 03b38e9fbf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 4 additions and 1 deletions

View File

@ -2714,7 +2714,8 @@ class PyTorchModelEngine(ModelEngine):
#Copy cache indirection to local buffer with offsets changing: seq_slots[i] -> i
# Convert to GPU tensor to avoid implicit sync
gen_request_seq_slots_tensor = torch.tensor(
gen_request_seq_slots, dtype=torch.long, device='cuda')
gen_request_seq_slots, dtype=torch.long,
pin_memory=True).to(device='cuda', non_blocking=True)
self.cache_indirection_attention[:num_generation_requests].copy_(
cache_indirection_buffer[gen_request_seq_slots_tensor])
if cache_indirection_buffer is not None or is_cuda_graph_during_warmup:

View File

@ -898,6 +898,8 @@ class AsyncWorkerMixin:
class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin):
DEFAULT_MAX_TOPK_LOGPROBS = 20
SampleState = SampleStateTorch
@override
def get_cache_indirection(self) -> torch.Tensor | None:
return self.store.cache_indirection