mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Merge c37c80d3ac into 38296a472b
This commit is contained in:
commit
96dc51a806
@ -1124,6 +1124,26 @@ class ConfigurableMoE(MoE):
|
||||
)
|
||||
return self.backend.post_load_weights()
|
||||
|
||||
def process_weights_after_loading(self):
|
||||
"""
|
||||
Process weights after loading - delegated to backend
|
||||
|
||||
"""
|
||||
assert hasattr(self.backend, "process_weights_after_loading"), (
|
||||
f"Backend {self.backend.__class__.__name__} must implement process_weights_after_loading()"
|
||||
)
|
||||
return self.backend.process_weights_after_loading()
|
||||
|
||||
def pre_reload_weights(self):
|
||||
"""
|
||||
Pre reload weights - delegated to backend
|
||||
|
||||
"""
|
||||
assert hasattr(self.backend, "pre_reload_weights"), (
|
||||
f"Backend {self.backend.__class__.__name__} must implement pre_reload_weights()"
|
||||
)
|
||||
return self.backend.pre_reload_weights()
|
||||
|
||||
# ========== Communication and Quantization Properties ==========
|
||||
|
||||
@property
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import inspect
|
||||
import os
|
||||
from functools import cached_property
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
@ -862,16 +863,22 @@ class CutlassFusedMoE(MoE):
|
||||
assert len(weights) == 1
|
||||
weights = weights[0]
|
||||
|
||||
if not isinstance(self.quant_method, UnquantizedFusedMoEMethod):
|
||||
assert not allow_partial_loading, "Partial loading is not supported for quantized MoE now"
|
||||
self.quant_method.load_weights(self, weights,
|
||||
self.weight_loading_mode)
|
||||
else:
|
||||
self.quant_method.load_weights(
|
||||
self,
|
||||
weights,
|
||||
self.weight_loading_mode,
|
||||
allow_partial_loading=allow_partial_loading)
|
||||
kargs = {}
|
||||
if "allow_partial_loading" in inspect.getfullargspec(
|
||||
self.quant_method.load_weights).args:
|
||||
kargs["allow_partial_loading"] = allow_partial_loading
|
||||
self.quant_method.load_weights(self, weights, self.weight_loading_mode,
|
||||
**kargs)
|
||||
|
||||
def post_load_weights(self):
|
||||
self.quant_method.post_load_weights(self)
|
||||
|
||||
def process_weights_after_loading(self):
|
||||
if hasattr(self.quant_method, 'process_weights_after_loading'):
|
||||
self.quant_method.process_weights_after_loading(self)
|
||||
|
||||
def pre_reload_weights(self):
|
||||
assert hasattr(
|
||||
self.quant_method, 'pre_reload_weights'
|
||||
), "pre_reload_weights is not supported for this quant method"
|
||||
self.quant_method.pre_reload_weights(self)
|
||||
|
||||
@ -1401,3 +1401,13 @@ class TritonFusedMoE(MoE):
|
||||
|
||||
def post_load_weights(self):
|
||||
self.quant_method.post_load_weights(self)
|
||||
|
||||
def process_weights_after_loading(self):
|
||||
if hasattr(self.quant_method, 'process_weights_after_loading'):
|
||||
self.quant_method.process_weights_after_loading(self)
|
||||
|
||||
def pre_reload_weights(self):
|
||||
assert hasattr(
|
||||
self.quant_method, 'pre_reload_weights'
|
||||
), "pre_reload_weights is not supported for this quant method"
|
||||
self.quant_method.pre_reload_weights(self)
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import inspect
|
||||
import os
|
||||
from functools import cached_property
|
||||
from typing import Dict, List, Optional, Union
|
||||
@ -21,9 +22,9 @@ from .interface import AlltoallMethodType, MoE, MoEWeightLoadingMode
|
||||
# isort: off
|
||||
from .quantization import (
|
||||
DeepSeekFP8BlockScalesFusedMoEMethod, NVFP4TRTLLMGenFusedMoEBaseMethod,
|
||||
NVFP4TRTLLMGenFusedMoEMethod, UnquantizedFusedMoEMethod,
|
||||
W4A8MXFP4FP8TRTLLMGenFusedMoEMethod, W4A8MXFP4MXFP8TRTLLMGenFusedMoEMethod,
|
||||
W4A8NVFP4FP8TRTLLMGenFusedMoEMethod, W4A16MXFP4TRTLLMGenFusedMoEMethod)
|
||||
NVFP4TRTLLMGenFusedMoEMethod, W4A8MXFP4FP8TRTLLMGenFusedMoEMethod,
|
||||
W4A8MXFP4MXFP8TRTLLMGenFusedMoEMethod, W4A8NVFP4FP8TRTLLMGenFusedMoEMethod,
|
||||
W4A16MXFP4TRTLLMGenFusedMoEMethod)
|
||||
# isort: on
|
||||
from .routing import BaseMoeRoutingMethod, DeepSeekV3MoeRoutingMethod
|
||||
|
||||
@ -273,20 +274,26 @@ class TRTLLMGenFusedMoE(MoE):
|
||||
assert len(weights) == 1
|
||||
weights = weights[0]
|
||||
|
||||
if not isinstance(self.quant_method, UnquantizedFusedMoEMethod):
|
||||
assert not allow_partial_loading, "Partial loading is not supported for quantized MoE now"
|
||||
self.quant_method.load_weights(self, weights,
|
||||
self.weight_loading_mode)
|
||||
else:
|
||||
self.quant_method.load_weights(
|
||||
self,
|
||||
weights,
|
||||
self.weight_loading_mode,
|
||||
allow_partial_loading=allow_partial_loading)
|
||||
kargs = {}
|
||||
if "allow_partial_loading" in inspect.getfullargspec(
|
||||
self.quant_method.load_weights).args:
|
||||
kargs["allow_partial_loading"] = allow_partial_loading
|
||||
self.quant_method.load_weights(self, weights, self.weight_loading_mode,
|
||||
**kargs)
|
||||
|
||||
def post_load_weights(self):
|
||||
self.quant_method.post_load_weights(self)
|
||||
|
||||
def process_weights_after_loading(self):
|
||||
if hasattr(self.quant_method, 'process_weights_after_loading'):
|
||||
self.quant_method.process_weights_after_loading(self)
|
||||
|
||||
def pre_reload_weights(self):
|
||||
assert hasattr(
|
||||
self.quant_method, 'pre_reload_weights'
|
||||
), "pre_reload_weights is not supported for this quant method"
|
||||
self.quant_method.pre_reload_weights(self)
|
||||
|
||||
def quantize_input(self, x, post_quant_comm: bool = True):
|
||||
"""Quantize inputs prior to post-communication (alltoall/allgather) or before MoE computation.
|
||||
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import inspect
|
||||
import os
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
@ -935,20 +936,26 @@ class WideEPMoE(MoE):
|
||||
assert len(weights) == 1
|
||||
weights = weights[0]
|
||||
|
||||
if not isinstance(self.quant_method, UnquantizedFusedMoEMethod):
|
||||
assert not allow_partial_loading, "Partial loading is not supported for quantized MoE now"
|
||||
self.quant_method.load_weights(self, weights,
|
||||
self.weight_loading_mode)
|
||||
else:
|
||||
self.quant_method.load_weights(
|
||||
self,
|
||||
weights,
|
||||
self.weight_loading_mode,
|
||||
allow_partial_loading=allow_partial_loading)
|
||||
kargs = {}
|
||||
if "allow_partial_loading" in inspect.getfullargspec(
|
||||
self.quant_method.load_weights).args:
|
||||
kargs["allow_partial_loading"] = allow_partial_loading
|
||||
self.quant_method.load_weights(self, weights, self.weight_loading_mode,
|
||||
**kargs)
|
||||
|
||||
def post_load_weights(self):
|
||||
self.quant_method.post_load_weights(self)
|
||||
|
||||
def process_weights_after_loading(self):
|
||||
if hasattr(self.quant_method, 'process_weights_after_loading'):
|
||||
self.quant_method.process_weights_after_loading(self)
|
||||
|
||||
def pre_reload_weights(self):
|
||||
assert hasattr(
|
||||
self.quant_method, 'pre_reload_weights'
|
||||
), "pre_reload_weights is not supported for this quant method"
|
||||
self.quant_method.pre_reload_weights(self)
|
||||
|
||||
def forward_fake(
|
||||
self,
|
||||
x: Union[torch.Tensor, Fp4QuantizedTensor],
|
||||
|
||||
@ -521,12 +521,20 @@ class MoE(nn.Module):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def load_weights(self, weights: List[Dict]):
|
||||
def load_weights(self,
|
||||
weights: List[Dict],
|
||||
allow_partial_loading: bool = False):
|
||||
raise NotImplementedError
|
||||
|
||||
def process_weights_after_loading(self):
|
||||
pass
|
||||
|
||||
def post_load_weights(self):
|
||||
pass
|
||||
|
||||
def pre_reload_weights(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def quantize_input(
|
||||
self,
|
||||
|
||||
@ -18,7 +18,8 @@ from tensorrt_llm.quantization.utils.fp4_utils import (
|
||||
from tensorrt_llm.quantization.utils.fp8_utils import (
|
||||
resmooth_to_fp8_e8m0, transform_sf_into_required_layout)
|
||||
|
||||
from ...utils import swizzle_sf, unswizzle_sf
|
||||
from ...utils import (replace_parameter_and_save_metadata, swizzle_sf,
|
||||
unswizzle_sf)
|
||||
from ..linear import TensorParallelMode, load_weight_shard
|
||||
from .interface import MoEWeightLoadingMode
|
||||
|
||||
@ -236,6 +237,8 @@ class FusedMoEMethodBase(ABC):
|
||||
module.w3_w1_bias = None
|
||||
module.w2_bias = None
|
||||
|
||||
module.rebuild_tensor_metadata = {}
|
||||
|
||||
def load_expert_weights_to_dst(
|
||||
self,
|
||||
module: torch.nn.Module,
|
||||
@ -330,6 +333,12 @@ class FusedMoEMethodBase(ABC):
|
||||
weights: List[Dict],
|
||||
weight_loading_mode: MoEWeightLoadingMode,
|
||||
allow_partial_loading: bool = False):
|
||||
if allow_partial_loading:
|
||||
assert isinstance(
|
||||
self, (UnquantizedFusedMoEMethod, FP8QDQFusedMoEMethod,
|
||||
DeepSeekFP8BlockScalesFusedMoEMethod,
|
||||
DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm)
|
||||
), "Partial loading is only supported for unquantized and FP8 models"
|
||||
additional_kargs = {}
|
||||
if "allow_partial_loading" in inspect.getfullargspec(
|
||||
self.load_expert_weights_to_dst).args:
|
||||
@ -402,6 +411,9 @@ class FusedMoEMethodBase(ABC):
|
||||
local_shared_w2_bias_tensors if module.bias else None,
|
||||
**additional_kargs)
|
||||
|
||||
if not allow_partial_loading:
|
||||
self.process_weights_after_loading(module)
|
||||
|
||||
def post_load_weights(self, module: torch.nn.Module):
|
||||
if self.need_load_shared_weights(module):
|
||||
weight_fns = {
|
||||
@ -432,6 +444,9 @@ class FusedMoEMethodBase(ABC):
|
||||
def load_quant_scales(self, module: torch.nn.Module, weights: List[Dict]):
|
||||
pass
|
||||
|
||||
def process_weights_after_loading(self, module: torch.nn.Module):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def setup_quant_scales(self, module: torch.nn.Module):
|
||||
raise NotImplementedError
|
||||
@ -473,21 +488,19 @@ class FusedMoEMethodBase(ABC):
|
||||
TensorParallelMode.COLUMN,
|
||||
device=device) if w3_weight is not None else None
|
||||
|
||||
src_w3_size_shard = w3_weight_shard.shape[
|
||||
0] if w3_weight_shard is not None else 0
|
||||
src_w1_size_shard = w1_weight_shard.shape[
|
||||
0] if w1_weight_shard is not None else 0
|
||||
if w1_weight is not None:
|
||||
dst_w1_weight = dst_w3_w1_weight.narrow(dim=0,
|
||||
start=src_w3_size_shard,
|
||||
length=src_w1_size_shard)
|
||||
dst_w1_weight.copy_(w1_weight_shard.contiguous().view(
|
||||
dst_w3_w1_weight.dtype),
|
||||
non_blocking=True)
|
||||
if w3_weight is not None:
|
||||
dst_w3_weight = dst_w3_w1_weight.narrow(dim=0,
|
||||
start=0,
|
||||
length=src_w3_size_shard)
|
||||
dst_w3_weight, dst_w1_weight = dst_w3_w1_weight.chunk(2, dim=0)
|
||||
if w1_weight_shard is not None and w1_weight_shard.shape[0] != 0:
|
||||
w1_weight_shard_viewed = w1_weight_shard.contiguous().view(
|
||||
dst_w3_w1_weight.dtype)
|
||||
if w1_weight_shard_viewed.shape[0] == dst_w3_w1_weight.shape[0]:
|
||||
# w3_weight (gate_proj) should be empty for Nemotron-H MoE model.
|
||||
dst_w3_w1_weight.copy_(w1_weight_shard_viewed,
|
||||
non_blocking=True)
|
||||
elif w1_weight_shard_viewed.shape[0] == dst_w1_weight.shape[0]:
|
||||
dst_w1_weight.copy_(w1_weight_shard_viewed, non_blocking=True)
|
||||
else:
|
||||
raise ValueError("Shape mismatch!")
|
||||
if w3_weight_shard is not None and w3_weight_shard.shape[0] != 0:
|
||||
dst_w3_weight.copy_(w3_weight_shard.contiguous().view(
|
||||
dst_w3_w1_weight.dtype),
|
||||
non_blocking=True)
|
||||
@ -516,6 +529,16 @@ class FusedMoEMethodBase(ABC):
|
||||
dst_w2_weight.copy_(w2_weight_shard.view(dst_w2_weight.dtype),
|
||||
non_blocking=True)
|
||||
|
||||
def pre_reload_weights(self, module: torch.nn.Module):
|
||||
for param_name, metadata in module.rebuild_tensor_metadata.items():
|
||||
logger.warning(
|
||||
f"Pre-reloading weight '{param_name}' requires tensor re-creation, which will invalidate existing CUDA graphs."
|
||||
)
|
||||
param = torch.nn.Parameter(torch.empty_like(metadata,
|
||||
device="cuda"),
|
||||
requires_grad=False)
|
||||
module.register_parameter(param_name, param)
|
||||
|
||||
|
||||
class UnquantizedFusedMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
@ -539,8 +562,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def load_expert_fc31_input_scale_fp8_qdq(w1_input_scale, w3_input_scale,
|
||||
dst_fc31_input_scale: torch.Tensor):
|
||||
dst_fc31_input_scale.copy_(
|
||||
max(w1_input_scale[...].reshape([]), w3_input_scale[...].reshape([])))
|
||||
if w1_input_scale is not None and w1_input_scale.numel() != 0:
|
||||
w1_input_scale = w1_input_scale[...].reshape([])
|
||||
dst_fc31_input_scale[0].copy_(w1_input_scale)
|
||||
if w3_input_scale is not None and w3_input_scale.numel() != 0:
|
||||
w3_input_scale = w3_input_scale[...].reshape([])
|
||||
dst_fc31_input_scale[1].copy_(w3_input_scale)
|
||||
|
||||
|
||||
def load_expert_fc2_input_scale_fp8_qdq(w2_input_scale,
|
||||
@ -549,35 +576,45 @@ def load_expert_fc2_input_scale_fp8_qdq(w2_input_scale,
|
||||
|
||||
|
||||
def load_activation_scales_fp8_qdq(module: torch.nn.Module, weights: Dict):
|
||||
tmp_fc31_input_scale = torch.empty(module.num_experts, dtype=torch.float32)
|
||||
tmp_fc2_input_scale = torch.empty(module.num_experts, dtype=torch.float32)
|
||||
if not hasattr(module, 'tmp_fc31_input_scale'):
|
||||
module.tmp_fc31_input_scale = torch.empty(
|
||||
(module.num_experts, 2),
|
||||
dtype=torch.float32,
|
||||
device=module.fc31_dequant.device)
|
||||
tmp_fc31_input_scale = module.tmp_fc31_input_scale
|
||||
if not hasattr(module, 'tmp_fc2_input_scale'):
|
||||
module.tmp_fc2_input_scale = torch.empty(
|
||||
module.num_experts,
|
||||
dtype=torch.float32,
|
||||
device=module.fc2_dequant.device)
|
||||
tmp_fc2_input_scale = module.tmp_fc2_input_scale
|
||||
for expert_id in range(module.num_experts):
|
||||
if module.weight_loading_mode == MoEWeightLoadingMode.VANILLA:
|
||||
w1_input_scale = weights[f"{expert_id}.w1.input_scale"]
|
||||
w3_input_scale = weights[f"{expert_id}.w3.input_scale"]
|
||||
w2_input_scale = weights[f"{expert_id}.w2.input_scale"]
|
||||
w1_input_scale = weights[
|
||||
f"{expert_id}.w1.input_scale"] if f"{expert_id}.w1.input_scale" in weights else None
|
||||
w3_input_scale = weights[
|
||||
f"{expert_id}.w3.input_scale"] if f"{expert_id}.w3.input_scale" in weights else None
|
||||
w2_input_scale = weights[
|
||||
f"{expert_id}.w2.input_scale"] if f"{expert_id}.w2.input_scale" in weights else None
|
||||
elif module.weight_loading_mode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ:
|
||||
w1_input_scale = weights[f"gate_up_proj_input_scale"]
|
||||
w3_input_scale = weights[f"gate_up_proj_input_scale"]
|
||||
w2_input_scale = weights[f"down_proj_input_scale"]
|
||||
w1_input_scale = weights[
|
||||
f"gate_up_proj_input_scale"] if f"gate_up_proj_input_scale" in weights else None
|
||||
w3_input_scale = weights[
|
||||
f"gate_up_proj_input_scale"] if f"gate_up_proj_input_scale" in weights else None
|
||||
w2_input_scale = weights[
|
||||
f"down_proj_input_scale"] if f"down_proj_input_scale" in weights else None
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unknown weight loading mode in MoE: {module.weight_loading_mode}"
|
||||
)
|
||||
|
||||
load_expert_fc31_input_scale_fp8_qdq(w1_input_scale, w3_input_scale,
|
||||
tmp_fc31_input_scale[expert_id])
|
||||
if w1_input_scale is not None or w3_input_scale is not None:
|
||||
load_expert_fc31_input_scale_fp8_qdq(
|
||||
w1_input_scale, w3_input_scale, tmp_fc31_input_scale[expert_id])
|
||||
|
||||
load_expert_fc2_input_scale_fp8_qdq(w2_input_scale,
|
||||
tmp_fc2_input_scale[expert_id])
|
||||
|
||||
# max_fc31_input_scale is the maximum of all w1 input scales and w3 input scales.
|
||||
# It's used to quantize fc31 input inside the MOE op
|
||||
max_fc31_input_scale = tmp_fc31_input_scale.max()
|
||||
# max_fc2_input_scale is the maximum of all w2 input scales.
|
||||
max_fc2_input_scale = tmp_fc2_input_scale.max()
|
||||
|
||||
return max_fc31_input_scale, max_fc2_input_scale
|
||||
if w2_input_scale is not None:
|
||||
load_expert_fc2_input_scale_fp8_qdq(w2_input_scale,
|
||||
tmp_fc2_input_scale[expert_id])
|
||||
|
||||
|
||||
def requantize_expert_w3_w1_weight_fp8_qdq(module: torch.nn.Module,
|
||||
@ -654,9 +691,12 @@ class FP8QDQFusedMoEMethod(FusedMoEMethodBase):
|
||||
def load_expert_w3_w1_weight_scale_fp8_qdq(
|
||||
self, w1_weight_scale, w3_weight_scale,
|
||||
dst_w3_w1_weight_scale: torch.Tensor):
|
||||
w1_weight_scale = w1_weight_scale[...].reshape([])
|
||||
w3_weight_scale = w3_weight_scale[...].reshape([])
|
||||
dst_w3_w1_weight_scale.copy_(max(w1_weight_scale, w3_weight_scale))
|
||||
if w1_weight_scale is not None and w1_weight_scale.numel() != 0:
|
||||
w1_weight_scale = w1_weight_scale[...].reshape([])
|
||||
dst_w3_w1_weight_scale[0].copy_(w1_weight_scale)
|
||||
if w3_weight_scale is not None and w3_weight_scale.numel() != 0:
|
||||
w3_weight_scale = w3_weight_scale[...].reshape([])
|
||||
dst_w3_w1_weight_scale[1].copy_(w3_weight_scale)
|
||||
|
||||
def load_expert_w2_weight_scale_fp8(self, w2_weight_scale,
|
||||
dst_w2_weight_scale: torch.Tensor):
|
||||
@ -664,25 +704,38 @@ class FP8QDQFusedMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
|
||||
# Step1: Load input scales.
|
||||
max_fc31_input_scale, max_fc2_input_scale = load_activation_scales_fp8_qdq(
|
||||
module, weights)
|
||||
load_activation_scales_fp8_qdq(module, weights)
|
||||
|
||||
# Step2: Load weight scales and requantize w3_w1_weight.
|
||||
tmp_w3_w1_weight_scale = torch.empty(module.expert_size_per_partition,
|
||||
dtype=torch.float32)
|
||||
tmp_w2_weight_scale = torch.empty(module.expert_size_per_partition,
|
||||
dtype=torch.float32)
|
||||
# Step2: Load weight scales
|
||||
if not hasattr(module, 'tmp_w3_w1_weight_scale'):
|
||||
module.tmp_w3_w1_weight_scale = torch.empty(
|
||||
(module.expert_size_per_partition, 2),
|
||||
dtype=torch.float32,
|
||||
device=module.fc31_dequant.device)
|
||||
if not hasattr(module, 'tmp_w2_weight_scale'):
|
||||
module.tmp_w2_weight_scale = torch.empty(
|
||||
module.expert_size_per_partition,
|
||||
dtype=torch.float32,
|
||||
device=module.fc2_dequant.device)
|
||||
tmp_w3_w1_weight_scale = module.tmp_w3_w1_weight_scale
|
||||
tmp_w2_weight_scale = module.tmp_w2_weight_scale
|
||||
|
||||
for local_slot_id, expert_id in enumerate(
|
||||
module.initial_local_expert_ids):
|
||||
if module.weight_loading_mode == MoEWeightLoadingMode.VANILLA:
|
||||
w1_weight_scale = weights[f"{expert_id}.w1.weight_scale"]
|
||||
w3_weight_scale = weights[f"{expert_id}.w3.weight_scale"]
|
||||
w2_weight_scale = weights[f"{expert_id}.w2.weight_scale"]
|
||||
w1_weight_scale = weights[
|
||||
f"{expert_id}.w1.weight_scale"] if f"{expert_id}.w1.weight_scale" in weights else None
|
||||
w3_weight_scale = weights[
|
||||
f"{expert_id}.w3.weight_scale"] if f"{expert_id}.w3.weight_scale" in weights else None
|
||||
w2_weight_scale = weights[
|
||||
f"{expert_id}.w2.weight_scale"] if f"{expert_id}.w2.weight_scale" in weights else None
|
||||
elif module.weight_loading_mode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ:
|
||||
w1_weight_scale = weights[f"gate_up_proj_weight_scale"]
|
||||
w3_weight_scale = weights[f"gate_up_proj_weight_scale"]
|
||||
w2_weight_scale = weights[f"down_proj_weight_scale"]
|
||||
w1_weight_scale = weights[
|
||||
f"gate_up_proj_weight_scale"] if f"gate_up_proj_weight_scale" in weights else None
|
||||
w3_weight_scale = weights[
|
||||
f"gate_up_proj_weight_scale"] if f"gate_up_proj_weight_scale" in weights else None
|
||||
w2_weight_scale = weights[
|
||||
f"down_proj_weight_scale"] if f"down_proj_weight_scale" in weights else None
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unknown weight loading mode in MoE: {module.weight_loading_mode}"
|
||||
@ -690,24 +743,45 @@ class FP8QDQFusedMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
expert_idx = local_slot_id
|
||||
|
||||
self.load_expert_w3_w1_weight_scale_fp8_qdq(
|
||||
w1_weight_scale, w3_weight_scale,
|
||||
tmp_w3_w1_weight_scale[expert_idx])
|
||||
if w1_weight_scale is not None or w3_weight_scale is not None:
|
||||
self.load_expert_w3_w1_weight_scale_fp8_qdq(
|
||||
w1_weight_scale, w3_weight_scale,
|
||||
tmp_w3_w1_weight_scale[expert_idx])
|
||||
|
||||
if w2_weight_scale is not None:
|
||||
self.load_expert_w2_weight_scale_fp8(
|
||||
w2_weight_scale, tmp_w2_weight_scale[expert_idx])
|
||||
|
||||
def process_weights_after_loading(self, module: torch.nn.Module):
|
||||
# max_fc31_input_scale is the maximum of all w1 input scales and w3 input scales.
|
||||
# It's used to quantize fc31 input inside the MOE op
|
||||
max_fc31_input_scale = module.tmp_fc31_input_scale.max()
|
||||
# max_fc2_input_scale is the maximum of all w2 input scales.
|
||||
max_fc2_input_scale = module.tmp_fc2_input_scale.max()
|
||||
# Requantize w3_w1_weight
|
||||
for local_slot_id, _ in enumerate(module.initial_local_expert_ids):
|
||||
expert_idx = local_slot_id
|
||||
requantize_expert_w3_w1_weight_fp8_qdq(
|
||||
module, w1_weight_scale, w3_weight_scale,
|
||||
module, module.tmp_w3_w1_weight_scale[expert_idx][0],
|
||||
module.tmp_w3_w1_weight_scale[expert_idx][1],
|
||||
module.w3_w1_weight.data[expert_idx])
|
||||
|
||||
self.load_expert_w2_weight_scale_fp8(
|
||||
w2_weight_scale, tmp_w2_weight_scale[expert_idx])
|
||||
|
||||
# Step3: calculate and store final loaded weights
|
||||
module.fc31_dequant.data.copy_(tmp_w3_w1_weight_scale *
|
||||
# Calculate and store final loaded weights
|
||||
max_w3_w1_weight_scale = module.tmp_w3_w1_weight_scale.max(dim=1).values
|
||||
module.fc31_dequant.data.copy_(max_w3_w1_weight_scale *
|
||||
max_fc31_input_scale)
|
||||
module.fc2_quant.data.copy_(max_fc2_input_scale.reciprocal())
|
||||
module.fc2_dequant.data.copy_(tmp_w2_weight_scale * max_fc2_input_scale)
|
||||
module.fc2_dequant.data.copy_(module.tmp_w2_weight_scale *
|
||||
max_fc2_input_scale)
|
||||
module.fc31_input_dequant.data.copy_(max_fc31_input_scale)
|
||||
|
||||
self.setup_quant_scales(module)
|
||||
|
||||
delattr(module, 'tmp_w3_w1_weight_scale')
|
||||
delattr(module, 'tmp_w2_weight_scale')
|
||||
delattr(module, 'tmp_fc31_input_scale')
|
||||
delattr(module, 'tmp_fc2_input_scale')
|
||||
|
||||
def post_load_weights(self, module):
|
||||
super().post_load_weights(module)
|
||||
|
||||
@ -733,11 +807,15 @@ class FP8QDQFusedMoEMethod(FusedMoEMethodBase):
|
||||
module.w2_weight, cutlass_fp8_row_alignment,
|
||||
cutlass_fp8_row_alignment)
|
||||
if is_padded_w3_w1_weight:
|
||||
module.w3_w1_weight = nn.Parameter(padded_w3_w1_weight,
|
||||
requires_grad=False)
|
||||
replace_parameter_and_save_metadata(
|
||||
module, "w3_w1_weight",
|
||||
nn.Parameter(padded_w3_w1_weight, requires_grad=False),
|
||||
module.rebuild_tensor_metadata)
|
||||
if is_padded_w2_weight:
|
||||
module.w2_weight = nn.Parameter(padded_w2_weight,
|
||||
requires_grad=False)
|
||||
replace_parameter_and_save_metadata(
|
||||
module, "w2_weight",
|
||||
nn.Parameter(padded_w2_weight, requires_grad=False),
|
||||
module.rebuild_tensor_metadata)
|
||||
|
||||
|
||||
class DeepSeekFP8BlockScalesFusedMoEMethod(FusedMoEMethodBase):
|
||||
@ -778,9 +856,13 @@ class DeepSeekFP8BlockScalesFusedMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
self.setup_quant_scales(module)
|
||||
|
||||
def load_weights(self, module: torch.nn.Module, weights: List[Dict],
|
||||
weight_loading_mode: MoEWeightLoadingMode):
|
||||
super().load_weights(module, weights, weight_loading_mode)
|
||||
def load_weights(self,
|
||||
module: torch.nn.Module,
|
||||
weights: List[Dict],
|
||||
weight_loading_mode: MoEWeightLoadingMode,
|
||||
allow_partial_loading: bool = False):
|
||||
super().load_weights(module, weights, weight_loading_mode,
|
||||
allow_partial_loading)
|
||||
|
||||
def setup_quant_scales(self, module: torch.nn.Module):
|
||||
module.quant_scales = FusedMoEQuantScalesDeepSeekFP8BlockScales(
|
||||
@ -795,42 +877,53 @@ class DeepSeekFP8BlockScalesFusedMoEMethod(FusedMoEMethodBase):
|
||||
for local_slot_id, expert_id in enumerate(load_expert_ids):
|
||||
if module.weight_loading_mode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ:
|
||||
w3_scale = weights['gate_up_proj_weight_scale'][
|
||||
expert_id].transpose(0, 1).contiguous()
|
||||
w1_scale = None
|
||||
expert_id].transpose(0, 1).contiguous(
|
||||
) if "gate_up_proj_weight_scale" in weights else None
|
||||
w2_scale = weights['down_proj_weight_scale'][
|
||||
expert_id].transpose(0, 1).contiguous()
|
||||
expert_id].transpose(0, 1).contiguous(
|
||||
) if "down_proj_weight_scale" in weights else None
|
||||
w3_w1_scale_shard = load_weight_shard(w3_scale,
|
||||
module.tp_size,
|
||||
module.tp_rank,
|
||||
TensorParallelMode.COLUMN,
|
||||
device=device)
|
||||
dst_w3_w1_weight_scale[local_slot_id].copy_(w3_w1_scale_shard)
|
||||
elif module.weight_loading_mode == MoEWeightLoadingMode.VANILLA:
|
||||
w3_scale = weights[f"{expert_id}.w3.weight_scale_inv"]
|
||||
w1_scale = weights[f"{expert_id}.w1.weight_scale_inv"]
|
||||
w2_scale = weights[f"{expert_id}.w2.weight_scale_inv"]
|
||||
w3_scale = weights[
|
||||
f"{expert_id}.w3.weight_scale_inv"] if f"{expert_id}.w3.weight_scale_inv" in weights else None
|
||||
w1_scale = weights[
|
||||
f"{expert_id}.w1.weight_scale_inv"] if f"{expert_id}.w1.weight_scale_inv" in weights else None
|
||||
w2_scale = weights[
|
||||
f"{expert_id}.w2.weight_scale_inv"] if f"{expert_id}.w2.weight_scale_inv" in weights else None
|
||||
dst_w3_weight_scale, dst_w1_weight_scale = dst_w3_w1_weight_scale[
|
||||
local_slot_id].chunk(2, dim=0)
|
||||
if w1_scale is not None:
|
||||
w1_scale_shard = load_weight_shard(
|
||||
w1_scale,
|
||||
module.tp_size,
|
||||
module.tp_rank,
|
||||
TensorParallelMode.COLUMN,
|
||||
device=device)
|
||||
dst_w1_weight_scale.copy_(w1_scale_shard)
|
||||
if w3_scale is not None:
|
||||
w3_scale_shard = load_weight_shard(
|
||||
w3_scale,
|
||||
module.tp_size,
|
||||
module.tp_rank,
|
||||
TensorParallelMode.COLUMN,
|
||||
device=device)
|
||||
dst_w3_weight_scale.copy_(w3_scale_shard)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unknown weight loading mode in MoE: {module.weight_loading_mode}"
|
||||
)
|
||||
|
||||
w3_w1_scale_shard = load_weight_shard(w3_scale,
|
||||
module.tp_size,
|
||||
module.tp_rank,
|
||||
TensorParallelMode.COLUMN,
|
||||
device=device)
|
||||
|
||||
if w1_scale is not None:
|
||||
w1_scale_shard = load_weight_shard(w1_scale,
|
||||
if w2_scale is not None:
|
||||
w2_scale_shard = load_weight_shard(w2_scale,
|
||||
module.tp_size,
|
||||
module.tp_rank,
|
||||
TensorParallelMode.COLUMN,
|
||||
TensorParallelMode.ROW,
|
||||
device=device)
|
||||
w3_w1_scale_shard = torch.cat(
|
||||
[w3_w1_scale_shard, w1_scale_shard], dim=-2)
|
||||
|
||||
dst_w3_w1_weight_scale[local_slot_id].copy_(w3_w1_scale_shard)
|
||||
|
||||
w2_scale_shard = load_weight_shard(w2_scale,
|
||||
module.tp_size,
|
||||
module.tp_rank,
|
||||
TensorParallelMode.ROW,
|
||||
device=device)
|
||||
dst_w2_weight_scale[local_slot_id].copy_(w2_scale_shard)
|
||||
dst_w2_weight_scale[local_slot_id].copy_(w2_scale_shard)
|
||||
|
||||
def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
|
||||
self.load_expert_all_weight_scale_fp8_block_scale(
|
||||
@ -843,16 +936,30 @@ class DeepSeekFP8BlockScalesFusedMoEMethod(FusedMoEMethodBase):
|
||||
if self.need_load_shared_weights(module):
|
||||
local_shared_load_expert_ids = module.layer_load_balancer.get_load_expert_ids(
|
||||
)
|
||||
local_shared_w3_w1_scale_tensors = torch.empty(
|
||||
(len(local_shared_load_expert_ids), ) +
|
||||
module.w3_w1_weight_scaling_factor.data.shape[1:],
|
||||
dtype=module.w3_w1_weight_scaling_factor.data.dtype,
|
||||
device='cpu')
|
||||
local_shared_w2_scale_tensors = torch.empty(
|
||||
(len(local_shared_load_expert_ids), ) +
|
||||
module.w2_weight_scaling_factor.data.shape[1:],
|
||||
dtype=module.w2_weight_scaling_factor.data.dtype,
|
||||
device='cpu')
|
||||
if getattr(module, 'local_shared_w3_w1_scale_tensors',
|
||||
None) is not None:
|
||||
local_shared_w3_w1_scale_tensors = getattr(
|
||||
module, 'local_shared_w3_w1_scale_tensors')
|
||||
else:
|
||||
local_shared_w3_w1_scale_tensors = torch.empty(
|
||||
(len(local_shared_load_expert_ids), ) +
|
||||
module.w3_w1_weight_scaling_factor.data.shape[1:],
|
||||
dtype=module.w3_w1_weight_scaling_factor.data.dtype,
|
||||
device='cpu')
|
||||
setattr(module, 'local_shared_w3_w1_scale_tensors',
|
||||
local_shared_w3_w1_scale_tensors)
|
||||
if getattr(module, 'local_shared_w2_scale_tensors',
|
||||
None) is not None:
|
||||
local_shared_w2_scale_tensors = getattr(
|
||||
module, 'local_shared_w2_scale_tensors')
|
||||
else:
|
||||
local_shared_w2_scale_tensors = torch.empty(
|
||||
(len(local_shared_load_expert_ids), ) +
|
||||
module.w2_weight_scaling_factor.data.shape[1:],
|
||||
dtype=module.w2_weight_scaling_factor.data.dtype,
|
||||
device='cpu')
|
||||
setattr(module, 'local_shared_w2_scale_tensors',
|
||||
local_shared_w2_scale_tensors)
|
||||
self.load_expert_all_weight_scale_fp8_block_scale(
|
||||
module,
|
||||
weights,
|
||||
@ -860,19 +967,32 @@ class DeepSeekFP8BlockScalesFusedMoEMethod(FusedMoEMethodBase):
|
||||
local_shared_w3_w1_scale_tensors,
|
||||
local_shared_w2_scale_tensors,
|
||||
device=torch.device("cpu"))
|
||||
module.register_all_parameter_slot_and_to_fix_weight_fns({
|
||||
'w3_w1_weight_scaling_factor':
|
||||
local_shared_w3_w1_scale_tensors,
|
||||
'w2_weight_scaling_factor':
|
||||
local_shared_w2_scale_tensors,
|
||||
})
|
||||
|
||||
def post_load_weights(self, module: torch.nn.Module):
|
||||
if self.need_load_shared_weights(module):
|
||||
weight_fns = {}
|
||||
if hasattr(module, 'local_shared_w3_w1_scale_tensors'):
|
||||
weight_fns['w3_w1_weight_scaling_factor'] = getattr(
|
||||
module, 'local_shared_w3_w1_scale_tensors')
|
||||
delattr(module, 'local_shared_w3_w1_scale_tensors')
|
||||
if hasattr(module, 'local_shared_w2_scale_tensors'):
|
||||
weight_fns['w2_weight_scaling_factor'] = getattr(
|
||||
module, 'local_shared_w2_scale_tensors')
|
||||
delattr(module, 'local_shared_w2_scale_tensors')
|
||||
if weight_fns:
|
||||
module.register_all_parameter_slot_and_to_fix_weight_fns(
|
||||
weight_fns)
|
||||
super().post_load_weights(module)
|
||||
|
||||
|
||||
class DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm(
|
||||
DeepSeekFP8BlockScalesFusedMoEMethod):
|
||||
|
||||
def load_weights(self, module: torch.nn.Module, weights: List[Dict],
|
||||
weight_loading_mode: MoEWeightLoadingMode):
|
||||
def load_weights(self,
|
||||
module: torch.nn.Module,
|
||||
weights: List[Dict],
|
||||
weight_loading_mode: MoEWeightLoadingMode,
|
||||
allow_partial_loading: bool = False):
|
||||
if is_sm_100f():
|
||||
expert_ids = set(module.initial_local_expert_ids)
|
||||
if self.need_load_shared_weights(module):
|
||||
@ -888,7 +1008,8 @@ class DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm(
|
||||
scale = weights[name][:]
|
||||
weights[weight_name], weights[name] = resmooth_to_fp8_e8m0(
|
||||
weight, scale)
|
||||
super().load_weights(module, weights, weight_loading_mode)
|
||||
super().load_weights(module, weights, weight_loading_mode,
|
||||
allow_partial_loading)
|
||||
|
||||
def post_load_weights(self, module: torch.nn.Module):
|
||||
super().post_load_weights(module)
|
||||
@ -900,8 +1021,12 @@ class DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm(
|
||||
recipe=(1, 128, 128),
|
||||
num_groups=module.w3_w1_weight.shape[0],
|
||||
is_sfa=False)
|
||||
module.w3_w1_weight_scaling_factor = nn.Parameter(
|
||||
transformed_w3_w1_weight_scaling_factor = nn.Parameter(
|
||||
transfromed_w3_w1_scale, requires_grad=False)
|
||||
replace_parameter_and_save_metadata(
|
||||
module, "w3_w1_weight_scaling_factor",
|
||||
transformed_w3_w1_weight_scaling_factor,
|
||||
module.rebuild_tensor_metadata)
|
||||
transfromed_w2_scale = transform_sf_into_required_layout(
|
||||
module.quant_scales[1],
|
||||
mn=module.w2_weight.shape[1],
|
||||
@ -909,8 +1034,12 @@ class DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm(
|
||||
recipe=(1, 128, 128),
|
||||
num_groups=module.w3_w1_weight.shape[0],
|
||||
is_sfa=False)
|
||||
module.w2_weight_scaling_factor = nn.Parameter(transfromed_w2_scale,
|
||||
requires_grad=False)
|
||||
transformed_w2_weight_scaling_factor = nn.Parameter(
|
||||
transfromed_w2_scale, requires_grad=False)
|
||||
replace_parameter_and_save_metadata(
|
||||
module, "w2_weight_scaling_factor",
|
||||
transformed_w2_weight_scaling_factor,
|
||||
module.rebuild_tensor_metadata)
|
||||
self.setup_quant_scales(module)
|
||||
|
||||
|
||||
|
||||
@ -28,7 +28,8 @@ from tensorrt_llm.quantization.utils.fp8_utils import (
|
||||
|
||||
from ..._utils import get_sm_version, is_sm_100f
|
||||
from ...models.modeling_utils import QuantConfig
|
||||
from ..utils import Fp4QuantizedTensor, get_model_extra_attrs, unswizzle_sf
|
||||
from ..utils import (Fp4QuantizedTensor, get_model_extra_attrs,
|
||||
replace_parameter_and_save_metadata, unswizzle_sf)
|
||||
|
||||
|
||||
class WeightMode(str, enum.Enum):
|
||||
@ -39,6 +40,40 @@ class WeightMode(str, enum.Enum):
|
||||
# weight of a fused gate and up linear layer
|
||||
FUSED_GATE_UP_LINEAR = 'fused_gate_up_linear'
|
||||
|
||||
@property
|
||||
def int_value(self) -> int:
|
||||
_INT_MAP = {
|
||||
WeightMode.VANILLA: 1,
|
||||
WeightMode.FUSED_GATE_UP_LINEAR: 2,
|
||||
WeightMode.FUSED_QKV_LINEAR: 3,
|
||||
}
|
||||
return _INT_MAP[self]
|
||||
|
||||
@property
|
||||
def shard_keys(self) -> list[str] | None:
|
||||
_SHARD_KEYS_MAP = {
|
||||
WeightMode.VANILLA: None,
|
||||
WeightMode.FUSED_GATE_UP_LINEAR: ['gate', 'up'],
|
||||
WeightMode.FUSED_QKV_LINEAR: ['q', 'k', 'v'],
|
||||
}
|
||||
return _SHARD_KEYS_MAP[self]
|
||||
|
||||
@property
|
||||
def shard_key_to_index(self) -> dict[str, int] | None:
|
||||
_SHARD_KEY_TO_INDEX_MAP = {
|
||||
WeightMode.VANILLA: None,
|
||||
WeightMode.FUSED_GATE_UP_LINEAR: {
|
||||
'gate': 0,
|
||||
'up': 1
|
||||
},
|
||||
WeightMode.FUSED_QKV_LINEAR: {
|
||||
'q': 0,
|
||||
'k': 1,
|
||||
'v': 2
|
||||
},
|
||||
}
|
||||
return _SHARD_KEY_TO_INDEX_MAP[self]
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class WeightsLoadingConfig:
|
||||
@ -327,6 +362,9 @@ class LinearMethodBase(ABC):
|
||||
else:
|
||||
raise ValueError(f'unsupported weight mode: {weight_mode}')
|
||||
|
||||
if not allow_partial_loading:
|
||||
self.process_weights_after_loading(module)
|
||||
|
||||
def post_load_weights(self, module: Linear):
|
||||
pass
|
||||
|
||||
@ -367,6 +405,36 @@ class LinearMethodBase(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def process_weights_after_loading(self, module: Linear):
|
||||
"""
|
||||
Process quantization weights and scales after loading weights.
|
||||
"""
|
||||
weight_mode = module.weights_loading_config.weight_mode
|
||||
if weight_mode == WeightMode.VANILLA:
|
||||
self.process_weights_after_loading_vanilla(module)
|
||||
elif weight_mode == WeightMode.FUSED_QKV_LINEAR:
|
||||
self.process_weights_after_loading_fused_qkv_linear(module)
|
||||
elif weight_mode == WeightMode.FUSED_GATE_UP_LINEAR:
|
||||
self.process_weights_after_loading_fused_gate_up_linear(module)
|
||||
else:
|
||||
raise ValueError(f'unsupported weight mode: {weight_mode}')
|
||||
|
||||
def process_weights_after_loading_vanilla(self, module: Linear):
|
||||
"""
|
||||
Process quantization weights and scales after loading weights for vanilla linear layer.
|
||||
"""
|
||||
|
||||
def process_weights_after_loading_fused_qkv_linear(self, module: Linear):
|
||||
"""
|
||||
Process quantization weights and scales after loading weights for fused QKV linear layer.
|
||||
"""
|
||||
|
||||
def process_weights_after_loading_fused_gate_up_linear(
|
||||
self, module: Linear):
|
||||
"""
|
||||
Process quantization weights and scales after loading weights for fused gate up linear layer.
|
||||
"""
|
||||
|
||||
|
||||
class UnquantizedLinearMethod(LinearMethodBase):
|
||||
|
||||
@ -382,6 +450,8 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
||||
else:
|
||||
module.register_parameter("bias", None)
|
||||
|
||||
module.rebuild_tensor_metadata = {}
|
||||
|
||||
def apply(self, module: Linear, input: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]):
|
||||
if module.use_custom_cublas_mm:
|
||||
@ -440,8 +510,17 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
||||
copy_weight_shard(module.weight, weight, shard_offset,
|
||||
shard_size)
|
||||
|
||||
def pre_reload_weights(self, module: Linear):
|
||||
for param_name, metadata in module.rebuild_tensor_metadata.items():
|
||||
logger.warning(
|
||||
f"Pre-reloading weight '{param_name}' requires tensor re-creation, which will invalidate existing CUDA graphs."
|
||||
)
|
||||
param = Parameter(torch.empty_like(metadata, device="cuda"),
|
||||
requires_grad=False)
|
||||
module.register_parameter(param_name, param)
|
||||
|
||||
class FP8QDQLinearMethod(LinearMethodBase):
|
||||
|
||||
class FP8QDQLinearMethod(UnquantizedLinearMethod):
|
||||
|
||||
def create_weights(self, module: Linear, in_features: int,
|
||||
out_features: int, bias: bool, dtype: torch.dtype):
|
||||
@ -468,6 +547,8 @@ class FP8QDQLinearMethod(LinearMethodBase):
|
||||
else:
|
||||
module.register_parameter("bias", None)
|
||||
|
||||
module.rebuild_tensor_metadata = {}
|
||||
|
||||
def apply(self, module: Linear, input: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]):
|
||||
cur_input_scale = module.input_scale
|
||||
@ -518,93 +599,224 @@ class FP8QDQLinearMethod(LinearMethodBase):
|
||||
v_scale.append(w["v_scale"][...].reshape([]))
|
||||
return k_scale, v_scale
|
||||
|
||||
def load_weight_scales(self, weights: List[Dict]):
|
||||
input_scale, weight_scale = [], []
|
||||
for w in weights:
|
||||
if "input_scale" in w:
|
||||
input_scale.append(w["input_scale"][...].reshape([]))
|
||||
if "weight_scale" in w:
|
||||
weight_scale.append(w["weight_scale"][...].reshape([]))
|
||||
return input_scale, weight_scale
|
||||
|
||||
def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None:
|
||||
load_weights_vanilla_helper(module, weights)
|
||||
input_scale, weight_scale = self.load_weight_scales(weights)
|
||||
if len(input_scale) != 0:
|
||||
# Static quantization
|
||||
copy_weight(module.input_scale, input_scale[0])
|
||||
module.inv_input_scale.data = 1.0 / module.input_scale
|
||||
def load_weight_scales(self,
|
||||
weights: List[Dict],
|
||||
shard_keys: list[str] = None):
|
||||
input_scales, weight_scales = {}, {}
|
||||
if shard_keys is None:
|
||||
for w in weights:
|
||||
if "input_scale" in w:
|
||||
input_scales[None] = w["input_scale"][...].reshape([])
|
||||
if "weight_scale" in w:
|
||||
weight_scales[None] = w["weight_scale"][...].reshape([])
|
||||
else:
|
||||
# Dynamic quantization
|
||||
for shard_key, w in zip(shard_keys, weights):
|
||||
if "input_scale" in w:
|
||||
input_scales[shard_key] = w["input_scale"][...].reshape([])
|
||||
if "weight_scale" in w:
|
||||
weight_scales[shard_key] = w["weight_scale"][...].reshape(
|
||||
[])
|
||||
return input_scales, weight_scales
|
||||
|
||||
def load_weights_vanilla(self,
|
||||
module: Linear,
|
||||
weights: List[Dict],
|
||||
allow_partial_loading: bool = False) -> None:
|
||||
super().load_weights_vanilla(
|
||||
module, weights, allow_partial_loading=allow_partial_loading)
|
||||
input_scale, weight_scale = self.load_weight_scales(weights)
|
||||
if input_scale:
|
||||
copy_weight(module.input_scale, input_scale[None])
|
||||
module.inv_input_scale.data = 1.0 / module.input_scale
|
||||
setattr(module, "has_static_input_scale", True)
|
||||
if weight_scale:
|
||||
copy_weight(module.weight_scale, weight_scale[None])
|
||||
|
||||
def process_weights_after_loading_vanilla(self, module: Linear):
|
||||
if not getattr(module, "has_static_input_scale", False):
|
||||
module.input_scale = None
|
||||
module.inv_input_scale = None
|
||||
copy_weight(module.weight_scale, weight_scale[0])
|
||||
delattr(module, "has_static_input_scale")
|
||||
|
||||
def load_weights_fused_qkv_linear(self, module: Linear,
|
||||
weights: List[Dict]) -> None:
|
||||
q_weight, k_weight, v_weight = load_weights_fused_qkv_helper(
|
||||
module, weights)
|
||||
def load_weights_fused_qkv_linear(
|
||||
self,
|
||||
module: Linear,
|
||||
weights: List[Dict],
|
||||
allow_partial_loading: bool = False) -> None:
|
||||
"""
|
||||
Load weights for fused QKV linear layer.
|
||||
|
||||
input_scale, weight_scale = self.load_weight_scales(weights)
|
||||
if len(input_scale) != 0:
|
||||
# Static quantization
|
||||
copy_weight(module.input_scale, max(input_scale))
|
||||
else:
|
||||
# Dynamic quantization
|
||||
module.input_scale = None
|
||||
In partial loading mode, only loads weights and scales to their designated positions.
|
||||
The actual rescaling is deferred to process_weights_after_loading_fused_qkv_linear.
|
||||
"""
|
||||
# Parent class handles weight loading
|
||||
super().load_weights_fused_qkv_linear(
|
||||
module, weights, allow_partial_loading=allow_partial_loading)
|
||||
weight_mode = module.weights_loading_config.weight_mode
|
||||
if not hasattr(module, "tmp_input_scales"):
|
||||
module.tmp_input_scales = torch.empty(
|
||||
weight_mode.int_value,
|
||||
dtype=torch.float32,
|
||||
device=module.input_scale.device)
|
||||
if not hasattr(module, "tmp_weight_scales"):
|
||||
module.tmp_weight_scales = torch.empty(
|
||||
weight_mode.int_value,
|
||||
dtype=torch.float32,
|
||||
device=module.weight_scale.device)
|
||||
# Load input_scale and weight_scale to tmp_qkv_input_scales and tmp_qkv_weight_scales
|
||||
# q -> index 0, k -> index 1, v -> index 2
|
||||
input_scales, weight_scales = self.load_weight_scales(
|
||||
weights, shard_keys=weight_mode.shard_keys)
|
||||
shard_key_to_index = weight_mode.shard_key_to_index
|
||||
|
||||
copy_weight(module.weight_scale, max(weight_scale))
|
||||
for shard_key, scale in input_scales.items():
|
||||
idx = shard_key_to_index[shard_key]
|
||||
module.tmp_input_scales[idx] = scale
|
||||
setattr(module, "has_static_input_scale", True)
|
||||
|
||||
# use in-place multiplication and division to avoid extra memory allocation
|
||||
q_weight = q_weight.to(module.dtype).mul_(weight_scale[0])
|
||||
k_weight = k_weight.to(module.dtype).mul_(weight_scale[1])
|
||||
v_weight = v_weight.to(module.dtype).mul_(weight_scale[2])
|
||||
|
||||
fused_weight = torch.cat((q_weight, k_weight, v_weight))
|
||||
fused_weight = fused_weight.div_(
|
||||
module.weight_scale.to(fused_weight.device)).to(torch.float8_e4m3fn)
|
||||
copy_weight(module.weight, fused_weight)
|
||||
for shard_key, scale in weight_scales.items():
|
||||
idx = shard_key_to_index[shard_key]
|
||||
module.tmp_weight_scales[idx] = scale
|
||||
|
||||
# Load k and v scales, used for NVFP4 KV cache
|
||||
# Store them temporarily for post-processing
|
||||
k_scale, v_scale = self.load_kv_scales(weights)
|
||||
# NOTE: Currently the calibrated kv scales may cause overflow for certain input, disabling by default.
|
||||
if k_scale:
|
||||
if getattr(module, "tmp_k_scales", None) is None:
|
||||
module.tmp_k_scales = []
|
||||
module.tmp_k_scales.extend(k_scale)
|
||||
if v_scale:
|
||||
if getattr(module, "tmp_v_scales", None) is None:
|
||||
module.tmp_v_scales = []
|
||||
module.tmp_v_scales.extend(v_scale)
|
||||
|
||||
def rescale_fused_weights(self, module: Linear):
|
||||
"""
|
||||
Helper function to rescale fused weights.
|
||||
|
||||
This method:
|
||||
1. Computes the max input_scale of all shards(qkv or gate/up) and update input_scale parameter to the max value
|
||||
2. Computes the max weight_scale across all shards(qkv or gate/up)
|
||||
3. Rescales each weight shard: weight * (original_scale / max_scale)
|
||||
4. Updates weight_scale parameter to the unified max value
|
||||
"""
|
||||
weight_mode = module.weights_loading_config.weight_mode
|
||||
shard_key_to_index = weight_mode.shard_key_to_index
|
||||
|
||||
# Handle input_scale
|
||||
if getattr(module, "has_static_input_scale", False):
|
||||
# Compute max and replace input_scale with a new parameter
|
||||
max_input_scale = module.tmp_input_scales.max()
|
||||
module.input_scale.data.copy_(max_input_scale)
|
||||
module.inv_input_scale.data = 1.0 / module.input_scale
|
||||
delattr(module, "has_static_input_scale")
|
||||
else:
|
||||
module.input_scale = None
|
||||
module.inv_input_scale = None
|
||||
|
||||
# Compute max weight_scale
|
||||
max_weight_scale = module.tmp_weight_scales.max()
|
||||
module.weight_scale.data.copy_(max_weight_scale)
|
||||
|
||||
# Rescale each weight shard: weight * (original_scale / max_scale)
|
||||
for shard_key in weight_mode.shard_keys:
|
||||
idx = shard_key_to_index[shard_key]
|
||||
original_scale = module.tmp_weight_scales[idx]
|
||||
|
||||
# Get shard position from mapping
|
||||
shard_offset, shard_size = module.fused_weight_shard_indices_mapping[
|
||||
shard_key]
|
||||
|
||||
# Rescale: FP8 -> BF16 -> multiply by ratio -> FP8
|
||||
weight_shard = module.weight.data[shard_offset:shard_offset +
|
||||
shard_size]
|
||||
rescaled_weight = (weight_shard.to(module.dtype).mul_(
|
||||
original_scale / max_weight_scale).to(torch.float8_e4m3fn))
|
||||
module.weight.data[shard_offset:shard_offset +
|
||||
shard_size] = rescaled_weight
|
||||
|
||||
delattr(module, "tmp_input_scales")
|
||||
delattr(module, "tmp_weight_scales")
|
||||
|
||||
def process_weights_after_loading_fused_qkv_linear(self, module: Linear):
|
||||
"""
|
||||
Post-process weights after all partial loads are complete.
|
||||
"""
|
||||
self.rescale_fused_weights(module)
|
||||
|
||||
# Handle kv_scales for NVFP4 KV cache
|
||||
if os.environ.get("TRTLLM_LOAD_KV_SCALES", "0") == "1":
|
||||
if len(k_scale) != 0:
|
||||
assert len(v_scale) != 0
|
||||
k_scales = getattr(module, "tmp_k_scales", [])
|
||||
v_scales = getattr(module, "tmp_v_scales", [])
|
||||
if k_scales:
|
||||
assert v_scales, "k_scale and v_scale must be loaded together"
|
||||
# The calibrated KV scales are amax / (6 * 448), but the requested KV scales are amax / 448,
|
||||
# to avoid overflow when dequantizing NVFP4 in attention kernels.
|
||||
copy_weight(
|
||||
module.kv_scales,
|
||||
torch.tensor(
|
||||
[1.0, max(k_scale) * 6.0,
|
||||
max(v_scale) * 6.0],
|
||||
dtype=torch.float32))
|
||||
torch.tensor([
|
||||
1.0,
|
||||
max(k_scales).item() * 6.0,
|
||||
max(v_scales).item() * 6.0
|
||||
],
|
||||
dtype=torch.float32))
|
||||
module.inv_kv_scales.data = 1.0 / module.kv_scales
|
||||
|
||||
def load_weights_fused_gate_up_linear(self, module: Linear,
|
||||
weights: List[Dict]) -> None:
|
||||
input_scale, weight_scale = self.load_weight_scales(weights)
|
||||
if len(input_scale) != 0:
|
||||
# Static quantization
|
||||
copy_weight(module.input_scale, max(input_scale))
|
||||
else:
|
||||
# Dynamic quantization
|
||||
module.input_scale = None
|
||||
copy_weight(module.weight_scale, max(weight_scale))
|
||||
# Clean up temporary attributes
|
||||
if hasattr(module, "tmp_k_scales"):
|
||||
delattr(module, "tmp_k_scales")
|
||||
if hasattr(module, "tmp_v_scales"):
|
||||
delattr(module, "tmp_v_scales")
|
||||
|
||||
gate_weight, up_weight = load_weights_fused_gate_up_helper(
|
||||
module, weights)
|
||||
def load_weights_fused_gate_up_linear(
|
||||
self,
|
||||
module: Linear,
|
||||
weights: List[Dict],
|
||||
allow_partial_loading: bool = False) -> None:
|
||||
"""
|
||||
Load weights for fused gate/up linear layer.
|
||||
|
||||
# use in-place multiplication and division to avoid extra memory allocation
|
||||
gate_weight = gate_weight.to(module.dtype).mul_(weight_scale[0])
|
||||
up_weight = up_weight.to(module.dtype).mul_(weight_scale[1])
|
||||
fused_weight = torch.cat((gate_weight, up_weight))
|
||||
fused_weight = fused_weight.div_(
|
||||
module.weight_scale.to(fused_weight.device)).to(torch.float8_e4m3fn)
|
||||
copy_weight(module.weight, fused_weight)
|
||||
In partial loading mode, only loads weights and scales to their designated positions.
|
||||
The actual rescaling is deferred to process_weights_after_loading_fused_gate_up_linear.
|
||||
"""
|
||||
# Parent class handles weight loading
|
||||
super().load_weights_fused_gate_up_linear(
|
||||
module, weights, allow_partial_loading=allow_partial_loading)
|
||||
weight_mode = module.weights_loading_config.weight_mode
|
||||
if not hasattr(module, "tmp_input_scales"):
|
||||
module.tmp_input_scales = torch.empty(
|
||||
weight_mode.int_value,
|
||||
dtype=torch.float32,
|
||||
device=module.input_scale.device)
|
||||
if not hasattr(module, "tmp_weight_scales"):
|
||||
module.tmp_weight_scales = torch.empty(
|
||||
weight_mode.int_value,
|
||||
dtype=torch.float32,
|
||||
device=module.weight_scale.device)
|
||||
# Load input_scale and weight_scale to their designated positions
|
||||
# gate -> index 0, up -> index 1
|
||||
input_scales, weight_scales = self.load_weight_scales(
|
||||
weights, shard_keys=weight_mode.shard_keys)
|
||||
shard_key_to_index = weight_mode.shard_key_to_index
|
||||
|
||||
for shard_key, scale in input_scales.items():
|
||||
idx = shard_key_to_index[shard_key]
|
||||
module.tmp_input_scales[idx] = scale
|
||||
setattr(module, "has_static_input_scale", True)
|
||||
|
||||
for shard_key, scale in weight_scales.items():
|
||||
idx = shard_key_to_index[shard_key]
|
||||
module.tmp_weight_scales[idx] = scale
|
||||
|
||||
def process_weights_after_loading_fused_gate_up_linear(
|
||||
self, module: Linear):
|
||||
"""
|
||||
Post-process weights after all partial loads are complete.
|
||||
"""
|
||||
self.rescale_fused_weights(module)
|
||||
|
||||
|
||||
class FP8RowwiseLinearMethod(LinearMethodBase):
|
||||
class FP8RowwiseLinearMethod(UnquantizedLinearMethod):
|
||||
|
||||
def create_weights(self, module: Linear, in_features: int,
|
||||
out_features: int, bias: bool, dtype: torch.dtype):
|
||||
@ -628,6 +840,8 @@ class FP8RowwiseLinearMethod(LinearMethodBase):
|
||||
else:
|
||||
module.register_parameter("bias", None)
|
||||
|
||||
module.rebuild_tensor_metadata = {}
|
||||
|
||||
def apply(self, module: Linear, input: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]):
|
||||
# FP8 tensor inputs are from attention. Directly use ones as scale.
|
||||
@ -661,51 +875,72 @@ class FP8RowwiseLinearMethod(LinearMethodBase):
|
||||
scale_name = "weight_scale"
|
||||
return scale_name
|
||||
|
||||
def load_weights_vanilla(self, module: Linear, weights: List[Dict]):
|
||||
load_weights_vanilla_helper(module, weights)
|
||||
|
||||
def load_weights_vanilla(self,
|
||||
module: Linear,
|
||||
weights: List[Dict],
|
||||
allow_partial_loading: bool = False):
|
||||
super().load_weights_vanilla(
|
||||
module, weights, allow_partial_loading=allow_partial_loading)
|
||||
scale_name = self._get_scale_name(weights)
|
||||
weight_scale = load_weight_shard(weights[0][scale_name], module.tp_size,
|
||||
module.tp_rank, module.tp_mode)
|
||||
copy_weight(module.weight_scale, weight_scale)
|
||||
if scale_name in weights[0]:
|
||||
weight_scale = load_weight_shard(weights[0][scale_name],
|
||||
module.tp_size, module.tp_rank,
|
||||
module.tp_mode)
|
||||
copy_weight(module.weight_scale, weight_scale)
|
||||
if "input_scale" in weights[0]:
|
||||
copy_weight(module.input_scale, weights[0]["input_scale"])
|
||||
module.inv_input_scale.data = 1.0 / module.input_scale
|
||||
|
||||
def load_weights_fused_qkv_linear(self, module: Linear,
|
||||
weights: List[Dict]):
|
||||
q_weight, k_weight, v_weight = load_weights_fused_qkv_helper(
|
||||
module, weights)
|
||||
fused_weight = torch.cat((q_weight, k_weight, v_weight))
|
||||
copy_weight(module.weight, fused_weight)
|
||||
|
||||
def load_weights_fused_qkv_linear(self,
|
||||
module: Linear,
|
||||
weights: List[Dict],
|
||||
allow_partial_loading: bool = False):
|
||||
super().load_weights_fused_qkv_linear(
|
||||
module, weights, allow_partial_loading=allow_partial_loading)
|
||||
scale_name = self._get_scale_name(weights)
|
||||
q_scale = load_weight_shard(weights[0][scale_name], module.tp_size,
|
||||
module.tp_rank, module.tp_mode)
|
||||
k_scale = load_weight_shard(weights[1][scale_name], module.tp_size,
|
||||
module.tp_rank, module.tp_mode)
|
||||
v_scale = load_weight_shard(weights[2][scale_name], module.tp_size,
|
||||
module.tp_rank, module.tp_mode)
|
||||
fused_fp8_block_scale = torch.cat((q_scale, k_scale, v_scale))
|
||||
copy_weight(module.weight_scale, fused_fp8_block_scale)
|
||||
|
||||
def load_weights_fused_gate_up_linear(self, module: Linear,
|
||||
weights: List[Dict]):
|
||||
gate_weight, up_weight = load_weights_fused_gate_up_helper(
|
||||
module, weights)
|
||||
fused_weight = torch.cat((gate_weight, up_weight))
|
||||
copy_weight(module.weight, fused_weight)
|
||||
q_scale = load_weight_shard(
|
||||
weights[0][scale_name], module.tp_size, module.tp_rank,
|
||||
module.tp_mode) if scale_name in weights[0] else None
|
||||
k_scale = load_weight_shard(
|
||||
weights[1][scale_name], module.tp_size, module.tp_rank,
|
||||
module.tp_mode) if scale_name in weights[1] else None
|
||||
v_scale = load_weight_shard(
|
||||
weights[2][scale_name], module.tp_size, module.tp_rank,
|
||||
module.tp_mode) if scale_name in weights[2] else None
|
||||
for shard_key, scale in zip(
|
||||
module.fused_weight_shard_indices_mapping.keys(),
|
||||
[q_scale, k_scale, v_scale]):
|
||||
if scale is not None:
|
||||
shard_offset, shard_size = module.fused_weight_shard_indices_mapping[
|
||||
shard_key]
|
||||
copy_weight_shard(module.weight_scale, scale, shard_offset,
|
||||
shard_size)
|
||||
|
||||
def load_weights_fused_gate_up_linear(
|
||||
self,
|
||||
module: Linear,
|
||||
weights: List[Dict],
|
||||
allow_partial_loading: bool = False) -> None:
|
||||
super().load_weights_fused_gate_up_linear(
|
||||
module, weights, allow_partial_loading=allow_partial_loading)
|
||||
scale_name = self._get_scale_name(weights)
|
||||
left_scale = load_weight_shard(weights[0][scale_name], module.tp_size,
|
||||
module.tp_rank, module.tp_mode)
|
||||
right_scale = load_weight_shard(weights[1][scale_name], module.tp_size,
|
||||
module.tp_rank, module.tp_mode)
|
||||
fused_scale = torch.cat((left_scale, right_scale))
|
||||
copy_weight(module.weight_scale, fused_scale)
|
||||
gate_scale = load_weight_shard(
|
||||
weights[0][scale_name], module.tp_size, module.tp_rank,
|
||||
module.tp_mode) if scale_name in weights[0] else None
|
||||
up_scale = load_weight_shard(
|
||||
weights[1][scale_name], module.tp_size, module.tp_rank,
|
||||
module.tp_mode) if scale_name in weights[1] else None
|
||||
for shard_key, scale in zip(
|
||||
module.fused_weight_shard_indices_mapping.keys(),
|
||||
[gate_scale, up_scale]):
|
||||
if scale is not None:
|
||||
shard_offset, shard_size = module.fused_weight_shard_indices_mapping[
|
||||
shard_key]
|
||||
copy_weight_shard(module.weight_scale, scale, shard_offset,
|
||||
shard_size)
|
||||
|
||||
|
||||
class FP8BlockScalesLinearMethod(LinearMethodBase):
|
||||
class FP8BlockScalesLinearMethod(UnquantizedLinearMethod):
|
||||
|
||||
def create_weights(self, module: Linear, in_features: int,
|
||||
out_features: int, bias: bool, dtype: torch.dtype):
|
||||
@ -732,6 +967,8 @@ class FP8BlockScalesLinearMethod(LinearMethodBase):
|
||||
else:
|
||||
module.register_parameter("bias", None)
|
||||
|
||||
module.rebuild_tensor_metadata = {}
|
||||
|
||||
def apply(self, module: Linear, input: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]):
|
||||
if input.dtype == torch.float8_e4m3fn:
|
||||
@ -770,75 +1007,107 @@ class FP8BlockScalesLinearMethod(LinearMethodBase):
|
||||
def _get_scale_name(self, weights: List[Dict]):
|
||||
# `weight_scale_inv` for DS recipe and `weight_scale` for ModelOpt recipe.
|
||||
# Actually they hold identical values of data_amax / 448.
|
||||
scale_name = "weight_scale_inv"
|
||||
if scale_name not in weights[0]:
|
||||
scale_name = "weight_scale"
|
||||
return scale_name
|
||||
for w in weights:
|
||||
if "weight_scale_inv" in w:
|
||||
return "weight_scale_inv"
|
||||
return "weight_scale"
|
||||
|
||||
def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None:
|
||||
load_weights_vanilla_helper(module, weights)
|
||||
def load_weights_vanilla(self,
|
||||
module: Linear,
|
||||
weights: List[Dict],
|
||||
allow_partial_loading: bool = False) -> None:
|
||||
super().load_weights_vanilla(
|
||||
module, weights, allow_partial_loading=allow_partial_loading)
|
||||
|
||||
scale_name = self._get_scale_name(weights)
|
||||
full_weight_scale = weights[0][scale_name]
|
||||
# modelopt fp8_pb_wo can have 2 extra singleton dimensions
|
||||
if full_weight_scale.dim() == 4:
|
||||
full_weight_scale = full_weight_scale.squeeze(1).squeeze(-1)
|
||||
weight_scale = load_weight_shard(full_weight_scale, module.tp_size,
|
||||
module.tp_rank, module.tp_mode)
|
||||
copy_weight(module.weight_scale, weight_scale)
|
||||
if scale_name in weights[0]:
|
||||
full_weight_scale = weights[0][scale_name]
|
||||
# modelopt fp8_pb_wo can have 2 extra singleton dimensions
|
||||
if full_weight_scale.dim() == 4:
|
||||
full_weight_scale = full_weight_scale.squeeze(1).squeeze(-1)
|
||||
weight_scale = load_weight_shard(full_weight_scale, module.tp_size,
|
||||
module.tp_rank, module.tp_mode)
|
||||
copy_weight(module.weight_scale, weight_scale)
|
||||
if "input_scale" in weights[0]:
|
||||
copy_weight(module.input_scale, weights[0]["input_scale"])
|
||||
module.inv_input_scale.data = 1.0 / module.input_scale
|
||||
|
||||
def load_weights_fused_qkv_linear(self, module: Linear,
|
||||
weights: List[Dict]) -> None:
|
||||
q_weight, k_weight, v_weight = load_weights_fused_qkv_helper(
|
||||
module, weights)
|
||||
fused_weight = torch.cat((q_weight, k_weight, v_weight))
|
||||
def remap_fused_shard_indices_by_divisible_factor(self, mapping: Dict,
|
||||
divisible_factor: int):
|
||||
"""
|
||||
Remap fused weight shard indices to scale coordinates by dividing by divisible_factor.
|
||||
|
||||
Args:
|
||||
mapping: Dict of {shard_key: (offset, size)} in weight coordinates
|
||||
divisible_factor: Block size (e.g., 128 for block-scale quantization)
|
||||
|
||||
Returns:
|
||||
Dict of {shard_key: (scale_offset, scale_size)} in scale coordinates
|
||||
"""
|
||||
result = {}
|
||||
for key, (offset, size) in mapping.items():
|
||||
scale_offset = math.ceil(offset / divisible_factor)
|
||||
scale_size = math.ceil(size / divisible_factor)
|
||||
result[key] = (scale_offset, scale_size)
|
||||
return result
|
||||
|
||||
def load_weights_fused_qkv_linear(
|
||||
self,
|
||||
module: Linear,
|
||||
weights: List[Dict],
|
||||
allow_partial_loading: bool = False) -> None:
|
||||
super().load_weights_fused_qkv_linear(
|
||||
module, weights, allow_partial_loading=allow_partial_loading)
|
||||
|
||||
scale_name = self._get_scale_name(weights)
|
||||
full_q_scale = weights[0][scale_name]
|
||||
full_k_scale = weights[1][scale_name]
|
||||
full_v_scale = weights[2][scale_name]
|
||||
# modelopt fp8_pb_wo can have 2 extra singleton dimensions
|
||||
if full_q_scale.dim() == 4:
|
||||
full_q_scale = full_q_scale.squeeze(1).squeeze(-1)
|
||||
if full_k_scale.dim() == 4:
|
||||
full_k_scale = full_k_scale.squeeze(1).squeeze(-1)
|
||||
if full_v_scale.dim() == 4:
|
||||
full_v_scale = full_v_scale.squeeze(1).squeeze(-1)
|
||||
q_scale = load_weight_shard(full_q_scale, module.tp_size,
|
||||
module.tp_rank, module.tp_mode)
|
||||
k_scale = load_weight_shard(full_k_scale, module.tp_size,
|
||||
module.tp_rank, module.tp_mode)
|
||||
v_scale = load_weight_shard(full_v_scale, module.tp_size,
|
||||
module.tp_rank, module.tp_mode)
|
||||
fused_fp8_block_scale = torch.cat((q_scale, k_scale, v_scale))
|
||||
full_scales = [
|
||||
w[scale_name] if scale_name in w else None for w in weights[:3]
|
||||
]
|
||||
full_scales_squeezed = [
|
||||
s.squeeze(1).squeeze(-1) if s is not None and s.dim() == 4 else s
|
||||
for s in full_scales
|
||||
]
|
||||
|
||||
copy_weight(module.weight, fused_weight)
|
||||
copy_weight(module.weight_scale, fused_fp8_block_scale)
|
||||
scales = [
|
||||
load_weight_shard(s, module.tp_size, module.tp_rank, module.tp_mode)
|
||||
if s is not None else None for s in full_scales_squeezed
|
||||
]
|
||||
processed_mapping = self.remap_fused_shard_indices_by_divisible_factor(
|
||||
module.fused_weight_shard_indices_mapping, 128)
|
||||
for shard_key, scale in zip(processed_mapping.keys(), scales):
|
||||
if scale is not None:
|
||||
shard_offset, shard_size = processed_mapping[shard_key]
|
||||
copy_weight_shard(module.weight_scale, scale, shard_offset,
|
||||
shard_size)
|
||||
|
||||
def load_weights_fused_gate_up_linear(self, module: Linear,
|
||||
weights: List[Dict]) -> None:
|
||||
gate_weight, up_weight = load_weights_fused_gate_up_helper(
|
||||
module, weights)
|
||||
fused_weight = torch.cat((gate_weight, up_weight))
|
||||
def load_weights_fused_gate_up_linear(
|
||||
self,
|
||||
module: Linear,
|
||||
weights: List[Dict],
|
||||
allow_partial_loading: bool = False) -> None:
|
||||
super().load_weights_fused_gate_up_linear(
|
||||
module, weights, allow_partial_loading=allow_partial_loading)
|
||||
|
||||
scale_name = self._get_scale_name(weights)
|
||||
full_left_scale = weights[0][scale_name]
|
||||
full_right_scale = weights[1][scale_name]
|
||||
# modelopt fp8_pb_wo can have 2 extra singleton dimensions
|
||||
if full_left_scale.dim() == 4:
|
||||
full_left_scale = full_left_scale.squeeze(1).squeeze(-1)
|
||||
if full_right_scale.dim() == 4:
|
||||
full_right_scale = full_right_scale.squeeze(1).squeeze(-1)
|
||||
left_scale = load_weight_shard(full_left_scale, module.tp_size,
|
||||
module.tp_rank, module.tp_mode)
|
||||
right_scale = load_weight_shard(full_right_scale, module.tp_size,
|
||||
module.tp_rank, module.tp_mode)
|
||||
fused_scale = torch.cat([left_scale, right_scale], dim=0)
|
||||
copy_weight(module.weight, fused_weight)
|
||||
copy_weight(module.weight_scale, fused_scale)
|
||||
full_scales = [
|
||||
w[scale_name] if scale_name in w else None for w in weights[:2]
|
||||
]
|
||||
full_scales_squeezed = [
|
||||
s.squeeze(1).squeeze(-1) if s is not None and s.dim() == 4 else s
|
||||
for s in full_scales
|
||||
]
|
||||
scales = [
|
||||
load_weight_shard(s, module.tp_size, module.tp_rank, module.tp_mode)
|
||||
if s is not None else None for s in full_scales_squeezed
|
||||
]
|
||||
processed_mapping = self.remap_fused_shard_indices_by_divisible_factor(
|
||||
module.fused_weight_shard_indices_mapping, 128)
|
||||
for shard_key, scale in zip(processed_mapping.keys(), scales):
|
||||
if scale is not None:
|
||||
shard_offset, shard_size = processed_mapping[shard_key]
|
||||
copy_weight_shard(module.weight_scale, scale, shard_offset,
|
||||
shard_size)
|
||||
|
||||
def post_load_weights(self, module: Linear):
|
||||
super().post_load_weights(module)
|
||||
@ -847,17 +1116,19 @@ class FP8BlockScalesLinearMethod(LinearMethodBase):
|
||||
get_sm_version() == 120:
|
||||
weight, weight_scale = resmooth_to_fp8_e8m0(module.weight,
|
||||
module.weight_scale)
|
||||
transfromed_scale = transform_sf_into_required_layout(
|
||||
transformed_scale = transform_sf_into_required_layout(
|
||||
weight_scale,
|
||||
mn=weight.shape[0],
|
||||
k=weight.shape[1],
|
||||
recipe=(1, 128, 128),
|
||||
is_sfa=False)
|
||||
module.weight = nn.Parameter(weight, requires_grad=False)
|
||||
module.weight_scale = nn.Parameter(
|
||||
transfromed_scale,
|
||||
requires_grad=False,
|
||||
)
|
||||
replace_parameter_and_save_metadata(
|
||||
module, "weight", nn.Parameter(weight, requires_grad=False),
|
||||
module.rebuild_tensor_metadata)
|
||||
replace_parameter_and_save_metadata(
|
||||
module, "weight_scale",
|
||||
nn.Parameter(transformed_scale, requires_grad=False),
|
||||
module.rebuild_tensor_metadata)
|
||||
|
||||
|
||||
class NVFP4LinearMethod(LinearMethodBase):
|
||||
@ -2293,5 +2564,14 @@ class Linear(nn.Module):
|
||||
weight_mode,
|
||||
allow_partial_loading=allow_partial_loading)
|
||||
|
||||
def process_weights_after_loading(self):
|
||||
self.quant_method.process_weights_after_loading(self)
|
||||
|
||||
def post_load_weights(self):
|
||||
self.quant_method.post_load_weights(self)
|
||||
|
||||
def pre_reload_weights(self):
|
||||
assert hasattr(
|
||||
self.quant_method, "pre_reload_weights"
|
||||
), "pre_reload_weights is not supported for this quant method"
|
||||
self.quant_method.pre_reload_weights(self)
|
||||
|
||||
@ -414,3 +414,15 @@ def maybe_compiled_copy_(dst, src):
|
||||
@maybe_compile
|
||||
def maybe_compiled_cat(tensors, dim):
|
||||
return torch.cat(tensors, dim)
|
||||
|
||||
|
||||
def replace_parameter_and_save_metadata(module: torch.nn.Module,
|
||||
param_name: str,
|
||||
new_param: torch.nn.Parameter,
|
||||
metadata_dict: Dict):
|
||||
"""
|
||||
Replace a parameter in a module and save the metadata of the original parameter.
|
||||
"""
|
||||
if param_name not in metadata_dict:
|
||||
metadata_dict[param_name] = getattr(module, param_name).to("meta")
|
||||
module.register_parameter(param_name, new_param)
|
||||
|
||||
@ -50,6 +50,13 @@ class WorkerExtension:
|
||||
Exception: Re-raises any exception encountered during weight update.
|
||||
"""
|
||||
try:
|
||||
if not hasattr(self.engine.model_engine.model, "first_pre_reload_weights"):
|
||||
for module in self.engine.model_engine.model.modules():
|
||||
if hasattr(module, "pre_reload_weights") and not getattr(
|
||||
module, "_weights_removed", False
|
||||
):
|
||||
module.pre_reload_weights()
|
||||
setattr(self.engine.model_engine.model, "first_pre_reload_weights", True)
|
||||
if ipc_handles is not None:
|
||||
logger.info("Update weights from IPC handles")
|
||||
device_uuid = get_device_uuid(self.device_id)
|
||||
@ -82,6 +89,10 @@ class WorkerExtension:
|
||||
else:
|
||||
logger.info("Finalize update weights")
|
||||
for module in self.engine.model_engine.model.modules():
|
||||
if hasattr(module, "process_weights_after_loading") and not getattr(
|
||||
module, "_weights_removed", False
|
||||
):
|
||||
module.process_weights_after_loading()
|
||||
if hasattr(module, "post_load_weights") and not getattr(
|
||||
module, "_weights_removed", False
|
||||
):
|
||||
@ -93,6 +104,7 @@ class WorkerExtension:
|
||||
moe_load_balancer.finalize_model()
|
||||
logger.info("moe_load_balancer finalize model done")
|
||||
self.engine.reset_prefix_cache()
|
||||
delattr(self.engine.model_engine.model, "first_pre_reload_weights")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Encountered an error in update_weights")
|
||||
|
||||
@ -15,112 +15,21 @@
|
||||
import asyncio
|
||||
import os
|
||||
from functools import partial
|
||||
from typing import List, Tuple
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
import ray
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers import AutoTokenizer
|
||||
from utils.llm_data import llm_models_root
|
||||
from utils.torch_ref import RefHFModel
|
||||
|
||||
from tensorrt_llm import LLM
|
||||
from tensorrt_llm.llmapi import KvCacheConfig, SamplingParams
|
||||
|
||||
|
||||
class HFModel:
|
||||
def __init__(self, model_name: str, device_id: int):
|
||||
self.device_id = device_id
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name, torch_dtype=torch.bfloat16
|
||||
).to(f"cuda:{device_id}")
|
||||
|
||||
def generate_batch_with_padding(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
responses: List[List[int]],
|
||||
prompt_max_len: int = 1024,
|
||||
micro_batch_size: int = 16,
|
||||
):
|
||||
"""
|
||||
Synchronous inference on a batch with micro-batching.
|
||||
Directly extracts response logprobs to save memory.
|
||||
|
||||
Args:
|
||||
input_ids: [batch_size, seq_len]
|
||||
attention_mask: [batch_size, seq_len]
|
||||
position_ids: [batch_size, seq_len]
|
||||
responses: List of response token IDs for each sample
|
||||
prompt_max_len: Maximum prompt length (default 1024)
|
||||
micro_batch_size: Size of each micro batch to avoid OOM
|
||||
|
||||
Returns:
|
||||
List of logprobs tensors, one per sample [response_len]
|
||||
"""
|
||||
# Move tensors to the correct device
|
||||
input_ids = input_ids.to(f"cuda:{self.device_id}")
|
||||
attention_mask = attention_mask.to(f"cuda:{self.device_id}")
|
||||
position_ids = position_ids.to(f"cuda:{self.device_id}")
|
||||
|
||||
batch_size = input_ids.shape[0]
|
||||
num_micro_batches = (batch_size + micro_batch_size - 1) // micro_batch_size
|
||||
|
||||
all_response_logprobs = []
|
||||
|
||||
with torch.no_grad():
|
||||
for micro_idx in range(num_micro_batches):
|
||||
start_idx = micro_idx * micro_batch_size
|
||||
end_idx = min((micro_idx + 1) * micro_batch_size, batch_size)
|
||||
|
||||
# Extract micro batch
|
||||
micro_input_ids = input_ids[start_idx:end_idx]
|
||||
micro_attention_mask = attention_mask[start_idx:end_idx]
|
||||
micro_position_ids = position_ids[start_idx:end_idx]
|
||||
|
||||
# Forward pass
|
||||
outputs = self.model(
|
||||
input_ids=micro_input_ids,
|
||||
attention_mask=micro_attention_mask,
|
||||
position_ids=micro_position_ids,
|
||||
)
|
||||
|
||||
# Extract response logprobs for each sample in this micro batch
|
||||
micro_logits = outputs.logits # [micro_batch_size, seq_len, vocab_size]
|
||||
|
||||
for i in range(micro_logits.shape[0]):
|
||||
sample_idx = start_idx + i
|
||||
response = responses[sample_idx]
|
||||
response_len = len(response)
|
||||
|
||||
# Extract logits for predicting response tokens
|
||||
# For predicting response[j], we need logits at position prompt_max_len-1+j
|
||||
response_logits = micro_logits[
|
||||
i, prompt_max_len - 1 : prompt_max_len - 1 + response_len, :
|
||||
]
|
||||
|
||||
# Convert to logprobs
|
||||
response_logprobs = torch.log_softmax(response_logits, dim=-1)
|
||||
|
||||
# Extract logprobs for the actual generated tokens
|
||||
response_tensor = torch.tensor(
|
||||
response, dtype=torch.long, device=response_logprobs.device
|
||||
)
|
||||
ref_logprob_for_tokens = torch.gather(
|
||||
response_logprobs, dim=-1, index=response_tensor.unsqueeze(-1)
|
||||
).squeeze(-1)
|
||||
|
||||
all_response_logprobs.append(ref_logprob_for_tokens)
|
||||
|
||||
# Free memory immediately after processing each micro batch
|
||||
del outputs, micro_logits
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return all_response_logprobs
|
||||
|
||||
|
||||
async def generate_batch_async(
|
||||
hf_model: HFModel,
|
||||
hf_model: RefHFModel,
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
@ -133,7 +42,7 @@ async def generate_batch_async(
|
||||
Runs the synchronous model inference in a thread pool.
|
||||
|
||||
Args:
|
||||
hf_model: HFModel instance
|
||||
hf_model: RefHFModel instance
|
||||
input_ids: Input token IDs
|
||||
attention_mask: Attention mask
|
||||
position_ids: Position IDs
|
||||
@ -163,89 +72,6 @@ async def generate_batch_async(
|
||||
return result
|
||||
|
||||
|
||||
def pad_data(
|
||||
original_prompts: List[List[int]],
|
||||
generated_token_ids_list: List[List[int]],
|
||||
prompt_max_len: int = 1024,
|
||||
response_max_len: int = 1024,
|
||||
pad_token_id: int = 0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Pad the data to the maximum length.
|
||||
|
||||
Structure:
|
||||
[left_pad | actual_prompt | actual_response | right_pad]
|
||||
|<-- prompt_max_len=1024 -->|<-- response_max_len=1024 -->|
|
||||
|
||||
Args:
|
||||
original_prompts: List of prompt token IDs, len = batch_size
|
||||
generated_token_ids_list: List of response token IDs, len = batch_size
|
||||
prompt_max_len: Maximum length for prompt section (default 1024)
|
||||
response_max_len: Maximum length for response section (default 1024)
|
||||
pad_token_id: Token ID for padding (default 0)
|
||||
Returns:
|
||||
input_ids: Tensor of shape [batch_size, prompt_max_len + response_max_len]
|
||||
attention_mask: Tensor of shape [batch_size, prompt_max_len + response_max_len]
|
||||
position_ids: Tensor of shape [batch_size, prompt_max_len + response_max_len]
|
||||
"""
|
||||
batch_size = len(original_prompts)
|
||||
total_len = prompt_max_len + response_max_len
|
||||
|
||||
for i, (prompt, response) in enumerate(zip(original_prompts, generated_token_ids_list)):
|
||||
assert len(prompt) <= prompt_max_len, (
|
||||
f"Batch {i}: Prompt length {len(prompt)} exceeds max {prompt_max_len}"
|
||||
)
|
||||
assert len(response) <= response_max_len, (
|
||||
f"Batch {i}: Response length {len(response)} exceeds max {response_max_len}"
|
||||
)
|
||||
|
||||
# Build batch tensors [batch_size, 2048]
|
||||
batch_input_ids = torch.full(
|
||||
(batch_size, total_len), pad_token_id, dtype=torch.long, device="cuda"
|
||||
)
|
||||
batch_attention_mask = torch.zeros((batch_size, total_len), dtype=torch.long, device="cuda")
|
||||
batch_position_ids = torch.zeros((batch_size, total_len), dtype=torch.long, device="cuda")
|
||||
|
||||
response_lens = []
|
||||
|
||||
for i in range(batch_size):
|
||||
prompt_tokens = original_prompts[i]
|
||||
response_tokens = generated_token_ids_list[i]
|
||||
|
||||
prompt_len = len(prompt_tokens)
|
||||
response_len = len(response_tokens)
|
||||
response_lens.append(response_len)
|
||||
|
||||
left_pad_len = prompt_max_len - prompt_len
|
||||
|
||||
# Fill input_ids: [left_pad | prompt | response | right_pad]
|
||||
prompt_start = left_pad_len
|
||||
prompt_end = prompt_max_len
|
||||
response_start = prompt_max_len
|
||||
response_end = prompt_max_len + response_len
|
||||
|
||||
batch_input_ids[i, prompt_start:prompt_end] = torch.tensor(
|
||||
prompt_tokens, dtype=torch.long, device="cuda"
|
||||
)
|
||||
batch_input_ids[i, response_start:response_end] = torch.tensor(
|
||||
response_tokens, dtype=torch.long, device="cuda"
|
||||
)
|
||||
|
||||
# Fill attention_mask: 1 for actual tokens, 0 for padding
|
||||
batch_attention_mask[i, prompt_start:response_end] = 1
|
||||
|
||||
# Fill position_ids: sequential for actual tokens
|
||||
actual_seq_len = prompt_len + response_len
|
||||
batch_position_ids[i, prompt_start:response_end] = torch.arange(
|
||||
actual_seq_len, dtype=torch.long, device="cuda"
|
||||
)
|
||||
# Right padding keeps the last position value
|
||||
if response_len < response_max_len:
|
||||
batch_position_ids[i, response_end:] = actual_seq_len - 1
|
||||
|
||||
return batch_input_ids, batch_attention_mask, batch_position_ids
|
||||
|
||||
|
||||
def compare_logprobs(logprobs_list, ref_new_token_logprobs_list):
|
||||
"""
|
||||
logprobs_list: List[torch.Tensor] - LLM logprob values
|
||||
@ -337,7 +163,7 @@ def test_accuracy_with_allreduce_strategy(model_dir, sampler_type, allreduce_str
|
||||
ray.shutdown()
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
input_ids, attention_mask, position_ids = pad_data(test_prompts, llm_responses)
|
||||
input_ids, attention_mask, position_ids = RefHFModel.pad_data(test_prompts, llm_responses)
|
||||
|
||||
# Split data across GPUs
|
||||
num_gpus = 4
|
||||
@ -347,7 +173,7 @@ def test_accuracy_with_allreduce_strategy(model_dir, sampler_type, allreduce_str
|
||||
|
||||
dp_hf_models = []
|
||||
for device_id in range(num_gpus):
|
||||
hf_model = HFModel(model_dir, device_id)
|
||||
hf_model = RefHFModel(model_dir, device_id)
|
||||
dp_hf_models.append(hf_model)
|
||||
|
||||
# Split input data and responses into chunks for each GPU
|
||||
@ -367,7 +193,7 @@ def test_accuracy_with_allreduce_strategy(model_dir, sampler_type, allreduce_str
|
||||
responses_chunks.append(llm_responses[start_idx:end_idx])
|
||||
|
||||
# Process each chunk on its corresponding GPU asynchronously
|
||||
async def process_all_chunks(hf_models: List[HFModel]):
|
||||
async def process_all_chunks(hf_models: List[RefHFModel]):
|
||||
tasks = []
|
||||
for i, (input_chunk, attn_chunk, pos_chunk, resp_chunk) in enumerate(
|
||||
zip(input_ids_chunks, attention_mask_chunks, position_ids_chunks, responses_chunks)
|
||||
|
||||
@ -1,62 +1,64 @@
|
||||
from typing import Callable, List, Optional
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch.multiprocessing.reductions import reduce_tensor
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
from utils.llm_data import llm_models_root
|
||||
from utils.torch_ref import RefHFModel
|
||||
|
||||
from tensorrt_llm import LLM
|
||||
from tensorrt_llm._torch.utils import get_device_uuid
|
||||
from tensorrt_llm.llmapi import KvCacheConfig, SamplingParams
|
||||
|
||||
|
||||
class HFModel:
|
||||
def __init__(self, model_name: str):
|
||||
class RefHFModelWithIPCHandles(RefHFModel):
|
||||
def __init__(self, model_dir: str, device_id: int = 0, num_hidden_layers: int = 4):
|
||||
self.device_id = device_id
|
||||
config = AutoConfig.from_pretrained(model_dir)
|
||||
config.num_hidden_layers = num_hidden_layers
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name, torch_dtype=torch.bfloat16
|
||||
).to("cuda")
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
self.cuda_device = torch.cuda.current_device()
|
||||
model_dir, config=config, torch_dtype=torch.bfloat16, attn_implementation="eager"
|
||||
).to(f"cuda:{device_id}")
|
||||
self.all_weights = {}
|
||||
self.device_uuid = [HFModel.get_device_uuid(i) for i in range(torch.cuda.device_count())]
|
||||
self.device_uuid = [get_device_uuid(i) for i in range(torch.cuda.device_count())]
|
||||
self._replicate_weights()
|
||||
|
||||
@staticmethod
|
||||
def get_device_uuid(cuda_device: int):
|
||||
from tensorrt_llm._torch.utils import get_device_uuid
|
||||
|
||||
return get_device_uuid(cuda_device)
|
||||
|
||||
def _replicate_weights(self):
|
||||
model_weights = []
|
||||
for n, p in self.model.named_parameters():
|
||||
model_weights.append((n, p.detach().clone()))
|
||||
|
||||
self.all_weights[self.cuda_device] = model_weights
|
||||
self.all_weights[self.device_id] = model_weights
|
||||
for i in range(torch.cuda.device_count()):
|
||||
if i != self.cuda_device:
|
||||
if i != self.device_id:
|
||||
cur_weights = []
|
||||
for n, p in self.all_weights[self.cuda_device]:
|
||||
for n, p in self.all_weights[self.device_id]:
|
||||
cur_weights.append((n, p.to("cuda:" + str(i))))
|
||||
self.all_weights[i] = cur_weights
|
||||
|
||||
def get_weight_ipc_handles(
|
||||
self,
|
||||
cuda_device: Optional[List[int]] = None,
|
||||
device_ids: Optional[List[int]] = None,
|
||||
weight_filter: Optional[Callable[[str], bool]] = None,
|
||||
):
|
||||
"""
|
||||
Get IPC handles for model weights with flexible filtering.
|
||||
|
||||
Args:
|
||||
cuda_device: List of CUDA device indices to get weights from
|
||||
device_ids: List of device indices to get weights from
|
||||
weight_filter: Optional function that takes weight name and returns True if weight should be included
|
||||
|
||||
Returns:
|
||||
ret: Dictionary containing weight handles
|
||||
"""
|
||||
ret = {}
|
||||
device_list = list(range(torch.cuda.device_count())) if cuda_device is None else cuda_device
|
||||
device_list = list(range(torch.cuda.device_count())) if device_ids is None else device_ids
|
||||
|
||||
for device in device_list:
|
||||
all_handles = []
|
||||
@ -71,50 +73,6 @@ class HFModel:
|
||||
|
||||
return ret
|
||||
|
||||
def generate_batch_incremental(
|
||||
self, original_prompts: List[str], generated_token_ids_list: List[List[int]]
|
||||
):
|
||||
"""
|
||||
Generate tokens incrementally for each prompt in the batch: [prompt, prompt+token0, prompt+token0+token1, ...]
|
||||
"""
|
||||
logits_list = []
|
||||
|
||||
for i in range(len(original_prompts)):
|
||||
base_token_ids = self.tokenizer.encode(original_prompts[i], return_tensors="pt")[0].to(
|
||||
"cuda"
|
||||
)
|
||||
cur_logits = []
|
||||
for j in range(len(generated_token_ids_list[i])):
|
||||
if j > 0:
|
||||
cur_gen_tokens = torch.tensor(generated_token_ids_list[i][:j]).to("cuda")
|
||||
cur_token_ids = torch.cat([base_token_ids, cur_gen_tokens], dim=-1)
|
||||
else:
|
||||
cur_token_ids = base_token_ids
|
||||
|
||||
ret = self.model.generate(
|
||||
input_ids=cur_token_ids.unsqueeze(0).cuda(),
|
||||
max_new_tokens=1,
|
||||
do_sample=False,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
)
|
||||
|
||||
cur_logits.append(ret["scores"][0])
|
||||
cur_logits = torch.stack(cur_logits, dim=0)
|
||||
logits_list.append(cur_logits.squeeze(1))
|
||||
|
||||
return logits_list
|
||||
|
||||
|
||||
def extract_tokens_from_outputs(outputs):
|
||||
"""Extract individual tokens from LLM outputs using token IDs directly"""
|
||||
tokens_list = []
|
||||
for output in outputs:
|
||||
# Get token IDs directly from the output
|
||||
token_ids = output.outputs[0].token_ids
|
||||
tokens_list.append(token_ids)
|
||||
return tokens_list
|
||||
|
||||
|
||||
def compare_logits(
|
||||
logits_list: List[torch.Tensor],
|
||||
@ -123,7 +81,6 @@ def compare_logits(
|
||||
threshold: float = 0.9,
|
||||
):
|
||||
assert len(logits_list) == len(ref_logits_list)
|
||||
|
||||
for i in range(len(logits_list)):
|
||||
assert logits_list[i].shape == ref_logits_list[i].shape
|
||||
lhs_idx = torch.topk(logits_list[i], topk, dim=-1).indices
|
||||
@ -142,119 +99,168 @@ def compare_logits(
|
||||
)
|
||||
|
||||
|
||||
def run_generate(llm, hf_model, prompts, sampling_params):
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
def run_generate(
|
||||
llm: LLM,
|
||||
hf_model: RefHFModel,
|
||||
prompts: List[List[int]],
|
||||
sampling_params: SamplingParams,
|
||||
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||
llm_responses = []
|
||||
llm_logits = []
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
for output in outputs:
|
||||
llm_logits.append(output.outputs[0].generation_logits)
|
||||
|
||||
generated_token_ids_list = extract_tokens_from_outputs(outputs)
|
||||
ref_logits = hf_model.generate_batch_incremental(prompts, generated_token_ids_list)
|
||||
llm_responses.append(output.outputs[0].token_ids)
|
||||
input_ids, attention_mask, position_ids = RefHFModel.pad_data(prompts, llm_responses)
|
||||
ref_logits = hf_model.generate_batch_with_padding(
|
||||
input_ids, attention_mask, position_ids, llm_responses, return_logits=True
|
||||
)
|
||||
return llm_logits, ref_logits
|
||||
|
||||
|
||||
def process_and_copy_folder(src_folder, dst_folder, num_hidden_layers: int = 4):
|
||||
if os.path.exists(dst_folder):
|
||||
shutil.rmtree(dst_folder)
|
||||
os.makedirs(dst_folder)
|
||||
|
||||
for root, dirs, files in os.walk(src_folder):
|
||||
rel_path = os.path.relpath(root, src_folder)
|
||||
dest_dir = os.path.join(dst_folder, rel_path)
|
||||
|
||||
if not os.path.exists(dest_dir):
|
||||
os.makedirs(dest_dir)
|
||||
|
||||
for file in files:
|
||||
src_path = os.path.join(root, file)
|
||||
dest_path = os.path.join(dest_dir, file)
|
||||
if "safetensor" in file:
|
||||
continue
|
||||
|
||||
if file == "config.json":
|
||||
with open(src_path, "r", encoding="utf-8") as f:
|
||||
config = json.load(f)
|
||||
config["num_hidden_layers"] = num_hidden_layers
|
||||
with open(dest_path, "w", encoding="utf-8") as f:
|
||||
json.dump(config, f, indent=2, ensure_ascii=False)
|
||||
else:
|
||||
shutil.copy2(src_path, dest_path)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_dir",
|
||||
["Qwen2.5-0.5B-Instruct", "Qwen3/Qwen3-8B", "llama-models-v2/TinyLlama-1.1B-Chat-v1.0"],
|
||||
[
|
||||
"llama-models-v2/TinyLlama-1.1B-Chat-v1.0",
|
||||
"Qwen2.5-0.5B-Instruct",
|
||||
"Qwen3/Qwen3-8B",
|
||||
"Qwen3/Qwen3-30B-A3B",
|
||||
"Qwen3/Qwen3-8B-FP8",
|
||||
"Qwen3/Qwen3-30B-A3B-FP8",
|
||||
],
|
||||
)
|
||||
def test_llm_update_weights(model_dir):
|
||||
model_dir = str(llm_models_root() / model_dir)
|
||||
kv_cache_config = KvCacheConfig(enable_block_reuse=True, free_gpu_memory_fraction=0.1)
|
||||
with TemporaryDirectory() as tmp_model_dir:
|
||||
num_hidden_layers = 1
|
||||
process_and_copy_folder(model_dir, tmp_model_dir, num_hidden_layers=num_hidden_layers)
|
||||
hf_model = RefHFModelWithIPCHandles(model_dir, num_hidden_layers=num_hidden_layers)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||
kv_cache_config = KvCacheConfig(enable_block_reuse=True, free_gpu_memory_fraction=0.1)
|
||||
llm = LLM(
|
||||
model=tmp_model_dir,
|
||||
ray_worker_extension_cls="tensorrt_llm.llmapi.rlhf_utils.WorkerExtension",
|
||||
tensor_parallel_size=1,
|
||||
load_format="dummy",
|
||||
pipeline_parallel_size=1,
|
||||
kv_cache_config=kv_cache_config,
|
||||
)
|
||||
|
||||
hf_model = HFModel(model_dir)
|
||||
# Generate texts from the prompts.
|
||||
prompts_texts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
prompts = [tokenizer.encode(prompt) for prompt in prompts_texts]
|
||||
del tokenizer
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0, return_generation_logits=True, max_tokens=1024
|
||||
)
|
||||
|
||||
llm = LLM(
|
||||
model=model_dir,
|
||||
ray_worker_extension_cls="tensorrt_llm.llmapi.rlhf_utils.WorkerExtension",
|
||||
tensor_parallel_size=1,
|
||||
load_format="dummy",
|
||||
pipeline_parallel_size=1,
|
||||
kv_cache_config=kv_cache_config,
|
||||
)
|
||||
ipc_handles = hf_model.get_weight_ipc_handles([0])
|
||||
|
||||
# Generate texts from the prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
llm._collective_rpc("update_weights", (ipc_handles,))
|
||||
# Finalize the update weights
|
||||
llm._collective_rpc("update_weights", (None,))
|
||||
|
||||
sampling_params = SamplingParams(temperature=0, return_generation_logits=True)
|
||||
|
||||
ipc_handles = hf_model.get_weight_ipc_handles([0])
|
||||
|
||||
llm._collective_rpc("update_weights", (ipc_handles,))
|
||||
# Finalize the update weights
|
||||
llm._collective_rpc("update_weights", (None,))
|
||||
|
||||
llm_logits, ref_logits = run_generate(llm, hf_model, prompts, sampling_params)
|
||||
compare_logits(llm_logits, ref_logits)
|
||||
llm_logits, ref_logits = run_generate(llm, hf_model, prompts, sampling_params)
|
||||
compare_logits(llm_logits, ref_logits)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_dir",
|
||||
["Qwen2.5-0.5B-Instruct", "Qwen3/Qwen3-8B", "llama-models-v2/TinyLlama-1.1B-Chat-v1.0"],
|
||||
[
|
||||
"llama-models-v2/TinyLlama-1.1B-Chat-v1.0",
|
||||
"Qwen2.5-0.5B-Instruct",
|
||||
"Qwen3/Qwen3-8B",
|
||||
"Qwen3/Qwen3-30B-A3B",
|
||||
"Qwen3/Qwen3-8B-FP8",
|
||||
"Qwen3/Qwen3-30B-A3B-FP8",
|
||||
],
|
||||
)
|
||||
def test_llm_partial_update_weights(model_dir):
|
||||
model_dir = str(llm_models_root() / model_dir)
|
||||
kv_cache_config = KvCacheConfig(enable_block_reuse=True, free_gpu_memory_fraction=0.1)
|
||||
with TemporaryDirectory() as tmp_model_dir:
|
||||
num_hidden_layers = 1
|
||||
process_and_copy_folder(model_dir, tmp_model_dir, num_hidden_layers=num_hidden_layers)
|
||||
hf_model = RefHFModelWithIPCHandles(model_dir, num_hidden_layers=num_hidden_layers)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||
kv_cache_config = KvCacheConfig(enable_block_reuse=True, free_gpu_memory_fraction=0.1)
|
||||
|
||||
hf_model = HFModel(model_dir)
|
||||
|
||||
llm = LLM(
|
||||
model=model_dir,
|
||||
ray_worker_extension_cls="tensorrt_llm.llmapi.rlhf_utils.WorkerExtension",
|
||||
tensor_parallel_size=1,
|
||||
load_format="dummy",
|
||||
pipeline_parallel_size=1,
|
||||
kv_cache_config=kv_cache_config,
|
||||
)
|
||||
|
||||
# Generate texts from the prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
sampling_params = SamplingParams(temperature=0, return_generation_logits=True)
|
||||
|
||||
ipc_handles = hf_model.get_weight_ipc_handles([0])
|
||||
|
||||
def common_filter(filter_name: str) -> Callable[[str], bool]:
|
||||
def filter_fn(name: str) -> bool:
|
||||
return filter_name in name
|
||||
|
||||
return filter_fn
|
||||
|
||||
filter_list = [
|
||||
"q_proj.weight",
|
||||
"k_proj.weight",
|
||||
"v_proj.weight",
|
||||
"o_proj.weight",
|
||||
"gate_proj.weight",
|
||||
"up_proj.weight",
|
||||
"down_proj.weight",
|
||||
"norm.weight",
|
||||
"embed_tokens.weight",
|
||||
"lm_head.weight",
|
||||
]
|
||||
if "Qwen2.5" in model_dir or "Qwen2" in model_dir:
|
||||
filter_list.extend(
|
||||
[
|
||||
"q_proj.bias",
|
||||
"k_proj.bias",
|
||||
"v_proj.bias",
|
||||
]
|
||||
llm = LLM(
|
||||
model=tmp_model_dir,
|
||||
ray_worker_extension_cls="tensorrt_llm.llmapi.rlhf_utils.WorkerExtension",
|
||||
tensor_parallel_size=1,
|
||||
load_format="dummy",
|
||||
pipeline_parallel_size=1,
|
||||
kv_cache_config=kv_cache_config,
|
||||
)
|
||||
for filter_name in filter_list:
|
||||
weight_filter = common_filter(filter_name=filter_name)
|
||||
ipc_handles = hf_model.get_weight_ipc_handles([0], weight_filter=weight_filter)
|
||||
llm._collective_rpc("update_weights", (ipc_handles,), non_block=True)
|
||||
# Finalize the update weights
|
||||
llm._collective_rpc("update_weights", (None,))
|
||||
|
||||
llm_logits, ref_logits = run_generate(llm, hf_model, prompts, sampling_params)
|
||||
compare_logits(llm_logits, ref_logits)
|
||||
# Generate texts from the prompts.
|
||||
prompts_texts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
prompts = [tokenizer.encode(prompt) for prompt in prompts_texts]
|
||||
del tokenizer
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0, return_generation_logits=True, max_tokens=1024
|
||||
)
|
||||
|
||||
def common_filter(filter_name: str) -> Callable[[str], bool]:
|
||||
def filter_fn(name: str) -> bool:
|
||||
return name.endswith(filter_name)
|
||||
|
||||
return filter_fn
|
||||
|
||||
# Generate filter_list from model weight keys by removing layer prefix
|
||||
# e.g., "model.layers.41.input_layernorm.weight" -> "input_layernorm.weight"
|
||||
layer_prefix_pattern = re.compile(r"^model\.layers\.\d+\.")
|
||||
filter_set = set()
|
||||
for name, _ in hf_model.all_weights[hf_model.device_id]:
|
||||
suffix = layer_prefix_pattern.sub("", name)
|
||||
filter_set.add(suffix)
|
||||
filter_list = list(filter_set)
|
||||
|
||||
for filter_name in filter_list:
|
||||
weight_filter = common_filter(filter_name=filter_name)
|
||||
ipc_handles = hf_model.get_weight_ipc_handles([0], weight_filter=weight_filter)
|
||||
llm._collective_rpc("update_weights", (ipc_handles,))
|
||||
# Finalize the update weights
|
||||
llm._collective_rpc("update_weights", (None,))
|
||||
|
||||
llm_logits, ref_logits = run_generate(llm, hf_model, prompts, sampling_params)
|
||||
compare_logits(llm_logits, ref_logits)
|
||||
|
||||
@ -13,12 +13,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import math
|
||||
from typing import Optional
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
|
||||
def geglu(x):
|
||||
@ -1239,3 +1240,188 @@ class recurrent_ref(nn.Module):
|
||||
x = self.linear_out(x)
|
||||
|
||||
return x, conv_state, lru_state
|
||||
|
||||
|
||||
class RefHFModel:
|
||||
|
||||
def __init__(self,
|
||||
model_dir: str,
|
||||
device_id: int = 0,
|
||||
additional_model_kargs: Optional[Dict[str, Any]] = None):
|
||||
self.device_id = device_id
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_dir, **(additional_model_kargs or {})).to(f"cuda:{device_id}")
|
||||
|
||||
def generate_batch_with_padding(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
responses: List[List[int]],
|
||||
prompt_max_len: int = 1024,
|
||||
micro_batch_size: int = 16,
|
||||
return_logits: bool = False,
|
||||
):
|
||||
"""
|
||||
Synchronous inference on a batch with micro-batching.
|
||||
Directly extracts response logprobs to save memory.
|
||||
|
||||
Args:
|
||||
input_ids: [batch_size, seq_len]
|
||||
attention_mask: [batch_size, seq_len]
|
||||
position_ids: [batch_size, seq_len]
|
||||
responses: List of response token IDs for each sample
|
||||
prompt_max_len: Maximum prompt length (default 1024)
|
||||
micro_batch_size: Size of each micro batch to avoid OOM
|
||||
return_logits: Whether to return logits, If True, return logits, otherwise return logprobs
|
||||
Returns:
|
||||
List of logits or logprobs tensors, one per sample [response_len]
|
||||
"""
|
||||
# Move tensors to the correct device
|
||||
input_ids = input_ids.to(f"cuda:{self.device_id}")
|
||||
attention_mask = attention_mask.to(f"cuda:{self.device_id}")
|
||||
position_ids = position_ids.to(f"cuda:{self.device_id}")
|
||||
|
||||
batch_size = input_ids.shape[0]
|
||||
num_micro_batches = (batch_size + micro_batch_size -
|
||||
1) // micro_batch_size
|
||||
|
||||
ref_results = []
|
||||
|
||||
with torch.no_grad():
|
||||
for micro_idx in range(num_micro_batches):
|
||||
start_idx = micro_idx * micro_batch_size
|
||||
end_idx = min((micro_idx + 1) * micro_batch_size, batch_size)
|
||||
|
||||
# Extract micro batch
|
||||
micro_input_ids = input_ids[start_idx:end_idx]
|
||||
micro_attention_mask = attention_mask[start_idx:end_idx]
|
||||
micro_position_ids = position_ids[start_idx:end_idx]
|
||||
|
||||
# Forward pass
|
||||
outputs = self.model(
|
||||
input_ids=micro_input_ids,
|
||||
attention_mask=micro_attention_mask,
|
||||
position_ids=micro_position_ids,
|
||||
)
|
||||
# Extract response logprobs for each sample in this micro batch
|
||||
for i in range(outputs.logits.shape[0]):
|
||||
sample_idx = start_idx + i
|
||||
response = responses[sample_idx]
|
||||
response_len = len(response)
|
||||
|
||||
# Extract logits for predicting response tokens
|
||||
# For predicting response[j], we need logits at position prompt_max_len-1+j
|
||||
response_logits = outputs.logits[i, prompt_max_len -
|
||||
1:prompt_max_len - 1 +
|
||||
response_len, :]
|
||||
if return_logits:
|
||||
ref_results.append(response_logits)
|
||||
else:
|
||||
# Convert to logprobs
|
||||
response_logprobs = torch.log_softmax(response_logits,
|
||||
dim=-1)
|
||||
|
||||
# Extract logprobs for the actual generated tokens
|
||||
response_tensor = torch.tensor(
|
||||
response,
|
||||
dtype=torch.long,
|
||||
device=response_logprobs.device)
|
||||
ref_logprob_for_tokens = torch.gather(
|
||||
response_logprobs,
|
||||
dim=-1,
|
||||
index=response_tensor.unsqueeze(-1)).squeeze(-1)
|
||||
|
||||
ref_results.append(ref_logprob_for_tokens)
|
||||
|
||||
# Free memory immediately after processing each micro batch
|
||||
del outputs
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return ref_results
|
||||
|
||||
@staticmethod
|
||||
def pad_data(
|
||||
original_prompts: List[List[int]],
|
||||
generated_token_ids_list: List[List[int]],
|
||||
prompt_max_len: int = 1024,
|
||||
response_max_len: int = 1024,
|
||||
pad_token_id: int = 0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Pad the data to the maximum length.
|
||||
|
||||
Structure:
|
||||
[left_pad | actual_prompt | actual_response | right_pad]
|
||||
|<-- prompt_max_len=1024 -->|<-- response_max_len=1024 -->|
|
||||
|
||||
Args:
|
||||
original_prompts: List of prompt token IDs, len = batch_size
|
||||
generated_token_ids_list: List of response token IDs, len = batch_size
|
||||
prompt_max_len: Maximum length for prompt section (default 1024)
|
||||
response_max_len: Maximum length for response section (default 1024)
|
||||
pad_token_id: Token ID for padding (default 0)
|
||||
Returns:
|
||||
input_ids: Tensor of shape [batch_size, prompt_max_len + response_max_len]
|
||||
attention_mask: Tensor of shape [batch_size, prompt_max_len + response_max_len]
|
||||
position_ids: Tensor of shape [batch_size, prompt_max_len + response_max_len]
|
||||
"""
|
||||
batch_size = len(original_prompts)
|
||||
total_len = prompt_max_len + response_max_len
|
||||
|
||||
for i, (prompt, response) in enumerate(
|
||||
zip(original_prompts, generated_token_ids_list)):
|
||||
assert len(prompt) <= prompt_max_len, (
|
||||
f"Batch {i}: Prompt length {len(prompt)} exceeds max {prompt_max_len}"
|
||||
)
|
||||
assert len(response) <= response_max_len, (
|
||||
f"Batch {i}: Response length {len(response)} exceeds max {response_max_len}"
|
||||
)
|
||||
|
||||
# Build batch tensors [batch_size, total_len]
|
||||
batch_input_ids = torch.full((batch_size, total_len),
|
||||
pad_token_id,
|
||||
dtype=torch.long,
|
||||
device="cuda")
|
||||
batch_attention_mask = torch.zeros((batch_size, total_len),
|
||||
dtype=torch.long,
|
||||
device="cuda")
|
||||
batch_position_ids = torch.zeros((batch_size, total_len),
|
||||
dtype=torch.long,
|
||||
device="cuda")
|
||||
|
||||
response_lens = []
|
||||
|
||||
for i in range(batch_size):
|
||||
prompt_tokens = original_prompts[i]
|
||||
response_tokens = generated_token_ids_list[i]
|
||||
|
||||
prompt_len = len(prompt_tokens)
|
||||
response_len = len(response_tokens)
|
||||
response_lens.append(response_len)
|
||||
|
||||
left_pad_len = prompt_max_len - prompt_len
|
||||
|
||||
# Fill input_ids: [left_pad | prompt | response | right_pad]
|
||||
prompt_start = left_pad_len
|
||||
prompt_end = prompt_max_len
|
||||
response_start = prompt_max_len
|
||||
response_end = prompt_max_len + response_len
|
||||
|
||||
batch_input_ids[i, prompt_start:prompt_end] = torch.tensor(
|
||||
prompt_tokens, dtype=torch.long, device="cuda")
|
||||
batch_input_ids[i, response_start:response_end] = torch.tensor(
|
||||
response_tokens, dtype=torch.long, device="cuda")
|
||||
|
||||
# Fill attention_mask: 1 for actual tokens, 0 for padding
|
||||
batch_attention_mask[i, prompt_start:response_end] = 1
|
||||
|
||||
# Fill position_ids: sequential for actual tokens
|
||||
actual_seq_len = prompt_len + response_len
|
||||
batch_position_ids[i, prompt_start:response_end] = torch.arange(
|
||||
actual_seq_len, dtype=torch.long, device="cuda")
|
||||
# Right padding keeps the last position value
|
||||
if response_len < response_max_len:
|
||||
batch_position_ids[i, response_end:] = actual_seq_len - 1
|
||||
|
||||
return batch_input_ids, batch_attention_mask, batch_position_ids
|
||||
|
||||
Loading…
Reference in New Issue
Block a user