[feat] Support torch compile for attention dp (#5086)

Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com>
This commit is contained in:
liji-nv 2025-07-02 01:48:52 +08:00 committed by GitHub
parent f9a455651b
commit c345f5876c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
33 changed files with 299 additions and 127 deletions

View File

@ -70,7 +70,15 @@ public:
outputShape[0] *= mGroup.size();
}
auto output = torch::empty(outputShape, input.options());
if (sizes.has_value())
bool use_nccl_allgather = !sizes.has_value()
|| std::all_of(sizes.value().begin(), sizes.value().end(),
[&sizes](int64_t size) { return size == sizes.value()[0]; });
if (use_nccl_allgather)
{
NCCLCHECK_THROW(ncclAllGather(input.data_ptr(), output.mutable_data_ptr(), input.numel(),
(*getDtypeMap())[type], *mNcclComm, stream));
}
else
{
size_t numel_base = std::accumulate(outputShape.cbegin() + 1, outputShape.cend(), 1, std::multiplies<>{});
int64_t split_offset = 0;
@ -85,11 +93,6 @@ public:
}
ncclGroupEnd();
}
else
{
NCCLCHECK_THROW(ncclAllGather(input.data_ptr(), output.mutable_data_ptr(), input.numel(),
(*getDtypeMap())[type], *mNcclComm, stream));
}
return output;
}
@ -155,8 +158,8 @@ std::vector<torch::Tensor> allgather_list(
TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def("allgather(Tensor input, int[]? sizes, int[] group) -> Tensor");
m.def("allgather_list(Tensor[] input_list, int[]? sizes, int[] group) -> Tensor[]");
m.def("allgather(Tensor input, SymInt[]? sizes, int[] group) -> Tensor");
m.def("allgather_list(Tensor[] input_list, SymInt[]? sizes, int[] group) -> Tensor[]");
}
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)

View File

@ -718,10 +718,10 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
", int num_heads"
", int num_kv_heads"
", int head_size"
", int? tokens_per_block"
", int max_num_requests"
", int max_context_length"
", int attention_window_size"
", SymInt? tokens_per_block"
", SymInt max_num_requests"
", SymInt max_context_length"
", SymInt attention_window_size"
", int sink_token_length"
", int beam_width"
", int mask_type"

View File

@ -524,8 +524,20 @@ static auto e2m1_and_ufp8sf_scale_to_float
static auto e2m1_and_ufp8sf_scale_to_float_v2 = torch::RegisterOperators(
"tensorrt_llm::e2m1_and_ufp8sf_scale_to_float_v2", &torch_ext::E2M1AndUFP8SFScaleToFloatV2);
static auto nvfp4_block_scale_interleave
= torch::RegisterOperators("tensorrt_llm::nvfp4_block_scale_interleave", &torch_ext::NVFP4BlockScaleInterleave);
TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def("nvfp4_block_scale_interleave(Tensor input) -> Tensor");
m.def("nvfp4_block_scale_interleave_reverse(Tensor input) -> Tensor");
}
static auto nvfp4_block_scale_interleave_reverse = torch::RegisterOperators(
"tensorrt_llm::nvfp4_block_scale_interleave_reverse", &torch_ext::NVFP4BlockScaleInterleaveReverse);
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
{
m.impl("nvfp4_block_scale_interleave", &torch_ext::NVFP4BlockScaleInterleave);
m.impl("nvfp4_block_scale_interleave_reverse", &torch_ext::NVFP4BlockScaleInterleaveReverse);
}
TORCH_LIBRARY_IMPL(trtllm, CPU, m)
{
m.impl("nvfp4_block_scale_interleave", &torch_ext::NVFP4BlockScaleInterleave);
m.impl("nvfp4_block_scale_interleave_reverse", &torch_ext::NVFP4BlockScaleInterleaveReverse);
}

View File

@ -79,7 +79,15 @@ public:
outputShape[0] = outputShape[0] / mGroup.size();
}
auto output = torch::empty(outputShape, input.options());
if (sizes.has_value())
bool use_nccl_reducescatter = !sizes.has_value()
|| std::all_of(sizes.value().begin(), sizes.value().end(),
[&sizes](int64_t size) { return size == sizes.value()[0]; });
if (use_nccl_reducescatter)
{
NCCLCHECK_THROW(ncclReduceScatter(input.data_ptr(), output.mutable_data_ptr(), output.numel(),
(*getDtypeMap())[type], ncclSum, *mNcclComm, stream));
}
else
{
size_t numel_base = std::accumulate(outputShape.cbegin() + 1, outputShape.cend(), 1, std::multiplies<>{});
int64_t split_offset = 0;
@ -95,11 +103,6 @@ public:
}
ncclGroupEnd();
}
else
{
NCCLCHECK_THROW(ncclReduceScatter(input.data_ptr(), output.mutable_data_ptr(), output.numel(),
(*getDtypeMap())[type], ncclSum, *mNcclComm, stream));
}
return output;
}
@ -167,8 +170,8 @@ extern std::vector<torch::Tensor> reducescatter_list(
TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def("reducescatter(Tensor input, int[]? sizes, int[] group) -> Tensor");
m.def("reducescatter_list(Tensor[] input_list, int[]? sizes, int[] group) -> Tensor[]");
m.def("reducescatter(Tensor input, SymInt[]? sizes, int[] group) -> Tensor");
m.def("reducescatter_list(Tensor[] input_list, SymInt[]? sizes, int[] group) -> Tensor[]");
}
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)

View File

@ -118,7 +118,13 @@ class AttentionMetadata:
runtime_features: AttentionRuntimeFeatures = field(
default_factory=AttentionRuntimeFeatures)
all_rank_num_tokens: Optional[List[int]] = None
# The number of tokens in each rank.
_all_rank_num_tokens: Optional[List[int]] = field(init=False,
default=None,
repr=False)
all_rank_num_tokens: Optional[List[int]]
# The max number of tokens among all ranks.
all_rank_max_num_tokens: Optional[int] = None
# These fields are set when changing seq_lens and _num_contexts to avoid computation
# during execution. If the calculation happens during execution, torch compile treats it
@ -149,6 +155,16 @@ class AttentionMetadata:
elif self._seq_lens is not None:
self._num_tokens = self._seq_lens.sum().item()
@property
def all_rank_num_tokens(self) -> Optional[List[int]]:
return self._all_rank_num_tokens
@all_rank_num_tokens.setter
def all_rank_num_tokens(self, value: Optional[List[int]]):
value = value if value is not AttentionMetadata.all_rank_num_tokens else None
self._all_rank_num_tokens = value
self.all_rank_max_num_tokens = max(value) if value is not None else None
@property
def seq_lens(self) -> Optional[torch.Tensor]:
return self._seq_lens

