diff --git a/tensorrt_llm/quantization/utils/fp8_utils.py b/tensorrt_llm/quantization/utils/fp8_utils.py index e26288b5bc..aa368b2788 100644 --- a/tensorrt_llm/quantization/utils/fp8_utils.py +++ b/tensorrt_llm/quantization/utils/fp8_utils.py @@ -103,6 +103,9 @@ def resmooth_to_fp8_e8m0( assert weight.dtype == torch.float8_e4m3fn assert weight_scale.dtype == torch.float32 + weight = weight.cuda() + weight_scale = weight_scale.cuda() + orig_shape = weight.shape M, K = orig_shape[-2:] w_view = weight.view(-1, M, K)