TensorRT-LLMs/tensorrt_llm/tools/plugin_gen/shape_infer.py
Kaiyu Xie be9cd719f7
Update TensorRT-LLM (#2094)
* Update TensorRT-LLM

---------

Co-authored-by: akhoroshev <arthoroshev@gmail.com>
Co-authored-by: Fabian Joswig <fjosw@users.noreply.github.com>
Co-authored-by: Tayef Shah <tayefshah@gmail.com>
Co-authored-by: lfz941 <linfanzai941@gmail.com>
2024-08-07 16:44:43 +08:00

356 lines
12 KiB
Python

from dataclasses import dataclass, field
from enum import Enum
from typing import TYPE_CHECKING, Dict, List, Tuple, Union
from lark import Lark, Token, Tree
if TYPE_CHECKING:
from tensorrt_llm.tools.plugin_gen.core import Argument
parser = Lark(r"""
value: SIGNED_NUMBER
| name
| expr
| "(" expr ")"
expr: value "+" value -> add
| value "-" value -> sub
| value "*" value -> mul
| value "/" value -> div
| value "///" value -> cdiv
| value
shaped_tensor: name "[" value ("," value)* ("," "*")? "]" -> tensor
| name "[" "*" "]" -> wildcard_tensor
tensors: shaped_tensor ("," shaped_tensor)*
deduce_shape: tensors "->" tensors
deduce_dim_size_arg: tensors ":" expr "->" name
name: CNAME
?start: deduce_shape | deduce_dim_size_arg
%import common.SIGNED_NUMBER
%import common.WS
%import common.CNAME
%ignore WS
""".strip())
class TargetType(Enum):
CONCRETE = 0 # to produce size_t
SYMBOLIC = 1 # to produce IDimensionExpr
# Here we introduce a set of ASTs to represent the target's expression.
# The Ast nodes from lark is not convenient to use.
class _AST:
pass
@dataclass
class NumberAST(_AST):
value: int
target_type: TargetType = TargetType.CONCRETE
@dataclass
class BinaryAST(_AST):
op: str
left: _AST
right: _AST
target_type: TargetType = TargetType.CONCRETE
@dataclass
class ShapeAST:
dims: List[_AST]
@dataclass
class DimAST(_AST):
name: str
@dataclass
class ShapedTensorAST(_AST):
arg_name: str
shape: ShapeAST
@dataclass
class DeduceShapeRule(_AST):
left: List[ShapedTensorAST]
right: List[ShapedTensorAST]
@dataclass
class DeduceDimSizeArgRule(_AST):
left: List[ShapedTensorAST]
expr: _AST
right: str
class ToAst:
def __call__(self,
tree: Tree) -> Union[DeduceShapeRule, DeduceDimSizeArgRule]:
if tree.data == "deduce_shape":
assert len(tree.children) == 2
return self.visit_DeduceShape(tree.children[0], tree.children[1])
elif tree.data == "deduce_dim_size_arg":
assert len(tree.children) == 3
return self.visit_DeduceDimSizeArg(tree.children[0],
tree.children[1],
tree.children[2])
raise NotImplementedError()
def visit_DeduceShape(self, left: Tree, right: Tree) -> DeduceShapeRule:
assert left.data == "tensors"
assert right.data == "tensors"
lefts = self.visit_tensors(left, TargetType.SYMBOLIC)
rights = self.visit_tensors(right, TargetType.SYMBOLIC)
return DeduceShapeRule(lefts, rights)
def visit_DeduceDimSizeArg(self, left: Tree, expr: Tree,
right: Tree) -> DeduceDimSizeArgRule:
lefts = self.visit_tensors(left, TargetType.CONCRETE)
_expr = self.visit_expr(expr, TargetType.CONCRETE)
rights = self.visit_name(right)
return DeduceDimSizeArgRule(lefts, _expr, rights)
def visit_tensors(self, tree: Tree,
target_type: TargetType) -> List[ShapedTensorAST]:
assert tree.data == "tensors", repr(tree)
return [
self.visit_tensor(child, target_type) for child in tree.children
]
def visit_tensor(self, tree: Tree,
target_type: TargetType) -> ShapedTensorAST:
if tree.data == "tensor":
arg_name = self.visit_name(tree.children[0])
dims = [
self.visit_expr(child, target_type)
for child in tree.children[1:]
]
return ShapedTensorAST(arg_name, ShapeAST(dims))
assert tree.data == "wildcard_tensor", repr(tree)
arg_name = self.visit_name(tree.children[0])
return ShapedTensorAST(arg_name, ShapeAST([DimAST("*")]))
def visit_number(self, v: str) -> _AST:
return NumberAST(int(v))
def visit_expr(self, tree: Tree, target_type: TargetType) -> _AST:
'''
for expression of dims, like `m * 2 + 1`
'''
def visit(tree: Union[Tree, Token]) -> _AST:
if isinstance(tree, Token):
if tree.type == "SIGNED_NUMBER":
return NumberAST(int(tree.value), target_type)
elif tree.type == "CNAME":
return DimAST(tree.value)
raise ValueError("Unexpected token: %s" % tree)
elif isinstance(tree.data, Token): # RULE; CNAME
tree_type = tree.data.value
if tree_type == 'name':
return DimAST(tree.children[0].value)
elif tree_type == 'value':
return visit(tree.children[0])
elif tree_type == 'expr':
return visit(tree.children[0])
elif tree.data == "SIGNED_NUMBER":
return NumberAST(int(tree.children[0].data))
else:
raise ValueError(f"Unexpected tree: {repr(tree)}")
# (add, sub, mul) have operator overloading for IDimensionExpr
# no need to do anything special
elif tree.data == "add":
assert len(tree.children) == 2
return BinaryAST("+", visit(tree.children[0]),
visit(tree.children[1]))
elif tree.data == "sub":
assert len(tree.children) == 2
return BinaryAST("-", visit(tree.children[0]),
visit(tree.children[1]))
elif tree.data == "mul":
assert len(tree.children) == 2
return BinaryAST("*", visit(tree.children[0]),
visit(tree.children[1]))
elif tree.data == "div":
assert len(tree.children) == 2
return BinaryAST("/", visit(tree.children[0]),
visit(tree.children[1]), target_type)
elif tree.data == "cdiv":
assert len(tree.children) == 2
return BinaryAST("///", visit(tree.children[0]),
visit(tree.children[1]), target_type)
else:
raise ValueError(f"Unexpected tree: {repr(tree)}")
return visit(tree)
def visit_name(self, tree: Tree) -> str:
assert isinstance(tree.data, Token) and tree.data.value == "name"
return tree.children[0].value
@dataclass
class Dim:
arg: "Argument"
dim_off: int
@dataclass
class CppCodeTranspiler:
# The mapping from a arg_name in the expression to the corresponding Argument.
name_to_arg: Dict[str, "Argument"]
# The mapping from a dim_name in the expression to the corresponding Dim in an Argument.
name_to_dim: Dict[str, Dim] = field(default_factory=dict, init=False)
def __call__(self, exprs: List[str]) -> Tuple[List[str], Dict[str, str]]:
asts = [self.to_ast(expr) for expr in exprs]
return self.codegen(asts)
def to_ast(self, expr: str) -> _AST:
self.cur_expr = expr
ast = parser.parse(expr)
ast = ToAst()(ast)
return ast
def codegen(self, asts: List[_AST]) -> Tuple[List[str], Dict[str, str]]:
'''
Parse an expression group and generate the corresponding C++ code.
The syntax of an expression is like below:
- `name[expr, expr, ...] -> name[expr, expr, ...]`
- `name[expr, expr, ...]:expr -> dim_arg`
'''
shape_infer_code = []
dim_size_infer_code = {}
for ast in asts:
if isinstance(ast, DeduceShapeRule):
self.dim_cpp_repr = lambda arg_idx, dim_idx: f"inputDims[{arg_idx}].d[{dim_idx}]"
shape_infer_code.extend(self.emit_DeduceShapeRule(ast))
elif isinstance(ast, DeduceDimSizeArgRule):
self.dim_cpp_repr = lambda arg_idx, dim_idx: f"inputDesc[{arg_idx}].dims.d[{dim_idx}]"
dim_size_infer_code[ast.right] = self.emit_DeduceDimSizeArgRule(
ast)
else:
raise ValueError("Unexpected ast: %s" % repr(ast))
return shape_infer_code, dim_size_infer_code
@staticmethod
def is_cur_identical_dims(item: ShapedTensorAST):
return len(item.shape.dims) == 1 and isinstance(
item.shape.dims[0], DimAST) and item.shape.dims[0].name == "*"
def collect_dims_from_left(self, lefts: List[ShapedTensorAST]):
self.name_to_dim.clear()
is_left_identical_dims = self.is_cur_identical_dims(lefts[0])
# process left, and record the named dimensions
for left in lefts:
arg_name = left.arg_name
argument = self.name_to_arg[arg_name]
for off, dim in enumerate(left.shape.dims):
assert isinstance(
dim, DimAST
), f"Wrong syntax in '{self.cur_expr}', for deduce_shape rule, each named dimension should be a name rather than an expression"
self.name_to_dim[dim.name] = Dim(argument, off)
return is_left_identical_dims
def emit_DeduceShapeRule(self, rule: DeduceShapeRule) -> List[str]:
from tensorrt_llm.tools.plugin_gen.core import code
is_cur_identical_dims = lambda item: len(
item.shape.dims) == 1 and isinstance(item.shape.dims[
0], DimAST) and item.shape.dims[0].name == "*"
is_left_identical_dims = self.collect_dims_from_left(rule.left)
first_left_tensor = rule.left[0]
first_left_tensor_arg = self.name_to_arg[first_left_tensor.arg_name]
ret = []
# process right, and generate the code for each dimensions
# TODO: support more wildcard cases, currently only A[*] -> B[*], C[*] is supported
is_right_identical_dims = False
for off, item in enumerate(rule.right):
is_cur_identical_dims = self.is_cur_identical_dims(item)
if is_right_identical_dims and not is_cur_identical_dims:
assert is_cur_identical_dims, "Wrong syntax in '%s', for deduce_shape rule, once the left side be X[*], the should all be X[*] format too" % self.cur_expr
is_right_identical_dims = is_cur_identical_dims
assert is_left_identical_dims == is_right_identical_dims, "Wrong syntax in '%s', for deduce_shape rule, the left and right side should be both X[*] or not" % self.cur_expr
for off, tensor in enumerate(rule.right):
out_arg = self.name_to_arg[tensor.arg_name]
ret.append(code(f"if (outputIndex == {out_arg.offset}) {{"))
if is_right_identical_dims:
ret.append(
code(
f" outputDims = inputDims[{first_left_tensor_arg.offset}];"
))
else:
ret.append(
code(f" outputDims.nbDims = {len(tensor.shape.dims)};"))
for dim_off, dim in enumerate(tensor.shape.dims):
ret.append(
code(
f" outputDims.d[{dim_off}] = {self.emit_expr(dim)};"
))
ret.append(code(f"}}"))
return ret
def emit_DeduceDimSizeArgRule(self, rule: DeduceDimSizeArgRule) -> str:
self.collect_dims_from_left(rule.left)
return self.emit_expr(rule.expr)
def emit_expr(self, expr: _AST) -> str:
if isinstance(expr, NumberAST):
if expr.target_type == TargetType.SYMBOLIC:
return f"exprBuilder.constant({expr.value})"
else:
return str(expr.value)
elif isinstance(expr, DimAST):
return self.emit_dim(expr)
elif isinstance(expr, BinaryAST):
return self.emit_binary(expr)
raise ValueError("Unexpected expr: %s" % expr)
def emit_dim(self, dim: DimAST) -> str:
dim_: Dim = self.name_to_dim[dim.name]
repr = self.dim_cpp_repr(dim_.arg.offset, dim_.dim_off)
return repr
def emit_binary(self, binary: BinaryAST) -> str:
left = self.emit_expr(binary.left)
right = self.emit_expr(binary.right)
if binary.op == "/" and binary.target_type == TargetType.SYMBOLIC:
return f"exprBuilder.operation(nvinfer1::DimensionOperation::kFLOOR_DIV, *{left}, *{right})"
elif binary.op == "///":
if binary.target_type == TargetType.SYMBOLIC:
return f"exprBuilder.operation(nvinfer1::DimensionOperation::kCEIL_DIV, *{left}, *{right})"
else:
return f"(({left} + {right} - 1) / {right})"
else:
return f"({left} {binary.op} {right})"