TensorRT-LLMs/tensorrt_llm/models/unet/pp/conv2d.py
Kaiyu Xie 385626572d
Update TensorRT-LLM (#2502)
* Update TensorRT-LLM

---------

Co-authored-by: 岑灿 <yunyi.hyy@alibaba-inc.com>
2024-11-26 16:51:34 +08:00

110 lines
4.2 KiB
Python
Executable File

# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ....functional import allgather, concat, conv2d, slice, stack, unsqueeze
from ....layers import Conv2d
from ....mapping import Mapping
from ....module import Module
class DistriConv2dPP(Module):
def __init__(self,
conv: Conv2d,
mapping: Mapping = Mapping(),
is_first_layer: bool = False):
super().__init__()
self.mapping = mapping
self.conv = conv
self.is_first_layer = is_first_layer
def sliced_forward(self, x):
mapping = self.mapping
b, c, h, w = x.shape
assert h % mapping.tp_size == 0
stride = self.conv.stride[0]
padding = self.conv.padding[0]
output_h = x.shape[2] // stride // mapping.tp_size
idx = mapping.tp_rank
h_begin = output_h * idx * stride - padding
h_end = output_h * (idx + 1) * stride + padding
pre_padding = [0, padding]
post_padding = [0, padding]
if h_begin < 0:
h_begin = 0
pre_padding[0] = padding
if h_end > h:
h_end = h
post_padding[0] = padding
sliced_input = slice(x, [0, 0, h_begin, 0], [b, c, h_end - h_begin, w])
return conv2d(sliced_input,
self.conv.weight.value,
None if self.conv.bias is None else self.conv.bias.value,
stride=self.conv.stride,
padding=(0, 0),
pre_padding=tuple(pre_padding),
post_padding=tuple(post_padding))
def forward(self, x, *args, **kwargs):
mapping = self.mapping
if self.is_first_layer:
full_x = x
output = self.sliced_forward(full_x)
else:
boundary_size = self.conv.padding[0]
def create_padded_x(x, boundaries):
preH = 0
postH = 0
if mapping.tp_rank == 0:
b = boundaries.select(0, mapping.tp_rank + 1).select(0, 0)
padded_x = concat([x, b], dim=2)
preH = boundary_size
elif mapping.tp_rank == mapping.tp_size - 1:
b = boundaries.select(0, mapping.tp_rank - 1).select(0, 1)
padded_x = concat([b, x], dim=2)
postH = boundary_size
else:
b0 = boundaries.select(0, mapping.tp_rank - 1).select(0, 1)
b1 = boundaries.select(0, mapping.tp_rank + 1).select(0, 0)
padded_x = concat(
[
b0,
x,
b1,
],
dim=2,
)
return padded_x, preH, postH
n, c, h, w = x.shape
b0 = slice(x, [0, 0, 0, 0], [n, c, boundary_size, w])
b1 = slice(x, [0, 0, h - boundary_size, 0],
[n, c, boundary_size, w])
boundary = stack([b0, b1], dim=0)
boundaries = allgather(unsqueeze(boundary, 0),
group=mapping.tp_group)
padded_x, preH, postH = create_padded_x(x, boundaries)
output = conv2d(padded_x,
self.conv.weight.value,
self.conv.bias.value,
stride=self.conv.stride,
pre_padding=(preH, self.conv.padding[1]),
post_padding=(postH, self.conv.padding[1]))
return output