mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[Frontend] Remove frontend pooling multi task support. (#37861)
Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io> Signed-off-by: wang.yuqi <noooop@126.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
This commit is contained in:
@@ -286,10 +286,10 @@ Pooling models now support token-wise task.
|
||||
|
||||
### Score task
|
||||
|
||||
`score` task is deprecated and will be removed in v0.20. Please use `classify` instead. Only when a
|
||||
classification model outputs num_labels equal to 1 can it be used as a scoring model and have its scoring API enabled.
|
||||
`score` task have has been removed in v0.21, use `classify` instead. Only when a classification model outputs num_labels
|
||||
equal to 1 can it be used as a scoring model and have its scoring API enabled.
|
||||
|
||||
### Pooling multitask support
|
||||
|
||||
Pooling multitask support is deprecated and will be removed in v0.20. When the default pooling task is not what you want,
|
||||
Pooling multitask support has been removed in v0.21. When the default pooling task is not what you want,
|
||||
you need to manually specify it via `PoolerConfig(task=<task>)` offline or `--pooler-config.task <task>` online.
|
||||
|
||||
@@ -4,68 +4,74 @@
|
||||
import torch
|
||||
|
||||
from vllm import LLM
|
||||
from vllm.config import PoolerConfig
|
||||
from vllm.inputs import TextPrompt
|
||||
from vllm.multimodal.utils import fetch_image
|
||||
|
||||
# Initialize model
|
||||
model = LLM(
|
||||
model="jinaai/jina-embeddings-v4-vllm-text-matching",
|
||||
runner="pooling",
|
||||
max_model_len=1024,
|
||||
gpu_memory_utilization=0.8,
|
||||
)
|
||||
|
||||
# Create text prompts
|
||||
text1 = "Ein wunderschöner Sonnenuntergang am Strand"
|
||||
text1_prompt = TextPrompt(prompt=f"Query: {text1}")
|
||||
def main():
|
||||
# Initialize model
|
||||
model = LLM(
|
||||
model="jinaai/jina-embeddings-v4-vllm-text-matching",
|
||||
pooler_config=PoolerConfig(task="token_embed"),
|
||||
runner="pooling",
|
||||
max_model_len=1024,
|
||||
gpu_memory_utilization=0.8,
|
||||
)
|
||||
|
||||
text2 = "浜辺に沈む美しい夕日"
|
||||
text2_prompt = TextPrompt(prompt=f"Query: {text2}")
|
||||
# Create text prompts
|
||||
text1 = "Ein wunderschöner Sonnenuntergang am Strand"
|
||||
text1_prompt = TextPrompt(prompt=f"Query: {text1}")
|
||||
|
||||
# Create image prompt
|
||||
image = fetch_image(
|
||||
"https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/eskimo.jpg" # noqa: E501
|
||||
)
|
||||
image_prompt = TextPrompt(
|
||||
prompt="<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|>\n", # noqa: E501
|
||||
multi_modal_data={"image": image},
|
||||
)
|
||||
text2 = "浜辺に沈む美しい夕日"
|
||||
text2_prompt = TextPrompt(prompt=f"Query: {text2}")
|
||||
|
||||
# Encode all prompts
|
||||
prompts = [text1_prompt, text2_prompt, image_prompt]
|
||||
outputs = model.encode(prompts, pooling_task="token_embed")
|
||||
# Create image prompt
|
||||
image = fetch_image(
|
||||
"https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/eskimo.jpg" # noqa: E501
|
||||
)
|
||||
image_prompt = TextPrompt(
|
||||
prompt="<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|>\n", # noqa: E501
|
||||
multi_modal_data={"image": image},
|
||||
)
|
||||
|
||||
# Encode all prompts
|
||||
prompts = [text1_prompt, text2_prompt, image_prompt]
|
||||
outputs = model.encode(prompts, pooling_task="token_embed")
|
||||
|
||||
def get_embeddings(outputs):
|
||||
VISION_START_TOKEN_ID, VISION_END_TOKEN_ID = 151652, 151653
|
||||
|
||||
embeddings = []
|
||||
for output in outputs:
|
||||
if VISION_START_TOKEN_ID in output.prompt_token_ids:
|
||||
# Gather only vision tokens
|
||||
img_start_pos = torch.where(
|
||||
torch.tensor(output.prompt_token_ids) == VISION_START_TOKEN_ID
|
||||
)[0][0]
|
||||
img_end_pos = torch.where(
|
||||
torch.tensor(output.prompt_token_ids) == VISION_END_TOKEN_ID
|
||||
)[0][0]
|
||||
embeddings_tensor = output.outputs.data.detach().clone()[
|
||||
img_start_pos : img_end_pos + 1
|
||||
]
|
||||
else:
|
||||
# Use all tokens for text-only prompts
|
||||
embeddings_tensor = output.outputs.data.detach().clone()
|
||||
|
||||
# Pool and normalize embeddings
|
||||
pooled_output = (
|
||||
embeddings_tensor.sum(dim=0, dtype=torch.float32)
|
||||
/ embeddings_tensor.shape[0]
|
||||
)
|
||||
embeddings.append(torch.nn.functional.normalize(pooled_output, dim=-1))
|
||||
return embeddings
|
||||
|
||||
embeddings = get_embeddings(outputs)
|
||||
|
||||
for embedding in embeddings:
|
||||
print(embedding.shape)
|
||||
|
||||
|
||||
def get_embeddings(outputs):
|
||||
VISION_START_TOKEN_ID, VISION_END_TOKEN_ID = 151652, 151653
|
||||
|
||||
embeddings = []
|
||||
for output in outputs:
|
||||
if VISION_START_TOKEN_ID in output.prompt_token_ids:
|
||||
# Gather only vision tokens
|
||||
img_start_pos = torch.where(
|
||||
torch.tensor(output.prompt_token_ids) == VISION_START_TOKEN_ID
|
||||
)[0][0]
|
||||
img_end_pos = torch.where(
|
||||
torch.tensor(output.prompt_token_ids) == VISION_END_TOKEN_ID
|
||||
)[0][0]
|
||||
embeddings_tensor = output.outputs.data.detach().clone()[
|
||||
img_start_pos : img_end_pos + 1
|
||||
]
|
||||
else:
|
||||
# Use all tokens for text-only prompts
|
||||
embeddings_tensor = output.outputs.data.detach().clone()
|
||||
|
||||
# Pool and normalize embeddings
|
||||
pooled_output = (
|
||||
embeddings_tensor.sum(dim=0, dtype=torch.float32)
|
||||
/ embeddings_tensor.shape[0]
|
||||
)
|
||||
embeddings.append(torch.nn.functional.normalize(pooled_output, dim=-1))
|
||||
return embeddings
|
||||
|
||||
|
||||
embeddings = get_embeddings(outputs)
|
||||
|
||||
for embedding in embeddings:
|
||||
print(embedding.shape)
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
from argparse import Namespace
|
||||
|
||||
from vllm import LLM, EngineArgs
|
||||
from vllm.config import PoolerConfig
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
@@ -13,6 +14,7 @@ def parse_args():
|
||||
# Set example specific arguments
|
||||
parser.set_defaults(
|
||||
model="BAAI/bge-m3",
|
||||
pooler_config=PoolerConfig(task="token_embed"),
|
||||
runner="pooling",
|
||||
enforce_eager=True,
|
||||
)
|
||||
@@ -32,15 +34,6 @@ def main(args: Namespace):
|
||||
# You should pass runner="pooling" for embedding models
|
||||
llm = LLM(**vars(args))
|
||||
|
||||
# Generate embedding. The output is a list of EmbeddingRequestOutputs.
|
||||
outputs = llm.embed(prompts)
|
||||
|
||||
# Print the outputs.
|
||||
print("\nGenerated Outputs:\n" + "-" * 60)
|
||||
for prompt, output in zip(prompts, outputs):
|
||||
embeds = output.outputs.embedding
|
||||
print(len(embeds))
|
||||
|
||||
# Generate embedding for each token. The output is a list of PoolingRequestOutput.
|
||||
outputs = llm.encode(prompts, pooling_task="token_embed")
|
||||
|
||||
@@ -50,6 +43,20 @@ def main(args: Namespace):
|
||||
multi_vector = output.outputs.data
|
||||
print(multi_vector.shape)
|
||||
|
||||
query = "What is the capital of France?"
|
||||
documents = [
|
||||
"The capital of Brazil is Brasilia.",
|
||||
"The capital of France is Paris.",
|
||||
]
|
||||
# Generate scores.
|
||||
outputs = llm.score(query, documents)
|
||||
# Print the outputs.
|
||||
print("\nGenerated Outputs:\n" + "-" * 60)
|
||||
for document, output in zip(documents, outputs):
|
||||
score = output.outputs.score
|
||||
print(f"Pair: {[query, document]!r} \nScore: {score}")
|
||||
print("-" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
|
||||
@@ -7,10 +7,11 @@ Example online usage of Pooling API for multi vector retrieval.
|
||||
Run `vllm serve <model> --runner pooling`
|
||||
to start up the server in vLLM. e.g.
|
||||
|
||||
vllm serve BAAI/bge-m3
|
||||
vllm serve BAAI/bge-m3 --pooler-config.task token_embed
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import pprint
|
||||
|
||||
import requests
|
||||
import torch
|
||||
@@ -32,7 +33,8 @@ def parse_args():
|
||||
|
||||
|
||||
def main(args):
|
||||
api_url = f"http://{args.host}:{args.port}/pooling"
|
||||
pooling_url = f"http://{args.host}:{args.port}/pooling"
|
||||
score_url = f"http://{args.host}:{args.port}/score"
|
||||
model_name = args.model
|
||||
|
||||
prompts = [
|
||||
@@ -43,11 +45,23 @@ def main(args):
|
||||
]
|
||||
prompt = {"model": model_name, "input": prompts}
|
||||
|
||||
pooling_response = post_http_request(prompt=prompt, api_url=api_url)
|
||||
pooling_response = post_http_request(prompt=prompt, api_url=pooling_url)
|
||||
for output in pooling_response.json()["data"]:
|
||||
multi_vector = torch.tensor(output["data"])
|
||||
print(multi_vector.shape)
|
||||
|
||||
queries = "What is the capital of France?"
|
||||
documents = [
|
||||
"The capital of Brazil is Brasilia.",
|
||||
"The capital of France is Paris.",
|
||||
]
|
||||
prompt = {"model": model_name, "queries": queries, "documents": documents}
|
||||
score_response = post_http_request(prompt=prompt, api_url=score_url)
|
||||
print("\nPrompt when queries is string and documents is a list:")
|
||||
pprint.pprint(prompt)
|
||||
print("\nScore Response:")
|
||||
pprint.pprint(score_response.json())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import logging
|
||||
import weakref
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.models.utils import softmax
|
||||
from vllm import LLM, ClassificationRequestOutput, PoolingParams, PoolingRequestOutput
|
||||
from vllm import LLM, ClassificationRequestOutput, PoolingParams
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.tasks import PoolingTask
|
||||
|
||||
@@ -66,18 +65,6 @@ def test_list_prompts(llm: LLM):
|
||||
assert len(outputs[i].outputs.probs) == num_labels
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_token_classify(llm: LLM, caplog_vllm):
|
||||
with caplog_vllm.at_level(level=logging.WARNING, logger="vllm"):
|
||||
outputs = llm.encode(prompt, pooling_task="token_classify", use_tqdm=False)
|
||||
assert "deprecated" in caplog_vllm.text
|
||||
|
||||
assert len(outputs) == 1
|
||||
assert isinstance(outputs[0], PoolingRequestOutput)
|
||||
assert outputs[0].prompt_token_ids == prompt_token_ids
|
||||
assert outputs[0].outputs.data.shape == (len(prompt_token_ids), num_labels)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_pooling_params(llm: LLM):
|
||||
def get_outputs(use_activation):
|
||||
@@ -110,10 +97,12 @@ def test_score_api(llm: LLM):
|
||||
llm.score("ping", "pong", use_tqdm=False)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("task", ["embed", "token_embed", "plugin"])
|
||||
@pytest.mark.parametrize("task", ["embed", "token_embed", "token_classify", "plugin"])
|
||||
def test_unsupported_tasks(llm: LLM, task: PoolingTask):
|
||||
if task == "plugin":
|
||||
err_msg = "No IOProcessor plugin installed."
|
||||
elif task == "token_classify":
|
||||
err_msg = "Try switching the model's pooling_task via.+"
|
||||
else:
|
||||
err_msg = "Embedding API is not supported by this model.+"
|
||||
with pytest.raises(ValueError, match=err_msg):
|
||||
|
||||
@@ -436,26 +436,7 @@ async def test_pooling_classify(server: RemoteOpenAIServer, model_name: str):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_pooling_token_classify(server: RemoteOpenAIServer, model_name: str):
|
||||
task = "token_classify"
|
||||
response = requests.post(
|
||||
server.url_for("pooling"),
|
||||
json={
|
||||
"model": model_name,
|
||||
"input": input_text,
|
||||
"encoding_format": "float",
|
||||
"task": task,
|
||||
},
|
||||
)
|
||||
poolings = PoolingResponse.model_validate(response.json())
|
||||
assert len(poolings.data) == 1
|
||||
assert len(poolings.data[0].data) == 8
|
||||
assert len(poolings.data[0].data[0]) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("task", ["embed", "token_embed", "plugin"])
|
||||
@pytest.mark.parametrize("task", ["embed", "token_embed", "token_classify", "plugin"])
|
||||
async def test_pooling_not_supported(
|
||||
server: RemoteOpenAIServer, model_name: str, task: str
|
||||
):
|
||||
@@ -469,8 +450,11 @@ async def test_pooling_not_supported(
|
||||
},
|
||||
)
|
||||
assert response.json()["error"]["type"] == "BadRequestError"
|
||||
|
||||
if task == "plugin":
|
||||
err_msg = "No IOProcessor plugin installed."
|
||||
elif task == "token_classify":
|
||||
err_msg = "Try switching the model's pooling_task via"
|
||||
else:
|
||||
err_msg = f"Unsupported task: {task!r}"
|
||||
assert response.json()["error"]["message"].startswith(err_msg)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import logging
|
||||
import weakref
|
||||
|
||||
import pytest
|
||||
@@ -38,11 +37,11 @@ def llm():
|
||||
seed=0,
|
||||
attention_config=attention_config,
|
||||
)
|
||||
assert embedding_size == llm.model_config.embedding_size
|
||||
|
||||
yield weakref.proxy(llm)
|
||||
|
||||
del llm
|
||||
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
@@ -74,16 +73,6 @@ def test_list_prompts(llm: LLM):
|
||||
assert len(outputs[i].outputs.embedding) == embedding_size
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_token_embed(llm: LLM, caplog_vllm):
|
||||
with caplog_vllm.at_level(level=logging.WARNING, logger="vllm"):
|
||||
outputs = llm.encode(prompt, pooling_task="token_embed", use_tqdm=False)
|
||||
assert "deprecated" in caplog_vllm.text
|
||||
|
||||
multi_vector = outputs[0].outputs.data
|
||||
assert multi_vector.shape == (11, 384)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_pooling_params(llm: LLM):
|
||||
def get_outputs(normalize):
|
||||
@@ -107,10 +96,14 @@ def test_pooling_params(llm: LLM):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("task", ["token_classify", "classify", "plugin"])
|
||||
@pytest.mark.parametrize(
|
||||
"task", ["token_classify", "classify", "token_embed", "plugin"]
|
||||
)
|
||||
def test_unsupported_tasks(llm: LLM, task: PoolingTask):
|
||||
if task == "plugin":
|
||||
err_msg = "No IOProcessor plugin installed."
|
||||
elif task == "token_embed":
|
||||
err_msg = "Try switching the model's pooling_task via.+"
|
||||
else:
|
||||
err_msg = "Classification API is not supported by this model.+"
|
||||
with pytest.raises(ValueError, match=err_msg):
|
||||
|
||||
@@ -732,28 +732,9 @@ async def test_pooling_embed(server: RemoteOpenAIServer, model_name: str):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_pooling_token_embed(server: RemoteOpenAIServer, model_name: str):
|
||||
task = "token_embed"
|
||||
response = requests.post(
|
||||
server.url_for("pooling"),
|
||||
json={
|
||||
"model": model_name,
|
||||
"input": input_text,
|
||||
"encoding_format": "float",
|
||||
"task": task,
|
||||
},
|
||||
)
|
||||
|
||||
poolings = PoolingResponse.model_validate(response.json())
|
||||
|
||||
assert len(poolings.data) == 1
|
||||
assert len(poolings.data[0].data) == len(input_tokens)
|
||||
assert len(poolings.data[0].data[0]) == 384
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("task", ["classify", "token_classify", "plugin"])
|
||||
@pytest.mark.parametrize(
|
||||
"task", ["classify", "token_classify", "token_embed", "plugin"]
|
||||
)
|
||||
async def test_pooling_not_supported(
|
||||
server: RemoteOpenAIServer, model_name: str, task: str
|
||||
):
|
||||
@@ -769,6 +750,8 @@ async def test_pooling_not_supported(
|
||||
assert response.json()["error"]["type"] == "BadRequestError"
|
||||
if task == "plugin":
|
||||
err_msg = "No IOProcessor plugin installed."
|
||||
elif task == "token_embed":
|
||||
err_msg = "Try switching the model's pooling_task via"
|
||||
else:
|
||||
err_msg = f"Unsupported task: {task!r}"
|
||||
assert response.json()["error"]["message"].startswith(err_msg)
|
||||
|
||||
@@ -452,25 +452,6 @@ async def test_pooling_classify(server: RemoteOpenAIServer):
|
||||
assert len(poolings.data[0].data) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pooling_token_classify(server: RemoteOpenAIServer):
|
||||
response = requests.post(
|
||||
server.url_for("pooling"),
|
||||
json={
|
||||
"model": MODEL_NAME,
|
||||
"task": "token_classify",
|
||||
"input": input_text,
|
||||
"encoding_format": "float",
|
||||
},
|
||||
)
|
||||
|
||||
poolings = PoolingResponse.model_validate(response.json())
|
||||
|
||||
assert len(poolings.data) == 1
|
||||
assert len(poolings.data[0].data) == len(input_tokens)
|
||||
assert len(poolings.data[0].data[0]) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rerank_max_tokens_per_doc(
|
||||
server: RemoteOpenAIServer,
|
||||
@@ -544,7 +525,7 @@ async def test_rerank_max_tokens_per_doc_validation(
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("task", ["embed", "token_embed", "plugin"])
|
||||
@pytest.mark.parametrize("task", ["embed", "token_embed", "token_classify", "plugin"])
|
||||
async def test_pooling_not_supported(server: RemoteOpenAIServer, task: str):
|
||||
response = requests.post(
|
||||
server.url_for("pooling"),
|
||||
@@ -558,6 +539,8 @@ async def test_pooling_not_supported(server: RemoteOpenAIServer, task: str):
|
||||
assert response.json()["error"]["type"] == "BadRequestError"
|
||||
if task == "plugin":
|
||||
err_msg = "No IOProcessor plugin installed."
|
||||
elif task == "token_classify":
|
||||
err_msg = "Try switching the model's pooling_task via"
|
||||
else:
|
||||
err_msg = f"Unsupported task: {task!r}"
|
||||
assert response.json()["error"]["message"].startswith(err_msg)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import logging
|
||||
import weakref
|
||||
|
||||
import pytest
|
||||
@@ -60,22 +59,19 @@ def test_token_ids_prompts(llm: LLM):
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_score_api(llm: LLM):
|
||||
err_msg = "Scoring API is only enabled for num_labels == 1."
|
||||
err_msg = "This model does not support the Scoring API."
|
||||
with pytest.raises(ValueError, match=err_msg):
|
||||
llm.score("ping", "pong", use_tqdm=False)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("task", ["classify", "embed", "token_embed", "plugin"])
|
||||
def test_unsupported_tasks(llm: LLM, task: PoolingTask, caplog_vllm):
|
||||
if task == "classify":
|
||||
with caplog_vllm.at_level(level=logging.WARNING, logger="vllm"):
|
||||
llm.encode(prompt, pooling_task=task, use_tqdm=False)
|
||||
assert "deprecated" in caplog_vllm.text
|
||||
if task == "plugin":
|
||||
err_msg = "No IOProcessor plugin installed."
|
||||
elif task == "classify":
|
||||
err_msg = "Try switching the model's pooling_task via.+"
|
||||
else:
|
||||
if task == "plugin":
|
||||
err_msg = "No IOProcessor plugin installed."
|
||||
else:
|
||||
err_msg = "Embedding API is not supported by this model.+"
|
||||
err_msg = "Embedding API is not supported by this model.+"
|
||||
|
||||
with pytest.raises(ValueError, match=err_msg):
|
||||
llm.encode(prompt, pooling_task=task, use_tqdm=False)
|
||||
with pytest.raises(ValueError, match=err_msg):
|
||||
llm.encode(prompt, pooling_task=task, use_tqdm=False)
|
||||
|
||||
@@ -50,7 +50,7 @@ async def test_pooling_token_classify(server: RemoteOpenAIServer, model_name: st
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("task", ["embed", "token_embed", "plugin"])
|
||||
@pytest.mark.parametrize("task", ["classify", "embed", "token_embed", "plugin"])
|
||||
async def test_pooling_not_supported(
|
||||
server: RemoteOpenAIServer, model_name: str, task: str
|
||||
):
|
||||
@@ -63,9 +63,12 @@ async def test_pooling_not_supported(
|
||||
"task": task,
|
||||
},
|
||||
)
|
||||
assert response.json()["error"]["type"] == "BadRequestError"
|
||||
|
||||
if task == "plugin":
|
||||
err_msg = "No IOProcessor plugin installed."
|
||||
elif task == "classify":
|
||||
err_msg = "Try switching the model's pooling_task via"
|
||||
else:
|
||||
err_msg = f"Unsupported task: {task!r}"
|
||||
assert response.json()["error"]["message"].startswith(err_msg)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import logging
|
||||
import weakref
|
||||
|
||||
import pytest
|
||||
@@ -64,15 +63,12 @@ def test_token_ids_prompts(llm: LLM):
|
||||
|
||||
@pytest.mark.parametrize("task", ["embed", "classify", "token_classify", "plugin"])
|
||||
def test_unsupported_tasks(llm: LLM, task: PoolingTask, caplog_vllm):
|
||||
if task == "embed":
|
||||
with caplog_vllm.at_level(level=logging.WARNING, logger="vllm"):
|
||||
llm.encode(prompt, pooling_task=task, use_tqdm=False)
|
||||
assert "deprecated" in caplog_vllm.text
|
||||
if task == "plugin":
|
||||
err_msg = "No IOProcessor plugin installed."
|
||||
elif task == "embed":
|
||||
err_msg = "Try switching the model's pooling_task via.+"
|
||||
else:
|
||||
if task == "plugin":
|
||||
err_msg = "No IOProcessor plugin installed."
|
||||
else:
|
||||
err_msg = "Classification API is not supported by this model.+"
|
||||
err_msg = "Classification API is not supported by this model.+"
|
||||
|
||||
with pytest.raises(ValueError, match=err_msg):
|
||||
llm.encode(prompt, pooling_task=task, use_tqdm=False)
|
||||
with pytest.raises(ValueError, match=err_msg):
|
||||
llm.encode(prompt, pooling_task=task, use_tqdm=False)
|
||||
|
||||
@@ -73,7 +73,7 @@ async def test_pooling_token_embed(server: RemoteOpenAIServer, model_name: str):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("task", ["classify", "token_classify", "plugin"])
|
||||
@pytest.mark.parametrize("task", ["embed", "classify", "token_classify", "plugin"])
|
||||
async def test_pooling_not_supported(
|
||||
server: RemoteOpenAIServer, model_name: str, task: str
|
||||
):
|
||||
@@ -86,9 +86,12 @@ async def test_pooling_not_supported(
|
||||
"task": task,
|
||||
},
|
||||
)
|
||||
assert response.json()["error"]["type"] == "BadRequestError"
|
||||
|
||||
if task == "plugin":
|
||||
err_msg = "No IOProcessor plugin installed."
|
||||
elif task == "embed":
|
||||
err_msg = "Try switching the model's pooling_task via"
|
||||
else:
|
||||
err_msg = f"Unsupported task: {task!r}"
|
||||
assert response.json()["error"]["message"].startswith(err_msg)
|
||||
|
||||
@@ -6,6 +6,7 @@ from transformers import AutoModel
|
||||
|
||||
from tests.models.utils import check_embeddings_close
|
||||
from vllm import TokensPrompt
|
||||
from vllm.config import PoolerConfig
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -21,6 +22,7 @@ def test_embed_models(hf_runner, vllm_runner, model: str):
|
||||
with vllm_runner(
|
||||
model,
|
||||
runner="pooling",
|
||||
pooler_config=PoolerConfig(task="token_embed"),
|
||||
max_model_len=128,
|
||||
max_num_batched_tokens=chunk_size,
|
||||
enforce_eager=True,
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
import httpx
|
||||
import openai
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import torch
|
||||
|
||||
from ....utils import RemoteOpenAIServer
|
||||
@@ -25,29 +24,42 @@ sentences_2 = [
|
||||
similarity_reference = [[0.6259, 0.3474], [0.3309, 0.6734]]
|
||||
lexical_score_reference = [0.19554901123046875, 0.0]
|
||||
colbert_score_reference = [0.7797, 0.4620]
|
||||
SUPPORTED_TASKS = ["embed", "token_embed", "token_classify"]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=SUPPORTED_TASKS)
|
||||
def pooling_task(request):
|
||||
yield request.param
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
def server(pooling_task):
|
||||
args = [
|
||||
"--max-model-len",
|
||||
str(MAX_MODEL_LEN),
|
||||
"--hf-overrides",
|
||||
'{"architectures": ["BgeM3EmbeddingModel"]}',
|
||||
"--pooler-config.task",
|
||||
pooling_task,
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server):
|
||||
async with server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bge_m3_api_server_embedding(client: openai.AsyncOpenAI):
|
||||
async def test_bge_m3_api_server_embedding(server, pooling_task):
|
||||
client = server.get_async_client()
|
||||
|
||||
if pooling_task != "embed":
|
||||
with pytest.raises(openai.InternalServerError):
|
||||
await run_client_embeddings(
|
||||
client,
|
||||
MODEL_NAME,
|
||||
sentences_1,
|
||||
)
|
||||
return
|
||||
|
||||
embeddings_list_1 = await run_client_embeddings(
|
||||
client,
|
||||
MODEL_NAME,
|
||||
@@ -117,7 +129,14 @@ def compute_lexical_matching_score(
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bge_m3_api_server_sparse_embedding(client: openai.AsyncOpenAI):
|
||||
async def test_bge_m3_api_server_sparse_embedding(server, pooling_task):
|
||||
client = server.get_async_client()
|
||||
|
||||
if pooling_task != "token_classify":
|
||||
with pytest.raises(openai.BadRequestError):
|
||||
await sparse_embeddings(client, sentences_1)
|
||||
return
|
||||
|
||||
embeddings_1 = await sparse_embeddings(client, sentences_1)
|
||||
embeddings_2 = await sparse_embeddings(client, sentences_2)
|
||||
|
||||
@@ -137,9 +156,11 @@ async def test_bge_m3_api_server_sparse_embedding(client: openai.AsyncOpenAI):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bge_m3_api_server_sparse_embedding_corner_case(
|
||||
client: openai.AsyncOpenAI,
|
||||
):
|
||||
async def test_bge_m3_api_server_sparse_embedding_corner_case(server, pooling_task):
|
||||
if pooling_task != "token_classify":
|
||||
return
|
||||
|
||||
client = server.get_async_client()
|
||||
embeddings = await sparse_embeddings(client, ["Hi"])
|
||||
assert len(embeddings) == 1
|
||||
assert 2673 in embeddings[0]
|
||||
@@ -155,7 +176,18 @@ def colbert_score(q_reps: torch.Tensor, p_reps: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bge_m3_api_server_multi_vector(client: openai.AsyncOpenAI):
|
||||
async def test_bge_m3_api_server_multi_vector(server, pooling_task):
|
||||
client = server.get_async_client()
|
||||
|
||||
if pooling_task != "token_embed":
|
||||
with pytest.raises(openai.BadRequestError):
|
||||
await client.post(
|
||||
"../pooling",
|
||||
body={"model": MODEL_NAME, "input": sentences_1, "task": "token_embed"},
|
||||
cast_to=httpx.Response,
|
||||
)
|
||||
return
|
||||
|
||||
result_1 = await client.post(
|
||||
"../pooling",
|
||||
body={"model": MODEL_NAME, "input": sentences_1, "task": "token_embed"},
|
||||
|
||||
@@ -4,6 +4,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm import TokensPrompt
|
||||
from vllm.config import PoolerConfig
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -20,6 +21,7 @@ def test_extract_hidden_states(hf_runner, vllm_runner, model: str):
|
||||
max_model_len=128,
|
||||
enforce_eager=True,
|
||||
runner="pooling",
|
||||
pooler_config=PoolerConfig(task="token_embed"),
|
||||
enable_prefix_caching=True,
|
||||
) as vllm_model:
|
||||
pooling_outputs = vllm_model.llm.encode(
|
||||
@@ -44,14 +46,3 @@ def test_extract_hidden_states(hf_runner, vllm_runner, model: str):
|
||||
assert len(output.prompt_token_ids) == n
|
||||
assert len(output.outputs.data) == n
|
||||
assert output.num_cached_tokens == 0
|
||||
|
||||
# skip_reading_prefix_cache can still write to cache
|
||||
# to accelerate following requests
|
||||
pooling_outputs = vllm_model.llm.encode(
|
||||
[TokensPrompt(prompt_token_ids=t) for t in token_prompts],
|
||||
pooling_task="embed",
|
||||
)
|
||||
|
||||
for n, output in zip(n_prompt_tokens, pooling_outputs):
|
||||
assert len(output.prompt_token_ids) == n
|
||||
assert output.num_cached_tokens > 0
|
||||
|
||||
@@ -5,6 +5,7 @@ import torch
|
||||
from transformers import AutoModel
|
||||
|
||||
from tests.models.utils import check_embeddings_close
|
||||
from vllm.config import PoolerConfig
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -17,6 +18,7 @@ def test_embed_models(hf_runner, vllm_runner, example_prompts, model: str, dtype
|
||||
with vllm_runner(
|
||||
model,
|
||||
runner="pooling",
|
||||
pooler_config=PoolerConfig(task="token_embed"),
|
||||
max_model_len=None,
|
||||
) as vllm_model:
|
||||
vllm_outputs = vllm_model.token_embed(example_prompts)
|
||||
|
||||
@@ -146,7 +146,7 @@ def test_multi_vector_retrieval_models_using_normalize(
|
||||
model,
|
||||
max_model_len=512,
|
||||
dtype=dtype,
|
||||
pooler_config=PoolerConfig(use_activation=False),
|
||||
pooler_config=PoolerConfig(use_activation=False, task="token_embed"),
|
||||
) as vllm_model:
|
||||
wo_normalize = vllm_model.token_embed(example_prompts)
|
||||
|
||||
@@ -154,7 +154,7 @@ def test_multi_vector_retrieval_models_using_normalize(
|
||||
model,
|
||||
max_model_len=512,
|
||||
dtype=dtype,
|
||||
pooler_config=PoolerConfig(use_activation=True),
|
||||
pooler_config=PoolerConfig(use_activation=True, task="token_embed"),
|
||||
) as vllm_model:
|
||||
w_normalize = vllm_model.token_embed(example_prompts)
|
||||
|
||||
|
||||
@@ -79,7 +79,7 @@ from vllm.renderers.inputs.preprocess import (
|
||||
prompt_to_seq,
|
||||
)
|
||||
from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams
|
||||
from vllm.tasks import PoolingTask
|
||||
from vllm.tasks import SCORE_TYPE_MAP, PoolingTask
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils.counter import Counter
|
||||
@@ -1207,12 +1207,9 @@ class LLM:
|
||||
f"Supported tasks: {self.supported_tasks}"
|
||||
)
|
||||
else:
|
||||
logger.warning_once(
|
||||
"Pooling multitask support is deprecated and will "
|
||||
"be removed in v0.20. When the default pooling task is "
|
||||
"not what you want, you need to manually specify it "
|
||||
'via PoolerConfig(task="%s"). ',
|
||||
pooling_task,
|
||||
raise ValueError(
|
||||
f"Try switching the model's pooling_task "
|
||||
f'via `PoolerConfig(task="{pooling_task}")`'
|
||||
)
|
||||
|
||||
if pooling_task == "plugin" and "plugin" not in self.pooling_io_processors:
|
||||
@@ -1410,7 +1407,7 @@ class LLM:
|
||||
"pooling model."
|
||||
)
|
||||
|
||||
score_type = self.model_config.score_type
|
||||
score_type: str | None = SCORE_TYPE_MAP.get(self.pooling_task, None) # type: ignore[arg-type]
|
||||
if (
|
||||
score_type == "cross-encoder"
|
||||
and getattr(self.model_config.hf_config, "num_labels", 0) != 1
|
||||
|
||||
@@ -15,10 +15,7 @@ from starlette.datastructures import Headers
|
||||
from vllm import PoolingParams, PoolingRequestOutput, envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
ChatTemplateConfig,
|
||||
ChatTemplateContentFormatOption,
|
||||
)
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateConfig
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
|
||||
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
||||
@@ -48,9 +45,7 @@ class PoolingServingBase(ABC):
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: RequestLogger | None,
|
||||
chat_template: str | None = None,
|
||||
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
|
||||
trust_request_chat_template: bool = False,
|
||||
chat_template_config: ChatTemplateConfig,
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
log_error_stack: bool = False,
|
||||
):
|
||||
@@ -63,11 +58,7 @@ class PoolingServingBase(ABC):
|
||||
self.request_logger = request_logger
|
||||
self.return_tokens_as_token_ids = return_tokens_as_token_ids
|
||||
self.log_error_stack = log_error_stack
|
||||
self.chat_template_config = ChatTemplateConfig(
|
||||
chat_template=chat_template,
|
||||
chat_template_content_format=chat_template_content_format,
|
||||
trust_request_chat_template=trust_request_chat_template,
|
||||
)
|
||||
self.chat_template_config = chat_template_config
|
||||
|
||||
# Shared thread pool executor for preprocessing and postprocessing.
|
||||
self._executor: Executor = models.renderer._executor
|
||||
|
||||
@@ -10,7 +10,7 @@ from vllm.entrypoints.chat_utils import ChatTemplateConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.plugins.io_processors import has_io_processor
|
||||
from vllm.renderers import BaseRenderer
|
||||
from vllm.tasks import POOLING_TASKS, SupportedTask
|
||||
from vllm.tasks import POOLING_TASKS, SCORE_TYPE_MAP, SupportedTask
|
||||
|
||||
from .base.io_processor import PoolingIOProcessor
|
||||
from .utils import enable_scoring_api
|
||||
@@ -43,23 +43,24 @@ def init_pooling_io_processors(
|
||||
) -> dict[str, PoolingIOProcessor]:
|
||||
model_config = vllm_config.model_config
|
||||
processors: dict[str, type[PoolingIOProcessor]] = {}
|
||||
pooling_task = model_config.get_pooling_task(supported_tasks)
|
||||
|
||||
if "classify" in supported_tasks:
|
||||
if pooling_task == "classify":
|
||||
from .classify.io_processor import ClassifyIOProcessor
|
||||
|
||||
processors["classify"] = ClassifyIOProcessor
|
||||
|
||||
if "token_classify" in supported_tasks:
|
||||
if pooling_task == "token_classify":
|
||||
from .classify.io_processor import TokenClassifyIOProcessor
|
||||
|
||||
processors["token_classify"] = TokenClassifyIOProcessor
|
||||
|
||||
if "embed" in supported_tasks:
|
||||
if pooling_task == "embed":
|
||||
from .embed.io_processor import EmbedIOProcessor
|
||||
|
||||
processors["embed"] = EmbedIOProcessor
|
||||
|
||||
if "token_embed" in supported_tasks:
|
||||
if pooling_task == "token_embed":
|
||||
from .embed.io_processor import TokenEmbedIOProcessor
|
||||
|
||||
processors["token_embed"] = TokenEmbedIOProcessor
|
||||
@@ -71,15 +72,15 @@ def init_pooling_io_processors(
|
||||
from .pooling.io_processor import PluginWithIOProcessorPlugins
|
||||
|
||||
processors["plugin"] = PluginWithIOProcessorPlugins
|
||||
elif "plugin" in supported_tasks:
|
||||
elif pooling_task == "plugin":
|
||||
from .pooling.io_processor import PluginWithoutIOProcessorPlugins
|
||||
|
||||
processors["plugin"] = PluginWithoutIOProcessorPlugins
|
||||
|
||||
if enable_scoring_api(supported_tasks, model_config):
|
||||
score_type = model_config.score_type
|
||||
from .scoring.io_processor import ScoringIOProcessors
|
||||
|
||||
score_type: str | None = SCORE_TYPE_MAP.get(pooling_task, None) # type: ignore[arg-type]
|
||||
if score_type is not None and score_type in ScoringIOProcessors:
|
||||
processors[score_type] = ScoringIOProcessors[score_type]
|
||||
|
||||
@@ -140,6 +141,10 @@ def init_pooling_state(
|
||||
request_logger: RequestLogger | None,
|
||||
supported_tasks: tuple["SupportedTask", ...],
|
||||
):
|
||||
model_config = engine_client.model_config
|
||||
if model_config is None:
|
||||
return
|
||||
|
||||
from vllm.entrypoints.chat_utils import load_chat_template
|
||||
from vllm.tasks import POOLING_TASKS
|
||||
|
||||
@@ -148,8 +153,14 @@ def init_pooling_state(
|
||||
from .pooling.serving import ServingPooling
|
||||
from .scoring.serving import ServingScores
|
||||
|
||||
model_config = engine_client.model_config
|
||||
resolved_chat_template = load_chat_template(args.chat_template)
|
||||
pooling_task = model_config.get_pooling_task(supported_tasks)
|
||||
|
||||
chat_template_config = ChatTemplateConfig(
|
||||
chat_template=resolved_chat_template,
|
||||
chat_template_content_format=args.chat_template_content_format,
|
||||
trust_request_chat_template=args.trust_request_chat_template,
|
||||
)
|
||||
|
||||
state.serving_pooling = (
|
||||
(
|
||||
@@ -158,9 +169,7 @@ def init_pooling_state(
|
||||
state.openai_serving_models,
|
||||
supported_tasks=supported_tasks,
|
||||
request_logger=request_logger,
|
||||
chat_template=resolved_chat_template,
|
||||
chat_template_content_format=args.chat_template_content_format,
|
||||
trust_request_chat_template=args.trust_request_chat_template,
|
||||
chat_template_config=chat_template_config,
|
||||
)
|
||||
)
|
||||
if any(t in supported_tasks for t in POOLING_TASKS)
|
||||
@@ -171,11 +180,9 @@ def init_pooling_state(
|
||||
engine_client,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
chat_template=resolved_chat_template,
|
||||
chat_template_content_format=args.chat_template_content_format,
|
||||
trust_request_chat_template=args.trust_request_chat_template,
|
||||
chat_template_config=chat_template_config,
|
||||
)
|
||||
if "embed" in supported_tasks
|
||||
if pooling_task == "embed"
|
||||
else None
|
||||
)
|
||||
state.serving_classification = (
|
||||
@@ -183,21 +190,18 @@ def init_pooling_state(
|
||||
engine_client,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
chat_template=resolved_chat_template,
|
||||
chat_template_content_format=args.chat_template_content_format,
|
||||
trust_request_chat_template=args.trust_request_chat_template,
|
||||
chat_template_config=chat_template_config,
|
||||
)
|
||||
if "classify" in supported_tasks
|
||||
if pooling_task == "classify"
|
||||
else None
|
||||
)
|
||||
state.serving_scores = (
|
||||
ServingScores(
|
||||
engine_client,
|
||||
state.openai_serving_models,
|
||||
supported_tasks=supported_tasks,
|
||||
request_logger=request_logger,
|
||||
chat_template=resolved_chat_template,
|
||||
chat_template_content_format=args.chat_template_content_format,
|
||||
trust_request_chat_template=args.trust_request_chat_template,
|
||||
chat_template_config=chat_template_config,
|
||||
enable_flash_late_interaction=getattr(
|
||||
args, "enable_flash_late_interaction", True
|
||||
),
|
||||
@@ -214,7 +218,12 @@ def get_pooling_invocation_types(
|
||||
# NOTE: Items defined earlier take higher priority
|
||||
invocation_types: list[tuple[RequestType, tuple[GetHandlerFn, EndpointFn]]] = []
|
||||
|
||||
if "embed" in supported_tasks:
|
||||
if model_config is None:
|
||||
return invocation_types
|
||||
|
||||
pooling_task = model_config.get_pooling_task(supported_tasks)
|
||||
|
||||
if pooling_task == "embed":
|
||||
from .embed.api_router import create_embedding, embedding
|
||||
from .embed.protocol import EmbeddingRequest
|
||||
|
||||
@@ -222,7 +231,7 @@ def get_pooling_invocation_types(
|
||||
(EmbeddingRequest, (embedding, create_embedding)),
|
||||
]
|
||||
|
||||
if "classify" in supported_tasks:
|
||||
if pooling_task == "classify":
|
||||
from .classify.api_router import classify, create_classify
|
||||
from .classify.protocol import ClassificationRequest
|
||||
|
||||
|
||||
@@ -78,17 +78,15 @@ class ServingPooling(PoolingServingBase):
|
||||
|
||||
# plugin task uses io_processor.parse_request to verify inputs
|
||||
if pooling_task != "plugin" and pooling_task != self.pooling_task:
|
||||
if pooling_task not in self.io_processors:
|
||||
if pooling_task not in self.supported_tasks:
|
||||
raise ValueError(
|
||||
f"Unsupported task: {pooling_task!r} "
|
||||
f"Supported tasks: {self.supported_tasks}"
|
||||
)
|
||||
else:
|
||||
logger.warning_once(
|
||||
"Pooling multitask support is deprecated and will be removed "
|
||||
"in v0.20. When the default pooling task is not what you want, you "
|
||||
"need to manually specify it via --pooler-config.task %s. ",
|
||||
pooling_task,
|
||||
raise ValueError(
|
||||
"Try switching the model's pooling_task "
|
||||
f"via --pooler-config.task {request.task}."
|
||||
)
|
||||
|
||||
if pooling_task == "plugin" and "plugin" not in self.io_processors:
|
||||
|
||||
@@ -8,6 +8,7 @@ from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.openai.engine.protocol import UsageInfo
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
|
||||
from vllm.tasks import SCORE_TYPE_MAP, SupportedTask
|
||||
from vllm.v1.pool.late_interaction import (
|
||||
build_late_interaction_doc_params,
|
||||
build_late_interaction_query_params,
|
||||
@@ -38,10 +39,15 @@ class ServingScores(PoolingServing):
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
*args,
|
||||
supported_tasks: tuple[SupportedTask, ...],
|
||||
enable_flash_late_interaction: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
self.io_processor_name: str = engine_client.model_config.score_type
|
||||
pooling_task = engine_client.model_config.get_pooling_task(supported_tasks)
|
||||
score_type = SCORE_TYPE_MAP.get(pooling_task, None) # type: ignore[arg-type]
|
||||
assert score_type is not None
|
||||
|
||||
self.io_processor_name: str = score_type
|
||||
self.enable_flash_late_interaction = (
|
||||
self.io_processor_name == "late-interaction"
|
||||
and enable_flash_late_interaction
|
||||
|
||||
@@ -141,10 +141,14 @@ def enable_scoring_api(
|
||||
supported_tasks: tuple["SupportedTask", ...],
|
||||
model_config: ModelConfig | None = None,
|
||||
) -> bool:
|
||||
if any(t in supported_tasks for t in ("embed", "token_embed")):
|
||||
if model_config is None:
|
||||
return False
|
||||
|
||||
pooling_task = model_config.get_pooling_task(supported_tasks)
|
||||
if pooling_task in ("embed", "token_embed"):
|
||||
return True
|
||||
|
||||
if model_config is not None and "classify" in supported_tasks:
|
||||
if pooling_task == "classify":
|
||||
num_labels = getattr(model_config.hf_config, "num_labels", 0)
|
||||
if num_labels != 1:
|
||||
logger.debug_once("Scoring API is only enabled for num_labels == 1.")
|
||||
|
||||
@@ -87,13 +87,6 @@ class PoolingParams(
|
||||
return deepcopy(self)
|
||||
|
||||
def verify(self, model_config: ModelConfig) -> None:
|
||||
if self.task == "score":
|
||||
logger.warning_once(
|
||||
"`score` task is deprecated and will be removed in v0.20. "
|
||||
"Please use `classify` instead."
|
||||
)
|
||||
self.task = "classify"
|
||||
|
||||
# plugin task uses io_processor.parse_request to verify inputs,
|
||||
# skipping PoolingParams verify
|
||||
if self.task == "plugin":
|
||||
|
||||
@@ -16,6 +16,11 @@ PoolingTask = Literal[
|
||||
POOLING_TASKS: tuple[PoolingTask, ...] = get_args(PoolingTask)
|
||||
|
||||
ScoreType = Literal["bi-encoder", "cross-encoder", "late-interaction"]
|
||||
SCORE_TYPE_MAP: dict[PoolingTask, ScoreType] = {
|
||||
"embed": "bi-encoder",
|
||||
"classify": "cross-encoder",
|
||||
"token_embed": "late-interaction",
|
||||
}
|
||||
|
||||
FrontendTask = Literal["render"]
|
||||
FRONTEND_TASKS: tuple[FrontendTask, ...] = get_args(FrontendTask)
|
||||
|
||||
Reference in New Issue
Block a user