mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None][feat] spark cublas LUT table for llama-8b-bf16 perf (#9811)
Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com>
This commit is contained in:
parent
e4e09867d1
commit
98d72c7648
@ -14,6 +14,7 @@
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "cublasScaledMMLut.h"
|
||||
#include "tensorrt_llm/common/cublasMMWrapper.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/kernels/userbuffers/ub_interface.h"
|
||||
@ -22,10 +23,8 @@
|
||||
#include "tensorrt_llm/runtime/torchUtils.h"
|
||||
#include "tensorrt_llm/thop/thUtils.h"
|
||||
#include "userbuffersTensor.h"
|
||||
#include <array>
|
||||
#include <cublasLt.h>
|
||||
#include <torch/extension.h>
|
||||
#include <unordered_map>
|
||||
|
||||
using torch::Tensor;
|
||||
|
||||
@ -39,67 +38,7 @@ namespace
|
||||
|
||||
using tensorrt_llm::common::check;
|
||||
using tensorrt_llm::common::CublasMMWrapper;
|
||||
|
||||
struct hash_tuple
|
||||
{
|
||||
size_t operator()(std::tuple<int, int, int> const& x) const
|
||||
{
|
||||
return std::get<0>(x) ^ std::get<1>(x) ^ std::get<2>(x);
|
||||
}
|
||||
};
|
||||
|
||||
// got from cublasTest matmultFind
|
||||
// {mp2, k, n}: {algo, m_tile, m_stages, m_numsK, m_reduction, m_swizzle, m_custom, m_cga}
|
||||
using AlgoListType = std::unordered_map<std::tuple<int32_t, int32_t, int32_t>, std::array<int, 8>, hash_tuple>;
|
||||
|
||||
// bf16*bf16->fp32->bf16
|
||||
AlgoListType spark_bf16_algo_list = {
|
||||
// GPT-OSS-20b
|
||||
//-m201088 -n1 -algo21 -m_tile11 -m_stages20 -m_workmem0 -k2880
|
||||
{{8, 2880, 201088}, {21, 11, 20, 1, 0, 0, 0, 0}},
|
||||
//-m32 -n1 -algo14 -m_reduction2 -m_numsK10 -m_workmem1024 -k2880
|
||||
{{8, 2880, 32}, {14, 0, 0, 10, 2, 0, 0, 0}},
|
||||
//-m32 -n2048 -algo21 -m_tile11 -m_stages13 -m_reduction1 -m_numsK9 -m_workmem1024
|
||||
//-k2880
|
||||
{{2048, 2880, 32}, {21, 11, 13, 9, 1, 0, 0, 0}},
|
||||
//-m32 -n2175 -algo21 -m_tile11 -m_stages19 -m_reduction1 -m_numsK11
|
||||
//-m_workmem1024 -k2880
|
||||
{{4096, 2880, 32}, {21, 11, 19, 11, 1, 0, 0, 0}},
|
||||
//-m5120 -n1 -algo23 -m_tile11 -m_stages8 -m_reduction1 -m_numsK2
|
||||
//-m_workmem1024 -k2880
|
||||
{{8, 2880, 5120}, {23, 11, 8, 2, 1, 0, 0, 0}},
|
||||
//-m5120 -n2048 -algo21 -m_tile20 -m_stages15 -m_workmem1024 -k2880
|
||||
{{2048, 2880, 5120}, {21, 20, 15, 1, 0, 0, 0, 0}},
|
||||
//-m5120 -n2175 -algo21 -m_tile20 -m_stages15 -m_workmem1024 -k2880
|
||||
{{4096, 2880, 5120}, {21, 20, 15, 1, 0, 0, 0, 0}},
|
||||
//-m2880 -n1 -algo23 -m_tile11 -m_stages14 -m_reduction1 -m_numsK24 -m_workmem1024 -k4096
|
||||
{{8, 4096, 2880}, {23, 11, 14, 24, 1, 0, 0, 0}},
|
||||
//-m2880 -n2048 -ldc2880 -Poutt -ldd2880 -Ps -Pscales -algo21 -m_tile20 -m_stages15 -m_workmem1024 -k4096
|
||||
{{2048, 4096, 2880}, {21, 20, 15, 1, 0, 0, 0, 0}},
|
||||
//-m2880 -n2175 -ldc2880 -Poutt -ldd2880 -Ps -Pscales -algo21 -m_tile20 -m_stages15 -m_workmem1024 -k4096
|
||||
{{4096, 4096, 2880}, {21, 20, 15, 1, 0, 0, 0, 0}},
|
||||
};
|
||||
|
||||
// bf16*bf16->fp32->bf16
|
||||
AlgoListType bf16_algo_list = {
|
||||
// Deepseek v3/R1 router gemm
|
||||
// [-algo66 -m_tile10 -m_stages35 -m_numsK1 -m_reduction0 -m_swizzle0 -m_custom3 -m_mma0 -m_cga2 -m_scheduling1]
|
||||
{{8, 7168, 256}, {66, 10, 35, 1, 0, 0, 3, 2}},
|
||||
{{512, 7168, 256}, {66, 48, 35, 1, 0, 0, 0, 2}},
|
||||
{{1024, 7168, 256}, {66, 13, 35, 1, 0, 0, 1, 3}},
|
||||
};
|
||||
|
||||
// fp8*fp8->fp32->fp16
|
||||
AlgoListType fp8_algo_list = {
|
||||
// Llama-3.1-70B
|
||||
// [-algo66 -m_tile393 -m_stages36 -m_numsK1 -m_reduction0 -m_swizzle0 -m_custom5 -m_mma0 -m_cga2 -m_scheduling1]
|
||||
{{8, 8192, 8192}, {66, 393, 36, 1, 0, 0, 5, 2}},
|
||||
// [-algo66 -m_tile10 -m_stages36 -m_numsK1 -m_reduction0 -m_swizzle0 -m_custom1 -m_mma0 -m_cga2 -m_scheduling1]
|
||||
{{8, 8192, 57344}, {66, 10, 36, 1, 0, 0, 1, 2}},
|
||||
// Llama-3.3-70B TP4 (this is the default algo on B200. Here we aim to use the same algo on GB200.)
|
||||
// [-algo66 -m_tile393 -m_stages36 -m_numsK1 -m_reduction0 -m_swizzle0 -m_custom1 -m_mma0 -m_cga4 -m_scheduling1]
|
||||
{{8, 8192, 14336}, {66, 393, 36, 1, 0, 1, 1, 4}},
|
||||
};
|
||||
using cublas_lut::AlgoListType;
|
||||
|
||||
void set_algo_attr(cublasLtMatmulAlgo_t& algo, std::array<int, 8> const& attr_list)
|
||||
{
|
||||
@ -127,17 +66,18 @@ bool find_special_algo(cublasLtMatmulAlgo_t& algo, std::shared_ptr<CublasMMWrapp
|
||||
cudaDataType_t bType, cudaDataType_t outType)
|
||||
{
|
||||
int32_t mp2 = std::max(nextPowerOfTwo(m), 8);
|
||||
AlgoListType algo_list;
|
||||
AlgoListType const* algo_list = nullptr;
|
||||
if ((aType == CUDA_R_16BF || aType == CUDA_R_16F) && (outType == aType || outType == CUDA_R_32F)
|
||||
&& compType == CUBLAS_COMPUTE_32F)
|
||||
{
|
||||
// TODO: remove this after cublas fix the heuristic for Spark
|
||||
algo_list = tensorrt_llm::common::getSMVersion(/*queryRealSmArch=*/true) == 121 ? spark_bf16_algo_list
|
||||
: bf16_algo_list;
|
||||
algo_list = tensorrt_llm::common::getSMVersion(/*queryRealSmArch=*/true) == 121
|
||||
? &cublas_lut::spark_bf16_algo_list
|
||||
: &cublas_lut::bf16_algo_list;
|
||||
}
|
||||
else if (aType == CUDA_R_8F_E4M3 && compType == CUBLAS_COMPUTE_32F)
|
||||
{
|
||||
algo_list = fp8_algo_list;
|
||||
algo_list = &cublas_lut::fp8_algo_list;
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -145,11 +85,12 @@ bool find_special_algo(cublasLtMatmulAlgo_t& algo, std::shared_ptr<CublasMMWrapp
|
||||
"No special cublasLt algo found for aType=%d, outType=%d, compType=%d\n", aType, outType, compType);
|
||||
return false;
|
||||
}
|
||||
if (auto algo_iter = algo_list.find({mp2, k, n}); algo_iter != algo_list.end())
|
||||
if (auto algo_iter = algo_list->find({mp2, k, n}); algo_iter != algo_list->end())
|
||||
{
|
||||
int const algoID = algo_iter->second[0];
|
||||
check_cuda_error(cublasLtMatmulAlgoInit(
|
||||
cublasWrapper->getCublasLtHandle(), compType, scaleType, aType, bType, outType, outType, algoID, &algo));
|
||||
TLLM_LOG_DEBUG("Found special cublasLt algo for m=%d, k=%d, n=%d\n", m, k, n);
|
||||
set_algo_attr(algo, algo_iter->second);
|
||||
}
|
||||
else
|
||||
|
||||
99
cpp/tensorrt_llm/thop/cublasScaledMMLut.h
Normal file
99
cpp/tensorrt_llm/thop/cublasScaledMMLut.h
Normal file
@ -0,0 +1,99 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <array>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <tuple>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace torch_ext
|
||||
{
|
||||
namespace cublas_lut
|
||||
{
|
||||
|
||||
struct HashTuple
|
||||
{
|
||||
size_t operator()(std::tuple<int32_t, int32_t, int32_t> const& x) const
|
||||
{
|
||||
return std::get<0>(x) ^ std::get<1>(x) ^ std::get<2>(x);
|
||||
}
|
||||
};
|
||||
|
||||
// {mp2, k, n}: {algo, m_tile, m_stages, m_numsK, m_reduction, m_swizzle, m_custom, m_cga}
|
||||
using AlgoListType = std::unordered_map<std::tuple<int32_t, int32_t, int32_t>, std::array<int, 8>, HashTuple>;
|
||||
|
||||
inline const AlgoListType spark_bf16_algo_list = {
|
||||
// llama 8b instruct fp16 decode
|
||||
// [-algo67 -m_tile6 -m_stages35 -m_numsK1 -m_reduction0 -m_swizzle0 -m_custom130 -m_mma0 -m_cga2 -m_scheduling1]
|
||||
{{8, 4096, 4096}, {67, 6, 35, 1, 0, 0, 130, 2}},
|
||||
// [-algo67 -m_tile393 -m_stages35 -m_numsK1 -m_reduction0 -m_swizzle0 -m_custom142 -m_mma0 -m_cga2 -m_scheduling1]
|
||||
{{8, 4096, 6144}, {67, 393, 35, 1, 0, 0, 142, 2}},
|
||||
// [-algo67 -m_tile393 -m_stages35 -m_numsK1 -m_reduction0 -m_swizzle0 -m_custom142 -m_mma0 -m_cga2 -m_scheduling1]
|
||||
{{8, 4096, 128256}, {67, 393, 35, 1, 0, 0, 142, 2}},
|
||||
|
||||
// gpt-oss mxfp4-fp16 decode
|
||||
// [-algo67 -m_tile393 -m_stages35 -m_numsK1 -m_reduction0 -m_swizzle0 -m_custom142 -m_mma0 -m_cga2 -m_scheduling1]
|
||||
{{8, 2880, 201088}, {67, 393, 35, 1, 0, 0, 142, 2}},
|
||||
// [-algo14 -m_tile0 -m_stages35 -m_numsK10 -m_reduction2 -m_swizzle0 -m_custom0 -m_mma0 -m_cga0 -m_scheduling1]
|
||||
{{8, 2880, 32}, {14, 0, 0, 10, 2, 0, 0, 0}},
|
||||
// [-algo21 -m_tile11 -m_stages13 -m_numsK9 -m_reduction1 -m_swizzle0 -m_custom0 -m_mma0 -m_cga0 -m_scheduling1]
|
||||
//-k2880
|
||||
{{2048, 2880, 32}, {21, 11, 13, 9, 1, 0, 0, 0}},
|
||||
// [-algo21 -m_tile11 -m_stages19 -m_numsK11 -m_reduction1 -m_swizzle0 -m_custom0 -m_mma0 -m_cga0 -m_scheduling1]
|
||||
//-m_workmem1024 -k2880
|
||||
{{4096, 2880, 32}, {21, 11, 19, 11, 1, 0, 0, 0}},
|
||||
// [-algo23 -m_tile11 -m_stages8 -m_numsK2 -m_reduction1 -m_swizzle0 -m_custom0 -m_mma0 -m_cga0 -m_scheduling1]
|
||||
//-m_workmem1024 -k2880
|
||||
{{8, 2880, 5120}, {23, 11, 8, 2, 1, 0, 0, 0}},
|
||||
// [-algo21 -m_tile20 -m_stages15 -m_numsK1 -m_reduction0 -m_swizzle0 -m_custom0 -m_mma0 -m_cga0 -m_scheduling1]
|
||||
{{2048, 2880, 5120}, {21, 20, 15, 1, 0, 0, 0, 0}},
|
||||
// [-algo21 -m_tile20 -m_stages15 -m_numsK1 -m_reduction0 -m_swizzle0 -m_custom0 -m_mma0 -m_cga0 -m_scheduling1]
|
||||
{{4096, 2880, 5120}, {21, 20, 15, 1, 0, 0, 0, 0}},
|
||||
// [-algo23 -m_tile11 -m_stages14 -m_numsK24 -m_reduction1 -m_swizzle0 -m_custom0 -m_mma0 -m_cga0 -m_scheduling1]
|
||||
{{8, 4096, 2880}, {23, 11, 14, 24, 1, 0, 0, 0}},
|
||||
// [-algo21 -m_tile20 -m_stages15 -m_numsK1 -m_reduction0 -m_swizzle0 -m_custom0 -m_mma0 -m_cga0 -m_scheduling1]
|
||||
{{2048, 4096, 2880}, {21, 20, 15, 1, 0, 0, 0, 0}},
|
||||
// [-algo21 -m_tile20 -m_stages15 -m_numsK1 -m_reduction0 -m_swizzle0 -m_custom0 -m_mma0 -m_cga0 -m_scheduling1]
|
||||
{{4096, 4096, 2880}, {21, 20, 15, 1, 0, 0, 0, 0}},
|
||||
|
||||
};
|
||||
|
||||
// bf16*bf16->fp32->bf16
|
||||
inline const AlgoListType bf16_algo_list = {
|
||||
// Deepseek v3/R1 router gemm
|
||||
// [-algo66 -m_tile10 -m_stages35 -m_numsK1 -m_reduction0 -m_swizzle0 -m_custom3 -m_mma0 -m_cga2 -m_scheduling1]
|
||||
{{8, 7168, 256}, {66, 10, 35, 1, 0, 0, 3, 2}},
|
||||
{{512, 7168, 256}, {66, 48, 35, 1, 0, 0, 0, 2}},
|
||||
{{1024, 7168, 256}, {66, 13, 35, 1, 0, 0, 1, 3}},
|
||||
};
|
||||
|
||||
// fp8*fp8->fp32->fp16
|
||||
inline const AlgoListType fp8_algo_list = {
|
||||
// Llama-3.1-70B
|
||||
// [-algo66 -m_tile393 -m_stages36 -m_numsK1 -m_reduction0 -m_swizzle0 -m_custom5 -m_mma0 -m_cga2 -m_scheduling1]
|
||||
{{8, 8192, 8192}, {66, 393, 36, 1, 0, 0, 5, 2}},
|
||||
// [-algo66 -m_tile10 -m_stages36 -m_numsK1 -m_reduction0 -m_swizzle0 -m_custom1 -m_mma0 -m_cga2 -m_scheduling1]
|
||||
{{8, 8192, 57344}, {66, 10, 36, 1, 0, 0, 1, 2}},
|
||||
// Llama-3.3-70B TP4 (this is the default algo on B200. Here we aim to use the same algo on GB200.)
|
||||
// [-algo66 -m_tile393 -m_stages36 -m_numsK1 -m_reduction0 -m_swizzle0 -m_custom1 -m_mma0 -m_cga4 -m_scheduling1]
|
||||
{{8, 8192, 14336}, {66, 393, 36, 1, 0, 1, 1, 4}},
|
||||
};
|
||||
|
||||
} // namespace cublas_lut
|
||||
} // namespace torch_ext
|
||||
@ -230,6 +230,7 @@ class LlamaAttention(Attention):
|
||||
self,
|
||||
model_config: ModelConfig[LlamaConfig],
|
||||
layer_idx: Optional[int] = None,
|
||||
use_custom_cublas_mm: bool = False,
|
||||
):
|
||||
config = model_config.pretrained_config
|
||||
super().__init__(
|
||||
@ -245,6 +246,7 @@ class LlamaAttention(Attention):
|
||||
layer_idx=layer_idx,
|
||||
dtype=config.torch_dtype,
|
||||
config=model_config,
|
||||
use_custom_cublas_mm=use_custom_cublas_mm,
|
||||
)
|
||||
|
||||
|
||||
@ -618,6 +620,7 @@ class LlamaDecoderLayer(DecoderLayer):
|
||||
self,
|
||||
model_config: ModelConfig[LlamaConfig],
|
||||
layer_idx: int,
|
||||
use_custom_cublas_mm: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
super().__init__()
|
||||
config = model_config.pretrained_config
|
||||
@ -634,6 +637,7 @@ class LlamaDecoderLayer(DecoderLayer):
|
||||
self.self_attn = LlamaAttention(
|
||||
model_config,
|
||||
layer_idx=layer_idx,
|
||||
use_custom_cublas_mm=use_custom_cublas_mm,
|
||||
)
|
||||
|
||||
self.mlp = GatedMLP(
|
||||
@ -643,6 +647,7 @@ class LlamaDecoderLayer(DecoderLayer):
|
||||
dtype=config.torch_dtype,
|
||||
config=model_config,
|
||||
layer_idx=layer_idx,
|
||||
use_custom_cublas_mm=use_custom_cublas_mm,
|
||||
)
|
||||
self.input_layernorm = RMSNorm(hidden_size=config.hidden_size,
|
||||
eps=config.rms_norm_eps,
|
||||
@ -889,6 +894,8 @@ class LlamaModel(DecoderModel):
|
||||
config = self.model_config.pretrained_config
|
||||
self.num_hidden_layers = config.num_hidden_layers
|
||||
|
||||
self.use_custom_cublas_mm = get_sm_version() == 121
|
||||
|
||||
vocab_size = config.vocab_size
|
||||
# TODO smor- we load manually only if there is a single lora dir, need to come up with a better solution
|
||||
self.has_custom_embed_tokens = False
|
||||
@ -909,6 +916,7 @@ class LlamaModel(DecoderModel):
|
||||
vocab_size,
|
||||
config.hidden_size,
|
||||
dtype=config.torch_dtype,
|
||||
use_custom_cublas_mm=self.use_custom_cublas_mm,
|
||||
)
|
||||
else:
|
||||
self.embed_tokens = Embedding(
|
||||
@ -918,6 +926,7 @@ class LlamaModel(DecoderModel):
|
||||
mapping=model_config.mapping,
|
||||
tensor_parallel_mode=TensorParallelMode.COLUMN,
|
||||
gather_output=True,
|
||||
use_custom_cublas_mm=self.use_custom_cublas_mm,
|
||||
)
|
||||
|
||||
if self.has_custom_embed_tokens:
|
||||
@ -932,7 +941,8 @@ class LlamaModel(DecoderModel):
|
||||
self.embed_tokens.weight.data.copy_(x)
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
LlamaDecoderLayer(model_config, layer_idx)
|
||||
LlamaDecoderLayer(model_config, layer_idx,
|
||||
self.use_custom_cublas_mm)
|
||||
for layer_idx in range(config.num_hidden_layers)
|
||||
])
|
||||
self.norm = RMSNorm(hidden_size=config.hidden_size,
|
||||
|
||||
@ -32,6 +32,7 @@ class GatedMLP(nn.Module):
|
||||
layer_idx: Optional[int] = None,
|
||||
use_cute_dsl_blockscaling_mm: bool = False,
|
||||
disable_deep_gemm: bool = False,
|
||||
use_custom_cublas_mm: bool = False,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
@ -83,6 +84,7 @@ class GatedMLP(nn.Module):
|
||||
use_cute_dsl_blockscaling_mm=use_cute_dsl_blockscaling_mm,
|
||||
disable_deep_gemm=disable_deep_gemm,
|
||||
fused_weight_shard_indices_mapping=gateup_shard_indices_mapping,
|
||||
use_custom_cublas_mm=use_custom_cublas_mm,
|
||||
)
|
||||
|
||||
self.down_lora = LoraLayer([LoraModuleType.MLP_4H_TO_H],
|
||||
@ -103,6 +105,7 @@ class GatedMLP(nn.Module):
|
||||
force_dynamic_quantization=config.force_dynamic_quantization,
|
||||
use_cute_dsl_blockscaling_mm=use_cute_dsl_blockscaling_mm,
|
||||
disable_deep_gemm=disable_deep_gemm,
|
||||
use_custom_cublas_mm=use_custom_cublas_mm,
|
||||
)
|
||||
|
||||
# These two modules are mutually exclusive - either splitted_gate_up_lora or fused_gate_up_lora will be used,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user