TensorRT-LLMs/tests/functional/torch_ref.py
2023-09-20 00:29:41 -07:00

154 lines
5.9 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import torch
from einops import rearrange
def geglu(x):
a, b = x.chunk(2, dim=-1)
return a * torch.nn.functional.gelu(b)
def swiglu(x):
x, gate = x.chunk(2, dim=-1)
return torch.nn.functional.silu(gate) * x
def generate_qkv(x, Wqkv, nheads, kvpacked=False, qkvpacked=False):
"""
Arguments:
x: (batch_size, seqlen, nheads * d)
Wqkv: nn.Linear(nheads * d, 3 * nheads * d)
"""
assert not (kvpacked and qkvpacked)
batch_size, seqlen, dim = x.shape
q, k, v = Wqkv(x).chunk(3, dim=-1)
q_unpad = rearrange(q, 'b s (h d) -> (b s) h d', h=nheads)
cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen,
step=seqlen,
dtype=torch.int32,
device=q_unpad.device)
max_seqlen_q = seqlen
output_pad_fn = lambda output_unpad: rearrange(
output_unpad, '(b s) h d -> b s h d', b=batch_size)
k_unpad = rearrange(k, 'b s (h d) -> (b s) h d', h=nheads)
v_unpad = rearrange(v, 'b s (h d) -> (b s) h d', h=nheads)
cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen,
step=seqlen,
dtype=torch.int32,
device=q_unpad.device)
max_seqlen_k = seqlen
if qkvpacked:
qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
qkv = rearrange(torch.stack([q, k, v], dim=2),
'b s t (h d) -> b s t h d',
h=nheads)
dqkv_pad_fn = lambda dqkv_unpad: rearrange(
dqkv_unpad, '(b s) t h d -> b s t h d', b=batch_size)
return (qkv_unpad, cu_seqlens_q, max_seqlen_q, qkv, output_pad_fn,
dqkv_pad_fn)
elif kvpacked:
kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
q = rearrange(q, 'b s (h d) -> b s h d', h=nheads)
kv = rearrange(torch.stack([k, v], dim=2),
'b s t (h d) -> b s t h d',
h=nheads)
dq_pad_fn = output_pad_fn
dkv_pad_fn = lambda dkv_unpad: rearrange(
dkv_unpad, '(b s) t h d -> b s t h d', b=batch_size)
return (q_unpad, kv_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q,
max_seqlen_k, q, kv, output_pad_fn, dq_pad_fn, dkv_pad_fn)
else:
q, k, v = [
rearrange(z, 'b s (h d) -> b s h d', h=nheads) for z in [q, k, v]
]
dq_pad_fn = output_pad_fn
dk_pad_fn = lambda dk_unpad: rearrange(
dk_unpad, '(b s) h d -> b s h d', b=batch_size)
return (q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k,
max_seqlen_q, max_seqlen_k, q, k, v, output_pad_fn, dq_pad_fn,
dk_pad_fn)
def attention_ref(q,
k,
v,
causal=False,
bias=None,
upcast=True,
reorder_ops=False):
"""
Arguments:
q: (batch_size, seqlen_q, nheads, head_dim)
k: (batch_size, seqlen_k, nheads, head_dim)
v: (batch_size, seqlen_k, nheads, head_dim)
bias: (batch_size, nheads, seqlen_q, seqlen_k)
upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast
output back to fp16/bf16.
reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.)
without changing the math. This is to estimate the numerical error from operation
reordering.
Output:
output: (batch_size, seqlen_q, nheads, head_dim)
attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout
"""
dtype_og = q.dtype
if upcast:
q, k, v = q.float(), k.float(), v.float()
seqlen_q, seqlen_k = q.shape[1], k.shape[1]
d = q.shape[-1]
if not reorder_ops:
scores = torch.einsum('bthd,bshd->bhts', q / math.sqrt(d), k)
else:
scores = torch.einsum('bthd,bshd->bhts', q, k / math.sqrt(d))
if bias is not None:
scores = (scores + bias).to(dtype=scores.dtype)
if causal:
causal_mask = torch.triu(
torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device),
1)
scores.masked_fill_(causal_mask, float('-inf'))
attention = torch.softmax(scores, dim=-1)
output = torch.einsum('bhts,bshd->bthd', attention, v)
return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)
def attention_kvpacked_ref(q, kv, causal=False, upcast=True, reorder_ops=False):
return attention_ref(q,
kv[:, :, 0],
kv[:, :, 1],
upcast=upcast,
causal=causal,
reorder_ops=reorder_ops)
def attention_qkvpacked_ref(qkv,
causal=False,
bias=None,
upcast=True,
reorder_ops=False):
return attention_ref(qkv[:, :, 0],
qkv[:, :, 1],
qkv[:, :, 2],
upcast=upcast,
causal=causal,
bias=bias,
reorder_ops=reorder_ops)