mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Fix pre-reload and add cuda graph warning
Signed-off-by: Shuyi Xiong <219646547+shuyixiong@users.noreply.github.com>
This commit is contained in:
parent
694eb01480
commit
c37c80d3ac
@ -531,9 +531,13 @@ class FusedMoEMethodBase(ABC):
|
||||
|
||||
def pre_reload_weights(self, module: torch.nn.Module):
|
||||
for param_name, metadata in module.rebuild_tensor_metadata.items():
|
||||
param = torch.nn.Parameter(torch.empty_like(metadata),
|
||||
logger.warning(
|
||||
f"Pre-reloading weight '{param_name}' requires tensor re-creation, which will invalidate existing CUDA graphs."
|
||||
)
|
||||
param = torch.nn.Parameter(torch.empty_like(metadata,
|
||||
device="cuda"),
|
||||
requires_grad=False)
|
||||
setattr(module, param_name, param)
|
||||
module.register_parameter(param_name, param)
|
||||
|
||||
|
||||
class UnquantizedFusedMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
@ -512,7 +512,11 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
||||
|
||||
def pre_reload_weights(self, module: Linear):
|
||||
for param_name, metadata in module.rebuild_tensor_metadata.items():
|
||||
param = Parameter(torch.empty_like(metadata), requires_grad=False)
|
||||
logger.warning(
|
||||
f"Pre-reloading weight '{param_name}' requires tensor re-creation, which will invalidate existing CUDA graphs."
|
||||
)
|
||||
param = Parameter(torch.empty_like(metadata, device="cuda"),
|
||||
requires_grad=False)
|
||||
module.register_parameter(param_name, param)
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user