[TRTLLM-1316] refactor: Remove unnecessary pipeline parallelism logic from postProcessRequest (#5489)

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
This commit is contained in:
Robin Kobus 2025-07-02 10:13:31 +02:00 committed by GitHub
parent ca7b6ec8d8
commit 4cd8543d8c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 146 additions and 96 deletions

View File

@ -1842,40 +1842,6 @@ void TrtGptModelInflightBatching::postProcessRequest(
bufferManager.getStream().synchronize();
}
if (mWorldConfig.isPipelineParallel())
{
// Send context logits from last to first PP rank
if (llmReq.getReturnContextLogits())
{
if (mWorldConfig.isLastPipelineParallelRank())
{
mMpiCommPipelinePara->send(
*(llmReq.getContextLogitsHost()), 0, mpi::MpiTag::kTrtGptModelInflightBatchingContextLogits);
}
else if (mWorldConfig.isFirstPipelineParallelRank())
{
mMpiCommPipelinePara->recv(*(llmReq.getContextLogitsHost()), mWorldConfig.getPipelineParallelism() - 1,
mpi::MpiTag::kTrtGptModelInflightBatchingContextLogits);
}
}
// Send generation logits from last to first PP rank
if (llmReq.getReturnGenerationLogits())
{
if (mWorldConfig.isLastPipelineParallelRank())
{
mMpiCommPipelinePara->send(
*(llmReq.getGenerationLogitsHost()), 0, mpi::MpiTag::kTrtGptModelInflightBatchingGenerationLogits);
}
else if (mWorldConfig.isFirstPipelineParallelRank())
{
mMpiCommPipelinePara->recv(*(llmReq.getGenerationLogitsHost()),
mWorldConfig.getPipelineParallelism() - 1,
mpi::MpiTag::kTrtGptModelInflightBatchingGenerationLogits);
}
}
}
if (reqBeamWidth == 1)
{
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);

View File

@ -1901,7 +1901,7 @@ namespace
void runTest(Executor& executor, fs::path const& inputPath, ModelIds const& modelIds,
FlakyTestInfo const& flakyTestInfo, bool streaming, SizeType32 const vocabSizePadded, BeamResult const& beamResult,
OutputConfig const& outConfig, bool isSpeculativeDecoding, int maxWaitMs, bool returnAllGeneratedTokens,
SizeType32 const numReturnSequences, bool isNonGreedySampling)
SizeType32 const numReturnSequences, bool isNonGreedySampling, SizeType32 const modelParallelism)
{
auto const beamWidth = beamResult.beamWidth;
@ -1948,7 +1948,6 @@ void runTest(Executor& executor, fs::path const& inputPath, ModelIds const& mode
auto& comm = tensorrt_llm::mpi::MpiComm::world();
auto const worldRank = comm.getRank();
auto const worldSize = comm.getSize();
// Expected return sizes.
auto const numSequences = beamWidth > 1 ? 1 : numReturnSequences;
@ -2021,15 +2020,18 @@ void runTest(Executor& executor, fs::path const& inputPath, ModelIds const& mode
if (!isNonGreedySampling)
{
float const logitsAtol = modelParallelism > 1 ? 1e-1 : 1e-2;
float const logitsRtol = modelParallelism > 1 ? 1e-2 : 1e-3;
testData.verifyLogProbs(outConfig.returnLogProbs, streaming, outConfig.excludeInputFromOutput,
givenInputLengths.at(batchId), beamWidth, beamTokens, cumLogProbs, logProbs, batchId,
flakyTestInfo);
testData.validateContextLogits(outConfig.returnContextLogits, givenInputLengths.at(batchId),
beamWidth, contextLogits, vocabSizePadded, batchId);
beamWidth, contextLogits, vocabSizePadded, batchId, logitsAtol, logitsRtol);
testData.validateGenerationLogits(outConfig.returnGenerationLogits, result.isSequenceFinal,
streaming, outConfig.excludeInputFromOutput, givenInputLengths.at(batchId),
reqMaxNewTokens.at(batchId), beamWidth, beamTokens, genLogits, vocabSizePadded, batchId,
returnAllGeneratedTokens);
returnAllGeneratedTokens, logitsAtol, logitsRtol);
}
// Ignore first iteration as it doesn't use draft tokens
@ -2063,12 +2065,14 @@ void runTest(Executor& executor, fs::path const& inputPath, ModelIds const& mode
void runTest(fs::path const& modelPath, ExecutorConfig const& executorConfig, fs::path const& inputPath,
ModelIds const& modelIds, FlakyTestInfo const& flakyTestInfo, bool streaming, SizeType32 const vocabSizePadded,
BeamResult const& beamResult, OutputConfig const& outConfig, bool isSpeculativeDecoding, int maxWaitMs,
bool returnAllGeneratedTokens, SizeType32 const numReturnSequences, bool isNonGreedySampling)
bool returnAllGeneratedTokens, SizeType32 const numReturnSequences, bool isNonGreedySampling,
SizeType32 const modelParallelism)
{
auto executor = Executor{modelPath, ModelType::kDECODER_ONLY, executorConfig};
runTest(executor, inputPath, modelIds, flakyTestInfo, streaming, vocabSizePadded, beamResult, outConfig,
isSpeculativeDecoding, maxWaitMs, returnAllGeneratedTokens, numReturnSequences, isNonGreedySampling);
isSpeculativeDecoding, maxWaitMs, returnAllGeneratedTokens, numReturnSequences, isNonGreedySampling,
modelParallelism);
}
ExecutorConfig createExecutorConfig(SizeType32 maxBeamWidth, bool useOrchestratorMode, bool gatherGenerationLogits,
@ -2193,6 +2197,12 @@ TEST_P(AllParamsTest, TokenComparison)
beamResult.resultsFile = resultsPath / PathUtil::FP16_PLUGIN_PACKED_PAGED_RESULT_TP2_PP2_FILE();
modelPath = LLAMA_MODEL_PATH / PathUtil::FP16_GPT_ATTENTION_PACKED_PAGED_DIR() / "tp2-pp2-cp1-gpu";
}
beamResult.genLogitsFile = resultsPath / PathUtil::FP16_PLUGIN_PACKED_PAGED_GENERATION_LOGITS_TP4_PP1_FILE();
if (outConfig.returnLogProbs)
{
beamResult.cumLogProbsFile = resultsPath / PathUtil::FP16_PLUGIN_PACKED_PAGED_CUM_LOG_PROBS_TP4_PP1_FILE();
beamResult.logProbsFile = resultsPath / PathUtil::FP16_PLUGIN_PACKED_PAGED_LOG_PROBS_TP4_PP1_FILE();
}
}
else if (modelName == "medusa")
{
@ -2283,9 +2293,9 @@ TEST_P(AllParamsTest, TokenComparison)
GTEST_SKIP() << "Skipping Llama test";
}
if (outConfig.returnLogProbs || outConfig.returnContextLogits || outConfig.returnGenerationLogits)
if (outConfig.returnContextLogits)
{
GTEST_SKIP() << "Skipping logits and log probs tests for mpi runs";
GTEST_SKIP() << "Skipping context logits tests for mpi runs";
}
// Check that it was launched with right number of MPI ranks
@ -2302,11 +2312,12 @@ TEST_P(AllParamsTest, TokenComparison)
}
auto decoderJsonConfig = tensorrt_llm::runtime::GptJsonConfig::parse(modelPath / "config.json");
auto modelTP = decoderJsonConfig.getTensorParallelism();
auto modelPP = decoderJsonConfig.getPipelineParallelism();
auto const modelTP = decoderJsonConfig.getTensorParallelism();
auto const modelPP = decoderJsonConfig.getPipelineParallelism();
auto const modelParallelism = modelTP * modelPP;
int deviceCount = -1;
TLLM_CUDA_CHECK(cudaGetDeviceCount(&deviceCount));
std::optional<std::vector<SizeType32>> deviceIds = std::vector<SizeType32>(modelTP * modelPP);
std::optional<std::vector<SizeType32>> deviceIds = std::vector<SizeType32>(modelParallelism);
for (auto i = 0; i < deviceIds->size(); i++)
{
deviceIds->at(i) = i % deviceCount;
@ -2354,7 +2365,8 @@ TEST_P(AllParamsTest, TokenComparison)
std::move(deviceIds), std::move(participantIds));
runTest(modelPath, executorConfig, inputPath, modelIds, flakyTestInfo, streaming, vocabSizePadded, beamResult,
outConfig, isSpeculativeDecoding, mMaxWaitMs, returnAllGeneratedTokens, numReturnSequences, false);
outConfig, isSpeculativeDecoding, mMaxWaitMs, returnAllGeneratedTokens, numReturnSequences, false,
modelParallelism);
}
TEST_F(GptExecutorTest, ChangeBeamWidth)
@ -2455,7 +2467,7 @@ void doTokenComparisonChangeBeamWidth(bool enableReuse, SizeType32 maxWaitMs)
beamResult.genLogitsFile = resultsPath / PathUtil::FP16_PLUGIN_PACKED_PAGED_GENERATION_LOGITS_FILE();
runTest(executor, inputPath, modelIds, flakyTestInfo, streaming, vocabSizePadded, beamResult, outConfig,
isSpeculativeDecoding, maxWaitMs, false, 1, false);
isSpeculativeDecoding, maxWaitMs, false, 1, false, 1);
}
}
@ -2497,7 +2509,7 @@ TEST_F(GptExecutorTest, NReturnRandomness)
beamResult.genLogitsFile = resultsPath / PathUtil::FP16_PLUGIN_PACKED_PAGED_GENERATION_LOGITS_FILE();
runTest(executor, inputPath, modelIds, flakyTestInfo, streaming, vocabSizePadded, beamResult, outConfig,
isSpeculativeDecoding, mMaxWaitMs, false, 1, true);
isSpeculativeDecoding, mMaxWaitMs, false, 1, true, 1);
}
TEST_F(GptExecutorTest, TimedOut)
@ -4526,10 +4538,10 @@ INSTANTIATE_TEST_SUITE_P(LlamaExecutorTest, AllParamsTest,
testing::Combine( //
testing::Values(false, true), // streaming
testing::Values(1, 2), // beamWidth
testing::Values(false), // computeLogProbs
testing::Values(true), // computeLogProbs
testing::Values(false, true), // excludeInputInOutput
testing::Values(false), // returnContextLogits
testing::Values(false), // returnGenerationLogits
testing::Values(true), // returnGenerationLogits
testing::Values("llama_tp1_pp4_cp1", "llama_tp4_pp1_cp1", "llama_tp2_pp2_cp1"), // modelName
testing::Values(false, true), // useOrchestratorMode
testing::Values(false), // returnAllGeneratedTokens

View File

@ -18,6 +18,7 @@ import argparse as _arg
import pathlib as _pl
import platform as _pf
import sys as _sys
import time
from build_engines_utils import run_command, wincopy
@ -93,8 +94,10 @@ def build_engines(model_cache: str, only_multi_gpu: bool):
tp_pp_cp_sizes = [(1, 4, 1), (4, 1, 1), (1, 2, 1), (2, 2, 1), (2, 1, 1),
(1, 1, 2), (2, 1, 2)]
for tp_size, pp_size, cp_size in tp_pp_cp_sizes:
tp_pp_cp_dir = f"tp{tp_size}-pp{pp_size}-cp{cp_size}-gpu"
print(f"\nBuilding fp16 tp{tp_size} pp{pp_size} cp{cp_size} engine")
start_time = time.time()
tp_pp_cp_dir = f"tp{tp_size}-pp{pp_size}-cp{cp_size}-gpu"
model_spec_obj.use_tensor_parallelism(tp_size)
model_spec_obj.use_pipeline_parallelism(pp_size)
model_spec_obj.use_context_parallelism(cp_size)
@ -106,8 +109,15 @@ def build_engines(model_cache: str, only_multi_gpu: bool):
f'--cp_size={cp_size}'
], ['--use_paged_context_fmha=disable'])
duration = time.time() - start_time
print(
f"Building fp16 tp{tp_size} pp{pp_size} cp{cp_size} engine took {duration} seconds"
)
if not only_multi_gpu:
print(f"\nBuilding lookahead engine")
start_time = time.time()
model_spec_obj.use_tensor_parallelism(1)
model_spec_obj.use_pipeline_parallelism(1)
model_spec_obj.use_context_parallelism(1)
@ -120,6 +130,9 @@ def build_engines(model_cache: str, only_multi_gpu: bool):
'--speculative_decoding_mode=lookahead_decoding'
])
duration = time.time() - start_time
print(f"Building lookahead engine took {duration} seconds")
print("Done.")

