[TRTLLM-6100] fix: Nvbug 5356427: autotuned TRTLLM Gen fp8 block scale MoE illegal memory access (#5676)

Signed-off-by: Dom Brown <3886319+DomBrown@users.noreply.github.com>
This commit is contained in:
Dom Brown 2025-07-04 03:38:08 +01:00 committed by Zhenhuan Chen
parent 4d8920982a
commit afaa388bee
2 changed files with 36 additions and 6 deletions

View File

@ -42,7 +42,9 @@ at::Tensor run_fp8_block_scale_moe(at::Tensor const& routing_logits, at::Tensor
TORCH_CHECK(sm == 100, "Only SM100 is supported by FP8 block scale MOE"); TORCH_CHECK(sm == 100, "Only SM100 is supported by FP8 block scale MOE");
TORCH_CHECK(routing_logits.scalar_type() == at::ScalarType::Float, "routing_logits must be float."); TORCH_CHECK(routing_logits.scalar_type() == at::ScalarType::Float, "routing_logits must be float.");
TORCH_CHECK(routing_logits.dim() == 2, "routing_logits must be 2D."); TORCH_CHECK(routing_logits.dim() == 2, "routing_logits must be 2D.");
TORCH_CHECK(routing_logits.sizes()[1] == num_experts, "routing_logits has incorrect shape."); TORCH_CHECK(routing_logits.sizes()[0] == hidden_states.sizes()[0],
"routing_logits and hidden_states must have the same number of tokens.");
TORCH_CHECK(routing_logits.sizes()[1] == num_experts, "routing_logits dim1 must match num_experts.");
TORCH_CHECK( TORCH_CHECK(
routing_bias.scalar_type() == at::ScalarType::BFloat16 || routing_bias.scalar_type() == at::ScalarType::Float, routing_bias.scalar_type() == at::ScalarType::BFloat16 || routing_bias.scalar_type() == at::ScalarType::Float,
"routing_bias must be bfloat16 or float."); "routing_bias must be bfloat16 or float.");
@ -150,8 +152,9 @@ at::Tensor run_fp8_block_scale_moe(at::Tensor const& routing_logits, at::Tensor
TORCH_CHECK(hidden_states.scalar_type() == at::ScalarType::Float8_e4m3fn, "hidden_states must be fp8."); TORCH_CHECK(hidden_states.scalar_type() == at::ScalarType::Float8_e4m3fn, "hidden_states must be fp8.");
TORCH_CHECK(hidden_states_scale.scalar_type() == at::ScalarType::Float, "hidden_states_scale must be float."); TORCH_CHECK(hidden_states_scale.scalar_type() == at::ScalarType::Float, "hidden_states_scale must be float.");
TORCH_CHECK(hidden_states_scale.dim() == 2, "hidden_states_scale must be 2D."); TORCH_CHECK(hidden_states_scale.dim() == 2, "hidden_states_scale must be 2D.");
TORCH_CHECK( TORCH_CHECK(hidden_states_scale.sizes()[0] == hidden_states.sizes()[1] / 128,
hidden_states_scale.sizes()[0] == hidden_states.sizes()[1] / 128, "hidden_states_scale has incorrect shape."); "hidden_states_scale dim0 must match hidden_states dim1 / 128.");
TORCH_CHECK(hidden_states_scale.sizes()[1] == args.num_tokens, "hidden_states_scale dim1 must match num_tokens.");
TORCH_CHECK(gemm1_weights.scalar_type() == at::ScalarType::Float8_e4m3fn, "gemm1_weights must be fp8."); TORCH_CHECK(gemm1_weights.scalar_type() == at::ScalarType::Float8_e4m3fn, "gemm1_weights must be fp8.");
TORCH_CHECK(gemm1_weights.dim() == 3, "gemm1_weights must be 3D."); TORCH_CHECK(gemm1_weights.dim() == 3, "gemm1_weights must be 3D.");
TORCH_CHECK(gemm1_weights.sizes()[1] % 2 == 0, "the second dimension of weights must be even."); TORCH_CHECK(gemm1_weights.sizes()[1] % 2 == 0, "the second dimension of weights must be even.");

View File

@ -367,8 +367,11 @@ class FP8BlockScaleMoERunner(TunableRunner):
HIDDEN_STATES_IDX = 2 HIDDEN_STATES_IDX = 2
TUNED_DIM = 0 TUNED_DIM = 0
m_values = get_last_power_of_2_num_tokens_buckets(2048) MAX_PROFILE_BUCKET = 4096
round_rule = lambda x: min(last_positive_power_of_2(x), 2048)
m_values = get_last_power_of_2_num_tokens_buckets(MAX_PROFILE_BUCKET)
round_rule = lambda x: min(last_positive_power_of_2(x),
MAX_PROFILE_BUCKET)
specs = (DynamicTensorSpec(HIDDEN_STATES_IDX, TUNED_DIM, m_values, specs = (DynamicTensorSpec(HIDDEN_STATES_IDX, TUNED_DIM, m_values,
round_rule), ) round_rule), )
@ -377,7 +380,31 @@ class FP8BlockScaleMoERunner(TunableRunner):
@classmethod @classmethod
def get_constraint_specs(cls) -> Tuple[ConstraintSpec, ...]: def get_constraint_specs(cls) -> Tuple[ConstraintSpec, ...]:
return ()
def _constrain_to_num_tokens(shapes: Tuple[torch.Size]) -> int:
num_tokens = shapes[2][0]
return num_tokens
HS_SCALE_IDX = 3
CONSTRAINED_HS_SCALE_DIM = 1
constraint_hidden_states_scale = ConstraintSpec(
HS_SCALE_IDX, CONSTRAINED_HS_SCALE_DIM, _constrain_to_num_tokens)
ROUTER_LOGITS_IDX = 0
CONSTRAINED_RL_DIM = 0
constraint_routing_logits = ConstraintSpec(ROUTER_LOGITS_IDX,
CONSTRAINED_RL_DIM,
_constrain_to_num_tokens)
constraint_specs_tuple = (
constraint_hidden_states_scale,
constraint_routing_logits,
)
return constraint_specs_tuple
@classmethod @classmethod
@lru_cache(maxsize=None) @lru_cache(maxsize=None)