TensorRT-LLMs/tensorrt_llm/_torch/pipeline_interface.py
Dom Brown 8709fe8b53
chore: bump version to 0.19.0 (#3598) (#3841)
test: add test cases for 0.19 release (#3608)

* fix test name



* add quickstart test for nemotron-ultra



* add rcca multi-node test case for deepseek-v3



* add rcca info



---------




squash (#3642)



fix: nvbugs/5187237: fix deterministic mode crash (#3448)

* nvbugs/5187237 nvbugs/5112075: fix deterministic mode error

* remove waive


* Revert "remove waive"

This reverts commit 0bf5486d19906d692bfb7a6262333c296b0087ac.



* revert ar fusion



---------



update fp8 doc (#3647)




tests: change qa perf test to trtllm-bench (#3619)




 fix: FP8 quantized lm_head (NvBug 5214229) (#3567)



infra: Add PR approval protection for the release branch (#3634)



fix: nvbugs/5231298: pytorch allreduce issue (#3673)



Fix: nvbugs/5222698 variable not defined (#3630)

* Fix: nvbugs/5222698 variable not defined



* Tidy code



---------



test:sync waives.txt from main branch by disabling test_perf/gpt_350m-cppmanager case (#3685)



test:restore fp8 kv cache testing for L0 (#3671)



doc: Update DeepSeek perf docs (#3693)

* Update DeepSeek perf docs



* update



* Apply suggestions from code review




---------




tests: waive test_llm_multi_node (#3664)



fix: update test_user_buffers_mm_add_prologue atol (#3711)



Fix: cherry-pick hmac encryption from main branch (#3635)

* security fix cherry-pick changes from main



* fix hmac in remote mpi session (#3649)



---------





Un-waive DS-V3-Lite tests. (#3621)



fix: FP8 kv accuracy (#3675)

* fix FP8 kv accuracy



* update doc



---------



Fix script options for engines. (#3622)



unwaive multi-node test (#3721)



chore : Split more tests out of gpt tests (#3524) (#3674)



doc:add torch examples link into torch backend documentation (#3749)




test: Get Eagle tests working (#3593) (#3722)




Waive L0 test (#3756)



waive failed case in perf test, change default max_batch_size to 512 and write config.json to output log (#3656)





Update ds v3 parameters in stress test. (#3676)

waive gemma on L20 (#3766)



https://nvbugs/5141291: Fix convert.py script for Qwen model. (#3758)

Include Qwen2VLDecoderLayer in the smooth_qwen2_model function.



fix: PP4 fixes and cleanup (#3688)




remove benchmark test list (#3643)



skip disagg deepseek test if sm!=90 (#3720)



test: skip failed cases on B200 (#3710)

* add skip condition to tests



* fix error



---------



test: [nvbug: 5234494] skip_pre_ada for fp8 cases (#3718)

* skip_pre_ada for fp8 cases



* update



* update after rebase



---------



add know issue to deepseek doc. (#3800)



Fix ModelOpt Mixtral AWQ OOM (#3714) (#3761)




Waive L0 tests (#3826)



fix: Reduce memory usage in fused moe op associated with AutoTuning and fix moe fallback issue. (#3793)

* Reduce memory usage in fused moe op associated with AutoTuning.
* Replace pre-defined bucket size strategy with a generating function based on the tune_max_num_tokens.
* Add free_memory logic of workspace in min_latency_mode fused moe path.



* Fix fused_moe fallback issue. (#3652)

min_latency_mode is only set to False during warmup phase. Thus when it becomes true during inference, all tactics fall back to the default one and thus cause perf regression.



---------



[doc] Better document for Draft-Target-Model (DTM) speculative decoding (#3797)




Fix pre-commit



Fix again



Address some review comments for the MI

Signed-off-by: Dom Brown <3886319+DomBrown@users.noreply.github.com>
Co-authored-by: Zhanrui Sun <184402041+ZhanruiSunCh@users.noreply.github.com>
2025-04-29 16:57:22 +08:00

82 lines
3.1 KiB
Python

from typing import Optional, Union
import torch
from .distributed import PPComm
class PipelineInterface:
"""
A container class for passing intermediate tensors between pipeline parallel ranks.
It contains two intermediate tensors: [hidden_states, residual], supporting:
- Dict access: pp['hidden_states'], pp['residual']
- Unpacking: hidden, residual = pp
- PP communication: pp.send(), pp.recv()
- Slicing: pp[start:end]
Note: When using this interface in pp, the packing/unpacking and send/recv
operations must be used symmetrically within stage and between successive ranks.
"""
_pp_comm = None
def __init__(self,
hidden_states: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None):
self.hidden_states = hidden_states
self.residual = residual
self.tag = 1234
@classmethod
def init_pp_comm(cls, mapping):
"""Initialize PPComm once at startup"""
cls._pp_comm = PPComm(mapping)
def __getitem__(self, key: Union[str, slice]):
if isinstance(key, str):
if key == 'hidden_states':
return self.hidden_states
elif key == 'residual':
return self.residual
raise KeyError(f"Unknown key: {key}")
elif isinstance(key, slice):
return PipelineInterface(hidden_states=self.hidden_states[key] if
self.hidden_states is not None else None,
residual=self.residual[key]
if self.residual is not None else None)
def __setitem__(self, key: Union[str, slice], value: torch.Tensor):
if isinstance(key, str):
if key == 'hidden_states':
self.hidden_states = value
elif key == 'residual':
self.residual = value
else:
raise KeyError(f"Unknown key: {key}")
elif isinstance(key, slice):
if self.hidden_states is not None:
self.hidden_states[key] = value
if self.residual is not None:
self.residual[key] = value
def __iter__(self):
return iter((self.hidden_states, self.residual))
def recv(self):
"""Receive tensors from previous rank."""
if self.hidden_states is not None:
self._pp_comm.recv(self.hidden_states, tag=self.tag)
if self.residual is not None:
self._pp_comm.recv(self.residual, tag=self.tag)
def send(self):
"""Send tensors to next rank."""
# pp_comm.send returns after nccl send kernel is enqueued. Event sync waits till prev kernel
# finishes and avoids earlier PP rank executing multiple microbatches ahead of later rank.
self._pp_comm.send_event.synchronize()
if self.hidden_states is not None:
self._pp_comm.send(self.hidden_states, tag=self.tag)
if self.residual is not None:
self._pp_comm.send(self.residual, tag=self.tag)
self._pp_comm.send_event.record()