mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Merge branch 'main' into spark-weekly-newcases
This commit is contained in:
commit
30a6a40d2d
@ -273,6 +273,13 @@ class OpenAIServer:
|
||||
self.app.add_api_route("/v1/responses",
|
||||
self.openai_responses,
|
||||
methods=["POST"])
|
||||
self.app.add_api_route('/v1/responses/{response_id}',
|
||||
self.openai_responses_get_response,
|
||||
methods=["GET"])
|
||||
self.app.add_api_route('/v1/responses/{response_id}',
|
||||
self.openai_responses_delete_response,
|
||||
methods=["DELETE"])
|
||||
|
||||
# RL-only endpoints
|
||||
self.app.add_api_route("/release_memory",
|
||||
self.release_memory,
|
||||
@ -1065,6 +1072,38 @@ class OpenAIServer:
|
||||
|
||||
return JSONResponse(content={"detail": "None"})
|
||||
|
||||
async def openai_responses_get_response(self, response_id: str) -> JSONResponse:
|
||||
logger.info(f"Getting response: {response_id}")
|
||||
if not self.enable_store:
|
||||
return self.create_error_response(message="Response storage is disabled", err_type="InvalidRequestError")
|
||||
|
||||
if not response_id.startswith("resp_"):
|
||||
return self._create_invalid_response_id_error(response_id)
|
||||
|
||||
response = await self.conversation_store.load_response(response_id)
|
||||
if response is None:
|
||||
return self._create_response_id_not_found_error(response_id)
|
||||
|
||||
return JSONResponse(content=response.model_dump())
|
||||
|
||||
async def openai_responses_delete_response(self, response_id: str) -> JSONResponse:
|
||||
logger.info(f"Deleting response: {response_id}")
|
||||
if not self.enable_store:
|
||||
return self.create_error_response(message="Response storage is disabled", err_type="InvalidRequestError")
|
||||
|
||||
if not response_id.startswith("resp_"):
|
||||
return self._create_invalid_response_id_error(response_id)
|
||||
|
||||
success = await self.conversation_store.pop_response(response_id)
|
||||
if not success:
|
||||
return self._create_response_id_not_found_error(response_id)
|
||||
|
||||
return JSONResponse(content={
|
||||
"id": response_id,
|
||||
"object": "response",
|
||||
"deleted": True
|
||||
})
|
||||
|
||||
async def release_memory(self, request: MemoryUpdateRequest) -> JSONResponse:
|
||||
assert isinstance(self.llm, AsyncLLM), "/release_memory endpoint is only supported with AsyncLLM()"
|
||||
await self.llm.collective_rpc('sleep', args=(request.tags,))
|
||||
|
||||
@ -198,10 +198,13 @@ class ConversationHistoryStore:
|
||||
# Map from conversation id to response id, which is the latest response in the conversation.
|
||||
self.conversation_to_response: dict[str, str] = {}
|
||||
|
||||
async def load_response(self, resp_id: str) -> ResponsesResponse:
|
||||
async def load_response(self, resp_id: str) -> ResponsesResponse | None:
|
||||
_responses_debug_log(
|
||||
f"ConversationHistoryStore loading resp: {resp_id}")
|
||||
async with self.responses_lock:
|
||||
if resp_id not in self.responses:
|
||||
return None
|
||||
|
||||
self.responses.move_to_end(resp_id)
|
||||
return self.responses.get(resp_id)
|
||||
|
||||
@ -263,6 +266,10 @@ class ConversationHistoryStore:
|
||||
self.conversation_to_response[conversation_id] = resp_id
|
||||
self._update_visited_conversation(conversation_id)
|
||||
|
||||
async def pop_response(self, resp_id: Optional[str] = None) -> bool:
|
||||
async with self.responses_lock:
|
||||
return self._pop_response(resp_id)
|
||||
|
||||
async def store_messages(self, resp_id: str,
|
||||
msgs: Union[list[Message],
|
||||
list[ChatCompletionMessageParam]],
|
||||
@ -398,12 +405,24 @@ class ConversationHistoryStore:
|
||||
|
||||
del conversation[start_index:end_index + 1]
|
||||
|
||||
def _pop_response(self) -> None:
|
||||
_responses_debug_log(f"responses type: {type(self.responses)}")
|
||||
resp_id, _ = self.responses.popitem(last=False)
|
||||
def _pop_response(self, resp_id: Optional[str] = None) -> bool:
|
||||
_responses_debug_log(f"pop response {resp_id}")
|
||||
|
||||
if not self.responses:
|
||||
return False
|
||||
|
||||
if resp_id is not None:
|
||||
if resp_id not in self.responses:
|
||||
return False
|
||||
self.responses.pop(resp_id)
|
||||
else:
|
||||
resp_id, _ = self.responses.popitem(last=False)
|
||||
|
||||
if resp_id in self.response_to_conversation:
|
||||
self.response_to_conversation.pop(resp_id)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _get_system_message(
|
||||
model_identity: Optional[str] = None,
|
||||
|
||||
@ -1653,6 +1653,14 @@ def test_openai_responses(llm_root, llm_venv):
|
||||
str(test_root / "_test_openai_responses.py")])
|
||||
|
||||
|
||||
def test_openai_responses_entrypoint(llm_root, llm_venv):
|
||||
test_root = unittest_path() / "llmapi" / "apps"
|
||||
llm_venv.run_cmd([
|
||||
"-m", "pytest",
|
||||
str(test_root / "_test_openai_responses_entrypoint.py")
|
||||
])
|
||||
|
||||
|
||||
def test_openai_health(llm_root, llm_venv):
|
||||
test_root = unittest_path() / "llmapi" / "apps"
|
||||
llm_venv.run_cmd([
|
||||
|
||||
@ -66,6 +66,7 @@ l0_a10:
|
||||
- test_e2e.py::test_openai_misc_example[pytorch]
|
||||
- test_e2e.py::test_openai_reasoning[pytorch]
|
||||
- test_e2e.py::test_openai_tool_call
|
||||
- test_e2e.py::test_openai_responses_entrypoint
|
||||
- test_e2e.py::test_openai_completions_example[pytorch]
|
||||
- test_e2e.py::test_openai_chat_example[pytorch] TIMEOUT (90)
|
||||
- test_e2e.py::test_trtllm_bench_request_rate_and_concurrency[enable_concurrency-]
|
||||
|
||||
@ -0,0 +1,80 @@
|
||||
import openai
|
||||
import pytest
|
||||
|
||||
from ..test_llm import get_model_path
|
||||
from .openai_server import RemoteOpenAIServer
|
||||
|
||||
pytestmark = pytest.mark.threadleak(enabled=False)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=["Qwen3/Qwen3-0.6B"])
|
||||
def model(request):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server(model: str):
|
||||
model_path = get_model_path(model)
|
||||
|
||||
args = []
|
||||
if model.startswith("Qwen3"):
|
||||
args.extend(["--reasoning_parser", "qwen3"])
|
||||
elif model.startswith("DeepSeek-R1"):
|
||||
args.extend(["--reasoning_parser", "deepseek-r1"])
|
||||
|
||||
if not model.startswith("gpt_oss"):
|
||||
args.extend(["--tool_parser", "qwen3"])
|
||||
|
||||
with RemoteOpenAIServer(model_path, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def client(server: RemoteOpenAIServer):
|
||||
return server.get_async_client()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
async def test_get(client: openai.AsyncOpenAI, model: str):
|
||||
response = await client.responses.create(
|
||||
model=model, input="Which one is larger as numeric, 9.9 or 9.11?", max_output_tokens=1024
|
||||
)
|
||||
|
||||
response_get = await client.responses.retrieve(response.id)
|
||||
assert response_get.id == response.id
|
||||
assert response_get.model_dump() == response.model_dump()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
async def test_get_invalid_response_id(client: openai.AsyncOpenAI):
|
||||
with pytest.raises(openai.BadRequestError):
|
||||
await client.responses.retrieve("invalid_response_id")
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
async def test_get_non_existent_response_id(client: openai.AsyncOpenAI):
|
||||
with pytest.raises(openai.NotFoundError):
|
||||
await client.responses.retrieve("resp_non_existent_response_id")
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
async def test_delete(client: openai.AsyncOpenAI, model: str):
|
||||
response = await client.responses.create(
|
||||
model=model, input="Which one is larger as numeric, 9.9 or 9.11?", max_output_tokens=1024
|
||||
)
|
||||
|
||||
await client.responses.delete(response.id)
|
||||
with pytest.raises(openai.NotFoundError):
|
||||
await client.responses.retrieve(response.id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
async def test_delete_invalid_response_id(client: openai.AsyncOpenAI):
|
||||
with pytest.raises(openai.BadRequestError):
|
||||
await client.responses.delete("invalid_response_id")
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
async def test_delete_non_existent_response_id(client: openai.AsyncOpenAI):
|
||||
with pytest.raises(openai.NotFoundError):
|
||||
await client.responses.delete("resp_non_existent_response_id")
|
||||
Loading…
Reference in New Issue
Block a user