mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[None][fix] Replace hash method with unique_id for cutedsl MoE runners. (#9569)
Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
This commit is contained in:
parent
bc25fff039
commit
730eb3d859
@ -54,7 +54,6 @@ if IS_CUTLASS_DSL_AVAILABLE:
|
||||
f"SM version {get_sm_version()} is not supported for {self.__class__.__name__}, it only supports SM 100 and SM 103"
|
||||
)
|
||||
|
||||
# rewrite the hash function because the value of self.alpha doesn't affect the tactic.
|
||||
def unique_id(self):
|
||||
return (self.output_dtype, )
|
||||
|
||||
@ -531,6 +530,17 @@ if IS_CUTLASS_DSL_AVAILABLE:
|
||||
f"SM version {get_sm_version()} is not supported for {self.__class__.__name__}, it only supports SM 100"
|
||||
)
|
||||
|
||||
def unique_id(self):
|
||||
return (
|
||||
self.num_experts,
|
||||
self.top_k,
|
||||
self.num_local_experts,
|
||||
self.local_expert_offset,
|
||||
self.tile_size,
|
||||
self.output_dtype,
|
||||
self.scaling_vector_size,
|
||||
)
|
||||
|
||||
def get_valid_tactics(
|
||||
self,
|
||||
inputs: List[torch.Tensor],
|
||||
@ -571,7 +581,7 @@ if IS_CUTLASS_DSL_AVAILABLE:
|
||||
return valid_tactics
|
||||
|
||||
def get_tuning_config(self) -> TuningConfig:
|
||||
key = hash(self)
|
||||
key = self.unique_id()
|
||||
if key not in self.__class__.tuning_config_cache:
|
||||
helper = GroupedGemmInputsHelper(self.num_experts, self.top_k,
|
||||
self.num_local_experts,
|
||||
@ -807,6 +817,17 @@ if IS_CUTLASS_DSL_AVAILABLE:
|
||||
f"SM version {get_sm_version()} is not supported for {self.__class__.__name__}, it only supports SM 100"
|
||||
)
|
||||
|
||||
def unique_id(self):
|
||||
return (
|
||||
self.num_experts,
|
||||
self.top_k,
|
||||
self.num_local_experts,
|
||||
self.local_expert_offset,
|
||||
self.tile_size,
|
||||
self.output_dtype,
|
||||
self.scaling_vector_size,
|
||||
)
|
||||
|
||||
def get_valid_tactics(
|
||||
self,
|
||||
inputs: List[torch.Tensor],
|
||||
@ -847,7 +868,7 @@ if IS_CUTLASS_DSL_AVAILABLE:
|
||||
return valid_tactics
|
||||
|
||||
def get_tuning_config(self) -> TuningConfig:
|
||||
key = hash(self)
|
||||
key = self.unique_id()
|
||||
if key not in self.__class__.tuning_config_cache:
|
||||
helper = GroupedGemmInputsHelper(self.num_experts, self.top_k,
|
||||
self.num_local_experts,
|
||||
@ -1124,6 +1145,16 @@ if IS_CUTLASS_DSL_AVAILABLE:
|
||||
f"SM version {get_sm_version()} is not supported for {self.__class__.__name__}, it only supports SM 100"
|
||||
)
|
||||
|
||||
def unique_id(self):
|
||||
return (
|
||||
self.num_experts,
|
||||
self.top_k,
|
||||
self.num_local_experts,
|
||||
self.local_expert_offset,
|
||||
self.tile_size,
|
||||
self.scaling_vector_size,
|
||||
)
|
||||
|
||||
def get_valid_tactics(
|
||||
self,
|
||||
inputs: List[torch.Tensor],
|
||||
@ -1164,7 +1195,7 @@ if IS_CUTLASS_DSL_AVAILABLE:
|
||||
return valid_tactics
|
||||
|
||||
def get_tuning_config(self) -> TuningConfig:
|
||||
key = hash(self)
|
||||
key = self.unique_id()
|
||||
if key not in self.__class__.tuning_config_cache:
|
||||
helper = GroupedGemmInputsHelper(self.num_experts, self.top_k,
|
||||
self.num_local_experts,
|
||||
@ -1443,6 +1474,16 @@ if IS_CUTLASS_DSL_AVAILABLE:
|
||||
self.output_dtype = output_dtype
|
||||
self.scaling_vector_size = scaling_vector_size
|
||||
|
||||
def unique_id(self):
|
||||
return (
|
||||
self.num_experts,
|
||||
self.top_k,
|
||||
self.num_local_experts,
|
||||
self.local_expert_offset,
|
||||
self.output_dtype,
|
||||
self.scaling_vector_size,
|
||||
)
|
||||
|
||||
def get_valid_tactics(
|
||||
self,
|
||||
inputs: List[torch.Tensor],
|
||||
@ -1452,7 +1493,7 @@ if IS_CUTLASS_DSL_AVAILABLE:
|
||||
return [128]
|
||||
|
||||
def get_tuning_config(self) -> TuningConfig:
|
||||
key = hash(self)
|
||||
key = self.unique_id()
|
||||
if key not in self.__class__.tuning_config_cache:
|
||||
helper = FusedMoEInputsHelper(self.num_experts, self.top_k,
|
||||
self.num_local_experts,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user