mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[XPU][Bugfix] Fix per_token_group_fp8_quant missing dummy args on XPU (#43930)
Signed-off-by: Chaojun,Zhang <chaojun.zhang@intel.com> Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
+10
-1
@@ -338,7 +338,16 @@ def _xpu_mxfp8_quantize_impl(
|
||||
shape = x.shape[:-1] + (x.shape[-1] // MXFP8_BLOCK_SIZE,)
|
||||
x_s = torch.empty(shape, device=x.device, dtype=torch.float32)
|
||||
torch.ops._C.per_token_group_fp8_quant(
|
||||
x, x_q, x_s, MXFP8_BLOCK_SIZE, eps, fp8_min, fp8_max, True
|
||||
x,
|
||||
x_q,
|
||||
x_s,
|
||||
MXFP8_BLOCK_SIZE,
|
||||
eps,
|
||||
fp8_min,
|
||||
fp8_max,
|
||||
True,
|
||||
False,
|
||||
False, # dummy_is_scale_transposed, dummy_is_tma_aligned
|
||||
)
|
||||
x_s = x_s.to(torch.float8_e8m0fnu)
|
||||
return x_q, x_s
|
||||
|
||||
@@ -575,7 +575,7 @@ def per_token_group_quant_fp8(
|
||||
|
||||
# prefer CUDA/XPU kernel if available
|
||||
# TODO(bnell): this causes some fp8 moe test to fail.
|
||||
if current_platform.is_cuda() and x.is_contiguous():
|
||||
if (current_platform.is_cuda() or current_platform.is_xpu()) and x.is_contiguous():
|
||||
torch.ops._C.per_token_group_fp8_quant(
|
||||
x,
|
||||
x_q,
|
||||
@@ -590,12 +590,6 @@ def per_token_group_quant_fp8(
|
||||
)
|
||||
return x_q, x_s
|
||||
|
||||
if current_platform.is_xpu() and x.is_contiguous():
|
||||
torch.ops._C.per_token_group_fp8_quant(
|
||||
x, x_q, x_s, group_size, eps, fp8_min, fp8_max, use_ue8m0
|
||||
)
|
||||
return x_q, x_s
|
||||
|
||||
# TRITON FALLBACK
|
||||
M = x.numel() // group_size
|
||||
N = group_size
|
||||
|
||||
Reference in New Issue
Block a user