mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None][chore] Update tinygemm kernel name (#10248)
Signed-off-by: Jonas Li <6110159+longlee0622@users.noreply.github.com>
This commit is contained in:
parent
f4f0fe85e9
commit
ecea71ca7a
@ -61,7 +61,7 @@ void launch_tinygemm2(__nv_bfloat16* gA, __nv_bfloat16* gB, __nv_bfloat16* gC, _
|
||||
int smem_size
|
||||
= STAGES * STAGE_UNROLL * (TILE_M * TILE_K * sizeof(__nv_bfloat16) + TILE_N * TILE_K * sizeof(__nv_bfloat16));
|
||||
|
||||
gpuErrChk(cudaFuncSetAttribute(kernel<WARP_TILE_M, TILE_M, TILE_N, TILE_K, STAGES, STAGE_UNROLL, PROFILE>,
|
||||
gpuErrChk(cudaFuncSetAttribute(tinygemm_kernel<WARP_TILE_M, TILE_M, TILE_N, TILE_K, STAGES, STAGE_UNROLL, PROFILE>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||
|
||||
int tiles_m = (output_features + TILE_M - 1) / TILE_M;
|
||||
@ -82,8 +82,8 @@ void launch_tinygemm2(__nv_bfloat16* gA, __nv_bfloat16* gB, __nv_bfloat16* gC, _
|
||||
attrs[0].val.programmaticStreamSerializationAllowed = 1;
|
||||
config.numAttrs = 1;
|
||||
|
||||
cudaLaunchKernelEx(&config, &kernel<WARP_TILE_M, TILE_M, TILE_N, TILE_K, STAGES, STAGE_UNROLL, PROFILE>, gC, gA, gB,
|
||||
bias, output_features, batch_size, input_features, weight_map, activation_map, nullptr);
|
||||
cudaLaunchKernelEx(&config, &tinygemm_kernel<WARP_TILE_M, TILE_M, TILE_N, TILE_K, STAGES, STAGE_UNROLL, PROFILE>,
|
||||
gC, gA, gB, bias, output_features, batch_size, input_features, weight_map, activation_map, nullptr);
|
||||
}
|
||||
|
||||
torch::Tensor tinygemm2_cuda_forward(torch::Tensor input, torch::Tensor weight, torch::Tensor bias)
|
||||
|
||||
@ -172,7 +172,7 @@ struct Profile
|
||||
};
|
||||
|
||||
template <int WARP_TILE_M, int TILE_M, int TILE_N, int TILE_K, int STAGES, int STAGE_UNROLL, bool PROFILE>
|
||||
__global__ __launch_bounds__(384, 1) void kernel(__nv_bfloat16* output, __nv_bfloat16* weights,
|
||||
__global__ __launch_bounds__(384, 1) void tinygemm_kernel(__nv_bfloat16* output, __nv_bfloat16* weights,
|
||||
__nv_bfloat16* activations, __nv_bfloat16* bias, int M, int N, int K,
|
||||
const __grid_constant__ CUtensorMap weight_map, const __grid_constant__ CUtensorMap activation_map,
|
||||
Profile* profile = nullptr)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user