mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 15:55:08 +08:00
[TRTLLM-9771][feat] Make update_weights compatible with CUDA Graph (#11267)
Signed-off-by: Shuyi Xiong <219646547+shuyixiong@users.noreply.github.com>
This commit is contained in:
parent
8b2dc57823
commit
c3cdc93211
@ -1,7 +1,7 @@
|
||||
import inspect
|
||||
import math
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, NamedTuple, Optional, Union
|
||||
from typing import Dict, List, NamedTuple, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -543,10 +543,9 @@ class FusedMoEMethodBase(ABC):
|
||||
|
||||
def pre_reload_weights(self, module: torch.nn.Module):
|
||||
for param_name, metadata in module.rebuild_tensor_metadata.items():
|
||||
logger.warning(
|
||||
f"Pre-reloading weight '{param_name}' requires tensor re-creation, which will invalidate existing CUDA graphs."
|
||||
)
|
||||
param = torch.nn.Parameter(torch.empty_like(metadata,
|
||||
# Extract meta tensor from metadata dict
|
||||
meta_tensor = metadata['meta']
|
||||
param = torch.nn.Parameter(torch.empty_like(meta_tensor,
|
||||
device="cuda"),
|
||||
requires_grad=False)
|
||||
module.register_parameter(param_name, param)
|
||||
@ -999,6 +998,23 @@ class DeepSeekFP8BlockScalesFusedMoEMethod(FusedMoEMethodBase):
|
||||
super().post_load_weights(module)
|
||||
|
||||
|
||||
def resmooth_and_transform_fp8_scale(
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Resmooth weight/scale to FP8 E8M0 and transform scale to required layout for MoE."""
|
||||
resmoothed_weight, resmoothed_scale = resmooth_to_fp8_e8m0(
|
||||
weight, weight_scale)
|
||||
transformed_scale = transform_sf_into_required_layout(
|
||||
resmoothed_scale,
|
||||
mn=weight.shape[1],
|
||||
k=weight.shape[2],
|
||||
recipe=(1, 128, 128),
|
||||
num_groups=weight.shape[0],
|
||||
is_sfa=False)
|
||||
return resmoothed_weight, transformed_scale
|
||||
|
||||
|
||||
class DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm(
|
||||
DeepSeekFP8BlockScalesFusedMoEMethod):
|
||||
|
||||
@ -1007,53 +1023,70 @@ class DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm(
|
||||
weights: List[Dict],
|
||||
weight_loading_mode: MoEWeightLoadingMode,
|
||||
allow_partial_loading: bool = False):
|
||||
if is_sm_100f():
|
||||
expert_ids = set(module.initial_local_expert_ids)
|
||||
if self.need_load_shared_weights(module):
|
||||
expert_ids.update(
|
||||
module.layer_load_balancer.get_load_expert_ids())
|
||||
for name in list(weights.keys()):
|
||||
if name.endswith("weight_scale_inv"):
|
||||
if int(name.split(".")[0]) not in expert_ids:
|
||||
continue
|
||||
weight_name = name.replace("weight_scale_inv", "weight")
|
||||
logger.debug(f"Resmoothing {weight_name}")
|
||||
weight = weights[weight_name][:]
|
||||
scale = weights[name][:]
|
||||
weights[weight_name], weights[name] = resmooth_to_fp8_e8m0(
|
||||
weight, scale)
|
||||
super().load_weights(module, weights, weight_loading_mode,
|
||||
allow_partial_loading)
|
||||
|
||||
def post_load_weights(self, module: torch.nn.Module):
|
||||
super().post_load_weights(module)
|
||||
if is_sm_100f():
|
||||
transfromed_w3_w1_scale = transform_sf_into_required_layout(
|
||||
module.quant_scales[0],
|
||||
mn=module.w3_w1_weight.shape[1],
|
||||
k=module.w3_w1_weight.shape[2],
|
||||
recipe=(1, 128, 128),
|
||||
num_groups=module.w3_w1_weight.shape[0],
|
||||
is_sfa=False)
|
||||
transformed_w3_w1_weight_scaling_factor = nn.Parameter(
|
||||
transfromed_w3_w1_scale, requires_grad=False)
|
||||
replace_parameter_and_save_metadata(
|
||||
module, "w3_w1_weight_scaling_factor",
|
||||
transformed_w3_w1_weight_scaling_factor,
|
||||
module.rebuild_tensor_metadata)
|
||||
transfromed_w2_scale = transform_sf_into_required_layout(
|
||||
module.quant_scales[1],
|
||||
mn=module.w2_weight.shape[1],
|
||||
k=module.w2_weight.shape[2],
|
||||
recipe=(1, 128, 128),
|
||||
num_groups=module.w3_w1_weight.shape[0],
|
||||
is_sfa=False)
|
||||
transformed_w2_weight_scaling_factor = nn.Parameter(
|
||||
transfromed_w2_scale, requires_grad=False)
|
||||
replace_parameter_and_save_metadata(
|
||||
module, "w2_weight_scaling_factor",
|
||||
transformed_w2_weight_scaling_factor,
|
||||
module.rebuild_tensor_metadata)
|
||||
# Resmooth shared experts before registering shared weights
|
||||
if self.need_load_shared_weights(module):
|
||||
local_shared_load_expert_ids = module.layer_load_balancer.get_load_expert_ids(
|
||||
)
|
||||
if getattr(module, 'local_shared_w3_w1_tensors',
|
||||
None) is not None:
|
||||
num_shared_experts = len(local_shared_load_expert_ids)
|
||||
logger.debug(
|
||||
f"Batch resmoothing {num_shared_experts} shared experts"
|
||||
)
|
||||
|
||||
local_shared_w3_w1_tensors = getattr(
|
||||
module, 'local_shared_w3_w1_tensors')
|
||||
local_shared_w3_w1_scale_tensors = getattr(
|
||||
module, 'local_shared_w3_w1_scale_tensors')
|
||||
local_shared_w2_tensors = getattr(
|
||||
module, 'local_shared_w2_tensors')
|
||||
local_shared_w2_scale_tensors = getattr(
|
||||
module, 'local_shared_w2_scale_tensors')
|
||||
|
||||
resmoothed_shared_w3_w1_weight, transformed_shared_w3_w1_scale = resmooth_and_transform_fp8_scale(
|
||||
local_shared_w3_w1_tensors,
|
||||
local_shared_w3_w1_scale_tensors)
|
||||
setattr(module, 'local_shared_w3_w1_tensors',
|
||||
resmoothed_shared_w3_w1_weight.cpu())
|
||||
setattr(module, 'local_shared_w3_w1_scale_tensors',
|
||||
transformed_shared_w3_w1_scale.cpu())
|
||||
|
||||
resmoothed_shared_w2_weight, transformed_shared_w2_scale = resmooth_and_transform_fp8_scale(
|
||||
local_shared_w2_tensors, local_shared_w2_scale_tensors)
|
||||
setattr(module, 'local_shared_w2_tensors',
|
||||
resmoothed_shared_w2_weight.cpu())
|
||||
setattr(module, 'local_shared_w2_scale_tensors',
|
||||
transformed_shared_w2_scale.cpu())
|
||||
|
||||
# Call super() after resmooth shared experts (local_shared tensors will be deleted in super().post_load_weights())
|
||||
super().post_load_weights(module)
|
||||
|
||||
if is_sm_100f():
|
||||
logger.debug("Resmoothing FP8 weights in post_load_weights")
|
||||
resmoothed_w3_w1_weight, transformed_w3_w1_scale = resmooth_and_transform_fp8_scale(
|
||||
module.w3_w1_weight, module.w3_w1_weight_scaling_factor)
|
||||
replace_parameter_and_save_metadata(module, "w3_w1_weight",
|
||||
resmoothed_w3_w1_weight,
|
||||
module.rebuild_tensor_metadata)
|
||||
replace_parameter_and_save_metadata(module,
|
||||
"w3_w1_weight_scaling_factor",
|
||||
transformed_w3_w1_scale,
|
||||
module.rebuild_tensor_metadata)
|
||||
|
||||
resmoothed_w2_weight, transformed_w2_scale = resmooth_and_transform_fp8_scale(
|
||||
module.w2_weight, module.w2_weight_scaling_factor)
|
||||
replace_parameter_and_save_metadata(module, "w2_weight",
|
||||
resmoothed_w2_weight,
|
||||
module.rebuild_tensor_metadata)
|
||||
replace_parameter_and_save_metadata(module,
|
||||
"w2_weight_scaling_factor",
|
||||
transformed_w2_scale,
|
||||
module.rebuild_tensor_metadata)
|
||||
self.setup_quant_scales(module)
|
||||
|
||||
|
||||
|
||||
@ -518,10 +518,9 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
||||
|
||||
def pre_reload_weights(self, module: Linear):
|
||||
for param_name, metadata in module.rebuild_tensor_metadata.items():
|
||||
logger.warning(
|
||||
f"Pre-reloading weight '{param_name}' requires tensor re-creation, which will invalidate existing CUDA graphs."
|
||||
)
|
||||
param = Parameter(torch.empty_like(metadata, device="cuda"),
|
||||
# Extract meta tensor from metadata dict
|
||||
meta_tensor = metadata['meta']
|
||||
param = Parameter(torch.empty_like(meta_tensor, device="cuda"),
|
||||
requires_grad=False)
|
||||
module.register_parameter(param_name, param)
|
||||
|
||||
|
||||
@ -426,13 +426,37 @@ def maybe_compiled_cat(tensors, dim):
|
||||
return torch.cat(tensors, dim)
|
||||
|
||||
|
||||
def replace_parameter_and_save_metadata(module: torch.nn.Module,
|
||||
param_name: str,
|
||||
new_param: torch.nn.Parameter,
|
||||
metadata_dict: Dict):
|
||||
def replace_parameter_and_save_metadata(
|
||||
module: torch.nn.Module, param_name: str,
|
||||
new_param: torch.nn.Parameter | torch.Tensor, metadata_dict: Dict):
|
||||
"""
|
||||
Replace a parameter in a module and save the metadata of the original parameter.
|
||||
On first call: saves original param's meta tensor and new param's tensor, then replaces.
|
||||
On subsequent calls: copies new_param data into the saved tensor, then registers it.
|
||||
"""
|
||||
saved_param = None
|
||||
if param_name not in metadata_dict:
|
||||
metadata_dict[param_name] = getattr(module, param_name).to("meta")
|
||||
module.register_parameter(param_name, new_param)
|
||||
# First time: save original meta tensor and the new param tensor reference
|
||||
original_meta = getattr(module, param_name).to("meta")
|
||||
# Convert new_param to Parameter if it's a Tensor, otherwise use directly
|
||||
if isinstance(new_param, torch.nn.Parameter):
|
||||
saved_param = new_param
|
||||
elif isinstance(new_param, torch.Tensor):
|
||||
saved_param = torch.nn.Parameter(new_param, requires_grad=False)
|
||||
else:
|
||||
raise ValueError(f"Invalid type {type(new_param)} for new_param")
|
||||
metadata_dict[param_name] = {
|
||||
'meta': original_meta,
|
||||
'param': saved_param
|
||||
}
|
||||
else:
|
||||
# Subsequent calls: copy new_param into the saved tensor
|
||||
saved_param = metadata_dict[param_name]['param']
|
||||
if isinstance(new_param, torch.nn.Parameter):
|
||||
saved_param.data.copy_(new_param.data)
|
||||
elif isinstance(new_param, torch.Tensor):
|
||||
saved_param.data.copy_(new_param)
|
||||
else:
|
||||
raise ValueError(f"Invalid type {type(new_param)} for new_param")
|
||||
|
||||
module.register_parameter(param_name, saved_param)
|
||||
|
||||
@ -44,6 +44,7 @@ l0_dgx_b200:
|
||||
orchestrator: ray
|
||||
tests:
|
||||
- unittest/llmapi/test_llm_multi_gpu_pytorch.py -m "gpu4"
|
||||
- unittest/_torch/ray_orchestrator/multi_gpu/test_llm_update_weights_multi_gpu.py
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=False-attn_backend=TRTLLM-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp2pp2-fp8kv=False-attn_backend=TRTLLM-torch_compile=False]
|
||||
- disaggregated/test_disaggregated.py::test_disaggregated_ctxpp2_genpp2[TinyLlama-1.1B-Chat-v1.0]
|
||||
|
||||
@ -0,0 +1,154 @@
|
||||
import re
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
from _torch.ray_orchestrator.single_gpu.test_llm_update_weights import (
|
||||
RefHFModelWithIPCHandles,
|
||||
compare_logits,
|
||||
run_generate,
|
||||
)
|
||||
from transformers import AutoTokenizer
|
||||
from utils.llm_data import llm_models_root
|
||||
from utils.util import skip_pre_blackwell
|
||||
|
||||
from tensorrt_llm import LLM
|
||||
from tensorrt_llm.llmapi import KvCacheConfig, SamplingParams
|
||||
|
||||
|
||||
@skip_pre_blackwell
|
||||
@pytest.mark.parametrize(
|
||||
"model_dir, fp8_model_dir",
|
||||
[
|
||||
("Qwen3/Qwen3-8B", "Qwen3/Qwen3-8B-FP8"),
|
||||
("Qwen3/Qwen3-30B-A3B", "Qwen3/Qwen3-30B-A3B-FP8"),
|
||||
],
|
||||
)
|
||||
def test_llm_update_weights_with_quant_config(model_dir, fp8_model_dir):
|
||||
model_dir = str(llm_models_root() / model_dir)
|
||||
fp8_model_dir = str(llm_models_root() / fp8_model_dir)
|
||||
additional_kwargs = {}
|
||||
if "Qwen3/Qwen3-30B-A3B" in model_dir:
|
||||
additional_kwargs["moe_config"] = {
|
||||
"backend": "DEEPGEMM",
|
||||
}
|
||||
num_hidden_layers = 1
|
||||
hf_model = RefHFModelWithIPCHandles(fp8_model_dir, num_hidden_layers=num_hidden_layers)
|
||||
tokenizer = AutoTokenizer.from_pretrained(fp8_model_dir)
|
||||
kv_cache_config = KvCacheConfig(enable_block_reuse=True, free_gpu_memory_fraction=0.1)
|
||||
llm = LLM(
|
||||
model=model_dir,
|
||||
ray_worker_extension_cls="tensorrt_llm.llmapi.rlhf_utils.WorkerExtension",
|
||||
tensor_parallel_size=2,
|
||||
load_format="dummy",
|
||||
pipeline_parallel_size=1,
|
||||
kv_cache_config=kv_cache_config,
|
||||
model_kwargs={
|
||||
"num_hidden_layers": num_hidden_layers,
|
||||
"quantization_config": {
|
||||
"activation_scheme": "dynamic",
|
||||
"fmt": "e4m3",
|
||||
"quant_method": "fp8",
|
||||
"weight_block_size": [128, 128],
|
||||
},
|
||||
},
|
||||
**additional_kwargs,
|
||||
)
|
||||
|
||||
# Generate texts from the prompts.
|
||||
prompts_texts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
prompts = [tokenizer.encode(prompt) for prompt in prompts_texts]
|
||||
del tokenizer
|
||||
sampling_params = SamplingParams(temperature=0, return_generation_logits=True, max_tokens=1024)
|
||||
|
||||
ipc_handles = hf_model.get_weight_ipc_handles_serialized([0, 1])
|
||||
|
||||
llm._collective_rpc("update_weights", (ipc_handles,))
|
||||
# Finalize the update weights
|
||||
llm._collective_rpc("update_weights", (None,))
|
||||
|
||||
llm_logits, ref_logits = run_generate(llm, hf_model, prompts, sampling_params)
|
||||
compare_logits(llm_logits, ref_logits)
|
||||
|
||||
|
||||
@skip_pre_blackwell
|
||||
@pytest.mark.parametrize(
|
||||
"model_dir, fp8_model_dir",
|
||||
[
|
||||
("Qwen3/Qwen3-8B", "Qwen3/Qwen3-8B-FP8"),
|
||||
("Qwen3/Qwen3-30B-A3B", "Qwen3/Qwen3-30B-A3B-FP8"),
|
||||
],
|
||||
)
|
||||
def test_llm_partial_update_weights_with_quant_config(model_dir, fp8_model_dir):
|
||||
model_dir = str(llm_models_root() / model_dir)
|
||||
fp8_model_dir = str(llm_models_root() / fp8_model_dir)
|
||||
additional_kwargs = {}
|
||||
if "Qwen3/Qwen3-30B-A3B" in model_dir:
|
||||
additional_kwargs["moe_config"] = {
|
||||
"backend": "DEEPGEMM",
|
||||
}
|
||||
num_hidden_layers = 1
|
||||
hf_model = RefHFModelWithIPCHandles(fp8_model_dir, num_hidden_layers=num_hidden_layers)
|
||||
tokenizer = AutoTokenizer.from_pretrained(fp8_model_dir)
|
||||
kv_cache_config = KvCacheConfig(enable_block_reuse=True, free_gpu_memory_fraction=0.1)
|
||||
llm = LLM(
|
||||
model=model_dir,
|
||||
ray_worker_extension_cls="tensorrt_llm.llmapi.rlhf_utils.WorkerExtension",
|
||||
tensor_parallel_size=2,
|
||||
load_format="dummy",
|
||||
pipeline_parallel_size=1,
|
||||
kv_cache_config=kv_cache_config,
|
||||
model_kwargs={
|
||||
"num_hidden_layers": num_hidden_layers,
|
||||
"quantization_config": {
|
||||
"activation_scheme": "dynamic",
|
||||
"fmt": "e4m3",
|
||||
"quant_method": "fp8",
|
||||
"weight_block_size": [128, 128],
|
||||
},
|
||||
},
|
||||
**additional_kwargs,
|
||||
)
|
||||
|
||||
# Generate texts from the prompts.
|
||||
prompts_texts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
prompts = [tokenizer.encode(prompt) for prompt in prompts_texts]
|
||||
del tokenizer
|
||||
|
||||
sampling_params = SamplingParams(temperature=0, return_generation_logits=True, max_tokens=1024)
|
||||
|
||||
def common_filter(filter_name: str) -> Callable[[str], bool]:
|
||||
def filter_fn(name: str) -> bool:
|
||||
return name.endswith(filter_name)
|
||||
|
||||
return filter_fn
|
||||
|
||||
# Generate filter_list from model weight keys by removing layer prefix
|
||||
# e.g., "model.layers.41.input_layernorm.weight" -> "input_layernorm.weight"
|
||||
layer_prefix_pattern = re.compile(r"^model\.layers\.\d+\.")
|
||||
filter_set = set()
|
||||
for name, _ in hf_model.all_weights[hf_model.device_id]:
|
||||
suffix = layer_prefix_pattern.sub("", name)
|
||||
filter_set.add(suffix)
|
||||
filter_list = list(filter_set)
|
||||
|
||||
for filter_name in filter_list:
|
||||
weight_filter = common_filter(filter_name=filter_name)
|
||||
ipc_handles = hf_model.get_weight_ipc_handles_serialized(
|
||||
[0, 1], weight_filter=weight_filter
|
||||
)
|
||||
llm._collective_rpc("update_weights", (ipc_handles,))
|
||||
# Finalize the update weights
|
||||
llm._collective_rpc("update_weights", (None,))
|
||||
|
||||
llm_logits, ref_logits = run_generate(llm, hf_model, prompts, sampling_params)
|
||||
compare_logits(llm_logits, ref_logits)
|
||||
@ -9,10 +9,11 @@ from torch.multiprocessing.reductions import reduce_tensor
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
from utils.llm_data import llm_models_root
|
||||
from utils.torch_ref import RefHFModel
|
||||
from utils.util import getSMVersion, skip_pre_hopper
|
||||
|
||||
from tensorrt_llm import LLM
|
||||
from tensorrt_llm._torch.utils import get_device_uuid
|
||||
from tensorrt_llm.llmapi import KvCacheConfig, SamplingParams
|
||||
from tensorrt_llm.llmapi import KvCacheConfig, MoeConfig, SamplingParams
|
||||
|
||||
|
||||
class RefHFModelWithIPCHandles(RefHFModel):
|
||||
@ -119,6 +120,7 @@ def run_generate(
|
||||
return llm_logits, ref_logits
|
||||
|
||||
|
||||
@skip_pre_hopper
|
||||
@pytest.mark.parametrize(
|
||||
"model_dir",
|
||||
[
|
||||
@ -136,6 +138,7 @@ def test_llm_update_weights(model_dir):
|
||||
hf_model = RefHFModelWithIPCHandles(model_dir, num_hidden_layers=num_hidden_layers)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||
kv_cache_config = KvCacheConfig(enable_block_reuse=True, free_gpu_memory_fraction=0.1)
|
||||
moe_config = MoeConfig(backend="DEEPGEMM" if getSMVersion() >= 100 else "CUTLASS")
|
||||
llm = LLM(
|
||||
model=model_dir,
|
||||
ray_worker_extension_cls="tensorrt_llm.llmapi.rlhf_utils.WorkerExtension",
|
||||
@ -144,6 +147,7 @@ def test_llm_update_weights(model_dir):
|
||||
pipeline_parallel_size=1,
|
||||
kv_cache_config=kv_cache_config,
|
||||
model_kwargs={"num_hidden_layers": num_hidden_layers},
|
||||
moe_config=moe_config,
|
||||
)
|
||||
|
||||
# Generate texts from the prompts.
|
||||
@ -167,6 +171,7 @@ def test_llm_update_weights(model_dir):
|
||||
compare_logits(llm_logits, ref_logits)
|
||||
|
||||
|
||||
@skip_pre_hopper
|
||||
@pytest.mark.parametrize(
|
||||
"model_dir",
|
||||
[
|
||||
@ -184,7 +189,7 @@ def test_llm_partial_update_weights(model_dir):
|
||||
hf_model = RefHFModelWithIPCHandles(model_dir, num_hidden_layers=num_hidden_layers)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||
kv_cache_config = KvCacheConfig(enable_block_reuse=True, free_gpu_memory_fraction=0.1)
|
||||
|
||||
moe_config = MoeConfig(backend="DEEPGEMM" if getSMVersion() >= 100 else "CUTLASS")
|
||||
llm = LLM(
|
||||
model=model_dir,
|
||||
ray_worker_extension_cls="tensorrt_llm.llmapi.rlhf_utils.WorkerExtension",
|
||||
@ -193,6 +198,7 @@ def test_llm_partial_update_weights(model_dir):
|
||||
pipeline_parallel_size=1,
|
||||
kv_cache_config=kv_cache_config,
|
||||
model_kwargs={"num_hidden_layers": num_hidden_layers},
|
||||
moe_config=moe_config,
|
||||
)
|
||||
|
||||
# Generate texts from the prompts.
|
||||
@ -233,6 +239,7 @@ def test_llm_partial_update_weights(model_dir):
|
||||
compare_logits(llm_logits, ref_logits)
|
||||
|
||||
|
||||
@skip_pre_hopper
|
||||
@pytest.mark.parametrize(
|
||||
"model_dir, fp8_model_dir",
|
||||
[
|
||||
@ -247,6 +254,7 @@ def test_llm_update_weights_with_quant_config(model_dir, fp8_model_dir):
|
||||
hf_model = RefHFModelWithIPCHandles(fp8_model_dir, num_hidden_layers=num_hidden_layers)
|
||||
tokenizer = AutoTokenizer.from_pretrained(fp8_model_dir)
|
||||
kv_cache_config = KvCacheConfig(enable_block_reuse=True, free_gpu_memory_fraction=0.1)
|
||||
moe_config = MoeConfig(backend="DEEPGEMM" if getSMVersion() >= 100 else "CUTLASS")
|
||||
llm = LLM(
|
||||
model=model_dir,
|
||||
ray_worker_extension_cls="tensorrt_llm.llmapi.rlhf_utils.WorkerExtension",
|
||||
@ -263,6 +271,7 @@ def test_llm_update_weights_with_quant_config(model_dir, fp8_model_dir):
|
||||
"weight_block_size": [128, 128],
|
||||
},
|
||||
},
|
||||
moe_config=moe_config,
|
||||
)
|
||||
|
||||
# Generate texts from the prompts.
|
||||
|
||||
Loading…
Reference in New Issue
Block a user