[None][chore] update flashinfer to 0.6.0 (#10522)

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
This commit is contained in:
Chenghao Zhang 2026-01-16 13:22:06 -08:00 committed by GitHub
parent b6acd96616
commit 0b748d5bba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 12 additions and 156 deletions

View File

@ -5261,7 +5261,7 @@ For more information, please refer to <http://unlicense.org>
- `Tracker`: https://github.com/tox-dev/py-filelock/issues
## flashinfer-python (0.3.1.post1)
## flashinfer-python (0.6.0)
### Licenses
License: `Apache-2.0`

View File

@ -53,7 +53,7 @@ ordered-set
peft
patchelf
einops
flashinfer-python>=0.3.0,<0.4.0
flashinfer-python~=0.6.0
opencv-python-headless
xgrammar==0.1.25
llguidance==0.7.29
@ -74,7 +74,7 @@ nvidia-cutlass-dsl==4.3.4; python_version >= "3.10"
plotly
numexpr<2.14.0 # WAR for attempted use of nonexistent numpy.typing
partial_json_parser
apache-tvm-ffi==0.1.4 # used for reduce nvidia-cutlass-dsl host overhead
apache-tvm-ffi==0.1.6 # used for reduce nvidia-cutlass-dsl host overhead
torch-c-dlpack-ext==0.1.3 # used for reduce nvidia-cutlass-dsl host overhead, optional package for improved torch tensor calling perf
mistral-common==1.8.6
torchao>=0.14.1

View File

@ -57,7 +57,7 @@ ordered-set = "^4.1.0"
peft = "^0.18.1"
patchelf = "^0.17.2.4"
einops = "^0.8.1"
flashinfer-python = ">=0.3.0,<0.4.0"
flashinfer-python = "^0.6.0"
xgrammar = "0.1.25"
llguidance = "0.7.29"
jsonschema = "^4.26.0"

View File

@ -425,6 +425,8 @@ class FlashInferAttentionMetadata(AttentionMetadata):
paged_kv_indices_buffer=self._paged_kv_indices,
paged_kv_last_page_len_buffer=self._paged_kv_last_page_len,
use_tensor_cores=use_tensor_cores,
backend="fa2"
if torch.cuda.get_device_capability(0) == (9, 0) else "auto",
)
def decode_plan():

View File

