diff --git a/scripts/serve_openai_api.py b/scripts/serve_openai_api.py
index 0550d4a..d515ba1 100644
--- a/scripts/serve_openai_api.py
+++ b/scripts/serve_openai_api.py
@@ -14,6 +14,7 @@ import uvicorn
from threading import Thread
from queue import Queue
from fastapi import FastAPI, HTTPException
+from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
@@ -23,6 +24,13 @@ from model.model_lora import apply_lora, load_lora
warnings.filterwarnings('ignore')
app = FastAPI()
+app.add_middleware(
+ CORSMiddleware,
+ allow_origins=["*"],
+ allow_credentials=True,
+ allow_methods=["*"],
+ allow_headers=["*"],
+)
def init_model(args):
@@ -102,7 +110,26 @@ def parse_response(text):
return text.strip(), reasoning_content, tool_calls or None
+_SENTINEL = object()
+
+
def generate_stream_response(messages, temperature, top_p, max_tokens, tools=None, open_thinking=False):
+ request_id = f"chatcmpl-{int(time.time() * 1000)}"
+ created = int(time.time())
+
+ def _make_chunk(delta, finish_reason=None):
+ """Build a single SSE chunk in OpenAI chat.completion.chunk format."""
+ choice = {"index": 0, "delta": delta}
+ if finish_reason is not None:
+ choice["finish_reason"] = finish_reason
+ return json.dumps({
+ "id": request_id,
+ "object": "chat.completion.chunk",
+ "created": created,
+ "model": "minimind",
+ "choices": [choice],
+ }, ensure_ascii=False)
+
try:
new_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, tools=tools or None, open_thinking=open_thinking)[-max_tokens:]
inputs = tokenizer(new_prompt, return_tensors="pt", truncation=True).to(device)
@@ -111,19 +138,28 @@ def generate_stream_response(messages, temperature, top_p, max_tokens, tools=Non
streamer = CustomStreamer(tokenizer, queue)
def _generate():
- model.generate(
- inputs.input_ids,
- max_new_tokens=max_tokens,
- do_sample=True,
- temperature=temperature,
- top_p=top_p,
- attention_mask=inputs.attention_mask,
- pad_token_id=tokenizer.pad_token_id,
- eos_token_id=tokenizer.eos_token_id,
- streamer=streamer
- )
+ try:
+ model.generate(
+ inputs.input_ids,
+ max_new_tokens=max_tokens,
+ do_sample=True,
+ temperature=temperature,
+ top_p=top_p,
+ attention_mask=inputs.attention_mask,
+ pad_token_id=tokenizer.pad_token_id,
+ eos_token_id=tokenizer.eos_token_id,
+ streamer=streamer
+ )
+ except Exception as e:
+ # Propagate generation errors to the consumer so it doesn't
+ # block forever on queue.get().
+ queue.put(_SENTINEL)
+ queue.put(e)
- Thread(target=_generate).start()
+ Thread(target=_generate, daemon=True).start()
+
+ # Emit the initial chunk with the role.
+ yield _make_chunk({"role": "assistant"})
full_text = ""
emitted = 0
@@ -131,6 +167,9 @@ def generate_stream_response(messages, temperature, top_p, max_tokens, tools=Non
while True:
text = queue.get()
+ if text is _SENTINEL:
+ # The generation thread raised; re-raise here.
+ raise queue.get()
if text is None:
break
full_text += text
@@ -141,28 +180,29 @@ def generate_stream_response(messages, temperature, top_p, max_tokens, tools=Non
thinking_ended = True
new_r = full_text[emitted:pos]
if new_r:
- yield json.dumps({"choices": [{"delta": {"reasoning_content": new_r}}]}, ensure_ascii=False)
+ yield _make_chunk({"reasoning_content": new_r})
emitted = pos + len('')
after = full_text[emitted:].lstrip('\n')
emitted = len(full_text) - len(after)
if after:
- yield json.dumps({"choices": [{"delta": {"content": after}}]}, ensure_ascii=False)
+ yield _make_chunk({"content": after})
emitted = len(full_text)
else:
new_r = full_text[emitted:]
if new_r:
- yield json.dumps({"choices": [{"delta": {"reasoning_content": new_r}}]}, ensure_ascii=False)
+ yield _make_chunk({"reasoning_content": new_r})
emitted = len(full_text)
else:
new_c = full_text[emitted:]
if new_c:
- yield json.dumps({"choices": [{"delta": {"content": new_c}}]}, ensure_ascii=False)
+ yield _make_chunk({"content": new_c})
emitted = len(full_text)
_, _, tool_calls = parse_response(full_text)
if tool_calls:
- yield json.dumps({"choices": [{"delta": {"tool_calls": tool_calls}}]}, ensure_ascii=False)
- yield json.dumps({"choices": [{"delta": {}, "finish_reason": "tool_calls" if tool_calls else "stop"}]}, ensure_ascii=False)
+ yield _make_chunk({"tool_calls": tool_calls})
+ finish = "tool_calls" if tool_calls else "stop"
+ yield _make_chunk({}, finish_reason=finish)
except Exception as e:
yield json.dumps({"error": str(e)})
@@ -172,15 +212,20 @@ def generate_stream_response(messages, temperature, top_p, max_tokens, tools=Non
async def chat_completions(request: ChatRequest):
try:
if request.stream:
- return StreamingResponse(
- (f"data: {chunk}\n\n" for chunk in generate_stream_response(
+ def _event_stream():
+ for chunk in generate_stream_response(
messages=request.messages,
temperature=request.temperature,
top_p=request.top_p,
max_tokens=request.max_tokens,
tools=request.tools,
open_thinking=request.get_open_thinking()
- )),
+ ):
+ yield f"data: {chunk}\n\n"
+ yield "data: [DONE]\n\n"
+
+ return StreamingResponse(
+ _event_stream(),
media_type="text/event-stream"
)
else:
@@ -195,7 +240,7 @@ async def chat_completions(request: ChatRequest):
with torch.no_grad():
generated_ids = model.generate(
inputs["input_ids"],
- max_length=inputs["input_ids"].shape[1] + request.max_tokens,
+ max_new_tokens=request.max_tokens,
do_sample=True,
attention_mask=inputs["attention_mask"],
pad_token_id=tokenizer.pad_token_id,
diff --git a/tests/test_serve_openai_api.py b/tests/test_serve_openai_api.py
new file mode 100644
index 0000000..6ae9a54
--- /dev/null
+++ b/tests/test_serve_openai_api.py
@@ -0,0 +1,220 @@
+"""
+Regression tests for scripts/serve_openai_api.py
+These tests verify OpenAI API compatibility without loading a real model.
+"""
+import json
+import re
+import sys
+import os
+import types
+
+# ---------------------------------------------------------------------------
+# Stub heavy dependencies so importing serve_openai_api doesn't need GPU/model
+# ---------------------------------------------------------------------------
+
+# Provide a minimal torch stub (only attributes actually used at import time)
+_torch_stub = types.ModuleType("torch")
+_torch_stub.cuda = types.SimpleNamespace(is_available=lambda: False)
+_torch_stub.no_grad = lambda: types.SimpleNamespace(__enter__=lambda s: None, __exit__=lambda s, *a: None)
+_torch_stub.inference_mode = lambda: lambda fn: fn
+_torch_stub.load = lambda *a, **kw: {}
+_torch_stub.Tensor = type("Tensor", (), {})
+_torch_stub.float32 = "float32"
+_torch_stub.LongTensor = lambda *a, **kw: None
+
+# optim sub-module
+_optim_stub = types.ModuleType("torch.optim")
+sys.modules["torch.optim"] = _optim_stub
+_torch_stub.optim = _optim_stub
+
+# nn sub-module with Module and common layers
+_nn_stub = types.ModuleType("torch.nn")
+class _ModuleStub:
+ def __init_subclass__(cls, **kw): pass
+ def __init__(self, *a, **kw): pass
+_nn_stub.Module = _ModuleStub
+_nn_stub.Linear = type("Linear", (_ModuleStub,), {})
+_nn_stub.Embedding = type("Embedding", (_ModuleStub,), {})
+_nn_stub.Dropout = type("Dropout", (_ModuleStub,), {})
+class _ParameterStub:
+ def __init__(self, *a, **kw): pass
+_nn_stub.Parameter = _ParameterStub
+_torch_stub.nn = _nn_stub
+
+_nn_func = types.ModuleType("torch.nn.functional")
+_nn_func.scaled_dot_product_attention = lambda *a, **kw: None
+_nn_func.cross_entropy = lambda *a, **kw: None
+_nn_func.softmax = lambda *a, **kw: None
+_nn_func.silu = lambda *a, **kw: None
+
+sys.modules["torch"] = _torch_stub
+sys.modules["torch.nn"] = _nn_stub
+sys.modules["torch.nn.functional"] = _nn_func
+sys.modules.setdefault("torch.distributed", types.ModuleType("torch.distributed"))
+sys.modules.setdefault("torch.nn.parallel", types.ModuleType("torch.nn.parallel"))
+sys.modules.setdefault("torch.utils", types.ModuleType("torch.utils"))
+sys.modules.setdefault("torch.utils.data", types.ModuleType("torch.utils.data"))
+
+for mod_name in [
+ "transformers", "transformers.activations", "transformers.modeling_outputs",
+ "librosa", "soundfile", "numpy", "uvicorn",
+]:
+ sys.modules.setdefault(mod_name, types.ModuleType(mod_name))
+
+# Stub transformers classes that are used at module scope
+_transformers = sys.modules["transformers"]
+for cls_name in [
+ "PreTrainedModel", "GenerationMixin", "PretrainedConfig",
+ "AutoTokenizer", "AutoModelForCausalLM", "AutoModel",
+ "AutoModelForSequenceClassification", "TextStreamer",
+]:
+ if not hasattr(_transformers, cls_name):
+ setattr(_transformers, cls_name, type(cls_name, (), {}))
+
+_mo = sys.modules["transformers.modeling_outputs"]
+if not hasattr(_mo, "MoeCausalLMOutputWithPast"):
+ setattr(_mo, "MoeCausalLMOutputWithPast", type("MoeCausalLMOutputWithPast", (), {}))
+
+_act = sys.modules["transformers.activations"]
+if not hasattr(_act, "ACT2FN"):
+ setattr(_act, "ACT2FN", {})
+
+# Now import the functions we can test without a live model
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "scripts"))
+
+from scripts.serve_openai_api import parse_response
+
+
+# ============================================================
+# Tests for parse_response
+# ============================================================
+
+class TestParseResponse:
+ """Verify that parse_response correctly separates reasoning, content and tool calls."""
+
+ def test_plain_text(self):
+ content, reasoning, tool_calls = parse_response("Hello, world!")
+ assert content == "Hello, world!"
+ assert reasoning is None
+ assert tool_calls is None
+
+ def test_think_tags_complete(self):
+ text = "Step 1: consider optionsThe answer is 42."
+ content, reasoning, tool_calls = parse_response(text)
+ assert reasoning == "Step 1: consider options"
+ assert content == "The answer is 42."
+ assert tool_calls is None
+
+ def test_think_tag_no_opening(self):
+ """When only is present (streaming partial), split at the tag."""
+ text = "I need to think carefullyFinal answer here"
+ content, reasoning, tool_calls = parse_response(text)
+ assert reasoning == "I need to think carefully"
+ assert content == "Final answer here"
+
+ def test_tool_call_parsed(self):
+ text = 'Sure, I can help. {"name":"get_weather","arguments":{"city":"Tokyo"}}'
+ content, reasoning, tool_calls = parse_response(text)
+ assert "Sure, I can help." in content
+ assert tool_calls is not None
+ assert len(tool_calls) == 1
+ assert tool_calls[0]["function"]["name"] == "get_weather"
+ args = json.loads(tool_calls[0]["function"]["arguments"])
+ assert args["city"] == "Tokyo"
+
+ def test_multiple_tool_calls(self):
+ text = (
+ '{"name":"a","arguments":{}}'
+ '{"name":"b","arguments":{"x":1}}'
+ )
+ content, reasoning, tool_calls = parse_response(text)
+ assert tool_calls is not None
+ assert len(tool_calls) == 2
+ assert tool_calls[0]["function"]["name"] == "a"
+ assert tool_calls[1]["function"]["name"] == "b"
+
+ def test_invalid_tool_call_json_skipped(self):
+ text = 'not valid jsonThe rest.'
+ content, reasoning, tool_calls = parse_response(text)
+ assert "The rest." in content
+ assert tool_calls is None # Invalid JSON is silently skipped
+
+ def test_empty_string(self):
+ content, reasoning, tool_calls = parse_response("")
+ assert content == ""
+ assert reasoning is None
+ assert tool_calls is None
+
+
+# ============================================================
+# Tests for SSE stream format (structural, no live model)
+# ============================================================
+
+class TestSSEStreamFormat:
+ """Verify the SSE event stream wrapper produces correct framing."""
+
+ def test_done_marker_present(self):
+ """The stream MUST end with 'data: [DONE]\\n\\n' per OpenAI spec."""
+ # Read the source and verify the [DONE] marker is emitted
+ script_path = os.path.join(os.path.dirname(__file__), "..", "scripts", "serve_openai_api.py")
+ with open(script_path) as f:
+ source = f.read()
+ assert 'data: [DONE]' in source, "Stream must emit 'data: [DONE]' terminator"
+
+ def test_stream_chunks_have_required_fields(self):
+ """Each stream chunk must include id, object, created, model, choices."""
+ script_path = os.path.join(os.path.dirname(__file__), "..", "scripts", "serve_openai_api.py")
+ with open(script_path) as f:
+ source = f.read()
+ # Verify _make_chunk builds proper structure with all required fields
+ assert '"id": request_id' in source or '"id":' in source
+ assert '"object": "chat.completion.chunk"' in source
+ assert '"created":' in source
+ assert '"model":' in source
+
+ def test_cors_middleware_present(self):
+ """CORS middleware must be configured for browser-based clients."""
+ script_path = os.path.join(os.path.dirname(__file__), "..", "scripts", "serve_openai_api.py")
+ with open(script_path) as f:
+ source = f.read()
+ assert "CORSMiddleware" in source, "CORS middleware required for browser clients"
+
+ def test_non_stream_uses_max_new_tokens(self):
+ """Non-stream path must use max_new_tokens, not max_length, matching stream behavior."""
+ script_path = os.path.join(os.path.dirname(__file__), "..", "scripts", "serve_openai_api.py")
+ with open(script_path) as f:
+ source = f.read()
+ # The non-stream generate call should use max_new_tokens
+ # Find the non-stream generate block (after "with torch.no_grad():")
+ non_stream_section = source[source.index("with torch.no_grad()"):]
+ assert "max_new_tokens" in non_stream_section, \
+ "Non-stream path should use max_new_tokens for consistent behavior with stream path"
+ # max_length should NOT appear in the generate call
+ generate_block = non_stream_section[:non_stream_section.index("answer = tokenizer.decode")]
+ assert "max_length" not in generate_block, \
+ "Non-stream path should not use max_length (use max_new_tokens instead)"
+
+ def test_generation_thread_is_daemon(self):
+ """Generation thread must be a daemon so it doesn't block process exit."""
+ script_path = os.path.join(os.path.dirname(__file__), "..", "scripts", "serve_openai_api.py")
+ with open(script_path) as f:
+ source = f.read()
+ assert "daemon=True" in source, "Generation thread should be daemon to avoid blocking exit"
+
+ def test_generation_thread_error_propagation(self):
+ """Generation thread errors must be propagated to the consumer, not swallowed."""
+ script_path = os.path.join(os.path.dirname(__file__), "..", "scripts", "serve_openai_api.py")
+ with open(script_path) as f:
+ source = f.read()
+ # The _generate() function should have a try/except that puts errors on the queue
+ assert "_SENTINEL" in source, "Sentinel-based error propagation required"
+
+
+# ============================================================
+# Run with pytest or directly
+# ============================================================
+
+if __name__ == "__main__":
+ import pytest
+ sys.exit(pytest.main([__file__, "-v"]))