[TRTLLM-9771][feat] Make update_weights compatible with CUDA Graph (#11267)

Signed-off-by: Shuyi Xiong <219646547+shuyixiong@users.noreply.github.com>
This commit is contained in:
shuyixiong 2026-02-10 14:12:49 +08:00 committed by GitHub
parent 8b2dc57823
commit c3cdc93211
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 279 additions and 59 deletions

View File

@ -1,7 +1,7 @@
import inspect
import math
from abc import ABC, abstractmethod
from typing import Dict, List, NamedTuple, Optional, Union
from typing import Dict, List, NamedTuple, Optional, Tuple, Union
import torch
import torch.nn.functional as F
@ -543,10 +543,9 @@ class FusedMoEMethodBase(ABC):
def pre_reload_weights(self, module: torch.nn.Module):
for param_name, metadata in module.rebuild_tensor_metadata.items():
logger.warning(
f"Pre-reloading weight '{param_name}' requires tensor re-creation, which will invalidate existing CUDA graphs."
)
param = torch.nn.Parameter(torch.empty_like(metadata,
# Extract meta tensor from metadata dict
meta_tensor = metadata['meta']
param = torch.nn.Parameter(torch.empty_like(meta_tensor,
device="cuda"),
requires_grad=False)
module.register_parameter(param_name, param)
@ -999,6 +998,23 @@ class DeepSeekFP8BlockScalesFusedMoEMethod(FusedMoEMethodBase):
super().post_load_weights(module)
def resmooth_and_transform_fp8_scale(
weight: torch.Tensor,
weight_scale: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Resmooth weight/scale to FP8 E8M0 and transform scale to required layout for MoE."""
resmoothed_weight, resmoothed_scale = resmooth_to_fp8_e8m0(
weight, weight_scale)
transformed_scale = transform_sf_into_required_layout(
resmoothed_scale,
mn=weight.shape[1],
k=weight.shape[2],
recipe=(1, 128, 128),
num_groups=weight.shape[0],
is_sfa=False)
return resmoothed_weight, transformed_scale
class DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm(
DeepSeekFP8BlockScalesFusedMoEMethod):
@ -1007,53 +1023,70 @@ class DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm(
weights: List[Dict],
weight_loading_mode: MoEWeightLoadingMode,
allow_partial_loading: bool = False):
if is_sm_100f():
expert_ids = set(module.initial_local_expert_ids)
if self.need_load_shared_weights(module):
expert_ids.update(
module.layer_load_balancer.get_load_expert_ids())
for name in list(weights.keys()):
if name.endswith("weight_scale_inv"):
if int(name.split(".")[0]) not in expert_ids:
continue
weight_name = name.replace("weight_scale_inv", "weight")
logger.debug(f"Resmoothing {weight_name}")
weight = weights[weight_name][:]
scale = weights[name][:]
weights[weight_name], weights[name] = resmooth_to_fp8_e8m0(
weight, scale)
super().load_weights(module, weights, weight_loading_mode,
allow_partial_loading)
def post_load_weights(self, module: torch.nn.Module):
super().post_load_weights(module)
if is_sm_100f():
transfromed_w3_w1_scale = transform_sf_into_required_layout(
module.quant_scales[0],
mn=module.w3_w1_weight.shape[1],
k=module.w3_w1_weight.shape[2],
recipe=(1, 128, 128),
num_groups=module.w3_w1_weight.shape[0],
is_sfa=False)
transformed_w3_w1_weight_scaling_factor = nn.Parameter(
transfromed_w3_w1_scale, requires_grad=False)
replace_parameter_and_save_metadata(
module, "w3_w1_weight_scaling_factor",
transformed_w3_w1_weight_scaling_factor,
module.rebuild_tensor_metadata)
transfromed_w2_scale = transform_sf_into_required_layout(
module.quant_scales[1],
mn=module.w2_weight.shape[1],
k=module.w2_weight.shape[2],
recipe=(1, 128, 128),
num_groups=module.w3_w1_weight.shape[0],
is_sfa=False)
transformed_w2_weight_scaling_factor = nn.Parameter(
transfromed_w2_scale, requires_grad=False)
replace_parameter_and_save_metadata(
module, "w2_weight_scaling_factor",
transformed_w2_weight_scaling_factor,
module.rebuild_tensor_metadata)
# Resmooth shared experts before registering shared weights
if self.need_load_shared_weights(module):
local_shared_load_expert_ids = module.layer_load_balancer.get_load_expert_ids(
)
if getattr(module, 'local_shared_w3_w1_tensors',
None) is not None:
num_shared_experts = len(local_shared_load_expert_ids)
logger.debug(
f"Batch resmoothing {num_shared_experts} shared experts"
)
local_shared_w3_w1_tensors = getattr(
module, 'local_shared_w3_w1_tensors')
local_shared_w3_w1_scale_tensors = getattr(
module, 'local_shared_w3_w1_scale_tensors')
local_shared_w2_tensors = getattr(
module, 'local_shared_w2_tensors')
local_shared_w2_scale_tensors = getattr(
module, 'local_shared_w2_scale_tensors')
resmoothed_shared_w3_w1_weight, transformed_shared_w3_w1_scale = resmooth_and_transform_fp8_scale(
local_shared_w3_w1_tensors,
local_shared_w3_w1_scale_tensors)
setattr(module, 'local_shared_w3_w1_tensors',
resmoothed_shared_w3_w1_weight.cpu())
setattr(module, 'local_shared_w3_w1_scale_tensors',
transformed_shared_w3_w1_scale.cpu())
resmoothed_shared_w2_weight, transformed_shared_w2_scale = resmooth_and_transform_fp8_scale(
local_shared_w2_tensors, local_shared_w2_scale_tensors)
setattr(module, 'local_shared_w2_tensors',
resmoothed_shared_w2_weight.cpu())
setattr(module, 'local_shared_w2_scale_tensors',
transformed_shared_w2_scale.cpu())
# Call super() after resmooth shared experts (local_shared tensors will be deleted in super().post_load_weights())
super().post_load_weights(module)
if is_sm_100f():
logger.debug("Resmoothing FP8 weights in post_load_weights")
resmoothed_w3_w1_weight, transformed_w3_w1_scale = resmooth_and_transform_fp8_scale(
module.w3_w1_weight, module.w3_w1_weight_scaling_factor)
replace_parameter_and_save_metadata(module, "w3_w1_weight",
resmoothed_w3_w1_weight,
module.rebuild_tensor_metadata)
replace_parameter_and_save_metadata(module,
"w3_w1_weight_scaling_factor",
transformed_w3_w1_scale,
module.rebuild_tensor_metadata)
resmoothed_w2_weight, transformed_w2_scale = resmooth_and_transform_fp8_scale(
module.w2_weight, module.w2_weight_scaling_factor)
replace_parameter_and_save_metadata(module, "w2_weight",
resmoothed_w2_weight,
module.rebuild_tensor_metadata)
replace_parameter_and_save_metadata(module,
"w2_weight_scaling_factor",
transformed_w2_scale,
module.rebuild_tensor_metadata)
self.setup_quant_scales(module)

View File

@ -518,10 +518,9 @@ class UnquantizedLinearMethod(LinearMethodBase):
def pre_reload_weights(self, module: Linear):
for param_name, metadata in module.rebuild_tensor_metadata.items():
logger.warning(
f"Pre-reloading weight '{param_name}' requires tensor re-creation, which will invalidate existing CUDA graphs."
)
param = Parameter(torch.empty_like(metadata, device="cuda"),
# Extract meta tensor from metadata dict
meta_tensor = metadata['meta']
param = Parameter(torch.empty_like(meta_tensor, device="cuda"),
requires_grad=False)
module.register_parameter(param_name, param)

View File

@ -426,13 +426,37 @@ def maybe_compiled_cat(tensors, dim):
return torch.cat(tensors, dim)
def replace_parameter_and_save_metadata(module: torch.nn.Module,
param_name: str,
new_param: torch.nn.Parameter,
metadata_dict: Dict):
def replace_parameter_and_save_metadata(
module: torch.nn.Module, param_name: str,
new_param: torch.nn.Parameter | torch.Tensor, metadata_dict: Dict):
"""
Replace a parameter in a module and save the metadata of the original parameter.
On first call: saves original param's meta tensor and new param's tensor, then replaces.
On subsequent calls: copies new_param data into the saved tensor, then registers it.
"""
saved_param = None
if param_name not in metadata_dict:
metadata_dict[param_name] = getattr(module, param_name).to("meta")
module.register_parameter(param_name, new_param)
# First time: save original meta tensor and the new param tensor reference
original_meta = getattr(module, param_name).to("meta")
# Convert new_param to Parameter if it's a Tensor, otherwise use directly
if isinstance(new_param, torch.nn.Parameter):
saved_param = new_param
elif isinstance(new_param, torch.Tensor):
saved_param = torch.nn.Parameter(new_param, requires_grad=False)
else:
raise ValueError(f"Invalid type {type(new_param)} for new_param")
metadata_dict[param_name] = {
'meta': original_meta,
'param': saved_param
}
else:
# Subsequent calls: copy new_param into the saved tensor
saved_param = metadata_dict[param_name]['param']
if isinstance(new_param, torch.nn.Parameter):
saved_param.data.copy_(new_param.data)
elif isinstance(new_param, torch.Tensor):
saved_param.data.copy_(new_param)
else:
raise ValueError(f"Invalid type {type(new_param)} for new_param")
module.register_parameter(param_name, saved_param)

View File

@ -44,6 +44,7 @@ l0_dgx_b200:
orchestrator: ray
tests:
- unittest/llmapi/test_llm_multi_gpu_pytorch.py -m "gpu4"
- unittest/_torch/ray_orchestrator/multi_gpu/test_llm_update_weights_multi_gpu.py
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=False-attn_backend=TRTLLM-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp2pp2-fp8kv=False-attn_backend=TRTLLM-torch_compile=False]
- disaggregated/test_disaggregated.py::test_disaggregated_ctxpp2_genpp2[TinyLlama-1.1B-Chat-v1.0]

View File

@ -0,0 +1,154 @@
import re
from typing import Callable
import pytest
from _torch.ray_orchestrator.single_gpu.test_llm_update_weights import (
RefHFModelWithIPCHandles,
compare_logits,
run_generate,
)
from transformers import AutoTokenizer
from utils.llm_data import llm_models_root
from utils.util import skip_pre_blackwell
from tensorrt_llm import LLM
from tensorrt_llm.llmapi import KvCacheConfig, SamplingParams
@skip_pre_blackwell
@pytest.mark.parametrize(
"model_dir, fp8_model_dir",
[
("Qwen3/Qwen3-8B", "Qwen3/Qwen3-8B-FP8"),
("Qwen3/Qwen3-30B-A3B", "Qwen3/Qwen3-30B-A3B-FP8"),
],
)
def test_llm_update_weights_with_quant_config(model_dir, fp8_model_dir):
model_dir = str(llm_models_root() / model_dir)
fp8_model_dir = str(llm_models_root() / fp8_model_dir)
additional_kwargs = {}
if "Qwen3/Qwen3-30B-A3B" in model_dir:
additional_kwargs["moe_config"] = {
"backend": "DEEPGEMM",
}
num_hidden_layers = 1
hf_model = RefHFModelWithIPCHandles(fp8_model_dir, num_hidden_layers=num_hidden_layers)
tokenizer = AutoTokenizer.from_pretrained(fp8_model_dir)
kv_cache_config = KvCacheConfig(enable_block_reuse=True, free_gpu_memory_fraction=0.1)
llm = LLM(
model=model_dir,
ray_worker_extension_cls="tensorrt_llm.llmapi.rlhf_utils.WorkerExtension",
tensor_parallel_size=2,
load_format="dummy",
pipeline_parallel_size=1,
kv_cache_config=kv_cache_config,
model_kwargs={
"num_hidden_layers": num_hidden_layers,
"quantization_config": {
"activation_scheme": "dynamic",
"fmt": "e4m3",
"quant_method": "fp8",
"weight_block_size": [128, 128],
},
},
**additional_kwargs,
)
# Generate texts from the prompts.
prompts_texts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
prompts = [tokenizer.encode(prompt) for prompt in prompts_texts]
del tokenizer
sampling_params = SamplingParams(temperature=0, return_generation_logits=True, max_tokens=1024)
ipc_handles = hf_model.get_weight_ipc_handles_serialized([0, 1])
llm._collective_rpc("update_weights", (ipc_handles,))
# Finalize the update weights
llm._collective_rpc("update_weights", (None,))
llm_logits, ref_logits = run_generate(llm, hf_model, prompts, sampling_params)
compare_logits(llm_logits, ref_logits)
@skip_pre_blackwell
@pytest.mark.parametrize(
"model_dir, fp8_model_dir",
[
("Qwen3/Qwen3-8B", "Qwen3/Qwen3-8B-FP8"),
("Qwen3/Qwen3-30B-A3B", "Qwen3/Qwen3-30B-A3B-FP8"),
],
)
def test_llm_partial_update_weights_with_quant_config(model_dir, fp8_model_dir):
model_dir = str(llm_models_root() / model_dir)
fp8_model_dir = str(llm_models_root() / fp8_model_dir)
additional_kwargs = {}
if "Qwen3/Qwen3-30B-A3B" in model_dir:
additional_kwargs["moe_config"] = {
"backend": "DEEPGEMM",
}
num_hidden_layers = 1
hf_model = RefHFModelWithIPCHandles(fp8_model_dir, num_hidden_layers=num_hidden_layers)
tokenizer = AutoTokenizer.from_pretrained(fp8_model_dir)
kv_cache_config = KvCacheConfig(enable_block_reuse=True, free_gpu_memory_fraction=0.1)
llm = LLM(
model=model_dir,
ray_worker_extension_cls="tensorrt_llm.llmapi.rlhf_utils.WorkerExtension",
tensor_parallel_size=2,
load_format="dummy",
pipeline_parallel_size=1,
kv_cache_config=kv_cache_config,
model_kwargs={
"num_hidden_layers": num_hidden_layers,
"quantization_config": {
"activation_scheme": "dynamic",
"fmt": "e4m3",
"quant_method": "fp8",
"weight_block_size": [128, 128],
},
},
**additional_kwargs,
)
# Generate texts from the prompts.
prompts_texts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
prompts = [tokenizer.encode(prompt) for prompt in prompts_texts]
del tokenizer
sampling_params = SamplingParams(temperature=0, return_generation_logits=True, max_tokens=1024)
def common_filter(filter_name: str) -> Callable[[str], bool]:
def filter_fn(name: str) -> bool:
return name.endswith(filter_name)
return filter_fn
# Generate filter_list from model weight keys by removing layer prefix
# e.g., "model.layers.41.input_layernorm.weight" -> "input_layernorm.weight"
layer_prefix_pattern = re.compile(r"^model\.layers\.\d+\.")
filter_set = set()
for name, _ in hf_model.all_weights[hf_model.device_id]:
suffix = layer_prefix_pattern.sub("", name)
filter_set.add(suffix)
filter_list = list(filter_set)
for filter_name in filter_list:
weight_filter = common_filter(filter_name=filter_name)
ipc_handles = hf_model.get_weight_ipc_handles_serialized(
[0, 1], weight_filter=weight_filter
)
llm._collective_rpc("update_weights", (ipc_handles,))
# Finalize the update weights
llm._collective_rpc("update_weights", (None,))
llm_logits, ref_logits = run_generate(llm, hf_model, prompts, sampling_params)
compare_logits(llm_logits, ref_logits)

View File

@ -9,10 +9,11 @@ from torch.multiprocessing.reductions import reduce_tensor
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from utils.llm_data import llm_models_root
from utils.torch_ref import RefHFModel
from utils.util import getSMVersion, skip_pre_hopper
from tensorrt_llm import LLM
from tensorrt_llm._torch.utils import get_device_uuid
from tensorrt_llm.llmapi import KvCacheConfig, SamplingParams
from tensorrt_llm.llmapi import KvCacheConfig, MoeConfig, SamplingParams
class RefHFModelWithIPCHandles(RefHFModel):
@ -119,6 +120,7 @@ def run_generate(
return llm_logits, ref_logits
@skip_pre_hopper
@pytest.mark.parametrize(
"model_dir",
[
@ -136,6 +138,7 @@ def test_llm_update_weights(model_dir):
hf_model = RefHFModelWithIPCHandles(model_dir, num_hidden_layers=num_hidden_layers)
tokenizer = AutoTokenizer.from_pretrained(model_dir)
kv_cache_config = KvCacheConfig(enable_block_reuse=True, free_gpu_memory_fraction=0.1)
moe_config = MoeConfig(backend="DEEPGEMM" if getSMVersion() >= 100 else "CUTLASS")
llm = LLM(
model=model_dir,
ray_worker_extension_cls="tensorrt_llm.llmapi.rlhf_utils.WorkerExtension",
@ -144,6 +147,7 @@ def test_llm_update_weights(model_dir):
pipeline_parallel_size=1,
kv_cache_config=kv_cache_config,
model_kwargs={"num_hidden_layers": num_hidden_layers},
moe_config=moe_config,
)
# Generate texts from the prompts.
@ -167,6 +171,7 @@ def test_llm_update_weights(model_dir):
compare_logits(llm_logits, ref_logits)
@skip_pre_hopper
@pytest.mark.parametrize(
"model_dir",
[
@ -184,7 +189,7 @@ def test_llm_partial_update_weights(model_dir):
hf_model = RefHFModelWithIPCHandles(model_dir, num_hidden_layers=num_hidden_layers)
tokenizer = AutoTokenizer.from_pretrained(model_dir)
kv_cache_config = KvCacheConfig(enable_block_reuse=True, free_gpu_memory_fraction=0.1)
moe_config = MoeConfig(backend="DEEPGEMM" if getSMVersion() >= 100 else "CUTLASS")
llm = LLM(
model=model_dir,
ray_worker_extension_cls="tensorrt_llm.llmapi.rlhf_utils.WorkerExtension",
@ -193,6 +198,7 @@ def test_llm_partial_update_weights(model_dir):
pipeline_parallel_size=1,
kv_cache_config=kv_cache_config,
model_kwargs={"num_hidden_layers": num_hidden_layers},
moe_config=moe_config,
)
# Generate texts from the prompts.
@ -233,6 +239,7 @@ def test_llm_partial_update_weights(model_dir):
compare_logits(llm_logits, ref_logits)
@skip_pre_hopper
@pytest.mark.parametrize(
"model_dir, fp8_model_dir",
[
@ -247,6 +254,7 @@ def test_llm_update_weights_with_quant_config(model_dir, fp8_model_dir):
hf_model = RefHFModelWithIPCHandles(fp8_model_dir, num_hidden_layers=num_hidden_layers)
tokenizer = AutoTokenizer.from_pretrained(fp8_model_dir)
kv_cache_config = KvCacheConfig(enable_block_reuse=True, free_gpu_memory_fraction=0.1)
moe_config = MoeConfig(backend="DEEPGEMM" if getSMVersion() >= 100 else "CUTLASS")
llm = LLM(
model=model_dir,
ray_worker_extension_cls="tensorrt_llm.llmapi.rlhf_utils.WorkerExtension",
@ -263,6 +271,7 @@ def test_llm_update_weights_with_quant_config(model_dir, fp8_model_dir):
"weight_block_size": [128, 128],
},
},
moe_config=moe_config,
)
# Generate texts from the prompts.