[None][fix] convert to CUDA tensor before calling _resmooth_kernel. (#10770)

Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
This commit is contained in:
Yuxian Qiu 2026-01-17 16:18:34 +08:00 committed by GitHub
parent b65560fc32
commit cef67b4f8d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -103,6 +103,9 @@ def resmooth_to_fp8_e8m0(
assert weight.dtype == torch.float8_e4m3fn
assert weight_scale.dtype == torch.float32
weight = weight.cuda()
weight_scale = weight_scale.cuda()
orig_shape = weight.shape
M, K = orig_shape[-2:]
w_view = weight.view(-1, M, K)