diff --git a/benchmarks/benchmark_serving_structured_output.py b/benchmarks/benchmark_serving_structured_output.py index 33aca831883..664fa58dd49 100644 --- a/benchmarks/benchmark_serving_structured_output.py +++ b/benchmarks/benchmark_serving_structured_output.py @@ -115,6 +115,39 @@ class SampleRequest: def sample_requests( tokenizer: PreTrainedTokenizerBase, args: argparse.Namespace ) -> list[SampleRequest]: + def _apply_random_prefix( + tokenizer: PreTrainedTokenizerBase, + requests: list[SampleRequest], + prefix_len: int, + seed: int, + ) -> list[SampleRequest]: + if prefix_len <= 0: + return requests + rng = np.random.default_rng(seed) + vocab_size = tokenizer.vocab_size + prohibited = getattr(tokenizer, "all_special_ids", None) or [] + allowed = np.array([i for i in range(vocab_size) if i not in prohibited]) + if len(allowed) == 0: + return requests + prefix_ids = rng.integers(0, len(allowed), size=prefix_len) + prefix_token_ids = allowed[prefix_ids].tolist() + out = [] + for req in requests: + prompt_ids = tokenizer(req.prompt, add_special_tokens=False).input_ids + full_ids = prefix_token_ids + prompt_ids + full_prompt = tokenizer.decode(full_ids, skip_special_tokens=False) + out.append( + SampleRequest( + prompt=full_prompt, + prompt_len=len(tokenizer(full_prompt).input_ids), + expected_output_len=req.expected_output_len, + schema=req.schema, + structure_type=req.structure_type, + completion=req.completion, + ) + ) + return out + if args.dataset == "json" or args.dataset == "json-unique": if args.json_schema_path is None: dir_path = os.path.dirname(os.path.realpath(__file__)) @@ -261,6 +294,9 @@ def sample_requests( ) ) + requests = _apply_random_prefix( + tokenizer, requests, args.random_prefix_len, args.seed + ) return requests @@ -945,6 +981,15 @@ def create_argument_parser(): "results in a more uniform arrival of requests.", ) parser.add_argument("--seed", type=int, default=0) + parser.add_argument( + "--random-prefix-len", + type=int, + default=0, + help=( + "Number of prefix tokens to prepend to every prompt. " + "The same prefix is used for all prompts to enable prefix caching." + ), + ) parser.add_argument( "--trust-remote-code", action="store_true",