This commit is contained in:
shuyixiong 2026-01-13 19:23:22 +08:00 committed by GitHub
commit 96dc51a806
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 1168 additions and 658 deletions

View File

@ -1124,6 +1124,26 @@ class ConfigurableMoE(MoE):
)
return self.backend.post_load_weights()
def process_weights_after_loading(self):
"""
Process weights after loading - delegated to backend
"""
assert hasattr(self.backend, "process_weights_after_loading"), (
f"Backend {self.backend.__class__.__name__} must implement process_weights_after_loading()"
)
return self.backend.process_weights_after_loading()
def pre_reload_weights(self):
"""
Pre reload weights - delegated to backend
"""
assert hasattr(self.backend, "pre_reload_weights"), (
f"Backend {self.backend.__class__.__name__} must implement pre_reload_weights()"
)
return self.backend.pre_reload_weights()
# ========== Communication and Quantization Properties ==========
@property

View File

@ -1,3 +1,4 @@
import inspect
import os
from functools import cached_property
from typing import Dict, List, Optional, Tuple, Union
@ -862,16 +863,22 @@ class CutlassFusedMoE(MoE):
assert len(weights) == 1
weights = weights[0]
if not isinstance(self.quant_method, UnquantizedFusedMoEMethod):
assert not allow_partial_loading, "Partial loading is not supported for quantized MoE now"
self.quant_method.load_weights(self, weights,
self.weight_loading_mode)
else:
self.quant_method.load_weights(
self,
weights,
self.weight_loading_mode,
allow_partial_loading=allow_partial_loading)
kargs = {}
if "allow_partial_loading" in inspect.getfullargspec(
self.quant_method.load_weights).args:
kargs["allow_partial_loading"] = allow_partial_loading
self.quant_method.load_weights(self, weights, self.weight_loading_mode,
**kargs)
def post_load_weights(self):
self.quant_method.post_load_weights(self)
def process_weights_after_loading(self):
if hasattr(self.quant_method, 'process_weights_after_loading'):
self.quant_method.process_weights_after_loading(self)
def pre_reload_weights(self):
assert hasattr(
self.quant_method, 'pre_reload_weights'
), "pre_reload_weights is not supported for this quant method"
self.quant_method.pre_reload_weights(self)

View File

@ -1401,3 +1401,13 @@ class TritonFusedMoE(MoE):
def post_load_weights(self):
self.quant_method.post_load_weights(self)
def process_weights_after_loading(self):
if hasattr(self.quant_method, 'process_weights_after_loading'):
self.quant_method.process_weights_after_loading(self)
def pre_reload_weights(self):
assert hasattr(
self.quant_method, 'pre_reload_weights'
), "pre_reload_weights is not supported for this quant method"
self.quant_method.pre_reload_weights(self)

View File

