mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 14:07:21 +08:00
[TRTLLM-9989][fix] Fix tvm_ffi aaarch64 issue. (#10199)
Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
This commit is contained in:
parent
696f754ef4
commit
1e82ff7a0c
@ -69,7 +69,7 @@ triton==3.5.0
|
||||
tiktoken
|
||||
blobfile
|
||||
openai-harmony==0.0.4
|
||||
nvidia-cutlass-dsl==4.3.1; python_version >= "3.10"
|
||||
nvidia-cutlass-dsl==4.3.4; python_version >= "3.10"
|
||||
plotly
|
||||
numexpr<2.14.0 # WAR for attempted use of nonexistent numpy.typing
|
||||
partial_json_parser
|
||||
|
||||
@ -371,7 +371,7 @@ if IS_CUTLASS_DSL_AVAILABLE:
|
||||
def __init__(self,
|
||||
output_dtype: torch.dtype,
|
||||
to_userbuffers: bool = False,
|
||||
use_tvm_ffi: bool = False):
|
||||
use_tvm_ffi: bool = True):
|
||||
super().__init__()
|
||||
|
||||
if output_dtype != torch.bfloat16:
|
||||
@ -775,7 +775,7 @@ if IS_CUTLASS_DSL_AVAILABLE:
|
||||
alpha: torch.Tensor,
|
||||
output_dtype: torch.dtype,
|
||||
to_userbuffers: bool = False,
|
||||
use_tvm_ffi: bool = False,
|
||||
use_tvm_ffi: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""CuteDSL-based NVFP4 GEMM optimized for Blackwell.
|
||||
|
||||
@ -825,7 +825,7 @@ if IS_CUTLASS_DSL_AVAILABLE:
|
||||
alpha: torch.Tensor, # Match custom op signature
|
||||
output_dtype: torch.dtype,
|
||||
to_userbuffers: bool = False,
|
||||
use_tvm_ffi: bool = False,
|
||||
use_tvm_ffi: bool = True,
|
||||
):
|
||||
# [m, k]
|
||||
shape = list(mat_a.shape)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user