mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[TRTLLM-1316] refactor: Remove unnecessary pipeline parallelism logic from postProcessRequest (#5489)
Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
This commit is contained in:
parent
ca7b6ec8d8
commit
4cd8543d8c
@ -1842,40 +1842,6 @@ void TrtGptModelInflightBatching::postProcessRequest(
|
||||
bufferManager.getStream().synchronize();
|
||||
}
|
||||
|
||||
if (mWorldConfig.isPipelineParallel())
|
||||
{
|
||||
// Send context logits from last to first PP rank
|
||||
if (llmReq.getReturnContextLogits())
|
||||
{
|
||||
if (mWorldConfig.isLastPipelineParallelRank())
|
||||
{
|
||||
mMpiCommPipelinePara->send(
|
||||
*(llmReq.getContextLogitsHost()), 0, mpi::MpiTag::kTrtGptModelInflightBatchingContextLogits);
|
||||
}
|
||||
else if (mWorldConfig.isFirstPipelineParallelRank())
|
||||
{
|
||||
mMpiCommPipelinePara->recv(*(llmReq.getContextLogitsHost()), mWorldConfig.getPipelineParallelism() - 1,
|
||||
mpi::MpiTag::kTrtGptModelInflightBatchingContextLogits);
|
||||
}
|
||||
}
|
||||
|
||||
// Send generation logits from last to first PP rank
|
||||
if (llmReq.getReturnGenerationLogits())
|
||||
{
|
||||
if (mWorldConfig.isLastPipelineParallelRank())
|
||||
{
|
||||
mMpiCommPipelinePara->send(
|
||||
*(llmReq.getGenerationLogitsHost()), 0, mpi::MpiTag::kTrtGptModelInflightBatchingGenerationLogits);
|
||||
}
|
||||
else if (mWorldConfig.isFirstPipelineParallelRank())
|
||||
{
|
||||
mMpiCommPipelinePara->recv(*(llmReq.getGenerationLogitsHost()),
|
||||
mWorldConfig.getPipelineParallelism() - 1,
|
||||
mpi::MpiTag::kTrtGptModelInflightBatchingGenerationLogits);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (reqBeamWidth == 1)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
|
||||
@ -1901,7 +1901,7 @@ namespace
|
||||
void runTest(Executor& executor, fs::path const& inputPath, ModelIds const& modelIds,
|
||||
FlakyTestInfo const& flakyTestInfo, bool streaming, SizeType32 const vocabSizePadded, BeamResult const& beamResult,
|
||||
OutputConfig const& outConfig, bool isSpeculativeDecoding, int maxWaitMs, bool returnAllGeneratedTokens,
|
||||
SizeType32 const numReturnSequences, bool isNonGreedySampling)
|
||||
SizeType32 const numReturnSequences, bool isNonGreedySampling, SizeType32 const modelParallelism)
|
||||
{
|
||||
auto const beamWidth = beamResult.beamWidth;
|
||||
|
||||
@ -1948,7 +1948,6 @@ void runTest(Executor& executor, fs::path const& inputPath, ModelIds const& mode
|
||||
|
||||
auto& comm = tensorrt_llm::mpi::MpiComm::world();
|
||||
auto const worldRank = comm.getRank();
|
||||
auto const worldSize = comm.getSize();
|
||||
|
||||
// Expected return sizes.
|
||||
auto const numSequences = beamWidth > 1 ? 1 : numReturnSequences;
|
||||
@ -2021,15 +2020,18 @@ void runTest(Executor& executor, fs::path const& inputPath, ModelIds const& mode
|
||||
|
||||
if (!isNonGreedySampling)
|
||||
{
|
||||
float const logitsAtol = modelParallelism > 1 ? 1e-1 : 1e-2;
|
||||
float const logitsRtol = modelParallelism > 1 ? 1e-2 : 1e-3;
|
||||
|
||||
testData.verifyLogProbs(outConfig.returnLogProbs, streaming, outConfig.excludeInputFromOutput,
|
||||
givenInputLengths.at(batchId), beamWidth, beamTokens, cumLogProbs, logProbs, batchId,
|
||||
flakyTestInfo);
|
||||
testData.validateContextLogits(outConfig.returnContextLogits, givenInputLengths.at(batchId),
|
||||
beamWidth, contextLogits, vocabSizePadded, batchId);
|
||||
beamWidth, contextLogits, vocabSizePadded, batchId, logitsAtol, logitsRtol);
|
||||
testData.validateGenerationLogits(outConfig.returnGenerationLogits, result.isSequenceFinal,
|
||||
streaming, outConfig.excludeInputFromOutput, givenInputLengths.at(batchId),
|
||||
reqMaxNewTokens.at(batchId), beamWidth, beamTokens, genLogits, vocabSizePadded, batchId,
|
||||
returnAllGeneratedTokens);
|
||||
returnAllGeneratedTokens, logitsAtol, logitsRtol);
|
||||
}
|
||||
|
||||
// Ignore first iteration as it doesn't use draft tokens
|
||||
@ -2063,12 +2065,14 @@ void runTest(Executor& executor, fs::path const& inputPath, ModelIds const& mode
|
||||
void runTest(fs::path const& modelPath, ExecutorConfig const& executorConfig, fs::path const& inputPath,
|
||||
ModelIds const& modelIds, FlakyTestInfo const& flakyTestInfo, bool streaming, SizeType32 const vocabSizePadded,
|
||||
BeamResult const& beamResult, OutputConfig const& outConfig, bool isSpeculativeDecoding, int maxWaitMs,
|
||||
bool returnAllGeneratedTokens, SizeType32 const numReturnSequences, bool isNonGreedySampling)
|
||||
bool returnAllGeneratedTokens, SizeType32 const numReturnSequences, bool isNonGreedySampling,
|
||||
SizeType32 const modelParallelism)
|
||||
{
|
||||
auto executor = Executor{modelPath, ModelType::kDECODER_ONLY, executorConfig};
|
||||
|
||||
runTest(executor, inputPath, modelIds, flakyTestInfo, streaming, vocabSizePadded, beamResult, outConfig,
|
||||
isSpeculativeDecoding, maxWaitMs, returnAllGeneratedTokens, numReturnSequences, isNonGreedySampling);
|
||||
isSpeculativeDecoding, maxWaitMs, returnAllGeneratedTokens, numReturnSequences, isNonGreedySampling,
|
||||
modelParallelism);
|
||||
}
|
||||
|
||||
ExecutorConfig createExecutorConfig(SizeType32 maxBeamWidth, bool useOrchestratorMode, bool gatherGenerationLogits,
|
||||
@ -2193,6 +2197,12 @@ TEST_P(AllParamsTest, TokenComparison)
|
||||
beamResult.resultsFile = resultsPath / PathUtil::FP16_PLUGIN_PACKED_PAGED_RESULT_TP2_PP2_FILE();
|
||||
modelPath = LLAMA_MODEL_PATH / PathUtil::FP16_GPT_ATTENTION_PACKED_PAGED_DIR() / "tp2-pp2-cp1-gpu";
|
||||
}
|
||||
beamResult.genLogitsFile = resultsPath / PathUtil::FP16_PLUGIN_PACKED_PAGED_GENERATION_LOGITS_TP4_PP1_FILE();
|
||||
if (outConfig.returnLogProbs)
|
||||
{
|
||||
beamResult.cumLogProbsFile = resultsPath / PathUtil::FP16_PLUGIN_PACKED_PAGED_CUM_LOG_PROBS_TP4_PP1_FILE();
|
||||
beamResult.logProbsFile = resultsPath / PathUtil::FP16_PLUGIN_PACKED_PAGED_LOG_PROBS_TP4_PP1_FILE();
|
||||
}
|
||||
}
|
||||
else if (modelName == "medusa")
|
||||
{
|
||||
@ -2283,9 +2293,9 @@ TEST_P(AllParamsTest, TokenComparison)
|
||||
GTEST_SKIP() << "Skipping Llama test";
|
||||
}
|
||||
|
||||
if (outConfig.returnLogProbs || outConfig.returnContextLogits || outConfig.returnGenerationLogits)
|
||||
if (outConfig.returnContextLogits)
|
||||
{
|
||||
GTEST_SKIP() << "Skipping logits and log probs tests for mpi runs";
|
||||
GTEST_SKIP() << "Skipping context logits tests for mpi runs";
|
||||
}
|
||||
|
||||
// Check that it was launched with right number of MPI ranks
|
||||
@ -2302,11 +2312,12 @@ TEST_P(AllParamsTest, TokenComparison)
|
||||
}
|
||||
auto decoderJsonConfig = tensorrt_llm::runtime::GptJsonConfig::parse(modelPath / "config.json");
|
||||
|
||||
auto modelTP = decoderJsonConfig.getTensorParallelism();
|
||||
auto modelPP = decoderJsonConfig.getPipelineParallelism();
|
||||
auto const modelTP = decoderJsonConfig.getTensorParallelism();
|
||||
auto const modelPP = decoderJsonConfig.getPipelineParallelism();
|
||||
auto const modelParallelism = modelTP * modelPP;
|
||||
int deviceCount = -1;
|
||||
TLLM_CUDA_CHECK(cudaGetDeviceCount(&deviceCount));
|
||||
std::optional<std::vector<SizeType32>> deviceIds = std::vector<SizeType32>(modelTP * modelPP);
|
||||
std::optional<std::vector<SizeType32>> deviceIds = std::vector<SizeType32>(modelParallelism);
|
||||
for (auto i = 0; i < deviceIds->size(); i++)
|
||||
{
|
||||
deviceIds->at(i) = i % deviceCount;
|
||||
@ -2354,7 +2365,8 @@ TEST_P(AllParamsTest, TokenComparison)
|
||||
std::move(deviceIds), std::move(participantIds));
|
||||
|
||||
runTest(modelPath, executorConfig, inputPath, modelIds, flakyTestInfo, streaming, vocabSizePadded, beamResult,
|
||||
outConfig, isSpeculativeDecoding, mMaxWaitMs, returnAllGeneratedTokens, numReturnSequences, false);
|
||||
outConfig, isSpeculativeDecoding, mMaxWaitMs, returnAllGeneratedTokens, numReturnSequences, false,
|
||||
modelParallelism);
|
||||
}
|
||||
|
||||
TEST_F(GptExecutorTest, ChangeBeamWidth)
|
||||
@ -2455,7 +2467,7 @@ void doTokenComparisonChangeBeamWidth(bool enableReuse, SizeType32 maxWaitMs)
|
||||
beamResult.genLogitsFile = resultsPath / PathUtil::FP16_PLUGIN_PACKED_PAGED_GENERATION_LOGITS_FILE();
|
||||
|
||||
runTest(executor, inputPath, modelIds, flakyTestInfo, streaming, vocabSizePadded, beamResult, outConfig,
|
||||
isSpeculativeDecoding, maxWaitMs, false, 1, false);
|
||||
isSpeculativeDecoding, maxWaitMs, false, 1, false, 1);
|
||||
}
|
||||
}
|
||||
|
||||
@ -2497,7 +2509,7 @@ TEST_F(GptExecutorTest, NReturnRandomness)
|
||||
beamResult.genLogitsFile = resultsPath / PathUtil::FP16_PLUGIN_PACKED_PAGED_GENERATION_LOGITS_FILE();
|
||||
|
||||
runTest(executor, inputPath, modelIds, flakyTestInfo, streaming, vocabSizePadded, beamResult, outConfig,
|
||||
isSpeculativeDecoding, mMaxWaitMs, false, 1, true);
|
||||
isSpeculativeDecoding, mMaxWaitMs, false, 1, true, 1);
|
||||
}
|
||||
|
||||
TEST_F(GptExecutorTest, TimedOut)
|
||||
@ -4526,10 +4538,10 @@ INSTANTIATE_TEST_SUITE_P(LlamaExecutorTest, AllParamsTest,
|
||||
testing::Combine( //
|
||||
testing::Values(false, true), // streaming
|
||||
testing::Values(1, 2), // beamWidth
|
||||
testing::Values(false), // computeLogProbs
|
||||
testing::Values(true), // computeLogProbs
|
||||
testing::Values(false, true), // excludeInputInOutput
|
||||
testing::Values(false), // returnContextLogits
|
||||
testing::Values(false), // returnGenerationLogits
|
||||
testing::Values(true), // returnGenerationLogits
|
||||
testing::Values("llama_tp1_pp4_cp1", "llama_tp4_pp1_cp1", "llama_tp2_pp2_cp1"), // modelName
|
||||
testing::Values(false, true), // useOrchestratorMode
|
||||
testing::Values(false), // returnAllGeneratedTokens
|
||||
|
||||
@ -18,6 +18,7 @@ import argparse as _arg
|
||||
import pathlib as _pl
|
||||
import platform as _pf
|
||||
import sys as _sys
|
||||
import time
|
||||
|
||||
from build_engines_utils import run_command, wincopy
|
||||
|
||||
@ -93,8 +94,10 @@ def build_engines(model_cache: str, only_multi_gpu: bool):
|
||||
tp_pp_cp_sizes = [(1, 4, 1), (4, 1, 1), (1, 2, 1), (2, 2, 1), (2, 1, 1),
|
||||
(1, 1, 2), (2, 1, 2)]
|
||||
for tp_size, pp_size, cp_size in tp_pp_cp_sizes:
|
||||
tp_pp_cp_dir = f"tp{tp_size}-pp{pp_size}-cp{cp_size}-gpu"
|
||||
print(f"\nBuilding fp16 tp{tp_size} pp{pp_size} cp{cp_size} engine")
|
||||
start_time = time.time()
|
||||
|
||||
tp_pp_cp_dir = f"tp{tp_size}-pp{pp_size}-cp{cp_size}-gpu"
|
||||
model_spec_obj.use_tensor_parallelism(tp_size)
|
||||
model_spec_obj.use_pipeline_parallelism(pp_size)
|
||||
model_spec_obj.use_context_parallelism(cp_size)
|
||||
@ -106,8 +109,15 @@ def build_engines(model_cache: str, only_multi_gpu: bool):
|
||||
f'--cp_size={cp_size}'
|
||||
], ['--use_paged_context_fmha=disable'])
|
||||
|
||||
duration = time.time() - start_time
|
||||
print(
|
||||
f"Building fp16 tp{tp_size} pp{pp_size} cp{cp_size} engine took {duration} seconds"
|
||||
)
|
||||
|
||||
if not only_multi_gpu:
|
||||
print(f"\nBuilding lookahead engine")
|
||||
start_time = time.time()
|
||||
|
||||
model_spec_obj.use_tensor_parallelism(1)
|
||||
model_spec_obj.use_pipeline_parallelism(1)
|
||||
model_spec_obj.use_context_parallelism(1)
|
||||
@ -120,6 +130,9 @@ def build_engines(model_cache: str, only_multi_gpu: bool):
|
||||
'--speculative_decoding_mode=lookahead_decoding'
|
||||
])
|
||||
|
||||
duration = time.time() - start_time
|
||||
print(f"Building lookahead engine took {duration} seconds")
|
||||
|
||||
print("Done.")
|
||||
|
||||
|
||||
|
||||
@ -65,27 +65,22 @@ def generate_output(engine: str,
|
||||
|
||||
base_output_name = os.path.splitext(model_spec_obj.get_results_file())[0]
|
||||
|
||||
output_logits_args = []
|
||||
if output_logits:
|
||||
logits_file = base_output_name + '_logits.npy'
|
||||
output_logits_args = [
|
||||
'--output_logits_npy',
|
||||
str(output_dir / logits_file),
|
||||
'--output_generation_logits',
|
||||
]
|
||||
|
||||
results_file = str(output_dir / (base_output_name + '.npy'))
|
||||
results_csv = str(output_dir / (base_output_name + '.csv'))
|
||||
|
||||
args_list = [
|
||||
'--engine_dir',
|
||||
str(engine_dir), '--input_file',
|
||||
str(input_file), '--tokenizer_dir',
|
||||
str(models_dir / model), '--output_npy', results_file, '--output_csv',
|
||||
results_csv, '--max_output_len',
|
||||
str(max_output_len), '--num_beams',
|
||||
str(num_beams), '--use_py_session'
|
||||
] + output_logits_args
|
||||
f'--engine_dir={engine_dir}',
|
||||
f'--input_file={input_file}',
|
||||
f'--tokenizer_dir={models_dir / model}',
|
||||
f'--output_npy={output_dir / (base_output_name + ".npy")}',
|
||||
f'--output_csv={output_dir / (base_output_name + ".csv")}',
|
||||
f'--max_output_len={max_output_len}',
|
||||
f'--num_beams={num_beams}',
|
||||
'--use_py_session',
|
||||
]
|
||||
|
||||
if output_logits:
|
||||
args_list.extend([
|
||||
f'--output_logits_npy={output_dir / (base_output_name + "_logits.npy")}',
|
||||
'--output_generation_logits',
|
||||
])
|
||||
|
||||
# Generate context_fmha_fp32_acc enabled results for GptExecutorTest.GenerationLogitsEarlyStop
|
||||
if model_spec_obj.get_enable_context_fmha_fp32_acc():
|
||||
@ -93,14 +88,12 @@ def generate_output(engine: str,
|
||||
|
||||
if output_cum_log_probs:
|
||||
args_list.extend([
|
||||
'--output_cum_log_probs_npy',
|
||||
f'{output_dir / model_spec_obj.get_cum_log_probs_file()}'
|
||||
f'--output_cum_log_probs_npy={output_dir / model_spec_obj.get_cum_log_probs_file()}'
|
||||
])
|
||||
|
||||
if output_log_probs:
|
||||
args_list.extend([
|
||||
'--output_log_probs_npy',
|
||||
f'{output_dir / model_spec_obj.get_log_probs_file()}'
|
||||
f'--output_log_probs_npy={output_dir / model_spec_obj.get_log_probs_file()}'
|
||||
])
|
||||
|
||||
args = run.parse_arguments(args_list)
|
||||
|
||||
@ -16,6 +16,7 @@
|
||||
|
||||
import argparse as _arg
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
from mpi4py.MPI import COMM_WORLD
|
||||
@ -34,12 +35,14 @@ def generate_output(engine: str,
|
||||
tp_size: int = 1,
|
||||
pp_size: int = 1,
|
||||
cp_size: int = 1,
|
||||
max_output_len: int = 8):
|
||||
max_output_len: int = 8,
|
||||
output_logits: bool = False,
|
||||
output_cum_log_probs: bool = False,
|
||||
output_log_probs: bool = False):
|
||||
|
||||
model = 'Llama-3.2-1B'
|
||||
resources_dir = Path(__file__).parent.resolve().parent
|
||||
models_dir = resources_dir / 'models'
|
||||
hf_dir = models_dir / model
|
||||
tp_pp_cp_dir = 'tp' + str(tp_size) + '-pp' + str(pp_size) + '-cp' + str(
|
||||
cp_size) + '-gpu/'
|
||||
engine_dir = models_dir / 'rt_engine' / model / engine / tp_pp_cp_dir
|
||||
@ -54,16 +57,34 @@ def generate_output(engine: str,
|
||||
|
||||
base_output_name = os.path.splitext(model_spec_obj.get_results_file())[0]
|
||||
|
||||
args = run.parse_arguments([
|
||||
'--engine_dir',
|
||||
str(engine_dir), '--input_file',
|
||||
str(input_file), '--tokenizer_dir',
|
||||
str(hf_dir), '--output_npy',
|
||||
str(output_dir / (base_output_name + '.npy')), '--output_csv',
|
||||
str(output_dir / (base_output_name + '.csv')), '--max_output_len',
|
||||
str(max_output_len), '--num_beams',
|
||||
str(num_beams), '--use_py_session'
|
||||
])
|
||||
args_list = [
|
||||
f'--engine_dir={engine_dir}',
|
||||
f'--input_file={input_file}',
|
||||
f'--tokenizer_dir={models_dir / model}',
|
||||
f'--output_npy={output_dir / (base_output_name + ".npy")}',
|
||||
f'--output_csv={output_dir / (base_output_name + ".csv")}',
|
||||
f'--max_output_len={max_output_len}',
|
||||
f'--num_beams={num_beams}',
|
||||
'--use_py_session',
|
||||
]
|
||||
|
||||
if output_logits:
|
||||
args_list.extend([
|
||||
f'--output_logits_npy={output_dir / (base_output_name + "_logits.npy")}',
|
||||
'--output_generation_logits',
|
||||
])
|
||||
|
||||
if output_cum_log_probs:
|
||||
args_list.extend([
|
||||
f'--output_cum_log_probs_npy={output_dir / model_spec_obj.get_cum_log_probs_file()}'
|
||||
])
|
||||
|
||||
if output_log_probs:
|
||||
args_list.extend([
|
||||
f'--output_log_probs_npy={output_dir / model_spec_obj.get_log_probs_file()}'
|
||||
])
|
||||
|
||||
args = run.parse_arguments(args_list)
|
||||
run.main(args)
|
||||
|
||||
|
||||
@ -85,8 +106,18 @@ def generate_outputs(num_beams, only_multi_gpu=False):
|
||||
|
||||
for tp_size, pp_size, cp_size in tp_pp_cp_sizes:
|
||||
print(
|
||||
f'Generating outputs for Llama FP16 with TP={tp_size}, PP={pp_size} and CP={cp_size}'
|
||||
f'Generating outputs for Llama FP16 with TP={tp_size}, PP={pp_size}, CP={cp_size}, BW={num_beams}'
|
||||
)
|
||||
start_time = time.time()
|
||||
|
||||
output_logits = False
|
||||
output_log_probs = False
|
||||
output_cum_log_probs = False
|
||||
if tp_size == 4 and pp_size == 1:
|
||||
output_logits = True
|
||||
output_log_probs = True
|
||||
output_cum_log_probs = True
|
||||
|
||||
model_spec_obj.use_tensor_parallelism(tp_size)
|
||||
model_spec_obj.use_pipeline_parallelism(pp_size)
|
||||
model_spec_obj.use_context_parallelism(cp_size)
|
||||
@ -95,7 +126,15 @@ def generate_outputs(num_beams, only_multi_gpu=False):
|
||||
tp_size=tp_size,
|
||||
pp_size=pp_size,
|
||||
cp_size=cp_size,
|
||||
model_spec_obj=model_spec_obj)
|
||||
model_spec_obj=model_spec_obj,
|
||||
output_logits=output_logits,
|
||||
output_log_probs=output_log_probs,
|
||||
output_cum_log_probs=output_cum_log_probs)
|
||||
|
||||
duration = time.time() - start_time
|
||||
print(
|
||||
f"Generating outputs for Llama FP16 with TP={tp_size}, PP={pp_size}, CP={cp_size}, BW={num_beams} took {duration} seconds"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@ -133,6 +133,26 @@ std::string PathUtil::FP16_PLUGIN_PACKED_PAGED_RESULT_TP2_PP1_FILE()
|
||||
return ModelSpec::getDefaultModelSpec().useTensorParallelism(2).getResultsFile();
|
||||
}
|
||||
|
||||
std::string PathUtil::FP16_PLUGIN_PACKED_PAGED_CONTEXT_LOGITS_TP4_PP1_FILE()
|
||||
{
|
||||
return ModelSpec::getDefaultModelSpec().useTensorParallelism(4).getContextLogitsFile();
|
||||
}
|
||||
|
||||
std::string PathUtil::FP16_PLUGIN_PACKED_PAGED_GENERATION_LOGITS_TP4_PP1_FILE()
|
||||
{
|
||||
return ModelSpec::getDefaultModelSpec().useTensorParallelism(4).getGenerationLogitsFile();
|
||||
}
|
||||
|
||||
std::string PathUtil::FP16_PLUGIN_PACKED_PAGED_CUM_LOG_PROBS_TP4_PP1_FILE()
|
||||
{
|
||||
return ModelSpec::getDefaultModelSpec().useTensorParallelism(4).getCumLogProbsFile();
|
||||
}
|
||||
|
||||
std::string PathUtil::FP16_PLUGIN_PACKED_PAGED_LOG_PROBS_TP4_PP1_FILE()
|
||||
{
|
||||
return ModelSpec::getDefaultModelSpec().useTensorParallelism(4).getLogProbsFile();
|
||||
}
|
||||
|
||||
std::string PathUtil::FP16_PLUGIN_PACKED_PAGED_GATHER_CONTEXTFMHAFP32ACC_RESULT_FILE()
|
||||
{
|
||||
return ModelSpec::getDefaultModelSpec().gatherLogits().enableContextFMHAFp32Acc().getResultsFile();
|
||||
@ -564,7 +584,8 @@ void TestData::verifyLogProbs(bool computeLogProbs, bool streaming, bool exclude
|
||||
}
|
||||
|
||||
void TestData::validateContextLogits(bool getContextLogits, SizeType32 inputLength, SizeType32 beamWidth,
|
||||
std::optional<executor::Tensor> const& contextLogits, SizeType32 vocabSizePadded, SizeType32 batchId)
|
||||
std::optional<executor::Tensor> const& contextLogits, SizeType32 vocabSizePadded, SizeType32 batchId, float atol,
|
||||
float rtol)
|
||||
{
|
||||
if (getContextLogits)
|
||||
{
|
||||
@ -577,7 +598,8 @@ void TestData::validateContextLogits(bool getContextLogits, SizeType32 inputLeng
|
||||
if (beamWidth == 1)
|
||||
{
|
||||
cudaDeviceSynchronize(); // Make sure the logits copy is complete.
|
||||
EXPECT_TRUE(compareLogits(*expectedContextLogits, *(executor::detail::toITensor(contextLogits.value()))));
|
||||
EXPECT_TRUE(compareLogits(
|
||||
*expectedContextLogits, *(executor::detail::toITensor(contextLogits.value())), atol, rtol));
|
||||
}
|
||||
}
|
||||
else
|
||||
@ -589,7 +611,7 @@ void TestData::validateContextLogits(bool getContextLogits, SizeType32 inputLeng
|
||||
void TestData::validateGenerationLogits(bool getGenLogits, bool isFinal, bool streaming, bool excludeInputFromOutput,
|
||||
SizeType32 inputLength, SizeType32 maxOutputLen, SizeType32 beamWidth, executor::BeamTokens const& beamTokens,
|
||||
std::optional<executor::Tensor> const& genLogits, SizeType32 vocabSizePadded, SizeType32 batchId,
|
||||
bool const returnAllGeneratedTokens)
|
||||
bool const returnAllGeneratedTokens, float atol, float rtol)
|
||||
{
|
||||
auto const numReturnBeams = beamTokens.size();
|
||||
|
||||
@ -632,7 +654,7 @@ void TestData::validateGenerationLogits(bool getGenLogits, bool isFinal, bool st
|
||||
numGeneratedToken)); // [numGeneratedToken, vocabSizePadded]
|
||||
|
||||
cudaDeviceSynchronize(); // Make sure the logits copy is complete.
|
||||
EXPECT_TRUE(compareLogits(*expectedGenerationLogitsSlice, *outputGenerationLogits));
|
||||
EXPECT_TRUE(compareLogits(*expectedGenerationLogitsSlice, *outputGenerationLogits, atol, rtol));
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -643,7 +665,7 @@ void TestData::validateGenerationLogits(bool getGenLogits, bool isFinal, bool st
|
||||
if (isFinal && beamWidth == 1)
|
||||
{
|
||||
cudaDeviceSynchronize(); // Make sure the logits copy is complete.
|
||||
EXPECT_TRUE(compareLogits(*expectedGenerationLogits, *outputGenerationLogits));
|
||||
EXPECT_TRUE(compareLogits(*expectedGenerationLogits, *outputGenerationLogits, atol, rtol));
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(genLogits.value().getShape()[2], vocabSizePadded);
|
||||
|
||||
@ -101,6 +101,10 @@ public:
|
||||
static std::string FP16_PLUGIN_PACKED_PAGED_RESULT_TP1_PP4_FILE();
|
||||
static std::string FP16_PLUGIN_PACKED_PAGED_RESULT_TP1_PP2_FILE();
|
||||
static std::string FP16_PLUGIN_PACKED_PAGED_RESULT_TP2_PP1_FILE();
|
||||
static std::string FP16_PLUGIN_PACKED_PAGED_CONTEXT_LOGITS_TP4_PP1_FILE();
|
||||
static std::string FP16_PLUGIN_PACKED_PAGED_GENERATION_LOGITS_TP4_PP1_FILE();
|
||||
static std::string FP16_PLUGIN_PACKED_PAGED_CUM_LOG_PROBS_TP4_PP1_FILE();
|
||||
static std::string FP16_PLUGIN_PACKED_PAGED_LOG_PROBS_TP4_PP1_FILE();
|
||||
// GptExecutorTest.GenerationLogitsEarlyStop requires to use context_fmha_fp32_acc flag in runtime for better
|
||||
// accuracy
|
||||
static std::string FP16_PLUGIN_PACKED_PAGED_GATHER_CONTEXTFMHAFP32ACC_RESULT_FILE();
|
||||
@ -198,12 +202,13 @@ public:
|
||||
FlakyTestInfo flakyTestInfo);
|
||||
|
||||
void validateContextLogits(bool getContextLogits, SizeType32 inputLength, SizeType32 beamWidth,
|
||||
std::optional<executor::Tensor> const& contextLogits, SizeType32 vocabSizePadded, SizeType32 batchId);
|
||||
std::optional<executor::Tensor> const& contextLogits, SizeType32 vocabSizePadded, SizeType32 batchId,
|
||||
float atol = 1e-2, float rtol = 1e-3);
|
||||
|
||||
void validateGenerationLogits(bool getGenLogits, bool isFinal, bool streaming, bool excludeInputFromOutput,
|
||||
SizeType32 inputLength, SizeType32 maxOutputLen, SizeType32 beamWidth, executor::BeamTokens const& beamTokens,
|
||||
std::optional<executor::Tensor> const& genLogits, SizeType32 vocabSizePadded, SizeType32 batchId,
|
||||
bool returnAllGeneratedTokens);
|
||||
bool returnAllGeneratedTokens, float atol = 1e-2, float rtol = 1e-3);
|
||||
|
||||
SizeType32 nbGivenInputs{};
|
||||
SizeType32 beamWidth{};
|
||||
|
||||
Loading…
Reference in New Issue
Block a user