"""Unit tests for prompt/token/engine helper utilities.""" from __future__ import annotations import sys import types from pathlib import Path import pytest ROOT = Path(__file__).resolve().parents[1] if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) from app.core import engine, tokens from app.core.prompting import DEFAULT_SYSTEM_PROMPT, render_chat_prompt from app.schemas.chat import ChatMessage class DummyTokenizer: def __init__(self) -> None: self.called_with: tuple[str, bool] | None = None def encode(self, text: str, add_special_tokens: bool = False) -> list[int]: self.called_with = (text, add_special_tokens) return [1, 2, 3] class DummyEncoding: def __init__(self, size: int) -> None: self._size = size def encode(self, _: str) -> list[int]: return list(range(self._size)) class DummyTikToken: def __init__(self, size: int) -> None: self._size = size def encoding_for_model(self, _: str) -> DummyEncoding: return DummyEncoding(self._size) def test_render_chat_prompt_uses_default_system_prompt() -> None: prompt = render_chat_prompt([ChatMessage(role="user", content="Hello")]) assert prompt.startswith(f"System: {DEFAULT_SYSTEM_PROMPT}\n\n") assert prompt.endswith("Assistant:") assert "User: Hello" in prompt def test_render_chat_prompt_overrides_system_prompt_when_present() -> None: prompt = render_chat_prompt( [ ChatMessage(role="system", content="Custom system"), ChatMessage(role="user", content="Hello"), ChatMessage(role="assistant", content="Hi"), ] ) assert prompt.startswith("System: Custom system\n\n") assert "User: Hello" in prompt assert "Assistant: Hi" in prompt def test_count_tokens_returns_zero_for_empty_text() -> None: assert tokens.count_tokens("", "GPT3-dev") == 0 def test_count_tokens_uses_tiktoken_when_available(monkeypatch) -> None: monkeypatch.setattr(tokens, "tiktoken", DummyTikToken(size=4)) assert tokens.count_tokens("hello", "GPT3-dev") == 4 def test_count_tokens_falls_back_to_tokenizer_encode(monkeypatch) -> None: monkeypatch.setattr(tokens, "tiktoken", None) tokenizer = DummyTokenizer() assert tokens.count_tokens("hello", "GPT3-dev", tokenizer=tokenizer) == 3 assert tokenizer.called_with == ("hello", False) def test_apply_stop_sequences_returns_earliest_stop_index() -> None: text, reason = engine._apply_stop_sequences( "abcxyz", ["", ""], ) assert text == "abc" assert reason == "stop" def test_normalize_stop_handles_none_string_and_iterable() -> None: assert engine._normalize_stop(None) == () assert engine._normalize_stop("stop") == ("stop",) assert engine._normalize_stop(["a", "b"]) == ("a", "b") def test_pad_token_id_prefers_pad_then_eos_then_zero() -> None: with_pad = types.SimpleNamespace(pad_token_id=9, eos_token_id=7) with_eos_only = types.SimpleNamespace(pad_token_id=None, eos_token_id=7) with_none = types.SimpleNamespace(pad_token_id=None, eos_token_id=None) assert engine._pad_token_id_or_default(with_pad) == 9 assert engine._pad_token_id_or_default(with_eos_only) == 7 assert engine._pad_token_id_or_default(with_none) == 0 def test_unwrap_bound_callable_returns_plain_function() -> None: def loader(cls, model): # pragma: no cover - shape-only test return cls, model assert engine._unwrap_bound_callable(loader) is loader def test_unwrap_bound_callable_extracts_underlying_method_function() -> None: class Demo: @classmethod def loader(cls, model): # pragma: no cover - shape-only test return cls, model unwrapped = engine._unwrap_bound_callable(Demo.loader) assert callable(unwrapped) assert unwrapped.__name__ == "loader" def test_install_tie_weights_compat_patch_strips_unexpected_kwargs() -> None: class Demo: def tie_weights(self): # pragma: no cover - shape-only test return "ok" demo = Demo() with pytest.raises(TypeError): demo.tie_weights(missing_keys=set(), recompute_mapping=False) restore = engine._install_tie_weights_compat_patch(demo) try: assert demo.tie_weights(missing_keys=set(), recompute_mapping=False) == "ok" finally: restore() with pytest.raises(TypeError): demo.tie_weights(missing_keys=set(), recompute_mapping=False) def test_install_tie_weights_compat_patch_preserves_supported_kwargs() -> None: class Demo: def tie_weights(self, keep=None): # pragma: no cover - shape-only test return keep demo = Demo() restore = engine._install_tie_weights_compat_patch(demo) try: assert demo.tie_weights(keep="ok", missing_keys=set(), recompute_mapping=False) == "ok" finally: restore() def test_install_tie_weights_compat_patch_covers_class_dispatch() -> None: class Base: def tie_weights(self): # pragma: no cover - shape-only test return "ok" class Child(Base): pass instance = Child() restore = engine._install_tie_weights_compat_patch(instance, extra_classes=(Base,)) try: assert Base.tie_weights(instance, missing_keys=set(), recompute_mapping=False) == "ok" finally: restore() def test_install_loader_tie_weights_patch_handles_plain_function_descriptor() -> None: class DemoModel: def tie_weights(self): # pragma: no cover - shape-only test return "ok" class DummyLoader: def _load_pretrained_model(model, state, files, name): # noqa: N805 return model.tie_weights(missing_keys=set(), recompute_mapping=False), name model = DemoModel() with pytest.raises(TypeError): model.tie_weights(missing_keys=set(), recompute_mapping=False) restore_loader = engine._install_loader_tie_weights_patch(DummyLoader) try: result = DummyLoader._load_pretrained_model(model, None, None, "demo") assert result == ("ok", "demo") finally: restore_loader() with pytest.raises(TypeError): model.tie_weights(missing_keys=set(), recompute_mapping=False) def test_install_loader_tie_weights_patch_handles_classmethod_descriptor() -> None: class DemoModel: def tie_weights(self): # pragma: no cover - shape-only test return "ok" class DummyLoader: @classmethod def _load_pretrained_model(cls, model, state, files, name): # noqa: N805 return cls.__name__, model.tie_weights(recompute_mapping=False), name model = DemoModel() restore_loader = engine._install_loader_tie_weights_patch(DummyLoader) try: result = DummyLoader._load_pretrained_model(model, None, None, "demo") assert result == ("DummyLoader", "ok", "demo") finally: restore_loader()