TensorRT-LLMs/tests/unittest/utils/test_prebuilt_whl_cpp_extensions.py
Chang Liu 3a79d03874
[https://nvbugs/5617275][fix] Extract py files from prebuilt wheel for editable installs (#8738)
Signed-off-by: Chang Liu (Enterprise Products) <9713593+chang-l@users.noreply.github.com>
2025-10-30 21:40:22 -07:00

68 lines
2.4 KiB
Python

"""
Test that prebuilt wheel extraction includes all necessary Python files.
"""
from pathlib import Path
def test_cpp_extension_wrapper_files_exist():
"""Verify that C++ extension wrapper Python files were extracted from prebuilt wheel."""
import tensorrt_llm
trtllm_root = Path(tensorrt_llm.__file__).parent
# C++ extensions that have Python wrapper files generated during build
required_files = {
'deep_gemm':
['__init__.py', 'testing/__init__.py', 'utils/__init__.py'],
'deep_ep': ['__init__.py', 'buffer.py', 'utils.py'],
'flash_mla': ['__init__.py', 'flash_mla_interface.py'],
}
missing_files = []
for ext_dir, files in required_files.items():
for file in files:
file_path = trtllm_root / ext_dir / file
if not file_path.exists():
missing_files.append(str(file_path.relative_to(trtllm_root)))
assert not missing_files, (
f"Missing Python wrapper files for C++ extensions: {missing_files}\n"
f"This indicates setup.py may not be extracting Python files from prebuilt wheels.\n"
f"Check setup.py extract_from_precompiled() function.")
def test_cpp_extensions_importable():
"""Verify that C++ extension wrappers can be imported successfully."""
import_tests = [
('tensorrt_llm.deep_gemm', 'fp8_mqa_logits'),
('tensorrt_llm.deep_ep', 'Buffer'),
('tensorrt_llm.flash_mla', 'flash_mla_interface'),
]
failed_imports = []
for module_name, attr_name in import_tests:
try:
module = __import__(module_name, fromlist=[attr_name])
if not hasattr(module, attr_name):
failed_imports.append(
f"{module_name}.{attr_name} (attribute not found)")
except ImportError as e:
failed_imports.append(f"{module_name} (ImportError: {e})")
assert not failed_imports, (
f"Failed to import C++ extension wrappers: {failed_imports}\n"
f"This may indicate missing Python files or circular import issues.")
if __name__ == '__main__':
print("Testing C++ extension wrapper files...")
test_cpp_extension_wrapper_files_exist()
print("✅ All required Python files exist")
print("\nTesting C++ extension imports...")
test_cpp_extensions_importable()
print("✅ All imports successful")
print("\n✅ All tests passed!")