TensorRT-LLMs/tests/unittest/scaffolding/test_scaffolding.py
WeiHaocheng cc286687c4
[None][feat] Refactor scaffolding streaming feature and fix openai wo… (#8622)
Signed-off-by: Fred Wei <20514172+WeiHaocheng@users.noreply.github.com>
2025-10-30 16:02:40 +08:00

92 lines
3.4 KiB
Python

# autoflake: skip_file
from scaffolding.test_worker import (create_trtllm_worker,
deepseek_distill_7b_path, default_prompt)
from tensorrt_llm.scaffolding import (MajorityVoteController,
NativeGenerationController,
ScaffoldingLlm)
def create_scaffolding_llm_with_native_generation_controller(
deepseek_distill_7b_path):
trtllm_worker = create_trtllm_worker(deepseek_distill_7b_path)
prototype_generation_controller = NativeGenerationController(
sampling_params={
"max_tokens": 8,
"temperature": 0.7,
"top_p": 0.9,
"top_k": 50
})
return ScaffoldingLlm(
prototype_generation_controller,
{NativeGenerationController.WorkerTag.GENERATION: trtllm_worker},
)
def create_scaffolding_llm_with_majority_vote_controller(
deepseek_distill_7b_path, samples_num):
trtllm_worker = create_trtllm_worker(deepseek_distill_7b_path)
workers = {}
prototype_generation_controller = NativeGenerationController()
workers[NativeGenerationController.WorkerTag.GENERATION] = trtllm_worker
prototype_majority_vote_controller = MajorityVoteController(
prototype_generation_controller,
default_sample_num=samples_num,
)
llm = ScaffoldingLlm(
prototype_majority_vote_controller,
workers=workers,
)
return llm
def test_unbatched_scaffolding_sync(default_prompt, deepseek_distill_7b_path):
scaffolding_llm = create_scaffolding_llm_with_native_generation_controller(
deepseek_distill_7b_path)
result = scaffolding_llm.generate(default_prompt)
assert isinstance(result.outputs[0].text, str) and len(
result.outputs[0].text) > 0, "Output should be a non-empty string"
scaffolding_llm.shutdown(shutdown_workers=True)
def test_batched_scaffolding_sync(default_prompt, deepseek_distill_7b_path):
scaffolding_llm = create_scaffolding_llm_with_native_generation_controller(
deepseek_distill_7b_path)
batch_size = 3
prompts = [default_prompt] * batch_size
results = scaffolding_llm.generate(prompts)
assert len(results) == batch_size
for result in results:
assert isinstance(result.outputs[0].text, str) and len(
result.outputs[0].text) > 0, "Output should be a non-empty string"
scaffolding_llm.shutdown(shutdown_workers=True)
def test_async_scaffolding_generation(default_prompt, deepseek_distill_7b_path):
async def run_async_test():
scaffolding_llm = create_scaffolding_llm_with_native_generation_controller(
deepseek_distill_7b_path)
future = scaffolding_llm.generate_async(default_prompt)
result = await future.aresult()
assert isinstance(result.outputs[0].text, str) and len(
result.outputs[0].text) > 0, "Output should be a non-empty string"
scaffolding_llm.shutdown(shutdown_workers=True)
import asyncio
asyncio.run(run_async_test())
def test_majority_vote(default_prompt, deepseek_distill_7b_path):
scaffolding_llm = create_scaffolding_llm_with_majority_vote_controller(
deepseek_distill_7b_path, samples_num=3)
result = scaffolding_llm.generate(default_prompt)
assert isinstance(result.outputs[0].text, str) and len(
result.outputs[0].text) > 0, "Output should be a non-empty string"
scaffolding_llm.shutdown(shutdown_workers=True)