diff --git a/setup.py b/setup.py index 1748781985f..b0cca73bb91 100644 --- a/setup.py +++ b/setup.py @@ -693,6 +693,12 @@ class precompiled_wheel_utils: flash_attn_regex = re.compile( r"vllm/vllm_flash_attn/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py" ) + # __init__.py and flash_attn_interface.py are source-controlled + # in vllm and should not be overwritten (matches cmake exclusions) + flash_attn_files_to_skip = { + "vllm/vllm_flash_attn/__init__.py", + "vllm/vllm_flash_attn/flash_attn_interface.py", + } triton_kernels_regex = re.compile( r"vllm/third_party/triton_kernels/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py" ) @@ -705,7 +711,11 @@ class precompiled_wheel_utils: filter(lambda x: x.filename in files_to_copy, wheel.filelist) ) file_members += list( - filter(lambda x: flash_attn_regex.match(x.filename), wheel.filelist) + filter( + lambda x: flash_attn_regex.match(x.filename) + and x.filename not in flash_attn_files_to_skip, + wheel.filelist, + ) ) file_members += list( filter(