View File

@ -65,27 +65,22 @@ def generate_output(engine: str,
base_output_name = os.path.splitext(model_spec_obj.get_results_file())[0]
output_logits_args = []
if output_logits:
logits_file = base_output_name + '_logits.npy'
output_logits_args = [
'--output_logits_npy',
str(output_dir / logits_file),
'--output_generation_logits',
]
results_file = str(output_dir / (base_output_name + '.npy'))
results_csv = str(output_dir / (base_output_name + '.csv'))
args_list = [
'--engine_dir',
str(engine_dir), '--input_file',
str(input_file), '--tokenizer_dir',
str(models_dir / model), '--output_npy', results_file, '--output_csv',
results_csv, '--max_output_len',
str(max_output_len), '--num_beams',
str(num_beams), '--use_py_session'
] + output_logits_args
f'--engine_dir={engine_dir}',
f'--input_file={input_file}',
f'--tokenizer_dir={models_dir / model}',
f'--output_npy={output_dir / (base_output_name + ".npy")}',
f'--output_csv={output_dir / (base_output_name + ".csv")}',
f'--max_output_len={max_output_len}',
f'--num_beams={num_beams}',
'--use_py_session',
]
if output_logits:
args_list.extend([
f'--output_logits_npy={output_dir / (base_output_name + "_logits.npy")}',
'--output_generation_logits',
])
# Generate context_fmha_fp32_acc enabled results for GptExecutorTest.GenerationLogitsEarlyStop
if model_spec_obj.get_enable_context_fmha_fp32_acc():
@ -93,14 +88,12 @@ def generate_output(engine: str,
if output_cum_log_probs:
args_list.extend([
'--output_cum_log_probs_npy',
f'{output_dir / model_spec_obj.get_cum_log_probs_file()}'
f'--output_cum_log_probs_npy={output_dir / model_spec_obj.get_cum_log_probs_file()}'
])
if output_log_probs:
args_list.extend([
'--output_log_probs_npy',
f'{output_dir / model_spec_obj.get_log_probs_file()}'
f'--output_log_probs_npy={output_dir / model_spec_obj.get_log_probs_file()}'
])
args = run.parse_arguments(args_list)

View File

@ -16,6 +16,7 @@
import argparse as _arg
import os
import time
from pathlib import Path
from mpi4py.MPI import COMM_WORLD
@ -34,12 +35,14 @@ def generate_output(engine: str,
tp_size: int = 1,
pp_size: int = 1,
cp_size: int = 1,
max_output_len: int = 8):
max_output_len: int = 8,
output_logits: bool = False,
output_cum_log_probs: bool = False,
output_log_probs: bool = False):
model = 'Llama-3.2-1B'
resources_dir = Path(__file__).parent.resolve().parent
models_dir = resources_dir / 'models'
hf_dir = models_dir / model
tp_pp_cp_dir = 'tp' + str(tp_size) + '-pp' + str(pp_size) + '-cp' + str(
cp_size) + '-gpu/'
engine_dir = models_dir / 'rt_engine' / model / engine / tp_pp_cp_dir
@ -54,16 +57,34 @@ def generate_output(engine: str,
base_output_name = os.path.splitext(model_spec_obj.get_results_file())[0]
args = run.parse_arguments([
'--engine_dir',
str(engine_dir), '--input_file',
str(input_file), '--tokenizer_dir',
str(hf_dir), '--output_npy',
str(output_dir / (base_output_name + '.npy')), '--output_csv',
str(output_dir / (base_output_name + '.csv')), '--max_output_len',
str(max_output_len), '--num_beams',
str(num_beams), '--use_py_session'
])
args_list = [
f'--engine_dir={engine_dir}',
f'--input_file={input_file}',
f'--tokenizer_dir={models_dir / model}',
f'--output_npy={output_dir / (base_output_name + ".npy")}',
f'--output_csv={output_dir / (base_output_name + ".csv")}',
f'--max_output_len={max_output_len}',
f'--num_beams={num_beams}',
'--use_py_session',
]
if output_logits:
args_list.extend([
f'--output_logits_npy={output_dir / (base_output_name + "_logits.npy")}',
'--output_generation_logits',
])
if output_cum_log_probs:
args_list.extend([
f'--output_cum_log_probs_npy={output_dir / model_spec_obj.get_cum_log_probs_file()}'
])
if output_log_probs:
args_list.extend([
f'--output_log_probs_npy={output_dir / model_spec_obj.get_log_probs_file()}'
])
args = run.parse_arguments(args_list)
run.main(args)
@ -85,8 +106,18 @@ def generate_outputs(num_beams, only_multi_gpu=False):
for tp_size, pp_size, cp_size in tp_pp_cp_sizes:
print(
f'Generating outputs for Llama FP16 with TP={tp_size}, PP={pp_size} and CP={cp_size}'
f'Generating outputs for Llama FP16 with TP={tp_size}, PP={pp_size}, CP={cp_size}, BW={num_beams}'
)
start_time = time.time()
output_logits = False
output_log_probs = False
output_cum_log_probs = False
if tp_size == 4 and pp_size == 1:
output_logits = True
output_log_probs = True
output_cum_log_probs = True
model_spec_obj.use_tensor_parallelism(tp_size)
model_spec_obj.use_pipeline_parallelism(pp_size)
model_spec_obj.use_context_parallelism(cp_size)
@ -95,7 +126,15 @@ def generate_outputs(num_beams, only_multi_gpu=False):
tp_size=tp_size,
pp_size=pp_size,
cp_size=cp_size,
model_spec_obj=model_spec_obj)
model_spec_obj=model_spec_obj,
output_logits=output_logits,
output_log_probs=output_log_probs,
output_cum_log_probs=output_cum_log_probs)
duration = time.time() - start_time
print(
f"Generating outputs for Llama FP16 with TP={tp_size}, PP={pp_size}, CP={cp_size}, BW={num_beams} took {duration} seconds"
)
if __name__ == '__main__':

