mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
chore: Cleanup disable_fp4_allgather. (#6006)
Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
This commit is contained in:
parent
8ef8e73002
commit
fc2347eaf5
@ -511,8 +511,6 @@ class Deepseekv3MoE(nn.Module):
|
||||
# max-throughput
|
||||
use_dp_padding = False
|
||||
if self.use_dp and self.mapping.tp_size > 1:
|
||||
# FP4 all_gather moves this bf16 allgather in to after topk and fp4 quantization
|
||||
# to reduce allreduce BW
|
||||
if isinstance(self.experts, TRTLLMGenFusedMoE):
|
||||
hidden_states = allgather(hidden_states,
|
||||
self.mapping,
|
||||
|
||||
@ -20,7 +20,6 @@ from ..modules.fused_moe import (BaseMoeRoutingMethod, CutlassFusedMoE, MoE,
|
||||
from ..modules.linear import TensorParallelMode
|
||||
from ..modules.rms_norm import RMSNorm
|
||||
from ..speculative import SpecMetadata
|
||||
from ..utils import disable_fp4_allgather
|
||||
from .modeling_qwen3 import Qwen3Attention
|
||||
from .modeling_speculative import SpecDecOneEngineForCausalLM
|
||||
from .modeling_utils import (DecoderModel, EagerFusionConfig,
|
||||
@ -133,11 +132,7 @@ class Qwen3MoE(nn.Module):
|
||||
assert not self.enable_attention_dp
|
||||
|
||||
if self.enable_attention_dp and self.mapping.tp_size > 1:
|
||||
# FP4 all_gather moves this bf16 allgather in to after topk and fp4 quantization
|
||||
# to reduce allreduce BW
|
||||
if (disable_fp4_allgather()
|
||||
and not self.experts.enable_alltoall) or isinstance(
|
||||
self.experts, TRTLLMGenFusedMoE):
|
||||
if isinstance(self.experts, TRTLLMGenFusedMoE):
|
||||
hidden_states = allgather(hidden_states,
|
||||
self.mapping,
|
||||
dim=0,
|
||||
|
||||
@ -4,8 +4,7 @@ import torch
|
||||
|
||||
from ...distributed import allgather, reducescatter
|
||||
from ...model_config import ModelConfig
|
||||
from ...utils import (EventType, Fp4QuantizedTensor, ceil_div,
|
||||
disable_fp4_allgather, swizzle_sf)
|
||||
from ...utils import EventType, Fp4QuantizedTensor, ceil_div, swizzle_sf
|
||||
from .interface import MoE
|
||||
from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethod,
|
||||
FP8QDQFusedMoEMethod, MoEWeightLoadingMode,
|
||||
@ -220,8 +219,7 @@ class CutlassFusedMoE(MoE):
|
||||
# TODO: remove this once we have correct fusedmoe kernel ready
|
||||
token_final_scales = None
|
||||
|
||||
use_allgather = self.use_dp and self.parallel_size > 1 and not disable_fp4_allgather(
|
||||
)
|
||||
use_allgather = self.use_dp and self.parallel_size > 1
|
||||
|
||||
# quantize inputs
|
||||
use_deepseek_fp8_block_scale = False
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import contextlib
|
||||
import os
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
@ -100,13 +99,6 @@ class Fp4QuantizedTensor:
|
||||
return self.fp4_tensor.shape
|
||||
|
||||
|
||||
_disable_fp4_allgather = os.getenv("TLLM_DISABLE_FP4_ALLGATHER", "0") == "1"
|
||||
|
||||
|
||||
def disable_fp4_allgather():
|
||||
return _disable_fp4_allgather
|
||||
|
||||
|
||||
def compute_swizzled_sf_shape(row: int, col: int):
|
||||
padded_row = pad_up(row, 128)
|
||||
padded_col = pad_up(col, 4)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user