@ -1,3 +1,4 @@
import inspect
import os
from functools import cached_property
from typing import Dict, List, Optional, Union
@ -21,9 +22,9 @@ from .interface import AlltoallMethodType, MoE, MoEWeightLoadingMode
# isort: off
from .quantization import (
DeepSeekFP8BlockScalesFusedMoEMethod, NVFP4TRTLLMGenFusedMoEBaseMethod,
NVFP4TRTLLMGenFusedMoEMethod, UnquantizedFusedMoEMethod,
W4A8MXFP4FP8TRTLLMGenFusedMoEMethod, W4A8MXFP4MXFP8TRTLLMGenFusedMoEMethod,
W4A8NVFP4FP8TRTLLMGenFusedMoEMethod, W4A16MXFP4TRTLLMGenFusedMoEMethod)
NVFP4TRTLLMGenFusedMoEMethod, W4A8MXFP4FP8TRTLLMGenFusedMoEMethod,
W4A8MXFP4MXFP8TRTLLMGenFusedMoEMethod, W4A8NVFP4FP8TRTLLMGenFusedMoEMethod,
W4A16MXFP4TRTLLMGenFusedMoEMethod)
# isort: on
from .routing import BaseMoeRoutingMethod, DeepSeekV3MoeRoutingMethod
@ -273,20 +274,26 @@ class TRTLLMGenFusedMoE(MoE):
assert len(weights) == 1
weights = weights[0]
if not isinstance(self.quant_method, UnquantizedFusedMoEMethod):
assert not allow_partial_loading, "Partial loading is not supported for quantized MoE now"
self.quant_method.load_weights(self, weights,
self.weight_loading_mode)
else:
self.quant_method.load_weights(
self,
weights,
self.weight_loading_mode,
allow_partial_loading=allow_partial_loading)
kargs = {}
if "allow_partial_loading" in inspect.getfullargspec(
self.quant_method.load_weights).args:
kargs["allow_partial_loading"] = allow_partial_loading
self.quant_method.load_weights(self, weights, self.weight_loading_mode,
**kargs)
def post_load_weights(self):
self.quant_method.post_load_weights(self)
def process_weights_after_loading(self):
if hasattr(self.quant_method, 'process_weights_after_loading'):
self.quant_method.process_weights_after_loading(self)
def pre_reload_weights(self):
assert hasattr(
self.quant_method, 'pre_reload_weights'
), "pre_reload_weights is not supported for this quant method"
self.quant_method.pre_reload_weights(self)
def quantize_input(self, x, post_quant_comm: bool = True):
"""Quantize inputs prior to post-communication (alltoall/allgather) or before MoE computation.

View File

@ -1,3 +1,4 @@
import inspect
import os
from typing import Dict, List, Optional, Tuple, Union
@ -935,20 +936,26 @@ class WideEPMoE(MoE):
assert len(weights) == 1
weights = weights[0]
if not isinstance(self.quant_method, UnquantizedFusedMoEMethod):
assert not allow_partial_loading, "Partial loading is not supported for quantized MoE now"
self.quant_method.load_weights(self, weights,
self.weight_loading_mode)
else:
self.quant_method.load_weights(
self,
weights,
self.weight_loading_mode,
allow_partial_loading=allow_partial_loading)
kargs = {}
if "allow_partial_loading" in inspect.getfullargspec(
self.quant_method.load_weights).args:
kargs["allow_partial_loading"] = allow_partial_loading
self.quant_method.load_weights(self, weights, self.weight_loading_mode,
**kargs)
def post_load_weights(self):
self.quant_method.post_load_weights(self)
def process_weights_after_loading(self):
if hasattr(self.quant_method, 'process_weights_after_loading'):
self.quant_method.process_weights_after_loading(self)
def pre_reload_weights(self):
assert hasattr(
self.quant_method, 'pre_reload_weights'
), "pre_reload_weights is not supported for this quant method"
self.quant_method.pre_reload_weights(self)
def forward_fake(
self,
x: Union[torch.Tensor, Fp4QuantizedTensor],

View File

@ -521,12 +521,20 @@ class MoE(nn.Module):
raise NotImplementedError
@abstractmethod
def load_weights(self, weights: List[Dict]):
def load_weights(self,
weights: List[Dict],
allow_partial_loading: bool = False):
raise NotImplementedError
def process_weights_after_loading(self):
pass
def post_load_weights(self):
pass
def pre_reload_weights(self):
pass
@abstractmethod
def quantize_input(
self,

View File

@ -18,7 +18,8 @@ from tensorrt_llm.quantization.utils.fp4_utils import (
from tensorrt_llm.quantization.utils.fp8_utils import (
resmooth_to_fp8_e8m0, transform_sf_into_required_layout)
from ...utils import swizzle_sf, unswizzle_sf
from ...utils import (replace_parameter_and_save_metadata, swizzle_sf,
unswizzle_sf)
from ..linear import TensorParallelMode, load_weight_shard
from .interface import MoEWeightLoadingMode
@ -236,6 +237,8 @@ class FusedMoEMethodBase(ABC):
module.w3_w1_bias = None
module.w2_bias = None
module.rebuild_tensor_metadata = {}
def load_expert_weights_to_dst(
self,
module: torch.nn.Module,
@ -330,6 +333,12 @@ class FusedMoEMethodBase(ABC):
weights: List[Dict],
weight_loading_mode: MoEWeightLoadingMode,
allow_partial_loading: bool = False):
if allow_partial_loading:
assert isinstance(
self, (UnquantizedFusedMoEMethod, FP8QDQFusedMoEMethod,
DeepSeekFP8BlockScalesFusedMoEMethod,
DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm)
), "Partial loading is only supported for unquantized and FP8 models"
additional_kargs = {}
if "allow_partial_loading" in inspect.getfullargspec(
self.load_expert_weights_to_dst).args:
@ -402,6 +411,9 @@ class FusedMoEMethodBase(ABC):
local_shared_w2_bias_tensors if module.bias else None,
**additional_kargs)
if not allow_partial_loading:
self.process_weights_after_loading(module)
def post_load_weights(self, module: torch.nn.Module):
if self.need_load_shared_weights(module):
weight_fns = {
@ -432,6 +444,9 @@ class FusedMoEMethodBase(ABC):
def load_quant_scales(self, module: torch.nn.Module, weights: List[Dict]):
pass
def process_weights_after_loading(self, module: torch.nn.Module):
pass
@abstractmethod
def setup_quant_scales(self, module: torch.nn.Module):
raise NotImplementedError
@ -473,21 +488,19 @@ class FusedMoEMethodBase(ABC):
TensorParallelMode.COLUMN,
device=device) if w3_weight is not None else None
src_w3_size_shard = w3_weight_shard.shape[
0] if w3_weight_shard is not None else 0
src_w1_size_shard = w1_weight_shard.shape[
0] if w1_weight_shard is not None else 0
if w1_weight is not None:
dst_w1_weight = dst_w3_w1_weight.narrow(dim=0,
start=src_w3_size_shard,
length=src_w1_size_shard)
dst_w1_weight.copy_(w1_weight_shard.contiguous().view(
dst_w3_w1_weight.dtype),
non_blocking=True)
if w3_weight is not None:
dst_w3_weight = dst_w3_w1_weight.narrow(dim=0,
start=0,
length=src_w3_size_shard)
dst_w3_weight, dst_w1_weight = dst_w3_w1_weight.chunk(2, dim=0)
if w1_weight_shard is not None and w1_weight_shard.shape[0] != 0:
w1_weight_shard_viewed = w1_weight_shard.contiguous().view(
dst_w3_w1_weight.dtype)
if w1_weight_shard_viewed.shape[0] == dst_w3_w1_weight.shape[0]:
# w3_weight (gate_proj) should be empty for Nemotron-H MoE model.
dst_w3_w1_weight.copy_(w1_weight_shard_viewed,
non_blocking=True)
elif w1_weight_shard_viewed.shape[0] == dst_w1_weight.shape[0]:
dst_w1_weight.copy_(w1_weight_shard_viewed, non_blocking=True)
else:
raise ValueError("Shape mismatch!")
if w3_weight_shard is not None and w3_weight_shard.shape[0] != 0:
dst_w3_weight.copy_(w3_weight_shard.contiguous().view(
dst_w3_w1_weight.dtype),
non_blocking=True)
@ -516,6 +529,16 @@ class FusedMoEMethodBase(ABC):
dst_w2_weight.copy_(w2_weight_shard.view(dst_w2_weight.dtype),
non_blocking=True)
def pre_reload_weights(self, module: torch.nn.Module):
for param_name, metadata in module.rebuild_tensor_metadata.items():
logger.warning(
f"Pre-reloading weight '{param_name}' requires tensor re-creation, which will invalidate existing CUDA graphs."
)
param = torch.nn.Parameter(torch.empty_like(metadata,
device="cuda"),
requires_grad=False)
module.register_parameter(param_name, param)
class UnquantizedFusedMoEMethod(FusedMoEMethodBase):
@ -539,8 +562,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase):
def load_expert_fc31_input_scale_fp8_qdq(w1_input_scale, w3_input_scale,
dst_fc31_input_scale: torch.Tensor):
dst_fc31_input_scale.copy_(
max(w1_input_scale[...].reshape([]), w3_input_scale[...].reshape([])))
if w1_input_scale is not None and w1_input_scale.numel() != 0:
w1_input_scale = w1_input_scale[...].reshape([])
dst_fc31_input_scale[0].copy_(w1_input_scale)
if w3_input_scale is not None and w3_input_scale.numel() != 0:
w3_input_scale = w3_input_scale[...].reshape([])
dst_fc31_input_scale[1].copy_(w3_input_scale)
def load_expert_fc2_input_scale_fp8_qdq(w2_input_scale,
@ -549,35 +576,45 @@ def load_expert_fc2_input_scale_fp8_qdq(w2_input_scale,
def load_activation_scales_fp8_qdq(module: torch.nn.Module, weights: Dict):
tmp_fc31_input_scale = torch.empty(module.num_experts, dtype=torch.float32)
tmp_fc2_input_scale = torch.empty(module.num_experts, dtype=torch.float32)
if not hasattr(module, 'tmp_fc31_input_scale'):
module.tmp_fc31_input_scale = torch.empty(
(module.num_experts, 2),
dtype=torch.float32,
device=module.fc31_dequant.device)
tmp_fc31_input_scale = module.tmp_fc31_input_scale
if not hasattr(module, 'tmp_fc2_input_scale'):
module.tmp_fc2_input_scale = torch.empty(
module.num_experts,
dtype=torch.float32,
device=module.fc2_dequant.device)
tmp_fc2_input_scale = module.tmp_fc2_input_scale
for expert_id in range(module.num_experts):
if module.weight_loading_mode == MoEWeightLoadingMode.VANILLA:
w1_input_scale = weights[f"{expert_id}.w1.input_scale"]
w3_input_scale = weights[f"{expert_id}.w3.input_scale"]
w2_input_scale = weights[f"{expert_id}.w2.input_scale"]
w1_input_scale = weights[
f"{expert_id}.w1.input_scale"] if f"{expert_id}.w1.input_scale" in weights else None
w3_input_scale = weights[
f"{expert_id}.w3.input_scale"] if f"{expert_id}.w3.input_scale" in weights else None
w2_input_scale = weights[
f"{expert_id}.w2.input_scale"] if f"{expert_id}.w2.input_scale" in weights else None
elif module.weight_loading_mode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ:
w1_input_scale = weights[f"gate_up_proj_input_scale"]
w3_input_scale = weights[f"gate_up_proj_input_scale"]
w2_input_scale = weights[f"down_proj_input_scale"]
w1_input_scale = weights[
f"gate_up_proj_input_scale"] if f"gate_up_proj_input_scale" in weights else None
w3_input_scale = weights[
f"gate_up_proj_input_scale"] if f"gate_up_proj_input_scale" in weights else None
w2_input_scale = weights[
f"down_proj_input_scale"] if f"down_proj_input_scale" in weights else None
else:
raise NotImplementedError(
f"Unknown weight loading mode in MoE: {module.weight_loading_mode}"
)
load_expert_fc31_input_scale_fp8_qdq(w1_input_scale, w3_input_scale,
tmp_fc31_input_scale[expert_id])
if w1_input_scale is not None or w3_input_scale is not None:
load_expert_fc31_input_scale_fp8_qdq(
w1_input_scale, w3_input_scale, tmp_fc31_input_scale[expert_id])
load_expert_fc2_input_scale_fp8_qdq(w2_input_scale,
tmp_fc2_input_scale[expert_id])
# max_fc31_input_scale is the maximum of all w1 input scales and w3 input scales.
# It's used to quantize fc31 input inside the MOE op
max_fc31_input_scale = tmp_fc31_input_scale.max()
# max_fc2_input_scale is the maximum of all w2 input scales.
max_fc2_input_scale = tmp_fc2_input_scale.max()
return max_fc31_input_scale, max_fc2_input_scale
if w2_input_scale is not None:
load_expert_fc2_input_scale_fp8_qdq(w2_input_scale,
tmp_fc2_input_scale[expert_id])
def requantize_expert_w3_w1_weight_fp8_qdq(module: torch.nn.Module,
@ -654,9 +691,12 @@ class FP8QDQFusedMoEMethod(FusedMoEMethodBase):
def load_expert_w3_w1_weight_scale_fp8_qdq(
self, w1_weight_scale, w3_weight_scale,
dst_w3_w1_weight_scale: torch.Tensor):
w1_weight_scale = w1_weight_scale[...].reshape([])
w3_weight_scale = w3_weight_scale[...].reshape([])
dst_w3_w1_weight_scale.copy_(max(w1_weight_scale, w3_weight_scale))
if w1_weight_scale is not None and w1_weight_scale.numel() != 0:
w1_weight_scale = w1_weight_scale[...].reshape([])
dst_w3_w1_weight_scale[0].copy_(w1_weight_scale)
if w3_weight_scale is not None and w3_weight_scale.numel() != 0:
w3_weight_scale = w3_weight_scale[...].reshape([])
dst_w3_w1_weight_scale[1].copy_(w3_weight_scale)
def load_expert_w2_weight_scale_fp8(self, w2_weight_scale,
dst_w2_weight_scale: torch.Tensor):
@ -664,25 +704,38 @@ class FP8QDQFusedMoEMethod(FusedMoEMethodBase):
def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
# Step1: Load input scales.
max_fc31_input_scale, max_fc2_input_scale = load_activation_scales_fp8_qdq(
module, weights)
load_activation_scales_fp8_qdq(module, weights)
# Step2: Load weight scales and requantize w3_w1_weight.
tmp_w3_w1_weight_scale = torch.empty(module.expert_size_per_partition,
dtype=torch.float32)
tmp_w2_weight_scale = torch.empty(module.expert_size_per_partition,
dtype=torch.float32)
# Step2: Load weight scales
if not hasattr(module, 'tmp_w3_w1_weight_scale'):
module.tmp_w3_w1_weight_scale = torch.empty(
(module.expert_size_per_partition, 2),
dtype=torch.float32,
device=module.fc31_dequant.device)
if not hasattr(module, 'tmp_w2_weight_scale'):
module.tmp_w2_weight_scale = torch.empty(
module.expert_size_per_partition,
dtype=torch.float32,
device=module.fc2_dequant.device)
tmp_w3_w1_weight_scale = module.tmp_w3_w1_weight_scale
tmp_w2_weight_scale = module.tmp_w2_weight_scale
for local_slot_id, expert_id in enumerate(
module.initial_local_expert_ids):
if module.weight_loading_mode == MoEWeightLoadingMode.VANILLA:
w1_weight_scale = weights[f"{expert_id}.w1.weight_scale"]
w3_weight_scale = weights[f"{expert_id}.w3.weight_scale"]
w2_weight_scale = weights[f"{expert_id}.w2.weight_scale"]
w1_weight_scale = weights[
f"{expert_id}.w1.weight_scale"] if f"{expert_id}.w1.weight_scale" in weights else None
w3_weight_scale = weights[
f"{expert_id}.w3.weight_scale"] if f"{expert_id}.w3.weight_scale" in weights else None
w2_weight_scale = weights[
f"{expert_id}.w2.weight_scale"] if f"{expert_id}.w2.weight_scale" in weights else None
elif module.weight_loading_mode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ:
w1_weight_scale = weights[f"gate_up_proj_weight_scale"]
w3_weight_scale = weights[f"gate_up_proj_weight_scale"]
w2_weight_scale = weights[f"down_proj_weight_scale"]
w1_weight_scale = weights[
f"gate_up_proj_weight_scale"] if f"gate_up_proj_weight_scale" in weights else None
w3_weight_scale = weights[
f"gate_up_proj_weight_scale"] if f"gate_up_proj_weight_scale" in weights else None
w2_weight_scale = weights[
f"down_proj_weight_scale"] if f"down_proj_weight_scale" in weights else None
else:
raise NotImplementedError(
f"Unknown weight loading mode in MoE: {module.weight_loading_mode}"
@ -690,24 +743,45 @@ class FP8QDQFusedMoEMethod(FusedMoEMethodBase):
expert_idx = local_slot_id
self.load_expert_w3_w1_weight_scale_fp8_qdq(
w1_weight_scale, w3_weight_scale,
tmp_w3_w1_weight_scale[expert_idx])
if w1_weight_scale is not None or w3_weight_scale is not None:
self.load_expert_w3_w1_weight_scale_fp8_qdq(
w1_weight_scale, w3_weight_scale,
tmp_w3_w1_weight_scale[expert_idx])
if w2_weight_scale is not None:
self.load_expert_w2_weight_scale_fp8(
w2_weight_scale, tmp_w2_weight_scale[expert_idx])
def process_weights_after_loading(self, module: torch.nn.Module):
# max_fc31_input_scale is the maximum of all w1 input scales and w3 input scales.
# It's used to quantize fc31 input inside the MOE op
max_fc31_input_scale = module.tmp_fc31_input_scale.max()
# max_fc2_input_scale is the maximum of all w2 input scales.
max_fc2_input_scale = module.tmp_fc2_input_scale.max()
# Requantize w3_w1_weight
for local_slot_id, _ in enumerate(module.initial_local_expert_ids):
expert_idx = local_slot_id
requantize_expert_w3_w1_weight_fp8_qdq(
module, w1_weight_scale, w3_weight_scale,
module, module.tmp_w3_w1_weight_scale[expert_idx][0],
module.tmp_w3_w1_weight_scale[expert_idx][1],
module.w3_w1_weight.data[expert_idx])
self.load_expert_w2_weight_scale_fp8(
w2_weight_scale, tmp_w2_weight_scale[expert_idx])
# Step3: calculate and store final loaded weights
module.fc31_dequant.data.copy_(tmp_w3_w1_weight_scale *
# Calculate and store final loaded weights
max_w3_w1_weight_scale = module.tmp_w3_w1_weight_scale.max(dim=1).values
module.fc31_dequant.data.copy_(max_w3_w1_weight_scale *
max_fc31_input_scale)
module.fc2_quant.data.copy_(max_fc2_input_scale.reciprocal())
module.fc2_dequant.data.copy_(tmp_w2_weight_scale * max_fc2_input_scale)
module.fc2_dequant.data.copy_(module.tmp_w2_weight_scale *
max_fc2_input_scale)
module.fc31_input_dequant.data.copy_(max_fc31_input_scale)
self.setup_quant_scales(module)
delattr(module, 'tmp_w3_w1_weight_scale')
delattr(module, 'tmp_w2_weight_scale')
delattr(module, 'tmp_fc31_input_scale')
delattr(module, 'tmp_fc2_input_scale')
def post_load_weights(self, module):
super().post_load_weights(module)
@ -733,11 +807,15 @@ class FP8QDQFusedMoEMethod(FusedMoEMethodBase):
module.w2_weight, cutlass_fp8_row_alignment,
cutlass_fp8_row_alignment)
if is_padded_w3_w1_weight:
module.w3_w1_weight = nn.Parameter(padded_w3_w1_weight,
requires_grad=False)
replace_parameter_and_save_metadata(
module, "w3_w1_weight",
nn.Parameter(padded_w3_w1_weight, requires_grad=False),
module.rebuild_tensor_metadata)
if is_padded_w2_weight:
module.w2_weight = nn.Parameter(padded_w2_weight,
requires_grad=False)
replace_parameter_and_save_metadata(
module, "w2_weight",
nn.Parameter(padded_w2_weight, requires_grad=False),
module.rebuild_tensor_metadata)
class DeepSeekFP8BlockScalesFusedMoEMethod(FusedMoEMethodBase):
@ -778,9 +856,13 @@ class DeepSeekFP8BlockScalesFusedMoEMethod(FusedMoEMethodBase):
self.setup_quant_scales(module)
def load_weights(self, module: torch.nn.Module, weights: List[Dict],
weight_loading_mode: MoEWeightLoadingMode):
super().load_weights(module, weights, weight_loading_mode)
def load_weights(self,
module: torch.nn.Module,
weights: List[Dict],
weight_loading_mode: MoEWeightLoadingMode,
allow_partial_loading: bool = False):
super().load_weights(module, weights, weight_loading_mode,
allow_partial_loading)
def setup_quant_scales(self, module: torch.nn.Module):
module.quant_scales = FusedMoEQuantScalesDeepSeekFP8BlockScales(
@ -795,42 +877,53 @@ class DeepSeekFP8BlockScalesFusedMoEMethod(FusedMoEMethodBase):
for local_slot_id, expert_id in enumerate(load_expert_ids):
if module.weight_loading_mode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ:
w3_scale = weights['gate_up_proj_weight_scale'][
expert_id].transpose(0, 1).contiguous()
w1_scale = None
expert_id].transpose(0, 1).contiguous(
) if "gate_up_proj_weight_scale" in weights else None
w2_scale = weights['down_proj_weight_scale'][
expert_id].transpose(0, 1).contiguous()
expert_id].transpose(0, 1).contiguous(
) if "down_proj_weight_scale" in weights else None
w3_w1_scale_shard = load_weight_shard(w3_scale,
module.tp_size,
module.tp_rank,
TensorParallelMode.COLUMN,
device=device)
dst_w3_w1_weight_scale[local_slot_id].copy_(w3_w1_scale_shard)
elif module.weight_loading_mode == MoEWeightLoadingMode.VANILLA:
w3_scale = weights[f"{expert_id}.w3.weight_scale_inv"]
w1_scale = weights[f"{expert_id}.w1.weight_scale_inv"]
w2_scale = weights[f"{expert_id}.w2.weight_scale_inv"]
w3_scale = weights[
f"{expert_id}.w3.weight_scale_inv"] if f"{expert_id}.w3.weight_scale_inv" in weights else None
w1_scale = weights[
f"{expert_id}.w1.weight_scale_inv"] if f"{expert_id}.w1.weight_scale_inv" in weights else None
w2_scale = weights[
f"{expert_id}.w2.weight_scale_inv"] if f"{expert_id}.w2.weight_scale_inv" in weights else None
dst_w3_weight_scale, dst_w1_weight_scale = dst_w3_w1_weight_scale[
local_slot_id].chunk(2, dim=0)
if w1_scale is not None:
w1_scale_shard = load_weight_shard(
w1_scale,
module.tp_size,
module.tp_rank,
TensorParallelMode.COLUMN,
device=device)
dst_w1_weight_scale.copy_(w1_scale_shard)
if w3_scale is not None:
w3_scale_shard = load_weight_shard(
w3_scale,
module.tp_size,
module.tp_rank,
TensorParallelMode.COLUMN,
device=device)
dst_w3_weight_scale.copy_(w3_scale_shard)
else:
raise NotImplementedError(
f"Unknown weight loading mode in MoE: {module.weight_loading_mode}"
)
w3_w1_scale_shard = load_weight_shard(w3_scale,
module.tp_size,
module.tp_rank,
TensorParallelMode.COLUMN,
device=device)
if w1_scale is not None:
w1_scale_shard = load_weight_shard(w1_scale,
if w2_scale is not None:
w2_scale_shard = load_weight_shard(w2_scale,
module.tp_size,
module.tp_rank,
TensorParallelMode.COLUMN,
TensorParallelMode.ROW,
device=device)
w3_w1_scale_shard = torch.cat(
[w3_w1_scale_shard, w1_scale_shard], dim=-2)
dst_w3_w1_weight_scale[local_slot_id].copy_(w3_w1_scale_shard)
w2_scale_shard = load_weight_shard(w2_scale,
module.tp_size,
module.tp_rank,
TensorParallelMode.ROW,
device=device)
dst_w2_weight_scale[local_slot_id].copy_(w2_scale_shard)
dst_w2_weight_scale[local_slot_id].copy_(w2_scale_shard)
def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
self.load_expert_all_weight_scale_fp8_block_scale(
@ -843,16 +936,30 @@ class DeepSeekFP8BlockScalesFusedMoEMethod(FusedMoEMethodBase):
if self.need_load_shared_weights(module):
local_shared_load_expert_ids = module.layer_load_balancer.get_load_expert_ids(
)
local_shared_w3_w1_scale_tensors = torch.empty(
(len(local_shared_load_expert_ids), ) +
module.w3_w1_weight_scaling_factor.data.shape[1:],
dtype=module.w3_w1_weight_scaling_factor.data.dtype,
device='cpu')
local_shared_w2_scale_tensors = torch.empty(
(len(local_shared_load_expert_ids), ) +
module.w2_weight_scaling_factor.data.shape[1:],
dtype=module.w2_weight_scaling_factor.data.dtype,
device='cpu')
if getattr(module, 'local_shared_w3_w1_scale_tensors',
None) is not None:
local_shared_w3_w1_scale_tensors = getattr(
module, 'local_shared_w3_w1_scale_tensors')
else:
local_shared_w3_w1_scale_tensors = torch.empty(
(len(local_shared_load_expert_ids), ) +
module.w3_w1_weight_scaling_factor.data.shape[1:],
dtype=module.w3_w1_weight_scaling_factor.data.dtype,
device='cpu')
setattr(module, 'local_shared_w3_w1_scale_tensors',
local_shared_w3_w1_scale_tensors)
if getattr(module, 'local_shared_w2_scale_tensors',
None) is not None:
local_shared_w2_scale_tensors = getattr(
module, 'local_shared_w2_scale_tensors')
else:
local_shared_w2_scale_tensors = torch.empty(
(len(local_shared_load_expert_ids), ) +
module.w2_weight_scaling_factor.data.shape[1:],
dtype=module.w2_weight_scaling_factor.data.dtype,
device='cpu')
setattr(module, 'local_shared_w2_scale_tensors',
local_shared_w2_scale_tensors)
self.load_expert_all_weight_scale_fp8_block_scale(
module,
weights,
@ -860,19 +967,32 @@ class DeepSeekFP8BlockScalesFusedMoEMethod(FusedMoEMethodBase):
local_shared_w3_w1_scale_tensors,
local_shared_w2_scale_tensors,
device=torch.device("cpu"))
module.register_all_parameter_slot_and_to_fix_weight_fns({
'w3_w1_weight_scaling_factor':
local_shared_w3_w1_scale_tensors,
'w2_weight_scaling_factor':
local_shared_w2_scale_tensors,
})
def post_load_weights(self, module: torch.nn.Module):
if self.need_load_shared_weights(module):
weight_fns = {}
if hasattr(module, 'local_shared_w3_w1_scale_tensors'):
weight_fns['w3_w1_weight_scaling_factor'] = getattr(
module, 'local_shared_w3_w1_scale_tensors')
delattr(module, 'local_shared_w3_w1_scale_tensors')
if hasattr(module, 'local_shared_w2_scale_tensors'):
weight_fns['w2_weight_scaling_factor'] = getattr(
module, 'local_shared_w2_scale_tensors')
delattr(module, 'local_shared_w2_scale_tensors')
if weight_fns:
module.register_all_parameter_slot_and_to_fix_weight_fns(
weight_fns)
super().post_load_weights(module)
class DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm(
DeepSeekFP8BlockScalesFusedMoEMethod):
def load_weights(self, module: torch.nn.Module, weights: List[Dict],
weight_loading_mode: MoEWeightLoadingMode):
def load_weights(self,
module: torch.nn.Module,
weights: List[Dict],
weight_loading_mode: MoEWeightLoadingMode,
allow_partial_loading: bool = False):
if is_sm_100f():
expert_ids = set(module.initial_local_expert_ids)
if self.need_load_shared_weights(module):
@ -888,7 +1008,8 @@ class DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm(
scale = weights[name][:]
weights[weight_name], weights[name] = resmooth_to_fp8_e8m0(
weight, scale)
super().load_weights(module, weights, weight_loading_mode)
super().load_weights(module, weights, weight_loading_mode,
allow_partial_loading)
def post_load_weights(self, module: torch.nn.Module):
super().post_load_weights(module)
@ -900,8 +1021,12 @@ class DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm(
recipe=(1, 128, 128),
num_groups=module.w3_w1_weight.shape[0],
is_sfa=False)
module.w3_w1_weight_scaling_factor = nn.Parameter(
transformed_w3_w1_weight_scaling_factor = nn.Parameter(
transfromed_w3_w1_scale, requires_grad=False)
replace_parameter_and_save_metadata(
module, "w3_w1_weight_scaling_factor",
transformed_w3_w1_weight_scaling_factor,
module.rebuild_tensor_metadata)
transfromed_w2_scale = transform_sf_into_required_layout(
module.quant_scales[1],
mn=module.w2_weight.shape[1],
@ -909,8 +1034,12 @@ class DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm(
recipe=(1, 128, 128),
num_groups=module.w3_w1_weight.shape[0],
is_sfa=False)
module.w2_weight_scaling_factor = nn.Parameter(transfromed_w2_scale,
requires_grad=False)
transformed_w2_weight_scaling_factor = nn.Parameter(
transfromed_w2_scale, requires_grad=False)
replace_parameter_and_save_metadata(
module, "w2_weight_scaling_factor",
transformed_w2_weight_scaling_factor,
module.rebuild_tensor_metadata)
self.setup_quant_scales(module)

View File

@ -28,7 +28,8 @@ from tensorrt_llm.quantization.utils.fp8_utils import (
from ..._utils import get_sm_version, is_sm_100f
from ...models.modeling_utils import QuantConfig
from ..utils import Fp4QuantizedTensor, get_model_extra_attrs, unswizzle_sf
from ..utils import (Fp4QuantizedTensor, get_model_extra_attrs,
replace_parameter_and_save_metadata, unswizzle_sf)
class WeightMode(str, enum.Enum):
@ -39,6 +40,40 @@ class WeightMode(str, enum.Enum):
# weight of a fused gate and up linear layer
FUSED_GATE_UP_LINEAR = 'fused_gate_up_linear'
@property
def int_value(self) -> int:
_INT_MAP = {
WeightMode.VANILLA: 1,
WeightMode.FUSED_GATE_UP_LINEAR: 2,
WeightMode.FUSED_QKV_LINEAR: 3,
}
return _INT_MAP[self]
@property
def shard_keys(self) -> list[str] | None:
_SHARD_KEYS_MAP = {
WeightMode.VANILLA: None,
WeightMode.FUSED_GATE_UP_LINEAR: ['gate', 'up'],
WeightMode.FUSED_QKV_LINEAR: ['q', 'k', 'v'],
}
return _SHARD_KEYS_MAP[self]
@property
def shard_key_to_index(self) -> dict[str, int] | None:
_SHARD_KEY_TO_INDEX_MAP = {
WeightMode.VANILLA: None,
WeightMode.FUSED_GATE_UP_LINEAR: {
'gate': 0,
'up': 1
},
WeightMode.FUSED_QKV_LINEAR: {
'q': 0,
'k': 1,
'v': 2
},
}
return _SHARD_KEY_TO_INDEX_MAP[self]
@dataclass(kw_only=True)
class WeightsLoadingConfig:
@ -327,6 +362,9 @@ class LinearMethodBase(ABC):
else:
raise ValueError(f'unsupported weight mode: {weight_mode}')
if not allow_partial_loading:
self.process_weights_after_loading(module)
def post_load_weights(self, module: Linear):
pass
@ -367,6 +405,36 @@ class LinearMethodBase(ABC):
"""
raise NotImplementedError
def process_weights_after_loading(self, module: Linear):
"""
Process quantization weights and scales after loading weights.
"""
weight_mode = module.weights_loading_config.weight_mode
if weight_mode == WeightMode.VANILLA:
self.process_weights_after_loading_vanilla(module)
elif weight_mode == WeightMode.FUSED_QKV_LINEAR:
self.process_weights_after_loading_fused_qkv_linear(module)
elif weight_mode == WeightMode.FUSED_GATE_UP_LINEAR:
self.process_weights_after_loading_fused_gate_up_linear(module)
else:
raise ValueError(f'unsupported weight mode: {weight_mode}')
def process_weights_after_loading_vanilla(self, module: Linear):
"""
Process quantization weights and scales after loading weights for vanilla linear layer.
"""
def process_weights_after_loading_fused_qkv_linear(self, module: Linear):
"""
Process quantization weights and scales after loading weights for fused QKV linear layer.
"""
def process_weights_after_loading_fused_gate_up_linear(
self, module: Linear):
"""
Process quantization weights and scales after loading weights for fused gate up linear layer.
"""
class UnquantizedLinearMethod(LinearMethodBase):
@ -382,6 +450,8 @@ class UnquantizedLinearMethod(LinearMethodBase):
else:
module.register_parameter("bias", None)
module.rebuild_tensor_metadata = {}
def apply(self, module: Linear, input: torch.Tensor,
bias: Optional[torch.Tensor]):
if module.use_custom_cublas_mm:
@ -440,8 +510,17 @@ class UnquantizedLinearMethod(LinearMethodBase):
copy_weight_shard(module.weight, weight, shard_offset,
shard_size)
def pre_reload_weights(self, module: Linear):
for param_name, metadata in module.rebuild_tensor_metadata.items():
logger.warning(
f"Pre-reloading weight '{param_name}' requires tensor re-creation, which will invalidate existing CUDA graphs."
)
param = Parameter(torch.empty_like(metadata, device="cuda"),
requires_grad=False)
module.register_parameter(param_name, param)
class FP8QDQLinearMethod(LinearMethodBase):
class FP8QDQLinearMethod(UnquantizedLinearMethod):
def create_weights(self, module: Linear, in_features: int,
out_features: int, bias: bool, dtype: torch.dtype):
@ -468,6 +547,8 @@ class FP8QDQLinearMethod(LinearMethodBase):
else:
module.register_parameter("bias", None)
module.rebuild_tensor_metadata = {}
def apply(self, module: Linear, input: torch.Tensor,
bias: Optional[torch.Tensor]):
cur_input_scale = module.input_scale
@ -518,93 +599,224 @@ class FP8QDQLinearMethod(LinearMethodBase):
v_scale.append(w["v_scale"][...].reshape([]))
return k_scale, v_scale
def load_weight_scales(self, weights: List[Dict]):
input_scale, weight_scale = [], []
for w in weights:
if "input_scale" in w:
input_scale.append(w["input_scale"][...].reshape([]))
if "weight_scale" in w:
weight_scale.append(w["weight_scale"][...].reshape([]))
return input_scale, weight_scale
def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None:
load_weights_vanilla_helper(module, weights)
input_scale, weight_scale = self.load_weight_scales(weights)
if len(input_scale) != 0:
# Static quantization
copy_weight(module.input_scale, input_scale[0])
module.inv_input_scale.data = 1.0 / module.input_scale
def load_weight_scales(self,
weights: List[Dict],
shard_keys: list[str] = None):
input_scales, weight_scales = {}, {}
if shard_keys is None:
for w in weights:
if "input_scale" in w:
input_scales[None] = w["input_scale"][...].reshape([])
if "weight_scale" in w:
weight_scales[None] = w["weight_scale"][...].reshape([])
else:
# Dynamic quantization
for shard_key, w in zip(shard_keys, weights):
if "input_scale" in w:
input_scales[shard_key] = w["input_scale"][...].reshape([])
if "weight_scale" in w:
weight_scales[shard_key] = w["weight_scale"][...].reshape(
[])
return input_scales, weight_scales
def load_weights_vanilla(self,
module: Linear,
weights: List[Dict],
allow_partial_loading: bool = False) -> None:
super().load_weights_vanilla(
module, weights, allow_partial_loading=allow_partial_loading)
input_scale, weight_scale = self.load_weight_scales(weights)
if input_scale:
copy_weight(module.input_scale, input_scale[None])
module.inv_input_scale.data = 1.0 / module.input_scale
setattr(module, "has_static_input_scale", True)
if weight_scale:
copy_weight(module.weight_scale, weight_scale[None])
def process_weights_after_loading_vanilla(self, module: Linear):
if not getattr(module, "has_static_input_scale", False):
module.input_scale = None
module.inv_input_scale = None
copy_weight(module.weight_scale, weight_scale[0])
delattr(module, "has_static_input_scale")
def load_weights_fused_qkv_linear(self, module: Linear,
weights: List[Dict]) -> None:
q_weight, k_weight, v_weight = load_weights_fused_qkv_helper(
module, weights)
def load_weights_fused_qkv_linear(
self,
module: Linear,
weights: List[Dict],
allow_partial_loading: bool = False) -> None:
"""
Load weights for fused QKV linear layer.
input_scale, weight_scale = self.load_weight_scales(weights)
if len(input_scale) != 0:
# Static quantization
copy_weight(module.input_scale, max(input_scale))
else:
# Dynamic quantization
module.input_scale = None
In partial loading mode, only loads weights and scales to their designated positions.
The actual rescaling is deferred to process_weights_after_loading_fused_qkv_linear.
"""
# Parent class handles weight loading
super().load_weights_fused_qkv_linear(
module, weights, allow_partial_loading=allow_partial_loading)
weight_mode = module.weights_loading_config.weight_mode
if not hasattr(module, "tmp_input_scales"):
module.tmp_input_scales = torch.empty(
weight_mode.int_value,
dtype=torch.float32,
device=module.input_scale.device)
if not hasattr(module, "tmp_weight_scales"):
module.tmp_weight_scales = torch.empty(
weight_mode.int_value,
dtype=torch.float32,
device=module.weight_scale.device)
# Load input_scale and weight_scale to tmp_qkv_input_scales and tmp_qkv_weight_scales
# q -> index 0, k -> index 1, v -> index 2
input_scales, weight_scales = self.load_weight_scales(
weights, shard_keys=weight_mode.shard_keys)
shard_key_to_index = weight_mode.shard_key_to_index
copy_weight(module.weight_scale, max(weight_scale))
for shard_key, scale in input_scales.items():
idx = shard_key_to_index[shard_key]
module.tmp_input_scales[idx] = scale
setattr(module, "has_static_input_scale", True)
# use in-place multiplication and division to avoid extra memory allocation
q_weight = q_weight.to(module.dtype).mul_(weight_scale[0])
k_weight = k_weight.to(module.dtype).mul_(weight_scale[1])
v_weight = v_weight.to(module.dtype).mul_(weight_scale[2])
fused_weight = torch.cat((q_weight, k_weight, v_weight))
fused_weight = fused_weight.div_(
module.weight_scale.to(fused_weight.device)).to(torch.float8_e4m3fn)
copy_weight(module.weight, fused_weight)
for shard_key, scale in weight_scales.items():
idx = shard_key_to_index[shard_key]
module.tmp_weight_scales[idx] = scale
# Load k and v scales, used for NVFP4 KV cache
# Store them temporarily for post-processing
k_scale, v_scale = self.load_kv_scales(weights)
# NOTE: Currently the calibrated kv scales may cause overflow for certain input, disabling by default.
if k_scale:
if getattr(module, "tmp_k_scales", None) is None:
module.tmp_k_scales = []
module.tmp_k_scales.extend(k_scale)
if v_scale:
if getattr(module, "tmp_v_scales", None) is None:
module.tmp_v_scales = []
module.tmp_v_scales.extend(v_scale)
def rescale_fused_weights(self, module: Linear):
"""
Helper function to rescale fused weights.
This method:
1. Computes the max input_scale of all shards(qkv or gate/up) and update input_scale parameter to the max value
2. Computes the max weight_scale across all shards(qkv or gate/up)
3. Rescales each weight shard: weight * (original_scale / max_scale)
4. Updates weight_scale parameter to the unified max value
"""
weight_mode = module.weights_loading_config.weight_mode
shard_key_to_index = weight_mode.shard_key_to_index
# Handle input_scale
if getattr(module, "has_static_input_scale", False):
# Compute max and replace input_scale with a new parameter
max_input_scale = module.tmp_input_scales.max()
module.input_scale.data.copy_(max_input_scale)
module.inv_input_scale.data = 1.0 / module.input_scale
delattr(module, "has_static_input_scale")
else:
module.input_scale = None
module.inv_input_scale = None
# Compute max weight_scale
max_weight_scale = module.tmp_weight_scales.max()
module.weight_scale.data.copy_(max_weight_scale)
# Rescale each weight shard: weight * (original_scale / max_scale)
for shard_key in weight_mode.shard_keys:
idx = shard_key_to_index[shard_key]
original_scale = module.tmp_weight_scales[idx]
# Get shard position from mapping
shard_offset, shard_size = module.fused_weight_shard_indices_mapping[
shard_key]
# Rescale: FP8 -> BF16 -> multiply by ratio -> FP8
weight_shard = module.weight.data[shard_offset:shard_offset +
shard_size]
rescaled_weight = (weight_shard.to(module.dtype).mul_(
original_scale / max_weight_scale).to(torch.float8_e4m3fn))
module.weight.data[shard_offset:shard_offset +
shard_size] = rescaled_weight
delattr(module, "tmp_input_scales")
delattr(module, "tmp_weight_scales")
def process_weights_after_loading_fused_qkv_linear(self, module: Linear):
"""
Post-process weights after all partial loads are complete.
"""
self.rescale_fused_weights(module)
# Handle kv_scales for NVFP4 KV cache
if os.environ.get("TRTLLM_LOAD_KV_SCALES", "0") == "1":
if len(k_scale) != 0:
assert len(v_scale) != 0
k_scales = getattr(module, "tmp_k_scales", [])
v_scales = getattr(module, "tmp_v_scales", [])
if k_scales:
assert v_scales, "k_scale and v_scale must be loaded together"
# The calibrated KV scales are amax / (6 * 448), but the requested KV scales are amax / 448,
# to avoid overflow when dequantizing NVFP4 in attention kernels.
copy_weight(
module.kv_scales,
torch.tensor(
[1.0, max(k_scale) * 6.0,
max(v_scale) * 6.0],
dtype=torch.float32))
torch.tensor([
1.0,
max(k_scales).item() * 6.0,
max(v_scales).item() * 6.0
],
dtype=torch.float32))
module.inv_kv_scales.data = 1.0 / module.kv_scales
def load_weights_fused_gate_up_linear(self, module: Linear,
weights: List[Dict]) -> None:
input_scale, weight_scale = self.load_weight_scales(weights)
if len(input_scale) != 0:
# Static quantization
copy_weight(module.input_scale, max(input_scale))
else:
# Dynamic quantization
module.input_scale = None
copy_weight(module.weight_scale, max(weight_scale))
# Clean up temporary attributes
if hasattr(module, "tmp_k_scales"):
delattr(module, "tmp_k_scales")
if hasattr(module, "tmp_v_scales"):
delattr(module, "tmp_v_scales")
gate_weight, up_weight = load_weights_fused_gate_up_helper(
module, weights)
def load_weights_fused_gate_up_linear(
self,
module: Linear,
weights: List[Dict],
allow_partial_loading: bool = False) -> None:
"""
Load weights for fused gate/up linear layer.
# use in-place multiplication and division to avoid extra memory allocation
gate_weight = gate_weight.to(module.dtype).mul_(weight_scale[0])
up_weight = up_weight.to(module.dtype).mul_(weight_scale[1])
fused_weight = torch.cat((gate_weight, up_weight))
fused_weight = fused_weight.div_(
module.weight_scale.to(fused_weight.device)).to(torch.float8_e4m3fn)
copy_weight(module.weight, fused_weight)
In partial loading mode, only loads weights and scales to their designated positions.
The actual rescaling is deferred to process_weights_after_loading_fused_gate_up_linear.
"""
# Parent class handles weight loading
super().load_weights_fused_gate_up_linear(
module, weights, allow_partial_loading=allow_partial_loading)
weight_mode = module.weights_loading_config.weight_mode
if not hasattr(module, "tmp_input_scales"):
module.tmp_input_scales = torch.empty(
weight_mode.int_value,
dtype=torch.float32,
device=module.input_scale.device)
if not hasattr(module, "tmp_weight_scales"):
module.tmp_weight_scales = torch.empty(
weight_mode.int_value,
dtype=torch.float32,
device=module.weight_scale.device)
# Load input_scale and weight_scale to their designated positions
# gate -> index 0, up -> index 1
input_scales, weight_scales = self.load_weight_scales(
weights, shard_keys=weight_mode.shard_keys)
shard_key_to_index = weight_mode.shard_key_to_index
for shard_key, scale in input_scales.items():
idx = shard_key_to_index[shard_key]
module.tmp_input_scales[idx] = scale
setattr(module, "has_static_input_scale", True)
for shard_key, scale in weight_scales.items():
idx = shard_key_to_index[shard_key]
module.tmp_weight_scales[idx] = scale
def process_weights_after_loading_fused_gate_up_linear(
self, module: Linear):
"""
Post-process weights after all partial loads are complete.
"""
self.rescale_fused_weights(module)
class FP8RowwiseLinearMethod(LinearMethodBase):
class FP8RowwiseLinearMethod(UnquantizedLinearMethod):
def create_weights(self, module: Linear, in_features: int,
out_features: int, bias: bool, dtype: torch.dtype):
@ -628,6 +840,8 @@ class FP8RowwiseLinearMethod(LinearMethodBase):
else:
module.register_parameter("bias", None)
module.rebuild_tensor_metadata = {}
def apply(self, module: Linear, input: torch.Tensor,
bias: Optional[torch.Tensor]):
# FP8 tensor inputs are from attention. Directly use ones as scale.
@ -661,51 +875,72 @@ class FP8RowwiseLinearMethod(LinearMethodBase):
scale_name = "weight_scale"
return scale_name
def load_weights_vanilla(self, module: Linear, weights: List[Dict]):
load_weights_vanilla_helper(module, weights)
def load_weights_vanilla(self,
module: Linear,
weights: List[Dict],
allow_partial_loading: bool = False):
super().load_weights_vanilla(
module, weights, allow_partial_loading=allow_partial_loading)
scale_name = self._get_scale_name(weights)
weight_scale = load_weight_shard(weights[0][scale_name], module.tp_size,
module.tp_rank, module.tp_mode)
copy_weight(module.weight_scale, weight_scale)
if scale_name in weights[0]:
weight_scale = load_weight_shard(weights[0][scale_name],
module.tp_size, module.tp_rank,
module.tp_mode)
copy_weight(module.weight_scale, weight_scale)
if "input_scale" in weights[0]:
copy_weight(module.input_scale, weights[0]["input_scale"])
module.inv_input_scale.data = 1.0 / module.input_scale
def load_weights_fused_qkv_linear(self, module: Linear,
weights: List[Dict]):
q_weight, k_weight, v_weight = load_weights_fused_qkv_helper(
module, weights)
fused_weight = torch.cat((q_weight, k_weight, v_weight))
copy_weight(module.weight, fused_weight)
def load_weights_fused_qkv_linear(self,
module: Linear,
weights: List[Dict],
allow_partial_loading: bool = False):
super().load_weights_fused_qkv_linear(
module, weights, allow_partial_loading=allow_partial_loading)
scale_name = self._get_scale_name(weights)
q_scale = load_weight_shard(weights[0][scale_name], module.tp_size,
module.tp_rank, module.tp_mode)
k_scale = load_weight_shard(weights[1][scale_name], module.tp_size,
module.tp_rank, module.tp_mode)
v_scale = load_weight_shard(weights[2][scale_name], module.tp_size,
module.tp_rank, module.tp_mode)
fused_fp8_block_scale = torch.cat((q_scale, k_scale, v_scale))
copy_weight(module.weight_scale, fused_fp8_block_scale)
def load_weights_fused_gate_up_linear(self, module: Linear,
weights: List[Dict]):
gate_weight, up_weight = load_weights_fused_gate_up_helper(
module, weights)
fused_weight = torch.cat((gate_weight, up_weight))
copy_weight(module.weight, fused_weight)
q_scale = load_weight_shard(
weights[0][scale_name], module.tp_size, module.tp_rank,
module.tp_mode) if scale_name in weights[0] else None
k_scale = load_weight_shard(
weights[1][scale_name], module.tp_size, module.tp_rank,
module.tp_mode) if scale_name in weights[1] else None
v_scale = load_weight_shard(
weights[2][scale_name], module.tp_size, module.tp_rank,
module.tp_mode) if scale_name in weights[2] else None
for shard_key, scale in zip(
module.fused_weight_shard_indices_mapping.keys(),
[q_scale, k_scale, v_scale]):
if scale is not None:
shard_offset, shard_size = module.fused_weight_shard_indices_mapping[
shard_key]
copy_weight_shard(module.weight_scale, scale, shard_offset,
shard_size)
def load_weights_fused_gate_up_linear(
self,
module: Linear,
weights: List[Dict],
allow_partial_loading: bool = False) -> None:
super().load_weights_fused_gate_up_linear(
module, weights, allow_partial_loading=allow_partial_loading)
scale_name = self._get_scale_name(weights)
left_scale = load_weight_shard(weights[0][scale_name], module.tp_size,
module.tp_rank, module.tp_mode)
right_scale = load_weight_shard(weights[1][scale_name], module.tp_size,
module.tp_rank, module.tp_mode)
fused_scale = torch.cat((left_scale, right_scale))
copy_weight(module.weight_scale, fused_scale)
gate_scale = load_weight_shard(
weights[0][scale_name], module.tp_size, module.tp_rank,
module.tp_mode) if scale_name in weights[0] else None
up_scale = load_weight_shard(
weights[1][scale_name], module.tp_size, module.tp_rank,
module.tp_mode) if scale_name in weights[1] else None
for shard_key, scale in zip(
module.fused_weight_shard_indices_mapping.keys(),
[gate_scale, up_scale]):
if scale is not None:
shard_offset, shard_size = module.fused_weight_shard_indices_mapping[
shard_key]
copy_weight_shard(module.weight_scale, scale, shard_offset,
shard_size)
class FP8BlockScalesLinearMethod(LinearMethodBase):
class FP8BlockScalesLinearMethod(UnquantizedLinearMethod):
def create_weights(self, module: Linear, in_features: int,
out_features: int, bias: bool, dtype: torch.dtype):
@ -732,6 +967,8 @@ class FP8BlockScalesLinearMethod(LinearMethodBase):
else:
module.register_parameter("bias", None)
module.rebuild_tensor_metadata = {}
def apply(self, module: Linear, input: torch.Tensor,
bias: Optional[torch.Tensor]):
if input.dtype == torch.float8_e4m3fn:
@ -770,75 +1007,107 @@ class FP8BlockScalesLinearMethod(LinearMethodBase):
def _get_scale_name(self, weights: List[Dict]):
# `weight_scale_inv` for DS recipe and `weight_scale` for ModelOpt recipe.
# Actually they hold identical values of data_amax / 448.
scale_name = "weight_scale_inv"
if scale_name not in weights[0]:
scale_name = "weight_scale"
return scale_name
for w in weights:
if "weight_scale_inv" in w:
return "weight_scale_inv"
return "weight_scale"
def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None:
load_weights_vanilla_helper(module, weights)
def load_weights_vanilla(self,
module: Linear,
weights: List[Dict],
allow_partial_loading: bool = False) -> None:
super().load_weights_vanilla(
module, weights, allow_partial_loading=allow_partial_loading)
scale_name = self._get_scale_name(weights)
full_weight_scale = weights[0][scale_name]
# modelopt fp8_pb_wo can have 2 extra singleton dimensions
if full_weight_scale.dim() == 4:
full_weight_scale = full_weight_scale.squeeze(1).squeeze(-1)
weight_scale = load_weight_shard(full_weight_scale, module.tp_size,
module.tp_rank, module.tp_mode)
copy_weight(module.weight_scale, weight_scale)
if scale_name in weights[0]:
full_weight_scale = weights[0][scale_name]
# modelopt fp8_pb_wo can have 2 extra singleton dimensions
if full_weight_scale.dim() == 4:
full_weight_scale = full_weight_scale.squeeze(1).squeeze(-1)
weight_scale = load_weight_shard(full_weight_scale, module.tp_size,
module.tp_rank, module.tp_mode)
copy_weight(module.weight_scale, weight_scale)
if "input_scale" in weights[0]:
copy_weight(module.input_scale, weights[0]["input_scale"])
module.inv_input_scale.data = 1.0 / module.input_scale
def load_weights_fused_qkv_linear(self, module: Linear,
weights: List[Dict]) -> None:
q_weight, k_weight, v_weight = load_weights_fused_qkv_helper(
module, weights)
fused_weight = torch.cat((q_weight, k_weight, v_weight))
def remap_fused_shard_indices_by_divisible_factor(self, mapping: Dict,
divisible_factor: int):
"""
Remap fused weight shard indices to scale coordinates by dividing by divisible_factor.
Args:
mapping: Dict of {shard_key: (offset, size)} in weight coordinates
divisible_factor: Block size (e.g., 128 for block-scale quantization)
Returns:
Dict of {shard_key: (scale_offset, scale_size)} in scale coordinates
"""
result = {}
for key, (offset, size) in mapping.items():
scale_offset = math.ceil(offset / divisible_factor)
scale_size = math.ceil(size / divisible_factor)
result[key] = (scale_offset, scale_size)
return result
def load_weights_fused_qkv_linear(
self,
module: Linear,
weights: List[Dict],
allow_partial_loading: bool = False) -> None:
super().load_weights_fused_qkv_linear(
module, weights, allow_partial_loading=allow_partial_loading)
scale_name = self._get_scale_name(weights)
full_q_scale = weights[0][scale_name]
full_k_scale = weights[1][scale_name]
full_v_scale = weights[2][scale_name]
# modelopt fp8_pb_wo can have 2 extra singleton dimensions
if full_q_scale.dim() == 4:
full_q_scale = full_q_scale.squeeze(1).squeeze(-1)
if full_k_scale.dim() == 4:
full_k_scale = full_k_scale.squeeze(1).squeeze(-1)
if full_v_scale.dim() == 4:
full_v_scale = full_v_scale.squeeze(1).squeeze(-1)
q_scale = load_weight_shard(full_q_scale, module.tp_size,
module.tp_rank, module.tp_mode)
k_scale = load_weight_shard(full_k_scale, module.tp_size,
module.tp_rank, module.tp_mode)
v_scale = load_weight_shard(full_v_scale, module.tp_size,
module.tp_rank, module.tp_mode)
fused_fp8_block_scale = torch.cat((q_scale, k_scale, v_scale))
full_scales = [
w[scale_name] if scale_name in w else None for w in weights[:3]
]
full_scales_squeezed = [
s.squeeze(1).squeeze(-1) if s is not None and s.dim() == 4 else s
for s in full_scales
]
copy_weight(module.weight, fused_weight)
copy_weight(module.weight_scale, fused_fp8_block_scale)
scales = [
load_weight_shard(s, module.tp_size, module.tp_rank, module.tp_mode)
if s is not None else None for s in full_scales_squeezed
]
processed_mapping = self.remap_fused_shard_indices_by_divisible_factor(
module.fused_weight_shard_indices_mapping, 128)
for shard_key, scale in zip(processed_mapping.keys(), scales):
if scale is not None:
shard_offset, shard_size = processed_mapping[shard_key]
copy_weight_shard(module.weight_scale, scale, shard_offset,
shard_size)
def load_weights_fused_gate_up_linear(self, module: Linear,
weights: List[Dict]) -> None:
gate_weight, up_weight = load_weights_fused_gate_up_helper(
module, weights)
fused_weight = torch.cat((gate_weight, up_weight))
def load_weights_fused_gate_up_linear(
self,
module: Linear,
weights: List[Dict],
allow_partial_loading: bool = False) -> None:
super().load_weights_fused_gate_up_linear(
module, weights, allow_partial_loading=allow_partial_loading)
scale_name = self._get_scale_name(weights)
full_left_scale = weights[0][scale_name]
full_right_scale = weights[1][scale_name]
# modelopt fp8_pb_wo can have 2 extra singleton dimensions
if full_left_scale.dim() == 4:
full_left_scale = full_left_scale.squeeze(1).squeeze(-1)
if full_right_scale.dim() == 4:
full_right_scale = full_right_scale.squeeze(1).squeeze(-1)
left_scale = load_weight_shard(full_left_scale, module.tp_size,
module.tp_rank, module.tp_mode)
right_scale = load_weight_shard(full_right_scale, module.tp_size,
module.tp_rank, module.tp_mode)
fused_scale = torch.cat([left_scale, right_scale], dim=0)
copy_weight(module.weight, fused_weight)
copy_weight(module.weight_scale, fused_scale)
full_scales = [
w[scale_name] if scale_name in w else None for w in weights[:2]
]
full_scales_squeezed = [
s.squeeze(1).squeeze(-1) if s is not None and s.dim() == 4 else s
for s in full_scales
]
scales = [
load_weight_shard(s, module.tp_size, module.tp_rank, module.tp_mode)
if s is not None else None for s in full_scales_squeezed
]
processed_mapping = self.remap_fused_shard_indices_by_divisible_factor(
module.fused_weight_shard_indices_mapping, 128)
for shard_key, scale in zip(processed_mapping.keys(), scales):
if scale is not None:
shard_offset, shard_size = processed_mapping[shard_key]
copy_weight_shard(module.weight_scale, scale, shard_offset,
shard_size)
def post_load_weights(self, module: Linear):
super().post_load_weights(module)
@ -847,17 +1116,19 @@ class FP8BlockScalesLinearMethod(LinearMethodBase):
get_sm_version() == 120:
weight, weight_scale = resmooth_to_fp8_e8m0(module.weight,
module.weight_scale)
transfromed_scale = transform_sf_into_required_layout(
transformed_scale = transform_sf_into_required_layout(
weight_scale,
mn=weight.shape[0],
k=weight.shape[1],
recipe=(1, 128, 128),
is_sfa=False)
module.weight = nn.Parameter(weight, requires_grad=False)
module.weight_scale = nn.Parameter(
transfromed_scale,
requires_grad=False,
)
replace_parameter_and_save_metadata(
module, "weight", nn.Parameter(weight, requires_grad=False),
module.rebuild_tensor_metadata)
replace_parameter_and_save_metadata(
module, "weight_scale",
nn.Parameter(transformed_scale, requires_grad=False),
module.rebuild_tensor_metadata)
class NVFP4LinearMethod(LinearMethodBase):
@ -2293,5 +2564,14 @@ class Linear(nn.Module):
weight_mode,
allow_partial_loading=allow_partial_loading)
def process_weights_after_loading(self):
self.quant_method.process_weights_after_loading(self)
def post_load_weights(self):
self.quant_method.post_load_weights(self)
def pre_reload_weights(self):
assert hasattr(
self.quant_method, "pre_reload_weights"
), "pre_reload_weights is not supported for this quant method"
self.quant_method.pre_reload_weights(self)