View File

@ -133,6 +133,26 @@ std::string PathUtil::FP16_PLUGIN_PACKED_PAGED_RESULT_TP2_PP1_FILE()
return ModelSpec::getDefaultModelSpec().useTensorParallelism(2).getResultsFile();
}
std::string PathUtil::FP16_PLUGIN_PACKED_PAGED_CONTEXT_LOGITS_TP4_PP1_FILE()
{
return ModelSpec::getDefaultModelSpec().useTensorParallelism(4).getContextLogitsFile();
}
std::string PathUtil::FP16_PLUGIN_PACKED_PAGED_GENERATION_LOGITS_TP4_PP1_FILE()
{
return ModelSpec::getDefaultModelSpec().useTensorParallelism(4).getGenerationLogitsFile();
}
std::string PathUtil::FP16_PLUGIN_PACKED_PAGED_CUM_LOG_PROBS_TP4_PP1_FILE()
{
return ModelSpec::getDefaultModelSpec().useTensorParallelism(4).getCumLogProbsFile();
}
std::string PathUtil::FP16_PLUGIN_PACKED_PAGED_LOG_PROBS_TP4_PP1_FILE()
{
return ModelSpec::getDefaultModelSpec().useTensorParallelism(4).getLogProbsFile();
}
std::string PathUtil::FP16_PLUGIN_PACKED_PAGED_GATHER_CONTEXTFMHAFP32ACC_RESULT_FILE()
{
return ModelSpec::getDefaultModelSpec().gatherLogits().enableContextFMHAFp32Acc().getResultsFile();
@ -564,7 +584,8 @@ void TestData::verifyLogProbs(bool computeLogProbs, bool streaming, bool exclude
}
void TestData::validateContextLogits(bool getContextLogits, SizeType32 inputLength, SizeType32 beamWidth,
std::optional<executor::Tensor> const& contextLogits, SizeType32 vocabSizePadded, SizeType32 batchId)
std::optional<executor::Tensor> const& contextLogits, SizeType32 vocabSizePadded, SizeType32 batchId, float atol,
float rtol)
{
if (getContextLogits)
{
@ -577,7 +598,8 @@ void TestData::validateContextLogits(bool getContextLogits, SizeType32 inputLeng
if (beamWidth == 1)
{
cudaDeviceSynchronize(); // Make sure the logits copy is complete.
EXPECT_TRUE(compareLogits(*expectedContextLogits, *(executor::detail::toITensor(contextLogits.value()))));
EXPECT_TRUE(compareLogits(
*expectedContextLogits, *(executor::detail::toITensor(contextLogits.value())), atol, rtol));
}
}
else
@ -589,7 +611,7 @@ void TestData::validateContextLogits(bool getContextLogits, SizeType32 inputLeng
void TestData::validateGenerationLogits(bool getGenLogits, bool isFinal, bool streaming, bool excludeInputFromOutput,
SizeType32 inputLength, SizeType32 maxOutputLen, SizeType32 beamWidth, executor::BeamTokens const& beamTokens,
std::optional<executor::Tensor> const& genLogits, SizeType32 vocabSizePadded, SizeType32 batchId,
bool const returnAllGeneratedTokens)
bool const returnAllGeneratedTokens, float atol, float rtol)
{
auto const numReturnBeams = beamTokens.size();
@ -632,7 +654,7 @@ void TestData::validateGenerationLogits(bool getGenLogits, bool isFinal, bool st
numGeneratedToken)); // [numGeneratedToken, vocabSizePadded]
cudaDeviceSynchronize(); // Make sure the logits copy is complete.
EXPECT_TRUE(compareLogits(*expectedGenerationLogitsSlice, *outputGenerationLogits));
EXPECT_TRUE(compareLogits(*expectedGenerationLogitsSlice, *outputGenerationLogits, atol, rtol));
}
else
{
@ -643,7 +665,7 @@ void TestData::validateGenerationLogits(bool getGenLogits, bool isFinal, bool st
if (isFinal && beamWidth == 1)
{
cudaDeviceSynchronize(); // Make sure the logits copy is complete.
EXPECT_TRUE(compareLogits(*expectedGenerationLogits, *outputGenerationLogits));
EXPECT_TRUE(compareLogits(*expectedGenerationLogits, *outputGenerationLogits, atol, rtol));
}
}
EXPECT_EQ(genLogits.value().getShape()[2], vocabSizePadded);

View File

@ -101,6 +101,10 @@ public:
static std::string FP16_PLUGIN_PACKED_PAGED_RESULT_TP1_PP4_FILE();
static std::string FP16_PLUGIN_PACKED_PAGED_RESULT_TP1_PP2_FILE();
static std::string FP16_PLUGIN_PACKED_PAGED_RESULT_TP2_PP1_FILE();
static std::string FP16_PLUGIN_PACKED_PAGED_CONTEXT_LOGITS_TP4_PP1_FILE();
static std::string FP16_PLUGIN_PACKED_PAGED_GENERATION_LOGITS_TP4_PP1_FILE();
static std::string FP16_PLUGIN_PACKED_PAGED_CUM_LOG_PROBS_TP4_PP1_FILE();
static std::string FP16_PLUGIN_PACKED_PAGED_LOG_PROBS_TP4_PP1_FILE();
// GptExecutorTest.GenerationLogitsEarlyStop requires to use context_fmha_fp32_acc flag in runtime for better
// accuracy
static std::string FP16_PLUGIN_PACKED_PAGED_GATHER_CONTEXTFMHAFP32ACC_RESULT_FILE();
@ -198,12 +202,13 @@ public:
FlakyTestInfo flakyTestInfo);
void validateContextLogits(bool getContextLogits, SizeType32 inputLength, SizeType32 beamWidth,
std::optional<executor::Tensor> const& contextLogits, SizeType32 vocabSizePadded, SizeType32 batchId);
std::optional<executor::Tensor> const& contextLogits, SizeType32 vocabSizePadded, SizeType32 batchId,
float atol = 1e-2, float rtol = 1e-3);
void validateGenerationLogits(bool getGenLogits, bool isFinal, bool streaming, bool excludeInputFromOutput,
SizeType32 inputLength, SizeType32 maxOutputLen, SizeType32 beamWidth, executor::BeamTokens const& beamTokens,
std::optional<executor::Tensor> const& genLogits, SizeType32 vocabSizePadded, SizeType32 batchId,
bool returnAllGeneratedTokens);
bool returnAllGeneratedTokens, float atol = 1e-2, float rtol = 1e-3);
SizeType32 nbGivenInputs{};
SizeType32 beamWidth{};