@ -26,156 +26,6 @@ from .attention_interface import (
)
# TODO: remove this when flashinfer version is updated to >0.5
def fast_decode_plan(
wrapper: flashinfer.BatchDecodeWithPagedKVCacheWrapper,
indptr: torch.Tensor,
indices: torch.Tensor,
last_page_len: torch.Tensor,
num_qo_heads: int,
num_kv_heads: int,
head_dim: int,
page_size: int,
pos_encoding_mode: str = "NONE",
window_left: int = -1,
logits_soft_cap: Optional[float] = None,
q_data_type: Optional[Union[str, torch.dtype]] = None,
kv_data_type: Optional[Union[str, torch.dtype]] = None,
data_type: Optional[Union[str, torch.dtype]] = None,
sm_scale: Optional[float] = None,
rope_scale: Optional[float] = None,
rope_theta: Optional[float] = None,
non_blocking: bool = True,
fixed_split_size: Optional[int] = None,
disable_split_kv: bool = False,
global_override_indptr_cpu: Optional[torch.Tensor] = None,
) -> None:
"""
Copied from flashinfer.decode.fast_decode_plan in flashinfer version >0.5.
Does not exist in flashinfer version 0.3.1, hence copied here.
"""
batch_size = len(last_page_len)
if logits_soft_cap is None:
logits_soft_cap = 0.0
# Handle data types consistently
if data_type is not None:
if q_data_type is None:
q_data_type = data_type
if kv_data_type is None:
kv_data_type = data_type
elif q_data_type is None:
q_data_type = "float16"
if kv_data_type is None:
kv_data_type = q_data_type
if wrapper.use_tensor_cores:
qo_indptr_host = torch.arange(batch_size + 1, dtype=torch.int32, device="cpu")
# Here we set fixed_split_size to -1 to avoid the assertion error in flashinfer's plan function
if fixed_split_size is None:
fixed_split_size = -1
if wrapper.is_cuda_graph_enabled:
if batch_size != wrapper._fixed_batch_size:
raise ValueError(
"The batch size should be fixed in cudagraph mode, the runtime batch size {} "
" mismatches the batch size set during initialization {}".format(
batch_size, wrapper._fixed_batch_size
)
)
if len(indices) > len(wrapper._paged_kv_indices_buf):
raise ValueError(
"The size of indices should be less than or equal to the allocated buffer"
)
else:
wrapper._paged_kv_indptr_buf = indptr
wrapper._paged_kv_indices_buf = indices
wrapper._paged_kv_last_page_len_buf = last_page_len
if wrapper.use_tensor_cores:
wrapper._qo_indptr_buf = qo_indptr_host.to(wrapper.device, non_blocking=non_blocking)
# Create empty tensors for dtype info if needed
empty_q_data = torch.empty(
0,
dtype=(getattr(torch, q_data_type) if isinstance(q_data_type, str) else q_data_type),
device=wrapper.device,
)
empty_kv_cache = torch.empty(
0,
dtype=(getattr(torch, kv_data_type) if isinstance(kv_data_type, str) else kv_data_type),
device=wrapper.device,
)
indptr_host = (
global_override_indptr_cpu if global_override_indptr_cpu is not None else indptr.cpu()
)
with torch.cuda.device(wrapper.device):
if wrapper.use_tensor_cores:
# ALSO convert last_page_len to CPU
if page_size == 1:
# When page size is 1, last_page_len is always 1.
# Directly construct the host tensor rather than executing a device-to-host copy.
last_page_len_host = torch.ones((batch_size,), dtype=torch.int32, device="cpu")
else:
last_page_len_host = last_page_len.cpu()
kv_lens_arr_host = flashinfer.get_seq_lens(indptr_host, last_page_len_host, page_size)
try:
# Make sure we pass exactly 15 arguments for tensor core version
wrapper._plan_info = wrapper._cached_module.plan(
wrapper._float_workspace_buffer,
wrapper._int_workspace_buffer,
wrapper._pin_memory_int_workspace_buffer,
qo_indptr_host,
indptr_host,
kv_lens_arr_host,
batch_size, # total_num_rows
batch_size,
num_qo_heads,
num_kv_heads,
page_size,
wrapper.is_cuda_graph_enabled,
head_dim,
head_dim,
False, # causal
)
except Exception as e:
raise RuntimeError(f"Error in standard plan: {e}") from e
else:
try:
# Make sure we pass exactly 15 arguments for standard version
wrapper._plan_info = wrapper._cached_module.plan(
wrapper._float_workspace_buffer,
wrapper._int_workspace_buffer,
wrapper._pin_memory_int_workspace_buffer,
indptr_host,
batch_size,
num_qo_heads,
num_kv_heads,
page_size,
wrapper.is_cuda_graph_enabled,
window_left,
logits_soft_cap,
head_dim,
head_dim,
empty_q_data,
empty_kv_cache,
)
except Exception as e:
raise RuntimeError(f"Error in standard plan: {e}") from e
wrapper._pos_encoding_mode = pos_encoding_mode
wrapper._window_left = window_left
wrapper._logits_soft_cap = logits_soft_cap
wrapper._sm_scale = sm_scale
wrapper._rope_scale = rope_scale
wrapper._rope_theta = rope_theta
@dataclass
class PlanParams:
"""Parameters that affect the flashinfer execution plan."""
@ -233,12 +83,14 @@ class _FlashInferPlanner:
paged_kv_indices_buffer=indices,
paged_kv_last_page_len_buffer=last_page_len,
use_tensor_cores=True,
backend="fa2" if torch.cuda.get_device_capability(0) == (9, 0) else "auto",
)
else:
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
self.workspace_buffer,
"NHD",
use_tensor_cores=True,
backend="fa2" if torch.cuda.get_device_capability(0) == (9, 0) else "auto",
)
def init_workspace(self, workspace_buffer: torch.Tensor):
@ -268,7 +120,7 @@ class _FlashInferPlanner:
for plan_params in self.cached_cuda_graph_decode_wrappers:
if plan_params.num_seq == num_seq:
wrapper = self.cached_cuda_graph_decode_wrappers[plan_params]
fast_decode_plan(
flashinfer.decode.fast_decode_plan(
wrapper,
cu_num_pages,
cache_loc,

View File

@ -194,10 +194,10 @@ class TestNemotronMOE(LlmapiAccuracyTestHarness):
# TODO: multi-stream MOE seems to increase the memory usage
kwargs["max_batch_size"] = 32
kwargs["free_mem_ratio"] = 0.4
sampling_params = self.get_default_sampling_params()
with AutoDeployLLM(model=self.MODEL_PATH_BF16,
tokenizer=self.MODEL_PATH_BF16,
**kwargs) as llm:
sampling_params = self.get_default_sampling_params()
task = MMLU(self.MODEL_NAME)
task.evaluate(llm, sampling_params=sampling_params)
task = GSM8K(self.MODEL_NAME)
@ -206,6 +206,7 @@ class TestNemotronMOE(LlmapiAccuracyTestHarness):
@pytest.mark.skip_less_device_memory(32000)
def test_fp8(self):
kwargs = self.get_default_kwargs()
kwargs["max_batch_size"] = 64
with AutoDeployLLM(model=self.MODEL_PATH_FP8,
tokenizer=self.MODEL_PATH_FP8,
**kwargs) as llm:

View File

@ -383,6 +383,7 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backe
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-4-True-False-True] SKIP (https://nvbugs/5810980)
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[latency] SKIP (https://nvbugs/5814309)
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[tp4-mtp_nextn=2-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=True] SKIP (https://nvbugs/5800646)
unittest/_torch/thop/parallel/test_fp4_swizzle.py::test_swizzle_sf SKIP (https://nvbugs/5811159)
unittest/_torch/auto_deploy/unit/multigpu/custom_ops/test_mxfp4_moe_ep.py::test_mxfp4_mlp_ep_dtypes[1-4-6] SKIP (https://nvbugs/5814247)
unittest/_torch/auto_deploy/unit/multigpu/custom_ops/test_mxfp4_moe_ep.py::test_mxfp4_mlp_ep_dtypes[1-4-8] SKIP (https://nvbugs/5814247)
unittest/_torch/auto_deploy/unit/multigpu/test_ad_allreduce_strategies.py::test_allreduce_strategies[AUTO] SKIP (https://nvbugs/5814247)