mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-04-25 08:48:16 +08:00
fix: OpenAI API SSE compatibility and stream reliability
Problems:
1. Stream response missing 'data: [DONE]' terminator — OpenAI clients
(openai-python, LiteLLM, etc.) expect this SSE sentinel to detect
end-of-stream. Without it, clients hang or raise IncompleteRead.
2. Stream chunks lack required fields (id, object, created, model) —
the OpenAI spec mandates chat.completion.chunk objects contain these
metadata fields. Clients that validate response shape reject the
bare {choices:[{delta:{}}]} payloads.
3. Non-stream path uses max_length instead of max_new_tokens — stream
path correctly uses max_new_tokens, but non-stream passes
max_length=prompt_len+max_tokens. This creates inconsistent behavior
(different effective generation limits) between the two code paths.
4. Generation thread exceptions silently swallowed — if model.generate()
throws (OOM, dtype mismatch, etc.), the queue never receives None,
causing the consumer to block forever on queue.get(). Added sentinel-
based error propagation so exceptions surface to the SSE generator.
5. No CORS headers — browser-based clients (web UIs, chatbots) get
blocked by same-origin policy. Added CORSMiddleware.
6. Generation thread not daemonized — if the main process exits during
generation (e.g. client disconnect + server shutdown), the non-daemon
thread keeps the process alive. Set daemon=True.
Tests:
- 7 tests for parse_response (plain text, think tags, tool calls, edge cases)
- 6 tests for SSE stream format (DONE marker, chunk fields, CORS, max_new_tokens, daemon, error propagation)
This commit is contained in:
parent
487f78754d
commit
48ed5ec8bc
@ -14,6 +14,7 @@ import uvicorn
|
||||
from threading import Thread
|
||||
from queue import Queue
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
|
||||
@ -23,6 +24,13 @@ from model.model_lora import apply_lora, load_lora
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
def init_model(args):
|
||||
@ -102,7 +110,26 @@ def parse_response(text):
|
||||
return text.strip(), reasoning_content, tool_calls or None
|
||||
|
||||
|
||||
_SENTINEL = object()
|
||||
|
||||
|
||||
def generate_stream_response(messages, temperature, top_p, max_tokens, tools=None, open_thinking=False):
|
||||
request_id = f"chatcmpl-{int(time.time() * 1000)}"
|
||||
created = int(time.time())
|
||||
|
||||
def _make_chunk(delta, finish_reason=None):
|
||||
"""Build a single SSE chunk in OpenAI chat.completion.chunk format."""
|
||||
choice = {"index": 0, "delta": delta}
|
||||
if finish_reason is not None:
|
||||
choice["finish_reason"] = finish_reason
|
||||
return json.dumps({
|
||||
"id": request_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created,
|
||||
"model": "minimind",
|
||||
"choices": [choice],
|
||||
}, ensure_ascii=False)
|
||||
|
||||
try:
|
||||
new_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, tools=tools or None, open_thinking=open_thinking)[-max_tokens:]
|
||||
inputs = tokenizer(new_prompt, return_tensors="pt", truncation=True).to(device)
|
||||
@ -111,19 +138,28 @@ def generate_stream_response(messages, temperature, top_p, max_tokens, tools=Non
|
||||
streamer = CustomStreamer(tokenizer, queue)
|
||||
|
||||
def _generate():
|
||||
model.generate(
|
||||
inputs.input_ids,
|
||||
max_new_tokens=max_tokens,
|
||||
do_sample=True,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
attention_mask=inputs.attention_mask,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
streamer=streamer
|
||||
)
|
||||
try:
|
||||
model.generate(
|
||||
inputs.input_ids,
|
||||
max_new_tokens=max_tokens,
|
||||
do_sample=True,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
attention_mask=inputs.attention_mask,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
streamer=streamer
|
||||
)
|
||||
except Exception as e:
|
||||
# Propagate generation errors to the consumer so it doesn't
|
||||
# block forever on queue.get().
|
||||
queue.put(_SENTINEL)
|
||||
queue.put(e)
|
||||
|
||||
Thread(target=_generate).start()
|
||||
Thread(target=_generate, daemon=True).start()
|
||||
|
||||
# Emit the initial chunk with the role.
|
||||
yield _make_chunk({"role": "assistant"})
|
||||
|
||||
full_text = ""
|
||||
emitted = 0
|
||||
@ -131,6 +167,9 @@ def generate_stream_response(messages, temperature, top_p, max_tokens, tools=Non
|
||||
|
||||
while True:
|
||||
text = queue.get()
|
||||
if text is _SENTINEL:
|
||||
# The generation thread raised; re-raise here.
|
||||
raise queue.get()
|
||||
if text is None:
|
||||
break
|
||||
full_text += text
|
||||
@ -141,28 +180,29 @@ def generate_stream_response(messages, temperature, top_p, max_tokens, tools=Non
|
||||
thinking_ended = True
|
||||
new_r = full_text[emitted:pos]
|
||||
if new_r:
|
||||
yield json.dumps({"choices": [{"delta": {"reasoning_content": new_r}}]}, ensure_ascii=False)
|
||||
yield _make_chunk({"reasoning_content": new_r})
|
||||
emitted = pos + len('</think>')
|
||||
after = full_text[emitted:].lstrip('\n')
|
||||
emitted = len(full_text) - len(after)
|
||||
if after:
|
||||
yield json.dumps({"choices": [{"delta": {"content": after}}]}, ensure_ascii=False)
|
||||
yield _make_chunk({"content": after})
|
||||
emitted = len(full_text)
|
||||
else:
|
||||
new_r = full_text[emitted:]
|
||||
if new_r:
|
||||
yield json.dumps({"choices": [{"delta": {"reasoning_content": new_r}}]}, ensure_ascii=False)
|
||||
yield _make_chunk({"reasoning_content": new_r})
|
||||
emitted = len(full_text)
|
||||
else:
|
||||
new_c = full_text[emitted:]
|
||||
if new_c:
|
||||
yield json.dumps({"choices": [{"delta": {"content": new_c}}]}, ensure_ascii=False)
|
||||
yield _make_chunk({"content": new_c})
|
||||
emitted = len(full_text)
|
||||
|
||||
_, _, tool_calls = parse_response(full_text)
|
||||
if tool_calls:
|
||||
yield json.dumps({"choices": [{"delta": {"tool_calls": tool_calls}}]}, ensure_ascii=False)
|
||||
yield json.dumps({"choices": [{"delta": {}, "finish_reason": "tool_calls" if tool_calls else "stop"}]}, ensure_ascii=False)
|
||||
yield _make_chunk({"tool_calls": tool_calls})
|
||||
finish = "tool_calls" if tool_calls else "stop"
|
||||
yield _make_chunk({}, finish_reason=finish)
|
||||
|
||||
except Exception as e:
|
||||
yield json.dumps({"error": str(e)})
|
||||
@ -172,15 +212,20 @@ def generate_stream_response(messages, temperature, top_p, max_tokens, tools=Non
|
||||
async def chat_completions(request: ChatRequest):
|
||||
try:
|
||||
if request.stream:
|
||||
return StreamingResponse(
|
||||
(f"data: {chunk}\n\n" for chunk in generate_stream_response(
|
||||
def _event_stream():
|
||||
for chunk in generate_stream_response(
|
||||
messages=request.messages,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
max_tokens=request.max_tokens,
|
||||
tools=request.tools,
|
||||
open_thinking=request.get_open_thinking()
|
||||
)),
|
||||
):
|
||||
yield f"data: {chunk}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
_event_stream(),
|
||||
media_type="text/event-stream"
|
||||
)
|
||||
else:
|
||||
@ -195,7 +240,7 @@ async def chat_completions(request: ChatRequest):
|
||||
with torch.no_grad():
|
||||
generated_ids = model.generate(
|
||||
inputs["input_ids"],
|
||||
max_length=inputs["input_ids"].shape[1] + request.max_tokens,
|
||||
max_new_tokens=request.max_tokens,
|
||||
do_sample=True,
|
||||
attention_mask=inputs["attention_mask"],
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
|
||||
220
tests/test_serve_openai_api.py
Normal file
220
tests/test_serve_openai_api.py
Normal file
@ -0,0 +1,220 @@
|
||||
"""
|
||||
Regression tests for scripts/serve_openai_api.py
|
||||
These tests verify OpenAI API compatibility without loading a real model.
|
||||
"""
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
import os
|
||||
import types
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Stub heavy dependencies so importing serve_openai_api doesn't need GPU/model
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Provide a minimal torch stub (only attributes actually used at import time)
|
||||
_torch_stub = types.ModuleType("torch")
|
||||
_torch_stub.cuda = types.SimpleNamespace(is_available=lambda: False)
|
||||
_torch_stub.no_grad = lambda: types.SimpleNamespace(__enter__=lambda s: None, __exit__=lambda s, *a: None)
|
||||
_torch_stub.inference_mode = lambda: lambda fn: fn
|
||||
_torch_stub.load = lambda *a, **kw: {}
|
||||
_torch_stub.Tensor = type("Tensor", (), {})
|
||||
_torch_stub.float32 = "float32"
|
||||
_torch_stub.LongTensor = lambda *a, **kw: None
|
||||
|
||||
# optim sub-module
|
||||
_optim_stub = types.ModuleType("torch.optim")
|
||||
sys.modules["torch.optim"] = _optim_stub
|
||||
_torch_stub.optim = _optim_stub
|
||||
|
||||
# nn sub-module with Module and common layers
|
||||
_nn_stub = types.ModuleType("torch.nn")
|
||||
class _ModuleStub:
|
||||
def __init_subclass__(cls, **kw): pass
|
||||
def __init__(self, *a, **kw): pass
|
||||
_nn_stub.Module = _ModuleStub
|
||||
_nn_stub.Linear = type("Linear", (_ModuleStub,), {})
|
||||
_nn_stub.Embedding = type("Embedding", (_ModuleStub,), {})
|
||||
_nn_stub.Dropout = type("Dropout", (_ModuleStub,), {})
|
||||
class _ParameterStub:
|
||||
def __init__(self, *a, **kw): pass
|
||||
_nn_stub.Parameter = _ParameterStub
|
||||
_torch_stub.nn = _nn_stub
|
||||
|
||||
_nn_func = types.ModuleType("torch.nn.functional")
|
||||
_nn_func.scaled_dot_product_attention = lambda *a, **kw: None
|
||||
_nn_func.cross_entropy = lambda *a, **kw: None
|
||||
_nn_func.softmax = lambda *a, **kw: None
|
||||
_nn_func.silu = lambda *a, **kw: None
|
||||
|
||||
sys.modules["torch"] = _torch_stub
|
||||
sys.modules["torch.nn"] = _nn_stub
|
||||
sys.modules["torch.nn.functional"] = _nn_func
|
||||
sys.modules.setdefault("torch.distributed", types.ModuleType("torch.distributed"))
|
||||
sys.modules.setdefault("torch.nn.parallel", types.ModuleType("torch.nn.parallel"))
|
||||
sys.modules.setdefault("torch.utils", types.ModuleType("torch.utils"))
|
||||
sys.modules.setdefault("torch.utils.data", types.ModuleType("torch.utils.data"))
|
||||
|
||||
for mod_name in [
|
||||
"transformers", "transformers.activations", "transformers.modeling_outputs",
|
||||
"librosa", "soundfile", "numpy", "uvicorn",
|
||||
]:
|
||||
sys.modules.setdefault(mod_name, types.ModuleType(mod_name))
|
||||
|
||||
# Stub transformers classes that are used at module scope
|
||||
_transformers = sys.modules["transformers"]
|
||||
for cls_name in [
|
||||
"PreTrainedModel", "GenerationMixin", "PretrainedConfig",
|
||||
"AutoTokenizer", "AutoModelForCausalLM", "AutoModel",
|
||||
"AutoModelForSequenceClassification", "TextStreamer",
|
||||
]:
|
||||
if not hasattr(_transformers, cls_name):
|
||||
setattr(_transformers, cls_name, type(cls_name, (), {}))
|
||||
|
||||
_mo = sys.modules["transformers.modeling_outputs"]
|
||||
if not hasattr(_mo, "MoeCausalLMOutputWithPast"):
|
||||
setattr(_mo, "MoeCausalLMOutputWithPast", type("MoeCausalLMOutputWithPast", (), {}))
|
||||
|
||||
_act = sys.modules["transformers.activations"]
|
||||
if not hasattr(_act, "ACT2FN"):
|
||||
setattr(_act, "ACT2FN", {})
|
||||
|
||||
# Now import the functions we can test without a live model
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "scripts"))
|
||||
|
||||
from scripts.serve_openai_api import parse_response
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Tests for parse_response
|
||||
# ============================================================
|
||||
|
||||
class TestParseResponse:
|
||||
"""Verify that parse_response correctly separates reasoning, content and tool calls."""
|
||||
|
||||
def test_plain_text(self):
|
||||
content, reasoning, tool_calls = parse_response("Hello, world!")
|
||||
assert content == "Hello, world!"
|
||||
assert reasoning is None
|
||||
assert tool_calls is None
|
||||
|
||||
def test_think_tags_complete(self):
|
||||
text = "<think>Step 1: consider options</think>The answer is 42."
|
||||
content, reasoning, tool_calls = parse_response(text)
|
||||
assert reasoning == "Step 1: consider options"
|
||||
assert content == "The answer is 42."
|
||||
assert tool_calls is None
|
||||
|
||||
def test_think_tag_no_opening(self):
|
||||
"""When only </think> is present (streaming partial), split at the tag."""
|
||||
text = "I need to think carefully</think>Final answer here"
|
||||
content, reasoning, tool_calls = parse_response(text)
|
||||
assert reasoning == "I need to think carefully"
|
||||
assert content == "Final answer here"
|
||||
|
||||
def test_tool_call_parsed(self):
|
||||
text = 'Sure, I can help. <tool_call>{"name":"get_weather","arguments":{"city":"Tokyo"}}</tool_call>'
|
||||
content, reasoning, tool_calls = parse_response(text)
|
||||
assert "Sure, I can help." in content
|
||||
assert tool_calls is not None
|
||||
assert len(tool_calls) == 1
|
||||
assert tool_calls[0]["function"]["name"] == "get_weather"
|
||||
args = json.loads(tool_calls[0]["function"]["arguments"])
|
||||
assert args["city"] == "Tokyo"
|
||||
|
||||
def test_multiple_tool_calls(self):
|
||||
text = (
|
||||
'<tool_call>{"name":"a","arguments":{}}</tool_call>'
|
||||
'<tool_call>{"name":"b","arguments":{"x":1}}</tool_call>'
|
||||
)
|
||||
content, reasoning, tool_calls = parse_response(text)
|
||||
assert tool_calls is not None
|
||||
assert len(tool_calls) == 2
|
||||
assert tool_calls[0]["function"]["name"] == "a"
|
||||
assert tool_calls[1]["function"]["name"] == "b"
|
||||
|
||||
def test_invalid_tool_call_json_skipped(self):
|
||||
text = '<tool_call>not valid json</tool_call>The rest.'
|
||||
content, reasoning, tool_calls = parse_response(text)
|
||||
assert "The rest." in content
|
||||
assert tool_calls is None # Invalid JSON is silently skipped
|
||||
|
||||
def test_empty_string(self):
|
||||
content, reasoning, tool_calls = parse_response("")
|
||||
assert content == ""
|
||||
assert reasoning is None
|
||||
assert tool_calls is None
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Tests for SSE stream format (structural, no live model)
|
||||
# ============================================================
|
||||
|
||||
class TestSSEStreamFormat:
|
||||
"""Verify the SSE event stream wrapper produces correct framing."""
|
||||
|
||||
def test_done_marker_present(self):
|
||||
"""The stream MUST end with 'data: [DONE]\\n\\n' per OpenAI spec."""
|
||||
# Read the source and verify the [DONE] marker is emitted
|
||||
script_path = os.path.join(os.path.dirname(__file__), "..", "scripts", "serve_openai_api.py")
|
||||
with open(script_path) as f:
|
||||
source = f.read()
|
||||
assert 'data: [DONE]' in source, "Stream must emit 'data: [DONE]' terminator"
|
||||
|
||||
def test_stream_chunks_have_required_fields(self):
|
||||
"""Each stream chunk must include id, object, created, model, choices."""
|
||||
script_path = os.path.join(os.path.dirname(__file__), "..", "scripts", "serve_openai_api.py")
|
||||
with open(script_path) as f:
|
||||
source = f.read()
|
||||
# Verify _make_chunk builds proper structure with all required fields
|
||||
assert '"id": request_id' in source or '"id":' in source
|
||||
assert '"object": "chat.completion.chunk"' in source
|
||||
assert '"created":' in source
|
||||
assert '"model":' in source
|
||||
|
||||
def test_cors_middleware_present(self):
|
||||
"""CORS middleware must be configured for browser-based clients."""
|
||||
script_path = os.path.join(os.path.dirname(__file__), "..", "scripts", "serve_openai_api.py")
|
||||
with open(script_path) as f:
|
||||
source = f.read()
|
||||
assert "CORSMiddleware" in source, "CORS middleware required for browser clients"
|
||||
|
||||
def test_non_stream_uses_max_new_tokens(self):
|
||||
"""Non-stream path must use max_new_tokens, not max_length, matching stream behavior."""
|
||||
script_path = os.path.join(os.path.dirname(__file__), "..", "scripts", "serve_openai_api.py")
|
||||
with open(script_path) as f:
|
||||
source = f.read()
|
||||
# The non-stream generate call should use max_new_tokens
|
||||
# Find the non-stream generate block (after "with torch.no_grad():")
|
||||
non_stream_section = source[source.index("with torch.no_grad()"):]
|
||||
assert "max_new_tokens" in non_stream_section, \
|
||||
"Non-stream path should use max_new_tokens for consistent behavior with stream path"
|
||||
# max_length should NOT appear in the generate call
|
||||
generate_block = non_stream_section[:non_stream_section.index("answer = tokenizer.decode")]
|
||||
assert "max_length" not in generate_block, \
|
||||
"Non-stream path should not use max_length (use max_new_tokens instead)"
|
||||
|
||||
def test_generation_thread_is_daemon(self):
|
||||
"""Generation thread must be a daemon so it doesn't block process exit."""
|
||||
script_path = os.path.join(os.path.dirname(__file__), "..", "scripts", "serve_openai_api.py")
|
||||
with open(script_path) as f:
|
||||
source = f.read()
|
||||
assert "daemon=True" in source, "Generation thread should be daemon to avoid blocking exit"
|
||||
|
||||
def test_generation_thread_error_propagation(self):
|
||||
"""Generation thread errors must be propagated to the consumer, not swallowed."""
|
||||
script_path = os.path.join(os.path.dirname(__file__), "..", "scripts", "serve_openai_api.py")
|
||||
with open(script_path) as f:
|
||||
source = f.read()
|
||||
# The _generate() function should have a try/except that puts errors on the queue
|
||||
assert "_SENTINEL" in source, "Sentinel-based error propagation required"
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Run with pytest or directly
|
||||
# ============================================================
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
sys.exit(pytest.main([__file__, "-v"]))
|
||||
Loading…
Reference in New Issue
Block a user