mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-03 17:52:19 +08:00
[None][fix] Fix Piecewise Cuda Graph for GPTOSS (#10631)
Signed-off-by: Dongfeng Yu <dongfengy@nvidia.com>
This commit is contained in:
parent
0256c7234f
commit
6dfb8d7084
@ -226,6 +226,81 @@ def _register_fake():
|
||||
return (input.new_empty(output_shape, dtype=torch.uint8),
|
||||
global_scale.new_empty(scale_shape, dtype=torch.uint8))
|
||||
|
||||
@torch.library.register_fake("trtllm::mxfp8_quantize")
|
||||
def _(
|
||||
input: torch.Tensor,
|
||||
swizzled_layout: bool = True,
|
||||
alignment: int = 32,
|
||||
):
|
||||
SF_VEC_SIZE = 32
|
||||
|
||||
def pad_up(x, m: int):
|
||||
return (x + m - 1) // m * m
|
||||
|
||||
m_val = 1
|
||||
for d in input.shape[:-1]:
|
||||
m_val = m_val * d
|
||||
|
||||
k = input.shape[-1]
|
||||
padded_k = pad_up(k, alignment)
|
||||
|
||||
out_shape = list(input.shape)
|
||||
out_shape[-1] = padded_k
|
||||
|
||||
# Output tensor: float8_e4m3fn, last dim padded to alignment
|
||||
val_mxfp8 = input.new_empty(out_shape, dtype=torch.float8_e4m3fn)
|
||||
|
||||
# Scale tensor: 1D uint8, size depends on swizzled vs linear layout
|
||||
cols = padded_k // SF_VEC_SIZE
|
||||
if swizzled_layout:
|
||||
sf_size = pad_up(m_val, 128) * pad_up(cols, 4)
|
||||
else:
|
||||
sf_size = m_val * cols
|
||||
|
||||
scale_fp8_sf = input.new_empty((sf_size, ), dtype=torch.uint8)
|
||||
return val_mxfp8, scale_fp8_sf
|
||||
|
||||
@torch.library.register_fake("trtllm::mxe4m3_mxe2m1_block_scale_moe_runner")
|
||||
def _(
|
||||
routing_logits: Optional[torch.Tensor],
|
||||
routing_bias: Optional[torch.Tensor],
|
||||
hidden_states: torch.Tensor,
|
||||
hidden_states_scale: torch.Tensor,
|
||||
gemm1_weights: torch.Tensor,
|
||||
gemm1_weights_scale: torch.Tensor,
|
||||
gemm1_bias: Optional[torch.Tensor],
|
||||
gemm1_alpha: Optional[torch.Tensor],
|
||||
gemm1_beta: Optional[torch.Tensor],
|
||||
gemm1_clamp_limit: Optional[torch.Tensor],
|
||||
gemm2_weights: torch.Tensor,
|
||||
gemm2_weights_scale: torch.Tensor,
|
||||
gemm2_bias: Optional[torch.Tensor],
|
||||
num_experts: int,
|
||||
top_k: int,
|
||||
n_group: Optional[int],
|
||||
topk_group: Optional[int],
|
||||
intermediate_size: int,
|
||||
valid_hidden_size: Optional[int],
|
||||
valid_intermediate_size: Optional[int],
|
||||
local_expert_offset: int,
|
||||
local_num_experts: int,
|
||||
routed_scaling_factor: Optional[float],
|
||||
routing_method_type: int,
|
||||
act_type: int,
|
||||
topk_weights: Optional[torch.Tensor] = None,
|
||||
topk_ids: Optional[torch.Tensor] = None,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
num_tokens = hidden_states.shape[0]
|
||||
hidden_size = hidden_states.shape[1]
|
||||
out_hidden_size = valid_hidden_size if valid_hidden_size is not None else hidden_size
|
||||
|
||||
if output is not None:
|
||||
return output
|
||||
|
||||
return hidden_states.new_empty((num_tokens, out_hidden_size),
|
||||
dtype=torch.bfloat16)
|
||||
|
||||
@torch.library.register_fake("trtllm::calculate_nvfp4_global_scale")
|
||||
def _(input: torch.Tensor, tokens_per_batch: Optional[torch.Tensor]):
|
||||
return input.new_empty((input.shape[:-1], 1), dtype=torch.float32)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user