mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None][chore] Use a cached model path for Ray integration test (#8660)
Signed-off-by: Aurelien Chartier <2567591+achartier@users.noreply.github.com>
This commit is contained in:
parent
49974eed75
commit
0a02f5f25d
@ -1,4 +1,5 @@
|
||||
# Generate text asynchronously with Ray orchestrator.
|
||||
import argparse
|
||||
import asyncio
|
||||
|
||||
from tensorrt_llm import LLM, SamplingParams
|
||||
@ -6,6 +7,16 @@ from tensorrt_llm.llmapi import KvCacheConfig
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate text asynchronously with Ray orchestrator.")
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
help=
|
||||
"HuggingFace model name or path to local HF model (default: TinyLlama/TinyLlama-1.1B-Chat-v1.0)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
# Configure KV cache memory usage fraction.
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.5,
|
||||
max_tokens=4096,
|
||||
@ -13,7 +24,7 @@ def main():
|
||||
|
||||
# model could accept HF model name or a path to local HF model.
|
||||
llm = LLM(
|
||||
model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
model=args.model,
|
||||
kv_cache_config=kv_cache_config,
|
||||
max_seq_len=1024,
|
||||
max_batch_size=1,
|
||||
|
||||
@ -14,7 +14,8 @@ def ray_example_root(llm_root):
|
||||
|
||||
def test_llm_inference_async_ray(ray_example_root, llm_venv):
|
||||
script_path = os.path.join(ray_example_root, "llm_inference_async_ray.py")
|
||||
venv_check_call(llm_venv, [script_path])
|
||||
model_path = f"{llm_models_root()}/llama-models-v2/TinyLlama-1.1B-Chat-v1.0"
|
||||
venv_check_call(llm_venv, [script_path, "--model", model_path])
|
||||
|
||||
|
||||
@pytest.mark.skip_less_device(2)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user