mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[Bug] Fix torch device issue for MOE permute (#44005)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@@ -74,6 +74,9 @@ class MoEPermuteScratch:
|
||||
self.sort_workspace = torch.empty(
|
||||
sorter_size, dtype=torch.int8, device=self.device
|
||||
)
|
||||
# torch.device("cuda") in config, after initialized,
|
||||
# will be changed to cuda:{index}, so we need to refresh here.
|
||||
self.device = self.token_expert_indices.device
|
||||
|
||||
def validate(self, hidden_states: torch.Tensor, topk_ids: torch.Tensor) -> None:
|
||||
n_token, n_hidden = hidden_states.shape
|
||||
|
||||
Reference in New Issue
Block a user