Fix pre-reload and add cuda graph warning

Signed-off-by: Shuyi Xiong <219646547+shuyixiong@users.noreply.github.com>
This commit is contained in:
Shuyi Xiong 2026-01-12 19:47:01 -08:00
parent 694eb01480
commit c37c80d3ac
2 changed files with 11 additions and 3 deletions

View File

@ -531,9 +531,13 @@ class FusedMoEMethodBase(ABC):
def pre_reload_weights(self, module: torch.nn.Module):
for param_name, metadata in module.rebuild_tensor_metadata.items():
param = torch.nn.Parameter(torch.empty_like(metadata),
logger.warning(
f"Pre-reloading weight '{param_name}' requires tensor re-creation, which will invalidate existing CUDA graphs."
)
param = torch.nn.Parameter(torch.empty_like(metadata,
device="cuda"),
requires_grad=False)
setattr(module, param_name, param)
module.register_parameter(param_name, param)
class UnquantizedFusedMoEMethod(FusedMoEMethodBase):

View File

@ -512,7 +512,11 @@ class UnquantizedLinearMethod(LinearMethodBase):
def pre_reload_weights(self, module: Linear):
for param_name, metadata in module.rebuild_tensor_metadata.items():
param = Parameter(torch.empty_like(metadata), requires_grad=False)
logger.warning(
f"Pre-reloading weight '{param_name}' requires tensor re-creation, which will invalidate existing CUDA graphs."
)
param = Parameter(torch.empty_like(metadata, device="cuda"),
requires_grad=False)
module.register_parameter(param_name, param)