mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
164 lines
6.0 KiB
Python
Executable File
164 lines
6.0 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import numpy as np
|
|
|
|
|
|
def save_raw(data: np.ndarray, filename: str):
|
|
with open(filename, 'wb') as file:
|
|
file.write(data.tobytes())
|
|
|
|
|
|
headElems = 256
|
|
nbKHeads = 1
|
|
headGrpSize = 32
|
|
|
|
beamWidth = 1
|
|
|
|
nbVHeads = nbKHeads
|
|
nbQHeads = nbKHeads * headGrpSize
|
|
inputElem = np.float16
|
|
inputElemSize = np.dtype(inputElem).itemsize
|
|
cacheElem = np.int8
|
|
cacheElemSize = np.dtype(cacheElem).itemsize
|
|
|
|
batchSize = 1
|
|
seqLen = 256
|
|
|
|
kScale = 1 if cacheElemSize == 2 else 1 / 4
|
|
vScale = kScale
|
|
qkScale = (headElems**-0.5) * kScale
|
|
|
|
dataBuf = open("data.bin", 'rb').read()
|
|
cache_data = np.frombuffer(dataBuf[0:cacheElemSize * headElems * seqLen *
|
|
(nbKHeads + nbVHeads) * batchSize],
|
|
dtype=cacheElem)
|
|
offset = 0
|
|
|
|
k_shape = (batchSize, nbKHeads, seqLen, headElems)
|
|
offset_next = offset + np.prod(k_shape)
|
|
k = np.reshape(cache_data[offset:offset_next], k_shape)
|
|
offset = offset_next
|
|
|
|
v_shape = (batchSize, nbKHeads, seqLen, headElems)
|
|
offset_next = offset + np.prod(v_shape)
|
|
v = np.reshape(cache_data[offset:offset_next], v_shape)
|
|
offset = offset_next
|
|
|
|
io_data = np.frombuffer(dataBuf[cacheElemSize * offset:], dtype=inputElem)
|
|
offset = 0
|
|
|
|
input_shape = (batchSize, beamWidth, (nbQHeads + nbKHeads + nbVHeads),
|
|
headElems)
|
|
offset_next = offset + np.prod(input_shape)
|
|
input = np.reshape(io_data[offset:offset_next], input_shape)
|
|
offset = offset_next
|
|
|
|
q = np.reshape(input[:, :, :nbQHeads, :],
|
|
(batchSize, beamWidth, nbKHeads, headGrpSize, headElems))
|
|
q = np.transpose(q, axes=[0, 2, 1, 3, 4])
|
|
assert q.shape == (batchSize, nbKHeads, beamWidth, headGrpSize, headElems)
|
|
q = np.reshape(q, (batchSize, nbKHeads, beamWidth * headGrpSize, headElems))
|
|
|
|
ref = np.zeros((batchSize, nbKHeads, beamWidth, headGrpSize, headElems),
|
|
dtype=np.float16)
|
|
for req in range(batchSize):
|
|
for g in range(nbKHeads):
|
|
qk = np.mat(q[req, g]).astype(np.float32) * np.mat(k[req, g]).astype(
|
|
np.float32).T * qkScale
|
|
row_max = np.max(qk, axis=1)
|
|
qk = np.exp(qk - row_max).astype(np.float16).astype(np.float32)
|
|
row_sum = np.sum(qk, axis=1)
|
|
qk = qk / row_sum
|
|
qkv = (qk.astype(np.float32) * np.mat(v).astype(np.float32)).astype(
|
|
np.float16) * vScale
|
|
ref[req, g] = np.reshape(np.array(qkv),
|
|
(beamWidth, headGrpSize, headElems))
|
|
|
|
out_shape = (batchSize, beamWidth, nbQHeads, headElems)
|
|
offset_next = offset + np.prod(out_shape)
|
|
out = np.reshape(io_data[offset:offset_next], out_shape)
|
|
offset = offset_next
|
|
assert offset == io_data.shape[0]
|
|
|
|
ref_cpp = np.reshape(
|
|
np.frombuffer(open("ref_cpp.bin", 'rb').read(), dtype=np.float32),
|
|
ref.shape).astype(np.float16)
|
|
|
|
|
|
def is_close(a, b):
|
|
return np.max(np.abs(a - b)) < 0.01
|
|
|
|
|
|
print("maxDiff: %f\n" % np.max(np.abs(ref - ref_cpp)))
|
|
assert is_close(ref, ref_cpp)
|
|
|
|
debug_refcheck = False # only for batchSize 1 and seqLen 256
|
|
if debug_refcheck:
|
|
#tiled to emulate kernel implementation (for no ctaRowMax update)
|
|
q = np.reshape(q, (32, headElems))
|
|
save_raw(np.transpose(np.reshape(q, (32, 8, 32)), axes=[1, 0, 2]),
|
|
'q_8x32x32_f16.bin')
|
|
k = np.reshape(k, (seqLen, headElems))
|
|
save_raw(np.transpose(np.reshape(k, (4, 64, 8, 32)), axes=[2, 0, 1, 3]),
|
|
'k_8x4x64x32_f16.bin')
|
|
qk = np.mat(q.astype(np.float32)) * np.mat(k.astype(np.float32)).T
|
|
qk_tiles = np.transpose(np.reshape(np.array(qk), (32, 4, 64)),
|
|
axes=[1, 0, 2])
|
|
assert qk_tiles.shape == (4, 32, 64)
|
|
save_raw(qk_tiles, 'qk_4x32x64_f32.bin')
|
|
tile_row_max = np.max(qk_tiles, axis=2, keepdims=True)
|
|
save_raw(tile_row_max, 'tileRowMax_4x32_f32.bin')
|
|
x = np.exp(qk_tiles * qkScale - tile_row_max).astype(np.float16)
|
|
save_raw(x, 'x_4x32x64_f16.bin')
|
|
tile_row_sum = np.sum(x.astype(np.float32), axis=2, keepdims=True)
|
|
save_raw(tile_row_sum, 'tileRowSum_4x32_f32.bin')
|
|
|
|
cta_row_max = np.full((32, 1), fill_value=-np.inf)
|
|
cta_row_sum = np.zeros((32, 1))
|
|
acc1 = np.zeros((4, 32, 256), dtype=np.float32) # first dim is for steps
|
|
v = np.reshape(v, (seqLen, headElems))
|
|
save_raw(np.transpose(np.reshape(v, (8, 32, 4, 64)), axes=[0, 2, 1, 3]),
|
|
'v_8x4x32x64_f16.bin')
|
|
for i in range(4):
|
|
cta_row_max_old = cta_row_max
|
|
cta_row_max = np.maximum(cta_row_max, tile_row_max[i])
|
|
xScale = np.exp(tile_row_max[i] - cta_row_max)
|
|
x[i] = x[i] * xScale
|
|
tile_row_sum[i] = tile_row_sum[i] * xScale
|
|
acc1Scale = np.exp(cta_row_max_old - cta_row_max)
|
|
acc1[i] = acc1Scale * (acc1[i - 1] if i > 0 else np.zeros(
|
|
(32, 256), dtype=np.float32))
|
|
cta_row_sum = cta_row_sum * acc1Scale + tile_row_sum[i]
|
|
acc1[i] = acc1[i] + np.mat(x[i]).astype(np.float32) * np.mat(
|
|
v[64 * i:(64 * i + 64), :]).astype(np.float32)
|
|
|
|
save_raw(acc1, 'acc1PerStep_4x32x256_f32.bin')
|
|
ref_tiled = (acc1[3] / cta_row_sum * vScale).astype(np.float16)
|
|
save_raw(ref_tiled, 'out_32x256_f16.bin')
|
|
|
|
assert is_close(ref_tiled, np.reshape(ref, (32, headElems)))
|
|
|
|
|
|
def compute(q, k, v, kvScale, headElems):
|
|
qkScale = (headElems**-0.5) * kvScale
|
|
qk = q @ k.T * qkScale
|
|
row_max = np.max(qk, axis=1).reshape(-1, 1)
|
|
x = np.exp(qk - row_max)
|
|
row_sum = np.sum(x, axis=1).reshape(-1, 1)
|
|
x @ v * (kvScale / row_sum)
|
|
return x, row_max, row_sum
|