[None][fix] Fix Piecewise Cuda Graph for GPTOSS (#10631)

Signed-off-by: Dongfeng Yu <dongfengy@nvidia.com>
This commit is contained in:
dongfengy 2026-01-16 15:47:34 +08:00 committed by GitHub
parent 0256c7234f
commit 6dfb8d7084
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -226,6 +226,81 @@ def _register_fake():
return (input.new_empty(output_shape, dtype=torch.uint8),
global_scale.new_empty(scale_shape, dtype=torch.uint8))
@torch.library.register_fake("trtllm::mxfp8_quantize")
def _(
input: torch.Tensor,
swizzled_layout: bool = True,
alignment: int = 32,
):
SF_VEC_SIZE = 32
def pad_up(x, m: int):
return (x + m - 1) // m * m
m_val = 1
for d in input.shape[:-1]:
m_val = m_val * d
k = input.shape[-1]
padded_k = pad_up(k, alignment)
out_shape = list(input.shape)
out_shape[-1] = padded_k
# Output tensor: float8_e4m3fn, last dim padded to alignment
val_mxfp8 = input.new_empty(out_shape, dtype=torch.float8_e4m3fn)
# Scale tensor: 1D uint8, size depends on swizzled vs linear layout
cols = padded_k // SF_VEC_SIZE
if swizzled_layout:
sf_size = pad_up(m_val, 128) * pad_up(cols, 4)
else:
sf_size = m_val * cols
scale_fp8_sf = input.new_empty((sf_size, ), dtype=torch.uint8)
return val_mxfp8, scale_fp8_sf
@torch.library.register_fake("trtllm::mxe4m3_mxe2m1_block_scale_moe_runner")
def _(
routing_logits: Optional[torch.Tensor],
routing_bias: Optional[torch.Tensor],
hidden_states: torch.Tensor,
hidden_states_scale: torch.Tensor,
gemm1_weights: torch.Tensor,
gemm1_weights_scale: torch.Tensor,
gemm1_bias: Optional[torch.Tensor],
gemm1_alpha: Optional[torch.Tensor],
gemm1_beta: Optional[torch.Tensor],
gemm1_clamp_limit: Optional[torch.Tensor],
gemm2_weights: torch.Tensor,
gemm2_weights_scale: torch.Tensor,
gemm2_bias: Optional[torch.Tensor],
num_experts: int,
top_k: int,
n_group: Optional[int],
topk_group: Optional[int],
intermediate_size: int,
valid_hidden_size: Optional[int],
valid_intermediate_size: Optional[int],
local_expert_offset: int,
local_num_experts: int,
routed_scaling_factor: Optional[float],
routing_method_type: int,
act_type: int,
topk_weights: Optional[torch.Tensor] = None,
topk_ids: Optional[torch.Tensor] = None,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
num_tokens = hidden_states.shape[0]
hidden_size = hidden_states.shape[1]
out_hidden_size = valid_hidden_size if valid_hidden_size is not None else hidden_size
if output is not None:
return output
return hidden_states.new_empty((num_tokens, out_hidden_size),
dtype=torch.bfloat16)
@torch.library.register_fake("trtllm::calculate_nvfp4_global_scale")
def _(input: torch.Tensor, tokens_per_batch: Optional[torch.Tensor]):
return input.new_empty((input.shape[:-1], 1), dtype=torch.float32)