View File

@ -414,3 +414,15 @@ def maybe_compiled_copy_(dst, src):
@maybe_compile
def maybe_compiled_cat(tensors, dim):
return torch.cat(tensors, dim)
def replace_parameter_and_save_metadata(module: torch.nn.Module,
param_name: str,
new_param: torch.nn.Parameter,
metadata_dict: Dict):
"""
Replace a parameter in a module and save the metadata of the original parameter.
"""
if param_name not in metadata_dict:
metadata_dict[param_name] = getattr(module, param_name).to("meta")
module.register_parameter(param_name, new_param)

View File

@ -50,6 +50,13 @@ class WorkerExtension:
Exception: Re-raises any exception encountered during weight update.
"""
try:
if not hasattr(self.engine.model_engine.model, "first_pre_reload_weights"):
for module in self.engine.model_engine.model.modules():
if hasattr(module, "pre_reload_weights") and not getattr(
module, "_weights_removed", False
):
module.pre_reload_weights()
setattr(self.engine.model_engine.model, "first_pre_reload_weights", True)
if ipc_handles is not None:
logger.info("Update weights from IPC handles")
device_uuid = get_device_uuid(self.device_id)
@ -82,6 +89,10 @@ class WorkerExtension:
else:
logger.info("Finalize update weights")
for module in self.engine.model_engine.model.modules():
if hasattr(module, "process_weights_after_loading") and not getattr(
module, "_weights_removed", False
):
module.process_weights_after_loading()
if hasattr(module, "post_load_weights") and not getattr(
module, "_weights_removed", False
):
@ -93,6 +104,7 @@ class WorkerExtension:
moe_load_balancer.finalize_model()
logger.info("moe_load_balancer finalize model done")
self.engine.reset_prefix_cache()
delattr(self.engine.model_engine.model, "first_pre_reload_weights")
except Exception as e:
logger.error("Encountered an error in update_weights")

View File

@ -15,112 +15,21 @@
import asyncio
import os
from functools import partial
from typing import List, Tuple
from typing import List
import pytest
import ray
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoTokenizer
from utils.llm_data import llm_models_root
from utils.torch_ref import RefHFModel
from tensorrt_llm import LLM
from tensorrt_llm.llmapi import KvCacheConfig, SamplingParams
class HFModel:
def __init__(self, model_name: str, device_id: int):
self.device_id = device_id
self.model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch.bfloat16
).to(f"cuda:{device_id}")
def generate_batch_with_padding(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
position_ids: torch.Tensor,
responses: List[List[int]],
prompt_max_len: int = 1024,
micro_batch_size: int = 16,
):
"""
Synchronous inference on a batch with micro-batching.
Directly extracts response logprobs to save memory.
Args:
input_ids: [batch_size, seq_len]
attention_mask: [batch_size, seq_len]
position_ids: [batch_size, seq_len]
responses: List of response token IDs for each sample
prompt_max_len: Maximum prompt length (default 1024)
micro_batch_size: Size of each micro batch to avoid OOM
Returns:
List of logprobs tensors, one per sample [response_len]
"""
# Move tensors to the correct device
input_ids = input_ids.to(f"cuda:{self.device_id}")
attention_mask = attention_mask.to(f"cuda:{self.device_id}")
position_ids = position_ids.to(f"cuda:{self.device_id}")
batch_size = input_ids.shape[0]
num_micro_batches = (batch_size + micro_batch_size - 1) // micro_batch_size
all_response_logprobs = []
with torch.no_grad():
for micro_idx in range(num_micro_batches):
start_idx = micro_idx * micro_batch_size
end_idx = min((micro_idx + 1) * micro_batch_size, batch_size)
# Extract micro batch
micro_input_ids = input_ids[start_idx:end_idx]
micro_attention_mask = attention_mask[start_idx:end_idx]
micro_position_ids = position_ids[start_idx:end_idx]
# Forward pass
outputs = self.model(
input_ids=micro_input_ids,
attention_mask=micro_attention_mask,
position_ids=micro_position_ids,
)
# Extract response logprobs for each sample in this micro batch
micro_logits = outputs.logits # [micro_batch_size, seq_len, vocab_size]
for i in range(micro_logits.shape[0]):
sample_idx = start_idx + i
response = responses[sample_idx]
response_len = len(response)
# Extract logits for predicting response tokens
# For predicting response[j], we need logits at position prompt_max_len-1+j
response_logits = micro_logits[
i, prompt_max_len - 1 : prompt_max_len - 1 + response_len, :
]
# Convert to logprobs
response_logprobs = torch.log_softmax(response_logits, dim=-1)
# Extract logprobs for the actual generated tokens
response_tensor = torch.tensor(
response, dtype=torch.long, device=response_logprobs.device
)
ref_logprob_for_tokens = torch.gather(
response_logprobs, dim=-1, index=response_tensor.unsqueeze(-1)
).squeeze(-1)
all_response_logprobs.append(ref_logprob_for_tokens)
# Free memory immediately after processing each micro batch
del outputs, micro_logits
torch.cuda.empty_cache()
return all_response_logprobs
async def generate_batch_async(
hf_model: HFModel,
hf_model: RefHFModel,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
position_ids: torch.Tensor,
@ -133,7 +42,7 @@ async def generate_batch_async(
Runs the synchronous model inference in a thread pool.
Args:
hf_model: HFModel instance
hf_model: RefHFModel instance
input_ids: Input token IDs
attention_mask: Attention mask
position_ids: Position IDs
@ -163,89 +72,6 @@ async def generate_batch_async(
return result
def pad_data(
original_prompts: List[List[int]],
generated_token_ids_list: List[List[int]],
prompt_max_len: int = 1024,
response_max_len: int = 1024,
pad_token_id: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Pad the data to the maximum length.
Structure:
[left_pad | actual_prompt | actual_response | right_pad]
|<-- prompt_max_len=1024 -->|<-- response_max_len=1024 -->|
Args:
original_prompts: List of prompt token IDs, len = batch_size
generated_token_ids_list: List of response token IDs, len = batch_size
prompt_max_len: Maximum length for prompt section (default 1024)
response_max_len: Maximum length for response section (default 1024)
pad_token_id: Token ID for padding (default 0)
Returns:
input_ids: Tensor of shape [batch_size, prompt_max_len + response_max_len]
attention_mask: Tensor of shape [batch_size, prompt_max_len + response_max_len]
position_ids: Tensor of shape [batch_size, prompt_max_len + response_max_len]
"""
batch_size = len(original_prompts)
total_len = prompt_max_len + response_max_len
for i, (prompt, response) in enumerate(zip(original_prompts, generated_token_ids_list)):
assert len(prompt) <= prompt_max_len, (
f"Batch {i}: Prompt length {len(prompt)} exceeds max {prompt_max_len}"
)
assert len(response) <= response_max_len, (
f"Batch {i}: Response length {len(response)} exceeds max {response_max_len}"
)
# Build batch tensors [batch_size, 2048]
batch_input_ids = torch.full(
(batch_size, total_len), pad_token_id, dtype=torch.long, device="cuda"
)
batch_attention_mask = torch.zeros((batch_size, total_len), dtype=torch.long, device="cuda")
batch_position_ids = torch.zeros((batch_size, total_len), dtype=torch.long, device="cuda")
response_lens = []
for i in range(batch_size):
prompt_tokens = original_prompts[i]
response_tokens = generated_token_ids_list[i]
prompt_len = len(prompt_tokens)
response_len = len(response_tokens)
response_lens.append(response_len)
left_pad_len = prompt_max_len - prompt_len
# Fill input_ids: [left_pad | prompt | response | right_pad]
prompt_start = left_pad_len
prompt_end = prompt_max_len
response_start = prompt_max_len
response_end = prompt_max_len + response_len
batch_input_ids[i, prompt_start:prompt_end] = torch.tensor(
prompt_tokens, dtype=torch.long, device="cuda"
)
batch_input_ids[i, response_start:response_end] = torch.tensor(
response_tokens, dtype=torch.long, device="cuda"
)
# Fill attention_mask: 1 for actual tokens, 0 for padding
batch_attention_mask[i, prompt_start:response_end] = 1
# Fill position_ids: sequential for actual tokens
actual_seq_len = prompt_len + response_len
batch_position_ids[i, prompt_start:response_end] = torch.arange(
actual_seq_len, dtype=torch.long, device="cuda"
)
# Right padding keeps the last position value
if response_len < response_max_len:
batch_position_ids[i, response_end:] = actual_seq_len - 1
return batch_input_ids, batch_attention_mask, batch_position_ids
def compare_logprobs(logprobs_list, ref_new_token_logprobs_list):
"""
logprobs_list: List[torch.Tensor] - LLM logprob values
@ -337,7 +163,7 @@ def test_accuracy_with_allreduce_strategy(model_dir, sampler_type, allreduce_str
ray.shutdown()
torch.cuda.empty_cache()
input_ids, attention_mask, position_ids = pad_data(test_prompts, llm_responses)
input_ids, attention_mask, position_ids = RefHFModel.pad_data(test_prompts, llm_responses)
# Split data across GPUs
num_gpus = 4
@ -347,7 +173,7 @@ def test_accuracy_with_allreduce_strategy(model_dir, sampler_type, allreduce_str
dp_hf_models = []
for device_id in range(num_gpus):
hf_model = HFModel(model_dir, device_id)
hf_model = RefHFModel(model_dir, device_id)
dp_hf_models.append(hf_model)
# Split input data and responses into chunks for each GPU
@ -367,7 +193,7 @@ def test_accuracy_with_allreduce_strategy(model_dir, sampler_type, allreduce_str
responses_chunks.append(llm_responses[start_idx:end_idx])
# Process each chunk on its corresponding GPU asynchronously
async def process_all_chunks(hf_models: List[HFModel]):
async def process_all_chunks(hf_models: List[RefHFModel]):
tasks = []
for i, (input_chunk, attn_chunk, pos_chunk, resp_chunk) in enumerate(
zip(input_ids_chunks, attention_mask_chunks, position_ids_chunks, responses_chunks)

View File

@ -1,62 +1,64 @@
from typing import Callable, List, Optional
import json
import os
import re
import shutil
from tempfile import TemporaryDirectory
from typing import Callable, List, Optional, Tuple
import pytest
import torch
from torch.multiprocessing.reductions import reduce_tensor
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from utils.llm_data import llm_models_root
from utils.torch_ref import RefHFModel
from tensorrt_llm import LLM
from tensorrt_llm._torch.utils import get_device_uuid
from tensorrt_llm.llmapi import KvCacheConfig, SamplingParams
class HFModel:
def __init__(self, model_name: str):
class RefHFModelWithIPCHandles(RefHFModel):
def __init__(self, model_dir: str, device_id: int = 0, num_hidden_layers: int = 4):
self.device_id = device_id
config = AutoConfig.from_pretrained(model_dir)
config.num_hidden_layers = num_hidden_layers
self.model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch.bfloat16
).to("cuda")
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.cuda_device = torch.cuda.current_device()
model_dir, config=config, torch_dtype=torch.bfloat16, attn_implementation="eager"
).to(f"cuda:{device_id}")
self.all_weights = {}
self.device_uuid = [HFModel.get_device_uuid(i) for i in range(torch.cuda.device_count())]
self.device_uuid = [get_device_uuid(i) for i in range(torch.cuda.device_count())]
self._replicate_weights()
@staticmethod
def get_device_uuid(cuda_device: int):
from tensorrt_llm._torch.utils import get_device_uuid
return get_device_uuid(cuda_device)
def _replicate_weights(self):
model_weights = []
for n, p in self.model.named_parameters():
model_weights.append((n, p.detach().clone()))
self.all_weights[self.cuda_device] = model_weights
self.all_weights[self.device_id] = model_weights
for i in range(torch.cuda.device_count()):
if i != self.cuda_device:
if i != self.device_id:
cur_weights = []
for n, p in self.all_weights[self.cuda_device]:
for n, p in self.all_weights[self.device_id]:
cur_weights.append((n, p.to("cuda:" + str(i))))
self.all_weights[i] = cur_weights
def get_weight_ipc_handles(
self,
cuda_device: Optional[List[int]] = None,
device_ids: Optional[List[int]] = None,
weight_filter: Optional[Callable[[str], bool]] = None,
):
"""
Get IPC handles for model weights with flexible filtering.
Args:
cuda_device: List of CUDA device indices to get weights from
device_ids: List of device indices to get weights from
weight_filter: Optional function that takes weight name and returns True if weight should be included
Returns:
ret: Dictionary containing weight handles
"""
ret = {}
device_list = list(range(torch.cuda.device_count())) if cuda_device is None else cuda_device
device_list = list(range(torch.cuda.device_count())) if device_ids is None else device_ids
for device in device_list:
all_handles = []
@ -71,50 +73,6 @@ class HFModel:
return ret
def generate_batch_incremental(
self, original_prompts: List[str], generated_token_ids_list: List[List[int]]
):
"""
Generate tokens incrementally for each prompt in the batch: [prompt, prompt+token0, prompt+token0+token1, ...]
"""
logits_list = []
for i in range(len(original_prompts)):
base_token_ids = self.tokenizer.encode(original_prompts[i], return_tensors="pt")[0].to(
"cuda"
)
cur_logits = []
for j in range(len(generated_token_ids_list[i])):
if j > 0:
cur_gen_tokens = torch.tensor(generated_token_ids_list[i][:j]).to("cuda")
cur_token_ids = torch.cat([base_token_ids, cur_gen_tokens], dim=-1)
else:
cur_token_ids = base_token_ids
ret = self.model.generate(
input_ids=cur_token_ids.unsqueeze(0).cuda(),
max_new_tokens=1,
do_sample=False,
return_dict_in_generate=True,
output_scores=True,
)
cur_logits.append(ret["scores"][0])
cur_logits = torch.stack(cur_logits, dim=0)
logits_list.append(cur_logits.squeeze(1))
return logits_list
def extract_tokens_from_outputs(outputs):
"""Extract individual tokens from LLM outputs using token IDs directly"""
tokens_list = []
for output in outputs:
# Get token IDs directly from the output
token_ids = output.outputs[0].token_ids
tokens_list.append(token_ids)
return tokens_list
def compare_logits(
logits_list: List[torch.Tensor],
@ -123,7 +81,6 @@ def compare_logits(
threshold: float = 0.9,
):
assert len(logits_list) == len(ref_logits_list)
for i in range(len(logits_list)):
assert logits_list[i].shape == ref_logits_list[i].shape
lhs_idx = torch.topk(logits_list[i], topk, dim=-1).indices
@ -142,119 +99,168 @@ def compare_logits(
)
def run_generate(llm, hf_model, prompts, sampling_params):
outputs = llm.generate(prompts, sampling_params)
def run_generate(
llm: LLM,
hf_model: RefHFModel,
prompts: List[List[int]],
sampling_params: SamplingParams,
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
llm_responses = []
llm_logits = []
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
llm_logits.append(output.outputs[0].generation_logits)
generated_token_ids_list = extract_tokens_from_outputs(outputs)
ref_logits = hf_model.generate_batch_incremental(prompts, generated_token_ids_list)
llm_responses.append(output.outputs[0].token_ids)
input_ids, attention_mask, position_ids = RefHFModel.pad_data(prompts, llm_responses)
ref_logits = hf_model.generate_batch_with_padding(
input_ids, attention_mask, position_ids, llm_responses, return_logits=True
)
return llm_logits, ref_logits
def process_and_copy_folder(src_folder, dst_folder, num_hidden_layers: int = 4):
if os.path.exists(dst_folder):
shutil.rmtree(dst_folder)
os.makedirs(dst_folder)
for root, dirs, files in os.walk(src_folder):
rel_path = os.path.relpath(root, src_folder)
dest_dir = os.path.join(dst_folder, rel_path)
if not os.path.exists(dest_dir):
os.makedirs(dest_dir)
for file in files:
src_path = os.path.join(root, file)
dest_path = os.path.join(dest_dir, file)
if "safetensor" in file:
continue
if file == "config.json":
with open(src_path, "r", encoding="utf-8") as f:
config = json.load(f)
config["num_hidden_layers"] = num_hidden_layers
with open(dest_path, "w", encoding="utf-8") as f:
json.dump(config, f, indent=2, ensure_ascii=False)
else:
shutil.copy2(src_path, dest_path)
@pytest.mark.parametrize(
"model_dir",
["Qwen2.5-0.5B-Instruct", "Qwen3/Qwen3-8B", "llama-models-v2/TinyLlama-1.1B-Chat-v1.0"],
[
"llama-models-v2/TinyLlama-1.1B-Chat-v1.0",
"Qwen2.5-0.5B-Instruct",
"Qwen3/Qwen3-8B",
"Qwen3/Qwen3-30B-A3B",
"Qwen3/Qwen3-8B-FP8",
"Qwen3/Qwen3-30B-A3B-FP8",
],
)
def test_llm_update_weights(model_dir):
model_dir = str(llm_models_root() / model_dir)
kv_cache_config = KvCacheConfig(enable_block_reuse=True, free_gpu_memory_fraction=0.1)
with TemporaryDirectory() as tmp_model_dir:
num_hidden_layers = 1
process_and_copy_folder(model_dir, tmp_model_dir, num_hidden_layers=num_hidden_layers)
hf_model = RefHFModelWithIPCHandles(model_dir, num_hidden_layers=num_hidden_layers)
tokenizer = AutoTokenizer.from_pretrained(model_dir)
kv_cache_config = KvCacheConfig(enable_block_reuse=True, free_gpu_memory_fraction=0.1)
llm = LLM(
model=tmp_model_dir,
ray_worker_extension_cls="tensorrt_llm.llmapi.rlhf_utils.WorkerExtension",
tensor_parallel_size=1,
load_format="dummy",
pipeline_parallel_size=1,
kv_cache_config=kv_cache_config,
)
hf_model = HFModel(model_dir)
# Generate texts from the prompts.
prompts_texts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
prompts = [tokenizer.encode(prompt) for prompt in prompts_texts]
del tokenizer
sampling_params = SamplingParams(
temperature=0, return_generation_logits=True, max_tokens=1024
)
llm = LLM(
model=model_dir,
ray_worker_extension_cls="tensorrt_llm.llmapi.rlhf_utils.WorkerExtension",
tensor_parallel_size=1,
load_format="dummy",
pipeline_parallel_size=1,
kv_cache_config=kv_cache_config,
)
ipc_handles = hf_model.get_weight_ipc_handles([0])
# Generate texts from the prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
llm._collective_rpc("update_weights", (ipc_handles,))
# Finalize the update weights
llm._collective_rpc("update_weights", (None,))
sampling_params = SamplingParams(temperature=0, return_generation_logits=True)
ipc_handles = hf_model.get_weight_ipc_handles([0])
llm._collective_rpc("update_weights", (ipc_handles,))
# Finalize the update weights
llm._collective_rpc("update_weights", (None,))
llm_logits, ref_logits = run_generate(llm, hf_model, prompts, sampling_params)
compare_logits(llm_logits, ref_logits)
llm_logits, ref_logits = run_generate(llm, hf_model, prompts, sampling_params)
compare_logits(llm_logits, ref_logits)
@pytest.mark.parametrize(
"model_dir",
["Qwen2.5-0.5B-Instruct", "Qwen3/Qwen3-8B", "llama-models-v2/TinyLlama-1.1B-Chat-v1.0"],
[
"llama-models-v2/TinyLlama-1.1B-Chat-v1.0",
"Qwen2.5-0.5B-Instruct",
"Qwen3/Qwen3-8B",
"Qwen3/Qwen3-30B-A3B",
"Qwen3/Qwen3-8B-FP8",
"Qwen3/Qwen3-30B-A3B-FP8",
],
)
def test_llm_partial_update_weights(model_dir):
model_dir = str(llm_models_root() / model_dir)
kv_cache_config = KvCacheConfig(enable_block_reuse=True, free_gpu_memory_fraction=0.1)
with TemporaryDirectory() as tmp_model_dir:
num_hidden_layers = 1
process_and_copy_folder(model_dir, tmp_model_dir, num_hidden_layers=num_hidden_layers)
hf_model = RefHFModelWithIPCHandles(model_dir, num_hidden_layers=num_hidden_layers)
tokenizer = AutoTokenizer.from_pretrained(model_dir)
kv_cache_config = KvCacheConfig(enable_block_reuse=True, free_gpu_memory_fraction=0.1)
hf_model = HFModel(model_dir)
llm = LLM(
model=model_dir,
ray_worker_extension_cls="tensorrt_llm.llmapi.rlhf_utils.WorkerExtension",
tensor_parallel_size=1,
load_format="dummy",
pipeline_parallel_size=1,
kv_cache_config=kv_cache_config,
)
# Generate texts from the prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0, return_generation_logits=True)
ipc_handles = hf_model.get_weight_ipc_handles([0])
def common_filter(filter_name: str) -> Callable[[str], bool]:
def filter_fn(name: str) -> bool:
return filter_name in name
return filter_fn
filter_list = [
"q_proj.weight",
"k_proj.weight",
"v_proj.weight",
"o_proj.weight",
"gate_proj.weight",
"up_proj.weight",
"down_proj.weight",
"norm.weight",
"embed_tokens.weight",
"lm_head.weight",
]
if "Qwen2.5" in model_dir or "Qwen2" in model_dir:
filter_list.extend(
[
"q_proj.bias",
"k_proj.bias",
"v_proj.bias",
]
llm = LLM(
model=tmp_model_dir,
ray_worker_extension_cls="tensorrt_llm.llmapi.rlhf_utils.WorkerExtension",
tensor_parallel_size=1,
load_format="dummy",
pipeline_parallel_size=1,
kv_cache_config=kv_cache_config,
)
for filter_name in filter_list:
weight_filter = common_filter(filter_name=filter_name)
ipc_handles = hf_model.get_weight_ipc_handles([0], weight_filter=weight_filter)
llm._collective_rpc("update_weights", (ipc_handles,), non_block=True)
# Finalize the update weights
llm._collective_rpc("update_weights", (None,))
llm_logits, ref_logits = run_generate(llm, hf_model, prompts, sampling_params)
compare_logits(llm_logits, ref_logits)
# Generate texts from the prompts.
prompts_texts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
prompts = [tokenizer.encode(prompt) for prompt in prompts_texts]
del tokenizer
sampling_params = SamplingParams(
temperature=0, return_generation_logits=True, max_tokens=1024
)
def common_filter(filter_name: str) -> Callable[[str], bool]:
def filter_fn(name: str) -> bool:
return name.endswith(filter_name)
return filter_fn
# Generate filter_list from model weight keys by removing layer prefix
# e.g., "model.layers.41.input_layernorm.weight" -> "input_layernorm.weight"
layer_prefix_pattern = re.compile(r"^model\.layers\.\d+\.")
filter_set = set()
for name, _ in hf_model.all_weights[hf_model.device_id]:
suffix = layer_prefix_pattern.sub("", name)
filter_set.add(suffix)
filter_list = list(filter_set)
for filter_name in filter_list:
weight_filter = common_filter(filter_name=filter_name)
ipc_handles = hf_model.get_weight_ipc_handles([0], weight_filter=weight_filter)
llm._collective_rpc("update_weights", (ipc_handles,))
# Finalize the update weights
llm._collective_rpc("update_weights", (None,))
llm_logits, ref_logits = run_generate(llm, hf_model, prompts, sampling_params)
compare_logits(llm_logits, ref_logits)

View File

@ -13,12 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional
from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from transformers import AutoModelForCausalLM
def geglu(x):
@ -1239,3 +1240,188 @@ class recurrent_ref(nn.Module):
x = self.linear_out(x)
return x, conv_state, lru_state
class RefHFModel:
def __init__(self,
model_dir: str,
device_id: int = 0,
additional_model_kargs: Optional[Dict[str, Any]] = None):
self.device_id = device_id
self.model = AutoModelForCausalLM.from_pretrained(
model_dir, **(additional_model_kargs or {})).to(f"cuda:{device_id}")
def generate_batch_with_padding(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
position_ids: torch.Tensor,
responses: List[List[int]],
prompt_max_len: int = 1024,
micro_batch_size: int = 16,
return_logits: bool = False,
):
"""
Synchronous inference on a batch with micro-batching.
Directly extracts response logprobs to save memory.
Args:
input_ids: [batch_size, seq_len]
attention_mask: [batch_size, seq_len]
position_ids: [batch_size, seq_len]
responses: List of response token IDs for each sample
prompt_max_len: Maximum prompt length (default 1024)
micro_batch_size: Size of each micro batch to avoid OOM
return_logits: Whether to return logits, If True, return logits, otherwise return logprobs
Returns:
List of logits or logprobs tensors, one per sample [response_len]
"""
# Move tensors to the correct device
input_ids = input_ids.to(f"cuda:{self.device_id}")
attention_mask = attention_mask.to(f"cuda:{self.device_id}")
position_ids = position_ids.to(f"cuda:{self.device_id}")
batch_size = input_ids.shape[0]
num_micro_batches = (batch_size + micro_batch_size -
1) // micro_batch_size
ref_results = []
with torch.no_grad():
for micro_idx in range(num_micro_batches):
start_idx = micro_idx * micro_batch_size
end_idx = min((micro_idx + 1) * micro_batch_size, batch_size)
# Extract micro batch
micro_input_ids = input_ids[start_idx:end_idx]
micro_attention_mask = attention_mask[start_idx:end_idx]
micro_position_ids = position_ids[start_idx:end_idx]
# Forward pass
outputs = self.model(
input_ids=micro_input_ids,
attention_mask=micro_attention_mask,
position_ids=micro_position_ids,
)
# Extract response logprobs for each sample in this micro batch
for i in range(outputs.logits.shape[0]):
sample_idx = start_idx + i
response = responses[sample_idx]
response_len = len(response)
# Extract logits for predicting response tokens
# For predicting response[j], we need logits at position prompt_max_len-1+j
response_logits = outputs.logits[i, prompt_max_len -
1:prompt_max_len - 1 +
response_len, :]
if return_logits:
ref_results.append(response_logits)
else:
# Convert to logprobs
response_logprobs = torch.log_softmax(response_logits,
dim=-1)
# Extract logprobs for the actual generated tokens
response_tensor = torch.tensor(
response,
dtype=torch.long,
device=response_logprobs.device)
ref_logprob_for_tokens = torch.gather(
response_logprobs,
dim=-1,
index=response_tensor.unsqueeze(-1)).squeeze(-1)
ref_results.append(ref_logprob_for_tokens)
# Free memory immediately after processing each micro batch
del outputs
torch.cuda.empty_cache()
return ref_results
@staticmethod
def pad_data(
original_prompts: List[List[int]],
generated_token_ids_list: List[List[int]],
prompt_max_len: int = 1024,
response_max_len: int = 1024,
pad_token_id: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Pad the data to the maximum length.
Structure:
[left_pad | actual_prompt | actual_response | right_pad]
|<-- prompt_max_len=1024 -->|<-- response_max_len=1024 -->|
Args:
original_prompts: List of prompt token IDs, len = batch_size
generated_token_ids_list: List of response token IDs, len = batch_size
prompt_max_len: Maximum length for prompt section (default 1024)
response_max_len: Maximum length for response section (default 1024)
pad_token_id: Token ID for padding (default 0)
Returns:
input_ids: Tensor of shape [batch_size, prompt_max_len + response_max_len]
attention_mask: Tensor of shape [batch_size, prompt_max_len + response_max_len]
position_ids: Tensor of shape [batch_size, prompt_max_len + response_max_len]
"""
batch_size = len(original_prompts)
total_len = prompt_max_len + response_max_len
for i, (prompt, response) in enumerate(
zip(original_prompts, generated_token_ids_list)):
assert len(prompt) <= prompt_max_len, (
f"Batch {i}: Prompt length {len(prompt)} exceeds max {prompt_max_len}"
)
assert len(response) <= response_max_len, (
f"Batch {i}: Response length {len(response)} exceeds max {response_max_len}"
)
# Build batch tensors [batch_size, total_len]
batch_input_ids = torch.full((batch_size, total_len),
pad_token_id,
dtype=torch.long,
device="cuda")
batch_attention_mask = torch.zeros((batch_size, total_len),
dtype=torch.long,
device="cuda")
batch_position_ids = torch.zeros((batch_size, total_len),
dtype=torch.long,
device="cuda")
response_lens = []
for i in range(batch_size):
prompt_tokens = original_prompts[i]
response_tokens = generated_token_ids_list[i]
prompt_len = len(prompt_tokens)
response_len = len(response_tokens)
response_lens.append(response_len)
left_pad_len = prompt_max_len - prompt_len
# Fill input_ids: [left_pad | prompt | response | right_pad]
prompt_start = left_pad_len
prompt_end = prompt_max_len
response_start = prompt_max_len
response_end = prompt_max_len + response_len
batch_input_ids[i, prompt_start:prompt_end] = torch.tensor(
prompt_tokens, dtype=torch.long, device="cuda")
batch_input_ids[i, response_start:response_end] = torch.tensor(
response_tokens, dtype=torch.long, device="cuda")
# Fill attention_mask: 1 for actual tokens, 0 for padding
batch_attention_mask[i, prompt_start:response_end] = 1
# Fill position_ids: sequential for actual tokens
actual_seq_len = prompt_len + response_len
batch_position_ids[i, prompt_start:response_end] = torch.arange(
actual_seq_len, dtype=torch.long, device="cuda")
# Right padding keeps the last position value
if response_len < response_max_len:
batch_position_ids[i, response_end:] = actual_seq_len - 1
return batch_input_ids, batch_attention_mask, batch_position_ids