mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[None][chore] Restore asserts in pytorch flow LoRA tests (#8227)
Signed-off-by: Amit Zuker <203509407+amitz-nv@users.noreply.github.com>
This commit is contained in:
parent
e10121345e
commit
d560054e1b
@ -268,9 +268,13 @@ def llama_7b_lora_from_dir_test_harness(**llm_kwargs) -> None:
|
||||
max_lora_rank=8,
|
||||
max_loras=2,
|
||||
max_cpu_loras=2)
|
||||
llm = LLM(model=f"{llm_models_root()}/llama-models/llama-7b-hf",
|
||||
lora_config=lora_config,
|
||||
**llm_kwargs)
|
||||
llm = LLM(
|
||||
model=f"{llm_models_root()}/llama-models/llama-7b-hf",
|
||||
lora_config=lora_config,
|
||||
# Disable CUDA graph
|
||||
# TODO: remove this once we have a proper fix for CUDA graph in LoRA
|
||||
cuda_graph_config=None,
|
||||
**llm_kwargs)
|
||||
try:
|
||||
prompts = [
|
||||
"美国的首都在哪里? \n答案:",
|
||||
@ -286,10 +290,7 @@ def llama_7b_lora_from_dir_test_harness(**llm_kwargs) -> None:
|
||||
outputs = llm.generate(prompts,
|
||||
sampling_params,
|
||||
lora_request=lora_request)
|
||||
# TODO: remove this once we have a proper fix for CUDA graph in LoRA
|
||||
# assert similar(outputs[0].outputs[0].text, references[0])
|
||||
print(f"lora output: {outputs[0].outputs[0].text}")
|
||||
print(f"ref output: {references[0]}")
|
||||
assert similar(outputs[0].outputs[0].text, references[0])
|
||||
finally:
|
||||
llm.shutdown()
|
||||
|
||||
@ -305,7 +306,12 @@ def test_llama_7b_lora_default_modules() -> None:
|
||||
|
||||
hf_model_dir = f"{llm_models_root()}/llama-models/llama-7b-hf"
|
||||
|
||||
llm = LLM(model=hf_model_dir, lora_config=lora_config)
|
||||
llm = LLM(
|
||||
model=hf_model_dir,
|
||||
lora_config=lora_config,
|
||||
# Disable CUDA graph
|
||||
# TODO: remove this once we have a proper fix for CUDA graph in LoRA
|
||||
cuda_graph_config=None)
|
||||
|
||||
hf_lora_dir = f"{llm_models_root()}/llama-models/luotuo-lora-7b-0.1"
|
||||
try:
|
||||
@ -324,9 +330,7 @@ def test_llama_7b_lora_default_modules() -> None:
|
||||
sampling_params,
|
||||
lora_request=lora_request)
|
||||
|
||||
# assert similar(outputs[0].outputs[0].text, references[0])
|
||||
print(f"lora output: {outputs[0].outputs[0].text}")
|
||||
print(f"ref output: {references[0]}")
|
||||
assert similar(outputs[0].outputs[0].text, references[0])
|
||||
finally:
|
||||
llm.shutdown()
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user