From c3cdc9321113b245172c837f6983657b501f2a3d Mon Sep 17 00:00:00 2001 From: shuyixiong <219646547+shuyixiong@users.noreply.github.com> Date: Tue, 10 Feb 2026 14:12:49 +0800 Subject: [PATCH] [TRTLLM-9771][feat] Make update_weights compatible with CUDA Graph (#11267) Signed-off-by: Shuyi Xiong <219646547+shuyixiong@users.noreply.github.com> --- .../_torch/modules/fused_moe/quantization.py | 127 +++++++++------ tensorrt_llm/_torch/modules/linear.py | 7 +- tensorrt_llm/_torch/utils.py | 36 +++- .../test_lists/test-db/l0_dgx_b200.yml | 1 + .../test_llm_update_weights_multi_gpu.py | 154 ++++++++++++++++++ .../single_gpu/test_llm_update_weights.py | 13 +- 6 files changed, 279 insertions(+), 59 deletions(-) create mode 100644 tests/unittest/_torch/ray_orchestrator/multi_gpu/test_llm_update_weights_multi_gpu.py diff --git a/tensorrt_llm/_torch/modules/fused_moe/quantization.py b/tensorrt_llm/_torch/modules/fused_moe/quantization.py index 2e9b43550b..21fd1940a3 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/quantization.py +++ b/tensorrt_llm/_torch/modules/fused_moe/quantization.py @@ -1,7 +1,7 @@ import inspect import math from abc import ABC, abstractmethod -from typing import Dict, List, NamedTuple, Optional, Union +from typing import Dict, List, NamedTuple, Optional, Tuple, Union import torch import torch.nn.functional as F @@ -543,10 +543,9 @@ class FusedMoEMethodBase(ABC): def pre_reload_weights(self, module: torch.nn.Module): for param_name, metadata in module.rebuild_tensor_metadata.items(): - logger.warning( - f"Pre-reloading weight '{param_name}' requires tensor re-creation, which will invalidate existing CUDA graphs." - ) - param = torch.nn.Parameter(torch.empty_like(metadata, + # Extract meta tensor from metadata dict + meta_tensor = metadata['meta'] + param = torch.nn.Parameter(torch.empty_like(meta_tensor, device="cuda"), requires_grad=False) module.register_parameter(param_name, param) @@ -999,6 +998,23 @@ class DeepSeekFP8BlockScalesFusedMoEMethod(FusedMoEMethodBase): super().post_load_weights(module) +def resmooth_and_transform_fp8_scale( + weight: torch.Tensor, + weight_scale: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Resmooth weight/scale to FP8 E8M0 and transform scale to required layout for MoE.""" + resmoothed_weight, resmoothed_scale = resmooth_to_fp8_e8m0( + weight, weight_scale) + transformed_scale = transform_sf_into_required_layout( + resmoothed_scale, + mn=weight.shape[1], + k=weight.shape[2], + recipe=(1, 128, 128), + num_groups=weight.shape[0], + is_sfa=False) + return resmoothed_weight, transformed_scale + + class DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm( DeepSeekFP8BlockScalesFusedMoEMethod): @@ -1007,53 +1023,70 @@ class DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm( weights: List[Dict], weight_loading_mode: MoEWeightLoadingMode, allow_partial_loading: bool = False): - if is_sm_100f(): - expert_ids = set(module.initial_local_expert_ids) - if self.need_load_shared_weights(module): - expert_ids.update( - module.layer_load_balancer.get_load_expert_ids()) - for name in list(weights.keys()): - if name.endswith("weight_scale_inv"): - if int(name.split(".")[0]) not in expert_ids: - continue - weight_name = name.replace("weight_scale_inv", "weight") - logger.debug(f"Resmoothing {weight_name}") - weight = weights[weight_name][:] - scale = weights[name][:] - weights[weight_name], weights[name] = resmooth_to_fp8_e8m0( - weight, scale) super().load_weights(module, weights, weight_loading_mode, allow_partial_loading) def post_load_weights(self, module: torch.nn.Module): - super().post_load_weights(module) if is_sm_100f(): - transfromed_w3_w1_scale = transform_sf_into_required_layout( - module.quant_scales[0], - mn=module.w3_w1_weight.shape[1], - k=module.w3_w1_weight.shape[2], - recipe=(1, 128, 128), - num_groups=module.w3_w1_weight.shape[0], - is_sfa=False) - transformed_w3_w1_weight_scaling_factor = nn.Parameter( - transfromed_w3_w1_scale, requires_grad=False) - replace_parameter_and_save_metadata( - module, "w3_w1_weight_scaling_factor", - transformed_w3_w1_weight_scaling_factor, - module.rebuild_tensor_metadata) - transfromed_w2_scale = transform_sf_into_required_layout( - module.quant_scales[1], - mn=module.w2_weight.shape[1], - k=module.w2_weight.shape[2], - recipe=(1, 128, 128), - num_groups=module.w3_w1_weight.shape[0], - is_sfa=False) - transformed_w2_weight_scaling_factor = nn.Parameter( - transfromed_w2_scale, requires_grad=False) - replace_parameter_and_save_metadata( - module, "w2_weight_scaling_factor", - transformed_w2_weight_scaling_factor, - module.rebuild_tensor_metadata) + # Resmooth shared experts before registering shared weights + if self.need_load_shared_weights(module): + local_shared_load_expert_ids = module.layer_load_balancer.get_load_expert_ids( + ) + if getattr(module, 'local_shared_w3_w1_tensors', + None) is not None: + num_shared_experts = len(local_shared_load_expert_ids) + logger.debug( + f"Batch resmoothing {num_shared_experts} shared experts" + ) + + local_shared_w3_w1_tensors = getattr( + module, 'local_shared_w3_w1_tensors') + local_shared_w3_w1_scale_tensors = getattr( + module, 'local_shared_w3_w1_scale_tensors') + local_shared_w2_tensors = getattr( + module, 'local_shared_w2_tensors') + local_shared_w2_scale_tensors = getattr( + module, 'local_shared_w2_scale_tensors') + + resmoothed_shared_w3_w1_weight, transformed_shared_w3_w1_scale = resmooth_and_transform_fp8_scale( + local_shared_w3_w1_tensors, + local_shared_w3_w1_scale_tensors) + setattr(module, 'local_shared_w3_w1_tensors', + resmoothed_shared_w3_w1_weight.cpu()) + setattr(module, 'local_shared_w3_w1_scale_tensors', + transformed_shared_w3_w1_scale.cpu()) + + resmoothed_shared_w2_weight, transformed_shared_w2_scale = resmooth_and_transform_fp8_scale( + local_shared_w2_tensors, local_shared_w2_scale_tensors) + setattr(module, 'local_shared_w2_tensors', + resmoothed_shared_w2_weight.cpu()) + setattr(module, 'local_shared_w2_scale_tensors', + transformed_shared_w2_scale.cpu()) + + # Call super() after resmooth shared experts (local_shared tensors will be deleted in super().post_load_weights()) + super().post_load_weights(module) + + if is_sm_100f(): + logger.debug("Resmoothing FP8 weights in post_load_weights") + resmoothed_w3_w1_weight, transformed_w3_w1_scale = resmooth_and_transform_fp8_scale( + module.w3_w1_weight, module.w3_w1_weight_scaling_factor) + replace_parameter_and_save_metadata(module, "w3_w1_weight", + resmoothed_w3_w1_weight, + module.rebuild_tensor_metadata) + replace_parameter_and_save_metadata(module, + "w3_w1_weight_scaling_factor", + transformed_w3_w1_scale, + module.rebuild_tensor_metadata) + + resmoothed_w2_weight, transformed_w2_scale = resmooth_and_transform_fp8_scale( + module.w2_weight, module.w2_weight_scaling_factor) + replace_parameter_and_save_metadata(module, "w2_weight", + resmoothed_w2_weight, + module.rebuild_tensor_metadata) + replace_parameter_and_save_metadata(module, + "w2_weight_scaling_factor", + transformed_w2_scale, + module.rebuild_tensor_metadata) self.setup_quant_scales(module) diff --git a/tensorrt_llm/_torch/modules/linear.py b/tensorrt_llm/_torch/modules/linear.py index 1bf48e1821..65811569ca 100644 --- a/tensorrt_llm/_torch/modules/linear.py +++ b/tensorrt_llm/_torch/modules/linear.py @@ -518,10 +518,9 @@ class UnquantizedLinearMethod(LinearMethodBase): def pre_reload_weights(self, module: Linear): for param_name, metadata in module.rebuild_tensor_metadata.items(): - logger.warning( - f"Pre-reloading weight '{param_name}' requires tensor re-creation, which will invalidate existing CUDA graphs." - ) - param = Parameter(torch.empty_like(metadata, device="cuda"), + # Extract meta tensor from metadata dict + meta_tensor = metadata['meta'] + param = Parameter(torch.empty_like(meta_tensor, device="cuda"), requires_grad=False) module.register_parameter(param_name, param) diff --git a/tensorrt_llm/_torch/utils.py b/tensorrt_llm/_torch/utils.py index d441cbd1da..876e4077a0 100644 --- a/tensorrt_llm/_torch/utils.py +++ b/tensorrt_llm/_torch/utils.py @@ -426,13 +426,37 @@ def maybe_compiled_cat(tensors, dim): return torch.cat(tensors, dim) -def replace_parameter_and_save_metadata(module: torch.nn.Module, - param_name: str, - new_param: torch.nn.Parameter, - metadata_dict: Dict): +def replace_parameter_and_save_metadata( + module: torch.nn.Module, param_name: str, + new_param: torch.nn.Parameter | torch.Tensor, metadata_dict: Dict): """ Replace a parameter in a module and save the metadata of the original parameter. + On first call: saves original param's meta tensor and new param's tensor, then replaces. + On subsequent calls: copies new_param data into the saved tensor, then registers it. """ + saved_param = None if param_name not in metadata_dict: - metadata_dict[param_name] = getattr(module, param_name).to("meta") - module.register_parameter(param_name, new_param) + # First time: save original meta tensor and the new param tensor reference + original_meta = getattr(module, param_name).to("meta") + # Convert new_param to Parameter if it's a Tensor, otherwise use directly + if isinstance(new_param, torch.nn.Parameter): + saved_param = new_param + elif isinstance(new_param, torch.Tensor): + saved_param = torch.nn.Parameter(new_param, requires_grad=False) + else: + raise ValueError(f"Invalid type {type(new_param)} for new_param") + metadata_dict[param_name] = { + 'meta': original_meta, + 'param': saved_param + } + else: + # Subsequent calls: copy new_param into the saved tensor + saved_param = metadata_dict[param_name]['param'] + if isinstance(new_param, torch.nn.Parameter): + saved_param.data.copy_(new_param.data) + elif isinstance(new_param, torch.Tensor): + saved_param.data.copy_(new_param) + else: + raise ValueError(f"Invalid type {type(new_param)} for new_param") + + module.register_parameter(param_name, saved_param) diff --git a/tests/integration/test_lists/test-db/l0_dgx_b200.yml b/tests/integration/test_lists/test-db/l0_dgx_b200.yml index 2640b2eaa3..64605d88a2 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_b200.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_b200.yml @@ -44,6 +44,7 @@ l0_dgx_b200: orchestrator: ray tests: - unittest/llmapi/test_llm_multi_gpu_pytorch.py -m "gpu4" + - unittest/_torch/ray_orchestrator/multi_gpu/test_llm_update_weights_multi_gpu.py - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=False-attn_backend=TRTLLM-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp2pp2-fp8kv=False-attn_backend=TRTLLM-torch_compile=False] - disaggregated/test_disaggregated.py::test_disaggregated_ctxpp2_genpp2[TinyLlama-1.1B-Chat-v1.0] diff --git a/tests/unittest/_torch/ray_orchestrator/multi_gpu/test_llm_update_weights_multi_gpu.py b/tests/unittest/_torch/ray_orchestrator/multi_gpu/test_llm_update_weights_multi_gpu.py new file mode 100644 index 0000000000..1a59fb13ab --- /dev/null +++ b/tests/unittest/_torch/ray_orchestrator/multi_gpu/test_llm_update_weights_multi_gpu.py @@ -0,0 +1,154 @@ +import re +from typing import Callable + +import pytest +from _torch.ray_orchestrator.single_gpu.test_llm_update_weights import ( + RefHFModelWithIPCHandles, + compare_logits, + run_generate, +) +from transformers import AutoTokenizer +from utils.llm_data import llm_models_root +from utils.util import skip_pre_blackwell + +from tensorrt_llm import LLM +from tensorrt_llm.llmapi import KvCacheConfig, SamplingParams + + +@skip_pre_blackwell +@pytest.mark.parametrize( + "model_dir, fp8_model_dir", + [ + ("Qwen3/Qwen3-8B", "Qwen3/Qwen3-8B-FP8"), + ("Qwen3/Qwen3-30B-A3B", "Qwen3/Qwen3-30B-A3B-FP8"), + ], +) +def test_llm_update_weights_with_quant_config(model_dir, fp8_model_dir): + model_dir = str(llm_models_root() / model_dir) + fp8_model_dir = str(llm_models_root() / fp8_model_dir) + additional_kwargs = {} + if "Qwen3/Qwen3-30B-A3B" in model_dir: + additional_kwargs["moe_config"] = { + "backend": "DEEPGEMM", + } + num_hidden_layers = 1 + hf_model = RefHFModelWithIPCHandles(fp8_model_dir, num_hidden_layers=num_hidden_layers) + tokenizer = AutoTokenizer.from_pretrained(fp8_model_dir) + kv_cache_config = KvCacheConfig(enable_block_reuse=True, free_gpu_memory_fraction=0.1) + llm = LLM( + model=model_dir, + ray_worker_extension_cls="tensorrt_llm.llmapi.rlhf_utils.WorkerExtension", + tensor_parallel_size=2, + load_format="dummy", + pipeline_parallel_size=1, + kv_cache_config=kv_cache_config, + model_kwargs={ + "num_hidden_layers": num_hidden_layers, + "quantization_config": { + "activation_scheme": "dynamic", + "fmt": "e4m3", + "quant_method": "fp8", + "weight_block_size": [128, 128], + }, + }, + **additional_kwargs, + ) + + # Generate texts from the prompts. + prompts_texts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + prompts = [tokenizer.encode(prompt) for prompt in prompts_texts] + del tokenizer + sampling_params = SamplingParams(temperature=0, return_generation_logits=True, max_tokens=1024) + + ipc_handles = hf_model.get_weight_ipc_handles_serialized([0, 1]) + + llm._collective_rpc("update_weights", (ipc_handles,)) + # Finalize the update weights + llm._collective_rpc("update_weights", (None,)) + + llm_logits, ref_logits = run_generate(llm, hf_model, prompts, sampling_params) + compare_logits(llm_logits, ref_logits) + + +@skip_pre_blackwell +@pytest.mark.parametrize( + "model_dir, fp8_model_dir", + [ + ("Qwen3/Qwen3-8B", "Qwen3/Qwen3-8B-FP8"), + ("Qwen3/Qwen3-30B-A3B", "Qwen3/Qwen3-30B-A3B-FP8"), + ], +) +def test_llm_partial_update_weights_with_quant_config(model_dir, fp8_model_dir): + model_dir = str(llm_models_root() / model_dir) + fp8_model_dir = str(llm_models_root() / fp8_model_dir) + additional_kwargs = {} + if "Qwen3/Qwen3-30B-A3B" in model_dir: + additional_kwargs["moe_config"] = { + "backend": "DEEPGEMM", + } + num_hidden_layers = 1 + hf_model = RefHFModelWithIPCHandles(fp8_model_dir, num_hidden_layers=num_hidden_layers) + tokenizer = AutoTokenizer.from_pretrained(fp8_model_dir) + kv_cache_config = KvCacheConfig(enable_block_reuse=True, free_gpu_memory_fraction=0.1) + llm = LLM( + model=model_dir, + ray_worker_extension_cls="tensorrt_llm.llmapi.rlhf_utils.WorkerExtension", + tensor_parallel_size=2, + load_format="dummy", + pipeline_parallel_size=1, + kv_cache_config=kv_cache_config, + model_kwargs={ + "num_hidden_layers": num_hidden_layers, + "quantization_config": { + "activation_scheme": "dynamic", + "fmt": "e4m3", + "quant_method": "fp8", + "weight_block_size": [128, 128], + }, + }, + **additional_kwargs, + ) + + # Generate texts from the prompts. + prompts_texts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + prompts = [tokenizer.encode(prompt) for prompt in prompts_texts] + del tokenizer + + sampling_params = SamplingParams(temperature=0, return_generation_logits=True, max_tokens=1024) + + def common_filter(filter_name: str) -> Callable[[str], bool]: + def filter_fn(name: str) -> bool: + return name.endswith(filter_name) + + return filter_fn + + # Generate filter_list from model weight keys by removing layer prefix + # e.g., "model.layers.41.input_layernorm.weight" -> "input_layernorm.weight" + layer_prefix_pattern = re.compile(r"^model\.layers\.\d+\.") + filter_set = set() + for name, _ in hf_model.all_weights[hf_model.device_id]: + suffix = layer_prefix_pattern.sub("", name) + filter_set.add(suffix) + filter_list = list(filter_set) + + for filter_name in filter_list: + weight_filter = common_filter(filter_name=filter_name) + ipc_handles = hf_model.get_weight_ipc_handles_serialized( + [0, 1], weight_filter=weight_filter + ) + llm._collective_rpc("update_weights", (ipc_handles,)) + # Finalize the update weights + llm._collective_rpc("update_weights", (None,)) + + llm_logits, ref_logits = run_generate(llm, hf_model, prompts, sampling_params) + compare_logits(llm_logits, ref_logits) diff --git a/tests/unittest/_torch/ray_orchestrator/single_gpu/test_llm_update_weights.py b/tests/unittest/_torch/ray_orchestrator/single_gpu/test_llm_update_weights.py index 073067907a..11ede427e2 100644 --- a/tests/unittest/_torch/ray_orchestrator/single_gpu/test_llm_update_weights.py +++ b/tests/unittest/_torch/ray_orchestrator/single_gpu/test_llm_update_weights.py @@ -9,10 +9,11 @@ from torch.multiprocessing.reductions import reduce_tensor from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from utils.llm_data import llm_models_root from utils.torch_ref import RefHFModel +from utils.util import getSMVersion, skip_pre_hopper from tensorrt_llm import LLM from tensorrt_llm._torch.utils import get_device_uuid -from tensorrt_llm.llmapi import KvCacheConfig, SamplingParams +from tensorrt_llm.llmapi import KvCacheConfig, MoeConfig, SamplingParams class RefHFModelWithIPCHandles(RefHFModel): @@ -119,6 +120,7 @@ def run_generate( return llm_logits, ref_logits +@skip_pre_hopper @pytest.mark.parametrize( "model_dir", [ @@ -136,6 +138,7 @@ def test_llm_update_weights(model_dir): hf_model = RefHFModelWithIPCHandles(model_dir, num_hidden_layers=num_hidden_layers) tokenizer = AutoTokenizer.from_pretrained(model_dir) kv_cache_config = KvCacheConfig(enable_block_reuse=True, free_gpu_memory_fraction=0.1) + moe_config = MoeConfig(backend="DEEPGEMM" if getSMVersion() >= 100 else "CUTLASS") llm = LLM( model=model_dir, ray_worker_extension_cls="tensorrt_llm.llmapi.rlhf_utils.WorkerExtension", @@ -144,6 +147,7 @@ def test_llm_update_weights(model_dir): pipeline_parallel_size=1, kv_cache_config=kv_cache_config, model_kwargs={"num_hidden_layers": num_hidden_layers}, + moe_config=moe_config, ) # Generate texts from the prompts. @@ -167,6 +171,7 @@ def test_llm_update_weights(model_dir): compare_logits(llm_logits, ref_logits) +@skip_pre_hopper @pytest.mark.parametrize( "model_dir", [ @@ -184,7 +189,7 @@ def test_llm_partial_update_weights(model_dir): hf_model = RefHFModelWithIPCHandles(model_dir, num_hidden_layers=num_hidden_layers) tokenizer = AutoTokenizer.from_pretrained(model_dir) kv_cache_config = KvCacheConfig(enable_block_reuse=True, free_gpu_memory_fraction=0.1) - + moe_config = MoeConfig(backend="DEEPGEMM" if getSMVersion() >= 100 else "CUTLASS") llm = LLM( model=model_dir, ray_worker_extension_cls="tensorrt_llm.llmapi.rlhf_utils.WorkerExtension", @@ -193,6 +198,7 @@ def test_llm_partial_update_weights(model_dir): pipeline_parallel_size=1, kv_cache_config=kv_cache_config, model_kwargs={"num_hidden_layers": num_hidden_layers}, + moe_config=moe_config, ) # Generate texts from the prompts. @@ -233,6 +239,7 @@ def test_llm_partial_update_weights(model_dir): compare_logits(llm_logits, ref_logits) +@skip_pre_hopper @pytest.mark.parametrize( "model_dir, fp8_model_dir", [ @@ -247,6 +254,7 @@ def test_llm_update_weights_with_quant_config(model_dir, fp8_model_dir): hf_model = RefHFModelWithIPCHandles(fp8_model_dir, num_hidden_layers=num_hidden_layers) tokenizer = AutoTokenizer.from_pretrained(fp8_model_dir) kv_cache_config = KvCacheConfig(enable_block_reuse=True, free_gpu_memory_fraction=0.1) + moe_config = MoeConfig(backend="DEEPGEMM" if getSMVersion() >= 100 else "CUTLASS") llm = LLM( model=model_dir, ray_worker_extension_cls="tensorrt_llm.llmapi.rlhf_utils.WorkerExtension", @@ -263,6 +271,7 @@ def test_llm_update_weights_with_quant_config(model_dir, fp8_model_dir): "weight_block_size": [128, 128], }, }, + moe_config=moe_config, ) # Generate texts from the prompts.