View File

@ -300,7 +300,7 @@ class FP4QuantizationImpl(QuantizationImpl):
weight_scale = state_dict[weight_name + "_scale"].view(float4_sf_dtype)
ori_shape = weight_scale.shape
state_dict[weight_name + "_scale"] = (
torch.ops.tensorrt_llm.nvfp4_block_scale_interleave(
torch.ops.trtllm.nvfp4_block_scale_interleave(
weight_scale.view(torch.uint8).cpu().contiguous()
)
.reshape(ori_shape)

View File

@ -171,9 +171,10 @@ def _register_fake():
global_scale: torch.Tensor,
sf_vec_size: int,
sf_use_ue8m0=False,
swizzled_layout=True,
):
output_shape, scale_shape = fp4_utils.get_fp4_shape(
input.shape, sf_vec_size)
input.shape, sf_vec_size, swizzled_layout)
return (input.new_empty(output_shape, dtype=torch.uint8),
global_scale.new_empty(scale_shape, dtype=torch.uint8))
@ -473,3 +474,88 @@ def _register_fake():
hidden_size_val = int(hidden_size)
return gemm2_output.new_empty((num_rows_val, hidden_size_val),
dtype=gemm2_output.dtype)
@torch.library.register_fake("trtllm::allgather_list")
def _(input_list, sizes, group):
assert len(input_list) > 0
def create_output_tensor(i):
shape = list(i.shape)
if sizes is None:
shape[0] *= len(group)
else:
shape[0] = sum(sizes)
return i.new_empty(shape)
return [create_output_tensor(i) for i in input_list]
@torch.library.register_fake("trtllm::reducescatter")
def _(input, sizes, group):
import tensorrt_llm
local_rank = tensorrt_llm.mpi_rank()
shape = list(input.shape)
if sizes is None:
shape[0] = shape[0] // len(group)
else:
shape[0] = sizes[local_rank]
return input.new_empty(shape)
@torch.library.register_fake("trtllm::fp4_block_scale_moe_runner")
def _(
routing_logits,
routing_bias,
hidden_states,
hidden_states_scale,
gemm1_weights,
gemm1_weights_scale,
gemm2_weights,
gemm2_weights_scale,
output1_scale_scalar,
output1_scale_gate_scalar,
output2_scale_scalar,
num_experts,
top_k,
n_group,
topk_group,
intermediate_size,
local_expert_offset,
local_num_experts,
routed_scaling_factor,
tile_tokens_dim,
routing_method_type,
do_finalize,
) -> List[torch.Tensor]:
num_tokens = hidden_states.shape[0]
hidden_size = hidden_states.shape[1] * 2
if do_finalize:
return [
hidden_states.new_empty((num_tokens, hidden_size),
dtype=torch.bfloat16)
]
expanded_row_count = num_tokens * top_k
max_padding_required = (tile_tokens_dim - 1) * num_experts
max_num_padded_tokens = fp4_utils.pad_up(
expanded_row_count + max_padding_required, tile_tokens_dim)
wt_dtype = routing_bias.dtype if routing_bias is not None else torch.bfloat16
return [
hidden_states.new_empty((max_num_padded_tokens, hidden_size),
dtype=torch.bfloat16),
hidden_states.new_empty((num_tokens, top_k), dtype=wt_dtype),
hidden_states.new_empty((num_tokens, top_k), dtype=torch.int32)
]
@torch.library.register_fake("trtllm::nvfp4_block_scale_interleave")
def _(sf: torch.Tensor):
rows = sf.shape[-2]
cols = sf.shape[-1]
expert_out_size = fp4_utils.pad_up(rows, 128) * fp4_utils.pad_up(
cols, 4)
num_experts = sf.shape[0] if len(sf.shape) == 3 else 1
return sf.new_empty((num_experts * expert_out_size, ),
dtype=torch.uint8)
@torch.library.register_fake("trtllm::nvfp4_block_scale_interleave_reverse")
def _(sf: torch.Tensor):
return torch.empty_like(sf, dtype=torch.uint8)

View File

