mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[TRTLLM-10358][feat] Added proper rescaling of FP4 weights (#10378)
Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
This commit is contained in:
parent
c0b3c2b919
commit
0d1f5ad7a2
@ -304,6 +304,16 @@ def is_any_lin_op(node: Node) -> bool:
|
||||
return is_linear_op(node) or is_fake_quantized_linear_op(node)
|
||||
|
||||
|
||||
def is_fp4_op(node: Node) -> bool:
|
||||
return is_op(
|
||||
node,
|
||||
[
|
||||
torch.ops.auto_deploy.torch_quant_nvfp4_linear,
|
||||
torch.ops.auto_deploy.torch_fake_quant_nvfp4_linear,
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def is_any_moe_op(node: Node) -> bool:
|
||||
return is_op(
|
||||
node,
|
||||
@ -733,16 +743,20 @@ def subgraph(
|
||||
return subgraph_nodes
|
||||
|
||||
|
||||
def get_weight_shape(
|
||||
node: Node, dim: Optional[int] = None
|
||||
) -> Optional[Union[int, Tuple[int, ...]]]:
|
||||
def get_weight_shape(node: Node, dim: Optional[int] = None) -> Optional[Union[int, List[int]]]:
|
||||
"""Get the shape of the weight node."""
|
||||
if not is_any_lin_op(node):
|
||||
return None
|
||||
s = list(shape(extract_weight_node(node)))
|
||||
if len(s) == 0:
|
||||
return None
|
||||
if is_fp4_op(node):
|
||||
# FP4 weights are packed as uint8 type with 2 FP4 values per element
|
||||
s[-1] *= 2
|
||||
if dim is None:
|
||||
return shape(extract_weight_node(node))
|
||||
return s
|
||||
else:
|
||||
return shape(extract_weight_node(node))[dim]
|
||||
return s[dim]
|
||||
|
||||
|
||||
def get_layer_after_linear_node(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user