TensorRT-LLMs/cpp/kernels/xqa/ref.py
Kanghwan 41e5870a70
[#8476][chore] Update license (#8807)
Signed-off-by: Kanghwan Jang <861393+karljang@users.noreply.github.com>
2025-11-19 15:05:25 -08:00

164 lines
6.0 KiB
Python
Executable File

#!/usr/bin/env python3
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
def save_raw(data: np.ndarray, filename: str):
with open(filename, 'wb') as file:
file.write(data.tobytes())
headElems = 256
nbKHeads = 1
headGrpSize = 32
beamWidth = 1
nbVHeads = nbKHeads
nbQHeads = nbKHeads * headGrpSize
inputElem = np.float16
inputElemSize = np.dtype(inputElem).itemsize
cacheElem = np.int8
cacheElemSize = np.dtype(cacheElem).itemsize
batchSize = 1
seqLen = 256
kScale = 1 if cacheElemSize == 2 else 1 / 4
vScale = kScale
qkScale = (headElems**-0.5) * kScale
dataBuf = open("data.bin", 'rb').read()
cache_data = np.frombuffer(dataBuf[0:cacheElemSize * headElems * seqLen *
(nbKHeads + nbVHeads) * batchSize],
dtype=cacheElem)
offset = 0
k_shape = (batchSize, nbKHeads, seqLen, headElems)
offset_next = offset + np.prod(k_shape)
k = np.reshape(cache_data[offset:offset_next], k_shape)
offset = offset_next
v_shape = (batchSize, nbKHeads, seqLen, headElems)
offset_next = offset + np.prod(v_shape)
v = np.reshape(cache_data[offset:offset_next], v_shape)
offset = offset_next
io_data = np.frombuffer(dataBuf[cacheElemSize * offset:], dtype=inputElem)
offset = 0
input_shape = (batchSize, beamWidth, (nbQHeads + nbKHeads + nbVHeads),
headElems)
offset_next = offset + np.prod(input_shape)
input = np.reshape(io_data[offset:offset_next], input_shape)
offset = offset_next
q = np.reshape(input[:, :, :nbQHeads, :],
(batchSize, beamWidth, nbKHeads, headGrpSize, headElems))
q = np.transpose(q, axes=[0, 2, 1, 3, 4])
assert q.shape == (batchSize, nbKHeads, beamWidth, headGrpSize, headElems)
q = np.reshape(q, (batchSize, nbKHeads, beamWidth * headGrpSize, headElems))
ref = np.zeros((batchSize, nbKHeads, beamWidth, headGrpSize, headElems),
dtype=np.float16)
for req in range(batchSize):
for g in range(nbKHeads):
qk = np.mat(q[req, g]).astype(np.float32) * np.mat(k[req, g]).astype(
np.float32).T * qkScale
row_max = np.max(qk, axis=1)
qk = np.exp(qk - row_max).astype(np.float16).astype(np.float32)
row_sum = np.sum(qk, axis=1)
qk = qk / row_sum
qkv = (qk.astype(np.float32) * np.mat(v).astype(np.float32)).astype(
np.float16) * vScale
ref[req, g] = np.reshape(np.array(qkv),
(beamWidth, headGrpSize, headElems))
out_shape = (batchSize, beamWidth, nbQHeads, headElems)
offset_next = offset + np.prod(out_shape)
out = np.reshape(io_data[offset:offset_next], out_shape)
offset = offset_next
assert offset == io_data.shape[0]
ref_cpp = np.reshape(
np.frombuffer(open("ref_cpp.bin", 'rb').read(), dtype=np.float32),
ref.shape).astype(np.float16)
def is_close(a, b):
return np.max(np.abs(a - b)) < 0.01
print("maxDiff: %f\n" % np.max(np.abs(ref - ref_cpp)))
assert is_close(ref, ref_cpp)
debug_refcheck = False # only for batchSize 1 and seqLen 256
if debug_refcheck:
#tiled to emulate kernel implementation (for no ctaRowMax update)
q = np.reshape(q, (32, headElems))
save_raw(np.transpose(np.reshape(q, (32, 8, 32)), axes=[1, 0, 2]),
'q_8x32x32_f16.bin')
k = np.reshape(k, (seqLen, headElems))
save_raw(np.transpose(np.reshape(k, (4, 64, 8, 32)), axes=[2, 0, 1, 3]),
'k_8x4x64x32_f16.bin')
qk = np.mat(q.astype(np.float32)) * np.mat(k.astype(np.float32)).T
qk_tiles = np.transpose(np.reshape(np.array(qk), (32, 4, 64)),
axes=[1, 0, 2])
assert qk_tiles.shape == (4, 32, 64)
save_raw(qk_tiles, 'qk_4x32x64_f32.bin')
tile_row_max = np.max(qk_tiles, axis=2, keepdims=True)
save_raw(tile_row_max, 'tileRowMax_4x32_f32.bin')
x = np.exp(qk_tiles * qkScale - tile_row_max).astype(np.float16)
save_raw(x, 'x_4x32x64_f16.bin')
tile_row_sum = np.sum(x.astype(np.float32), axis=2, keepdims=True)
save_raw(tile_row_sum, 'tileRowSum_4x32_f32.bin')
cta_row_max = np.full((32, 1), fill_value=-np.inf)
cta_row_sum = np.zeros((32, 1))
acc1 = np.zeros((4, 32, 256), dtype=np.float32) # first dim is for steps
v = np.reshape(v, (seqLen, headElems))
save_raw(np.transpose(np.reshape(v, (8, 32, 4, 64)), axes=[0, 2, 1, 3]),
'v_8x4x32x64_f16.bin')
for i in range(4):
cta_row_max_old = cta_row_max
cta_row_max = np.maximum(cta_row_max, tile_row_max[i])
xScale = np.exp(tile_row_max[i] - cta_row_max)
x[i] = x[i] * xScale
tile_row_sum[i] = tile_row_sum[i] * xScale
acc1Scale = np.exp(cta_row_max_old - cta_row_max)
acc1[i] = acc1Scale * (acc1[i - 1] if i > 0 else np.zeros(
(32, 256), dtype=np.float32))
cta_row_sum = cta_row_sum * acc1Scale + tile_row_sum[i]
acc1[i] = acc1[i] + np.mat(x[i]).astype(np.float32) * np.mat(
v[64 * i:(64 * i + 64), :]).astype(np.float32)
save_raw(acc1, 'acc1PerStep_4x32x256_f32.bin')
ref_tiled = (acc1[3] / cta_row_sum * vScale).astype(np.float16)
save_raw(ref_tiled, 'out_32x256_f16.bin')
assert is_close(ref_tiled, np.reshape(ref, (32, headElems)))
def compute(q, k, v, kvScale, headElems):
qkScale = (headElems**-0.5) * kvScale
qk = q @ k.T * qkScale
row_max = np.max(qk, axis=1).reshape(-1, 1)
x = np.exp(qk - row_max)
row_sum = np.sum(x, axis=1).reshape(-1, 1)
x @ v * (kvScale / row_sum)
return x, row_max, row_sum