@ -224,6 +224,7 @@ def _(
ep_rank: int = 0,
cluster_size: int = 1,
cluster_rank: int = 0,
enable_alltoall: bool = False,
use_deepseek_fp8_block_scale: bool = False,
use_w4a8_group_scaling: bool = False,
min_latency_mode: bool = False,

View File

@ -1,7 +1,6 @@
import math
import os
import threading
from itertools import accumulate
from typing import List, Optional, Tuple, Union
import torch
@ -126,13 +125,14 @@ def filter_valid_input(
return input_list, valid_list
def restore_full_output(output_list: List[torch.Tensor],
def restore_full_output(valid_outputs: List[torch.Tensor],
valid_list: List[bool]) -> List[torch.Tensor]:
index_list = list(accumulate(map(int, valid_list)))
output_list = list(
map(lambda valid, index: output_list[index - 1]
if valid else None, valid_list, index_list))
return output_list
idx = 0
full_outputs = []
for v in valid_list:
full_outputs.append(valid_outputs[idx] if v else None)
idx += int(v)
return full_outputs
def allgather(
@ -178,12 +178,6 @@ def allgather(
val.shape[dim] == sizes[mapping.tp_rank] for val in input
if val is not None
])
# 'sizes' is not needed if all inputs in the same TP group have the same shape
for split_size in sizes[1:]:
if split_size != sizes[0]:
break
else:
sizes = None
# Inputs are reshaped in this way to pass necessary shape information to the allgather op
if isinstance(input, torch.Tensor):
@ -247,12 +241,6 @@ def reducescatter(
val.shape[dim] == sum_split_size for val in input
if val is not None
])
# 'sizes' is not needed if all outputs in the same TP group have the same shape
for split_size in sizes[1:]:
if split_size != sizes[0]:
break
else:
sizes = None
def convert_input(x, x_info):
# Inputs are reshaped in this way to pass necessary shape information to the reducescatter op

View File

@ -507,7 +507,8 @@ class Deepseekv3MoE(nn.Module):
return shared_tp_size, shared_output_scale
def compute_routed_output(self, hidden_states, hidden_states_fp4,
all_rank_num_tokens, do_finalize):
all_rank_num_tokens, all_rank_max_num_tokens,
do_finalize):
# max-throughput
use_dp_padding = False
if self.use_dp and self.mapping.tp_size > 1:
@ -522,10 +523,9 @@ class Deepseekv3MoE(nn.Module):
not self.experts.has_fp8_qdq and self.experts.has_nvfp4):
# Use padding when not using the cutlass path or when x_sf in self.experts is not None
use_dp_padding = True
max_num_token = max(all_rank_num_tokens)
hidden_states = torch.nn.functional.pad(
hidden_states,
(0, 0, 0, max_num_token - hidden_states.shape[0]))
(0, 0, 0, all_rank_max_num_tokens - hidden_states.shape[0]))
router_logits = self.gate(hidden_states)
@ -535,6 +535,7 @@ class Deepseekv3MoE(nn.Module):
do_finalize=do_finalize,
output_dtype=hidden_states.dtype,
all_rank_num_tokens=all_rank_num_tokens,
all_rank_max_num_tokens=all_rank_max_num_tokens,
use_dp_padding=use_dp_padding,
)
@ -545,6 +546,7 @@ class Deepseekv3MoE(nn.Module):
hidden_states: torch.Tensor,
hidden_states_fp4: Optional[Fp4QuantizedTensor] = None,
all_rank_num_tokens: Optional[list[int]] = None,
all_rank_max_num_tokens: Optional[int] = None,
final_all_reduce_params: Optional[AllReduceParams] = None,
do_finalize: Optional[bool] = True,
) -> torch.Tensor:
@ -562,6 +564,7 @@ class Deepseekv3MoE(nn.Module):
routed_output = self.compute_routed_output(hidden_states,
hidden_states_fp4,
all_rank_num_tokens,
all_rank_max_num_tokens,
do_finalize)
return routed_output
@ -774,6 +777,7 @@ class DeepseekV3DecoderLayer(DecoderLayer):
hidden_states,
hidden_states_fp4,
all_rank_num_tokens=attn_metadata.all_rank_num_tokens,
all_rank_max_num_tokens=attn_metadata.all_rank_max_num_tokens,
final_all_reduce_params=AllReduceParams(
enable_allreduce=not (self.fusion_config.POST_MOE_FUSION
or self.mapping.tp_size == 1)),
@ -934,6 +938,7 @@ class DeepseekV3MTP(DeepseekV3DecoderLayer):
embed_tokens: Embedding,
attn_metadata: AttentionMetadata,
all_rank_num_tokens: Optional[List[int]] = None,
all_rank_max_num_tokens: Optional[int] = None,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
@ -975,6 +980,7 @@ class DeepseekV3MTP(DeepseekV3DecoderLayer):
hidden_states = self.mlp(
hidden_states,
all_rank_num_tokens=all_rank_num_tokens,
all_rank_max_num_tokens=all_rank_max_num_tokens,
final_all_reduce_params=AllReduceParams(
enable_allreduce=not (self.fusion_config.POST_MOE_FUSION
or self.mapping.tp_size == 1)),

View File

@ -292,22 +292,22 @@ class Llama4MoE(nn.Module):
self.aux_stream = aux_stream
def compute_routed_output(self, hidden_states, all_rank_num_tokens,
all_rank_max_num_tokens,
cutlass_min_latency_mode):
use_dp_padding = False
if self.enable_attention_dp and self.mapping.tp_size > 1:
# Use padding here to keep the behavior unchanged
use_dp_padding = True
max_num_token_across_dp_ranks = max(all_rank_num_tokens)
hidden_states = torch.nn.functional.pad(
hidden_states,
(0, 0, 0,
max_num_token_across_dp_ranks - hidden_states.shape[0]))
(0, 0, 0, all_rank_max_num_tokens - hidden_states.shape[0]))
router_logits = self.router(hidden_states)
routed_output = self.experts(
hidden_states,
router_logits,
do_finalize=not cutlass_min_latency_mode,
all_rank_num_tokens=all_rank_num_tokens,
all_rank_max_num_tokens=all_rank_max_num_tokens,
use_dp_padding=use_dp_padding,
)
return routed_output
@ -316,6 +316,7 @@ class Llama4MoE(nn.Module):
self,
hidden_states: torch.Tensor,
all_rank_num_tokens=None,
all_rank_max_num_tokens=None,
final_all_reduce_params: Optional[AllReduceParams] = None,
cutlass_min_latency_mode: Optional[bool] = False,
) -> torch.Tensor:
@ -323,7 +324,8 @@ class Llama4MoE(nn.Module):
# This design is mainly for low latency use case. Need to improve for max throughput use case.
fn0 = lambda: self.shared_expert(hidden_states)
fn1 = lambda: self.compute_routed_output(
hidden_states, all_rank_num_tokens, cutlass_min_latency_mode)
hidden_states, all_rank_num_tokens, all_rank_max_num_tokens,
cutlass_min_latency_mode)
shared_output, routed_output = maybe_execute_in_parallel(
fn0, fn1, self.moe_event[0], self.moe_event[1], self.aux_stream)
if cutlass_min_latency_mode:
@ -479,6 +481,7 @@ class Llama4DecoderLayer(DecoderLayer):
hidden_states = self.feed_forward(
hidden_states,
all_rank_num_tokens=attn_metadata.all_rank_num_tokens,
all_rank_max_num_tokens=attn_metadata.all_rank_max_num_tokens,
final_all_reduce_params=AllReduceParams(enable_allreduce=not (
self.fusion_config.POST_MOE_FUSION
or self.fusion_config.POST_MLP_FUSION

View File

@ -61,19 +61,20 @@ class MixtralMoE(nn.Module):
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
all_rank_num_tokens = attn_metadata.all_rank_num_tokens
all_rank_max_num_tokens = attn_metadata.all_rank_max_num_tokens
use_dp_padding = False
if self.enable_attention_dp and len(all_rank_num_tokens) > 1:
# Use padding here to keep the behavior unchanged
use_dp_padding = True
max_num_token = max(all_rank_num_tokens)
hidden_states = torch.nn.functional.pad(
hidden_states,
(0, 0, 0, max_num_token - hidden_states.shape[0]))
(0, 0, 0, all_rank_max_num_tokens - hidden_states.shape[0]))
router_logits = self.gate(hidden_states)
final_hidden_states = self.experts(
hidden_states,
router_logits,
all_rank_num_tokens=all_rank_num_tokens,
all_rank_max_num_tokens=all_rank_max_num_tokens,
use_dp_padding=use_dp_padding)
return final_hidden_states

View File

@ -126,6 +126,7 @@ class Qwen3MoE(nn.Module):
hidden_states = hidden_states.view(-1, self.hidden_dim)
use_dp_padding = False
all_rank_num_tokens = attn_metadata.all_rank_num_tokens
all_rank_max_num_tokens = attn_metadata.all_rank_max_num_tokens
if not do_finalize:
assert not self.enable_attention_dp
@ -142,16 +143,16 @@ class Qwen3MoE(nn.Module):
not self.experts.has_fp8_qdq and self.experts.has_nvfp4):
# Use padding when not using the cutlass path or when x_sf in self.experts is not None
use_dp_padding = True
max_num_token = max(all_rank_num_tokens)
hidden_states = torch.nn.functional.pad(
hidden_states,
(0, 0, 0, max_num_token - hidden_states.shape[0]))
(0, 0, 0, all_rank_max_num_tokens - hidden_states.shape[0]))
router_logits = self.gate(hidden_states)
final_hidden_states = self.experts(
hidden_states,
router_logits,
all_rank_num_tokens=all_rank_num_tokens,
all_rank_max_num_tokens=all_rank_max_num_tokens,
use_dp_padding=use_dp_padding,
do_finalize=do_finalize,
)

