diff --git a/scripts/serve_openai_api.py b/scripts/serve_openai_api.py index 0550d4a..d515ba1 100644 --- a/scripts/serve_openai_api.py +++ b/scripts/serve_openai_api.py @@ -14,6 +14,7 @@ import uvicorn from threading import Thread from queue import Queue from fastapi import FastAPI, HTTPException +from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer @@ -23,6 +24,13 @@ from model.model_lora import apply_lora, load_lora warnings.filterwarnings('ignore') app = FastAPI() +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) def init_model(args): @@ -102,7 +110,26 @@ def parse_response(text): return text.strip(), reasoning_content, tool_calls or None +_SENTINEL = object() + + def generate_stream_response(messages, temperature, top_p, max_tokens, tools=None, open_thinking=False): + request_id = f"chatcmpl-{int(time.time() * 1000)}" + created = int(time.time()) + + def _make_chunk(delta, finish_reason=None): + """Build a single SSE chunk in OpenAI chat.completion.chunk format.""" + choice = {"index": 0, "delta": delta} + if finish_reason is not None: + choice["finish_reason"] = finish_reason + return json.dumps({ + "id": request_id, + "object": "chat.completion.chunk", + "created": created, + "model": "minimind", + "choices": [choice], + }, ensure_ascii=False) + try: new_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, tools=tools or None, open_thinking=open_thinking)[-max_tokens:] inputs = tokenizer(new_prompt, return_tensors="pt", truncation=True).to(device) @@ -111,19 +138,28 @@ def generate_stream_response(messages, temperature, top_p, max_tokens, tools=Non streamer = CustomStreamer(tokenizer, queue) def _generate(): - model.generate( - inputs.input_ids, - max_new_tokens=max_tokens, - do_sample=True, - temperature=temperature, - top_p=top_p, - attention_mask=inputs.attention_mask, - pad_token_id=tokenizer.pad_token_id, - eos_token_id=tokenizer.eos_token_id, - streamer=streamer - ) + try: + model.generate( + inputs.input_ids, + max_new_tokens=max_tokens, + do_sample=True, + temperature=temperature, + top_p=top_p, + attention_mask=inputs.attention_mask, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + streamer=streamer + ) + except Exception as e: + # Propagate generation errors to the consumer so it doesn't + # block forever on queue.get(). + queue.put(_SENTINEL) + queue.put(e) - Thread(target=_generate).start() + Thread(target=_generate, daemon=True).start() + + # Emit the initial chunk with the role. + yield _make_chunk({"role": "assistant"}) full_text = "" emitted = 0 @@ -131,6 +167,9 @@ def generate_stream_response(messages, temperature, top_p, max_tokens, tools=Non while True: text = queue.get() + if text is _SENTINEL: + # The generation thread raised; re-raise here. + raise queue.get() if text is None: break full_text += text @@ -141,28 +180,29 @@ def generate_stream_response(messages, temperature, top_p, max_tokens, tools=Non thinking_ended = True new_r = full_text[emitted:pos] if new_r: - yield json.dumps({"choices": [{"delta": {"reasoning_content": new_r}}]}, ensure_ascii=False) + yield _make_chunk({"reasoning_content": new_r}) emitted = pos + len('') after = full_text[emitted:].lstrip('\n') emitted = len(full_text) - len(after) if after: - yield json.dumps({"choices": [{"delta": {"content": after}}]}, ensure_ascii=False) + yield _make_chunk({"content": after}) emitted = len(full_text) else: new_r = full_text[emitted:] if new_r: - yield json.dumps({"choices": [{"delta": {"reasoning_content": new_r}}]}, ensure_ascii=False) + yield _make_chunk({"reasoning_content": new_r}) emitted = len(full_text) else: new_c = full_text[emitted:] if new_c: - yield json.dumps({"choices": [{"delta": {"content": new_c}}]}, ensure_ascii=False) + yield _make_chunk({"content": new_c}) emitted = len(full_text) _, _, tool_calls = parse_response(full_text) if tool_calls: - yield json.dumps({"choices": [{"delta": {"tool_calls": tool_calls}}]}, ensure_ascii=False) - yield json.dumps({"choices": [{"delta": {}, "finish_reason": "tool_calls" if tool_calls else "stop"}]}, ensure_ascii=False) + yield _make_chunk({"tool_calls": tool_calls}) + finish = "tool_calls" if tool_calls else "stop" + yield _make_chunk({}, finish_reason=finish) except Exception as e: yield json.dumps({"error": str(e)}) @@ -172,15 +212,20 @@ def generate_stream_response(messages, temperature, top_p, max_tokens, tools=Non async def chat_completions(request: ChatRequest): try: if request.stream: - return StreamingResponse( - (f"data: {chunk}\n\n" for chunk in generate_stream_response( + def _event_stream(): + for chunk in generate_stream_response( messages=request.messages, temperature=request.temperature, top_p=request.top_p, max_tokens=request.max_tokens, tools=request.tools, open_thinking=request.get_open_thinking() - )), + ): + yield f"data: {chunk}\n\n" + yield "data: [DONE]\n\n" + + return StreamingResponse( + _event_stream(), media_type="text/event-stream" ) else: @@ -195,7 +240,7 @@ async def chat_completions(request: ChatRequest): with torch.no_grad(): generated_ids = model.generate( inputs["input_ids"], - max_length=inputs["input_ids"].shape[1] + request.max_tokens, + max_new_tokens=request.max_tokens, do_sample=True, attention_mask=inputs["attention_mask"], pad_token_id=tokenizer.pad_token_id, diff --git a/tests/test_serve_openai_api.py b/tests/test_serve_openai_api.py new file mode 100644 index 0000000..6ae9a54 --- /dev/null +++ b/tests/test_serve_openai_api.py @@ -0,0 +1,220 @@ +""" +Regression tests for scripts/serve_openai_api.py +These tests verify OpenAI API compatibility without loading a real model. +""" +import json +import re +import sys +import os +import types + +# --------------------------------------------------------------------------- +# Stub heavy dependencies so importing serve_openai_api doesn't need GPU/model +# --------------------------------------------------------------------------- + +# Provide a minimal torch stub (only attributes actually used at import time) +_torch_stub = types.ModuleType("torch") +_torch_stub.cuda = types.SimpleNamespace(is_available=lambda: False) +_torch_stub.no_grad = lambda: types.SimpleNamespace(__enter__=lambda s: None, __exit__=lambda s, *a: None) +_torch_stub.inference_mode = lambda: lambda fn: fn +_torch_stub.load = lambda *a, **kw: {} +_torch_stub.Tensor = type("Tensor", (), {}) +_torch_stub.float32 = "float32" +_torch_stub.LongTensor = lambda *a, **kw: None + +# optim sub-module +_optim_stub = types.ModuleType("torch.optim") +sys.modules["torch.optim"] = _optim_stub +_torch_stub.optim = _optim_stub + +# nn sub-module with Module and common layers +_nn_stub = types.ModuleType("torch.nn") +class _ModuleStub: + def __init_subclass__(cls, **kw): pass + def __init__(self, *a, **kw): pass +_nn_stub.Module = _ModuleStub +_nn_stub.Linear = type("Linear", (_ModuleStub,), {}) +_nn_stub.Embedding = type("Embedding", (_ModuleStub,), {}) +_nn_stub.Dropout = type("Dropout", (_ModuleStub,), {}) +class _ParameterStub: + def __init__(self, *a, **kw): pass +_nn_stub.Parameter = _ParameterStub +_torch_stub.nn = _nn_stub + +_nn_func = types.ModuleType("torch.nn.functional") +_nn_func.scaled_dot_product_attention = lambda *a, **kw: None +_nn_func.cross_entropy = lambda *a, **kw: None +_nn_func.softmax = lambda *a, **kw: None +_nn_func.silu = lambda *a, **kw: None + +sys.modules["torch"] = _torch_stub +sys.modules["torch.nn"] = _nn_stub +sys.modules["torch.nn.functional"] = _nn_func +sys.modules.setdefault("torch.distributed", types.ModuleType("torch.distributed")) +sys.modules.setdefault("torch.nn.parallel", types.ModuleType("torch.nn.parallel")) +sys.modules.setdefault("torch.utils", types.ModuleType("torch.utils")) +sys.modules.setdefault("torch.utils.data", types.ModuleType("torch.utils.data")) + +for mod_name in [ + "transformers", "transformers.activations", "transformers.modeling_outputs", + "librosa", "soundfile", "numpy", "uvicorn", +]: + sys.modules.setdefault(mod_name, types.ModuleType(mod_name)) + +# Stub transformers classes that are used at module scope +_transformers = sys.modules["transformers"] +for cls_name in [ + "PreTrainedModel", "GenerationMixin", "PretrainedConfig", + "AutoTokenizer", "AutoModelForCausalLM", "AutoModel", + "AutoModelForSequenceClassification", "TextStreamer", +]: + if not hasattr(_transformers, cls_name): + setattr(_transformers, cls_name, type(cls_name, (), {})) + +_mo = sys.modules["transformers.modeling_outputs"] +if not hasattr(_mo, "MoeCausalLMOutputWithPast"): + setattr(_mo, "MoeCausalLMOutputWithPast", type("MoeCausalLMOutputWithPast", (), {})) + +_act = sys.modules["transformers.activations"] +if not hasattr(_act, "ACT2FN"): + setattr(_act, "ACT2FN", {}) + +# Now import the functions we can test without a live model +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "scripts")) + +from scripts.serve_openai_api import parse_response + + +# ============================================================ +# Tests for parse_response +# ============================================================ + +class TestParseResponse: + """Verify that parse_response correctly separates reasoning, content and tool calls.""" + + def test_plain_text(self): + content, reasoning, tool_calls = parse_response("Hello, world!") + assert content == "Hello, world!" + assert reasoning is None + assert tool_calls is None + + def test_think_tags_complete(self): + text = "Step 1: consider optionsThe answer is 42." + content, reasoning, tool_calls = parse_response(text) + assert reasoning == "Step 1: consider options" + assert content == "The answer is 42." + assert tool_calls is None + + def test_think_tag_no_opening(self): + """When only is present (streaming partial), split at the tag.""" + text = "I need to think carefullyFinal answer here" + content, reasoning, tool_calls = parse_response(text) + assert reasoning == "I need to think carefully" + assert content == "Final answer here" + + def test_tool_call_parsed(self): + text = 'Sure, I can help. {"name":"get_weather","arguments":{"city":"Tokyo"}}' + content, reasoning, tool_calls = parse_response(text) + assert "Sure, I can help." in content + assert tool_calls is not None + assert len(tool_calls) == 1 + assert tool_calls[0]["function"]["name"] == "get_weather" + args = json.loads(tool_calls[0]["function"]["arguments"]) + assert args["city"] == "Tokyo" + + def test_multiple_tool_calls(self): + text = ( + '{"name":"a","arguments":{}}' + '{"name":"b","arguments":{"x":1}}' + ) + content, reasoning, tool_calls = parse_response(text) + assert tool_calls is not None + assert len(tool_calls) == 2 + assert tool_calls[0]["function"]["name"] == "a" + assert tool_calls[1]["function"]["name"] == "b" + + def test_invalid_tool_call_json_skipped(self): + text = 'not valid jsonThe rest.' + content, reasoning, tool_calls = parse_response(text) + assert "The rest." in content + assert tool_calls is None # Invalid JSON is silently skipped + + def test_empty_string(self): + content, reasoning, tool_calls = parse_response("") + assert content == "" + assert reasoning is None + assert tool_calls is None + + +# ============================================================ +# Tests for SSE stream format (structural, no live model) +# ============================================================ + +class TestSSEStreamFormat: + """Verify the SSE event stream wrapper produces correct framing.""" + + def test_done_marker_present(self): + """The stream MUST end with 'data: [DONE]\\n\\n' per OpenAI spec.""" + # Read the source and verify the [DONE] marker is emitted + script_path = os.path.join(os.path.dirname(__file__), "..", "scripts", "serve_openai_api.py") + with open(script_path) as f: + source = f.read() + assert 'data: [DONE]' in source, "Stream must emit 'data: [DONE]' terminator" + + def test_stream_chunks_have_required_fields(self): + """Each stream chunk must include id, object, created, model, choices.""" + script_path = os.path.join(os.path.dirname(__file__), "..", "scripts", "serve_openai_api.py") + with open(script_path) as f: + source = f.read() + # Verify _make_chunk builds proper structure with all required fields + assert '"id": request_id' in source or '"id":' in source + assert '"object": "chat.completion.chunk"' in source + assert '"created":' in source + assert '"model":' in source + + def test_cors_middleware_present(self): + """CORS middleware must be configured for browser-based clients.""" + script_path = os.path.join(os.path.dirname(__file__), "..", "scripts", "serve_openai_api.py") + with open(script_path) as f: + source = f.read() + assert "CORSMiddleware" in source, "CORS middleware required for browser clients" + + def test_non_stream_uses_max_new_tokens(self): + """Non-stream path must use max_new_tokens, not max_length, matching stream behavior.""" + script_path = os.path.join(os.path.dirname(__file__), "..", "scripts", "serve_openai_api.py") + with open(script_path) as f: + source = f.read() + # The non-stream generate call should use max_new_tokens + # Find the non-stream generate block (after "with torch.no_grad():") + non_stream_section = source[source.index("with torch.no_grad()"):] + assert "max_new_tokens" in non_stream_section, \ + "Non-stream path should use max_new_tokens for consistent behavior with stream path" + # max_length should NOT appear in the generate call + generate_block = non_stream_section[:non_stream_section.index("answer = tokenizer.decode")] + assert "max_length" not in generate_block, \ + "Non-stream path should not use max_length (use max_new_tokens instead)" + + def test_generation_thread_is_daemon(self): + """Generation thread must be a daemon so it doesn't block process exit.""" + script_path = os.path.join(os.path.dirname(__file__), "..", "scripts", "serve_openai_api.py") + with open(script_path) as f: + source = f.read() + assert "daemon=True" in source, "Generation thread should be daemon to avoid blocking exit" + + def test_generation_thread_error_propagation(self): + """Generation thread errors must be propagated to the consumer, not swallowed.""" + script_path = os.path.join(os.path.dirname(__file__), "..", "scripts", "serve_openai_api.py") + with open(script_path) as f: + source = f.read() + # The _generate() function should have a try/except that puts errors on the queue + assert "_SENTINEL" in source, "Sentinel-based error propagation required" + + +# ============================================================ +# Run with pytest or directly +# ============================================================ + +if __name__ == "__main__": + import pytest + sys.exit(pytest.main([__file__, "-v"]))