/* * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef CUDART_VERSION #error CUDART_VERSION Undefined! #elif (CUDART_VERSION >= 11050) #include #else #include "3rdparty/cub/cub.cuh" #endif #include "tensorrt_llm/runtime/utils/debugUtils.h" #include "tensorrt_llm/common/cudaUtils.h" #include "tensorrt_llm/common/memoryUtils.h" #include #include namespace { template __global__ void checkTensorInvalidKernel(T const* data, std::size_t size, int* foundInvalid) { auto tidx = blockIdx.x * blockDim.x + threadIdx.x; int32_t found = 0; for (auto idx = tidx; idx < size; idx += blockDim.x * gridDim.x) { auto value = static_cast(data[idx]); if (isnan(value) || isinf(value)) { found = 1; break; } } typedef cub::BlockReduce BlockReduceT; // Allocate shared memory for BlockReduce __shared__ typename BlockReduceT::TempStorage tempStorage; // Compute block-wide maximum int blockFound = BlockReduceT(tempStorage).Reduce(found, cuda::maximum()); // Have thread 0 write out block's result if (threadIdx.x == 0) { atomicCAS(foundInvalid, 0, blockFound); } } __global__ void stallStreamKernel(int const microSeconds) { #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) for (int i = 0; i < microSeconds; ++i) { __nanosleep(1000); } #endif } } // namespace using namespace tensorrt_llm::runtime; namespace tc = tensorrt_llm::common; namespace tensorrt_llm::runtime::utils { template void invokeCheckTensorInvalidKernel(T const* data, std::size_t size, int* foundInvalid, cudaStream_t stream) { constexpr uint32_t kThreadsPerCta = 256; checkTensorInvalidKernel <<>>(data, size, foundInvalid); } template void invokeCheckTensorInvalidKernel( float const* data, std::size_t size, int* foundInvalid, cudaStream_t stream); template void invokeCheckTensorInvalidKernel( half const* data, std::size_t size, int* foundInvalid, cudaStream_t stream); template void invokeCheckTensorInvalidKernel( __nv_bfloat16 const* data, std::size_t size, int* foundInvalid, cudaStream_t stream); template void invokeCheckTensorInvalidKernel( __nv_fp8_e4m3 const* data, std::size_t size, int* foundInvalid, cudaStream_t stream); template void printLogitsKeyInfo(ITensor const& tensor, std::string const& infoStr) { auto const& shape = tensor.getShape(); auto const volume = ITensor::volume(shape); BufferManager::ITensorPtr host{}; T const* hostData; if (tensor.getMemoryType() == MemoryType::kGPU) { auto streamPtr = std::make_shared(); BufferManager manager{streamPtr}; host = manager.copyFrom(tensor, MemoryType::kCPU); streamPtr->synchronize(); hostData = bufferCast(*host); } else { hostData = bufferCast(tensor); } std::stringstream ss; ss << infoStr; ss << " Shape: " << shape; ss << "; Top 5: "; for (size_t ki = 0; ki < 5; ++ki) { ss << static_cast(hostData[ki]) << ", "; } ss << " Last 5: "; for (size_t ki = volume - 6; ki < volume; ++ki) { ss << static_cast(hostData[ki]) << ", "; } // find max, min, avg double mSum = 0.f; float mMax = -FLT_MAX; float mMin = FLT_MAX; for (size_t ki = 0; ki < volume; ++ki) { float value = static_cast(hostData[ki]); mSum += value; if (value > mMax) { mMax = value; } if (value < mMin) { mMin = value; } } float mAvg = mSum / volume; ss << " avg: " << mAvg << ", min: " << mMin << ", max: " << mMax << std::endl; TLLM_LOG_TRACE(ss.str()); } template void printLogitsKeyInfo(ITensor const& tensor, std::string const& infoStr); template void printLogitsKeyInfo(ITensor const& tensor, std::string const& infoStr); template void printLogitsKeyInfo<__nv_bfloat16>(ITensor const& tensor, std::string const& infoStr); template void printLogitsKeyInfo<__nv_fp8_e4m3>(ITensor const& tensor, std::string const& infoStr); template bool tensorHasInvalid(ITensor const& tensor, BufferManager const& manager, std::string const& infoStr) { printLogitsKeyInfo(tensor, infoStr); auto foundInvalid = BufferManager::pinnedPool(ITensor::makeShape({1}), nvinfer1::DataType::kINT32); auto foundInvalidPtr = bufferCast(*foundInvalid); foundInvalidPtr[0] = 0; auto const size = tensor.getSize(); invokeCheckTensorInvalidKernel(bufferCast(tensor), size, foundInvalidPtr, manager.getStream().get()); manager.getStream().synchronize(); return static_cast(foundInvalidPtr[0]); } template bool tensorHasInvalid(ITensor const& tensor, BufferManager const& manager, std::string const& infoStr); template bool tensorHasInvalid(ITensor const& tensor, BufferManager const& manager, std::string const& infoStr); template bool tensorHasInvalid<__nv_bfloat16>( ITensor const& tensor, BufferManager const& manager, std::string const& infoStr); template bool tensorHasInvalid<__nv_fp8_e4m3>( ITensor const& tensor, BufferManager const& manager, std::string const& infoStr); bool tensorHasInvalid( size_t M, size_t K, nvinfer1::DataType type, void const* data, cudaStream_t stream, std::string const& infoStr) { auto tensorView = ITensor::wrap( const_cast(data), type, ITensor::makeShape({static_cast(M), static_cast(K)})); auto manager = BufferManager(std::make_shared(stream)); if (type == nvinfer1::DataType::kFLOAT) { return tensorHasInvalid(*tensorView, manager, infoStr); } else if (type == nvinfer1::DataType::kHALF) { return tensorHasInvalid(*tensorView, manager, infoStr); } else if (type == nvinfer1::DataType::kBF16) { return tensorHasInvalid<__nv_bfloat16>(*tensorView, manager, infoStr); } else if (type == nvinfer1::DataType::kFP8) { return tensorHasInvalid<__nv_fp8_e4m3>(*tensorView, manager, infoStr); } else { TLLM_THROW("Not supported type for Nan check"); } } int stallStream(char const* name, std::optional stream, std::optional delay) { int delay_val = 0; if (delay) { delay_val = delay.value(); } else { char const* const env = std::getenv(name); if (env != nullptr) { delay_val = std::stoi(env); } } if (stream && delay_val > 0) { stallStreamKernel<<<1, 32, 0, stream.value()>>>(delay_val); } return delay_val; } } // namespace tensorrt_llm::runtime::utils