mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 15:55:08 +08:00
[TRTLLM-10030][perf] avoid sync in PyTorchModelEngine when using beam search (#11341)
Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com>
This commit is contained in:
parent
ffc0f54959
commit
03b38e9fbf
@ -2714,7 +2714,8 @@ class PyTorchModelEngine(ModelEngine):
|
||||
#Copy cache indirection to local buffer with offsets changing: seq_slots[i] -> i
|
||||
# Convert to GPU tensor to avoid implicit sync
|
||||
gen_request_seq_slots_tensor = torch.tensor(
|
||||
gen_request_seq_slots, dtype=torch.long, device='cuda')
|
||||
gen_request_seq_slots, dtype=torch.long,
|
||||
pin_memory=True).to(device='cuda', non_blocking=True)
|
||||
self.cache_indirection_attention[:num_generation_requests].copy_(
|
||||
cache_indirection_buffer[gen_request_seq_slots_tensor])
|
||||
if cache_indirection_buffer is not None or is_cuda_graph_during_warmup:
|
||||
|
||||
@ -898,6 +898,8 @@ class AsyncWorkerMixin:
|
||||
class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin):
|
||||
DEFAULT_MAX_TOPK_LOGPROBS = 20
|
||||
|
||||
SampleState = SampleStateTorch
|
||||
|
||||
@override
|
||||
def get_cache_indirection(self) -> torch.Tensor | None:
|
||||
return self.store.cache_indirection
|
||||
|
||||
Loading…
Reference in New Issue
Block a user