[XPU][Bugfix] Fix per_token_group_fp8_quant missing dummy args on XPU (#43930)

Signed-off-by: Chaojun,Zhang <chaojun.zhang@intel.com>
Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
Chaojun Zhang
2026-06-02 11:09:21 +08:00
committed by GitHub
parent 480fadab1b
commit a3a5a5ece5
2 changed files with 11 additions and 8 deletions
+10 -1
View File
@@ -338,7 +338,16 @@ def _xpu_mxfp8_quantize_impl(
shape = x.shape[:-1] + (x.shape[-1] // MXFP8_BLOCK_SIZE,)
x_s = torch.empty(shape, device=x.device, dtype=torch.float32)
torch.ops._C.per_token_group_fp8_quant(
x, x_q, x_s, MXFP8_BLOCK_SIZE, eps, fp8_min, fp8_max, True
x,
x_q,
x_s,
MXFP8_BLOCK_SIZE,
eps,
fp8_min,
fp8_max,
True,
False,
False, # dummy_is_scale_transposed, dummy_is_tma_aligned
)
x_s = x_s.to(torch.float8_e8m0fnu)
return x_q, x_s
@@ -575,7 +575,7 @@ def per_token_group_quant_fp8(
# prefer CUDA/XPU kernel if available
# TODO(bnell): this causes some fp8 moe test to fail.
if current_platform.is_cuda() and x.is_contiguous():
if (current_platform.is_cuda() or current_platform.is_xpu()) and x.is_contiguous():
torch.ops._C.per_token_group_fp8_quant(
x,
x_q,
@@ -590,12 +590,6 @@ def per_token_group_quant_fp8(
)
return x_q, x_s
if current_platform.is_xpu() and x.is_contiguous():
torch.ops._C.per_token_group_fp8_quant(
x, x_q, x_s, group_size, eps, fp8_min, fp8_max, use_ue8m0
)
return x_q, x_s
# TRITON FALLBACK
M = x.numel() // group_size
N = group_size