mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[feat] Support torch compile for attention dp (#5086)
Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com>
This commit is contained in:
parent
f9a455651b
commit
c345f5876c
@ -70,7 +70,15 @@ public:
|
||||
outputShape[0] *= mGroup.size();
|
||||
}
|
||||
auto output = torch::empty(outputShape, input.options());
|
||||
if (sizes.has_value())
|
||||
bool use_nccl_allgather = !sizes.has_value()
|
||||
|| std::all_of(sizes.value().begin(), sizes.value().end(),
|
||||
[&sizes](int64_t size) { return size == sizes.value()[0]; });
|
||||
if (use_nccl_allgather)
|
||||
{
|
||||
NCCLCHECK_THROW(ncclAllGather(input.data_ptr(), output.mutable_data_ptr(), input.numel(),
|
||||
(*getDtypeMap())[type], *mNcclComm, stream));
|
||||
}
|
||||
else
|
||||
{
|
||||
size_t numel_base = std::accumulate(outputShape.cbegin() + 1, outputShape.cend(), 1, std::multiplies<>{});
|
||||
int64_t split_offset = 0;
|
||||
@ -85,11 +93,6 @@ public:
|
||||
}
|
||||
ncclGroupEnd();
|
||||
}
|
||||
else
|
||||
{
|
||||
NCCLCHECK_THROW(ncclAllGather(input.data_ptr(), output.mutable_data_ptr(), input.numel(),
|
||||
(*getDtypeMap())[type], *mNcclComm, stream));
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
@ -155,8 +158,8 @@ std::vector<torch::Tensor> allgather_list(
|
||||
|
||||
TORCH_LIBRARY_FRAGMENT(trtllm, m)
|
||||
{
|
||||
m.def("allgather(Tensor input, int[]? sizes, int[] group) -> Tensor");
|
||||
m.def("allgather_list(Tensor[] input_list, int[]? sizes, int[] group) -> Tensor[]");
|
||||
m.def("allgather(Tensor input, SymInt[]? sizes, int[] group) -> Tensor");
|
||||
m.def("allgather_list(Tensor[] input_list, SymInt[]? sizes, int[] group) -> Tensor[]");
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
|
||||
|
||||
@ -718,10 +718,10 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
|
||||
", int num_heads"
|
||||
", int num_kv_heads"
|
||||
", int head_size"
|
||||
", int? tokens_per_block"
|
||||
", int max_num_requests"
|
||||
", int max_context_length"
|
||||
", int attention_window_size"
|
||||
", SymInt? tokens_per_block"
|
||||
", SymInt max_num_requests"
|
||||
", SymInt max_context_length"
|
||||
", SymInt attention_window_size"
|
||||
", int sink_token_length"
|
||||
", int beam_width"
|
||||
", int mask_type"
|
||||
|
||||
@ -524,8 +524,20 @@ static auto e2m1_and_ufp8sf_scale_to_float
|
||||
static auto e2m1_and_ufp8sf_scale_to_float_v2 = torch::RegisterOperators(
|
||||
"tensorrt_llm::e2m1_and_ufp8sf_scale_to_float_v2", &torch_ext::E2M1AndUFP8SFScaleToFloatV2);
|
||||
|
||||
static auto nvfp4_block_scale_interleave
|
||||
= torch::RegisterOperators("tensorrt_llm::nvfp4_block_scale_interleave", &torch_ext::NVFP4BlockScaleInterleave);
|
||||
TORCH_LIBRARY_FRAGMENT(trtllm, m)
|
||||
{
|
||||
m.def("nvfp4_block_scale_interleave(Tensor input) -> Tensor");
|
||||
m.def("nvfp4_block_scale_interleave_reverse(Tensor input) -> Tensor");
|
||||
}
|
||||
|
||||
static auto nvfp4_block_scale_interleave_reverse = torch::RegisterOperators(
|
||||
"tensorrt_llm::nvfp4_block_scale_interleave_reverse", &torch_ext::NVFP4BlockScaleInterleaveReverse);
|
||||
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
|
||||
{
|
||||
m.impl("nvfp4_block_scale_interleave", &torch_ext::NVFP4BlockScaleInterleave);
|
||||
m.impl("nvfp4_block_scale_interleave_reverse", &torch_ext::NVFP4BlockScaleInterleaveReverse);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(trtllm, CPU, m)
|
||||
{
|
||||
m.impl("nvfp4_block_scale_interleave", &torch_ext::NVFP4BlockScaleInterleave);
|
||||
m.impl("nvfp4_block_scale_interleave_reverse", &torch_ext::NVFP4BlockScaleInterleaveReverse);
|
||||
}
|
||||
|
||||
@ -79,7 +79,15 @@ public:
|
||||
outputShape[0] = outputShape[0] / mGroup.size();
|
||||
}
|
||||
auto output = torch::empty(outputShape, input.options());
|
||||
if (sizes.has_value())
|
||||
bool use_nccl_reducescatter = !sizes.has_value()
|
||||
|| std::all_of(sizes.value().begin(), sizes.value().end(),
|
||||
[&sizes](int64_t size) { return size == sizes.value()[0]; });
|
||||
if (use_nccl_reducescatter)
|
||||
{
|
||||
NCCLCHECK_THROW(ncclReduceScatter(input.data_ptr(), output.mutable_data_ptr(), output.numel(),
|
||||
(*getDtypeMap())[type], ncclSum, *mNcclComm, stream));
|
||||
}
|
||||
else
|
||||
{
|
||||
size_t numel_base = std::accumulate(outputShape.cbegin() + 1, outputShape.cend(), 1, std::multiplies<>{});
|
||||
int64_t split_offset = 0;
|
||||
@ -95,11 +103,6 @@ public:
|
||||
}
|
||||
ncclGroupEnd();
|
||||
}
|
||||
else
|
||||
{
|
||||
NCCLCHECK_THROW(ncclReduceScatter(input.data_ptr(), output.mutable_data_ptr(), output.numel(),
|
||||
(*getDtypeMap())[type], ncclSum, *mNcclComm, stream));
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
@ -167,8 +170,8 @@ extern std::vector<torch::Tensor> reducescatter_list(
|
||||
|
||||
TORCH_LIBRARY_FRAGMENT(trtllm, m)
|
||||
{
|
||||
m.def("reducescatter(Tensor input, int[]? sizes, int[] group) -> Tensor");
|
||||
m.def("reducescatter_list(Tensor[] input_list, int[]? sizes, int[] group) -> Tensor[]");
|
||||
m.def("reducescatter(Tensor input, SymInt[]? sizes, int[] group) -> Tensor");
|
||||
m.def("reducescatter_list(Tensor[] input_list, SymInt[]? sizes, int[] group) -> Tensor[]");
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
|
||||
|
||||
@ -118,7 +118,13 @@ class AttentionMetadata:
|
||||
runtime_features: AttentionRuntimeFeatures = field(
|
||||
default_factory=AttentionRuntimeFeatures)
|
||||
|
||||
all_rank_num_tokens: Optional[List[int]] = None
|
||||
# The number of tokens in each rank.
|
||||
_all_rank_num_tokens: Optional[List[int]] = field(init=False,
|
||||
default=None,
|
||||
repr=False)
|
||||
all_rank_num_tokens: Optional[List[int]]
|
||||
# The max number of tokens among all ranks.
|
||||
all_rank_max_num_tokens: Optional[int] = None
|
||||
|
||||
# These fields are set when changing seq_lens and _num_contexts to avoid computation
|
||||
# during execution. If the calculation happens during execution, torch compile treats it
|
||||
@ -149,6 +155,16 @@ class AttentionMetadata:
|
||||
elif self._seq_lens is not None:
|
||||
self._num_tokens = self._seq_lens.sum().item()
|
||||
|
||||
@property
|
||||
def all_rank_num_tokens(self) -> Optional[List[int]]:
|
||||
return self._all_rank_num_tokens
|
||||
|
||||
@all_rank_num_tokens.setter
|
||||
def all_rank_num_tokens(self, value: Optional[List[int]]):
|
||||
value = value if value is not AttentionMetadata.all_rank_num_tokens else None
|
||||
self._all_rank_num_tokens = value
|
||||
self.all_rank_max_num_tokens = max(value) if value is not None else None
|
||||
|
||||
@property
|
||||
def seq_lens(self) -> Optional[torch.Tensor]:
|
||||
return self._seq_lens
|
||||
|
||||
@ -300,7 +300,7 @@ class FP4QuantizationImpl(QuantizationImpl):
|
||||
weight_scale = state_dict[weight_name + "_scale"].view(float4_sf_dtype)
|
||||
ori_shape = weight_scale.shape
|
||||
state_dict[weight_name + "_scale"] = (
|
||||
torch.ops.tensorrt_llm.nvfp4_block_scale_interleave(
|
||||
torch.ops.trtllm.nvfp4_block_scale_interleave(
|
||||
weight_scale.view(torch.uint8).cpu().contiguous()
|
||||
)
|
||||
.reshape(ori_shape)
|
||||
|
||||
@ -171,9 +171,10 @@ def _register_fake():
|
||||
global_scale: torch.Tensor,
|
||||
sf_vec_size: int,
|
||||
sf_use_ue8m0=False,
|
||||
swizzled_layout=True,
|
||||
):
|
||||
output_shape, scale_shape = fp4_utils.get_fp4_shape(
|
||||
input.shape, sf_vec_size)
|
||||
input.shape, sf_vec_size, swizzled_layout)
|
||||
|
||||
return (input.new_empty(output_shape, dtype=torch.uint8),
|
||||
global_scale.new_empty(scale_shape, dtype=torch.uint8))
|
||||
@ -473,3 +474,88 @@ def _register_fake():
|
||||
hidden_size_val = int(hidden_size)
|
||||
return gemm2_output.new_empty((num_rows_val, hidden_size_val),
|
||||
dtype=gemm2_output.dtype)
|
||||
|
||||
@torch.library.register_fake("trtllm::allgather_list")
|
||||
def _(input_list, sizes, group):
|
||||
assert len(input_list) > 0
|
||||
|
||||
def create_output_tensor(i):
|
||||
shape = list(i.shape)
|
||||
if sizes is None:
|
||||
shape[0] *= len(group)
|
||||
else:
|
||||
shape[0] = sum(sizes)
|
||||
return i.new_empty(shape)
|
||||
|
||||
return [create_output_tensor(i) for i in input_list]
|
||||
|
||||
@torch.library.register_fake("trtllm::reducescatter")
|
||||
def _(input, sizes, group):
|
||||
import tensorrt_llm
|
||||
local_rank = tensorrt_llm.mpi_rank()
|
||||
|
||||
shape = list(input.shape)
|
||||
if sizes is None:
|
||||
shape[0] = shape[0] // len(group)
|
||||
else:
|
||||
shape[0] = sizes[local_rank]
|
||||
return input.new_empty(shape)
|
||||
|
||||
@torch.library.register_fake("trtllm::fp4_block_scale_moe_runner")
|
||||
def _(
|
||||
routing_logits,
|
||||
routing_bias,
|
||||
hidden_states,
|
||||
hidden_states_scale,
|
||||
gemm1_weights,
|
||||
gemm1_weights_scale,
|
||||
gemm2_weights,
|
||||
gemm2_weights_scale,
|
||||
output1_scale_scalar,
|
||||
output1_scale_gate_scalar,
|
||||
output2_scale_scalar,
|
||||
num_experts,
|
||||
top_k,
|
||||
n_group,
|
||||
topk_group,
|
||||
intermediate_size,
|
||||
local_expert_offset,
|
||||
local_num_experts,
|
||||
routed_scaling_factor,
|
||||
tile_tokens_dim,
|
||||
routing_method_type,
|
||||
do_finalize,
|
||||
) -> List[torch.Tensor]:
|
||||
num_tokens = hidden_states.shape[0]
|
||||
hidden_size = hidden_states.shape[1] * 2
|
||||
if do_finalize:
|
||||
return [
|
||||
hidden_states.new_empty((num_tokens, hidden_size),
|
||||
dtype=torch.bfloat16)
|
||||
]
|
||||
|
||||
expanded_row_count = num_tokens * top_k
|
||||
max_padding_required = (tile_tokens_dim - 1) * num_experts
|
||||
max_num_padded_tokens = fp4_utils.pad_up(
|
||||
expanded_row_count + max_padding_required, tile_tokens_dim)
|
||||
wt_dtype = routing_bias.dtype if routing_bias is not None else torch.bfloat16
|
||||
return [
|
||||
hidden_states.new_empty((max_num_padded_tokens, hidden_size),
|
||||
dtype=torch.bfloat16),
|
||||
hidden_states.new_empty((num_tokens, top_k), dtype=wt_dtype),
|
||||
hidden_states.new_empty((num_tokens, top_k), dtype=torch.int32)
|
||||
]
|
||||
|
||||
@torch.library.register_fake("trtllm::nvfp4_block_scale_interleave")
|
||||
def _(sf: torch.Tensor):
|
||||
rows = sf.shape[-2]
|
||||
cols = sf.shape[-1]
|
||||
expert_out_size = fp4_utils.pad_up(rows, 128) * fp4_utils.pad_up(
|
||||
cols, 4)
|
||||
num_experts = sf.shape[0] if len(sf.shape) == 3 else 1
|
||||
return sf.new_empty((num_experts * expert_out_size, ),
|
||||
dtype=torch.uint8)
|
||||
|
||||
@torch.library.register_fake("trtllm::nvfp4_block_scale_interleave_reverse")
|
||||
def _(sf: torch.Tensor):
|
||||
return torch.empty_like(sf, dtype=torch.uint8)
|
||||
|
||||
@ -224,6 +224,7 @@ def _(
|
||||
ep_rank: int = 0,
|
||||
cluster_size: int = 1,
|
||||
cluster_rank: int = 0,
|
||||
enable_alltoall: bool = False,
|
||||
use_deepseek_fp8_block_scale: bool = False,
|
||||
use_w4a8_group_scaling: bool = False,
|
||||
min_latency_mode: bool = False,
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import math
|
||||
import os
|
||||
import threading
|
||||
from itertools import accumulate
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@ -126,13 +125,14 @@ def filter_valid_input(
|
||||
return input_list, valid_list
|
||||
|
||||
|
||||
def restore_full_output(output_list: List[torch.Tensor],
|
||||
def restore_full_output(valid_outputs: List[torch.Tensor],
|
||||
valid_list: List[bool]) -> List[torch.Tensor]:
|
||||
index_list = list(accumulate(map(int, valid_list)))
|
||||
output_list = list(
|
||||
map(lambda valid, index: output_list[index - 1]
|
||||
if valid else None, valid_list, index_list))
|
||||
return output_list
|
||||
idx = 0
|
||||
full_outputs = []
|
||||
for v in valid_list:
|
||||
full_outputs.append(valid_outputs[idx] if v else None)
|
||||
idx += int(v)
|
||||
return full_outputs
|
||||
|
||||
|
||||
def allgather(
|
||||
@ -178,12 +178,6 @@ def allgather(
|
||||
val.shape[dim] == sizes[mapping.tp_rank] for val in input
|
||||
if val is not None
|
||||
])
|
||||
# 'sizes' is not needed if all inputs in the same TP group have the same shape
|
||||
for split_size in sizes[1:]:
|
||||
if split_size != sizes[0]:
|
||||
break
|
||||
else:
|
||||
sizes = None
|
||||
|
||||
# Inputs are reshaped in this way to pass necessary shape information to the allgather op
|
||||
if isinstance(input, torch.Tensor):
|
||||
@ -247,12 +241,6 @@ def reducescatter(
|
||||
val.shape[dim] == sum_split_size for val in input
|
||||
if val is not None
|
||||
])
|
||||
# 'sizes' is not needed if all outputs in the same TP group have the same shape
|
||||
for split_size in sizes[1:]:
|
||||
if split_size != sizes[0]:
|
||||
break
|
||||
else:
|
||||
sizes = None
|
||||
|
||||
def convert_input(x, x_info):
|
||||
# Inputs are reshaped in this way to pass necessary shape information to the reducescatter op
|
||||
|
||||
@ -507,7 +507,8 @@ class Deepseekv3MoE(nn.Module):
|
||||
return shared_tp_size, shared_output_scale
|
||||
|
||||
def compute_routed_output(self, hidden_states, hidden_states_fp4,
|
||||
all_rank_num_tokens, do_finalize):
|
||||
all_rank_num_tokens, all_rank_max_num_tokens,
|
||||
do_finalize):
|
||||
# max-throughput
|
||||
use_dp_padding = False
|
||||
if self.use_dp and self.mapping.tp_size > 1:
|
||||
@ -522,10 +523,9 @@ class Deepseekv3MoE(nn.Module):
|
||||
not self.experts.has_fp8_qdq and self.experts.has_nvfp4):
|
||||
# Use padding when not using the cutlass path or when x_sf in self.experts is not None
|
||||
use_dp_padding = True
|
||||
max_num_token = max(all_rank_num_tokens)
|
||||
hidden_states = torch.nn.functional.pad(
|
||||
hidden_states,
|
||||
(0, 0, 0, max_num_token - hidden_states.shape[0]))
|
||||
(0, 0, 0, all_rank_max_num_tokens - hidden_states.shape[0]))
|
||||
|
||||
router_logits = self.gate(hidden_states)
|
||||
|
||||
@ -535,6 +535,7 @@ class Deepseekv3MoE(nn.Module):
|
||||
do_finalize=do_finalize,
|
||||
output_dtype=hidden_states.dtype,
|
||||
all_rank_num_tokens=all_rank_num_tokens,
|
||||
all_rank_max_num_tokens=all_rank_max_num_tokens,
|
||||
use_dp_padding=use_dp_padding,
|
||||
)
|
||||
|
||||
@ -545,6 +546,7 @@ class Deepseekv3MoE(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
hidden_states_fp4: Optional[Fp4QuantizedTensor] = None,
|
||||
all_rank_num_tokens: Optional[list[int]] = None,
|
||||
all_rank_max_num_tokens: Optional[int] = None,
|
||||
final_all_reduce_params: Optional[AllReduceParams] = None,
|
||||
do_finalize: Optional[bool] = True,
|
||||
) -> torch.Tensor:
|
||||
@ -562,6 +564,7 @@ class Deepseekv3MoE(nn.Module):
|
||||
routed_output = self.compute_routed_output(hidden_states,
|
||||
hidden_states_fp4,
|
||||
all_rank_num_tokens,
|
||||
all_rank_max_num_tokens,
|
||||
do_finalize)
|
||||
return routed_output
|
||||
|
||||
@ -774,6 +777,7 @@ class DeepseekV3DecoderLayer(DecoderLayer):
|
||||
hidden_states,
|
||||
hidden_states_fp4,
|
||||
all_rank_num_tokens=attn_metadata.all_rank_num_tokens,
|
||||
all_rank_max_num_tokens=attn_metadata.all_rank_max_num_tokens,
|
||||
final_all_reduce_params=AllReduceParams(
|
||||
enable_allreduce=not (self.fusion_config.POST_MOE_FUSION
|
||||
or self.mapping.tp_size == 1)),
|
||||
@ -934,6 +938,7 @@ class DeepseekV3MTP(DeepseekV3DecoderLayer):
|
||||
embed_tokens: Embedding,
|
||||
attn_metadata: AttentionMetadata,
|
||||
all_rank_num_tokens: Optional[List[int]] = None,
|
||||
all_rank_max_num_tokens: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
@ -975,6 +980,7 @@ class DeepseekV3MTP(DeepseekV3DecoderLayer):
|
||||
hidden_states = self.mlp(
|
||||
hidden_states,
|
||||
all_rank_num_tokens=all_rank_num_tokens,
|
||||
all_rank_max_num_tokens=all_rank_max_num_tokens,
|
||||
final_all_reduce_params=AllReduceParams(
|
||||
enable_allreduce=not (self.fusion_config.POST_MOE_FUSION
|
||||
or self.mapping.tp_size == 1)),
|
||||
|
||||
@ -292,22 +292,22 @@ class Llama4MoE(nn.Module):
|
||||
self.aux_stream = aux_stream
|
||||
|
||||
def compute_routed_output(self, hidden_states, all_rank_num_tokens,
|
||||
all_rank_max_num_tokens,
|
||||
cutlass_min_latency_mode):
|
||||
use_dp_padding = False
|
||||
if self.enable_attention_dp and self.mapping.tp_size > 1:
|
||||
# Use padding here to keep the behavior unchanged
|
||||
use_dp_padding = True
|
||||
max_num_token_across_dp_ranks = max(all_rank_num_tokens)
|
||||
hidden_states = torch.nn.functional.pad(
|
||||
hidden_states,
|
||||
(0, 0, 0,
|
||||
max_num_token_across_dp_ranks - hidden_states.shape[0]))
|
||||
(0, 0, 0, all_rank_max_num_tokens - hidden_states.shape[0]))
|
||||
router_logits = self.router(hidden_states)
|
||||
routed_output = self.experts(
|
||||
hidden_states,
|
||||
router_logits,
|
||||
do_finalize=not cutlass_min_latency_mode,
|
||||
all_rank_num_tokens=all_rank_num_tokens,
|
||||
all_rank_max_num_tokens=all_rank_max_num_tokens,
|
||||
use_dp_padding=use_dp_padding,
|
||||
)
|
||||
return routed_output
|
||||
@ -316,6 +316,7 @@ class Llama4MoE(nn.Module):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
all_rank_num_tokens=None,
|
||||
all_rank_max_num_tokens=None,
|
||||
final_all_reduce_params: Optional[AllReduceParams] = None,
|
||||
cutlass_min_latency_mode: Optional[bool] = False,
|
||||
) -> torch.Tensor:
|
||||
@ -323,7 +324,8 @@ class Llama4MoE(nn.Module):
|
||||
# This design is mainly for low latency use case. Need to improve for max throughput use case.
|
||||
fn0 = lambda: self.shared_expert(hidden_states)
|
||||
fn1 = lambda: self.compute_routed_output(
|
||||
hidden_states, all_rank_num_tokens, cutlass_min_latency_mode)
|
||||
hidden_states, all_rank_num_tokens, all_rank_max_num_tokens,
|
||||
cutlass_min_latency_mode)
|
||||
shared_output, routed_output = maybe_execute_in_parallel(
|
||||
fn0, fn1, self.moe_event[0], self.moe_event[1], self.aux_stream)
|
||||
if cutlass_min_latency_mode:
|
||||
@ -479,6 +481,7 @@ class Llama4DecoderLayer(DecoderLayer):
|
||||
hidden_states = self.feed_forward(
|
||||
hidden_states,
|
||||
all_rank_num_tokens=attn_metadata.all_rank_num_tokens,
|
||||
all_rank_max_num_tokens=attn_metadata.all_rank_max_num_tokens,
|
||||
final_all_reduce_params=AllReduceParams(enable_allreduce=not (
|
||||
self.fusion_config.POST_MOE_FUSION
|
||||
or self.fusion_config.POST_MLP_FUSION
|
||||
|
||||
@ -61,19 +61,20 @@ class MixtralMoE(nn.Module):
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
all_rank_num_tokens = attn_metadata.all_rank_num_tokens
|
||||
all_rank_max_num_tokens = attn_metadata.all_rank_max_num_tokens
|
||||
use_dp_padding = False
|
||||
if self.enable_attention_dp and len(all_rank_num_tokens) > 1:
|
||||
# Use padding here to keep the behavior unchanged
|
||||
use_dp_padding = True
|
||||
max_num_token = max(all_rank_num_tokens)
|
||||
hidden_states = torch.nn.functional.pad(
|
||||
hidden_states,
|
||||
(0, 0, 0, max_num_token - hidden_states.shape[0]))
|
||||
(0, 0, 0, all_rank_max_num_tokens - hidden_states.shape[0]))
|
||||
router_logits = self.gate(hidden_states)
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states,
|
||||
router_logits,
|
||||
all_rank_num_tokens=all_rank_num_tokens,
|
||||
all_rank_max_num_tokens=all_rank_max_num_tokens,
|
||||
use_dp_padding=use_dp_padding)
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
@ -126,6 +126,7 @@ class Qwen3MoE(nn.Module):
|
||||
hidden_states = hidden_states.view(-1, self.hidden_dim)
|
||||
use_dp_padding = False
|
||||
all_rank_num_tokens = attn_metadata.all_rank_num_tokens
|
||||
all_rank_max_num_tokens = attn_metadata.all_rank_max_num_tokens
|
||||
|
||||
if not do_finalize:
|
||||
assert not self.enable_attention_dp
|
||||
@ -142,16 +143,16 @@ class Qwen3MoE(nn.Module):
|
||||
not self.experts.has_fp8_qdq and self.experts.has_nvfp4):
|
||||
# Use padding when not using the cutlass path or when x_sf in self.experts is not None
|
||||
use_dp_padding = True
|
||||
max_num_token = max(all_rank_num_tokens)
|
||||
hidden_states = torch.nn.functional.pad(
|
||||
hidden_states,
|
||||
(0, 0, 0, max_num_token - hidden_states.shape[0]))
|
||||
(0, 0, 0, all_rank_max_num_tokens - hidden_states.shape[0]))
|
||||
|
||||
router_logits = self.gate(hidden_states)
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states,
|
||||
router_logits,
|
||||
all_rank_num_tokens=all_rank_num_tokens,
|
||||
all_rank_max_num_tokens=all_rank_max_num_tokens,
|
||||
use_dp_padding=use_dp_padding,
|
||||
do_finalize=do_finalize,
|
||||
)
|
||||
|
||||
@ -85,11 +85,13 @@ class QwenMoE(nn.Module):
|
||||
hidden_states = hidden_states.view(-1, self.hidden_dim)
|
||||
|
||||
all_rank_num_tokens = attn_metadata.all_rank_num_tokens
|
||||
all_rank_max_num_tokens = attn_metadata.all_rank_max_num_tokens
|
||||
router_logits = self.gate(hidden_states)
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states,
|
||||
router_logits,
|
||||
all_rank_num_tokens=all_rank_num_tokens,
|
||||
all_rank_max_num_tokens=all_rank_max_num_tokens,
|
||||
use_dp_padding=False)
|
||||
|
||||
shared_expert_output = self.shared_expert(hidden_states)
|
||||
|
||||
@ -326,6 +326,7 @@ class CutlassFusedMoE(MoE):
|
||||
do_finalize: bool = True, # used by other MoE backends
|
||||
output_dtype: Optional[torch.dtype] = None,
|
||||
all_rank_num_tokens: Optional[List[int]] = None,
|
||||
all_rank_max_num_tokens: Optional[int] = None,
|
||||
use_dp_padding: Optional[bool] = None,
|
||||
) -> torch.Tensor:
|
||||
assert do_finalize, "CutlassFusedMoE does not support do_finalize=False"
|
||||
@ -341,7 +342,7 @@ class CutlassFusedMoE(MoE):
|
||||
1) // self.moe_max_num_tokens
|
||||
|
||||
if use_dp_padding:
|
||||
all_rank_num_tokens_padded = [max(all_rank_num_tokens)
|
||||
all_rank_num_tokens_padded = [all_rank_max_num_tokens
|
||||
] * len(all_rank_num_tokens)
|
||||
else:
|
||||
all_rank_num_tokens_padded = all_rank_num_tokens
|
||||
|
||||
@ -327,6 +327,7 @@ class WideEPMoE(MoE):
|
||||
cutlass_min_latency_mode: bool = False,
|
||||
output_dtype: Optional[torch.dtype] = None,
|
||||
all_rank_num_tokens: Optional[List[int]] = None,
|
||||
all_rank_max_num_tokens: Optional[int] = None,
|
||||
use_dp_padding: Optional[bool] = None,
|
||||
repeating_info: Tuple = (True, True),
|
||||
) -> torch.Tensor:
|
||||
@ -400,7 +401,7 @@ class WideEPMoE(MoE):
|
||||
token_count = x.shape[0]
|
||||
alltoall_info = None
|
||||
x, token_selected_slots, token_final_scales, gathered_loadbalancer_local_statistic_info, alltoall_info = \
|
||||
self.alltoall_prepare_maybe_dispatch(all_rank_num_tokens,
|
||||
self.alltoall_prepare_maybe_dispatch(all_rank_max_num_tokens,
|
||||
x,
|
||||
token_selected_slots,
|
||||
token_final_scales,
|
||||
@ -663,6 +664,7 @@ class WideEPMoE(MoE):
|
||||
do_finalize: bool = True,
|
||||
output_dtype: Optional[torch.dtype] = None,
|
||||
all_rank_num_tokens: Optional[List[int]] = None,
|
||||
all_rank_max_num_tokens: Optional[int] = None,
|
||||
use_dp_padding: Optional[bool] = None,
|
||||
) -> torch.Tensor:
|
||||
if self.use_dp:
|
||||
@ -684,7 +686,7 @@ class WideEPMoE(MoE):
|
||||
), "cutlass_min_latency_mode must be used with a single chunk and reduce_results must be False"
|
||||
|
||||
if use_dp_padding:
|
||||
all_rank_num_tokens_padded = [max(all_rank_num_tokens)
|
||||
all_rank_num_tokens_padded = [all_rank_max_num_tokens
|
||||
] * len(all_rank_num_tokens)
|
||||
else:
|
||||
all_rank_num_tokens_padded = all_rank_num_tokens
|
||||
@ -697,6 +699,7 @@ class WideEPMoE(MoE):
|
||||
cutlass_min_latency_mode,
|
||||
output_dtype,
|
||||
all_rank_num_tokens=all_rank_num_tokens_padded,
|
||||
all_rank_max_num_tokens=all_rank_max_num_tokens,
|
||||
use_dp_padding=use_dp_padding,
|
||||
repeating_info=(is_first_call, is_last_call))
|
||||
outputs = self.reducescatter_or_allreduce(
|
||||
@ -720,13 +723,20 @@ class WideEPMoE(MoE):
|
||||
all_rank_num_tokens_list = [[
|
||||
val[idx_chunk] for val in all_rank_chunk_size_list
|
||||
] for idx_chunk in range(num_chunks)]
|
||||
all_rank_max_num_tokens_list = split_chunk(
|
||||
all_rank_max_num_tokens, num_chunks)
|
||||
chunk_size_list = all_rank_chunk_size_list[self.rank]
|
||||
if self.enable_alltoall:
|
||||
all_rank_num_tokens_list = [[
|
||||
1 if val == 0 else val for val in val_list
|
||||
] for val_list in all_rank_num_tokens_list]
|
||||
all_rank_max_num_tokens_list = [
|
||||
1 if val == 0 else val
|
||||
for val in all_rank_max_num_tokens_list
|
||||
]
|
||||
else:
|
||||
all_rank_num_tokens_list = [None] * num_chunks
|
||||
all_rank_max_num_tokens_list = [None] * num_chunks
|
||||
chunk_size_list = split_chunk(x.shape[0], num_chunks)
|
||||
|
||||
x_list = x.split(chunk_size_list)
|
||||
@ -751,6 +761,9 @@ class WideEPMoE(MoE):
|
||||
router_logits,
|
||||
all_rank_num_tokens=all_rank_num_tokens_list[
|
||||
idx_chunk] if self.use_dp else None,
|
||||
all_rank_max_num_tokens=
|
||||
all_rank_max_num_tokens_list[idx_chunk]
|
||||
if self.use_dp else None,
|
||||
use_dp_padding=use_dp_padding,
|
||||
repeating_info=(is_first_call, is_last_call))
|
||||
if idx_chunk > 0:
|
||||
@ -765,6 +778,8 @@ class WideEPMoE(MoE):
|
||||
router_logits,
|
||||
all_rank_num_tokens=all_rank_num_tokens_list[
|
||||
idx_chunk] if self.use_dp else None,
|
||||
all_rank_max_num_tokens=all_rank_max_num_tokens_list[
|
||||
idx_chunk] if self.use_dp else None,
|
||||
use_dp_padding=use_dp_padding,
|
||||
repeating_info=(is_first_call, is_last_call))
|
||||
with torch.cuda.stream(self.aux_stream):
|
||||
@ -805,30 +820,30 @@ class WideEPMoE(MoE):
|
||||
return outputs
|
||||
|
||||
def alltoall_prepare_maybe_dispatch(
|
||||
self, all_rank_num_tokens: list, x: torch.Tensor,
|
||||
self, all_rank_max_num_tokens: int, x: torch.Tensor,
|
||||
token_selected_slots: torch.Tensor,
|
||||
token_final_scales: torch.Tensor,
|
||||
local_statistic_tensor: Optional[torch.Tensor]):
|
||||
top_k = self.routing_method.experts_per_token
|
||||
max_num_token = max(all_rank_num_tokens)
|
||||
|
||||
# TODO: support alltoall without allgather for top_k % 4 != 0
|
||||
if self.enable_alltoall_without_allgather and top_k % 4 == 0:
|
||||
alltoall_info, token_selected_slots, token_final_scales, gathered_local_statistic_tensor = MnnvlMoe.mnnvl_moe_alltoallv_prepare_without_allgather(
|
||||
token_selected_slots, token_final_scales,
|
||||
local_statistic_tensor, self.alltoall_prepare_workspace,
|
||||
max_num_token, self.ep_rank, self.ep_size, self.num_experts,
|
||||
self.num_slots, top_k)
|
||||
all_rank_max_num_tokens, self.ep_rank, self.ep_size,
|
||||
self.num_experts, self.num_slots, top_k)
|
||||
else:
|
||||
if max_num_token > token_selected_slots.shape[0]:
|
||||
if all_rank_max_num_tokens > token_selected_slots.shape[0]:
|
||||
token_selected_slots = torch.nn.functional.pad(
|
||||
token_selected_slots,
|
||||
(0, 0, 0, max_num_token - token_selected_slots.shape[0]),
|
||||
(0, 0, 0,
|
||||
all_rank_max_num_tokens - token_selected_slots.shape[0]),
|
||||
'constant', self.num_slots)
|
||||
if max_num_token > token_final_scales.shape[0]:
|
||||
if all_rank_max_num_tokens > token_final_scales.shape[0]:
|
||||
token_final_scales = torch.nn.functional.pad(
|
||||
token_final_scales,
|
||||
(0, 0, 0, max_num_token - token_final_scales.shape[0]))
|
||||
(0, 0, 0,
|
||||
all_rank_max_num_tokens - token_final_scales.shape[0]))
|
||||
gathered_token_selected_slots, gathered_token_final_scales, gathered_local_statistic_tensor = allgather(
|
||||
[
|
||||
token_selected_slots, token_final_scales,
|
||||
@ -848,8 +863,8 @@ class WideEPMoE(MoE):
|
||||
gathered_token_selected_slots, self.num_slots, self.ep_size)
|
||||
alltoall_info, token_selected_slots, token_final_scales = MnnvlMoe.mnnvl_moe_alltoallv_prepare(
|
||||
gathered_target_rank_ids, None, gathered_token_selected_slots,
|
||||
gathered_token_final_scales, max_num_token, self.num_slots,
|
||||
top_k, self.ep_rank, self.ep_size)
|
||||
gathered_token_final_scales, all_rank_max_num_tokens,
|
||||
self.num_slots, top_k, self.ep_rank, self.ep_size)
|
||||
|
||||
if not self.use_postquant_alltoall:
|
||||
assert not isinstance(
|
||||
|
||||
@ -1097,7 +1097,7 @@ class NVFP4CutlassFusedMoEMethod(NVFP4FusedMoEMethod):
|
||||
orig_shape = dst_w3_w1_weight_scale.shape
|
||||
|
||||
dst_w3_w1_weight_scale.copy_(
|
||||
torch.ops.tensorrt_llm.nvfp4_block_scale_interleave(
|
||||
torch.ops.trtllm.nvfp4_block_scale_interleave(
|
||||
dst_w3_w1_weight_scale.view(float4_sf_dtype)).view(
|
||||
self.block_scales_dtype).reshape(orig_shape))
|
||||
|
||||
@ -1114,7 +1114,7 @@ class NVFP4CutlassFusedMoEMethod(NVFP4FusedMoEMethod):
|
||||
orig_shape = dst_w2_weight_scale.shape
|
||||
|
||||
dst_w2_weight_scale.copy_(
|
||||
torch.ops.tensorrt_llm.nvfp4_block_scale_interleave(
|
||||
torch.ops.trtllm.nvfp4_block_scale_interleave(
|
||||
dst_w2_weight_scale.view(float4_sf_dtype)).view(
|
||||
self.block_scales_dtype).reshape(orig_shape))
|
||||
|
||||
@ -1296,7 +1296,7 @@ class NVFP4TRTLLMGenFusedMoEMethod(NVFP4FusedMoEMethod):
|
||||
# Assert should only be removed during debugging
|
||||
assert w3_w1_weight_scale.is_cuda, "w3_w1_weight_scale.is_cuda should be true or suffer from slow speed"
|
||||
# Interleave the weight.
|
||||
processed_w3_w1_weight_scale = torch.ops.tensorrt_llm.nvfp4_block_scale_interleave(
|
||||
processed_w3_w1_weight_scale = torch.ops.trtllm.nvfp4_block_scale_interleave(
|
||||
w3_w1_weight_scale.view(float4_sf_dtype).reshape(orig_shape))
|
||||
# Copy the result into device buffer
|
||||
dst_w3_w1_weight_scale.copy_(
|
||||
@ -1331,7 +1331,7 @@ class NVFP4TRTLLMGenFusedMoEMethod(NVFP4FusedMoEMethod):
|
||||
w_shuffled = torch.ops.trtllm.shuffle_matrix(
|
||||
dst_w2_weight_scale.view(dtype=float4_sf_dtype), permute_indices)
|
||||
# Interleave the weight.
|
||||
processed_w2_weight_scale = torch.ops.tensorrt_llm.nvfp4_block_scale_interleave(
|
||||
processed_w2_weight_scale = torch.ops.trtllm.nvfp4_block_scale_interleave(
|
||||
w_shuffled)
|
||||
# Copy the result into device buffer
|
||||
dst_w2_weight_scale.copy_(
|
||||
|
||||
@ -589,7 +589,7 @@ class NVFP4LinearMethod(LinearMethodBase):
|
||||
assert len(weights) == 1
|
||||
weight_scale = weight_scale[0]
|
||||
# Swizzle weight scale
|
||||
weight_scale = torch.ops.tensorrt_llm.nvfp4_block_scale_interleave(
|
||||
weight_scale = torch.ops.trtllm.nvfp4_block_scale_interleave(
|
||||
weight_scale)
|
||||
|
||||
copy_weight(module.input_scale, input_scale)
|
||||
@ -610,7 +610,7 @@ class NVFP4LinearMethod(LinearMethodBase):
|
||||
tp_mode=module.tp_mode)
|
||||
# Swizzle weight scales after concatenation
|
||||
weight_scale = torch.cat(weight_scales, 0)
|
||||
weight_scale = torch.ops.tensorrt_llm.nvfp4_block_scale_interleave(
|
||||
weight_scale = torch.ops.trtllm.nvfp4_block_scale_interleave(
|
||||
weight_scale)
|
||||
copy_weight(module.input_scale, input_scale)
|
||||
copy_weight(module.weight_scale, weight_scale)
|
||||
@ -633,7 +633,7 @@ class NVFP4LinearMethod(LinearMethodBase):
|
||||
tp_mode=module.tp_mode)
|
||||
# Swizzle weight scales after concatenation
|
||||
weight_scale = torch.cat(weight_scales, 0)
|
||||
weight_scale = torch.ops.tensorrt_llm.nvfp4_block_scale_interleave(
|
||||
weight_scale = torch.ops.trtllm.nvfp4_block_scale_interleave(
|
||||
weight_scale)
|
||||
copy_weight(module.input_scale, input_scale)
|
||||
copy_weight(module.weight_scale, weight_scale)
|
||||
@ -717,7 +717,7 @@ class W4A8MXFP4FP8LinearMethod(LinearMethodBase):
|
||||
assert len(weights) == 1
|
||||
weight_scale = weight_scale[0]
|
||||
# Swizzle weight scale
|
||||
weight_scale = torch.ops.tensorrt_llm.nvfp4_block_scale_interleave(
|
||||
weight_scale = torch.ops.trtllm.nvfp4_block_scale_interleave(
|
||||
weight_scale)
|
||||
copy_weight(module.weight_scale, weight_scale)
|
||||
|
||||
@ -733,7 +733,7 @@ class W4A8MXFP4FP8LinearMethod(LinearMethodBase):
|
||||
tp_rank=module.tp_rank,
|
||||
tp_mode=module.tp_mode)
|
||||
weight_scale = torch.cat(weight_scale, 0)
|
||||
weight_scale = torch.ops.tensorrt_llm.nvfp4_block_scale_interleave(
|
||||
weight_scale = torch.ops.trtllm.nvfp4_block_scale_interleave(
|
||||
weight_scale)
|
||||
copy_weight(module.weight_scale, weight_scale)
|
||||
|
||||
@ -750,7 +750,7 @@ class W4A8MXFP4FP8LinearMethod(LinearMethodBase):
|
||||
tp_mode=module.tp_mode)
|
||||
# Swizzle weight scales after concatenation
|
||||
weight_scale = torch.cat(weight_scale, 0)
|
||||
weight_scale = torch.ops.tensorrt_llm.nvfp4_block_scale_interleave(
|
||||
weight_scale = torch.ops.trtllm.nvfp4_block_scale_interleave(
|
||||
weight_scale)
|
||||
copy_weight(module.weight_scale, weight_scale)
|
||||
|
||||
|
||||
@ -7,7 +7,14 @@ from ..attention_backend.interface import AttentionMetadata
|
||||
from ..speculative.interface import SpecMetadata
|
||||
from ..utils import make_weak_ref, set_piecewise_cuda_graph_flag
|
||||
|
||||
_local = threading.local()
|
||||
|
||||
class graph_capturing_local(threading.local):
|
||||
|
||||
def __init__(self):
|
||||
self.is_graph_capturing = False
|
||||
|
||||
|
||||
_local = graph_capturing_local()
|
||||
|
||||
|
||||
def set_graph_capturing(enable: bool):
|
||||
@ -15,8 +22,6 @@ def set_graph_capturing(enable: bool):
|
||||
|
||||
|
||||
def is_graph_capturing() -> bool:
|
||||
if not hasattr(_local, 'is_graph_capturing'):
|
||||
return False
|
||||
return _local.is_graph_capturing
|
||||
|
||||
|
||||
|
||||
@ -156,7 +156,13 @@ class SpecMetadata:
|
||||
# The number of tokens for speculative model/layer
|
||||
num_tokens: int = 0
|
||||
# The number of tokens for speculative model/layer of different rank
|
||||
all_rank_num_tokens: Optional[List[int]] = None
|
||||
_all_rank_num_tokens: Optional[List[int]] = field(init=False,
|
||||
default=None,
|
||||
repr=False)
|
||||
all_rank_num_tokens: Optional[List[int]]
|
||||
# The max number of tokens among all ranks.
|
||||
all_rank_max_num_tokens: Optional[int] = None
|
||||
|
||||
# The number of sequences for speculative model/layer of different rank
|
||||
all_rank_num_seqs: Optional[List[int]] = None
|
||||
# The number of extra kv tokens
|
||||
@ -193,3 +199,13 @@ class SpecMetadata:
|
||||
Some spec decode algorithms require hidden states from the target
|
||||
model. Use this method to record them. By default, does nothing.
|
||||
"""
|
||||
|
||||
@property
|
||||
def all_rank_num_tokens(self) -> Optional[List[int]]:
|
||||
return self._all_rank_num_tokens
|
||||
|
||||
@all_rank_num_tokens.setter
|
||||
def all_rank_num_tokens(self, value: Optional[List[int]]):
|
||||
value = value if value is not SpecMetadata.all_rank_num_tokens else None
|
||||
self._all_rank_num_tokens = value
|
||||
self.all_rank_max_num_tokens = max(value) if value is not None else None
|
||||
|
||||
@ -1137,6 +1137,8 @@ class MTPEagleWorker(MTPWorker):
|
||||
hidden_states = mtp_layers[0](
|
||||
embed_tokens=embed_tokens,
|
||||
all_rank_num_tokens=spec_metadata.all_rank_num_tokens,
|
||||
all_rank_max_num_tokens=spec_metadata.
|
||||
all_rank_max_num_tokens,
|
||||
**inputs)
|
||||
start_ids_gen = (spec_metadata.batch_indices_cuda[:num_gens] *
|
||||
(self.mtp_num_modules + 1)).long()
|
||||
@ -1146,10 +1148,15 @@ class MTPEagleWorker(MTPWorker):
|
||||
gather_ids = torch.concat(
|
||||
[last_tokens_idx[:num_contexts], gather_ids_gen], dim=0)
|
||||
else:
|
||||
hidden_states = mtp_layers[0](embed_tokens=embed_tokens,
|
||||
all_rank_num_tokens=spec_metadata.
|
||||
subseq_all_rank_num_tokens,
|
||||
**inputs)
|
||||
hidden_states = mtp_layers[0](
|
||||
embed_tokens=embed_tokens,
|
||||
all_rank_num_tokens=spec_metadata.
|
||||
subseq_all_rank_num_tokens,
|
||||
all_rank_max_num_tokens=max(
|
||||
spec_metadata.subseq_all_rank_num_tokens)
|
||||
if spec_metadata.subseq_all_rank_num_tokens is not None else
|
||||
None,
|
||||
**inputs)
|
||||
# All of the seq_len are 1, use batch_indices_cuda as gather_ids
|
||||
gather_ids = spec_metadata.batch_indices_cuda[:batch_size]
|
||||
logits = mtp_layers[0].shared_head(hidden_states[gather_ids],
|
||||
|
||||
@ -128,7 +128,7 @@ def swizzle_sf(sf: torch.Tensor,
|
||||
"""
|
||||
sf_cols = ceil_div(cols, scaling_vector_size)
|
||||
sf = sf.view(-1, rows, sf_cols)
|
||||
return torch.ops.tensorrt_llm.nvfp4_block_scale_interleave(sf)
|
||||
return torch.ops.trtllm.nvfp4_block_scale_interleave(sf)
|
||||
|
||||
|
||||
def unswizzle_sf(sf: torch.Tensor,
|
||||
@ -146,14 +146,15 @@ def unswizzle_sf(sf: torch.Tensor,
|
||||
"""
|
||||
sf_cols = ceil_div(cols, scaling_vector_size)
|
||||
sf = sf.view(-1, rows, sf_cols)
|
||||
return torch.ops.tensorrt_llm.nvfp4_block_scale_interleave_reverse(sf).view(
|
||||
return torch.ops.trtllm.nvfp4_block_scale_interleave_reverse(sf).view(
|
||||
-1, sf_cols)
|
||||
|
||||
|
||||
@torch.library.custom_op("trtllm::reswizzle_sf", mutates_args=())
|
||||
def reswizzle_sf(sf: torch.Tensor,
|
||||
rows: int,
|
||||
cols: int,
|
||||
scaling_vector_size: int = 16):
|
||||
scaling_vector_size: int = 16) -> torch.Tensor:
|
||||
"""Reswizzle FP4 scaling factors using C++ torch op implementation.
|
||||
It unswizzles the scaling factors in each partition first, then concatenates them together, and finally swizzles them back.
|
||||
Args:
|
||||
@ -189,6 +190,16 @@ def reswizzle_sf(sf: torch.Tensor,
|
||||
return swizzle_sf(sf_concatenated, total_rows, cols, scaling_vector_size)
|
||||
|
||||
|
||||
@torch.library.register_fake("trtllm::reswizzle_sf")
|
||||
def _(sf, rows, cols, scaling_vector_size=16):
|
||||
sf_cols = ceil_div(cols, scaling_vector_size)
|
||||
padded_rows, padded_sf_cols = compute_swizzled_sf_shape(rows, sf_cols)
|
||||
num_partitions = sf.numel() // (padded_rows * padded_sf_cols)
|
||||
total_rows = num_partitions * rows
|
||||
sz = pad_up(total_rows, 128) * pad_up(cols, 4)
|
||||
return sf.new_empty(sz)
|
||||
|
||||
|
||||
def next_positive_power_of_2(x: int) -> int:
|
||||
if x < 1:
|
||||
return 1
|
||||
|
||||
@ -692,7 +692,7 @@ class MOEWeightWrapper(Module):
|
||||
weights = stack_weights(tllm_key, weights)
|
||||
if tllm_key.endswith("weights_block_scaling_factor_interleaved"):
|
||||
weights = stack_weights(tllm_key, weights)
|
||||
weights = torch.ops.tensorrt_llm.nvfp4_block_scale_interleave(
|
||||
weights = torch.ops.trtllm.nvfp4_block_scale_interleave(
|
||||
weights.to(torch.float8_e4m3fn).view(
|
||||
torch.uint8).cpu().contiguous()).reshape(
|
||||
weights.shape).view(torch.float8_e4m3fn)
|
||||
|
||||
@ -1811,7 +1811,7 @@ def preprocess_perlayer_weights(weights,
|
||||
weights[new_name] = weights[name]
|
||||
weights[
|
||||
new_name +
|
||||
"_interleaved"] = torch.ops.tensorrt_llm.nvfp4_block_scale_interleave(
|
||||
"_interleaved"] = torch.ops.trtllm.nvfp4_block_scale_interleave(
|
||||
weights[name].view(fp4_utils.float4_sf_dtype).cpu(
|
||||
).contiguous()).reshape(nrows, ncols).view(
|
||||
fp4_utils.float4_sf_dtype)
|
||||
|
||||
@ -2218,7 +2218,7 @@ class FP4Linear(Linear):
|
||||
qkv_block_scale,
|
||||
tllm_key.replace(
|
||||
'weight', "weights_block_scaling_factor_interleaved"):
|
||||
torch.ops.tensorrt_llm.nvfp4_block_scale_interleave(
|
||||
torch.ops.trtllm.nvfp4_block_scale_interleave(
|
||||
qkv_block_scale.view(
|
||||
torch.uint8).cpu().contiguous()).reshape(
|
||||
qkv_block_scale.shape).view(
|
||||
@ -2238,7 +2238,7 @@ class FP4Linear(Linear):
|
||||
elif tllm_key.endswith("weights_block_scaling_factor"):
|
||||
return weights
|
||||
elif tllm_key.endswith("weights_block_scaling_factor_interleaved"):
|
||||
return torch.ops.tensorrt_llm.nvfp4_block_scale_interleave(
|
||||
return torch.ops.trtllm.nvfp4_block_scale_interleave(
|
||||
weights.view(torch.uint8).cpu().contiguous()).reshape(
|
||||
weights.shape).view(torch.float8_e4m3fn)
|
||||
elif tllm_key.endswith("weights_global_scaling_factor"):
|
||||
@ -2379,7 +2379,7 @@ class FP4RowLinear(RowLinear):
|
||||
elif tllm_key.endswith("weights_block_scaling_factor"):
|
||||
return weights
|
||||
elif tllm_key.endswith("weights_block_scaling_factor_interleaved"):
|
||||
return torch.ops.tensorrt_llm.nvfp4_block_scale_interleave(
|
||||
return torch.ops.trtllm.nvfp4_block_scale_interleave(
|
||||
weights.view(torch.uint8).cpu().contiguous()).reshape(
|
||||
weights.shape).view(torch.float8_e4m3fn)
|
||||
elif tllm_key.endswith("weights_global_scaling_factor"):
|
||||
|
||||
@ -28,7 +28,7 @@ class FP4GemmType(IntEnum):
|
||||
W4A8_MXFP4_MXFP8 = 1
|
||||
|
||||
|
||||
def get_fp4_shape(input_shape, sf_vec_size):
|
||||
def get_fp4_shape(input_shape, sf_vec_size, is_swizzled_layout=True):
|
||||
m = 1
|
||||
for i in range(len(input_shape) - 1):
|
||||
m *= input_shape[i]
|
||||
@ -36,7 +36,9 @@ def get_fp4_shape(input_shape, sf_vec_size):
|
||||
output_shape = [i for i in input_shape]
|
||||
output_shape[-1] //= 2
|
||||
|
||||
scale_shape = pad_up(m, 128) * pad_up(input_shape[-1] // sf_vec_size, 4)
|
||||
scale_shape = pad_up(m, 128) * pad_up(
|
||||
input_shape[-1] // sf_vec_size,
|
||||
4) if is_swizzled_layout else m * (input_shape[-1] // sf_vec_size)
|
||||
return output_shape, scale_shape
|
||||
|
||||
|
||||
@ -205,4 +207,4 @@ def shuffle_matrix_sf_a(
|
||||
input_tensor, row_indices.to(input_tensor.device))
|
||||
|
||||
# 128x4
|
||||
return torch.ops.tensorrt_llm.nvfp4_block_scale_interleave(w_shuffled)
|
||||
return torch.ops.trtllm.nvfp4_block_scale_interleave(w_shuffled)
|
||||
|
||||
@ -516,8 +516,7 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
overlap_scheduler, torch_compile):
|
||||
if torch_compile and mtp_nextn > 0:
|
||||
pytest.skip("https://nvbugs/5252313")
|
||||
if torch_compile and attention_dp:
|
||||
pytest.skip("https://nvbugs/5252559")
|
||||
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9)
|
||||
torch_compile_config = TorchCompileConfig(
|
||||
enable_fullgraph=True,
|
||||
@ -558,14 +557,13 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
torch_compile):
|
||||
if torch_compile and mtp_nextn > 0:
|
||||
pytest.skip("https://nvbugs/5252313")
|
||||
if torch_compile and attention_dp:
|
||||
pytest.skip("https://nvbugs/5252559")
|
||||
if torch_compile and pp_size > 1:
|
||||
pytest.skip("PP with torch.compile is not supported yet.")
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9)
|
||||
torch_compile_config = TorchCompileConfig(
|
||||
enable_fullgraph=True,
|
||||
enable_piecewise_cuda_graph=cuda_graph) if torch_compile else None
|
||||
enable_piecewise_cuda_graph=cuda_graph
|
||||
and not attention_dp) if torch_compile else None
|
||||
pytorch_config = dict(
|
||||
disable_overlap_scheduler=not overlap_scheduler,
|
||||
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
|
||||
@ -602,8 +600,6 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
overlap_scheduler, torch_compile):
|
||||
if torch_compile and mtp != "disable":
|
||||
pytest.skip("https://nvbugs/5252313")
|
||||
if torch_compile and attention_dp:
|
||||
pytest.skip("https://nvbugs/5252559")
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9)
|
||||
torch_compile_config = TorchCompileConfig(
|
||||
enable_fullgraph=True,
|
||||
@ -775,14 +771,13 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
overlap_scheduler, torch_compile):
|
||||
if torch_compile and mtp_nextn > 0:
|
||||
pytest.skip("https://nvbugs/5252313")
|
||||
if torch_compile and attention_dp:
|
||||
pytest.skip("https://nvbugs/5252559")
|
||||
if torch_compile and pp_size > 1:
|
||||
pytest.skip("PP with torch.compile is not supported yet.")
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9)
|
||||
torch_compile_config = TorchCompileConfig(
|
||||
enable_fullgraph=True,
|
||||
enable_piecewise_cuda_graph=cuda_graph) if torch_compile else None
|
||||
enable_piecewise_cuda_graph=cuda_graph
|
||||
and not attention_dp) if torch_compile else None
|
||||
pytorch_config = dict(
|
||||
disable_overlap_scheduler=not overlap_scheduler,
|
||||
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
|
||||
@ -989,8 +984,6 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
torch_compile, mtp_nextn, moe_backend):
|
||||
if torch_compile and mtp_nextn > 0:
|
||||
pytest.skip("https://nvbugs/5252313")
|
||||
if torch_compile and attention_dp:
|
||||
pytest.skip("https://nvbugs/5252559")
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9)
|
||||
torch_compile_config = TorchCompileConfig(
|
||||
enable_fullgraph=True,
|
||||
@ -1048,16 +1041,16 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
torch_compile, mtp_nextn, moe_backend):
|
||||
if torch_compile and mtp_nextn > 0:
|
||||
pytest.skip("https://nvbugs/5252313")
|
||||
if torch_compile and attention_dp:
|
||||
pytest.skip("https://nvbugs/5252559")
|
||||
if torch_compile and pp_size > 1:
|
||||
pytest.skip("PP with torch.compile is not supported yet.")
|
||||
if not attention_dp and (tp_size > 1 or ep_size > 1):
|
||||
pytest.skip("https://nvbugs/5336321")
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9)
|
||||
# Picewise Cuda Graph cannot be enabled for nvfp4 attention dp.
|
||||
torch_compile_config = TorchCompileConfig(
|
||||
enable_fullgraph=True,
|
||||
enable_piecewise_cuda_graph=cuda_graph) if torch_compile else None
|
||||
enable_piecewise_cuda_graph=cuda_graph
|
||||
and not attention_dp) if torch_compile else None
|
||||
pytorch_config = dict(
|
||||
disable_overlap_scheduler=not overlap_scheduler,
|
||||
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
|
||||
|
||||
@ -214,11 +214,13 @@ def test_fused_moe_alltoall(alltoall_method_type):
|
||||
x,
|
||||
router_logits,
|
||||
all_rank_num_tokens=all_rank_num_tokens,
|
||||
all_rank_max_num_tokens=m,
|
||||
use_dp_padding=False)
|
||||
ref_output = ref_model.forward(
|
||||
x,
|
||||
router_logits,
|
||||
all_rank_num_tokens=all_rank_num_tokens,
|
||||
all_rank_max_num_tokens=m,
|
||||
use_dp_padding=False)
|
||||
|
||||
# Evaluate outputs
|
||||
@ -547,17 +549,17 @@ def test_fused_moe_nvfp4(dtype):
|
||||
|
||||
w1_weight_nvfp4, w1_sf_block = torch.ops.trtllm.fp4_quantize(
|
||||
w1_weight, w3_w1_global, SCALING_VECTOR_SIZE, False)
|
||||
w1_sf_block_unswizzled = torch.ops.tensorrt_llm.nvfp4_block_scale_interleave_reverse(
|
||||
w1_sf_block_unswizzled = torch.ops.trtllm.nvfp4_block_scale_interleave_reverse(
|
||||
w1_sf_block.cpu().view(INTERMEDIATE_SIZE, -1))
|
||||
|
||||
w2_weight_nvfp4, w2_sf_block = torch.ops.trtllm.fp4_quantize(
|
||||
w2_weight, w2_sf_global, SCALING_VECTOR_SIZE, False)
|
||||
w2_sf_block_unswizzled = torch.ops.tensorrt_llm.nvfp4_block_scale_interleave_reverse(
|
||||
w2_sf_block_unswizzled = torch.ops.trtllm.nvfp4_block_scale_interleave_reverse(
|
||||
w2_sf_block.cpu().view(HIDDEN_SIZE, -1))
|
||||
|
||||
w3_weight_nvfp4, w3_sf_block = torch.ops.trtllm.fp4_quantize(
|
||||
w3_weight, w3_w1_global, SCALING_VECTOR_SIZE, False)
|
||||
w3_sf_block_unswizzled = torch.ops.tensorrt_llm.nvfp4_block_scale_interleave_reverse(
|
||||
w3_sf_block_unswizzled = torch.ops.trtllm.nvfp4_block_scale_interleave_reverse(
|
||||
w3_sf_block.cpu().view(INTERMEDIATE_SIZE, -1))
|
||||
|
||||
w1_input_scale = x_sf_global.cuda()
|
||||
|
||||
@ -977,7 +977,7 @@ def run_single_rank_ub_pass_fp4(
|
||||
|
||||
def block_scale_unswizzled(scale):
|
||||
sz = fp4_utils.pad_up(hidden_size, 128)
|
||||
return torch.ops.tensorrt_llm.nvfp4_block_scale_interleave_reverse(
|
||||
return torch.ops.trtllm.nvfp4_block_scale_interleave_reverse(
|
||||
scale.cpu().view(sz, -1))[0:hidden_size]
|
||||
|
||||
l0_weight_scale_block_unswizzled = block_scale_unswizzled(
|
||||
|
||||
@ -172,8 +172,8 @@ def test_fp4_sf_interleave(b, m, k):
|
||||
w_cuda = w.cuda()
|
||||
|
||||
# The cpu and cuda kernels are different
|
||||
w_out_cpu = torch.ops.tensorrt_llm.nvfp4_block_scale_interleave(w)
|
||||
w_out_cuda = torch.ops.tensorrt_llm.nvfp4_block_scale_interleave(w_cuda)
|
||||
w_out_cpu = torch.ops.trtllm.nvfp4_block_scale_interleave(w)
|
||||
w_out_cuda = torch.ops.trtllm.nvfp4_block_scale_interleave(w_cuda)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
torch.testing.assert_allclose(w_out_cpu.cuda(), w_out_cuda)
|
||||
|
||||
@ -39,7 +39,7 @@ def test_fp4_linear(dtype):
|
||||
assert l_fp4.weight_scale.dtype == fp4_utils.float4_sf_dtype
|
||||
|
||||
w_sf_block_unswizzled = (
|
||||
torch.ops.tensorrt_llm.nvfp4_block_scale_interleave_reverse(
|
||||
torch.ops.trtllm.nvfp4_block_scale_interleave_reverse(
|
||||
w_sf_block.cpu().view(HIDDEN_SIZE, -1)))
|
||||
|
||||
l_fp4.load_weights([{
|
||||
|
||||
@ -70,10 +70,8 @@ class TestFunctional(unittest.TestCase):
|
||||
mat_b[0][0] = 36
|
||||
mat_b_ref[0][0] = 2.0
|
||||
|
||||
a_block_sf = torch.ops.tensorrt_llm.nvfp4_block_scale_interleave(
|
||||
a_block_sf)
|
||||
b_block_sf = torch.ops.tensorrt_llm.nvfp4_block_scale_interleave(
|
||||
b_block_sf)
|
||||
a_block_sf = torch.ops.trtllm.nvfp4_block_scale_interleave(a_block_sf)
|
||||
b_block_sf = torch.ops.trtllm.nvfp4_block_scale_interleave(b_block_sf)
|
||||
|
||||
c = (torch.ops.trtllm.w4a8_mxfp4_fp8_gemm(mat_a, mat_b, a_block_sf,
|
||||
b_block_sf, a_sf,
|
||||
|
||||
@ -1186,7 +1186,7 @@ class TestMoE(unittest.TestCase):
|
||||
moe_weight_wrapper.weights_block_scaling_factor_interleaved.value = (
|
||||
np.ascontiguousarray(
|
||||
torch_to_numpy(
|
||||
torch.ops.tensorrt_llm.nvfp4_block_scale_interleave(
|
||||
torch.ops.trtllm.nvfp4_block_scale_interleave(
|
||||
scale_factor.view(torch.uint8).contiguous()).view(
|
||||
scale_factor.dtype).reshape(
|
||||
scale_factor.shape).view(torch.uint8))))
|
||||
|
||||
Loading…
Reference in New Issue
Block a user