mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[TRTLLM-6100] fix: Nvbug 5356427: autotuned TRTLLM Gen fp8 block scale MoE illegal memory access (#5676)
Signed-off-by: Dom Brown <3886319+DomBrown@users.noreply.github.com>
This commit is contained in:
parent
4d8920982a
commit
afaa388bee
@ -42,7 +42,9 @@ at::Tensor run_fp8_block_scale_moe(at::Tensor const& routing_logits, at::Tensor
|
|||||||
TORCH_CHECK(sm == 100, "Only SM100 is supported by FP8 block scale MOE");
|
TORCH_CHECK(sm == 100, "Only SM100 is supported by FP8 block scale MOE");
|
||||||
TORCH_CHECK(routing_logits.scalar_type() == at::ScalarType::Float, "routing_logits must be float.");
|
TORCH_CHECK(routing_logits.scalar_type() == at::ScalarType::Float, "routing_logits must be float.");
|
||||||
TORCH_CHECK(routing_logits.dim() == 2, "routing_logits must be 2D.");
|
TORCH_CHECK(routing_logits.dim() == 2, "routing_logits must be 2D.");
|
||||||
TORCH_CHECK(routing_logits.sizes()[1] == num_experts, "routing_logits has incorrect shape.");
|
TORCH_CHECK(routing_logits.sizes()[0] == hidden_states.sizes()[0],
|
||||||
|
"routing_logits and hidden_states must have the same number of tokens.");
|
||||||
|
TORCH_CHECK(routing_logits.sizes()[1] == num_experts, "routing_logits dim1 must match num_experts.");
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
routing_bias.scalar_type() == at::ScalarType::BFloat16 || routing_bias.scalar_type() == at::ScalarType::Float,
|
routing_bias.scalar_type() == at::ScalarType::BFloat16 || routing_bias.scalar_type() == at::ScalarType::Float,
|
||||||
"routing_bias must be bfloat16 or float.");
|
"routing_bias must be bfloat16 or float.");
|
||||||
@ -150,8 +152,9 @@ at::Tensor run_fp8_block_scale_moe(at::Tensor const& routing_logits, at::Tensor
|
|||||||
TORCH_CHECK(hidden_states.scalar_type() == at::ScalarType::Float8_e4m3fn, "hidden_states must be fp8.");
|
TORCH_CHECK(hidden_states.scalar_type() == at::ScalarType::Float8_e4m3fn, "hidden_states must be fp8.");
|
||||||
TORCH_CHECK(hidden_states_scale.scalar_type() == at::ScalarType::Float, "hidden_states_scale must be float.");
|
TORCH_CHECK(hidden_states_scale.scalar_type() == at::ScalarType::Float, "hidden_states_scale must be float.");
|
||||||
TORCH_CHECK(hidden_states_scale.dim() == 2, "hidden_states_scale must be 2D.");
|
TORCH_CHECK(hidden_states_scale.dim() == 2, "hidden_states_scale must be 2D.");
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(hidden_states_scale.sizes()[0] == hidden_states.sizes()[1] / 128,
|
||||||
hidden_states_scale.sizes()[0] == hidden_states.sizes()[1] / 128, "hidden_states_scale has incorrect shape.");
|
"hidden_states_scale dim0 must match hidden_states dim1 / 128.");
|
||||||
|
TORCH_CHECK(hidden_states_scale.sizes()[1] == args.num_tokens, "hidden_states_scale dim1 must match num_tokens.");
|
||||||
TORCH_CHECK(gemm1_weights.scalar_type() == at::ScalarType::Float8_e4m3fn, "gemm1_weights must be fp8.");
|
TORCH_CHECK(gemm1_weights.scalar_type() == at::ScalarType::Float8_e4m3fn, "gemm1_weights must be fp8.");
|
||||||
TORCH_CHECK(gemm1_weights.dim() == 3, "gemm1_weights must be 3D.");
|
TORCH_CHECK(gemm1_weights.dim() == 3, "gemm1_weights must be 3D.");
|
||||||
TORCH_CHECK(gemm1_weights.sizes()[1] % 2 == 0, "the second dimension of weights must be even.");
|
TORCH_CHECK(gemm1_weights.sizes()[1] % 2 == 0, "the second dimension of weights must be even.");
|
||||||
|
|||||||
@ -367,8 +367,11 @@ class FP8BlockScaleMoERunner(TunableRunner):
|
|||||||
HIDDEN_STATES_IDX = 2
|
HIDDEN_STATES_IDX = 2
|
||||||
TUNED_DIM = 0
|
TUNED_DIM = 0
|
||||||
|
|
||||||
m_values = get_last_power_of_2_num_tokens_buckets(2048)
|
MAX_PROFILE_BUCKET = 4096
|
||||||
round_rule = lambda x: min(last_positive_power_of_2(x), 2048)
|
|
||||||
|
m_values = get_last_power_of_2_num_tokens_buckets(MAX_PROFILE_BUCKET)
|
||||||
|
round_rule = lambda x: min(last_positive_power_of_2(x),
|
||||||
|
MAX_PROFILE_BUCKET)
|
||||||
|
|
||||||
specs = (DynamicTensorSpec(HIDDEN_STATES_IDX, TUNED_DIM, m_values,
|
specs = (DynamicTensorSpec(HIDDEN_STATES_IDX, TUNED_DIM, m_values,
|
||||||
round_rule), )
|
round_rule), )
|
||||||
@ -377,7 +380,31 @@ class FP8BlockScaleMoERunner(TunableRunner):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_constraint_specs(cls) -> Tuple[ConstraintSpec, ...]:
|
def get_constraint_specs(cls) -> Tuple[ConstraintSpec, ...]:
|
||||||
return ()
|
|
||||||
|
def _constrain_to_num_tokens(shapes: Tuple[torch.Size]) -> int:
|
||||||
|
num_tokens = shapes[2][0]
|
||||||
|
|
||||||
|
return num_tokens
|
||||||
|
|
||||||
|
HS_SCALE_IDX = 3
|
||||||
|
CONSTRAINED_HS_SCALE_DIM = 1
|
||||||
|
|
||||||
|
constraint_hidden_states_scale = ConstraintSpec(
|
||||||
|
HS_SCALE_IDX, CONSTRAINED_HS_SCALE_DIM, _constrain_to_num_tokens)
|
||||||
|
|
||||||
|
ROUTER_LOGITS_IDX = 0
|
||||||
|
CONSTRAINED_RL_DIM = 0
|
||||||
|
|
||||||
|
constraint_routing_logits = ConstraintSpec(ROUTER_LOGITS_IDX,
|
||||||
|
CONSTRAINED_RL_DIM,
|
||||||
|
_constrain_to_num_tokens)
|
||||||
|
|
||||||
|
constraint_specs_tuple = (
|
||||||
|
constraint_hidden_states_scale,
|
||||||
|
constraint_routing_logits,
|
||||||
|
)
|
||||||
|
|
||||||
|
return constraint_specs_tuple
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@lru_cache(maxsize=None)
|
@lru_cache(maxsize=None)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user