[None][feat] spark cublas LUT table for llama-8b-bf16 perf (#9811)

Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com>
This commit is contained in:
Faraz 2025-12-12 22:37:56 -05:00 committed by GitHub
parent e4e09867d1
commit 98d72c7648
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 122 additions and 69 deletions

View File

@ -14,6 +14,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "cublasScaledMMLut.h"
#include "tensorrt_llm/common/cublasMMWrapper.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/kernels/userbuffers/ub_interface.h"
@ -22,10 +23,8 @@
#include "tensorrt_llm/runtime/torchUtils.h"
#include "tensorrt_llm/thop/thUtils.h"
#include "userbuffersTensor.h"
#include <array>
#include <cublasLt.h>
#include <torch/extension.h>
#include <unordered_map>
using torch::Tensor;
@ -39,67 +38,7 @@ namespace
using tensorrt_llm::common::check;
using tensorrt_llm::common::CublasMMWrapper;
struct hash_tuple
{
size_t operator()(std::tuple<int, int, int> const& x) const
{
return std::get<0>(x) ^ std::get<1>(x) ^ std::get<2>(x);
}
};
// got from cublasTest matmultFind
// {mp2, k, n}: {algo, m_tile, m_stages, m_numsK, m_reduction, m_swizzle, m_custom, m_cga}
using AlgoListType = std::unordered_map<std::tuple<int32_t, int32_t, int32_t>, std::array<int, 8>, hash_tuple>;
// bf16*bf16->fp32->bf16
AlgoListType spark_bf16_algo_list = {
// GPT-OSS-20b
//-m201088 -n1 -algo21 -m_tile11 -m_stages20 -m_workmem0 -k2880
{{8, 2880, 201088}, {21, 11, 20, 1, 0, 0, 0, 0}},
//-m32 -n1 -algo14 -m_reduction2 -m_numsK10 -m_workmem1024 -k2880
{{8, 2880, 32}, {14, 0, 0, 10, 2, 0, 0, 0}},
//-m32 -n2048 -algo21 -m_tile11 -m_stages13 -m_reduction1 -m_numsK9 -m_workmem1024
//-k2880
{{2048, 2880, 32}, {21, 11, 13, 9, 1, 0, 0, 0}},
//-m32 -n2175 -algo21 -m_tile11 -m_stages19 -m_reduction1 -m_numsK11
//-m_workmem1024 -k2880
{{4096, 2880, 32}, {21, 11, 19, 11, 1, 0, 0, 0}},
//-m5120 -n1 -algo23 -m_tile11 -m_stages8 -m_reduction1 -m_numsK2
//-m_workmem1024 -k2880
{{8, 2880, 5120}, {23, 11, 8, 2, 1, 0, 0, 0}},
//-m5120 -n2048 -algo21 -m_tile20 -m_stages15 -m_workmem1024 -k2880
{{2048, 2880, 5120}, {21, 20, 15, 1, 0, 0, 0, 0}},
//-m5120 -n2175 -algo21 -m_tile20 -m_stages15 -m_workmem1024 -k2880
{{4096, 2880, 5120}, {21, 20, 15, 1, 0, 0, 0, 0}},
//-m2880 -n1 -algo23 -m_tile11 -m_stages14 -m_reduction1 -m_numsK24 -m_workmem1024 -k4096
{{8, 4096, 2880}, {23, 11, 14, 24, 1, 0, 0, 0}},
//-m2880 -n2048 -ldc2880 -Poutt -ldd2880 -Ps -Pscales -algo21 -m_tile20 -m_stages15 -m_workmem1024 -k4096
{{2048, 4096, 2880}, {21, 20, 15, 1, 0, 0, 0, 0}},
//-m2880 -n2175 -ldc2880 -Poutt -ldd2880 -Ps -Pscales -algo21 -m_tile20 -m_stages15 -m_workmem1024 -k4096
{{4096, 4096, 2880}, {21, 20, 15, 1, 0, 0, 0, 0}},
};
// bf16*bf16->fp32->bf16
AlgoListType bf16_algo_list = {
// Deepseek v3/R1 router gemm
// [-algo66 -m_tile10 -m_stages35 -m_numsK1 -m_reduction0 -m_swizzle0 -m_custom3 -m_mma0 -m_cga2 -m_scheduling1]
{{8, 7168, 256}, {66, 10, 35, 1, 0, 0, 3, 2}},
{{512, 7168, 256}, {66, 48, 35, 1, 0, 0, 0, 2}},
{{1024, 7168, 256}, {66, 13, 35, 1, 0, 0, 1, 3}},
};
// fp8*fp8->fp32->fp16
AlgoListType fp8_algo_list = {
// Llama-3.1-70B
// [-algo66 -m_tile393 -m_stages36 -m_numsK1 -m_reduction0 -m_swizzle0 -m_custom5 -m_mma0 -m_cga2 -m_scheduling1]
{{8, 8192, 8192}, {66, 393, 36, 1, 0, 0, 5, 2}},
// [-algo66 -m_tile10 -m_stages36 -m_numsK1 -m_reduction0 -m_swizzle0 -m_custom1 -m_mma0 -m_cga2 -m_scheduling1]
{{8, 8192, 57344}, {66, 10, 36, 1, 0, 0, 1, 2}},
// Llama-3.3-70B TP4 (this is the default algo on B200. Here we aim to use the same algo on GB200.)
// [-algo66 -m_tile393 -m_stages36 -m_numsK1 -m_reduction0 -m_swizzle0 -m_custom1 -m_mma0 -m_cga4 -m_scheduling1]
{{8, 8192, 14336}, {66, 393, 36, 1, 0, 1, 1, 4}},
};
using cublas_lut::AlgoListType;
void set_algo_attr(cublasLtMatmulAlgo_t& algo, std::array<int, 8> const& attr_list)
{
@ -127,17 +66,18 @@ bool find_special_algo(cublasLtMatmulAlgo_t& algo, std::shared_ptr<CublasMMWrapp
cudaDataType_t bType, cudaDataType_t outType)
{
int32_t mp2 = std::max(nextPowerOfTwo(m), 8);
AlgoListType algo_list;
AlgoListType const* algo_list = nullptr;
if ((aType == CUDA_R_16BF || aType == CUDA_R_16F) && (outType == aType || outType == CUDA_R_32F)
&& compType == CUBLAS_COMPUTE_32F)
{
// TODO: remove this after cublas fix the heuristic for Spark
algo_list = tensorrt_llm::common::getSMVersion(/*queryRealSmArch=*/true) == 121 ? spark_bf16_algo_list
: bf16_algo_list;
algo_list = tensorrt_llm::common::getSMVersion(/*queryRealSmArch=*/true) == 121
? &cublas_lut::spark_bf16_algo_list
: &cublas_lut::bf16_algo_list;
}
else if (aType == CUDA_R_8F_E4M3 && compType == CUBLAS_COMPUTE_32F)
{
algo_list = fp8_algo_list;
algo_list = &cublas_lut::fp8_algo_list;
}
else
{
@ -145,11 +85,12 @@ bool find_special_algo(cublasLtMatmulAlgo_t& algo, std::shared_ptr<CublasMMWrapp
"No special cublasLt algo found for aType=%d, outType=%d, compType=%d\n", aType, outType, compType);
return false;
}
if (auto algo_iter = algo_list.find({mp2, k, n}); algo_iter != algo_list.end())
if (auto algo_iter = algo_list->find({mp2, k, n}); algo_iter != algo_list->end())
{
int const algoID = algo_iter->second[0];
check_cuda_error(cublasLtMatmulAlgoInit(
cublasWrapper->getCublasLtHandle(), compType, scaleType, aType, bType, outType, outType, algoID, &algo));
TLLM_LOG_DEBUG("Found special cublasLt algo for m=%d, k=%d, n=%d\n", m, k, n);
set_algo_attr(algo, algo_iter->second);
}
else

View File

@ -0,0 +1,99 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <array>
#include <cstddef>
#include <cstdint>
#include <tuple>
#include <unordered_map>
namespace torch_ext
{
namespace cublas_lut
{
struct HashTuple
{
size_t operator()(std::tuple<int32_t, int32_t, int32_t> const& x) const
{
return std::get<0>(x) ^ std::get<1>(x) ^ std::get<2>(x);
}
};
// {mp2, k, n}: {algo, m_tile, m_stages, m_numsK, m_reduction, m_swizzle, m_custom, m_cga}
using AlgoListType = std::unordered_map<std::tuple<int32_t, int32_t, int32_t>, std::array<int, 8>, HashTuple>;
inline const AlgoListType spark_bf16_algo_list = {
// llama 8b instruct fp16 decode
// [-algo67 -m_tile6 -m_stages35 -m_numsK1 -m_reduction0 -m_swizzle0 -m_custom130 -m_mma0 -m_cga2 -m_scheduling1]
{{8, 4096, 4096}, {67, 6, 35, 1, 0, 0, 130, 2}},
// [-algo67 -m_tile393 -m_stages35 -m_numsK1 -m_reduction0 -m_swizzle0 -m_custom142 -m_mma0 -m_cga2 -m_scheduling1]
{{8, 4096, 6144}, {67, 393, 35, 1, 0, 0, 142, 2}},
// [-algo67 -m_tile393 -m_stages35 -m_numsK1 -m_reduction0 -m_swizzle0 -m_custom142 -m_mma0 -m_cga2 -m_scheduling1]
{{8, 4096, 128256}, {67, 393, 35, 1, 0, 0, 142, 2}},
// gpt-oss mxfp4-fp16 decode
// [-algo67 -m_tile393 -m_stages35 -m_numsK1 -m_reduction0 -m_swizzle0 -m_custom142 -m_mma0 -m_cga2 -m_scheduling1]
{{8, 2880, 201088}, {67, 393, 35, 1, 0, 0, 142, 2}},
// [-algo14 -m_tile0 -m_stages35 -m_numsK10 -m_reduction2 -m_swizzle0 -m_custom0 -m_mma0 -m_cga0 -m_scheduling1]
{{8, 2880, 32}, {14, 0, 0, 10, 2, 0, 0, 0}},
// [-algo21 -m_tile11 -m_stages13 -m_numsK9 -m_reduction1 -m_swizzle0 -m_custom0 -m_mma0 -m_cga0 -m_scheduling1]
//-k2880
{{2048, 2880, 32}, {21, 11, 13, 9, 1, 0, 0, 0}},
// [-algo21 -m_tile11 -m_stages19 -m_numsK11 -m_reduction1 -m_swizzle0 -m_custom0 -m_mma0 -m_cga0 -m_scheduling1]
//-m_workmem1024 -k2880
{{4096, 2880, 32}, {21, 11, 19, 11, 1, 0, 0, 0}},
// [-algo23 -m_tile11 -m_stages8 -m_numsK2 -m_reduction1 -m_swizzle0 -m_custom0 -m_mma0 -m_cga0 -m_scheduling1]
//-m_workmem1024 -k2880
{{8, 2880, 5120}, {23, 11, 8, 2, 1, 0, 0, 0}},
// [-algo21 -m_tile20 -m_stages15 -m_numsK1 -m_reduction0 -m_swizzle0 -m_custom0 -m_mma0 -m_cga0 -m_scheduling1]
{{2048, 2880, 5120}, {21, 20, 15, 1, 0, 0, 0, 0}},
// [-algo21 -m_tile20 -m_stages15 -m_numsK1 -m_reduction0 -m_swizzle0 -m_custom0 -m_mma0 -m_cga0 -m_scheduling1]
{{4096, 2880, 5120}, {21, 20, 15, 1, 0, 0, 0, 0}},
// [-algo23 -m_tile11 -m_stages14 -m_numsK24 -m_reduction1 -m_swizzle0 -m_custom0 -m_mma0 -m_cga0 -m_scheduling1]
{{8, 4096, 2880}, {23, 11, 14, 24, 1, 0, 0, 0}},
// [-algo21 -m_tile20 -m_stages15 -m_numsK1 -m_reduction0 -m_swizzle0 -m_custom0 -m_mma0 -m_cga0 -m_scheduling1]
{{2048, 4096, 2880}, {21, 20, 15, 1, 0, 0, 0, 0}},
// [-algo21 -m_tile20 -m_stages15 -m_numsK1 -m_reduction0 -m_swizzle0 -m_custom0 -m_mma0 -m_cga0 -m_scheduling1]
{{4096, 4096, 2880}, {21, 20, 15, 1, 0, 0, 0, 0}},
};
// bf16*bf16->fp32->bf16
inline const AlgoListType bf16_algo_list = {
// Deepseek v3/R1 router gemm
// [-algo66 -m_tile10 -m_stages35 -m_numsK1 -m_reduction0 -m_swizzle0 -m_custom3 -m_mma0 -m_cga2 -m_scheduling1]
{{8, 7168, 256}, {66, 10, 35, 1, 0, 0, 3, 2}},
{{512, 7168, 256}, {66, 48, 35, 1, 0, 0, 0, 2}},
{{1024, 7168, 256}, {66, 13, 35, 1, 0, 0, 1, 3}},
};
// fp8*fp8->fp32->fp16
inline const AlgoListType fp8_algo_list = {
// Llama-3.1-70B
// [-algo66 -m_tile393 -m_stages36 -m_numsK1 -m_reduction0 -m_swizzle0 -m_custom5 -m_mma0 -m_cga2 -m_scheduling1]
{{8, 8192, 8192}, {66, 393, 36, 1, 0, 0, 5, 2}},
// [-algo66 -m_tile10 -m_stages36 -m_numsK1 -m_reduction0 -m_swizzle0 -m_custom1 -m_mma0 -m_cga2 -m_scheduling1]
{{8, 8192, 57344}, {66, 10, 36, 1, 0, 0, 1, 2}},
// Llama-3.3-70B TP4 (this is the default algo on B200. Here we aim to use the same algo on GB200.)
// [-algo66 -m_tile393 -m_stages36 -m_numsK1 -m_reduction0 -m_swizzle0 -m_custom1 -m_mma0 -m_cga4 -m_scheduling1]
{{8, 8192, 14336}, {66, 393, 36, 1, 0, 1, 1, 4}},
};
} // namespace cublas_lut
} // namespace torch_ext

View File

@ -230,6 +230,7 @@ class LlamaAttention(Attention):
self,
model_config: ModelConfig[LlamaConfig],
layer_idx: Optional[int] = None,
use_custom_cublas_mm: bool = False,
):
config = model_config.pretrained_config
super().__init__(
@ -245,6 +246,7 @@ class LlamaAttention(Attention):
layer_idx=layer_idx,
dtype=config.torch_dtype,
config=model_config,
use_custom_cublas_mm=use_custom_cublas_mm,
)
@ -618,6 +620,7 @@ class LlamaDecoderLayer(DecoderLayer):
self,
model_config: ModelConfig[LlamaConfig],
layer_idx: int,
use_custom_cublas_mm: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
super().__init__()
config = model_config.pretrained_config
@ -634,6 +637,7 @@ class LlamaDecoderLayer(DecoderLayer):
self.self_attn = LlamaAttention(
model_config,
layer_idx=layer_idx,
use_custom_cublas_mm=use_custom_cublas_mm,
)
self.mlp = GatedMLP(
@ -643,6 +647,7 @@ class LlamaDecoderLayer(DecoderLayer):
dtype=config.torch_dtype,
config=model_config,
layer_idx=layer_idx,
use_custom_cublas_mm=use_custom_cublas_mm,
)
self.input_layernorm = RMSNorm(hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
@ -889,6 +894,8 @@ class LlamaModel(DecoderModel):
config = self.model_config.pretrained_config
self.num_hidden_layers = config.num_hidden_layers
self.use_custom_cublas_mm = get_sm_version() == 121
vocab_size = config.vocab_size
# TODO smor- we load manually only if there is a single lora dir, need to come up with a better solution
self.has_custom_embed_tokens = False
@ -909,6 +916,7 @@ class LlamaModel(DecoderModel):
vocab_size,
config.hidden_size,
dtype=config.torch_dtype,
use_custom_cublas_mm=self.use_custom_cublas_mm,
)
else:
self.embed_tokens = Embedding(
@ -918,6 +926,7 @@ class LlamaModel(DecoderModel):
mapping=model_config.mapping,
tensor_parallel_mode=TensorParallelMode.COLUMN,
gather_output=True,
use_custom_cublas_mm=self.use_custom_cublas_mm,
)
if self.has_custom_embed_tokens:
@ -932,7 +941,8 @@ class LlamaModel(DecoderModel):
self.embed_tokens.weight.data.copy_(x)
self.layers = nn.ModuleList([
LlamaDecoderLayer(model_config, layer_idx)
LlamaDecoderLayer(model_config, layer_idx,
self.use_custom_cublas_mm)
for layer_idx in range(config.num_hidden_layers)
])
self.norm = RMSNorm(hidden_size=config.hidden_size,

View File

@ -32,6 +32,7 @@ class GatedMLP(nn.Module):
layer_idx: Optional[int] = None,
use_cute_dsl_blockscaling_mm: bool = False,
disable_deep_gemm: bool = False,
use_custom_cublas_mm: bool = False,
):
super().__init__()
@ -83,6 +84,7 @@ class GatedMLP(nn.Module):
use_cute_dsl_blockscaling_mm=use_cute_dsl_blockscaling_mm,
disable_deep_gemm=disable_deep_gemm,
fused_weight_shard_indices_mapping=gateup_shard_indices_mapping,
use_custom_cublas_mm=use_custom_cublas_mm,
)
self.down_lora = LoraLayer([LoraModuleType.MLP_4H_TO_H],
@ -103,6 +105,7 @@ class GatedMLP(nn.Module):
force_dynamic_quantization=config.force_dynamic_quantization,
use_cute_dsl_blockscaling_mm=use_cute_dsl_blockscaling_mm,
disable_deep_gemm=disable_deep_gemm,
use_custom_cublas_mm=use_custom_cublas_mm,
)
# These two modules are mutually exclusive - either splitted_gate_up_lora or fused_gate_up_lora will be used,