View File

@ -85,11 +85,13 @@ class QwenMoE(nn.Module):
hidden_states = hidden_states.view(-1, self.hidden_dim)
all_rank_num_tokens = attn_metadata.all_rank_num_tokens
all_rank_max_num_tokens = attn_metadata.all_rank_max_num_tokens
router_logits = self.gate(hidden_states)
final_hidden_states = self.experts(
hidden_states,
router_logits,
all_rank_num_tokens=all_rank_num_tokens,
all_rank_max_num_tokens=all_rank_max_num_tokens,
use_dp_padding=False)
shared_expert_output = self.shared_expert(hidden_states)

View File

@ -326,6 +326,7 @@ class CutlassFusedMoE(MoE):
do_finalize: bool = True, # used by other MoE backends
output_dtype: Optional[torch.dtype] = None,
all_rank_num_tokens: Optional[List[int]] = None,
all_rank_max_num_tokens: Optional[int] = None,
use_dp_padding: Optional[bool] = None,
) -> torch.Tensor:
assert do_finalize, "CutlassFusedMoE does not support do_finalize=False"
@ -341,7 +342,7 @@ class CutlassFusedMoE(MoE):
1) // self.moe_max_num_tokens
if use_dp_padding:
all_rank_num_tokens_padded = [max(all_rank_num_tokens)
all_rank_num_tokens_padded = [all_rank_max_num_tokens
] * len(all_rank_num_tokens)
else:
all_rank_num_tokens_padded = all_rank_num_tokens

View File

@ -327,6 +327,7 @@ class WideEPMoE(MoE):
cutlass_min_latency_mode: bool = False,
output_dtype: Optional[torch.dtype] = None,
all_rank_num_tokens: Optional[List[int]] = None,
all_rank_max_num_tokens: Optional[int] = None,
use_dp_padding: Optional[bool] = None,
repeating_info: Tuple = (True, True),
) -> torch.Tensor:
@ -400,7 +401,7 @@ class WideEPMoE(MoE):
token_count = x.shape[0]
alltoall_info = None
x, token_selected_slots, token_final_scales, gathered_loadbalancer_local_statistic_info, alltoall_info = \
self.alltoall_prepare_maybe_dispatch(all_rank_num_tokens,
self.alltoall_prepare_maybe_dispatch(all_rank_max_num_tokens,
x,
token_selected_slots,
token_final_scales,
@ -663,6 +664,7 @@ class WideEPMoE(MoE):
do_finalize: bool = True,
output_dtype: Optional[torch.dtype] = None,
all_rank_num_tokens: Optional[List[int]] = None,
all_rank_max_num_tokens: Optional[int] = None,
use_dp_padding: Optional[bool] = None,
) -> torch.Tensor:
if self.use_dp:
@ -684,7 +686,7 @@ class WideEPMoE(MoE):
), "cutlass_min_latency_mode must be used with a single chunk and reduce_results must be False"
if use_dp_padding:
all_rank_num_tokens_padded = [max(all_rank_num_tokens)
all_rank_num_tokens_padded = [all_rank_max_num_tokens
] * len(all_rank_num_tokens)
else:
all_rank_num_tokens_padded = all_rank_num_tokens
@ -697,6 +699,7 @@ class WideEPMoE(MoE):
cutlass_min_latency_mode,
output_dtype,
all_rank_num_tokens=all_rank_num_tokens_padded,
all_rank_max_num_tokens=all_rank_max_num_tokens,
use_dp_padding=use_dp_padding,
repeating_info=(is_first_call, is_last_call))
outputs = self.reducescatter_or_allreduce(
@ -720,13 +723,20 @@ class WideEPMoE(MoE):
all_rank_num_tokens_list = [[
val[idx_chunk] for val in all_rank_chunk_size_list
] for idx_chunk in range(num_chunks)]
all_rank_max_num_tokens_list = split_chunk(
all_rank_max_num_tokens, num_chunks)
chunk_size_list = all_rank_chunk_size_list[self.rank]
if self.enable_alltoall:
all_rank_num_tokens_list = [[
1 if val == 0 else val for val in val_list
] for val_list in all_rank_num_tokens_list]
all_rank_max_num_tokens_list = [
1 if val == 0 else val
for val in all_rank_max_num_tokens_list
]
else:
all_rank_num_tokens_list = [None] * num_chunks
all_rank_max_num_tokens_list = [None] * num_chunks
chunk_size_list = split_chunk(x.shape[0], num_chunks)
x_list = x.split(chunk_size_list)
@ -751,6 +761,9 @@ class WideEPMoE(MoE):
router_logits,
all_rank_num_tokens=all_rank_num_tokens_list[
idx_chunk] if self.use_dp else None,
all_rank_max_num_tokens=
all_rank_max_num_tokens_list[idx_chunk]
if self.use_dp else None,
use_dp_padding=use_dp_padding,
repeating_info=(is_first_call, is_last_call))
if idx_chunk > 0:
@ -765,6 +778,8 @@ class WideEPMoE(MoE):
router_logits,
all_rank_num_tokens=all_rank_num_tokens_list[
idx_chunk] if self.use_dp else None,
all_rank_max_num_tokens=all_rank_max_num_tokens_list[
idx_chunk] if self.use_dp else None,
use_dp_padding=use_dp_padding,
repeating_info=(is_first_call, is_last_call))
with torch.cuda.stream(self.aux_stream):
@ -805,30 +820,30 @@ class WideEPMoE(MoE):
return outputs
def alltoall_prepare_maybe_dispatch(
self, all_rank_num_tokens: list, x: torch.Tensor,
self, all_rank_max_num_tokens: int, x: torch.Tensor,
token_selected_slots: torch.Tensor,
token_final_scales: torch.Tensor,
local_statistic_tensor: Optional[torch.Tensor]):
top_k = self.routing_method.experts_per_token
max_num_token = max(all_rank_num_tokens)
# TODO: support alltoall without allgather for top_k % 4 != 0
if self.enable_alltoall_without_allgather and top_k % 4 == 0:
alltoall_info, token_selected_slots, token_final_scales, gathered_local_statistic_tensor = MnnvlMoe.mnnvl_moe_alltoallv_prepare_without_allgather(
token_selected_slots, token_final_scales,
local_statistic_tensor, self.alltoall_prepare_workspace,
max_num_token, self.ep_rank, self.ep_size, self.num_experts,
self.num_slots, top_k)
all_rank_max_num_tokens, self.ep_rank, self.ep_size,
self.num_experts, self.num_slots, top_k)
else:
if max_num_token > token_selected_slots.shape[0]:
if all_rank_max_num_tokens > token_selected_slots.shape[0]:
token_selected_slots = torch.nn.functional.pad(
token_selected_slots,
(0, 0, 0, max_num_token - token_selected_slots.shape[0]),
(0, 0, 0,
all_rank_max_num_tokens - token_selected_slots.shape[0]),
'constant', self.num_slots)
if max_num_token > token_final_scales.shape[0]:
if all_rank_max_num_tokens > token_final_scales.shape[0]:
token_final_scales = torch.nn.functional.pad(
token_final_scales,
(0, 0, 0, max_num_token - token_final_scales.shape[0]))
(0, 0, 0,
all_rank_max_num_tokens - token_final_scales.shape[0]))
gathered_token_selected_slots, gathered_token_final_scales, gathered_local_statistic_tensor = allgather(
[
token_selected_slots, token_final_scales,
@ -848,8 +863,8 @@ class WideEPMoE(MoE):
gathered_token_selected_slots, self.num_slots, self.ep_size)
alltoall_info, token_selected_slots, token_final_scales = MnnvlMoe.mnnvl_moe_alltoallv_prepare(
gathered_target_rank_ids, None, gathered_token_selected_slots,
gathered_token_final_scales, max_num_token, self.num_slots,
top_k, self.ep_rank, self.ep_size)
gathered_token_final_scales, all_rank_max_num_tokens,
self.num_slots, top_k, self.ep_rank, self.ep_size)
if not self.use_postquant_alltoall:
assert not isinstance(

View File

@ -1097,7 +1097,7 @@ class NVFP4CutlassFusedMoEMethod(NVFP4FusedMoEMethod):
orig_shape = dst_w3_w1_weight_scale.shape
dst_w3_w1_weight_scale.copy_(
torch.ops.tensorrt_llm.nvfp4_block_scale_interleave(
torch.ops.trtllm.nvfp4_block_scale_interleave(
dst_w3_w1_weight_scale.view(float4_sf_dtype)).view(
self.block_scales_dtype).reshape(orig_shape))
@ -1114,7 +1114,7 @@ class NVFP4CutlassFusedMoEMethod(NVFP4FusedMoEMethod):
orig_shape = dst_w2_weight_scale.shape
dst_w2_weight_scale.copy_(
torch.ops.tensorrt_llm.nvfp4_block_scale_interleave(
torch.ops.trtllm.nvfp4_block_scale_interleave(
dst_w2_weight_scale.view(float4_sf_dtype)).view(
self.block_scales_dtype).reshape(orig_shape))
@ -1296,7 +1296,7 @@ class NVFP4TRTLLMGenFusedMoEMethod(NVFP4FusedMoEMethod):
# Assert should only be removed during debugging
assert w3_w1_weight_scale.is_cuda, "w3_w1_weight_scale.is_cuda should be true or suffer from slow speed"
# Interleave the weight.
processed_w3_w1_weight_scale = torch.ops.tensorrt_llm.nvfp4_block_scale_interleave(
processed_w3_w1_weight_scale = torch.ops.trtllm.nvfp4_block_scale_interleave(
w3_w1_weight_scale.view(float4_sf_dtype).reshape(orig_shape))
# Copy the result into device buffer
dst_w3_w1_weight_scale.copy_(
@ -1331,7 +1331,7 @@ class NVFP4TRTLLMGenFusedMoEMethod(NVFP4FusedMoEMethod):
w_shuffled = torch.ops.trtllm.shuffle_matrix(
dst_w2_weight_scale.view(dtype=float4_sf_dtype), permute_indices)
# Interleave the weight.
processed_w2_weight_scale = torch.ops.tensorrt_llm.nvfp4_block_scale_interleave(
processed_w2_weight_scale = torch.ops.trtllm.nvfp4_block_scale_interleave(
w_shuffled)
# Copy the result into device buffer
dst_w2_weight_scale.copy_(

View File

@ -589,7 +589,7 @@ class NVFP4LinearMethod(LinearMethodBase):
assert len(weights) == 1
weight_scale = weight_scale[0]
# Swizzle weight scale
weight_scale = torch.ops.tensorrt_llm.nvfp4_block_scale_interleave(
weight_scale = torch.ops.trtllm.nvfp4_block_scale_interleave(
weight_scale)
copy_weight(module.input_scale, input_scale)
@ -610,7 +610,7 @@ class NVFP4LinearMethod(LinearMethodBase):
tp_mode=module.tp_mode)
# Swizzle weight scales after concatenation
weight_scale = torch.cat(weight_scales, 0)
weight_scale = torch.ops.tensorrt_llm.nvfp4_block_scale_interleave(
weight_scale = torch.ops.trtllm.nvfp4_block_scale_interleave(
weight_scale)
copy_weight(module.input_scale, input_scale)
copy_weight(module.weight_scale, weight_scale)
@ -633,7 +633,7 @@ class NVFP4LinearMethod(LinearMethodBase):
tp_mode=module.tp_mode)
# Swizzle weight scales after concatenation
weight_scale = torch.cat(weight_scales, 0)
weight_scale = torch.ops.tensorrt_llm.nvfp4_block_scale_interleave(
weight_scale = torch.ops.trtllm.nvfp4_block_scale_interleave(
weight_scale)
copy_weight(module.input_scale, input_scale)
copy_weight(module.weight_scale, weight_scale)
@ -717,7 +717,7 @@ class W4A8MXFP4FP8LinearMethod(LinearMethodBase):
assert len(weights) == 1
weight_scale = weight_scale[0]
# Swizzle weight scale
weight_scale = torch.ops.tensorrt_llm.nvfp4_block_scale_interleave(
weight_scale = torch.ops.trtllm.nvfp4_block_scale_interleave(
weight_scale)
copy_weight(module.weight_scale, weight_scale)
@ -733,7 +733,7 @@ class W4A8MXFP4FP8LinearMethod(LinearMethodBase):
tp_rank=module.tp_rank,
tp_mode=module.tp_mode)
weight_scale = torch.cat(weight_scale, 0)
weight_scale = torch.ops.tensorrt_llm.nvfp4_block_scale_interleave(
weight_scale = torch.ops.trtllm.nvfp4_block_scale_interleave(
weight_scale)
copy_weight(module.weight_scale, weight_scale)
@ -750,7 +750,7 @@ class W4A8MXFP4FP8LinearMethod(LinearMethodBase):
tp_mode=module.tp_mode)
# Swizzle weight scales after concatenation
weight_scale = torch.cat(weight_scale, 0)
weight_scale = torch.ops.tensorrt_llm.nvfp4_block_scale_interleave(
weight_scale = torch.ops.trtllm.nvfp4_block_scale_interleave(
weight_scale)
copy_weight(module.weight_scale, weight_scale)

View File

@ -7,7 +7,14 @@ from ..attention_backend.interface import AttentionMetadata
from ..speculative.interface import SpecMetadata
from ..utils import make_weak_ref, set_piecewise_cuda_graph_flag
_local = threading.local()
class graph_capturing_local(threading.local):
def __init__(self):
self.is_graph_capturing = False
_local = graph_capturing_local()
def set_graph_capturing(enable: bool):
@ -15,8 +22,6 @@ def set_graph_capturing(enable: bool):
def is_graph_capturing() -> bool:
if not hasattr(_local, 'is_graph_capturing'):
return False
return _local.is_graph_capturing

View File

@ -156,7 +156,13 @@ class SpecMetadata:
# The number of tokens for speculative model/layer
num_tokens: int = 0
# The number of tokens for speculative model/layer of different rank
all_rank_num_tokens: Optional[List[int]] = None
_all_rank_num_tokens: Optional[List[int]] = field(init=False,
default=None,
repr=False)
all_rank_num_tokens: Optional[List[int]]
# The max number of tokens among all ranks.
all_rank_max_num_tokens: Optional[int] = None
# The number of sequences for speculative model/layer of different rank
all_rank_num_seqs: Optional[List[int]] = None
# The number of extra kv tokens
@ -193,3 +199,13 @@ class SpecMetadata:
Some spec decode algorithms require hidden states from the target
model. Use this method to record them. By default, does nothing.
"""
@property
def all_rank_num_tokens(self) -> Optional[List[int]]:
return self._all_rank_num_tokens
@all_rank_num_tokens.setter
def all_rank_num_tokens(self, value: Optional[List[int]]):
value = value if value is not SpecMetadata.all_rank_num_tokens else None
self._all_rank_num_tokens = value
self.all_rank_max_num_tokens = max(value) if value is not None else None

View File

@ -1137,6 +1137,8 @@ class MTPEagleWorker(MTPWorker):
hidden_states = mtp_layers[0](
embed_tokens=embed_tokens,
all_rank_num_tokens=spec_metadata.all_rank_num_tokens,
all_rank_max_num_tokens=spec_metadata.
all_rank_max_num_tokens,
**inputs)
start_ids_gen = (spec_metadata.batch_indices_cuda[:num_gens] *
(self.mtp_num_modules + 1)).long()
@ -1146,10 +1148,15 @@ class MTPEagleWorker(MTPWorker):
gather_ids = torch.concat(
[last_tokens_idx[:num_contexts], gather_ids_gen], dim=0)
else:
hidden_states = mtp_layers[0](embed_tokens=embed_tokens,
all_rank_num_tokens=spec_metadata.
subseq_all_rank_num_tokens,
**inputs)
hidden_states = mtp_layers[0](
embed_tokens=embed_tokens,
all_rank_num_tokens=spec_metadata.
subseq_all_rank_num_tokens,
all_rank_max_num_tokens=max(
spec_metadata.subseq_all_rank_num_tokens)
if spec_metadata.subseq_all_rank_num_tokens is not None else
None,
**inputs)
# All of the seq_len are 1, use batch_indices_cuda as gather_ids
gather_ids = spec_metadata.batch_indices_cuda[:batch_size]
logits = mtp_layers[0].shared_head(hidden_states[gather_ids],

View File

@ -128,7 +128,7 @@ def swizzle_sf(sf: torch.Tensor,
"""
sf_cols = ceil_div(cols, scaling_vector_size)
sf = sf.view(-1, rows, sf_cols)
return torch.ops.tensorrt_llm.nvfp4_block_scale_interleave(sf)
return torch.ops.trtllm.nvfp4_block_scale_interleave(sf)
def unswizzle_sf(sf: torch.Tensor,
@ -146,14 +146,15 @@ def unswizzle_sf(sf: torch.Tensor,
"""
sf_cols = ceil_div(cols, scaling_vector_size)
sf = sf.view(-1, rows, sf_cols)
return torch.ops.tensorrt_llm.nvfp4_block_scale_interleave_reverse(sf).view(
return torch.ops.trtllm.nvfp4_block_scale_interleave_reverse(sf).view(
-1, sf_cols)
@torch.library.custom_op("trtllm::reswizzle_sf", mutates_args=())
def reswizzle_sf(sf: torch.Tensor,
rows: int,
cols: int,
scaling_vector_size: int = 16):
scaling_vector_size: int = 16) -> torch.Tensor:
"""Reswizzle FP4 scaling factors using C++ torch op implementation.
It unswizzles the scaling factors in each partition first, then concatenates them together, and finally swizzles them back.
Args:
@ -189,6 +190,16 @@ def reswizzle_sf(sf: torch.Tensor,
return swizzle_sf(sf_concatenated, total_rows, cols, scaling_vector_size)
@torch.library.register_fake("trtllm::reswizzle_sf")
def _(sf, rows, cols, scaling_vector_size=16):
sf_cols = ceil_div(cols, scaling_vector_size)
padded_rows, padded_sf_cols = compute_swizzled_sf_shape(rows, sf_cols)
num_partitions = sf.numel() // (padded_rows * padded_sf_cols)
total_rows = num_partitions * rows
sz = pad_up(total_rows, 128) * pad_up(cols, 4)
return sf.new_empty(sz)
def next_positive_power_of_2(x: int) -> int:
if x < 1:
return 1

View File

@ -692,7 +692,7 @@ class MOEWeightWrapper(Module):
weights = stack_weights(tllm_key, weights)
if tllm_key.endswith("weights_block_scaling_factor_interleaved"):
weights = stack_weights(tllm_key, weights)
weights = torch.ops.tensorrt_llm.nvfp4_block_scale_interleave(
weights = torch.ops.trtllm.nvfp4_block_scale_interleave(
weights.to(torch.float8_e4m3fn).view(
torch.uint8).cpu().contiguous()).reshape(
weights.shape).view(torch.float8_e4m3fn)

View File

@ -1811,7 +1811,7 @@ def preprocess_perlayer_weights(weights,
weights[new_name] = weights[name]
weights[
new_name +
"_interleaved"] = torch.ops.tensorrt_llm.nvfp4_block_scale_interleave(
"_interleaved"] = torch.ops.trtllm.nvfp4_block_scale_interleave(
weights[name].view(fp4_utils.float4_sf_dtype).cpu(
).contiguous()).reshape(nrows, ncols).view(
fp4_utils.float4_sf_dtype)

View File

@ -2218,7 +2218,7 @@ class FP4Linear(Linear):
qkv_block_scale,
tllm_key.replace(
'weight', "weights_block_scaling_factor_interleaved"):
torch.ops.tensorrt_llm.nvfp4_block_scale_interleave(
torch.ops.trtllm.nvfp4_block_scale_interleave(
qkv_block_scale.view(
torch.uint8).cpu().contiguous()).reshape(
qkv_block_scale.shape).view(
@ -2238,7 +2238,7 @@ class FP4Linear(Linear):
elif tllm_key.endswith("weights_block_scaling_factor"):
return weights
elif tllm_key.endswith("weights_block_scaling_factor_interleaved"):
return torch.ops.tensorrt_llm.nvfp4_block_scale_interleave(
return torch.ops.trtllm.nvfp4_block_scale_interleave(
weights.view(torch.uint8).cpu().contiguous()).reshape(
weights.shape).view(torch.float8_e4m3fn)
elif tllm_key.endswith("weights_global_scaling_factor"):
@ -2379,7 +2379,7 @@ class FP4RowLinear(RowLinear):
elif tllm_key.endswith("weights_block_scaling_factor"):
return weights
elif tllm_key.endswith("weights_block_scaling_factor_interleaved"):
return torch.ops.tensorrt_llm.nvfp4_block_scale_interleave(
return torch.ops.trtllm.nvfp4_block_scale_interleave(
weights.view(torch.uint8).cpu().contiguous()).reshape(
weights.shape).view(torch.float8_e4m3fn)
elif tllm_key.endswith("weights_global_scaling_factor"):

View File

@ -28,7 +28,7 @@ class FP4GemmType(IntEnum):
W4A8_MXFP4_MXFP8 = 1
def get_fp4_shape(input_shape, sf_vec_size):
def get_fp4_shape(input_shape, sf_vec_size, is_swizzled_layout=True):
m = 1
for i in range(len(input_shape) - 1):
m *= input_shape[i]
@ -36,7 +36,9 @@ def get_fp4_shape(input_shape, sf_vec_size):
output_shape = [i for i in input_shape]
output_shape[-1] //= 2
scale_shape = pad_up(m, 128) * pad_up(input_shape[-1] // sf_vec_size, 4)
scale_shape = pad_up(m, 128) * pad_up(
input_shape[-1] // sf_vec_size,
4) if is_swizzled_layout else m * (input_shape[-1] // sf_vec_size)
return output_shape, scale_shape
@ -205,4 +207,4 @@ def shuffle_matrix_sf_a(
input_tensor, row_indices.to(input_tensor.device))
# 128x4
return torch.ops.tensorrt_llm.nvfp4_block_scale_interleave(w_shuffled)
return torch.ops.trtllm.nvfp4_block_scale_interleave(w_shuffled)

View File

@ -516,8 +516,7 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
overlap_scheduler, torch_compile):
if torch_compile and mtp_nextn > 0:
pytest.skip("https://nvbugs/5252313")
if torch_compile and attention_dp:
pytest.skip("https://nvbugs/5252559")
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9)
torch_compile_config = TorchCompileConfig(
enable_fullgraph=True,
@ -558,14 +557,13 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
torch_compile):
if torch_compile and mtp_nextn > 0:
pytest.skip("https://nvbugs/5252313")
if torch_compile and attention_dp:
pytest.skip("https://nvbugs/5252559")
if torch_compile and pp_size > 1:
pytest.skip("PP with torch.compile is not supported yet.")
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9)
torch_compile_config = TorchCompileConfig(
enable_fullgraph=True,
enable_piecewise_cuda_graph=cuda_graph) if torch_compile else None
enable_piecewise_cuda_graph=cuda_graph
and not attention_dp) if torch_compile else None
pytorch_config = dict(
disable_overlap_scheduler=not overlap_scheduler,
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
@ -602,8 +600,6 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
overlap_scheduler, torch_compile):
if torch_compile and mtp != "disable":
pytest.skip("https://nvbugs/5252313")
if torch_compile and attention_dp:
pytest.skip("https://nvbugs/5252559")
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9)
torch_compile_config = TorchCompileConfig(
enable_fullgraph=True,
@ -775,14 +771,13 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
overlap_scheduler, torch_compile):
if torch_compile and mtp_nextn > 0:
pytest.skip("https://nvbugs/5252313")
if torch_compile and attention_dp:
pytest.skip("https://nvbugs/5252559")
if torch_compile and pp_size > 1:
pytest.skip("PP with torch.compile is not supported yet.")
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9)
torch_compile_config = TorchCompileConfig(
enable_fullgraph=True,
enable_piecewise_cuda_graph=cuda_graph) if torch_compile else None
enable_piecewise_cuda_graph=cuda_graph
and not attention_dp) if torch_compile else None
pytorch_config = dict(
disable_overlap_scheduler=not overlap_scheduler,
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
@ -989,8 +984,6 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
torch_compile, mtp_nextn, moe_backend):
if torch_compile and mtp_nextn > 0:
pytest.skip("https://nvbugs/5252313")
if torch_compile and attention_dp:
pytest.skip("https://nvbugs/5252559")
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9)
torch_compile_config = TorchCompileConfig(
enable_fullgraph=True,
@ -1048,16 +1041,16 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
torch_compile, mtp_nextn, moe_backend):
if torch_compile and mtp_nextn > 0:
pytest.skip("https://nvbugs/5252313")
if torch_compile and attention_dp:
pytest.skip("https://nvbugs/5252559")
if torch_compile and pp_size > 1:
pytest.skip("PP with torch.compile is not supported yet.")
if not attention_dp and (tp_size > 1 or ep_size > 1):
pytest.skip("https://nvbugs/5336321")
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9)
# Picewise Cuda Graph cannot be enabled for nvfp4 attention dp.
torch_compile_config = TorchCompileConfig(
enable_fullgraph=True,
enable_piecewise_cuda_graph=cuda_graph) if torch_compile else None
enable_piecewise_cuda_graph=cuda_graph
and not attention_dp) if torch_compile else None
pytorch_config = dict(
disable_overlap_scheduler=not overlap_scheduler,
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,

View File

@ -214,11 +214,13 @@ def test_fused_moe_alltoall(alltoall_method_type):
x,
router_logits,
all_rank_num_tokens=all_rank_num_tokens,
all_rank_max_num_tokens=m,
use_dp_padding=False)
ref_output = ref_model.forward(
x,
router_logits,
all_rank_num_tokens=all_rank_num_tokens,
all_rank_max_num_tokens=m,
use_dp_padding=False)
# Evaluate outputs
@ -547,17 +549,17 @@ def test_fused_moe_nvfp4(dtype):
w1_weight_nvfp4, w1_sf_block = torch.ops.trtllm.fp4_quantize(
w1_weight, w3_w1_global, SCALING_VECTOR_SIZE, False)
w1_sf_block_unswizzled = torch.ops.tensorrt_llm.nvfp4_block_scale_interleave_reverse(
w1_sf_block_unswizzled = torch.ops.trtllm.nvfp4_block_scale_interleave_reverse(
w1_sf_block.cpu().view(INTERMEDIATE_SIZE, -1))
w2_weight_nvfp4, w2_sf_block = torch.ops.trtllm.fp4_quantize(
w2_weight, w2_sf_global, SCALING_VECTOR_SIZE, False)
w2_sf_block_unswizzled = torch.ops.tensorrt_llm.nvfp4_block_scale_interleave_reverse(
w2_sf_block_unswizzled = torch.ops.trtllm.nvfp4_block_scale_interleave_reverse(
w2_sf_block.cpu().view(HIDDEN_SIZE, -1))
w3_weight_nvfp4, w3_sf_block = torch.ops.trtllm.fp4_quantize(
w3_weight, w3_w1_global, SCALING_VECTOR_SIZE, False)
w3_sf_block_unswizzled = torch.ops.tensorrt_llm.nvfp4_block_scale_interleave_reverse(
w3_sf_block_unswizzled = torch.ops.trtllm.nvfp4_block_scale_interleave_reverse(
w3_sf_block.cpu().view(INTERMEDIATE_SIZE, -1))
w1_input_scale = x_sf_global.cuda()

View File

@ -977,7 +977,7 @@ def run_single_rank_ub_pass_fp4(
def block_scale_unswizzled(scale):
sz = fp4_utils.pad_up(hidden_size, 128)
return torch.ops.tensorrt_llm.nvfp4_block_scale_interleave_reverse(
return torch.ops.trtllm.nvfp4_block_scale_interleave_reverse(
scale.cpu().view(sz, -1))[0:hidden_size]
l0_weight_scale_block_unswizzled = block_scale_unswizzled(

View File

@ -172,8 +172,8 @@ def test_fp4_sf_interleave(b, m, k):
w_cuda = w.cuda()
# The cpu and cuda kernels are different
w_out_cpu = torch.ops.tensorrt_llm.nvfp4_block_scale_interleave(w)
w_out_cuda = torch.ops.tensorrt_llm.nvfp4_block_scale_interleave(w_cuda)
w_out_cpu = torch.ops.trtllm.nvfp4_block_scale_interleave(w)
w_out_cuda = torch.ops.trtllm.nvfp4_block_scale_interleave(w_cuda)
torch.cuda.synchronize()
torch.testing.assert_allclose(w_out_cpu.cuda(), w_out_cuda)

View File

@ -39,7 +39,7 @@ def test_fp4_linear(dtype):
assert l_fp4.weight_scale.dtype == fp4_utils.float4_sf_dtype
w_sf_block_unswizzled = (
torch.ops.tensorrt_llm.nvfp4_block_scale_interleave_reverse(
torch.ops.trtllm.nvfp4_block_scale_interleave_reverse(
w_sf_block.cpu().view(HIDDEN_SIZE, -1)))
l_fp4.load_weights([{

View File

@ -70,10 +70,8 @@ class TestFunctional(unittest.TestCase):
mat_b[0][0] = 36
mat_b_ref[0][0] = 2.0
a_block_sf = torch.ops.tensorrt_llm.nvfp4_block_scale_interleave(
a_block_sf)
b_block_sf = torch.ops.tensorrt_llm.nvfp4_block_scale_interleave(
b_block_sf)
a_block_sf = torch.ops.trtllm.nvfp4_block_scale_interleave(a_block_sf)
b_block_sf = torch.ops.trtllm.nvfp4_block_scale_interleave(b_block_sf)
c = (torch.ops.trtllm.w4a8_mxfp4_fp8_gemm(mat_a, mat_b, a_block_sf,
b_block_sf, a_sf,

View File

@ -1186,7 +1186,7 @@ class TestMoE(unittest.TestCase):
moe_weight_wrapper.weights_block_scaling_factor_interleaved.value = (
np.ascontiguousarray(
torch_to_numpy(
torch.ops.tensorrt_llm.nvfp4_block_scale_interleave(
torch.ops.trtllm.nvfp4_block_scale_interleave(
scale_factor.view(torch.uint8).contiguous()).view(
scale_factor.dtype).reshape(
scale_factor.shape).view(torch.uint8))))