mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-04 18:21:52 +08:00
Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
This commit is contained in:
parent
ce556290c9
commit
00f341be49
@ -351,13 +351,14 @@ def maybe_pad_for_cuda_graph(func):
|
||||
def _call_func():
|
||||
return func(self, scheduled_requests, resource_manager, *args, **kwargs)
|
||||
|
||||
# check if we use cuda graph and we can run it
|
||||
if not (self.cuda_graph_used and scheduled_requests.can_run_cuda_graph):
|
||||
return _call_func()
|
||||
# check conditions for current rank
|
||||
can_run_cuda_graph = self.cuda_graph_used and scheduled_requests.can_run_cuda_graph
|
||||
batch_size = scheduled_requests.batch_size
|
||||
|
||||
# generate a persistent dummy request right away to ensure we can reserve the necessary
|
||||
# resources (kv page and slot)
|
||||
if self.padding_dummy_request is None:
|
||||
# resources (kv page and slot) the first time we can actually run cuda graph according to
|
||||
# this rank
|
||||
if can_run_cuda_graph and self.padding_dummy_request is None:
|
||||
self.padding_dummy_request = _generate_dummy_request(
|
||||
resource_manager,
|
||||
request_id=CUDA_GRAPH_DUMMY_REQUEST_ID,
|
||||
@ -367,20 +368,48 @@ def maybe_pad_for_cuda_graph(func):
|
||||
max_beam_width=self.max_beam_width,
|
||||
)
|
||||
|
||||
# check closest cuda graph batch size
|
||||
closest_cg_bs = _round_up_to_closest(
|
||||
self.cuda_graph_batch_sizes, scheduled_requests.batch_size
|
||||
)
|
||||
# check if we can pad the batch based on the availability of the dummy request
|
||||
can_pad = self.padding_dummy_request is not None
|
||||
|
||||
# check if we need to pad
|
||||
num_padding = closest_cg_bs - scheduled_requests.batch_size
|
||||
# in attention DP mode, we check all ranks
|
||||
if self.enable_attention_dp and self.mapping.tp_size > 1:
|
||||
assert self.dist is not None, "Distributed object is required for attention DP mode"
|
||||
all_rank_info = self.dist.tp_allgather([can_run_cuda_graph, can_pad, batch_size])
|
||||
else:
|
||||
all_rank_info = [[can_run_cuda_graph, can_pad, batch_size]]
|
||||
|
||||
if num_padding <= 0:
|
||||
# now let's check if we can in principle run cuda graph across all ranks
|
||||
can_run_cuda_graph_all = all(r_info[0] for r_info in all_rank_info)
|
||||
|
||||
if not can_run_cuda_graph_all:
|
||||
return _call_func()
|
||||
|
||||
# check if we have a dummy request to use
|
||||
if self.padding_dummy_request is None:
|
||||
ad_logger.info("No CUDA graph padding possible due to missing dummy request.")
|
||||
# get closest cudagraph batch size based on max_batch_size across ALL ranks
|
||||
# NOTE: we assume uniform cudagraph batch sizes across all ranks ensuring all ranks get the
|
||||
# same closest cudagraph batch size here based on the max batch size across all ranks
|
||||
max_batch_size = max(r_info[2] for r_info in all_rank_info)
|
||||
cg_batch_size = _round_up_to_closest(self.cuda_graph_batch_sizes, max_batch_size)
|
||||
|
||||
if cg_batch_size is None:
|
||||
return _call_func()
|
||||
|
||||
# let's check if all ranks can pad the batch if they need to
|
||||
can_pad_all = all(r_info[1] or (r_info[2] == cg_batch_size) for r_info in all_rank_info)
|
||||
|
||||
# fall back if we cannot run cudagraph due to padding issues
|
||||
if not can_pad_all:
|
||||
return _call_func()
|
||||
|
||||
# check actual amount of padding needed
|
||||
num_padding = cg_batch_size - batch_size
|
||||
|
||||
# we should only hit this point for either of these conditions
|
||||
assert num_padding == 0 or (num_padding > 0 and self.padding_dummy_request is not None), (
|
||||
"Padding should not be needed or available at this point"
|
||||
)
|
||||
|
||||
# no padding needed on current rank
|
||||
if num_padding == 0:
|
||||
return _call_func()
|
||||
|
||||
# pad the scheduled requests with the dummy request
|
||||
@ -411,7 +440,12 @@ class ADEngine(ModelEngine):
|
||||
return self.cache_seq_interface.device
|
||||
|
||||
@classmethod
|
||||
def build_from_config(cls, ad_config: LlmArgs, mapping: Optional[Mapping] = None):
|
||||
def build_from_config(
|
||||
cls,
|
||||
ad_config: LlmArgs,
|
||||
mapping: Optional[Mapping] = None,
|
||||
dist: Optional[Distributed] = None,
|
||||
):
|
||||
"""Build the ADEngine using the LlmArgs that gets passed through from the LLM."""
|
||||
|
||||
max_batch_size = ad_config.max_batch_size
|
||||
@ -453,6 +487,7 @@ class ADEngine(ModelEngine):
|
||||
device,
|
||||
ad_config=ad_config,
|
||||
mapping=mapping,
|
||||
dist=dist,
|
||||
reporting_info=reporting_info,
|
||||
)
|
||||
|
||||
@ -464,6 +499,7 @@ class ADEngine(ModelEngine):
|
||||
device: DeviceLikeType,
|
||||
ad_config: Optional[LlmArgs] = None,
|
||||
mapping: Optional[Mapping] = None,
|
||||
dist: Optional[Distributed] = None,
|
||||
reporting_info: ReportingInfo = ReportingInfo(),
|
||||
) -> None:
|
||||
"""Initialize the engine with model and sequence information."""
|
||||
@ -484,7 +520,7 @@ class ADEngine(ModelEngine):
|
||||
self.iter_states = {}
|
||||
|
||||
# NOTE (lucaslie): not a declared base member in the base class; required by PyExecutor...
|
||||
self.enable_attention_dp = False
|
||||
self.enable_attention_dp = mapping.enable_attention_dp if mapping else False
|
||||
|
||||
if ad_config is not None:
|
||||
self.max_beam_width = ad_config.max_beam_width
|
||||
@ -537,6 +573,7 @@ class ADEngine(ModelEngine):
|
||||
|
||||
# Reuse _execute_logit_post_processors from PyTorchModelEngine
|
||||
self.mapping = mapping
|
||||
self.dist = dist
|
||||
self._execute_logit_post_processors = types.MethodType(
|
||||
PyTorchModelEngine._execute_logit_post_processors, self
|
||||
)
|
||||
@ -1005,13 +1042,23 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer
|
||||
# initialize process groups
|
||||
world_size = mpi_world_size()
|
||||
rank = mpi_rank()
|
||||
dist_mapping = Mapping(rank=rank, world_size=world_size, tp_size=world_size)
|
||||
enable_attention_dp = ad_config.transforms.get("detect_sharding", {}).get(
|
||||
"enable_attention_dp", False
|
||||
)
|
||||
dist_mapping = Mapping(
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
tp_size=world_size,
|
||||
enable_attention_dp=enable_attention_dp,
|
||||
)
|
||||
dist = Distributed.get(dist_mapping)
|
||||
ad_logger.set_rank(rank)
|
||||
torch.cuda.set_device(rank)
|
||||
port = dist.broadcast(get_free_port()) # use MPI broadcast to pick a free port
|
||||
initialize_or_skip(rank, world_size, port)
|
||||
|
||||
ad_logger.info(f"{dist_mapping=}, {dist=}, {port=}")
|
||||
|
||||
# Setup AutoTuner with distributed state for allreduce autotuning
|
||||
AutoTuner.get().setup_distributed_state(dist_mapping)
|
||||
|
||||
@ -1030,7 +1077,7 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer
|
||||
)
|
||||
|
||||
# initialize model engine
|
||||
engine = ADEngine.build_from_config(ad_config=ad_config, mapping=dist_mapping)
|
||||
engine = ADEngine.build_from_config(ad_config=ad_config, mapping=dist_mapping, dist=dist)
|
||||
|
||||
spec_config = ad_config.speculative_config
|
||||
if spec_config is not None and not (
|
||||
|
||||
@ -151,6 +151,11 @@ class ShardingTransformConfig(TransformConfig):
|
||||
|
||||
process_grid: Dict[ShardingDim, int] = Field(default_factory=dict)
|
||||
|
||||
enable_attention_dp: bool = Field(
|
||||
default=False,
|
||||
description="When True, skip TP sharding as attention data parallelism is enabled.",
|
||||
)
|
||||
|
||||
def validate_config(self, sources: Union[ShardingSource, List[ShardingSource]] = None) -> bool:
|
||||
init_process_grid_from_config(self)
|
||||
if sources is None:
|
||||
@ -738,8 +743,9 @@ class Sharding(BaseTransform):
|
||||
f"Using allreduce strategy: {config.allreduce_strategy.name}, dist backend: {config.dist_backend}"
|
||||
)
|
||||
|
||||
if world_size < 2:
|
||||
ad_logger.info("Skipping sharding for single device")
|
||||
if world_size < 2 or config.enable_attention_dp:
|
||||
reason = "single device" if world_size < 2 else "attention DP enabled"
|
||||
ad_logger.info(f"Skipping sharding: {reason}")
|
||||
return gm, TransformInfo(
|
||||
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
|
||||
)
|
||||
|
||||
@ -81,6 +81,24 @@ class TestLlama3_1_8B(LlmapiAccuracyTestHarness):
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm, sampling_params=sampling_params)
|
||||
|
||||
@pytest.mark.skip_less_device_memory(32000)
|
||||
@pytest.mark.skip_less_device(2)
|
||||
@pytest.mark.parametrize("world_size", [2, 4])
|
||||
def test_attention_dp(self, world_size):
|
||||
"""Test attention data parallelism mode where TP sharding is disabled."""
|
||||
kwargs = self.get_default_kwargs(enable_chunked_prefill=True)
|
||||
# Enable attention DP - this disables TP sharding
|
||||
kwargs["transforms"]["detect_sharding"] = {"enable_attention_dp": True}
|
||||
sampling_params = self.get_default_sampling_params()
|
||||
with AutoDeployLLM(model=self.MODEL_PATH,
|
||||
tokenizer=self.MODEL_PATH,
|
||||
world_size=world_size,
|
||||
**kwargs) as llm:
|
||||
task = CnnDailymail(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm, sampling_params=sampling_params)
|
||||
|
||||
|
||||
class TestNemotronH(LlmapiAccuracyTestHarness):
|
||||
MODEL_NAME = "nvidia/Nemotron-H-8B-Base-8K"
|
||||
|
||||
@ -333,3 +333,4 @@ l0_dgx_h100:
|
||||
- accuracy/test_llm_api_autodeploy.py::TestNemotronSuperV3::test_bf16
|
||||
- accuracy/test_llm_api_autodeploy.py::TestNemotronSuperV3::test_fp8[4]
|
||||
- accuracy/test_llm_api_autodeploy.py::TestNemotronSuperV3::test_fp8[8]
|
||||
- accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_attention_dp[4]
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
from types import SimpleNamespace
|
||||
from typing import List, Optional, Type
|
||||
|
||||
import pytest
|
||||
@ -9,6 +8,7 @@ from tensorrt_llm import SamplingParams
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import SequenceInfo
|
||||
from tensorrt_llm._torch.auto_deploy.shim.ad_executor import ADEngine
|
||||
from tensorrt_llm._torch.auto_deploy.shim.demollm import DemoEngine
|
||||
from tensorrt_llm._torch.pyexecutor.scheduler import ScheduledRequests
|
||||
|
||||
|
||||
class TransformerLikeModelwithFakeCachePool(nn.Module):
|
||||
@ -196,19 +196,22 @@ def test_ad_engine_chunked_prefill_equivalence(attn_page_size: int):
|
||||
|
||||
# No-chunk: whole prompt in one request
|
||||
req_full = _DummyRequest(tokens=tokens, begin=0, size=len(tokens), seq_slot=0)
|
||||
scheduled_full = SimpleNamespace(context_requests=[req_full], generation_requests=[])
|
||||
logits_full_last = engine.forward(scheduled_full, resource_manager)["logits"][-1]
|
||||
scheduled_requests = ScheduledRequests()
|
||||
scheduled_requests.context_requests.append(req_full)
|
||||
logits_full_last = engine.forward(scheduled_requests, resource_manager)["logits"][-1]
|
||||
|
||||
# Chunked: split into two context chunks
|
||||
split = len(tokens) // 2
|
||||
req_part1 = _DummyRequest(tokens=tokens, begin=0, size=split, seq_slot=0)
|
||||
req_part2 = _DummyRequest(tokens=tokens, begin=split, size=len(tokens) - split, seq_slot=0)
|
||||
|
||||
scheduled_part1 = SimpleNamespace(context_requests=[req_part1], generation_requests=[])
|
||||
scheduled_part2 = SimpleNamespace(context_requests=[req_part2], generation_requests=[])
|
||||
scheduled_requests_part1 = ScheduledRequests()
|
||||
scheduled_requests_part1.context_requests.append(req_part1)
|
||||
scheduled_requests_part2 = ScheduledRequests()
|
||||
scheduled_requests_part2.context_requests.append(req_part2)
|
||||
|
||||
# Run first chunk (ignored output), then compare second chunk logits to full
|
||||
_ = engine.forward(scheduled_part1, resource_manager)
|
||||
logits_chunked_last = engine.forward(scheduled_part2, resource_manager)["logits"][-1]
|
||||
_ = engine.forward(scheduled_requests_part1, resource_manager)
|
||||
logits_chunked_last = engine.forward(scheduled_requests_part2, resource_manager)["logits"][-1]
|
||||
|
||||
torch.testing.assert_close(logits_full_last, logits_chunked_last) # , atol=1e-5)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user