mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[Misc] Add common random prefix option to structured-output serving benchmark (#41632)
Signed-off-by: Viktor Pus <viktorpus@tenstorrent.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -115,6 +115,39 @@ class SampleRequest:
|
||||
def sample_requests(
|
||||
tokenizer: PreTrainedTokenizerBase, args: argparse.Namespace
|
||||
) -> list[SampleRequest]:
|
||||
def _apply_random_prefix(
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
requests: list[SampleRequest],
|
||||
prefix_len: int,
|
||||
seed: int,
|
||||
) -> list[SampleRequest]:
|
||||
if prefix_len <= 0:
|
||||
return requests
|
||||
rng = np.random.default_rng(seed)
|
||||
vocab_size = tokenizer.vocab_size
|
||||
prohibited = getattr(tokenizer, "all_special_ids", None) or []
|
||||
allowed = np.array([i for i in range(vocab_size) if i not in prohibited])
|
||||
if len(allowed) == 0:
|
||||
return requests
|
||||
prefix_ids = rng.integers(0, len(allowed), size=prefix_len)
|
||||
prefix_token_ids = allowed[prefix_ids].tolist()
|
||||
out = []
|
||||
for req in requests:
|
||||
prompt_ids = tokenizer(req.prompt, add_special_tokens=False).input_ids
|
||||
full_ids = prefix_token_ids + prompt_ids
|
||||
full_prompt = tokenizer.decode(full_ids, skip_special_tokens=False)
|
||||
out.append(
|
||||
SampleRequest(
|
||||
prompt=full_prompt,
|
||||
prompt_len=len(tokenizer(full_prompt).input_ids),
|
||||
expected_output_len=req.expected_output_len,
|
||||
schema=req.schema,
|
||||
structure_type=req.structure_type,
|
||||
completion=req.completion,
|
||||
)
|
||||
)
|
||||
return out
|
||||
|
||||
if args.dataset == "json" or args.dataset == "json-unique":
|
||||
if args.json_schema_path is None:
|
||||
dir_path = os.path.dirname(os.path.realpath(__file__))
|
||||
@@ -261,6 +294,9 @@ def sample_requests(
|
||||
)
|
||||
)
|
||||
|
||||
requests = _apply_random_prefix(
|
||||
tokenizer, requests, args.random_prefix_len, args.seed
|
||||
)
|
||||
return requests
|
||||
|
||||
|
||||
@@ -945,6 +981,15 @@ def create_argument_parser():
|
||||
"results in a more uniform arrival of requests.",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument(
|
||||
"--random-prefix-len",
|
||||
type=int,
|
||||
default=0,
|
||||
help=(
|
||||
"Number of prefix tokens to prepend to every prompt. "
|
||||
"The same prefix is used for all prompts to enable prefix caching."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--trust-remote-code",
|
||||
action="store_true",
|
||||
|
||||
Reference in New Issue
Block a user