[TRTLLM-10358][feat] Added proper rescaling of FP4 weights (#10378)

Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com>
This commit is contained in:
Grzegorz Kwasniewski 2026-01-03 22:26:16 +01:00 committed by GitHub
parent c0b3c2b919
commit 0d1f5ad7a2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -304,6 +304,16 @@ def is_any_lin_op(node: Node) -> bool:
return is_linear_op(node) or is_fake_quantized_linear_op(node)
def is_fp4_op(node: Node) -> bool:
return is_op(
node,
[
torch.ops.auto_deploy.torch_quant_nvfp4_linear,
torch.ops.auto_deploy.torch_fake_quant_nvfp4_linear,
],
)
def is_any_moe_op(node: Node) -> bool:
return is_op(
node,
@ -733,16 +743,20 @@ def subgraph(
return subgraph_nodes
def get_weight_shape(
node: Node, dim: Optional[int] = None
) -> Optional[Union[int, Tuple[int, ...]]]:
def get_weight_shape(node: Node, dim: Optional[int] = None) -> Optional[Union[int, List[int]]]:
"""Get the shape of the weight node."""
if not is_any_lin_op(node):
return None
s = list(shape(extract_weight_node(node)))
if len(s) == 0:
return None
if is_fp4_op(node):
# FP4 weights are packed as uint8 type with 2 FP4 values per element
s[-1] *= 2
if dim is None:
return shape(extract_weight_node(node))
return s
else:
return shape(extract_weight_node(node))[dim]
return s[dim]
def get_layer_after_linear_node(