fix: OpenAI API SSE compatibility and stream reliability

Problems:
1. Stream response missing 'data: [DONE]' terminator — OpenAI clients
   (openai-python, LiteLLM, etc.) expect this SSE sentinel to detect
   end-of-stream. Without it, clients hang or raise IncompleteRead.

2. Stream chunks lack required fields (id, object, created, model) —
   the OpenAI spec mandates chat.completion.chunk objects contain these
   metadata fields. Clients that validate response shape reject the
   bare {choices:[{delta:{}}]} payloads.

3. Non-stream path uses max_length instead of max_new_tokens — stream
   path correctly uses max_new_tokens, but non-stream passes
   max_length=prompt_len+max_tokens. This creates inconsistent behavior
   (different effective generation limits) between the two code paths.

4. Generation thread exceptions silently swallowed — if model.generate()
   throws (OOM, dtype mismatch, etc.), the queue never receives None,
   causing the consumer to block forever on queue.get(). Added sentinel-
   based error propagation so exceptions surface to the SSE generator.

5. No CORS headers — browser-based clients (web UIs, chatbots) get
   blocked by same-origin policy. Added CORSMiddleware.

6. Generation thread not daemonized — if the main process exits during
   generation (e.g. client disconnect + server shutdown), the non-daemon
   thread keeps the process alive. Set daemon=True.

Tests:
- 7 tests for parse_response (plain text, think tags, tool calls, edge cases)
- 6 tests for SSE stream format (DONE marker, chunk fields, CORS, max_new_tokens, daemon, error propagation)
This commit is contained in:
voidborne-d 2026-04-17 00:56:36 +00:00
parent 487f78754d
commit 48ed5ec8bc
2 changed files with 287 additions and 22 deletions

View File

