[None][fix] Replace hash method with unique_id for cutedsl MoE runners. (#9569)

Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
This commit is contained in:
Yukun He 2025-12-01 17:02:33 +08:00 committed by GitHub
parent bc25fff039
commit 730eb3d859
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -54,7 +54,6 @@ if IS_CUTLASS_DSL_AVAILABLE:
f"SM version {get_sm_version()} is not supported for {self.__class__.__name__}, it only supports SM 100 and SM 103"
)
# rewrite the hash function because the value of self.alpha doesn't affect the tactic.
def unique_id(self):
return (self.output_dtype, )
@ -531,6 +530,17 @@ if IS_CUTLASS_DSL_AVAILABLE:
f"SM version {get_sm_version()} is not supported for {self.__class__.__name__}, it only supports SM 100"
)
def unique_id(self):
return (
self.num_experts,
self.top_k,
self.num_local_experts,
self.local_expert_offset,
self.tile_size,
self.output_dtype,
self.scaling_vector_size,
)
def get_valid_tactics(
self,
inputs: List[torch.Tensor],
@ -571,7 +581,7 @@ if IS_CUTLASS_DSL_AVAILABLE:
return valid_tactics
def get_tuning_config(self) -> TuningConfig:
key = hash(self)
key = self.unique_id()
if key not in self.__class__.tuning_config_cache:
helper = GroupedGemmInputsHelper(self.num_experts, self.top_k,
self.num_local_experts,
@ -807,6 +817,17 @@ if IS_CUTLASS_DSL_AVAILABLE:
f"SM version {get_sm_version()} is not supported for {self.__class__.__name__}, it only supports SM 100"
)
def unique_id(self):
return (
self.num_experts,
self.top_k,
self.num_local_experts,
self.local_expert_offset,
self.tile_size,
self.output_dtype,
self.scaling_vector_size,
)
def get_valid_tactics(
self,
inputs: List[torch.Tensor],
@ -847,7 +868,7 @@ if IS_CUTLASS_DSL_AVAILABLE:
return valid_tactics
def get_tuning_config(self) -> TuningConfig:
key = hash(self)
key = self.unique_id()
if key not in self.__class__.tuning_config_cache:
helper = GroupedGemmInputsHelper(self.num_experts, self.top_k,
self.num_local_experts,
@ -1124,6 +1145,16 @@ if IS_CUTLASS_DSL_AVAILABLE:
f"SM version {get_sm_version()} is not supported for {self.__class__.__name__}, it only supports SM 100"
)
def unique_id(self):
return (
self.num_experts,
self.top_k,
self.num_local_experts,
self.local_expert_offset,
self.tile_size,
self.scaling_vector_size,
)
def get_valid_tactics(
self,
inputs: List[torch.Tensor],
@ -1164,7 +1195,7 @@ if IS_CUTLASS_DSL_AVAILABLE:
return valid_tactics
def get_tuning_config(self) -> TuningConfig:
key = hash(self)
key = self.unique_id()
if key not in self.__class__.tuning_config_cache:
helper = GroupedGemmInputsHelper(self.num_experts, self.top_k,
self.num_local_experts,
@ -1443,6 +1474,16 @@ if IS_CUTLASS_DSL_AVAILABLE:
self.output_dtype = output_dtype
self.scaling_vector_size = scaling_vector_size
def unique_id(self):
return (
self.num_experts,
self.top_k,
self.num_local_experts,
self.local_expert_offset,
self.output_dtype,
self.scaling_vector_size,
)
def get_valid_tactics(
self,
inputs: List[torch.Tensor],
@ -1452,7 +1493,7 @@ if IS_CUTLASS_DSL_AVAILABLE:
return [128]
def get_tuning_config(self) -> TuningConfig:
key = hash(self)
key = self.unique_id()
if key not in self.__class__.tuning_config_cache:
helper = FusedMoEInputsHelper(self.num_experts, self.top_k,
self.num_local_experts,