Spaces:
Running
Running
Kyryll Kochkin commited on
Commit ·
88c0e85
1
Parent(s): 0207551
AI added tests
Browse files- tests/test_core_helpers.py +94 -0
- tests/test_live_api.py +26 -9
- tests/test_live_more_models.py +22 -7
- tests/test_main_behavior.py +80 -0
- tests/test_model_registry.py +58 -0
- tests/test_openai_compat.py +0 -6
- tests/test_router_error_paths.py +177 -0
- tests/test_settings.py +29 -0
- tests/test_streaming_contracts.py +225 -0
tests/test_core_helpers.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Unit tests for prompt/token/engine helper utilities."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import types
|
| 5 |
+
|
| 6 |
+
from app.core import engine, tokens
|
| 7 |
+
from app.core.prompting import DEFAULT_SYSTEM_PROMPT, render_chat_prompt
|
| 8 |
+
from app.schemas.chat import ChatMessage
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class DummyTokenizer:
|
| 12 |
+
def __init__(self) -> None:
|
| 13 |
+
self.called_with: tuple[str, bool] | None = None
|
| 14 |
+
|
| 15 |
+
def encode(self, text: str, add_special_tokens: bool = False) -> list[int]:
|
| 16 |
+
self.called_with = (text, add_special_tokens)
|
| 17 |
+
return [1, 2, 3]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class DummyEncoding:
|
| 21 |
+
def __init__(self, size: int) -> None:
|
| 22 |
+
self._size = size
|
| 23 |
+
|
| 24 |
+
def encode(self, _: str) -> list[int]:
|
| 25 |
+
return list(range(self._size))
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class DummyTikToken:
|
| 29 |
+
def __init__(self, size: int) -> None:
|
| 30 |
+
self._size = size
|
| 31 |
+
|
| 32 |
+
def encoding_for_model(self, _: str) -> DummyEncoding:
|
| 33 |
+
return DummyEncoding(self._size)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def test_render_chat_prompt_uses_default_system_prompt() -> None:
|
| 37 |
+
prompt = render_chat_prompt([ChatMessage(role="user", content="Hello")])
|
| 38 |
+
assert prompt.startswith(f"System: {DEFAULT_SYSTEM_PROMPT}\n\n")
|
| 39 |
+
assert prompt.endswith("Assistant:")
|
| 40 |
+
assert "User: Hello" in prompt
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def test_render_chat_prompt_overrides_system_prompt_when_present() -> None:
|
| 44 |
+
prompt = render_chat_prompt(
|
| 45 |
+
[
|
| 46 |
+
ChatMessage(role="system", content="Custom system"),
|
| 47 |
+
ChatMessage(role="user", content="Hello"),
|
| 48 |
+
ChatMessage(role="assistant", content="Hi"),
|
| 49 |
+
]
|
| 50 |
+
)
|
| 51 |
+
assert prompt.startswith("System: Custom system\n\n")
|
| 52 |
+
assert "User: Hello" in prompt
|
| 53 |
+
assert "Assistant: Hi" in prompt
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def test_count_tokens_returns_zero_for_empty_text() -> None:
|
| 57 |
+
assert tokens.count_tokens("", "GPT3-dev") == 0
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def test_count_tokens_uses_tiktoken_when_available(monkeypatch) -> None:
|
| 61 |
+
monkeypatch.setattr(tokens, "tiktoken", DummyTikToken(size=4))
|
| 62 |
+
assert tokens.count_tokens("hello", "GPT3-dev") == 4
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def test_count_tokens_falls_back_to_tokenizer_encode(monkeypatch) -> None:
|
| 66 |
+
monkeypatch.setattr(tokens, "tiktoken", None)
|
| 67 |
+
tokenizer = DummyTokenizer()
|
| 68 |
+
assert tokens.count_tokens("hello", "GPT3-dev", tokenizer=tokenizer) == 3
|
| 69 |
+
assert tokenizer.called_with == ("hello", False)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def test_apply_stop_sequences_returns_earliest_stop_index() -> None:
|
| 73 |
+
text, reason = engine._apply_stop_sequences(
|
| 74 |
+
"abc<END>xyz<STOP>",
|
| 75 |
+
["<STOP>", "<END>"],
|
| 76 |
+
)
|
| 77 |
+
assert text == "abc"
|
| 78 |
+
assert reason == "stop"
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def test_normalize_stop_handles_none_string_and_iterable() -> None:
|
| 82 |
+
assert engine._normalize_stop(None) == ()
|
| 83 |
+
assert engine._normalize_stop("stop") == ("stop",)
|
| 84 |
+
assert engine._normalize_stop(["a", "b"]) == ("a", "b")
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def test_pad_token_id_prefers_pad_then_eos_then_zero() -> None:
|
| 88 |
+
with_pad = types.SimpleNamespace(pad_token_id=9, eos_token_id=7)
|
| 89 |
+
with_eos_only = types.SimpleNamespace(pad_token_id=None, eos_token_id=7)
|
| 90 |
+
with_none = types.SimpleNamespace(pad_token_id=None, eos_token_id=None)
|
| 91 |
+
|
| 92 |
+
assert engine._pad_token_id_or_default(with_pad) == 9
|
| 93 |
+
assert engine._pad_token_id_or_default(with_eos_only) == 7
|
| 94 |
+
assert engine._pad_token_id_or_default(with_none) == 0
|
tests/test_live_api.py
CHANGED
|
@@ -1,28 +1,45 @@
|
|
| 1 |
"""Live API smoke tests hitting a running server.
|
| 2 |
|
| 3 |
Skipped by default; set RUN_LIVE_API_TESTS=1 to enable.
|
| 4 |
-
Configure API base via API_BASE_URL (default:
|
| 5 |
"""
|
| 6 |
from __future__ import annotations
|
| 7 |
|
| 8 |
import os
|
| 9 |
-
from typing import
|
| 10 |
|
| 11 |
import pytest
|
| 12 |
import httpx
|
| 13 |
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
RUN_LIVE = os.environ.get("RUN_LIVE_API_TESTS") == "1"
|
| 16 |
-
BASE_URL = os.environ.get("API_BASE_URL",
|
|
|
|
| 17 |
PROMPT = "he is a doctor. His main goal is"
|
| 18 |
|
| 19 |
|
| 20 |
def _get_models(timeout: float = 10.0) -> Set[str]:
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
|
| 28 |
@pytest.mark.skipif(not RUN_LIVE, reason="set RUN_LIVE_API_TESTS=1 to run live API tests")
|
|
@@ -53,7 +70,7 @@ def test_completion_basic(model: str) -> None:
|
|
| 53 |
}
|
| 54 |
# Allow generous timeout for first-run weight downloads
|
| 55 |
timeout = httpx.Timeout(connect=10.0, read=600.0, write=30.0, pool=10.0)
|
| 56 |
-
with httpx.Client(timeout=timeout) as client:
|
| 57 |
resp = client.post(f"{BASE_URL}/v1/completions", json=payload)
|
| 58 |
resp.raise_for_status()
|
| 59 |
body = resp.json()
|
|
|
|
| 1 |
"""Live API smoke tests hitting a running server.
|
| 2 |
|
| 3 |
Skipped by default; set RUN_LIVE_API_TESTS=1 to enable.
|
| 4 |
+
Configure API base via API_BASE_URL (default: https://k050506koch-gpt3-dev-api.hf.space).
|
| 5 |
"""
|
| 6 |
from __future__ import annotations
|
| 7 |
|
| 8 |
import os
|
| 9 |
+
from typing import Set
|
| 10 |
|
| 11 |
import pytest
|
| 12 |
import httpx
|
| 13 |
|
| 14 |
|
| 15 |
+
DEFAULT_BASE_URL = "https://k050506koch-gpt3-dev-api.hf.space"
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _normalize_base_url(raw_base_url: str) -> str:
|
| 19 |
+
base_url = raw_base_url.rstrip("/")
|
| 20 |
+
if base_url.endswith("/v1"):
|
| 21 |
+
base_url = base_url[:-3]
|
| 22 |
+
return base_url
|
| 23 |
+
|
| 24 |
+
|
| 25 |
RUN_LIVE = os.environ.get("RUN_LIVE_API_TESTS") == "1"
|
| 26 |
+
BASE_URL = _normalize_base_url(os.environ.get("API_BASE_URL", DEFAULT_BASE_URL))
|
| 27 |
+
VERIFY_SSL = os.environ.get("API_VERIFY_SSL", "1") != "0"
|
| 28 |
PROMPT = "he is a doctor. His main goal is"
|
| 29 |
|
| 30 |
|
| 31 |
def _get_models(timeout: float = 10.0) -> Set[str]:
|
| 32 |
+
try:
|
| 33 |
+
with httpx.Client(timeout=timeout, verify=VERIFY_SSL) as client:
|
| 34 |
+
resp = client.get(f"{BASE_URL}/v1/models")
|
| 35 |
+
resp.raise_for_status()
|
| 36 |
+
data = resp.json()
|
| 37 |
+
return {item["id"] for item in data.get("data", [])}
|
| 38 |
+
except httpx.HTTPError as exc:
|
| 39 |
+
pytest.fail(
|
| 40 |
+
f"Unable to reach live API at {BASE_URL}/v1/models: {exc}. "
|
| 41 |
+
"Set API_BASE_URL to your server root URL (with or without '/v1')."
|
| 42 |
+
)
|
| 43 |
|
| 44 |
|
| 45 |
@pytest.mark.skipif(not RUN_LIVE, reason="set RUN_LIVE_API_TESTS=1 to run live API tests")
|
|
|
|
| 70 |
}
|
| 71 |
# Allow generous timeout for first-run weight downloads
|
| 72 |
timeout = httpx.Timeout(connect=10.0, read=600.0, write=30.0, pool=10.0)
|
| 73 |
+
with httpx.Client(timeout=timeout, verify=VERIFY_SSL) as client:
|
| 74 |
resp = client.post(f"{BASE_URL}/v1/completions", json=payload)
|
| 75 |
resp.raise_for_status()
|
| 76 |
body = resp.json()
|
tests/test_live_more_models.py
CHANGED
|
@@ -15,8 +15,18 @@ import pytest
|
|
| 15 |
import httpx
|
| 16 |
|
| 17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
RUN_LIVE = os.environ.get("RUN_LIVE_API_TESTS") == "1"
|
| 19 |
-
BASE_URL = os.environ.get("API_BASE_URL",
|
| 20 |
VERIFY_SSL = os.environ.get("API_VERIFY_SSL", "1") != "0"
|
| 21 |
PROMPT = "he is a doctor. His main goal is"
|
| 22 |
|
|
@@ -35,11 +45,17 @@ CANDIDATES = [
|
|
| 35 |
|
| 36 |
@lru_cache(maxsize=1)
|
| 37 |
def _get_models(timeout: float = 10.0) -> Set[str]:
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
if not models:
|
| 45 |
pytest.fail(f"/v1/models returned no data from {BASE_URL}")
|
|
@@ -95,4 +111,3 @@ def test_completion_for_models(model: str) -> None:
|
|
| 95 |
warnings.warn(message, stacklevel=1)
|
| 96 |
usage = body.get("usage") or {}
|
| 97 |
assert "total_tokens" in usage
|
| 98 |
-
|
|
|
|
| 15 |
import httpx
|
| 16 |
|
| 17 |
|
| 18 |
+
DEFAULT_BASE_URL = "https://k050506koch-gpt3-dev-api.hf.space"
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _normalize_base_url(raw_base_url: str) -> str:
|
| 22 |
+
base_url = raw_base_url.rstrip("/")
|
| 23 |
+
if base_url.endswith("/v1"):
|
| 24 |
+
base_url = base_url[:-3]
|
| 25 |
+
return base_url
|
| 26 |
+
|
| 27 |
+
|
| 28 |
RUN_LIVE = os.environ.get("RUN_LIVE_API_TESTS") == "1"
|
| 29 |
+
BASE_URL = _normalize_base_url(os.environ.get("API_BASE_URL", DEFAULT_BASE_URL))
|
| 30 |
VERIFY_SSL = os.environ.get("API_VERIFY_SSL", "1") != "0"
|
| 31 |
PROMPT = "he is a doctor. His main goal is"
|
| 32 |
|
|
|
|
| 45 |
|
| 46 |
@lru_cache(maxsize=1)
|
| 47 |
def _get_models(timeout: float = 10.0) -> Set[str]:
|
| 48 |
+
try:
|
| 49 |
+
with httpx.Client(timeout=timeout, verify=VERIFY_SSL) as client:
|
| 50 |
+
resp = client.get(f"{BASE_URL}/v1/models")
|
| 51 |
+
resp.raise_for_status()
|
| 52 |
+
data = resp.json()
|
| 53 |
+
models = {item.get("id") for item in (data.get("data") or [])}
|
| 54 |
+
except httpx.HTTPError as exc:
|
| 55 |
+
pytest.fail(
|
| 56 |
+
f"Unable to reach live API at {BASE_URL}/v1/models: {exc}. "
|
| 57 |
+
"Set API_BASE_URL to your server root URL (with or without '/v1')."
|
| 58 |
+
)
|
| 59 |
|
| 60 |
if not models:
|
| 61 |
pytest.fail(f"/v1/models returned no data from {BASE_URL}")
|
|
|
|
| 111 |
warnings.warn(message, stacklevel=1)
|
| 112 |
usage = body.get("usage") or {}
|
| 113 |
assert "total_tokens" in usage
|
|
|
tests/test_main_behavior.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Behavioral tests for app.main helpers and handlers."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import asyncio
|
| 5 |
+
import json
|
| 6 |
+
|
| 7 |
+
from fastapi import HTTPException
|
| 8 |
+
from starlette.requests import Request
|
| 9 |
+
|
| 10 |
+
from app import main
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def _dummy_request() -> Request:
|
| 14 |
+
scope = {
|
| 15 |
+
"type": "http",
|
| 16 |
+
"http_version": "1.1",
|
| 17 |
+
"method": "GET",
|
| 18 |
+
"path": "/",
|
| 19 |
+
"headers": [],
|
| 20 |
+
"query_string": b"",
|
| 21 |
+
}
|
| 22 |
+
return Request(scope)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def test_root_returns_ok_when_no_failures(monkeypatch) -> None:
|
| 26 |
+
monkeypatch.setattr(main, "_endpoint_status", {"failures": {}, "last_checked": None})
|
| 27 |
+
payload = asyncio.run(main.root())
|
| 28 |
+
assert payload == {"status": "ok", "message": "GPT3dev API is running"}
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def test_root_returns_degraded_with_sorted_issues_and_last_checked(monkeypatch) -> None:
|
| 32 |
+
monkeypatch.setattr(
|
| 33 |
+
main,
|
| 34 |
+
"_endpoint_status",
|
| 35 |
+
{
|
| 36 |
+
"failures": {
|
| 37 |
+
"/v1/zeta": {"status_code": 503},
|
| 38 |
+
"/v1/alpha": {"status_code": 500, "detail": "boom"},
|
| 39 |
+
},
|
| 40 |
+
"last_checked": "2026-02-05T12:00:00+00:00",
|
| 41 |
+
},
|
| 42 |
+
)
|
| 43 |
+
payload = asyncio.run(main.root())
|
| 44 |
+
|
| 45 |
+
assert payload["status"] == "degraded"
|
| 46 |
+
assert payload["message"] == "GPT3dev API is running"
|
| 47 |
+
assert payload["issues"][0]["endpoint"] == "/v1/alpha"
|
| 48 |
+
assert payload["issues"][1]["endpoint"] == "/v1/zeta"
|
| 49 |
+
assert payload["last_checked"] == "2026-02-05T12:00:00+00:00"
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def test_openai_exception_handler_wraps_error_payload() -> None:
|
| 53 |
+
exc = HTTPException(
|
| 54 |
+
status_code=400,
|
| 55 |
+
detail={
|
| 56 |
+
"message": "bad request",
|
| 57 |
+
"type": "invalid_request_error",
|
| 58 |
+
"param": "model",
|
| 59 |
+
"code": "bad_model",
|
| 60 |
+
},
|
| 61 |
+
)
|
| 62 |
+
response = asyncio.run(main.openai_http_exception_handler(_dummy_request(), exc))
|
| 63 |
+
|
| 64 |
+
assert response.status_code == 400
|
| 65 |
+
assert json.loads(response.body.decode("utf-8")) == {
|
| 66 |
+
"error": {
|
| 67 |
+
"message": "bad request",
|
| 68 |
+
"type": "invalid_request_error",
|
| 69 |
+
"param": "model",
|
| 70 |
+
"code": "bad_model",
|
| 71 |
+
}
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def test_openai_exception_handler_preserves_generic_detail() -> None:
|
| 76 |
+
exc = HTTPException(status_code=422, detail="unprocessable")
|
| 77 |
+
response = asyncio.run(main.openai_http_exception_handler(_dummy_request(), exc))
|
| 78 |
+
|
| 79 |
+
assert response.status_code == 422
|
| 80 |
+
assert json.loads(response.body.decode("utf-8")) == {"detail": "unprocessable"}
|
tests/test_model_registry.py
CHANGED
|
@@ -115,3 +115,61 @@ def test_custom_registry_can_extend_defaults(reset_registry, tmp_path: Path):
|
|
| 115 |
names = {spec.name for spec in model_registry.list_models()}
|
| 116 |
assert "Tiny" in names
|
| 117 |
assert "GPT3-dev" in names
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
names = {spec.name for spec in model_registry.list_models()}
|
| 116 |
assert "Tiny" in names
|
| 117 |
assert "GPT3-dev" in names
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def test_registry_loads_yaml_when_json_parse_fails(
|
| 121 |
+
reset_registry,
|
| 122 |
+
monkeypatch: pytest.MonkeyPatch,
|
| 123 |
+
tmp_path: Path,
|
| 124 |
+
):
|
| 125 |
+
registry_path = tmp_path / "registry.yaml"
|
| 126 |
+
registry_path.write_text("- name: Tiny\n hf_repo: dummy/tiny\n")
|
| 127 |
+
|
| 128 |
+
def fake_safe_load(data: str) -> list[dict[str, str]]:
|
| 129 |
+
assert "name: Tiny" in data
|
| 130 |
+
return [{"name": "Tiny", "hf_repo": "dummy/tiny"}]
|
| 131 |
+
|
| 132 |
+
monkeypatch.setattr(model_registry.yaml, "safe_load", fake_safe_load)
|
| 133 |
+
|
| 134 |
+
reset_registry(registry_path=str(registry_path))
|
| 135 |
+
names = {spec.name for spec in model_registry.list_models()}
|
| 136 |
+
assert names == {"Tiny"}
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def test_registry_rejects_non_list_file_payload(reset_registry, tmp_path: Path):
|
| 140 |
+
registry_path = tmp_path / "registry.json"
|
| 141 |
+
registry_path.write_text(json.dumps({"name": "Tiny", "hf_repo": "dummy/tiny"}))
|
| 142 |
+
|
| 143 |
+
reset_registry(registry_path=str(registry_path))
|
| 144 |
+
with pytest.raises(ValueError, match="must contain a list"):
|
| 145 |
+
model_registry.list_models()
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def test_registry_rejects_non_object_entries(reset_registry, tmp_path: Path):
|
| 149 |
+
registry_path = tmp_path / "registry.json"
|
| 150 |
+
registry_path.write_text(json.dumps(["not-an-object"]))
|
| 151 |
+
|
| 152 |
+
reset_registry(registry_path=str(registry_path))
|
| 153 |
+
with pytest.raises(ValueError, match="entries must be objects"):
|
| 154 |
+
model_registry.list_models()
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def test_registry_path_missing_raises_file_not_found(reset_registry, tmp_path: Path):
|
| 158 |
+
reset_registry(registry_path=str(tmp_path / "missing.json"))
|
| 159 |
+
|
| 160 |
+
with pytest.raises(FileNotFoundError, match="MODEL_REGISTRY_PATH not found"):
|
| 161 |
+
model_registry.list_models()
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def test_file_registry_overrides_default_model_with_same_name(reset_registry, tmp_path: Path):
|
| 165 |
+
registry_path = tmp_path / "registry.json"
|
| 166 |
+
registry_path.write_text(
|
| 167 |
+
json.dumps([{"name": "GPT3-dev", "hf_repo": "custom/override"}])
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
reset_registry(registry_path=str(registry_path), include_defaults=True)
|
| 171 |
+
|
| 172 |
+
spec = model_registry.get_model_spec("GPT3-dev")
|
| 173 |
+
names = {item.name for item in model_registry.list_models()}
|
| 174 |
+
assert "GPT-2" in names
|
| 175 |
+
assert spec.hf_repo == "custom/override"
|
tests/test_openai_compat.py
CHANGED
|
@@ -302,12 +302,6 @@ def test_responses_instruct_messages(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
| 302 |
assert body["usage"]["total_tokens"] == 4
|
| 303 |
|
| 304 |
|
| 305 |
-
def test_openai_client_responses_create(monkeypatch: pytest.MonkeyPatch) -> None:
|
| 306 |
-
openai_module = pytest.importorskip("openai")
|
| 307 |
-
OpenAI = openai_module.OpenAI
|
| 308 |
-
pytest.skip("OpenAI client test moved to live API coverage.")
|
| 309 |
-
|
| 310 |
-
|
| 311 |
def test_embeddings_not_implemented() -> None:
|
| 312 |
with pytest.raises(HTTPException) as exc:
|
| 313 |
asyncio.run(embeddings.create_embeddings())
|
|
|
|
| 302 |
assert body["usage"]["total_tokens"] == 4
|
| 303 |
|
| 304 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 305 |
def test_embeddings_not_implemented() -> None:
|
| 306 |
with pytest.raises(HTTPException) as exc:
|
| 307 |
asyncio.run(embeddings.create_embeddings())
|
tests/test_router_error_paths.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Router-level error path tests for OpenAI-compatible payloads."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import asyncio
|
| 5 |
+
|
| 6 |
+
import pytest
|
| 7 |
+
from fastapi import HTTPException
|
| 8 |
+
|
| 9 |
+
from app.core.model_registry import ModelSpec
|
| 10 |
+
from app.routers import chat, completions, embeddings, responses
|
| 11 |
+
from app.schemas.chat import ChatCompletionRequest
|
| 12 |
+
from app.schemas.completions import CompletionRequest
|
| 13 |
+
from app.schemas.responses import ResponseRequest
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def _raise_key_error(_: str) -> None:
|
| 17 |
+
raise KeyError("unknown")
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def test_completions_unknown_model_returns_404_openai_error(
|
| 21 |
+
monkeypatch: pytest.MonkeyPatch,
|
| 22 |
+
) -> None:
|
| 23 |
+
monkeypatch.setattr("app.routers.completions.get_model_spec", _raise_key_error)
|
| 24 |
+
payload = CompletionRequest.model_validate({"model": "missing", "prompt": "Hi"})
|
| 25 |
+
|
| 26 |
+
with pytest.raises(HTTPException) as exc:
|
| 27 |
+
asyncio.run(completions.create_completion(payload))
|
| 28 |
+
|
| 29 |
+
assert exc.value.status_code == 404
|
| 30 |
+
assert exc.value.detail["type"] == "model_not_found"
|
| 31 |
+
assert exc.value.detail["param"] == "model"
|
| 32 |
+
assert exc.value.detail["code"] == "model_not_found"
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def test_chat_unknown_model_returns_404_openai_error(
|
| 36 |
+
monkeypatch: pytest.MonkeyPatch,
|
| 37 |
+
) -> None:
|
| 38 |
+
monkeypatch.setattr("app.routers.chat.get_model_spec", _raise_key_error)
|
| 39 |
+
payload = ChatCompletionRequest.model_validate(
|
| 40 |
+
{
|
| 41 |
+
"model": "missing",
|
| 42 |
+
"messages": [{"role": "user", "content": "Hi"}],
|
| 43 |
+
}
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
with pytest.raises(HTTPException) as exc:
|
| 47 |
+
asyncio.run(chat.create_chat_completion(payload))
|
| 48 |
+
|
| 49 |
+
assert exc.value.status_code == 404
|
| 50 |
+
assert exc.value.detail["type"] == "model_not_found"
|
| 51 |
+
assert exc.value.detail["param"] == "model"
|
| 52 |
+
assert exc.value.detail["code"] == "model_not_found"
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def test_responses_unknown_model_returns_404_openai_error(
|
| 56 |
+
monkeypatch: pytest.MonkeyPatch,
|
| 57 |
+
) -> None:
|
| 58 |
+
monkeypatch.setattr("app.routers.responses.get_model_spec", _raise_key_error)
|
| 59 |
+
payload = ResponseRequest.model_validate({"model": "missing", "input": "Hi"})
|
| 60 |
+
|
| 61 |
+
with pytest.raises(HTTPException) as exc:
|
| 62 |
+
asyncio.run(responses.create_response(payload))
|
| 63 |
+
|
| 64 |
+
assert exc.value.status_code == 404
|
| 65 |
+
assert exc.value.detail["type"] == "model_not_found"
|
| 66 |
+
assert exc.value.detail["param"] == "model"
|
| 67 |
+
assert exc.value.detail["code"] == "model_not_found"
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def test_completions_generation_exception_returns_generation_error(
|
| 71 |
+
monkeypatch: pytest.MonkeyPatch,
|
| 72 |
+
) -> None:
|
| 73 |
+
def boom(*_: object, **__: object) -> None:
|
| 74 |
+
raise RuntimeError("boom")
|
| 75 |
+
|
| 76 |
+
monkeypatch.setattr("app.routers.completions.get_model_spec", lambda _: None)
|
| 77 |
+
monkeypatch.setattr("app.routers.completions.engine.generate", boom)
|
| 78 |
+
payload = CompletionRequest.model_validate({"model": "GPT3-dev", "prompt": "Hi"})
|
| 79 |
+
|
| 80 |
+
with pytest.raises(HTTPException) as exc:
|
| 81 |
+
asyncio.run(completions.create_completion(payload))
|
| 82 |
+
|
| 83 |
+
assert exc.value.status_code == 500
|
| 84 |
+
assert exc.value.detail["type"] == "server_error"
|
| 85 |
+
assert exc.value.detail["code"] == "generation_error"
|
| 86 |
+
assert "Generation error:" in exc.value.detail["message"]
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def test_chat_generation_exception_returns_generation_error(
|
| 90 |
+
monkeypatch: pytest.MonkeyPatch,
|
| 91 |
+
) -> None:
|
| 92 |
+
def boom(*_: object, **__: object) -> None:
|
| 93 |
+
raise RuntimeError("boom")
|
| 94 |
+
|
| 95 |
+
monkeypatch.setattr(
|
| 96 |
+
"app.routers.chat.get_model_spec",
|
| 97 |
+
lambda model: ModelSpec(name=model, hf_repo="dummy/instruct", is_instruct=True),
|
| 98 |
+
)
|
| 99 |
+
monkeypatch.setattr("app.routers.chat.engine.apply_chat_template", lambda *_: "prompt")
|
| 100 |
+
monkeypatch.setattr("app.routers.chat.engine.generate", boom)
|
| 101 |
+
|
| 102 |
+
payload = ChatCompletionRequest.model_validate(
|
| 103 |
+
{
|
| 104 |
+
"model": "GPT4-dev-177M-1511-Instruct",
|
| 105 |
+
"messages": [{"role": "user", "content": "Hi"}],
|
| 106 |
+
}
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
with pytest.raises(HTTPException) as exc:
|
| 110 |
+
asyncio.run(chat.create_chat_completion(payload))
|
| 111 |
+
|
| 112 |
+
assert exc.value.status_code == 500
|
| 113 |
+
assert exc.value.detail["type"] == "server_error"
|
| 114 |
+
assert exc.value.detail["code"] == "generation_error"
|
| 115 |
+
assert "Generation error:" in exc.value.detail["message"]
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def test_responses_generation_exception_returns_generation_error(
|
| 119 |
+
monkeypatch: pytest.MonkeyPatch,
|
| 120 |
+
) -> None:
|
| 121 |
+
def boom(*_: object, **__: object) -> None:
|
| 122 |
+
raise RuntimeError("boom")
|
| 123 |
+
|
| 124 |
+
monkeypatch.setattr(
|
| 125 |
+
"app.routers.responses.get_model_spec",
|
| 126 |
+
lambda model: ModelSpec(name=model, hf_repo="dummy/base", is_instruct=False),
|
| 127 |
+
)
|
| 128 |
+
monkeypatch.setattr("app.routers.responses.engine.generate", boom)
|
| 129 |
+
payload = ResponseRequest.model_validate({"model": "GPT3-dev", "input": "Hi"})
|
| 130 |
+
|
| 131 |
+
with pytest.raises(HTTPException) as exc:
|
| 132 |
+
asyncio.run(responses.create_response(payload))
|
| 133 |
+
|
| 134 |
+
assert exc.value.status_code == 500
|
| 135 |
+
assert exc.value.detail["type"] == "server_error"
|
| 136 |
+
assert exc.value.detail["code"] == "generation_error"
|
| 137 |
+
assert "Generation error:" in exc.value.detail["message"]
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def test_responses_structured_input_with_non_instruct_model_returns_400(
|
| 141 |
+
monkeypatch: pytest.MonkeyPatch,
|
| 142 |
+
) -> None:
|
| 143 |
+
monkeypatch.setattr(
|
| 144 |
+
"app.routers.responses.get_model_spec",
|
| 145 |
+
lambda model: ModelSpec(name=model, hf_repo="dummy/base", is_instruct=False),
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
payload = ResponseRequest.model_validate(
|
| 149 |
+
{
|
| 150 |
+
"model": "GPT3-dev",
|
| 151 |
+
"input": [{"role": "user", "content": "Hi"}],
|
| 152 |
+
}
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
with pytest.raises(HTTPException) as exc:
|
| 156 |
+
asyncio.run(responses.create_response(payload))
|
| 157 |
+
|
| 158 |
+
assert exc.value.status_code == 400
|
| 159 |
+
assert exc.value.detail["type"] == "invalid_request_error"
|
| 160 |
+
assert exc.value.detail["param"] == "model"
|
| 161 |
+
assert "not an instruct model" in exc.value.detail["message"]
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def test_embeddings_enabled_backend_returns_pending_code(
|
| 165 |
+
monkeypatch: pytest.MonkeyPatch,
|
| 166 |
+
) -> None:
|
| 167 |
+
class DummySettings:
|
| 168 |
+
enable_embeddings_backend = True
|
| 169 |
+
|
| 170 |
+
monkeypatch.setattr("app.routers.embeddings.get_settings", lambda: DummySettings())
|
| 171 |
+
|
| 172 |
+
with pytest.raises(HTTPException) as exc:
|
| 173 |
+
asyncio.run(embeddings.create_embeddings())
|
| 174 |
+
|
| 175 |
+
assert exc.value.status_code == 501
|
| 176 |
+
assert exc.value.detail["type"] == "not_implemented_error"
|
| 177 |
+
assert exc.value.detail["code"] == "embeddings_backend_pending"
|
tests/test_settings.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for environment-driven settings parsing validators."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import pytest
|
| 5 |
+
from pydantic import ValidationError
|
| 6 |
+
|
| 7 |
+
from app.core.settings import Settings
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def test_cors_allow_origins_parses_comma_separated_values() -> None:
|
| 11 |
+
settings = Settings.model_validate(
|
| 12 |
+
{"CORS_ALLOW_ORIGINS": "https://a.example, https://b.example"}
|
| 13 |
+
)
|
| 14 |
+
assert settings.cors_allow_origins == ["https://a.example", "https://b.example"]
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def test_cors_allow_origins_rejects_invalid_type() -> None:
|
| 18 |
+
with pytest.raises(ValidationError):
|
| 19 |
+
Settings.model_validate({"CORS_ALLOW_ORIGINS": 123})
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def test_model_allow_list_parses_comma_separated_values() -> None:
|
| 23 |
+
settings = Settings.model_validate({"MODEL_ALLOW_LIST": "GPT3-dev, GPT-2"})
|
| 24 |
+
assert settings.model_allow_list == ["GPT3-dev", "GPT-2"]
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def test_model_allow_list_rejects_invalid_type() -> None:
|
| 28 |
+
with pytest.raises(ValidationError):
|
| 29 |
+
Settings.model_validate({"MODEL_ALLOW_LIST": 123})
|
tests/test_streaming_contracts.py
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Streaming contract tests for OpenAI-compatible SSE endpoints."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import asyncio
|
| 5 |
+
import json
|
| 6 |
+
from collections import deque
|
| 7 |
+
|
| 8 |
+
import pytest
|
| 9 |
+
from fastapi.responses import StreamingResponse
|
| 10 |
+
|
| 11 |
+
from app.core.model_registry import ModelSpec
|
| 12 |
+
from app.routers import chat, completions, responses
|
| 13 |
+
from app.schemas.chat import ChatCompletionRequest
|
| 14 |
+
from app.schemas.completions import CompletionRequest
|
| 15 |
+
from app.schemas.responses import ResponseRequest
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class DummyStream:
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
*,
|
| 22 |
+
tokens: list[str],
|
| 23 |
+
prompt_tokens: int,
|
| 24 |
+
completion_tokens: int,
|
| 25 |
+
finish_reason: str = "stop",
|
| 26 |
+
) -> None:
|
| 27 |
+
self._tokens = tokens
|
| 28 |
+
self.prompt_tokens = prompt_tokens
|
| 29 |
+
self.completion_tokens = completion_tokens
|
| 30 |
+
self.finish_reason = finish_reason
|
| 31 |
+
|
| 32 |
+
def iter_tokens(self):
|
| 33 |
+
for token in self._tokens:
|
| 34 |
+
yield token
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
async def _read_stream_body(response: StreamingResponse) -> str:
|
| 38 |
+
chunks: list[str] = []
|
| 39 |
+
async for chunk in response.body_iterator:
|
| 40 |
+
if isinstance(chunk, bytes):
|
| 41 |
+
chunks.append(chunk.decode("utf-8"))
|
| 42 |
+
else:
|
| 43 |
+
chunks.append(chunk)
|
| 44 |
+
return "".join(chunks)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _parse_sse_data_frames(raw_body: str) -> list[str]:
|
| 48 |
+
frames = [frame.strip() for frame in raw_body.split("\n\n") if frame.strip()]
|
| 49 |
+
data_frames: list[str] = []
|
| 50 |
+
for frame in frames:
|
| 51 |
+
assert frame.startswith("data: ")
|
| 52 |
+
data_frames.append(frame[len("data: ") :])
|
| 53 |
+
return data_frames
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def test_completions_stream_emits_sse_chunks_usage_and_done(
|
| 57 |
+
monkeypatch: pytest.MonkeyPatch,
|
| 58 |
+
) -> None:
|
| 59 |
+
monkeypatch.setattr("app.routers.completions.get_model_spec", lambda _: None)
|
| 60 |
+
monkeypatch.setattr(
|
| 61 |
+
"app.routers.completions.engine.create_stream",
|
| 62 |
+
lambda *_, **__: DummyStream(
|
| 63 |
+
tokens=["Hel", "lo"],
|
| 64 |
+
prompt_tokens=3,
|
| 65 |
+
completion_tokens=2,
|
| 66 |
+
finish_reason="stop",
|
| 67 |
+
),
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
payload = CompletionRequest.model_validate(
|
| 71 |
+
{
|
| 72 |
+
"model": "GPT3-dev",
|
| 73 |
+
"prompt": "Hello",
|
| 74 |
+
"stream": True,
|
| 75 |
+
}
|
| 76 |
+
)
|
| 77 |
+
response = asyncio.run(completions.create_completion(payload))
|
| 78 |
+
assert isinstance(response, StreamingResponse)
|
| 79 |
+
|
| 80 |
+
body = asyncio.run(_read_stream_body(response))
|
| 81 |
+
data_frames = _parse_sse_data_frames(body)
|
| 82 |
+
assert data_frames[-1] == "[DONE]"
|
| 83 |
+
|
| 84 |
+
chunks = [json.loads(frame) for frame in data_frames[:-1]]
|
| 85 |
+
assert chunks[0]["object"] == "text_completion.chunk"
|
| 86 |
+
assert chunks[0]["choices"][0]["text"] == "Hel"
|
| 87 |
+
assert chunks[1]["choices"][0]["text"] == "lo"
|
| 88 |
+
assert chunks[2]["choices"][0]["finish_reason"] == "stop"
|
| 89 |
+
|
| 90 |
+
tail = chunks[-1]
|
| 91 |
+
assert tail["choices"] == []
|
| 92 |
+
assert tail["usage"] == {
|
| 93 |
+
"prompt_tokens": 3,
|
| 94 |
+
"completion_tokens": 2,
|
| 95 |
+
"total_tokens": 5,
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def test_chat_stream_emits_initial_role_delta_and_done(
|
| 100 |
+
monkeypatch: pytest.MonkeyPatch,
|
| 101 |
+
) -> None:
|
| 102 |
+
monkeypatch.setattr(
|
| 103 |
+
"app.routers.chat.get_model_spec",
|
| 104 |
+
lambda model: ModelSpec(name=model, hf_repo="dummy/instruct", is_instruct=True),
|
| 105 |
+
)
|
| 106 |
+
monkeypatch.setattr("app.routers.chat.engine.apply_chat_template", lambda *_: "formatted")
|
| 107 |
+
monkeypatch.setattr(
|
| 108 |
+
"app.routers.chat.engine.create_stream",
|
| 109 |
+
lambda *_, **__: DummyStream(
|
| 110 |
+
tokens=["Hi", " there"],
|
| 111 |
+
prompt_tokens=4,
|
| 112 |
+
completion_tokens=2,
|
| 113 |
+
finish_reason="stop",
|
| 114 |
+
),
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
payload = ChatCompletionRequest.model_validate(
|
| 118 |
+
{
|
| 119 |
+
"model": "GPT4-dev-177M-1511-Instruct",
|
| 120 |
+
"messages": [{"role": "user", "content": "hello"}],
|
| 121 |
+
"stream": True,
|
| 122 |
+
}
|
| 123 |
+
)
|
| 124 |
+
response = asyncio.run(chat.create_chat_completion(payload))
|
| 125 |
+
assert isinstance(response, StreamingResponse)
|
| 126 |
+
|
| 127 |
+
body = asyncio.run(_read_stream_body(response))
|
| 128 |
+
data_frames = _parse_sse_data_frames(body)
|
| 129 |
+
assert data_frames[-1] == "[DONE]"
|
| 130 |
+
|
| 131 |
+
chunks = [json.loads(frame) for frame in data_frames[:-1]]
|
| 132 |
+
assert chunks[0]["choices"][0]["delta"]["role"] == "assistant"
|
| 133 |
+
assert chunks[1]["choices"][0]["delta"]["content"] == "Hi"
|
| 134 |
+
assert chunks[2]["choices"][0]["delta"]["content"] == " there"
|
| 135 |
+
assert chunks[3]["choices"][0]["finish_reason"] == "stop"
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def test_responses_stream_emits_created_delta_completed_done(
|
| 139 |
+
monkeypatch: pytest.MonkeyPatch,
|
| 140 |
+
) -> None:
|
| 141 |
+
monkeypatch.setattr(
|
| 142 |
+
"app.routers.responses.get_model_spec",
|
| 143 |
+
lambda model: ModelSpec(name=model, hf_repo="dummy/base", is_instruct=False),
|
| 144 |
+
)
|
| 145 |
+
monkeypatch.setattr(
|
| 146 |
+
"app.routers.responses.engine.create_stream",
|
| 147 |
+
lambda *_, **__: DummyStream(
|
| 148 |
+
tokens=["Hi", " there"],
|
| 149 |
+
prompt_tokens=5,
|
| 150 |
+
completion_tokens=2,
|
| 151 |
+
finish_reason="stop",
|
| 152 |
+
),
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
payload = ResponseRequest.model_validate(
|
| 156 |
+
{
|
| 157 |
+
"model": "GPT3-dev",
|
| 158 |
+
"input": "Say hi",
|
| 159 |
+
"stream": True,
|
| 160 |
+
}
|
| 161 |
+
)
|
| 162 |
+
response = asyncio.run(responses.create_response(payload))
|
| 163 |
+
assert isinstance(response, StreamingResponse)
|
| 164 |
+
|
| 165 |
+
body = asyncio.run(_read_stream_body(response))
|
| 166 |
+
data_frames = _parse_sse_data_frames(body)
|
| 167 |
+
assert data_frames[-1] == "[DONE]"
|
| 168 |
+
|
| 169 |
+
events = [json.loads(frame) for frame in data_frames[:-1]]
|
| 170 |
+
assert events[0]["type"] == "response.created"
|
| 171 |
+
assert events[1]["type"] == "response.output_text.delta"
|
| 172 |
+
assert events[1]["delta"] == "Hi"
|
| 173 |
+
assert events[2]["type"] == "response.output_text.delta"
|
| 174 |
+
assert events[2]["delta"] == " there"
|
| 175 |
+
assert events[3]["type"] == "response.completed"
|
| 176 |
+
assert events[3]["response"]["output"][0]["content"][0]["text"] == "Hi there"
|
| 177 |
+
assert events[3]["response"]["usage"] == {
|
| 178 |
+
"input_tokens": 5,
|
| 179 |
+
"output_tokens": 2,
|
| 180 |
+
"total_tokens": 7,
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def test_completions_stream_usage_aggregates_prompt_and_completion_tokens(
|
| 185 |
+
monkeypatch: pytest.MonkeyPatch,
|
| 186 |
+
) -> None:
|
| 187 |
+
calls: list[str] = []
|
| 188 |
+
streams = deque(
|
| 189 |
+
[
|
| 190 |
+
DummyStream(tokens=["a1"], prompt_tokens=10, completion_tokens=1),
|
| 191 |
+
DummyStream(tokens=["a2"], prompt_tokens=999, completion_tokens=2),
|
| 192 |
+
DummyStream(tokens=["b1"], prompt_tokens=20, completion_tokens=3),
|
| 193 |
+
DummyStream(tokens=["b2"], prompt_tokens=888, completion_tokens=4),
|
| 194 |
+
]
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
def fake_create_stream(model: str, prompt: str, **_: object) -> DummyStream:
|
| 198 |
+
calls.append(prompt)
|
| 199 |
+
return streams.popleft()
|
| 200 |
+
|
| 201 |
+
monkeypatch.setattr("app.routers.completions.get_model_spec", lambda _: None)
|
| 202 |
+
monkeypatch.setattr("app.routers.completions.engine.create_stream", fake_create_stream)
|
| 203 |
+
|
| 204 |
+
payload = CompletionRequest.model_validate(
|
| 205 |
+
{
|
| 206 |
+
"model": "GPT3-dev",
|
| 207 |
+
"prompt": ["alpha", "beta"],
|
| 208 |
+
"n": 2,
|
| 209 |
+
"stream": True,
|
| 210 |
+
}
|
| 211 |
+
)
|
| 212 |
+
response = asyncio.run(completions.create_completion(payload))
|
| 213 |
+
body = asyncio.run(_read_stream_body(response))
|
| 214 |
+
data_frames = _parse_sse_data_frames(body)
|
| 215 |
+
assert data_frames[-1] == "[DONE]"
|
| 216 |
+
|
| 217 |
+
chunks = [json.loads(frame) for frame in data_frames[:-1]]
|
| 218 |
+
tail = chunks[-1]
|
| 219 |
+
|
| 220 |
+
assert calls == ["alpha", "alpha", "beta", "beta"]
|
| 221 |
+
assert tail["usage"] == {
|
| 222 |
+
"prompt_tokens": 30,
|
| 223 |
+
"completion_tokens": 10,
|
| 224 |
+
"total_tokens": 40,
|
| 225 |
+
}
|