20q / modeling_twentyq.py
david-ar's picture
Upload folder using huggingface_hub
00cfc63 verified
"""TwentyQ: The world's smallest chat model.
2-bit quantized neural network (1988), 156 attention heads, 1200 output classes.
Trained on ~75 million conversations. Context window: 20 questions.
"""
import hashlib
import random
import torch
import torch.nn as nn
from transformers import PreTrainedModel, GenerationMixin
from .configuration_twentyq import TwentyQConfig
# Answer codes: 1=No(pol0), 2=Yes(pol1), 3=Probably(pol0), 4=Doubtful(pol1), 5=Maybe(pol0), 6=Unknown
POLARITY = [0, 0, 1, 0, 1, 0, 0]
MATCH_BONUS = [0, 4, 4, 3, 3, 1, 0]
MISS_PENALTY = [0, 4, 4, 1, 1, 0, 0]
ANSWER_WORDS = {
"yes": 2, "y": 2, "yeah": 2, "yep": 2, "usually": 2,
"no": 1, "n": 1, "nope": 1, "nah": 1,
"probably": 3, "prob": 3, "likely": 3,
"doubtful": 4, "doubt": 4, "rarely": 4,
"maybe": 5, "sometimes": 5, "perhaps": 5, "partly": 5,
"unknown": 6, "dunno": 6, "idk": 6, "irrelevant": 6, "skip": 6,
"close": -1,
}
AVM_WORDS = {"animal": 1, "vegetable": 2, "mineral": 3, "other": 4}
class TwentyQForCausalLM(PreTrainedModel, GenerationMixin):
config_class = TwentyQConfig
_tied_weights_keys = []
def __init__(self, config):
super().__init__(config)
self.all_tied_weights_keys = {}
self._dummy = nn.Parameter(torch.zeros(1), requires_grad=False)
self.register_buffer("weight_matrix", torch.zeros(config.num_questions, config.num_targets, dtype=torch.uint8))
self._vocab_loaded = False
def set_vocab(self, questions, targets):
"""Set question and target strings (called by tokenizer or manually)."""
self.questions_str = list(questions)
self.targets_str = list(targets)
self._q_lookup = {q.lower(): i for i, q in enumerate(self.questions_str)}
self._t_lookup = {t.lower(): i for i, t in enumerate(self.targets_str)}
self._vocab_loaded = True
def _ensure_strings(self):
if self._vocab_loaded:
return
raise RuntimeError(
"Model vocabulary not loaded. Call model.set_vocab(questions, targets) "
"or load a tokenizer with vocab.json alongside the model."
)
def forward(self, input_ids=None, **kwargs):
# Dummy forward — the real work happens in generate()
batch = input_ids.shape[0] if input_ids is not None else 1
return {"logits": torch.zeros(batch, 1, self.config.vocab_size)}
def generate(self, input_ids=None, attention_mask=None, **kwargs):
self._ensure_strings()
# Decode input_ids to text (byte-level tokenizer, filter specials > 255)
ids = input_ids[0].tolist()
raw_bytes = bytes(b for b in ids if b < 256)
text = raw_bytes.decode("utf-8", errors="replace")
# Parse conversation and get next response
answers, qnum, last_was_guess, game_over_msg, unrecognized = self._parse_conversation(text)
if unrecognized:
response = f"I didn't understand that. Please answer: {unrecognized}"
elif game_over_msg:
response = game_over_msg
else:
# Seed RNG from conversation for deterministic play
seed = int(hashlib.md5(text.encode()).hexdigest()[:8], 16)
self._rng = random.Random(seed)
response = self._next_move(answers, qnum, last_was_guess)
response_ids = list(response.encode("utf-8"))
response_tensor = torch.tensor([response_ids], dtype=input_ids.dtype, device=input_ids.device)
return torch.cat([input_ids, response_tensor], dim=1)
def _parse_conversation(self, text):
"""Parse chat-templated text into game state."""
answers = [] # [(q_idx, ans_code, is_guess)]
qnum = 0
last_was_guess = False
game_over_msg = None
unrecognized = None # set to hint string if last answer wasn't understood
# Split into turns by [A] and [U] markers
parts = text.replace("\r", "").split("\n")
turns = []
for line in parts:
line = line.strip()
if line.startswith("[A] "):
turns.append(("a", line[4:].strip()))
elif line.startswith("[U] "):
turns.append(("u", line[4:].strip()))
# Pair up assistant/user turns
i = 0
while i < len(turns):
if turns[i][0] == "a":
a_msg = turns[i][1]
u_msg = turns[i + 1][1] if i + 1 < len(turns) and turns[i + 1][0] == "u" else None
if u_msg is None:
# This is the generation prompt — no user response yet
break
u_lower = u_msg.lower().strip().rstrip(".")
if "animal, vegetable, mineral" in a_msg.lower():
# AVM question
avm_code = AVM_WORDS.get(u_lower, 0)
if avm_code:
answers.append((0, avm_code, False))
qnum += 1
unrecognized = None
else:
unrecognized = "Animal, Vegetable, Mineral, or Other"
i += 2
elif a_msg.lower().startswith("i'm guessing"):
# Guess
target_name = a_msg.split("...")[-1].strip().rstrip("?").strip()
t_idx = self._t_lookup.get(target_name.lower(), -1)
ans_code = ANSWER_WORDS.get(u_lower, 0)
if ans_code == 2: # Yes — correct guess
game_over_msg = f"I win! Got it in {qnum + 1} questions."
unrecognized = None
elif ans_code == 1 or ans_code == -1: # No or Close
if t_idx >= 0:
answers.append((t_idx, 0, True))
qnum += 1
unrecognized = None
else:
unrecognized = "Yes, No, or Close"
i += 2
elif a_msg.lower().startswith("i win") or a_msg.lower().startswith("i'm stumped"):
# Game already over
game_over_msg = a_msg
i += 2
else:
# Regular question
q_text = a_msg.rstrip("?").strip()
q_idx = self._q_lookup.get(q_text.lower(), -1)
ans_code = ANSWER_WORDS.get(u_lower, 0)
if ans_code == -1 or ans_code == 0:
unrecognized = "Yes, No, Probably, Doubtful, Maybe, or Unknown"
else:
unrecognized = None
if q_idx >= 0:
answers.append((q_idx, ans_code, False))
qnum += 1
i += 2
else:
i += 1
return answers, qnum, last_was_guess, game_over_msg, unrecognized
def _next_move(self, answers, qnum, last_was_guess):
if qnum == 0:
return "Is it Animal, Vegetable, Mineral, or Other?"
if qnum >= 30:
return "I'm stumped! I can't figure out what you're thinking of."
nc, best_t, best_s, cidx, cscores = self._rank_targets(answers)
if nc == 0:
return "I'm stumped! I can't figure out what you're thinking of."
should_guess = (
nc == 1 or qnum == 20 or qnum == 24 or qnum == 30
or (qnum >= 18 and nc <= 2)
)
if should_guess:
return f"I'm guessing... {self.targets_str[best_t]}?"
q = self._select_question(answers, nc, cidx)
if q < 0:
return f"I'm guessing... {self.targets_str[best_t]}?"
return f"{self.questions_str[q]}?"
def _score(self, answer_code, target, question):
w = int(self.weight_matrix[question, target])
if (POLARITY[answer_code] ^ w) & 1:
s = -MISS_PENALTY[answer_code]
else:
s = MATCH_BONUS[answer_code]
if w & 2:
s *= 2
return s
def _rank_targets(self, answers):
max_c = 16 if len(answers) <= 10 else (8 if len(answers) <= 12 else 5)
c_scores = [0] * max_c
c_indices = [0] * max_c
nc = 0
best_t, best_s = 0, 0
for t in range(self.config.num_targets):
guessed = any(qi == t and ig for qi, _, ig in answers)
if guessed:
continue
score = 0
skip = False
for qi, ac, ig in answers:
if ig or ac == 0:
continue
if qi != 0:
score += self._score(ac, t, qi)
else:
for k in range(4):
score += self._score(4 if k + 1 == ac else 3, t, k)
if len(answers) > 7 and score < 0:
skip = True
break
if skip or score < 0:
continue
score += self._rng.randint(0, 7)
if nc < max_c:
slot = nc
nc += 1
else:
min_s, slot = min((c_scores[j], j) for j in range(max_c))
if min_s >= score:
continue
c_scores[slot] = score
c_indices[slot] = t
if score > best_s:
best_t, best_s = t, score
thresh = best_s // 4
thresh = max(5, min(20, thresh))
cutoff = best_s - thresh
pi = [(c_indices[j], c_scores[j]) for j in range(nc) if c_scores[j] > cutoff]
if not pi:
return 0, best_t, best_s, [], []
idx, sc = zip(*pi)
return len(pi), best_t, best_s, list(idx), list(sc)
def _select_question(self, answers, nc, cidx):
best_s, best_q = -1000, -1
asked = {qi for qi, _, ig in answers if not ig}
for q in range(4, self.config.num_questions):
if q in asked:
continue
pos, neg = 0, 0
for t in cidx:
w = int(self.weight_matrix[q, t])
wt = 3 if (w & 2) else 1
if w & 1:
neg += wt
else:
pos += wt
s = (pos * 2 - neg) if pos <= neg else (neg * 2 - pos)
s += self._rng.randint(0, 7)
if s > best_s:
best_s, best_q = s, q
return best_q
def play(self, tokenizer=None):
"""Interactive CLI mode. Pass the tokenizer for proper chat template formatting."""
self._ensure_strings()
if tokenizer is None:
# Minimal fallback — construct chat text directly
from .tokenization_twentyq import TwentyQTokenizer
tokenizer = TwentyQTokenizer()
tokenizer.chat_template = (
"{% if messages[0]['role'] == 'system' %}{{ messages[0]['content'] }}\n"
"{% set loop_messages = messages[1:] %}{% else %}"
"{% set loop_messages = messages %}{% endif %}"
"{% for message in loop_messages %}"
"{% if message['role'] == 'assistant' %}[A] {{ message['content'] }}\n"
"{% elif message['role'] == 'user' %}[U] {{ message['content'] }}\n"
"{% endif %}{% endfor %}"
"{% if add_generation_prompt %}[A] {% endif %}"
)
messages = [
{"role": "system", "content": "Think of something and I'll try to guess it in 20 questions."},
]
print("\n Think of something...\n")
input(" Press Enter when ready... ")
while True:
text = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
ids = tokenizer.encode(text, return_tensors="pt")
out = self.generate(ids)
response = tokenizer.decode(out[0, ids.shape[1]:].tolist())
messages.append({"role": "assistant", "content": response})
print(f"\n > {response}")
if "I win" in response or "stumped" in response:
return
if "Animal, Vegetable, Mineral" in response:
hint = "(Animal/Vegetable/Mineral/Other)"
elif "guessing" in response.lower():
hint = "(Yes/No/Close)"
else:
hint = "(Yes/No/Probably/Doubtful/Maybe/Unknown)"
reply = input(f" {hint}: ").strip()
if not reply:
return
messages.append({"role": "user", "content": reply})