diff --git a/tests/unittest/test_pip_install.py b/tests/unittest/test_pip_install.py index 346b5051ac..e0e2599eb0 100644 --- a/tests/unittest/test_pip_install.py +++ b/tests/unittest/test_pip_install.py @@ -77,12 +77,17 @@ def get_cpython_version(): return "cp{}{}".format(python_version[0], python_version[1]) -def download_wheel(args): - if not args.wheel_path.startswith(("http://", "https://")): - args.wheel_path = "https://" + args.wheel_path - res = requests.get(args.wheel_path) +def get_wheel_url(wheel_path): + """Get direct wheel URL from wheel_path (directory listing or direct URL).""" + if not wheel_path.startswith(("http://", "https://")): + wheel_path = "https://" + wheel_path + + if wheel_path.endswith(".whl"): + return wheel_path + + res = requests.get(wheel_path) if res.status_code != 200: - print(f"Fail to get the result of {args.wheel_path}") + print(f"Fail to get the result of {wheel_path}") exit(1) wheel_name = None for line in res.text.split("\n"): @@ -96,11 +101,15 @@ def download_wheel(args): wheel_name = name break if not wheel_name: - print(f"Fail to get the wheel name of {args.wheel_path}") + print(f"Fail to get the wheel name of {wheel_path}") exit(1) - if args.wheel_path[-1] == "/": - args.wheel_path = args.wheel_path[:-1] - wheel_url = f"{args.wheel_path}/{wheel_name}" + if wheel_path[-1] == "/": + wheel_path = wheel_path[:-1] + return f"{wheel_path}/{wheel_name}" + + +def download_wheel(args): + wheel_url = get_wheel_url(args.wheel_path) subprocess.check_call("rm *.whl || true", shell=True) subprocess.check_call(f"apt-get install -y wget && wget -q {wheel_url}", shell=True) @@ -168,28 +177,8 @@ def create_link_for_models(): os.symlink(src, dst, target_is_directory=True) -def test_pip_install(): - parser = argparse.ArgumentParser(description="Check Pip Install") - parser.add_argument("--wheel_path", - type=str, - required=False, - default="Default", - help="The wheel path") - args = parser.parse_args() - - print("########## Install required system libs ##########") - if not os.path.exists("/usr/local/mpi/bin/mpicc"): - subprocess.check_call("apt-get -y install libopenmpi-dev", shell=True) - - subprocess.check_call("apt-get -y install libzmq3-dev", shell=True) - subprocess.check_call("apt-get -y install python3-pip", shell=True) - subprocess.check_call("pip3 install --upgrade pip || true", shell=True) - subprocess.check_call("pip3 install --upgrade setuptools || true", - shell=True) - - download_wheel(args) - install_tensorrt_llm() - +def run_sanity_check(examples_path="../../examples"): + """Run sanity checks after installation.""" print("########## Test import tensorrt_llm ##########") subprocess.check_call( 'python3 -c "import tensorrt_llm; print(tensorrt_llm.__version__)"', @@ -202,8 +191,81 @@ def test_pip_install(): print("########## Test quickstart example ##########") subprocess.check_call( - "python3 ../../examples/llm-api/quickstart_example.py", shell=True) + f"python3 {examples_path}/llm-api/quickstart_example.py", shell=True) + + +def install_system_libs(): + """Install required system libraries for tensorrt_llm.""" + print("########## Install required system libs ##########") + if not os.path.exists("/usr/local/mpi/bin/mpicc"): + subprocess.check_call("apt-get -y install libopenmpi-dev", shell=True) + + subprocess.check_call("apt-get -y install libzmq3-dev", shell=True) + subprocess.check_call("apt-get -y install python3-pip", shell=True) + subprocess.check_call("pip3 install --upgrade pip || true", shell=True) + subprocess.check_call("pip3 install --upgrade setuptools || true", + shell=True) + + +def test_pip_install(args): + install_system_libs() + + download_wheel(args) + install_tensorrt_llm() + + run_sanity_check() + + +def test_python_builds(args): + """Test Python builds using precompiled wheel (sanity check only). + + This test verifies the TRTLLM_PRECOMPILED_LOCATION workflow: + 1. Install required system libs + 2. Use precompiled wheel URL to extract C++ bindings + 3. Build Python-only wheel (editable install with --no-deps) + 4. Verify installation works correctly + 5. Run quickstart example + 6. Clean up editable install to leave env in clean state + """ + print("########## Python Builds Test ##########") + + install_system_libs() + + wheel_url = get_wheel_url(args.wheel_path) + print(f"Using precompiled wheel: {wheel_url}") + + repo_root = os.path.abspath( + os.path.join(os.path.dirname(__file__), "..", "..")) + print(f"Repository root: {repo_root}") + + # Uninstall existing tensorrt_llm to test fresh editable install + subprocess.run("pip3 uninstall -y tensorrt_llm || true", + shell=True, + check=False) + + print("########## Install with TRTLLM_PRECOMPILED_LOCATION ##########") + env = os.environ.copy() + env["TRTLLM_PRECOMPILED_LOCATION"] = wheel_url + + # Use --no-deps to avoid changing torch/torchvision versions. + subprocess.check_call(["pip3", "install", "-e", ".", "--no-deps", "-v"], + cwd=repo_root, + env=env) + run_sanity_check(examples_path=f"{repo_root}/examples") + + # Clean up: uninstall editable install to leave env in clean state + print("########## Clean up editable install ##########") + subprocess.run("pip3 uninstall -y tensorrt_llm || true", + shell=True, + check=False) if __name__ == "__main__": - test_pip_install() + parser = argparse.ArgumentParser(description="Check Pip Install") + parser.add_argument("--wheel_path", + type=str, + required=True, + help="The wheel path") + args = parser.parse_args() + test_python_builds(args) + test_pip_install(args)