[Bug] Fix torch device issue for MOE permute (#44005)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye
2026-05-29 14:55:00 -04:00
committed by GitHub
parent 8b9deeec4b
commit 739096a028
@@ -74,6 +74,9 @@ class MoEPermuteScratch:
self.sort_workspace = torch.empty(
sorter_size, dtype=torch.int8, device=self.device
)
# torch.device("cuda") in config, after initialized,
# will be changed to cuda:{index}, so we need to refresh here.
self.device = self.token_expert_indices.device
def validate(self, hidden_states: torch.Tensor, topk_ids: torch.Tensor) -> None:
n_token, n_hidden = hidden_states.shape