@ -14,6 +14,7 @@ import uvicorn
from threading import Thread
from queue import Queue
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
@ -23,6 +24,13 @@ from model.model_lora import apply_lora, load_lora
warnings.filterwarnings('ignore')
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def init_model(args):
@ -102,7 +110,26 @@ def parse_response(text):
return text.strip(), reasoning_content, tool_calls or None
_SENTINEL = object()
def generate_stream_response(messages, temperature, top_p, max_tokens, tools=None, open_thinking=False):
request_id = f"chatcmpl-{int(time.time() * 1000)}"
created = int(time.time())
def _make_chunk(delta, finish_reason=None):
"""Build a single SSE chunk in OpenAI chat.completion.chunk format."""
choice = {"index": 0, "delta": delta}
if finish_reason is not None:
choice["finish_reason"] = finish_reason
return json.dumps({
"id": request_id,
"object": "chat.completion.chunk",
"created": created,
"model": "minimind",
"choices": [choice],
}, ensure_ascii=False)
try:
new_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, tools=tools or None, open_thinking=open_thinking)[-max_tokens:]
inputs = tokenizer(new_prompt, return_tensors="pt", truncation=True).to(device)
@ -111,19 +138,28 @@ def generate_stream_response(messages, temperature, top_p, max_tokens, tools=Non
streamer = CustomStreamer(tokenizer, queue)
def _generate():
model.generate(
inputs.input_ids,
max_new_tokens=max_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
attention_mask=inputs.attention_mask,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
streamer=streamer
)
try:
model.generate(
inputs.input_ids,
max_new_tokens=max_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
attention_mask=inputs.attention_mask,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
streamer=streamer
)
except Exception as e:
# Propagate generation errors to the consumer so it doesn't
# block forever on queue.get().
queue.put(_SENTINEL)
queue.put(e)
Thread(target=_generate).start()
Thread(target=_generate, daemon=True).start()
# Emit the initial chunk with the role.
yield _make_chunk({"role": "assistant"})
full_text = ""
emitted = 0
@ -131,6 +167,9 @@ def generate_stream_response(messages, temperature, top_p, max_tokens, tools=Non
while True:
text = queue.get()
if text is _SENTINEL:
# The generation thread raised; re-raise here.
raise queue.get()
if text is None:
break
full_text += text
@ -141,28 +180,29 @@ def generate_stream_response(messages, temperature, top_p, max_tokens, tools=Non
thinking_ended = True
new_r = full_text[emitted:pos]
if new_r:
yield json.dumps({"choices": [{"delta": {"reasoning_content": new_r}}]}, ensure_ascii=False)
yield _make_chunk({"reasoning_content": new_r})
emitted = pos + len('</think>')
after = full_text[emitted:].lstrip('\n')
emitted = len(full_text) - len(after)
if after:
yield json.dumps({"choices": [{"delta": {"content": after}}]}, ensure_ascii=False)
yield _make_chunk({"content": after})
emitted = len(full_text)
else:
new_r = full_text[emitted:]
if new_r:
yield json.dumps({"choices": [{"delta": {"reasoning_content": new_r}}]}, ensure_ascii=False)
yield _make_chunk({"reasoning_content": new_r})
emitted = len(full_text)
else:
new_c = full_text[emitted:]
if new_c:
yield json.dumps({"choices": [{"delta": {"content": new_c}}]}, ensure_ascii=False)
yield _make_chunk({"content": new_c})
emitted = len(full_text)
_, _, tool_calls = parse_response(full_text)
if tool_calls:
yield json.dumps({"choices": [{"delta": {"tool_calls": tool_calls}}]}, ensure_ascii=False)
yield json.dumps({"choices": [{"delta": {}, "finish_reason": "tool_calls" if tool_calls else "stop"}]}, ensure_ascii=False)
yield _make_chunk({"tool_calls": tool_calls})
finish = "tool_calls" if tool_calls else "stop"
yield _make_chunk({}, finish_reason=finish)
except Exception as e:
yield json.dumps({"error": str(e)})
@ -172,15 +212,20 @@ def generate_stream_response(messages, temperature, top_p, max_tokens, tools=Non
async def chat_completions(request: ChatRequest):
try:
if request.stream:
return StreamingResponse(
(f"data: {chunk}\n\n" for chunk in generate_stream_response(
def _event_stream():
for chunk in generate_stream_response(
messages=request.messages,
temperature=request.temperature,
top_p=request.top_p,
max_tokens=request.max_tokens,
tools=request.tools,
open_thinking=request.get_open_thinking()
)),
):
yield f"data: {chunk}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(
_event_stream(),
media_type="text/event-stream"
)
else:
@ -195,7 +240,7 @@ async def chat_completions(request: ChatRequest):
with torch.no_grad():
generated_ids = model.generate(
inputs["input_ids"],
max_length=inputs["input_ids"].shape[1] + request.max_tokens,
max_new_tokens=request.max_tokens,
do_sample=True,
attention_mask=inputs["attention_mask"],
pad_token_id=tokenizer.pad_token_id,

View File

@ -0,0 +1,220 @@
"""
Regression tests for scripts/serve_openai_api.py
These tests verify OpenAI API compatibility without loading a real model.
"""
import json
import re
import sys
import os
import types
# ---------------------------------------------------------------------------
# Stub heavy dependencies so importing serve_openai_api doesn't need GPU/model
# ---------------------------------------------------------------------------
# Provide a minimal torch stub (only attributes actually used at import time)
_torch_stub = types.ModuleType("torch")
_torch_stub.cuda = types.SimpleNamespace(is_available=lambda: False)
_torch_stub.no_grad = lambda: types.SimpleNamespace(__enter__=lambda s: None, __exit__=lambda s, *a: None)
_torch_stub.inference_mode = lambda: lambda fn: fn
_torch_stub.load = lambda *a, **kw: {}
_torch_stub.Tensor = type("Tensor", (), {})
_torch_stub.float32 = "float32"
_torch_stub.LongTensor = lambda *a, **kw: None
# optim sub-module
_optim_stub = types.ModuleType("torch.optim")
sys.modules["torch.optim"] = _optim_stub
_torch_stub.optim = _optim_stub
# nn sub-module with Module and common layers
_nn_stub = types.ModuleType("torch.nn")
class _ModuleStub:
def __init_subclass__(cls, **kw): pass
def __init__(self, *a, **kw): pass
_nn_stub.Module = _ModuleStub
_nn_stub.Linear = type("Linear", (_ModuleStub,), {})
_nn_stub.Embedding = type("Embedding", (_ModuleStub,), {})
_nn_stub.Dropout = type("Dropout", (_ModuleStub,), {})
class _ParameterStub:
def __init__(self, *a, **kw): pass
_nn_stub.Parameter = _ParameterStub
_torch_stub.nn = _nn_stub
_nn_func = types.ModuleType("torch.nn.functional")
_nn_func.scaled_dot_product_attention = lambda *a, **kw: None
_nn_func.cross_entropy = lambda *a, **kw: None
_nn_func.softmax = lambda *a, **kw: None
_nn_func.silu = lambda *a, **kw: None
sys.modules["torch"] = _torch_stub
sys.modules["torch.nn"] = _nn_stub
sys.modules["torch.nn.functional"] = _nn_func
sys.modules.setdefault("torch.distributed", types.ModuleType("torch.distributed"))
sys.modules.setdefault("torch.nn.parallel", types.ModuleType("torch.nn.parallel"))
sys.modules.setdefault("torch.utils", types.ModuleType("torch.utils"))
sys.modules.setdefault("torch.utils.data", types.ModuleType("torch.utils.data"))
for mod_name in [
"transformers", "transformers.activations", "transformers.modeling_outputs",
"librosa", "soundfile", "numpy", "uvicorn",
]:
sys.modules.setdefault(mod_name, types.ModuleType(mod_name))
# Stub transformers classes that are used at module scope
_transformers = sys.modules["transformers"]
for cls_name in [
"PreTrainedModel", "GenerationMixin", "PretrainedConfig",
"AutoTokenizer", "AutoModelForCausalLM", "AutoModel",
"AutoModelForSequenceClassification", "TextStreamer",
]:
if not hasattr(_transformers, cls_name):
setattr(_transformers, cls_name, type(cls_name, (), {}))
_mo = sys.modules["transformers.modeling_outputs"]
if not hasattr(_mo, "MoeCausalLMOutputWithPast"):
setattr(_mo, "MoeCausalLMOutputWithPast", type("MoeCausalLMOutputWithPast", (), {}))
_act = sys.modules["transformers.activations"]
if not hasattr(_act, "ACT2FN"):
setattr(_act, "ACT2FN", {})
# Now import the functions we can test without a live model
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "scripts"))
from scripts.serve_openai_api import parse_response
# ============================================================
# Tests for parse_response
# ============================================================
class TestParseResponse:
"""Verify that parse_response correctly separates reasoning, content and tool calls."""
def test_plain_text(self):
content, reasoning, tool_calls = parse_response("Hello, world!")
assert content == "Hello, world!"
assert reasoning is None
assert tool_calls is None
def test_think_tags_complete(self):
text = "<think>Step 1: consider options</think>The answer is 42."
content, reasoning, tool_calls = parse_response(text)
assert reasoning == "Step 1: consider options"
assert content == "The answer is 42."
assert tool_calls is None
def test_think_tag_no_opening(self):
"""When only </think> is present (streaming partial), split at the tag."""
text = "I need to think carefully</think>Final answer here"
content, reasoning, tool_calls = parse_response(text)
assert reasoning == "I need to think carefully"
assert content == "Final answer here"
def test_tool_call_parsed(self):
text = 'Sure, I can help. <tool_call>{"name":"get_weather","arguments":{"city":"Tokyo"}}</tool_call>'
content, reasoning, tool_calls = parse_response(text)
assert "Sure, I can help." in content
assert tool_calls is not None
assert len(tool_calls) == 1
assert tool_calls[0]["function"]["name"] == "get_weather"
args = json.loads(tool_calls[0]["function"]["arguments"])
assert args["city"] == "Tokyo"
def test_multiple_tool_calls(self):
text = (
'<tool_call>{"name":"a","arguments":{}}</tool_call>'
'<tool_call>{"name":"b","arguments":{"x":1}}</tool_call>'
)
content, reasoning, tool_calls = parse_response(text)
assert tool_calls is not None
assert len(tool_calls) == 2
assert tool_calls[0]["function"]["name"] == "a"
assert tool_calls[1]["function"]["name"] == "b"
def test_invalid_tool_call_json_skipped(self):
text = '<tool_call>not valid json</tool_call>The rest.'
content, reasoning, tool_calls = parse_response(text)
assert "The rest." in content
assert tool_calls is None # Invalid JSON is silently skipped
def test_empty_string(self):
content, reasoning, tool_calls = parse_response("")
assert content == ""
assert reasoning is None
assert tool_calls is None
# ============================================================
# Tests for SSE stream format (structural, no live model)
# ============================================================
class TestSSEStreamFormat:
"""Verify the SSE event stream wrapper produces correct framing."""
def test_done_marker_present(self):
"""The stream MUST end with 'data: [DONE]\\n\\n' per OpenAI spec."""
# Read the source and verify the [DONE] marker is emitted
script_path = os.path.join(os.path.dirname(__file__), "..", "scripts", "serve_openai_api.py")
with open(script_path) as f:
source = f.read()
assert 'data: [DONE]' in source, "Stream must emit 'data: [DONE]' terminator"
def test_stream_chunks_have_required_fields(self):
"""Each stream chunk must include id, object, created, model, choices."""
script_path = os.path.join(os.path.dirname(__file__), "..", "scripts", "serve_openai_api.py")
with open(script_path) as f:
source = f.read()
# Verify _make_chunk builds proper structure with all required fields
assert '"id": request_id' in source or '"id":' in source
assert '"object": "chat.completion.chunk"' in source
assert '"created":' in source
assert '"model":' in source
def test_cors_middleware_present(self):
"""CORS middleware must be configured for browser-based clients."""
script_path = os.path.join(os.path.dirname(__file__), "..", "scripts", "serve_openai_api.py")
with open(script_path) as f:
source = f.read()
assert "CORSMiddleware" in source, "CORS middleware required for browser clients"
def test_non_stream_uses_max_new_tokens(self):
"""Non-stream path must use max_new_tokens, not max_length, matching stream behavior."""
script_path = os.path.join(os.path.dirname(__file__), "..", "scripts", "serve_openai_api.py")
with open(script_path) as f:
source = f.read()
# The non-stream generate call should use max_new_tokens
# Find the non-stream generate block (after "with torch.no_grad():")
non_stream_section = source[source.index("with torch.no_grad()"):]
assert "max_new_tokens" in non_stream_section, \
"Non-stream path should use max_new_tokens for consistent behavior with stream path"
# max_length should NOT appear in the generate call
generate_block = non_stream_section[:non_stream_section.index("answer = tokenizer.decode")]
assert "max_length" not in generate_block, \
"Non-stream path should not use max_length (use max_new_tokens instead)"
def test_generation_thread_is_daemon(self):
"""Generation thread must be a daemon so it doesn't block process exit."""
script_path = os.path.join(os.path.dirname(__file__), "..", "scripts", "serve_openai_api.py")
with open(script_path) as f:
source = f.read()
assert "daemon=True" in source, "Generation thread should be daemon to avoid blocking exit"
def test_generation_thread_error_propagation(self):
"""Generation thread errors must be propagated to the consumer, not swallowed."""
script_path = os.path.join(os.path.dirname(__file__), "..", "scripts", "serve_openai_api.py")
with open(script_path) as f:
source = f.read()
# The _generate() function should have a try/except that puts errors on the queue
assert "_SENTINEL" in source, "Sentinel-based error propagation required"
# ============================================================
# Run with pytest or directly
# ============================================================
if __name__ == "__main__":
import pytest
sys.exit(pytest.main([__file__, "-v"]))