mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
DeepEP LL dispatch FP4 (#6296)
Signed-off-by: Yilin Zhang <18275976+yilin-void@users.noreply.github.com>
This commit is contained in:
parent
93a0fd0a23
commit
f172face98
@ -1,4 +1,4 @@
|
|||||||
set(DEEP_EP_COMMIT 7b15af835942675df041eca2dcb9930b880287e1)
|
set(DEEP_EP_COMMIT edf3ea2b086a393d3163bf2773eab69d9191cc01)
|
||||||
set(NVSHMEM_URL_HASH
|
set(NVSHMEM_URL_HASH
|
||||||
SHA256=eb2c8fb3b7084c2db86bd9fd905387909f1dfd483e7b45f7b3c3d5fcf5374b5a)
|
SHA256=eb2c8fb3b7084c2db86bd9fd905387909f1dfd483e7b45f7b3c3d5fcf5374b5a)
|
||||||
|
|
||||||
|
|||||||
@ -154,6 +154,24 @@ class VariableLengthLowLatencyBuffer:
|
|||||||
# Later, you can use our GEMM library to do the computation with this specific format
|
# Later, you can use our GEMM library to do the computation with this specific format
|
||||||
return recv_hidden_states, recv_expert_count, handle
|
return recv_hidden_states, recv_expert_count, handle
|
||||||
|
|
||||||
|
def low_latency_dispatch_fp4(self, hidden_states: torch.Tensor,
|
||||||
|
scales: torch.Tensor, topk_idx: torch.Tensor,
|
||||||
|
num_max_dispatch_tokens_per_rank: int,
|
||||||
|
num_experts: int):
|
||||||
|
assert num_experts == self.num_experts
|
||||||
|
|
||||||
|
# Do MoE dispatch, compatible with CUDA graph (but you may restore some buffer status once you replay)
|
||||||
|
recv_hidden_states, recv_scales, recv_expert_count, handle, event, hook = \
|
||||||
|
self.buffer.low_latency_dispatch_fp4(hidden_states, scales, topk_idx, num_max_dispatch_tokens_per_rank, num_experts)
|
||||||
|
assert event.event is None
|
||||||
|
assert hook is None
|
||||||
|
|
||||||
|
# NOTES: the actual tensor will not be received only if you call `hook()`,
|
||||||
|
# it is useful for double-batch overlapping, but **without any SM occupation**
|
||||||
|
# If you don't want to overlap, please set `return_recv_hook=False`
|
||||||
|
# Later, you can use our GEMM library to do the computation with this specific format
|
||||||
|
return recv_hidden_states, recv_scales, recv_expert_count, handle
|
||||||
|
|
||||||
def low_latency_combine(self, hidden_states: torch.Tensor,
|
def low_latency_combine(self, hidden_states: torch.Tensor,
|
||||||
topk_idx: torch.Tensor, topk_weights: torch.Tensor,
|
topk_idx: torch.Tensor, topk_weights: torch.Tensor,
|
||||||
handle: Tuple):
|
handle: Tuple):
|
||||||
|
|||||||
@ -588,43 +588,26 @@ class WideEPMoE(MoE):
|
|||||||
x_sf = swizzle_sf(x_sf, x.shape[0], x.shape[1] * 2,
|
x_sf = swizzle_sf(x_sf, x.shape[0], x.shape[1] * 2,
|
||||||
self.scaling_vector_size)
|
self.scaling_vector_size)
|
||||||
elif self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
|
elif self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
|
||||||
assert x_sf is not None and self.has_nvfp4
|
|
||||||
token_num = x_row
|
token_num = x_row
|
||||||
hidden_size = x_col
|
hidden_size = x_col
|
||||||
|
assert x_sf is not None and self.has_nvfp4
|
||||||
assert hidden_size % 32 == 0
|
assert hidden_size % 32 == 0
|
||||||
x_sf_dtype = x_sf.dtype
|
assert x.dtype == torch.uint8 and x_sf.dtype == torch.uint8
|
||||||
x_dtype = x.dtype
|
|
||||||
assert x_sf_dtype == torch.uint8 and x_dtype == torch.uint8
|
|
||||||
x_sf = x_sf.view(torch.bfloat16)
|
|
||||||
assert x_sf.shape[0] == token_num and x_sf.shape[
|
assert x_sf.shape[0] == token_num and x_sf.shape[
|
||||||
1] == hidden_size // 16 // 2
|
1] == hidden_size // 16
|
||||||
x = x.view(torch.bfloat16)
|
assert x.shape[0] == token_num and x.shape[1] == hidden_size // 2
|
||||||
assert x.shape[0] == token_num and x.shape[1] == hidden_size // 4
|
|
||||||
# DeepEP LL dispatch only supports bf16 tensors with a hidden size of 2560, 4096, 5120, or 7168 as input. A hidden size of 2560 is sufficient to accommodate packed FP4 data.
|
|
||||||
packed_hidden_size = 2560
|
|
||||||
assert x.shape[1] + x_sf.shape[1] <= packed_hidden_size
|
|
||||||
fp4_packed_tensor = torch.empty((token_num, packed_hidden_size),
|
|
||||||
dtype=torch.bfloat16,
|
|
||||||
device=x.device)
|
|
||||||
fp4_packed_tensor[:, :x.shape[1]] = x
|
|
||||||
fp4_packed_tensor[:,
|
|
||||||
x.shape[1]:x.shape[1] + x_sf.shape[1]] = x_sf
|
|
||||||
|
|
||||||
deep_ep_topk_idx = token_selected_slots
|
deep_ep_topk_idx = token_selected_slots
|
||||||
deep_ep_topk_weights = token_final_scales
|
deep_ep_topk_weights = token_final_scales
|
||||||
|
|
||||||
assert all_rank_max_num_tokens <= self.deep_ep_max_num_tokens
|
assert all_rank_max_num_tokens <= self.deep_ep_max_num_tokens
|
||||||
fp4_packed_tensor, recv_expert_count, deep_ep_handle = \
|
x, x_sf, recv_expert_count, deep_ep_handle = \
|
||||||
self.deep_ep_buffer.low_latency_dispatch(fp4_packed_tensor, deep_ep_topk_idx, all_rank_max_num_tokens, self.num_slots)
|
self.deep_ep_buffer.low_latency_dispatch_fp4(x, x_sf, deep_ep_topk_idx, all_rank_max_num_tokens, self.num_slots)
|
||||||
deep_ep_handle = list(deep_ep_handle)
|
assert x.dtype == torch.uint8 and x_sf.dtype == torch.uint8
|
||||||
deep_ep_handle[3] = hidden_size
|
assert x.dim() == 3 and x_sf.dim() == 3
|
||||||
deep_ep_handle = tuple(deep_ep_handle)
|
assert x.shape[2] == hidden_size // 2 and x_sf.shape[
|
||||||
|
2] == hidden_size // 16
|
||||||
|
|
||||||
assert fp4_packed_tensor.ndim == 3 and fp4_packed_tensor.shape[
|
|
||||||
2] == packed_hidden_size
|
|
||||||
x_sf = fp4_packed_tensor[:, :, x.shape[1]:x.shape[1] +
|
|
||||||
x_sf.shape[1]].contiguous()
|
|
||||||
x = fp4_packed_tensor[:, :, :x.shape[1]].contiguous()
|
|
||||||
mask = torch.arange(
|
mask = torch.arange(
|
||||||
x.shape[1], dtype=torch.int32, device=x.device).expand(
|
x.shape[1], dtype=torch.int32, device=x.device).expand(
|
||||||
x.shape[0], x.shape[1]) < recv_expert_count.unsqueeze(1)
|
x.shape[0], x.shape[1]) < recv_expert_count.unsqueeze(1)
|
||||||
@ -634,9 +617,9 @@ class WideEPMoE(MoE):
|
|||||||
x.shape[0] * (self.mapping.moe_ep_rank + 1),
|
x.shape[0] * (self.mapping.moe_ep_rank + 1),
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=x.device).unsqueeze(1), self.num_slots)
|
device=x.device).unsqueeze(1), self.num_slots)
|
||||||
x = x.reshape(x.shape[0] * x.shape[1], x.shape[2]).view(x_dtype)
|
x = x.reshape(x.shape[0] * x.shape[1], x.shape[2])
|
||||||
x_sf = x_sf.reshape(x_sf.shape[0] * x_sf.shape[1],
|
x_sf = x_sf.reshape(x_sf.shape[0] * x_sf.shape[1],
|
||||||
x_sf.shape[2]).view(x_sf_dtype)
|
x_sf.shape[2])
|
||||||
x_sf = swizzle_sf(x_sf, x.shape[0], x.shape[1] * 2,
|
x_sf = swizzle_sf(x_sf, x.shape[0], x.shape[1] * 2,
|
||||||
self.scaling_vector_size)
|
self.scaling_vector_size)
|
||||||
token_selected_slots = token_selected_slots.view(x.shape[0], 1)
|
token_selected_slots = token_selected_slots.view(x.shape[0], 1)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user