Spaces:
Running
Running
Upload 17 files
Browse files- Dockerfile +40 -0
- README.md +5 -9
- nanochat/__init__.py +0 -0
- nanochat/checkpoint_manager.py +196 -0
- nanochat/common.py +278 -0
- nanochat/engine.py +360 -0
- nanochat/flash_attention.py +187 -0
- nanochat/gpt.py +465 -0
- nanochat/logo.svg +8 -0
- nanochat/optim.py +533 -0
- nanochat/tokenizer.py +14 -0
- nanochat/ui.html +566 -0
- scripts/__init__.py +0 -0
- scripts/chat_web.py +421 -0
- start.sh +27 -0
- tokenizer.json +0 -0
- tokenizer_wrapper.py +282 -0
Dockerfile
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
# Install system dependencies
|
| 6 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 7 |
+
build-essential \
|
| 8 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 9 |
+
|
| 10 |
+
# Install Python dependencies
|
| 11 |
+
RUN pip install --no-cache-dir \
|
| 12 |
+
torch --index-url https://download.pytorch.org/whl/cpu
|
| 13 |
+
RUN pip install --no-cache-dir \
|
| 14 |
+
tokenizers \
|
| 15 |
+
fastapi \
|
| 16 |
+
uvicorn[standard] \
|
| 17 |
+
pydantic \
|
| 18 |
+
httpx \
|
| 19 |
+
filelock \
|
| 20 |
+
huggingface_hub
|
| 21 |
+
|
| 22 |
+
# Copy application code
|
| 23 |
+
COPY nanochat/ nanochat/
|
| 24 |
+
COPY scripts/ scripts/
|
| 25 |
+
COPY tokenizer_wrapper.py .
|
| 26 |
+
COPY tokenizer.json .
|
| 27 |
+
COPY start.sh .
|
| 28 |
+
RUN chmod +x start.sh
|
| 29 |
+
|
| 30 |
+
# Create model directory
|
| 31 |
+
RUN mkdir -p /app/nanochat_cache/chatsft_checkpoints/d18
|
| 32 |
+
|
| 33 |
+
# Set environment variables
|
| 34 |
+
ENV NANOCHAT_BASE_DIR=/app/nanochat_cache
|
| 35 |
+
ENV PYTHONPATH=/app
|
| 36 |
+
|
| 37 |
+
# HuggingFace Spaces expects port 7860
|
| 38 |
+
EXPOSE 7860
|
| 39 |
+
|
| 40 |
+
CMD ["./start.sh"]
|
README.md
CHANGED
|
@@ -1,12 +1,8 @@
|
|
| 1 |
---
|
| 2 |
-
title: Mr Chatterbox
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
| 7 |
-
|
| 8 |
-
license: mit
|
| 9 |
-
short_description: The Victorian Gentleman Chatbot
|
| 10 |
---
|
| 11 |
-
|
| 12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Mr. Chatterbox
|
| 3 |
+
emoji: 🎩
|
| 4 |
+
colorFrom: red
|
| 5 |
+
colorTo: yellow
|
| 6 |
sdk: docker
|
| 7 |
+
app_port: 7860
|
|
|
|
|
|
|
| 8 |
---
|
|
|
|
|
|
nanochat/__init__.py
ADDED
|
File without changes
|
nanochat/checkpoint_manager.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utilities for saving and loading model/optim/state checkpoints.
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
import re
|
| 6 |
+
import glob
|
| 7 |
+
import json
|
| 8 |
+
import logging
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from nanochat.common import get_base_dir
|
| 12 |
+
from nanochat.gpt import GPT, GPTConfig
|
| 13 |
+
from nanochat.tokenizer import get_tokenizer
|
| 14 |
+
from nanochat.common import setup_default_logging
|
| 15 |
+
|
| 16 |
+
# Set up logging
|
| 17 |
+
setup_default_logging()
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
def log0(message):
|
| 20 |
+
if int(os.environ.get('RANK', 0)) == 0:
|
| 21 |
+
logger.info(message)
|
| 22 |
+
|
| 23 |
+
def _patch_missing_config_keys(model_config_kwargs):
|
| 24 |
+
"""Add default values for new config keys missing in old checkpoints."""
|
| 25 |
+
# Old models were trained with full context (no sliding window)
|
| 26 |
+
if "window_pattern" not in model_config_kwargs:
|
| 27 |
+
model_config_kwargs["window_pattern"] = "L"
|
| 28 |
+
log0(f"Patching missing window_pattern in model config to 'L'")
|
| 29 |
+
|
| 30 |
+
def _patch_missing_keys(model_data, model_config):
|
| 31 |
+
"""Add default values for new parameters that may be missing in old checkpoints."""
|
| 32 |
+
n_layer = model_config.n_layer
|
| 33 |
+
# resid_lambdas defaults to 1.0 (identity scaling)
|
| 34 |
+
if "resid_lambdas" not in model_data:
|
| 35 |
+
model_data["resid_lambdas"] = torch.ones(n_layer)
|
| 36 |
+
log0(f"Patching missing resid_lambdas in model data to 1.0")
|
| 37 |
+
# x0_lambdas defaults to 0.0 (disabled)
|
| 38 |
+
if "x0_lambdas" not in model_data:
|
| 39 |
+
model_data["x0_lambdas"] = torch.zeros(n_layer)
|
| 40 |
+
log0(f"Patching missing x0_lambdas in model data to 0.0")
|
| 41 |
+
|
| 42 |
+
def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, rank=0):
|
| 43 |
+
if rank == 0:
|
| 44 |
+
os.makedirs(checkpoint_dir, exist_ok=True)
|
| 45 |
+
# Save the model state parameters
|
| 46 |
+
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
|
| 47 |
+
torch.save(model_data, model_path)
|
| 48 |
+
logger.info(f"Saved model parameters to: {model_path}")
|
| 49 |
+
# Save the metadata dict as json
|
| 50 |
+
meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
|
| 51 |
+
with open(meta_path, "w", encoding="utf-8") as f:
|
| 52 |
+
json.dump(meta_data, f, indent=2)
|
| 53 |
+
logger.info(f"Saved metadata to: {meta_path}")
|
| 54 |
+
# Note that optimizer state is sharded across ranks, so each rank must save its own.
|
| 55 |
+
if optimizer_data is not None:
|
| 56 |
+
os.makedirs(checkpoint_dir, exist_ok=True)
|
| 57 |
+
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
|
| 58 |
+
torch.save(optimizer_data, optimizer_path)
|
| 59 |
+
logger.info(f"Saved optimizer state to: {optimizer_path}")
|
| 60 |
+
|
| 61 |
+
def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False, rank=0):
|
| 62 |
+
# Load the model state
|
| 63 |
+
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
|
| 64 |
+
model_data = torch.load(model_path, map_location=device)
|
| 65 |
+
# Load the optimizer state if requested
|
| 66 |
+
optimizer_data = None
|
| 67 |
+
if load_optimizer:
|
| 68 |
+
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
|
| 69 |
+
optimizer_data = torch.load(optimizer_path, map_location=device)
|
| 70 |
+
# Load the metadata
|
| 71 |
+
meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
|
| 72 |
+
with open(meta_path, "r", encoding="utf-8") as f:
|
| 73 |
+
meta_data = json.load(f)
|
| 74 |
+
return model_data, optimizer_data, meta_data
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def build_model(checkpoint_dir, step, device, phase):
|
| 78 |
+
"""
|
| 79 |
+
A bunch of repetitive code to build a model from a given checkpoint.
|
| 80 |
+
Returns:
|
| 81 |
+
- base model - uncompiled, not wrapped in DDP
|
| 82 |
+
- tokenizer
|
| 83 |
+
- meta data saved during base model training
|
| 84 |
+
"""
|
| 85 |
+
assert phase in ["train", "eval"], f"Invalid phase: {phase}"
|
| 86 |
+
model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, step, device, load_optimizer=False)
|
| 87 |
+
if device.type in {"cpu", "mps"}:
|
| 88 |
+
# Convert bfloat16 tensors to float for CPU inference
|
| 89 |
+
model_data = {
|
| 90 |
+
k: v.float() if v.dtype == torch.bfloat16 else v
|
| 91 |
+
for k, v in model_data.items()
|
| 92 |
+
}
|
| 93 |
+
# Hack: fix torch compile issue, which prepends all keys with _orig_mod.
|
| 94 |
+
model_data = {k.removeprefix("_orig_mod."): v for k, v in model_data.items()}
|
| 95 |
+
model_config_kwargs = meta_data["model_config"]
|
| 96 |
+
_patch_missing_config_keys(model_config_kwargs)
|
| 97 |
+
log0(f"Building model with config: {model_config_kwargs}")
|
| 98 |
+
model_config = GPTConfig(**model_config_kwargs)
|
| 99 |
+
_patch_missing_keys(model_data, model_config)
|
| 100 |
+
with torch.device("meta"):
|
| 101 |
+
model = GPT(model_config)
|
| 102 |
+
# Load the model state
|
| 103 |
+
model.to_empty(device=device)
|
| 104 |
+
model.init_weights() # note: this is dumb, but we need to init the rotary embeddings. TODO: fix model re-init
|
| 105 |
+
result = model.load_state_dict(model_data, strict=False, assign=True)
|
| 106 |
+
if result.unexpected_keys:
|
| 107 |
+
log0(f"Ignoring unexpected checkpoint keys: {result.unexpected_keys}")
|
| 108 |
+
# Put the model in the right training phase / mode
|
| 109 |
+
if phase == "eval":
|
| 110 |
+
model.eval()
|
| 111 |
+
else:
|
| 112 |
+
model.train()
|
| 113 |
+
# Load the Tokenizer
|
| 114 |
+
tokenizer = get_tokenizer()
|
| 115 |
+
# Sanity check: compatibility between model and tokenizer
|
| 116 |
+
assert tokenizer.get_vocab_size() == model_config_kwargs["vocab_size"], f"Tokenizer vocab size {tokenizer.get_vocab_size()} does not match model config vocab size {model_config_kwargs['vocab_size']}"
|
| 117 |
+
return model, tokenizer, meta_data
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def find_largest_model(checkpoints_dir):
|
| 121 |
+
# attempt to guess the model tag: take the biggest model available
|
| 122 |
+
model_tags = [f for f in os.listdir(checkpoints_dir) if os.path.isdir(os.path.join(checkpoints_dir, f))]
|
| 123 |
+
if not model_tags:
|
| 124 |
+
raise FileNotFoundError(f"No checkpoints found in {checkpoints_dir}")
|
| 125 |
+
# 1) normally all model tags are of the form d<number>, try that first:
|
| 126 |
+
candidates = []
|
| 127 |
+
for model_tag in model_tags:
|
| 128 |
+
match = re.match(r"d(\d+)", model_tag)
|
| 129 |
+
if match:
|
| 130 |
+
model_depth = int(match.group(1))
|
| 131 |
+
candidates.append((model_depth, model_tag))
|
| 132 |
+
if candidates:
|
| 133 |
+
candidates.sort(key=lambda x: x[0], reverse=True)
|
| 134 |
+
return candidates[0][1]
|
| 135 |
+
# 2) if that failed, take the most recently updated model:
|
| 136 |
+
model_tags.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoints_dir, x)), reverse=True)
|
| 137 |
+
return model_tags[0]
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def find_last_step(checkpoint_dir):
|
| 141 |
+
# Look into checkpoint_dir and find model_<step>.pt with the highest step
|
| 142 |
+
checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "model_*.pt"))
|
| 143 |
+
if not checkpoint_files:
|
| 144 |
+
raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
|
| 145 |
+
last_step = int(max(os.path.basename(f).split("_")[-1].split(".")[0] for f in checkpoint_files))
|
| 146 |
+
return last_step
|
| 147 |
+
|
| 148 |
+
# -----------------------------------------------------------------------------
|
| 149 |
+
# convenience functions that take into account nanochat's directory structure
|
| 150 |
+
|
| 151 |
+
def load_model_from_dir(checkpoints_dir, device, phase, model_tag=None, step=None):
|
| 152 |
+
if model_tag is None:
|
| 153 |
+
# guess the model tag by defaulting to the largest model
|
| 154 |
+
model_tag = find_largest_model(checkpoints_dir)
|
| 155 |
+
log0(f"No model tag provided, guessing model tag: {model_tag}")
|
| 156 |
+
checkpoint_dir = os.path.join(checkpoints_dir, model_tag)
|
| 157 |
+
if step is None:
|
| 158 |
+
# guess the step by defaulting to the last step
|
| 159 |
+
step = find_last_step(checkpoint_dir)
|
| 160 |
+
assert step is not None, f"No checkpoints found in {checkpoint_dir}"
|
| 161 |
+
# build the model
|
| 162 |
+
log0(f"Loading model from {checkpoint_dir} with step {step}")
|
| 163 |
+
model, tokenizer, meta_data = build_model(checkpoint_dir, step, device, phase)
|
| 164 |
+
return model, tokenizer, meta_data
|
| 165 |
+
|
| 166 |
+
def load_model(source, *args, **kwargs):
|
| 167 |
+
model_dir = {
|
| 168 |
+
"base": "base_checkpoints",
|
| 169 |
+
"sft": "chatsft_checkpoints",
|
| 170 |
+
"rl": "chatrl_checkpoints",
|
| 171 |
+
}[source]
|
| 172 |
+
base_dir = get_base_dir()
|
| 173 |
+
checkpoints_dir = os.path.join(base_dir, model_dir)
|
| 174 |
+
return load_model_from_dir(checkpoints_dir, *args, **kwargs)
|
| 175 |
+
|
| 176 |
+
def load_optimizer_state(source, device, rank, model_tag=None, step=None):
|
| 177 |
+
"""Load just the optimizer shard for a given rank, without re-loading the model."""
|
| 178 |
+
model_dir = {
|
| 179 |
+
"base": "base_checkpoints",
|
| 180 |
+
"sft": "chatsft_checkpoints",
|
| 181 |
+
"rl": "chatrl_checkpoints",
|
| 182 |
+
}[source]
|
| 183 |
+
base_dir = get_base_dir()
|
| 184 |
+
checkpoints_dir = os.path.join(base_dir, model_dir)
|
| 185 |
+
if model_tag is None:
|
| 186 |
+
model_tag = find_largest_model(checkpoints_dir)
|
| 187 |
+
checkpoint_dir = os.path.join(checkpoints_dir, model_tag)
|
| 188 |
+
if step is None:
|
| 189 |
+
step = find_last_step(checkpoint_dir)
|
| 190 |
+
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
|
| 191 |
+
if not os.path.exists(optimizer_path):
|
| 192 |
+
log0(f"Optimizer checkpoint not found: {optimizer_path}")
|
| 193 |
+
return None
|
| 194 |
+
log0(f"Loading optimizer state from {optimizer_path}")
|
| 195 |
+
optimizer_data = torch.load(optimizer_path, map_location=device)
|
| 196 |
+
return optimizer_data
|
nanochat/common.py
ADDED
|
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Common utilities for nanochat.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import re
|
| 7 |
+
import logging
|
| 8 |
+
import urllib.request
|
| 9 |
+
import torch
|
| 10 |
+
import torch.distributed as dist
|
| 11 |
+
from filelock import FileLock
|
| 12 |
+
|
| 13 |
+
# The dtype used for compute (matmuls, activations). Master weights stay fp32 for optimizer precision.
|
| 14 |
+
# Linear layers cast their weights to this dtype in forward, replacing torch.amp.autocast.
|
| 15 |
+
# Override with NANOCHAT_DTYPE env var: "bfloat16", "float16", "float32"
|
| 16 |
+
_DTYPE_MAP = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}
|
| 17 |
+
def _detect_compute_dtype():
|
| 18 |
+
env = os.environ.get("NANOCHAT_DTYPE")
|
| 19 |
+
if env is not None:
|
| 20 |
+
return _DTYPE_MAP[env], f"set via NANOCHAT_DTYPE={env}"
|
| 21 |
+
if torch.cuda.is_available():
|
| 22 |
+
# bf16 requires SM 80+ (Ampere: A100, A10, etc.)
|
| 23 |
+
# Older GPUs like V100 (SM 70) and T4 (SM 75) only have fp16 tensor cores
|
| 24 |
+
capability = torch.cuda.get_device_capability()
|
| 25 |
+
if capability >= (8, 0):
|
| 26 |
+
return torch.bfloat16, f"auto-detected: CUDA SM {capability[0]}{capability[1]} (bf16 supported)"
|
| 27 |
+
# fp16 training requires GradScaler (not yet implemented), so fall back to fp32.
|
| 28 |
+
# Users can still force fp16 via NANOCHAT_DTYPE=float16 if they know what they're doing.
|
| 29 |
+
return torch.float32, f"auto-detected: CUDA SM {capability[0]}{capability[1]} (pre-Ampere, bf16 not supported, using fp32)"
|
| 30 |
+
return torch.float32, "auto-detected: no CUDA (CPU/MPS)"
|
| 31 |
+
COMPUTE_DTYPE, COMPUTE_DTYPE_REASON = _detect_compute_dtype()
|
| 32 |
+
|
| 33 |
+
class ColoredFormatter(logging.Formatter):
|
| 34 |
+
"""Custom formatter that adds colors to log messages."""
|
| 35 |
+
# ANSI color codes
|
| 36 |
+
COLORS = {
|
| 37 |
+
'DEBUG': '\033[36m', # Cyan
|
| 38 |
+
'INFO': '\033[32m', # Green
|
| 39 |
+
'WARNING': '\033[33m', # Yellow
|
| 40 |
+
'ERROR': '\033[31m', # Red
|
| 41 |
+
'CRITICAL': '\033[35m', # Magenta
|
| 42 |
+
}
|
| 43 |
+
RESET = '\033[0m'
|
| 44 |
+
BOLD = '\033[1m'
|
| 45 |
+
def format(self, record):
|
| 46 |
+
# Add color to the level name
|
| 47 |
+
levelname = record.levelname
|
| 48 |
+
if levelname in self.COLORS:
|
| 49 |
+
record.levelname = f"{self.COLORS[levelname]}{self.BOLD}{levelname}{self.RESET}"
|
| 50 |
+
# Format the message
|
| 51 |
+
message = super().format(record)
|
| 52 |
+
# Add color to specific parts of the message
|
| 53 |
+
if levelname == 'INFO':
|
| 54 |
+
# Highlight numbers and percentages
|
| 55 |
+
message = re.sub(r'(\d+\.?\d*\s*(?:GB|MB|%|docs))', rf'{self.BOLD}\1{self.RESET}', message)
|
| 56 |
+
message = re.sub(r'(Shard \d+)', rf'{self.COLORS["INFO"]}{self.BOLD}\1{self.RESET}', message)
|
| 57 |
+
return message
|
| 58 |
+
|
| 59 |
+
def setup_default_logging():
|
| 60 |
+
handler = logging.StreamHandler()
|
| 61 |
+
handler.setFormatter(ColoredFormatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
|
| 62 |
+
logging.basicConfig(
|
| 63 |
+
level=logging.INFO,
|
| 64 |
+
handlers=[handler]
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
setup_default_logging()
|
| 68 |
+
logger = logging.getLogger(__name__)
|
| 69 |
+
|
| 70 |
+
def get_base_dir():
|
| 71 |
+
# co-locate nanochat intermediates with other cached data in ~/.cache (by default)
|
| 72 |
+
if os.environ.get("NANOCHAT_BASE_DIR"):
|
| 73 |
+
nanochat_dir = os.environ.get("NANOCHAT_BASE_DIR")
|
| 74 |
+
else:
|
| 75 |
+
home_dir = os.path.expanduser("~")
|
| 76 |
+
cache_dir = os.path.join(home_dir, ".cache")
|
| 77 |
+
nanochat_dir = os.path.join(cache_dir, "nanochat")
|
| 78 |
+
os.makedirs(nanochat_dir, exist_ok=True)
|
| 79 |
+
return nanochat_dir
|
| 80 |
+
|
| 81 |
+
def download_file_with_lock(url, filename, postprocess_fn=None):
|
| 82 |
+
"""
|
| 83 |
+
Downloads a file from a URL to a local path in the base directory.
|
| 84 |
+
Uses a lock file to prevent concurrent downloads among multiple ranks.
|
| 85 |
+
"""
|
| 86 |
+
base_dir = get_base_dir()
|
| 87 |
+
file_path = os.path.join(base_dir, filename)
|
| 88 |
+
lock_path = file_path + ".lock"
|
| 89 |
+
|
| 90 |
+
if os.path.exists(file_path):
|
| 91 |
+
return file_path
|
| 92 |
+
|
| 93 |
+
with FileLock(lock_path):
|
| 94 |
+
# Only a single rank can acquire this lock
|
| 95 |
+
# All other ranks block until it is released
|
| 96 |
+
|
| 97 |
+
# Recheck after acquiring lock
|
| 98 |
+
if os.path.exists(file_path):
|
| 99 |
+
return file_path
|
| 100 |
+
|
| 101 |
+
# Download the content as bytes
|
| 102 |
+
print(f"Downloading {url}...")
|
| 103 |
+
with urllib.request.urlopen(url) as response:
|
| 104 |
+
content = response.read() # bytes
|
| 105 |
+
|
| 106 |
+
# Write to local file
|
| 107 |
+
with open(file_path, 'wb') as f:
|
| 108 |
+
f.write(content)
|
| 109 |
+
print(f"Downloaded to {file_path}")
|
| 110 |
+
|
| 111 |
+
# Run the postprocess function if provided
|
| 112 |
+
if postprocess_fn is not None:
|
| 113 |
+
postprocess_fn(file_path)
|
| 114 |
+
|
| 115 |
+
return file_path
|
| 116 |
+
|
| 117 |
+
def print0(s="",**kwargs):
|
| 118 |
+
ddp_rank = int(os.environ.get('RANK', 0))
|
| 119 |
+
if ddp_rank == 0:
|
| 120 |
+
print(s, **kwargs)
|
| 121 |
+
|
| 122 |
+
def print_banner():
|
| 123 |
+
# Cool DOS Rebel font ASCII banner made with https://manytools.org/hacker-tools/ascii-banner/
|
| 124 |
+
banner = """
|
| 125 |
+
█████ █████
|
| 126 |
+
░░███ ░░███
|
| 127 |
+
████████ ██████ ██��█████ ██████ ██████ ░███████ ██████ ███████
|
| 128 |
+
░░███░░███ ░░░░░███ ░░███░░███ ███░░███ ███░░███ ░███░░███ ░░░░░███░░░███░
|
| 129 |
+
░███ ░███ ███████ ░███ ░███ ░███ ░███░███ ░░░ ░███ ░███ ███████ ░███
|
| 130 |
+
░███ ░███ ███░░███ ░███ ░███ ░███ ░███░███ ███ ░███ ░███ ███░░███ ░███ ███
|
| 131 |
+
████ █████░░████████ ████ █████░░██████ ░░██████ ████ █████░░███████ ░░█████
|
| 132 |
+
░░░░ ░░░░░ ░░░░░░░░ ░░░░ ░░░░░ ░░░░░░ ░░░░░░ ░░░░ ░░░░░ ░░░░░░░░ ░░░░░
|
| 133 |
+
"""
|
| 134 |
+
print0(banner)
|
| 135 |
+
|
| 136 |
+
def is_ddp_requested() -> bool:
|
| 137 |
+
"""
|
| 138 |
+
True if launched by torchrun (env present), even before init.
|
| 139 |
+
Used to decide whether we *should* initialize a PG.
|
| 140 |
+
"""
|
| 141 |
+
return all(k in os.environ for k in ("RANK", "LOCAL_RANK", "WORLD_SIZE"))
|
| 142 |
+
|
| 143 |
+
def is_ddp_initialized() -> bool:
|
| 144 |
+
"""
|
| 145 |
+
True if torch.distributed is available and the process group is initialized.
|
| 146 |
+
Used at cleanup to avoid destroying a non-existent PG.
|
| 147 |
+
"""
|
| 148 |
+
return dist.is_available() and dist.is_initialized()
|
| 149 |
+
|
| 150 |
+
def get_dist_info():
|
| 151 |
+
if is_ddp_requested():
|
| 152 |
+
# We rely on torchrun's env to decide if we SHOULD init.
|
| 153 |
+
# (Initialization itself happens in compute init.)
|
| 154 |
+
assert all(var in os.environ for var in ['RANK', 'LOCAL_RANK', 'WORLD_SIZE'])
|
| 155 |
+
ddp_rank = int(os.environ['RANK'])
|
| 156 |
+
ddp_local_rank = int(os.environ['LOCAL_RANK'])
|
| 157 |
+
ddp_world_size = int(os.environ['WORLD_SIZE'])
|
| 158 |
+
return True, ddp_rank, ddp_local_rank, ddp_world_size
|
| 159 |
+
else:
|
| 160 |
+
return False, 0, 0, 1
|
| 161 |
+
|
| 162 |
+
def autodetect_device_type():
|
| 163 |
+
# prefer to use CUDA if available, otherwise use MPS, otherwise fallback on CPU
|
| 164 |
+
if torch.cuda.is_available():
|
| 165 |
+
device_type = "cuda"
|
| 166 |
+
elif torch.backends.mps.is_available():
|
| 167 |
+
device_type = "mps"
|
| 168 |
+
else:
|
| 169 |
+
device_type = "cpu"
|
| 170 |
+
print0(f"Autodetected device type: {device_type}")
|
| 171 |
+
return device_type
|
| 172 |
+
|
| 173 |
+
def compute_init(device_type="cuda"): # cuda|cpu|mps
|
| 174 |
+
"""Basic initialization that we keep doing over and over, so make common."""
|
| 175 |
+
|
| 176 |
+
assert device_type in ["cuda", "mps", "cpu"], "Invalid device type atm"
|
| 177 |
+
if device_type == "cuda":
|
| 178 |
+
assert torch.cuda.is_available(), "Your PyTorch installation is not configured for CUDA but device_type is 'cuda'"
|
| 179 |
+
if device_type == "mps":
|
| 180 |
+
assert torch.backends.mps.is_available(), "Your PyTorch installation is not configured for MPS but device_type is 'mps'"
|
| 181 |
+
|
| 182 |
+
# Reproducibility
|
| 183 |
+
# Note that we set the global seeds here, but most of the code uses explicit rng objects.
|
| 184 |
+
# The only place where global rng might be used is nn.Module initialization of the model weights.
|
| 185 |
+
torch.manual_seed(42)
|
| 186 |
+
if device_type == "cuda":
|
| 187 |
+
torch.cuda.manual_seed(42)
|
| 188 |
+
# skipping full reproducibility for now, possibly investigate slowdown later
|
| 189 |
+
# torch.use_deterministic_algorithms(True)
|
| 190 |
+
|
| 191 |
+
# Precision
|
| 192 |
+
if device_type == "cuda":
|
| 193 |
+
torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls, see https://docs.pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html
|
| 194 |
+
|
| 195 |
+
# Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA
|
| 196 |
+
is_ddp_requested, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
| 197 |
+
if is_ddp_requested and device_type == "cuda":
|
| 198 |
+
device = torch.device("cuda", ddp_local_rank)
|
| 199 |
+
torch.cuda.set_device(device) # make "cuda" default to this device
|
| 200 |
+
dist.init_process_group(backend="nccl", device_id=device)
|
| 201 |
+
dist.barrier()
|
| 202 |
+
else:
|
| 203 |
+
device = torch.device(device_type) # mps|cpu
|
| 204 |
+
|
| 205 |
+
if ddp_rank == 0:
|
| 206 |
+
logger.info(f"Distributed world size: {ddp_world_size}")
|
| 207 |
+
|
| 208 |
+
return is_ddp_requested, ddp_rank, ddp_local_rank, ddp_world_size, device
|
| 209 |
+
|
| 210 |
+
def compute_cleanup():
|
| 211 |
+
"""Companion function to compute_init, to clean things up before script exit"""
|
| 212 |
+
if is_ddp_initialized():
|
| 213 |
+
dist.destroy_process_group()
|
| 214 |
+
|
| 215 |
+
class DummyWandb:
|
| 216 |
+
"""Useful if we wish to not use wandb but have all the same signatures"""
|
| 217 |
+
def __init__(self):
|
| 218 |
+
pass
|
| 219 |
+
def log(self, *args, **kwargs):
|
| 220 |
+
pass
|
| 221 |
+
def finish(self):
|
| 222 |
+
pass
|
| 223 |
+
|
| 224 |
+
# hardcoded BF16 peak flops for various GPUs
|
| 225 |
+
# inspired by torchtitan: https://github.com/pytorch/torchtitan/blob/main/torchtitan/tools/utils.py
|
| 226 |
+
# and PR: https://github.com/karpathy/nanochat/pull/147
|
| 227 |
+
def get_peak_flops(device_name: str) -> float:
|
| 228 |
+
name = device_name.lower()
|
| 229 |
+
|
| 230 |
+
# Table order matters: more specific patterns first.
|
| 231 |
+
_PEAK_FLOPS_TABLE = (
|
| 232 |
+
# NVIDIA Blackwell
|
| 233 |
+
(["gb200"], 2.5e15),
|
| 234 |
+
(["grace blackwell"], 2.5e15),
|
| 235 |
+
(["b200"], 2.25e15),
|
| 236 |
+
(["b100"], 1.8e15),
|
| 237 |
+
# NVIDIA Hopper
|
| 238 |
+
(["h200", "nvl"], 836e12),
|
| 239 |
+
(["h200", "pcie"], 836e12),
|
| 240 |
+
(["h200"], 989e12),
|
| 241 |
+
(["h100", "nvl"], 835e12),
|
| 242 |
+
(["h100", "pcie"], 756e12),
|
| 243 |
+
(["h100"], 989e12),
|
| 244 |
+
(["h800", "nvl"], 989e12),
|
| 245 |
+
(["h800"], 756e12),
|
| 246 |
+
# NVIDIA Ampere data center
|
| 247 |
+
(["a100"], 312e12),
|
| 248 |
+
(["a800"], 312e12),
|
| 249 |
+
(["a40"], 149.7e12),
|
| 250 |
+
(["a30"], 165e12),
|
| 251 |
+
# NVIDIA Ada data center
|
| 252 |
+
(["l40s"], 362e12),
|
| 253 |
+
(["l40-s"], 362e12),
|
| 254 |
+
(["l40 s"], 362e12),
|
| 255 |
+
(["l4"], 121e12),
|
| 256 |
+
# AMD CDNA accelerators
|
| 257 |
+
(["mi355"], 2.5e15),
|
| 258 |
+
(["mi325"], 1.3074e15),
|
| 259 |
+
(["mi300x"], 1.3074e15),
|
| 260 |
+
(["mi300a"], 980.6e12),
|
| 261 |
+
(["mi250x"], 383e12),
|
| 262 |
+
(["mi250"], 362.1e12),
|
| 263 |
+
# Consumer RTX
|
| 264 |
+
(["5090"], 209.5e12),
|
| 265 |
+
(["4090"], 165.2e12),
|
| 266 |
+
(["3090"], 71e12),
|
| 267 |
+
)
|
| 268 |
+
for patterns, flops in _PEAK_FLOPS_TABLE:
|
| 269 |
+
if all(p in name for p in patterns):
|
| 270 |
+
return flops
|
| 271 |
+
if "data center gpu max 1550" in name:
|
| 272 |
+
# Ponte Vecchio (PVC) - dynamic based on compute units
|
| 273 |
+
max_comp_units = torch.xpu.get_device_properties("xpu").max_compute_units
|
| 274 |
+
return 512 * max_comp_units * 1300 * 10**6
|
| 275 |
+
|
| 276 |
+
# Unknown GPU - return inf so MFU shows as 0% rather than a wrong guess
|
| 277 |
+
logger.warning(f"Peak flops undefined for: {device_name}, MFU will show as 0%")
|
| 278 |
+
return float('inf')
|
nanochat/engine.py
ADDED
|
@@ -0,0 +1,360 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Engine for efficient inference of our models.
|
| 3 |
+
|
| 4 |
+
Everything works around token sequences:
|
| 5 |
+
- The user can send token sequences to the engine
|
| 6 |
+
- The engine returns the next token
|
| 7 |
+
|
| 8 |
+
Notes:
|
| 9 |
+
- The engine knows nothing about tokenization, it's purely token id sequences.
|
| 10 |
+
|
| 11 |
+
The whole thing is made as efficient as possible.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
import signal
|
| 17 |
+
import warnings
|
| 18 |
+
from contextlib import contextmanager
|
| 19 |
+
from collections import deque
|
| 20 |
+
from nanochat.common import compute_init, autodetect_device_type
|
| 21 |
+
from nanochat.checkpoint_manager import load_model
|
| 22 |
+
|
| 23 |
+
# -----------------------------------------------------------------------------
|
| 24 |
+
# Calculator tool helpers
|
| 25 |
+
@contextmanager
|
| 26 |
+
def timeout(duration, formula):
|
| 27 |
+
def timeout_handler(signum, frame):
|
| 28 |
+
raise Exception(f"'{formula}': timed out after {duration} seconds")
|
| 29 |
+
|
| 30 |
+
signal.signal(signal.SIGALRM, timeout_handler)
|
| 31 |
+
signal.alarm(duration)
|
| 32 |
+
yield
|
| 33 |
+
signal.alarm(0)
|
| 34 |
+
|
| 35 |
+
def eval_with_timeout(formula, max_time=3):
|
| 36 |
+
try:
|
| 37 |
+
with timeout(max_time, formula):
|
| 38 |
+
with warnings.catch_warnings():
|
| 39 |
+
warnings.simplefilter("ignore", SyntaxWarning)
|
| 40 |
+
return eval(formula, {"__builtins__": {}}, {})
|
| 41 |
+
except Exception as e:
|
| 42 |
+
signal.alarm(0)
|
| 43 |
+
# print(f"Warning: Failed to eval {formula}, exception: {e}") # it's ok ignore wrong calculator usage
|
| 44 |
+
return None
|
| 45 |
+
|
| 46 |
+
def use_calculator(expr):
|
| 47 |
+
"""
|
| 48 |
+
Evaluate a Python expression safely.
|
| 49 |
+
Supports both math expressions and string operations like .count()
|
| 50 |
+
"""
|
| 51 |
+
# Remove commas from numbers
|
| 52 |
+
expr = expr.replace(",", "")
|
| 53 |
+
|
| 54 |
+
# Check if it's a pure math expression (old behavior)
|
| 55 |
+
if all([x in "0123456789*+-/.() " for x in expr]):
|
| 56 |
+
if "**" in expr: # disallow power operator
|
| 57 |
+
return None
|
| 58 |
+
return eval_with_timeout(expr)
|
| 59 |
+
|
| 60 |
+
# Check if it's a string operation we support
|
| 61 |
+
# Allow: strings (single/double quotes), .count(), letters, numbers, spaces, parens
|
| 62 |
+
allowed_chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'\"()._ "
|
| 63 |
+
if not all([x in allowed_chars for x in expr]):
|
| 64 |
+
return None
|
| 65 |
+
|
| 66 |
+
# Disallow dangerous patterns
|
| 67 |
+
dangerous_patterns = ['__', 'import', 'exec', 'eval', 'compile', 'open', 'file',
|
| 68 |
+
'input', 'raw_input', 'globals', 'locals', 'vars', 'dir',
|
| 69 |
+
'getattr', 'setattr', 'delattr', 'hasattr']
|
| 70 |
+
expr_lower = expr.lower()
|
| 71 |
+
if any(pattern in expr_lower for pattern in dangerous_patterns):
|
| 72 |
+
return None
|
| 73 |
+
|
| 74 |
+
# Only allow .count() method for now (can expand later)
|
| 75 |
+
if '.count(' not in expr:
|
| 76 |
+
return None
|
| 77 |
+
|
| 78 |
+
# Evaluate with timeout
|
| 79 |
+
return eval_with_timeout(expr)
|
| 80 |
+
|
| 81 |
+
# -----------------------------------------------------------------------------
|
| 82 |
+
class KVCache:
|
| 83 |
+
"""
|
| 84 |
+
KV Cache designed for Flash Attention 3's flash_attn_with_kvcache API.
|
| 85 |
+
|
| 86 |
+
Key differences from FA2-style cache:
|
| 87 |
+
- Tensors are (B, T, H, D) not (B, H, T, D)
|
| 88 |
+
- FA3 updates the cache in-place during flash_attn_with_kvcache
|
| 89 |
+
- Position tracked per batch element via cache_seqlens tensor
|
| 90 |
+
"""
|
| 91 |
+
|
| 92 |
+
def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers, device, dtype):
|
| 93 |
+
self.batch_size = batch_size
|
| 94 |
+
self.max_seq_len = seq_len
|
| 95 |
+
self.n_layers = num_layers
|
| 96 |
+
self.n_heads = num_heads
|
| 97 |
+
self.head_dim = head_dim
|
| 98 |
+
# Pre-allocate cache tensors: (n_layers, B, T, H, D)
|
| 99 |
+
self.k_cache = torch.zeros(num_layers, batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)
|
| 100 |
+
self.v_cache = torch.zeros(num_layers, batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)
|
| 101 |
+
# Current sequence length per batch element (FA3 needs int32)
|
| 102 |
+
self.cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device)
|
| 103 |
+
|
| 104 |
+
def reset(self):
|
| 105 |
+
"""Reset cache to empty state."""
|
| 106 |
+
self.cache_seqlens.zero_()
|
| 107 |
+
|
| 108 |
+
def get_pos(self):
|
| 109 |
+
"""Get current position (assumes all batch elements at same position)."""
|
| 110 |
+
return self.cache_seqlens[0].item()
|
| 111 |
+
|
| 112 |
+
def get_layer_cache(self, layer_idx):
|
| 113 |
+
"""Return (k_cache, v_cache) views for a specific layer."""
|
| 114 |
+
return self.k_cache[layer_idx], self.v_cache[layer_idx]
|
| 115 |
+
|
| 116 |
+
def advance(self, num_tokens):
|
| 117 |
+
"""Advance the cache position by num_tokens."""
|
| 118 |
+
self.cache_seqlens += num_tokens
|
| 119 |
+
|
| 120 |
+
def prefill(self, other):
|
| 121 |
+
"""
|
| 122 |
+
Copy cached KV from another cache into this one.
|
| 123 |
+
Used when we do batch=1 prefill and then want to generate multiple samples in parallel.
|
| 124 |
+
"""
|
| 125 |
+
assert self.get_pos() == 0, "Cannot prefill a non-empty KV cache"
|
| 126 |
+
assert self.n_layers == other.n_layers and self.n_heads == other.n_heads and self.head_dim == other.head_dim
|
| 127 |
+
assert self.max_seq_len >= other.max_seq_len
|
| 128 |
+
other_pos = other.get_pos()
|
| 129 |
+
self.k_cache[:, :, :other_pos, :, :] = other.k_cache[:, :, :other_pos, :, :]
|
| 130 |
+
self.v_cache[:, :, :other_pos, :, :] = other.v_cache[:, :, :other_pos, :, :]
|
| 131 |
+
self.cache_seqlens.fill_(other_pos)
|
| 132 |
+
|
| 133 |
+
# -----------------------------------------------------------------------------
|
| 134 |
+
@torch.inference_mode()
|
| 135 |
+
def sample_next_token(logits, rng, temperature=1.0, top_k=None):
|
| 136 |
+
"""Sample a single next token from given logits of shape (B, vocab_size). Returns (B, 1)."""
|
| 137 |
+
assert temperature >= 0.0, "temperature must be non-negative"
|
| 138 |
+
if temperature == 0.0:
|
| 139 |
+
return torch.argmax(logits, dim=-1, keepdim=True)
|
| 140 |
+
if top_k is not None and top_k > 0:
|
| 141 |
+
k = min(top_k, logits.size(-1))
|
| 142 |
+
vals, idx = torch.topk(logits, k, dim=-1)
|
| 143 |
+
vals = vals / temperature
|
| 144 |
+
probs = F.softmax(vals, dim=-1)
|
| 145 |
+
choice = torch.multinomial(probs, num_samples=1, generator=rng)
|
| 146 |
+
return idx.gather(1, choice)
|
| 147 |
+
else:
|
| 148 |
+
logits = logits / temperature
|
| 149 |
+
probs = F.softmax(logits, dim=-1)
|
| 150 |
+
return torch.multinomial(probs, num_samples=1, generator=rng)
|
| 151 |
+
|
| 152 |
+
# -----------------------------------------------------------------------------
|
| 153 |
+
|
| 154 |
+
class RowState:
|
| 155 |
+
# Per-row state tracking during generation
|
| 156 |
+
def __init__(self, current_tokens=None):
|
| 157 |
+
self.current_tokens = current_tokens or [] # Current token sequence for this row
|
| 158 |
+
self.forced_tokens = deque() # Queue of tokens to force inject
|
| 159 |
+
self.in_python_block = False # Whether we are inside a python block
|
| 160 |
+
self.python_expr_tokens = [] # Tokens of the current python expression
|
| 161 |
+
self.completed = False # Whether this row has completed generation
|
| 162 |
+
|
| 163 |
+
class Engine:
|
| 164 |
+
|
| 165 |
+
def __init__(self, model, tokenizer):
|
| 166 |
+
self.model = model
|
| 167 |
+
self.tokenizer = tokenizer # needed for tool use
|
| 168 |
+
|
| 169 |
+
@torch.inference_mode()
|
| 170 |
+
def generate(self, tokens, num_samples=1, max_tokens=None, temperature=1.0, top_k=None, seed=42, repetition_penalty=1.0, repetition_window=64):
|
| 171 |
+
"""Same as generate, but does single prefill and then clones the KV cache."""
|
| 172 |
+
assert isinstance(tokens, list) and isinstance(tokens[0], int), "expecting list of ints"
|
| 173 |
+
device = self.model.get_device()
|
| 174 |
+
# NOTE: setting the dtype here and in this way is an ugly hack.
|
| 175 |
+
# Currently the repo assumes that cuda -> bfloat16 and everything else -> float32.
|
| 176 |
+
# We need to know the dtype here to call __init__ on KVCache and pre-allocate its tensors.
|
| 177 |
+
# As a quick hack, we're making generate() function inherit and know about this repo-wise assumption.
|
| 178 |
+
# I think there has to be a bigger refactor to deal with device/dtype tracking across the codebase.
|
| 179 |
+
# In particular, the KVCache should allocate its tensors lazily
|
| 180 |
+
dtype = torch.bfloat16 if device.type == "cuda" else torch.float32
|
| 181 |
+
rng = torch.Generator(device=device)
|
| 182 |
+
rng.manual_seed(seed)
|
| 183 |
+
|
| 184 |
+
# Get the special tokens we need to coordinate the tool use state machine
|
| 185 |
+
get_special = lambda s: self.tokenizer.encode_special(s)
|
| 186 |
+
python_start = get_special("<|python_start|>")
|
| 187 |
+
python_end = get_special("<|python_end|>")
|
| 188 |
+
output_start = get_special("<|output_start|>")
|
| 189 |
+
output_end = get_special("<|output_end|>")
|
| 190 |
+
assistant_end = get_special("<|assistant_end|>") # if sampled, ends row
|
| 191 |
+
bos = self.tokenizer.get_bos_token_id() # if sampled, ends row
|
| 192 |
+
|
| 193 |
+
# 1) Run a batch 1 prefill of the prompt tokens
|
| 194 |
+
m = self.model.config
|
| 195 |
+
kv_model_kwargs = {"num_heads": m.n_kv_head, "head_dim": m.n_embd // m.n_head, "num_layers": m.n_layer}
|
| 196 |
+
kv_cache_prefill = KVCache(
|
| 197 |
+
batch_size=1,
|
| 198 |
+
seq_len=len(tokens),
|
| 199 |
+
device=device,
|
| 200 |
+
dtype=dtype,
|
| 201 |
+
**kv_model_kwargs,
|
| 202 |
+
)
|
| 203 |
+
ids = torch.tensor([tokens], dtype=torch.long, device=device)
|
| 204 |
+
logits = self.model.forward(ids, kv_cache=kv_cache_prefill)
|
| 205 |
+
logits = logits[:, -1, :].expand(num_samples, -1) # (num_samples, vocab_size)
|
| 206 |
+
|
| 207 |
+
# 2) Replicate the KV cache for each sample/row
|
| 208 |
+
kv_length_hint = (len(tokens) + max_tokens) if max_tokens is not None else self.model.config.sequence_len
|
| 209 |
+
kv_cache_decode = KVCache(
|
| 210 |
+
batch_size=num_samples,
|
| 211 |
+
seq_len=kv_length_hint,
|
| 212 |
+
device=device,
|
| 213 |
+
dtype=dtype,
|
| 214 |
+
**kv_model_kwargs,
|
| 215 |
+
)
|
| 216 |
+
kv_cache_decode.prefill(kv_cache_prefill)
|
| 217 |
+
del kv_cache_prefill # no need to keep this memory around
|
| 218 |
+
|
| 219 |
+
# 3) Initialize states for each sample
|
| 220 |
+
row_states = [RowState(tokens.copy()) for _ in range(num_samples)]
|
| 221 |
+
|
| 222 |
+
# 4) Main generation loop
|
| 223 |
+
num_generated = 0
|
| 224 |
+
while True:
|
| 225 |
+
# Stop condition: we've reached max tokens
|
| 226 |
+
if max_tokens is not None and num_generated >= max_tokens:
|
| 227 |
+
break
|
| 228 |
+
# Stop condition: all rows are completed
|
| 229 |
+
if all(state.completed for state in row_states):
|
| 230 |
+
break
|
| 231 |
+
|
| 232 |
+
# Sample the next token for each row
|
| 233 |
+
if repetition_penalty != 1.0: # Victorian repetition penalty patch
|
| 234 |
+
_pen = logits.clone()
|
| 235 |
+
for _i, _s in enumerate(row_states):
|
| 236 |
+
if not _s.completed:
|
| 237 |
+
for _t in set(_s.current_tokens[-repetition_window:]):
|
| 238 |
+
_pen[_i, _t] = (_pen[_i, _t] / repetition_penalty
|
| 239 |
+
if _pen[_i, _t] > 0 else _pen[_i, _t] * repetition_penalty)
|
| 240 |
+
next_ids = sample_next_token(_pen, rng, temperature, top_k)
|
| 241 |
+
else:
|
| 242 |
+
next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1)
|
| 243 |
+
sampled_tokens = next_ids[:, 0].tolist()
|
| 244 |
+
|
| 245 |
+
# Process each row: choose the next token, update state, optional tool use
|
| 246 |
+
token_column = [] # contains the next token id along each row
|
| 247 |
+
token_masks = [] # contains the mask (was it sampled (1) or forced (0)?) along each row
|
| 248 |
+
for i, state in enumerate(row_states):
|
| 249 |
+
# Select the next token in this row
|
| 250 |
+
is_forced = len(state.forced_tokens) > 0 # are there tokens waiting to be forced in deque?
|
| 251 |
+
token_masks.append(0 if is_forced else 1) # mask is 0 if forced, 1 if sampled
|
| 252 |
+
next_token = state.forced_tokens.popleft() if is_forced else sampled_tokens[i]
|
| 253 |
+
token_column.append(next_token)
|
| 254 |
+
# Update the state of this row to include the next token
|
| 255 |
+
state.current_tokens.append(next_token)
|
| 256 |
+
# On <|assistant_end|> or <|bos|>, mark the row as completed
|
| 257 |
+
if next_token == assistant_end or next_token == bos:
|
| 258 |
+
state.completed = True
|
| 259 |
+
# Handle tool logic
|
| 260 |
+
if next_token == python_start:
|
| 261 |
+
state.in_python_block = True
|
| 262 |
+
state.python_expr_tokens = []
|
| 263 |
+
elif next_token == python_end and state.in_python_block:
|
| 264 |
+
state.in_python_block = False
|
| 265 |
+
if state.python_expr_tokens:
|
| 266 |
+
expr = self.tokenizer.decode(state.python_expr_tokens)
|
| 267 |
+
result = use_calculator(expr)
|
| 268 |
+
if result is not None:
|
| 269 |
+
result_tokens = self.tokenizer.encode(str(result))
|
| 270 |
+
state.forced_tokens.append(output_start)
|
| 271 |
+
state.forced_tokens.extend(result_tokens)
|
| 272 |
+
state.forced_tokens.append(output_end)
|
| 273 |
+
state.python_expr_tokens = []
|
| 274 |
+
elif state.in_python_block:
|
| 275 |
+
state.python_expr_tokens.append(next_token)
|
| 276 |
+
|
| 277 |
+
# Yield the token column
|
| 278 |
+
yield token_column, token_masks
|
| 279 |
+
num_generated += 1
|
| 280 |
+
|
| 281 |
+
# Prepare logits for next iteration
|
| 282 |
+
ids = torch.tensor(token_column, dtype=torch.long, device=device).unsqueeze(1)
|
| 283 |
+
logits = self.model.forward(ids, kv_cache=kv_cache_decode)[:, -1, :] # (B, vocab_size)
|
| 284 |
+
|
| 285 |
+
def generate_batch(self, tokens, num_samples=1, **kwargs):
|
| 286 |
+
"""
|
| 287 |
+
Non-streaming batch generation that just returns the final token sequences.
|
| 288 |
+
Returns a list of token sequences (list of lists of ints).
|
| 289 |
+
Terminal tokens (assistant_end, bos) are not included in the results.
|
| 290 |
+
"""
|
| 291 |
+
assistant_end = self.tokenizer.encode_special("<|assistant_end|>")
|
| 292 |
+
bos = self.tokenizer.get_bos_token_id()
|
| 293 |
+
results = [tokens.copy() for _ in range(num_samples)]
|
| 294 |
+
masks = [[0] * len(tokens) for _ in range(num_samples)]
|
| 295 |
+
completed = [False] * num_samples
|
| 296 |
+
for token_column, token_masks in self.generate(tokens, num_samples, **kwargs):
|
| 297 |
+
for i, (token, mask) in enumerate(zip(token_column, token_masks)):
|
| 298 |
+
if not completed[i]:
|
| 299 |
+
if token == assistant_end or token == bos:
|
| 300 |
+
completed[i] = True
|
| 301 |
+
else:
|
| 302 |
+
results[i].append(token)
|
| 303 |
+
masks[i].append(mask)
|
| 304 |
+
# Stop if all rows are completed
|
| 305 |
+
if all(completed):
|
| 306 |
+
break
|
| 307 |
+
return results, masks
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
if __name__ == "__main__":
|
| 311 |
+
"""
|
| 312 |
+
Quick inline test to make sure that the naive/slow model.generate function
|
| 313 |
+
is equivalent to the faster Engine.generate function here.
|
| 314 |
+
"""
|
| 315 |
+
import time
|
| 316 |
+
# init compute
|
| 317 |
+
device_type = autodetect_device_type()
|
| 318 |
+
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
| 319 |
+
# load the model and tokenizer
|
| 320 |
+
model, tokenizer, meta = load_model("base", device, phase="eval")
|
| 321 |
+
bos_token_id = tokenizer.get_bos_token_id()
|
| 322 |
+
# common hyperparameters
|
| 323 |
+
kwargs = dict(max_tokens=64, temperature=0.0)
|
| 324 |
+
# set the starting prompt
|
| 325 |
+
prompt_tokens = tokenizer.encode("The chemical formula of water is", prepend=bos_token_id)
|
| 326 |
+
# generate the reference sequence using the model.generate() function
|
| 327 |
+
generated_tokens = []
|
| 328 |
+
torch.cuda.synchronize()
|
| 329 |
+
t0 = time.time()
|
| 330 |
+
stream = model.generate(prompt_tokens, **kwargs)
|
| 331 |
+
for token in stream:
|
| 332 |
+
generated_tokens.append(token)
|
| 333 |
+
chunk = tokenizer.decode([token])
|
| 334 |
+
print(chunk, end="", flush=True)
|
| 335 |
+
print()
|
| 336 |
+
torch.cuda.synchronize()
|
| 337 |
+
t1 = time.time()
|
| 338 |
+
print(f"Reference time: {t1 - t0:.2f}s")
|
| 339 |
+
reference_ids = generated_tokens
|
| 340 |
+
# generate tokens with Engine
|
| 341 |
+
generated_tokens = []
|
| 342 |
+
engine = Engine(model, tokenizer)
|
| 343 |
+
stream = engine.generate(prompt_tokens, num_samples=1, **kwargs) # note: runs in fp32
|
| 344 |
+
torch.cuda.synchronize()
|
| 345 |
+
t0 = time.time()
|
| 346 |
+
for token_column, token_masks in stream:
|
| 347 |
+
token = token_column[0] # only print out the first row
|
| 348 |
+
generated_tokens.append(token)
|
| 349 |
+
chunk = tokenizer.decode([token])
|
| 350 |
+
print(chunk, end="", flush=True)
|
| 351 |
+
print()
|
| 352 |
+
torch.cuda.synchronize()
|
| 353 |
+
t1 = time.time()
|
| 354 |
+
print(f"Engine time: {t1 - t0:.2f}s")
|
| 355 |
+
# compare the two sequences
|
| 356 |
+
for i in range(len(reference_ids)):
|
| 357 |
+
if reference_ids[i] != generated_tokens[i]:
|
| 358 |
+
print(f"Mismatch at {i}: {reference_ids[i]} != {generated_tokens[i]}")
|
| 359 |
+
break
|
| 360 |
+
print(f"Match: {reference_ids == generated_tokens}")
|
nanochat/flash_attention.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Unified Flash Attention interface with automatic FA3/SDPA switching.
|
| 3 |
+
|
| 4 |
+
Exports `flash_attn` module that matches the FA3 API exactly, but falls back
|
| 5 |
+
to PyTorch SDPA on non-Hopper GPUs (including Blackwell), MPS, and CPU.
|
| 6 |
+
|
| 7 |
+
Usage (drop-in replacement for FA3):
|
| 8 |
+
from nanochat.flash_attention import flash_attn
|
| 9 |
+
|
| 10 |
+
# Training (no KV cache)
|
| 11 |
+
y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=window_size)
|
| 12 |
+
|
| 13 |
+
# Inference (with KV cache)
|
| 14 |
+
y = flash_attn.flash_attn_with_kvcache(q, k_cache, v_cache, k=k, v=v, ...)
|
| 15 |
+
"""
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# =============================================================================
|
| 21 |
+
# Detection: Try to load FA3 on Hopper+ GPUs
|
| 22 |
+
# =============================================================================
|
| 23 |
+
def _load_flash_attention_3():
|
| 24 |
+
"""Try to load Flash Attention 3 (requires Hopper GPU, sm90)."""
|
| 25 |
+
if not torch.cuda.is_available():
|
| 26 |
+
return None
|
| 27 |
+
try:
|
| 28 |
+
major, _ = torch.cuda.get_device_capability()
|
| 29 |
+
# FA3 kernels are compiled for Hopper (sm90) only
|
| 30 |
+
# Ada (sm89), Blackwell (sm100) need SDPA fallback until FA3 is recompiled
|
| 31 |
+
if major != 9:
|
| 32 |
+
return None
|
| 33 |
+
import os
|
| 34 |
+
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
|
| 35 |
+
from kernels import get_kernel
|
| 36 |
+
return get_kernel('varunneal/flash-attention-3').flash_attn_interface
|
| 37 |
+
except Exception:
|
| 38 |
+
return None
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
_fa3 = _load_flash_attention_3()
|
| 42 |
+
HAS_FA3 = _fa3 is not None
|
| 43 |
+
|
| 44 |
+
# Override for testing: set to 'fa3', 'sdpa', or None (auto)
|
| 45 |
+
_override_impl = None
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _resolve_use_fa3():
|
| 49 |
+
"""Decide once whether to use FA3, based on availability, override, and dtype."""
|
| 50 |
+
if _override_impl == 'fa3':
|
| 51 |
+
assert HAS_FA3, "Cannot override to FA3: not available on this hardware"
|
| 52 |
+
return True
|
| 53 |
+
if _override_impl == 'sdpa':
|
| 54 |
+
return False
|
| 55 |
+
if HAS_FA3:
|
| 56 |
+
# FA3 Hopper kernels only support bf16 and fp8; fp16/fp32 must use SDPA fallback
|
| 57 |
+
from nanochat.common import COMPUTE_DTYPE
|
| 58 |
+
if COMPUTE_DTYPE == torch.bfloat16:
|
| 59 |
+
return True
|
| 60 |
+
return False
|
| 61 |
+
return False
|
| 62 |
+
|
| 63 |
+
USE_FA3 = _resolve_use_fa3()
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
# =============================================================================
|
| 67 |
+
# SDPA helpers
|
| 68 |
+
# =============================================================================
|
| 69 |
+
def _sdpa_attention(q, k, v, window_size, enable_gqa):
|
| 70 |
+
"""
|
| 71 |
+
SDPA attention with sliding window support.
|
| 72 |
+
q, k, v are (B, H, T, D) format.
|
| 73 |
+
"""
|
| 74 |
+
Tq = q.size(2)
|
| 75 |
+
Tk = k.size(2)
|
| 76 |
+
window = window_size[0]
|
| 77 |
+
|
| 78 |
+
# Full context, same length
|
| 79 |
+
if (window < 0 or window >= Tq) and Tq == Tk:
|
| 80 |
+
return F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa)
|
| 81 |
+
|
| 82 |
+
# Single token generation
|
| 83 |
+
if Tq == 1:
|
| 84 |
+
if window >= 0 and window < Tk:
|
| 85 |
+
# window is "left" tokens we need to include (window + 1) keys total
|
| 86 |
+
start = max(0, Tk - (window + 1))
|
| 87 |
+
k = k[:, :, start:, :]
|
| 88 |
+
v = v[:, :, start:, :]
|
| 89 |
+
return F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa)
|
| 90 |
+
|
| 91 |
+
# Need explicit mask for sliding window/chunk inference
|
| 92 |
+
device = q.device
|
| 93 |
+
# For chunk inference (Tq != Tk), is_causal is not aligned to cache position => build an explicit bool mask
|
| 94 |
+
row_idx = (Tk - Tq) + torch.arange(Tq, device=device).unsqueeze(1)
|
| 95 |
+
col_idx = torch.arange(Tk, device=device).unsqueeze(0)
|
| 96 |
+
mask = col_idx <= row_idx
|
| 97 |
+
|
| 98 |
+
# sliding window (left)
|
| 99 |
+
if window >= 0 and window < Tk:
|
| 100 |
+
mask = mask & ((row_idx - col_idx) <= window)
|
| 101 |
+
|
| 102 |
+
return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, enable_gqa=enable_gqa)
|
| 103 |
+
|
| 104 |
+
# =============================================================================
|
| 105 |
+
# Public API: Same interface as FA3
|
| 106 |
+
# =============================================================================
|
| 107 |
+
def flash_attn_func(q, k, v, causal=False, window_size=(-1, -1)):
|
| 108 |
+
"""
|
| 109 |
+
Flash Attention for training (no KV cache).
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
q, k, v: Tensors of shape (B, T, H, D)
|
| 113 |
+
causal: Whether to use causal masking
|
| 114 |
+
window_size: (left, right) sliding window. -1 means unlimited.
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
Output tensor of shape (B, T, H, D)
|
| 118 |
+
"""
|
| 119 |
+
if USE_FA3:
|
| 120 |
+
return _fa3.flash_attn_func(q, k, v, causal=causal, window_size=window_size)
|
| 121 |
+
|
| 122 |
+
# SDPA fallback: transpose (B, T, H, D) -> (B, H, T, D)
|
| 123 |
+
q = q.transpose(1, 2)
|
| 124 |
+
k = k.transpose(1, 2)
|
| 125 |
+
v = v.transpose(1, 2)
|
| 126 |
+
enable_gqa = q.size(1) != k.size(1)
|
| 127 |
+
y = _sdpa_attention(q, k, v, window_size, enable_gqa)
|
| 128 |
+
return y.transpose(1, 2) # back to (B, T, H, D)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=None,
|
| 132 |
+
causal=False, window_size=(-1, -1)):
|
| 133 |
+
"""
|
| 134 |
+
Flash Attention with KV cache for inference.
|
| 135 |
+
|
| 136 |
+
FA3 updates k_cache/v_cache in-place. Our SDPA fallback does the same.
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
q: Queries, shape (B, T_new, H, D)
|
| 140 |
+
k_cache, v_cache: Pre-allocated cache tensors, shape (B, T_max, H_kv, D)
|
| 141 |
+
k, v: New keys/values to insert, shape (B, T_new, H_kv, D)
|
| 142 |
+
cache_seqlens: Current position in cache, shape (B,) int32
|
| 143 |
+
causal: Whether to use causal masking
|
| 144 |
+
window_size: (left, right) sliding window. -1 means unlimited.
|
| 145 |
+
|
| 146 |
+
Returns:
|
| 147 |
+
Output tensor of shape (B, T_new, H, D)
|
| 148 |
+
"""
|
| 149 |
+
if USE_FA3:
|
| 150 |
+
return _fa3.flash_attn_with_kvcache(
|
| 151 |
+
q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens,
|
| 152 |
+
causal=causal, window_size=window_size
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
# SDPA fallback: manually manage KV cache
|
| 156 |
+
B, T_new, H, D = q.shape
|
| 157 |
+
pos = cache_seqlens[0].item() # assume uniform position across batch
|
| 158 |
+
|
| 159 |
+
# Insert new k, v into cache (in-place, matching FA3 behavior)
|
| 160 |
+
if k is not None and v is not None:
|
| 161 |
+
k_cache[:, pos:pos+T_new, :, :] = k
|
| 162 |
+
v_cache[:, pos:pos+T_new, :, :] = v
|
| 163 |
+
|
| 164 |
+
# Get full cache up to current position + new tokens
|
| 165 |
+
end_pos = pos + T_new
|
| 166 |
+
k_full = k_cache[:, :end_pos, :, :]
|
| 167 |
+
v_full = v_cache[:, :end_pos, :, :]
|
| 168 |
+
|
| 169 |
+
# Transpose to SDPA layout: (B, T, H, D) -> (B, H, T, D)
|
| 170 |
+
q_sdpa = q.transpose(1, 2)
|
| 171 |
+
k_sdpa = k_full.transpose(1, 2)
|
| 172 |
+
v_sdpa = v_full.transpose(1, 2)
|
| 173 |
+
|
| 174 |
+
enable_gqa = q_sdpa.size(1) != k_sdpa.size(1)
|
| 175 |
+
y_sdpa = _sdpa_attention(q_sdpa, k_sdpa, v_sdpa, window_size, enable_gqa)
|
| 176 |
+
|
| 177 |
+
return y_sdpa.transpose(1, 2) # back to (B, T, H, D)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
# =============================================================================
|
| 181 |
+
# Export: flash_attn module interface (drop-in replacement for FA3)
|
| 182 |
+
# =============================================================================
|
| 183 |
+
from types import SimpleNamespace
|
| 184 |
+
flash_attn = SimpleNamespace(
|
| 185 |
+
flash_attn_func=flash_attn_func,
|
| 186 |
+
flash_attn_with_kvcache=flash_attn_with_kvcache,
|
| 187 |
+
)
|
nanochat/gpt.py
ADDED
|
@@ -0,0 +1,465 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GPT model (rewrite, a lot simpler)
|
| 3 |
+
Notable features:
|
| 4 |
+
- rotary embeddings (and no positional embeddings)
|
| 5 |
+
- QK norm
|
| 6 |
+
- untied weights for token embedding and lm_head
|
| 7 |
+
- relu^2 activation in MLP
|
| 8 |
+
- norm after token embedding
|
| 9 |
+
- no learnable params in rmsnorm
|
| 10 |
+
- no bias in linear layers
|
| 11 |
+
- Group-Query Attention (GQA) support for more efficient inference
|
| 12 |
+
- Flash Attention 3 integration
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from functools import partial
|
| 16 |
+
from dataclasses import dataclass
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
import torch.nn.functional as F
|
| 21 |
+
|
| 22 |
+
from nanochat.common import get_dist_info, print0, COMPUTE_DTYPE
|
| 23 |
+
from nanochat.optim import MuonAdamW, DistMuonAdamW
|
| 24 |
+
|
| 25 |
+
# Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere
|
| 26 |
+
from nanochat.flash_attention import flash_attn
|
| 27 |
+
|
| 28 |
+
@dataclass
|
| 29 |
+
class GPTConfig:
|
| 30 |
+
sequence_len: int = 2048
|
| 31 |
+
vocab_size: int = 32768
|
| 32 |
+
n_layer: int = 12
|
| 33 |
+
n_head: int = 6 # number of query heads
|
| 34 |
+
n_kv_head: int = 6 # number of key/value heads (GQA)
|
| 35 |
+
n_embd: int = 768
|
| 36 |
+
# Sliding window attention pattern string, tiled across layers. Final layer always L.
|
| 37 |
+
# Characters: L=long (full context), S=short (half context)
|
| 38 |
+
# Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long
|
| 39 |
+
window_pattern: str = "SSSL"
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def norm(x):
|
| 43 |
+
return F.rms_norm(x, (x.size(-1),)) # note that this will run in bf16, seems ok
|
| 44 |
+
|
| 45 |
+
class Linear(nn.Linear):
|
| 46 |
+
"""nn.Linear that casts weights to match input dtype in forward.
|
| 47 |
+
Replaces autocast: master weights stay fp32 for optimizer precision,
|
| 48 |
+
but matmuls run in the activation dtype (typically bf16 from embeddings)."""
|
| 49 |
+
def forward(self, x):
|
| 50 |
+
return F.linear(x, self.weight.to(dtype=x.dtype))
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def has_ve(layer_idx, n_layer):
|
| 54 |
+
"""Returns True if GPT layer should have Value Embedding (alternating, last layer always included)."""
|
| 55 |
+
return layer_idx % 2 == (n_layer - 1) % 2
|
| 56 |
+
|
| 57 |
+
def apply_rotary_emb(x, cos, sin):
|
| 58 |
+
assert x.ndim == 4 # multihead attention
|
| 59 |
+
d = x.shape[3] // 2
|
| 60 |
+
x1, x2 = x[..., :d], x[..., d:] # split up last dim into two halves
|
| 61 |
+
y1 = x1 * cos + x2 * sin # rotate pairs of dims
|
| 62 |
+
y2 = x1 * (-sin) + x2 * cos
|
| 63 |
+
return torch.cat([y1, y2], 3)
|
| 64 |
+
|
| 65 |
+
class CausalSelfAttention(nn.Module):
|
| 66 |
+
def __init__(self, config, layer_idx):
|
| 67 |
+
super().__init__()
|
| 68 |
+
self.layer_idx = layer_idx
|
| 69 |
+
self.n_head = config.n_head
|
| 70 |
+
self.n_kv_head = config.n_kv_head
|
| 71 |
+
self.n_embd = config.n_embd
|
| 72 |
+
self.head_dim = self.n_embd // self.n_head
|
| 73 |
+
assert self.n_embd % self.n_head == 0
|
| 74 |
+
assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0
|
| 75 |
+
self.c_q = Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
|
| 76 |
+
self.c_k = Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
| 77 |
+
self.c_v = Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
| 78 |
+
self.c_proj = Linear(self.n_embd, self.n_embd, bias=False)
|
| 79 |
+
self.ve_gate_channels = 32 # Victorian checkpoint patch
|
| 80 |
+
self.ve_gate = Linear(self.ve_gate_channels, self.n_kv_head, bias=False) if has_ve(layer_idx, config.n_layer) else None
|
| 81 |
+
|
| 82 |
+
def forward(self, x, ve, cos_sin, window_size, kv_cache):
|
| 83 |
+
B, T, C = x.size()
|
| 84 |
+
|
| 85 |
+
# Project the input to get queries, keys, and values
|
| 86 |
+
# Shape: (B, T, H, D) - FA3's native layout, no transpose needed!
|
| 87 |
+
q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
|
| 88 |
+
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
|
| 89 |
+
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
|
| 90 |
+
|
| 91 |
+
# Value residual (ResFormer): mix in value embedding with input-dependent gate per head
|
| 92 |
+
if ve is not None:
|
| 93 |
+
ve = ve.view(B, T, self.n_kv_head, self.head_dim)
|
| 94 |
+
gate = 3 * torch.sigmoid(self.ve_gate(x[..., :self.ve_gate_channels])) # (B, T, n_kv_head), range (0, 3)
|
| 95 |
+
v = v + gate.unsqueeze(-1) * ve
|
| 96 |
+
|
| 97 |
+
# Apply Rotary Embeddings to queries and keys to get relative positional encoding
|
| 98 |
+
cos, sin = cos_sin
|
| 99 |
+
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
|
| 100 |
+
q, k = norm(q), norm(k) # QK norm
|
| 101 |
+
q = q * 1.15 # sharper attention (split scale between Q and K), TODO think through better
|
| 102 |
+
k = k * 1.15
|
| 103 |
+
|
| 104 |
+
# Flash Attention (FA3 on Hopper+, PyTorch SDPA fallback elsewhere)
|
| 105 |
+
# window_size is (left, right) tuple: (N, 0) for causal, (-1, 0) for full context
|
| 106 |
+
if kv_cache is None:
|
| 107 |
+
# Training: causal attention with optional sliding window
|
| 108 |
+
y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=window_size)
|
| 109 |
+
else:
|
| 110 |
+
# Inference: use flash_attn_with_kvcache which handles cache management
|
| 111 |
+
k_cache, v_cache = kv_cache.get_layer_cache(self.layer_idx)
|
| 112 |
+
y = flash_attn.flash_attn_with_kvcache(
|
| 113 |
+
q, k_cache, v_cache,
|
| 114 |
+
k=k, v=v,
|
| 115 |
+
cache_seqlens=kv_cache.cache_seqlens,
|
| 116 |
+
causal=True,
|
| 117 |
+
window_size=window_size,
|
| 118 |
+
)
|
| 119 |
+
# Advance position after last layer processes
|
| 120 |
+
if self.layer_idx == kv_cache.n_layers - 1:
|
| 121 |
+
kv_cache.advance(T)
|
| 122 |
+
|
| 123 |
+
# Re-assemble the heads and project back to residual stream
|
| 124 |
+
y = y.contiguous().view(B, T, -1)
|
| 125 |
+
y = self.c_proj(y)
|
| 126 |
+
return y
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class MLP(nn.Module):
|
| 130 |
+
def __init__(self, config):
|
| 131 |
+
super().__init__()
|
| 132 |
+
self.c_fc = Linear(config.n_embd, 4 * config.n_embd, bias=False)
|
| 133 |
+
self.c_proj = Linear(4 * config.n_embd, config.n_embd, bias=False)
|
| 134 |
+
|
| 135 |
+
def forward(self, x):
|
| 136 |
+
x = self.c_fc(x)
|
| 137 |
+
x = F.relu(x).square()
|
| 138 |
+
x = self.c_proj(x)
|
| 139 |
+
return x
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class Block(nn.Module):
|
| 143 |
+
def __init__(self, config, layer_idx):
|
| 144 |
+
super().__init__()
|
| 145 |
+
self.attn = CausalSelfAttention(config, layer_idx)
|
| 146 |
+
self.mlp = MLP(config)
|
| 147 |
+
|
| 148 |
+
def forward(self, x, ve, cos_sin, window_size, kv_cache):
|
| 149 |
+
x = x + self.attn(norm(x), ve, cos_sin, window_size, kv_cache)
|
| 150 |
+
x = x + self.mlp(norm(x))
|
| 151 |
+
return x
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
class GPT(nn.Module):
|
| 155 |
+
def __init__(self, config, pad_vocab_size_to=64):
|
| 156 |
+
"""
|
| 157 |
+
NOTE a major footgun: this __init__ function runs in meta device context (!!)
|
| 158 |
+
Therefore, any calculations inside here are shapes and dtypes only, no actual data.
|
| 159 |
+
=> We actually initialize all data (parameters, buffers, etc.) in init_weights() instead.
|
| 160 |
+
"""
|
| 161 |
+
super().__init__()
|
| 162 |
+
self.config = config
|
| 163 |
+
# Compute per-layer window sizes for sliding window attention
|
| 164 |
+
# window_size is (left, right) tuple: (-1, 0) for full context, (N, 0) for sliding window
|
| 165 |
+
self.window_sizes = self._compute_window_sizes(config)
|
| 166 |
+
# Pad vocab for efficiency (DDP, tensor cores). This is just an optimization - outputs are cropped in forward().
|
| 167 |
+
# https://huggingface.co/docs/transformers/main_classes/model#transformers.PreTrainedModel.resize_token_embeddings
|
| 168 |
+
padded_vocab_size = ((config.vocab_size + pad_vocab_size_to - 1) // pad_vocab_size_to) * pad_vocab_size_to
|
| 169 |
+
if padded_vocab_size != config.vocab_size:
|
| 170 |
+
print0(f"Padding vocab_size from {config.vocab_size} to {padded_vocab_size} for efficiency")
|
| 171 |
+
self.transformer = nn.ModuleDict({
|
| 172 |
+
"wte": nn.Embedding(padded_vocab_size, config.n_embd),
|
| 173 |
+
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
|
| 174 |
+
})
|
| 175 |
+
self.lm_head = Linear(config.n_embd, padded_vocab_size, bias=False)
|
| 176 |
+
# Per-layer learnable scalars (inspired by modded-nanogpt)
|
| 177 |
+
# resid_lambdas: scales the residual stream at each layer (init 1.0 = neutral)
|
| 178 |
+
# x0_lambdas: blends initial embedding back in at each layer (init 0.0 = disabled)
|
| 179 |
+
# Separate parameters so they can have different optimizer treatment
|
| 180 |
+
self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) # fake init, real init in init_weights()
|
| 181 |
+
self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights()
|
| 182 |
+
# Value embeddings (ResFormer-style): alternating layers, last layer always included
|
| 183 |
+
head_dim = config.n_embd // config.n_head
|
| 184 |
+
kv_dim = config.n_kv_head * head_dim
|
| 185 |
+
self.value_embeds = nn.ModuleDict({str(i): nn.Embedding(padded_vocab_size, kv_dim) for i in range(config.n_layer) if has_ve(i, config.n_layer)})
|
| 186 |
+
# To support meta device initialization, we init the rotary embeddings here, but it's just "fake" meta tensors only.
|
| 187 |
+
# As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
|
| 188 |
+
# so let's just over-compute them by 10X, but assert fail if we ever reach that amount.
|
| 189 |
+
# In the future we can dynamically grow the cache, for now it's fine.
|
| 190 |
+
self.rotary_seq_len = config.sequence_len * 10 # 10X over-compute should be enough, TODO make nicer?
|
| 191 |
+
head_dim = config.n_embd // config.n_head
|
| 192 |
+
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
| 193 |
+
self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
|
| 194 |
+
self.register_buffer("sin", sin, persistent=False)
|
| 195 |
+
|
| 196 |
+
@torch.no_grad()
|
| 197 |
+
def init_weights(self):
|
| 198 |
+
"""
|
| 199 |
+
Initialize the full model in this one function for maximum clarity.
|
| 200 |
+
|
| 201 |
+
wte (embedding): normal, std=1.0
|
| 202 |
+
lm_head: normal, std=0.001
|
| 203 |
+
for each block:
|
| 204 |
+
attn.c_q: uniform, std=1/sqrt(n_embd)
|
| 205 |
+
attn.c_k: uniform, std=1/sqrt(n_embd)
|
| 206 |
+
attn.c_v: uniform, std=1/sqrt(n_embd)
|
| 207 |
+
attn.c_proj: zeros
|
| 208 |
+
mlp.c_fc: uniform, std=1/sqrt(n_embd)
|
| 209 |
+
mlp.c_proj: zeros
|
| 210 |
+
"""
|
| 211 |
+
|
| 212 |
+
# Embedding and unembedding
|
| 213 |
+
torch.nn.init.normal_(self.transformer.wte.weight, mean=0.0, std=0.8)
|
| 214 |
+
torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=0.001)
|
| 215 |
+
|
| 216 |
+
# Transformer blocks: uniform init with bound = sqrt(3) * std (same standard deviation as normal)
|
| 217 |
+
n_embd = self.config.n_embd
|
| 218 |
+
s = 3**0.5 * n_embd**-0.5 # sqrt(3) multiplier makes sure Uniform achieves the same std as Normal
|
| 219 |
+
for block in self.transformer.h:
|
| 220 |
+
torch.nn.init.uniform_(block.attn.c_q.weight, -s, s) # weights use Uniform to avoid outliers
|
| 221 |
+
torch.nn.init.uniform_(block.attn.c_k.weight, -s, s)
|
| 222 |
+
torch.nn.init.uniform_(block.attn.c_v.weight, -s, s)
|
| 223 |
+
torch.nn.init.zeros_(block.attn.c_proj.weight) # projections are zero
|
| 224 |
+
torch.nn.init.uniform_(block.mlp.c_fc.weight, -s * 0.5, s * 0.5) # 0.5x init scale for c_fc
|
| 225 |
+
torch.nn.init.zeros_(block.mlp.c_proj.weight)
|
| 226 |
+
|
| 227 |
+
# Per-layer scalars
|
| 228 |
+
self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init
|
| 229 |
+
self.x0_lambdas.fill_(0.1) # 0.1 => small initial weight for skip connection to input embedding
|
| 230 |
+
|
| 231 |
+
# Value embeddings (init like c_v: uniform with same std)
|
| 232 |
+
for ve in self.value_embeds.values():
|
| 233 |
+
torch.nn.init.uniform_(ve.weight, -s, s)
|
| 234 |
+
|
| 235 |
+
# Gate weights init with small positive values so gates start slightly above neutral
|
| 236 |
+
for block in self.transformer.h:
|
| 237 |
+
if block.attn.ve_gate is not None:
|
| 238 |
+
torch.nn.init.uniform_(block.attn.ve_gate.weight, 0.0, 0.02)
|
| 239 |
+
|
| 240 |
+
# Rotary embeddings
|
| 241 |
+
head_dim = self.config.n_embd // self.config.n_head
|
| 242 |
+
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
| 243 |
+
self.cos, self.sin = cos, sin
|
| 244 |
+
|
| 245 |
+
# Cast embeddings to COMPUTE_DTYPE: optimizer can tolerate reduced-precision
|
| 246 |
+
# embeddings and it saves memory. Exception: fp16 requires fp32 embeddings
|
| 247 |
+
# because GradScaler cannot unscale fp16 gradients.
|
| 248 |
+
if COMPUTE_DTYPE != torch.float16:
|
| 249 |
+
self.transformer.wte.to(dtype=COMPUTE_DTYPE)
|
| 250 |
+
for ve in self.value_embeds.values():
|
| 251 |
+
ve.to(dtype=COMPUTE_DTYPE)
|
| 252 |
+
|
| 253 |
+
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=100000, device=None):
|
| 254 |
+
# TODO: bump base theta more? e.g. 100K is more common more recently
|
| 255 |
+
# autodetect the device from model embeddings
|
| 256 |
+
if device is None:
|
| 257 |
+
device = self.transformer.wte.weight.device
|
| 258 |
+
# stride the channels
|
| 259 |
+
channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
|
| 260 |
+
inv_freq = 1.0 / (base ** (channel_range / head_dim))
|
| 261 |
+
# stride the time steps
|
| 262 |
+
t = torch.arange(seq_len, dtype=torch.float32, device=device)
|
| 263 |
+
# calculate the rotation frequencies at each (time, channel) pair
|
| 264 |
+
freqs = torch.outer(t, inv_freq)
|
| 265 |
+
cos, sin = freqs.cos(), freqs.sin()
|
| 266 |
+
cos, sin = cos.to(COMPUTE_DTYPE), sin.to(COMPUTE_DTYPE)
|
| 267 |
+
cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting
|
| 268 |
+
return cos, sin
|
| 269 |
+
|
| 270 |
+
def _compute_window_sizes(self, config):
|
| 271 |
+
"""
|
| 272 |
+
Compute per-layer window sizes for sliding window attention.
|
| 273 |
+
|
| 274 |
+
Returns list of (left, right) tuples for FA3's window_size parameter:
|
| 275 |
+
- left: how many tokens before current position to attend to (-1 = unlimited)
|
| 276 |
+
- right: how many tokens after current position to attend to (0 for causal)
|
| 277 |
+
|
| 278 |
+
Pattern string is tiled across layers. Final layer always gets L (full context).
|
| 279 |
+
Characters: L=long (full context), S=short (half context)
|
| 280 |
+
"""
|
| 281 |
+
pattern = config.window_pattern.upper()
|
| 282 |
+
assert all(c in "SL" for c in pattern), f"Invalid window_pattern: {pattern}. Use only S and L."
|
| 283 |
+
# Map characters to window sizes
|
| 284 |
+
long_window = config.sequence_len
|
| 285 |
+
short_window = -(-long_window // 3 // 128) * 128 # ceil to FA3 tile size (2048 -> 768)
|
| 286 |
+
char_to_window = {
|
| 287 |
+
"L": (long_window, 0),
|
| 288 |
+
"S": (short_window, 0),
|
| 289 |
+
}
|
| 290 |
+
# Tile pattern across layers
|
| 291 |
+
window_sizes = []
|
| 292 |
+
for layer_idx in range(config.n_layer):
|
| 293 |
+
char = pattern[layer_idx % len(pattern)]
|
| 294 |
+
window_sizes.append(char_to_window[char])
|
| 295 |
+
# Final layer always gets full context
|
| 296 |
+
window_sizes[-1] = (long_window, 0)
|
| 297 |
+
return window_sizes
|
| 298 |
+
|
| 299 |
+
def get_device(self):
|
| 300 |
+
return self.transformer.wte.weight.device
|
| 301 |
+
|
| 302 |
+
def estimate_flops(self):
|
| 303 |
+
"""
|
| 304 |
+
Return the estimated FLOPs per token for the model (forward + backward).
|
| 305 |
+
Each matmul weight parameter contributes 2 FLOPs (multiply *, accumulate +) in forward, and 2X that in backward => 2+4=6.
|
| 306 |
+
Cleanest explanation of this: https://medium.com/@dzmitrybahdanau/the-flops-calculus-of-language-model-training-3b19c1f025e4
|
| 307 |
+
On top of that, 12 * h * q * effective_seq_len accounts for key @ query matmul flops inside attention.
|
| 308 |
+
With sliding windows, effective_seq_len varies per layer (capped by window size).
|
| 309 |
+
Ref: https://arxiv.org/abs/2204.02311 (PaLM paper).
|
| 310 |
+
This is ~1% off from the exact formulas of Chinchilla paper, the difference is:
|
| 311 |
+
- Chinchilla counts the embedding layer as flops (? weird, it's just a lookup => we ignore)
|
| 312 |
+
- Chinchilla counts exp/sum/divide in attention softmax as flops (a little sus and very tiny => we ignore)
|
| 313 |
+
"""
|
| 314 |
+
nparams = sum(p.numel() for p in self.parameters())
|
| 315 |
+
# Exclude non-matmul params: embeddings and per-layer scalars
|
| 316 |
+
value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values())
|
| 317 |
+
nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel +
|
| 318 |
+
self.resid_lambdas.numel() + self.x0_lambdas.numel())
|
| 319 |
+
h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
|
| 320 |
+
# Sum attention FLOPs per layer, accounting for sliding window
|
| 321 |
+
attn_flops = 0
|
| 322 |
+
for window_size in self.window_sizes:
|
| 323 |
+
window = window_size[0] # (left, right) tuple, we use left
|
| 324 |
+
effective_seq = t if window < 0 else min(window, t)
|
| 325 |
+
attn_flops += 12 * h * q * effective_seq
|
| 326 |
+
num_flops_per_token = 6 * (nparams - nparams_exclude) + attn_flops
|
| 327 |
+
return num_flops_per_token
|
| 328 |
+
|
| 329 |
+
def num_scaling_params(self):
|
| 330 |
+
"""
|
| 331 |
+
Return detailed parameter counts for scaling law analysis.
|
| 332 |
+
Different papers use different conventions:
|
| 333 |
+
- Kaplan et al. excluded embedding parameters
|
| 334 |
+
- Chinchilla included all parameters
|
| 335 |
+
Ref: https://arxiv.org/abs/2203.15556 (Chinchilla paper)
|
| 336 |
+
Ref: https://arxiv.org/abs/2001.08361 (Kaplan et al. original scaling laws paper)
|
| 337 |
+
|
| 338 |
+
Returns a dict with counts for each parameter group, so downstream analysis
|
| 339 |
+
can experiment with which combination gives the cleanest scaling laws.
|
| 340 |
+
"""
|
| 341 |
+
# Count each group separately (mirrors the grouping in setup_optimizers)
|
| 342 |
+
wte = sum(p.numel() for p in self.transformer.wte.parameters())
|
| 343 |
+
value_embeds = sum(p.numel() for p in self.value_embeds.parameters())
|
| 344 |
+
lm_head = sum(p.numel() for p in self.lm_head.parameters())
|
| 345 |
+
transformer_matrices = sum(p.numel() for p in self.transformer.h.parameters())
|
| 346 |
+
scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel()
|
| 347 |
+
total = wte + value_embeds + lm_head + transformer_matrices + scalars
|
| 348 |
+
assert total == sum(p.numel() for p in self.parameters()), "Parameter count mismatch"
|
| 349 |
+
return {
|
| 350 |
+
'wte': wte,
|
| 351 |
+
'value_embeds': value_embeds,
|
| 352 |
+
'lm_head': lm_head,
|
| 353 |
+
'transformer_matrices': transformer_matrices,
|
| 354 |
+
'scalars': scalars,
|
| 355 |
+
'total': total,
|
| 356 |
+
}
|
| 357 |
+
|
| 358 |
+
def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, scalar_lr=0.5):
|
| 359 |
+
model_dim = self.config.n_embd
|
| 360 |
+
ddp, rank, local_rank, world_size = get_dist_info()
|
| 361 |
+
|
| 362 |
+
# Separate out all parameters into groups
|
| 363 |
+
matrix_params = list(self.transformer.h.parameters())
|
| 364 |
+
value_embeds_params = list(self.value_embeds.parameters())
|
| 365 |
+
embedding_params = list(self.transformer.wte.parameters())
|
| 366 |
+
lm_head_params = list(self.lm_head.parameters())
|
| 367 |
+
resid_params = [self.resid_lambdas]
|
| 368 |
+
x0_params = [self.x0_lambdas]
|
| 369 |
+
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params)
|
| 370 |
+
|
| 371 |
+
# Scale the LR for the AdamW parameters by ∝1/√dmodel (tuned for 768 dim model)
|
| 372 |
+
dmodel_lr_scale = (model_dim / 768) ** -0.5
|
| 373 |
+
print0(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}")
|
| 374 |
+
|
| 375 |
+
# Build param_groups with all required fields explicit
|
| 376 |
+
param_groups = [
|
| 377 |
+
# AdamW groups (embeddings, lm_head, scalars)
|
| 378 |
+
dict(kind='adamw', params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale, betas=(0.8, 0.96), eps=1e-10, weight_decay=0.01),
|
| 379 |
+
dict(kind='adamw', params=embedding_params, lr=embedding_lr * dmodel_lr_scale, betas=(0.8, 0.995), eps=1e-10, weight_decay=0.001),
|
| 380 |
+
dict(kind='adamw', params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale * 0.5, betas=(0.8, 0.995), eps=1e-10, weight_decay=0.01),
|
| 381 |
+
dict(kind='adamw', params=resid_params, lr=scalar_lr * 0.01, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.05),
|
| 382 |
+
dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0), # higher beta1 for x0
|
| 383 |
+
]
|
| 384 |
+
# Muon groups (matrix params, grouped by shape for stacking)
|
| 385 |
+
for shape in sorted({p.shape for p in matrix_params}):
|
| 386 |
+
group_params = [p for p in matrix_params if p.shape == shape]
|
| 387 |
+
param_groups.append(dict(
|
| 388 |
+
kind='muon', params=group_params, lr=matrix_lr,
|
| 389 |
+
momentum=0.95, ns_steps=5, beta2=0.9, weight_decay=weight_decay,
|
| 390 |
+
))
|
| 391 |
+
|
| 392 |
+
Factory = DistMuonAdamW if ddp else MuonAdamW
|
| 393 |
+
optimizer = Factory(param_groups)
|
| 394 |
+
for group in optimizer.param_groups:
|
| 395 |
+
group["initial_lr"] = group["lr"]
|
| 396 |
+
return optimizer
|
| 397 |
+
|
| 398 |
+
def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
|
| 399 |
+
B, T = idx.size()
|
| 400 |
+
|
| 401 |
+
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2))
|
| 402 |
+
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
|
| 403 |
+
assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
|
| 404 |
+
assert self.cos.dtype == COMPUTE_DTYPE, f"Rotary embeddings must be in {COMPUTE_DTYPE}, got {self.cos.dtype}"
|
| 405 |
+
# if kv cache exists, we need to offset the rotary embeddings to the current position in the cache
|
| 406 |
+
T0 = 0 if kv_cache is None else kv_cache.get_pos()
|
| 407 |
+
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length
|
| 408 |
+
|
| 409 |
+
# Forward the trunk of the Transformer
|
| 410 |
+
x = self.transformer.wte(idx) # embed current token
|
| 411 |
+
x = x.to(COMPUTE_DTYPE) # ensure activations are in compute dtype (no-op usually, but active for fp16 code path)
|
| 412 |
+
x = norm(x)
|
| 413 |
+
x0 = x # save initial normalized embedding for x0 residual
|
| 414 |
+
for i, block in enumerate(self.transformer.h):
|
| 415 |
+
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
|
| 416 |
+
ve = self.value_embeds[str(i)](idx).to(x.dtype) if str(i) in self.value_embeds else None
|
| 417 |
+
x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache)
|
| 418 |
+
x = norm(x)
|
| 419 |
+
|
| 420 |
+
# Forward the lm_head (compute logits)
|
| 421 |
+
softcap = 15 # smoothly cap the logits to the range [-softcap, softcap]
|
| 422 |
+
logits = self.lm_head(x) # (B, T, padded_vocab_size) <- very big tensor, large amount of memory
|
| 423 |
+
logits = logits[..., :self.config.vocab_size] # slice to remove padding
|
| 424 |
+
logits = logits.float() # switch to fp32 for logit softcap and loss computation
|
| 425 |
+
logits = softcap * torch.tanh(logits / softcap) # squash the logits
|
| 426 |
+
|
| 427 |
+
if targets is not None:
|
| 428 |
+
# training: given the targets, compute and return the loss
|
| 429 |
+
# TODO experiment with chunked cross-entropy?
|
| 430 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction)
|
| 431 |
+
return loss
|
| 432 |
+
else:
|
| 433 |
+
# inference: just return the logits directly
|
| 434 |
+
return logits
|
| 435 |
+
|
| 436 |
+
@torch.inference_mode()
|
| 437 |
+
def generate(self, tokens, max_tokens, temperature=1.0, top_k=None, seed=42):
|
| 438 |
+
"""
|
| 439 |
+
Naive autoregressive streaming inference.
|
| 440 |
+
To make it super simple, let's assume:
|
| 441 |
+
- batch size is 1
|
| 442 |
+
- ids and the yielded tokens are simple Python lists and ints
|
| 443 |
+
"""
|
| 444 |
+
assert isinstance(tokens, list)
|
| 445 |
+
device = self.get_device()
|
| 446 |
+
rng = None
|
| 447 |
+
if temperature > 0:
|
| 448 |
+
rng = torch.Generator(device=device)
|
| 449 |
+
rng.manual_seed(seed)
|
| 450 |
+
ids = torch.tensor([tokens], dtype=torch.long, device=device) # add batch dim
|
| 451 |
+
for _ in range(max_tokens):
|
| 452 |
+
logits = self.forward(ids) # (B, T, vocab_size)
|
| 453 |
+
logits = logits[:, -1, :] # (B, vocab_size)
|
| 454 |
+
if top_k is not None and top_k > 0:
|
| 455 |
+
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
| 456 |
+
logits[logits < v[:, [-1]]] = -float('Inf')
|
| 457 |
+
if temperature > 0:
|
| 458 |
+
logits = logits / temperature
|
| 459 |
+
probs = F.softmax(logits, dim=-1)
|
| 460 |
+
next_ids = torch.multinomial(probs, num_samples=1, generator=rng)
|
| 461 |
+
else:
|
| 462 |
+
next_ids = torch.argmax(logits, dim=-1, keepdim=True)
|
| 463 |
+
ids = torch.cat((ids, next_ids), dim=1)
|
| 464 |
+
token = next_ids.item()
|
| 465 |
+
yield token
|
nanochat/logo.svg
ADDED
|
|
nanochat/optim.py
ADDED
|
@@ -0,0 +1,533 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
A nice and efficient mixed AdamW/Muon Combined Optimizer.
|
| 3 |
+
Usually the embeddings and scalars go into AdamW, and the matrix parameters go into Muon.
|
| 4 |
+
Two versions are provided (MuonAdamW, DistMuonAdamW), for single GPU and distributed.
|
| 5 |
+
|
| 6 |
+
Addapted from: https://github.com/KellerJordan/modded-nanogpt
|
| 7 |
+
Further contributions from @karpathy and @chrisjmccormick.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.distributed as dist
|
| 12 |
+
from torch import Tensor
|
| 13 |
+
|
| 14 |
+
# -----------------------------------------------------------------------------
|
| 15 |
+
"""
|
| 16 |
+
Good old AdamW optimizer, fused kernel.
|
| 17 |
+
https://arxiv.org/abs/1711.05101
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
@torch.compile(dynamic=False, fullgraph=True)
|
| 21 |
+
def adamw_step_fused(
|
| 22 |
+
p: Tensor, # (32768, 768) - parameter tensor
|
| 23 |
+
grad: Tensor, # (32768, 768) - gradient, same shape as p
|
| 24 |
+
exp_avg: Tensor, # (32768, 768) - first moment, same shape as p
|
| 25 |
+
exp_avg_sq: Tensor, # (32768, 768) - second moment, same shape as p
|
| 26 |
+
step_t: Tensor, # () - 0-D CPU tensor, step count
|
| 27 |
+
lr_t: Tensor, # () - 0-D CPU tensor, learning rate
|
| 28 |
+
beta1_t: Tensor, # () - 0-D CPU tensor, beta1
|
| 29 |
+
beta2_t: Tensor, # () - 0-D CPU tensor, beta2
|
| 30 |
+
eps_t: Tensor, # () - 0-D CPU tensor, epsilon
|
| 31 |
+
wd_t: Tensor, # () - 0-D CPU tensor, weight decay
|
| 32 |
+
) -> None:
|
| 33 |
+
"""
|
| 34 |
+
Fused AdamW step: weight_decay -> momentum_update -> bias_correction -> param_update
|
| 35 |
+
All in one compiled graph to eliminate Python overhead between ops.
|
| 36 |
+
The 0-D CPU tensors avoid recompilation when hyperparameter values change.
|
| 37 |
+
"""
|
| 38 |
+
# Weight decay (decoupled, applied before the update)
|
| 39 |
+
p.mul_(1 - lr_t * wd_t)
|
| 40 |
+
# Update running averages (lerp_ is cleaner and fuses well)
|
| 41 |
+
exp_avg.lerp_(grad, 1 - beta1_t)
|
| 42 |
+
exp_avg_sq.lerp_(grad.square(), 1 - beta2_t)
|
| 43 |
+
# Bias corrections
|
| 44 |
+
bias1 = 1 - beta1_t ** step_t
|
| 45 |
+
bias2 = 1 - beta2_t ** step_t
|
| 46 |
+
# Compute update and apply
|
| 47 |
+
denom = (exp_avg_sq / bias2).sqrt() + eps_t
|
| 48 |
+
step_size = lr_t / bias1
|
| 49 |
+
p.add_(exp_avg / denom, alpha=-step_size)
|
| 50 |
+
|
| 51 |
+
# -----------------------------------------------------------------------------
|
| 52 |
+
"""
|
| 53 |
+
Muon optimizer adapted and simplified from modded-nanogpt.
|
| 54 |
+
https://github.com/KellerJordan/modded-nanogpt
|
| 55 |
+
|
| 56 |
+
Background:
|
| 57 |
+
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
|
| 58 |
+
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
|
| 59 |
+
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
|
| 60 |
+
zero even beyond the point where the iteration no longer converges all the way to one everywhere
|
| 61 |
+
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
|
| 62 |
+
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
|
| 63 |
+
performance at all relative to UV^T, where USV^T = G is the SVD.
|
| 64 |
+
|
| 65 |
+
Here, an alternative to Newton-Schulz iteration with potentially better convergence properties:
|
| 66 |
+
Polar Express Sign Method for orthogonalization.
|
| 67 |
+
https://arxiv.org/pdf/2505.16932
|
| 68 |
+
by Noah Amsel, David Persson, Christopher Musco, Robert M. Gower.
|
| 69 |
+
|
| 70 |
+
NorMuon variance reduction: per-neuron/column adaptive learning rate that normalizes
|
| 71 |
+
update scales after orthogonalization (Muon's output has non-uniform scales across neurons).
|
| 72 |
+
https://arxiv.org/pdf/2510.05491
|
| 73 |
+
|
| 74 |
+
Some of the changes in nanochat implementation:
|
| 75 |
+
- Uses a simpler, more general approach to parameter grouping and stacking
|
| 76 |
+
- Uses a single fused kernel for the momentum -> polar_express -> variance_reduction -> update step
|
| 77 |
+
- Makes no assumptions about model architecture (e.g. that attention weights are fused into QKVO format)
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
# Coefficients for Polar Express (computed for num_iters=5, safety_factor=2e-2, cushion=2)
|
| 81 |
+
# From https://arxiv.org/pdf/2505.16932
|
| 82 |
+
polar_express_coeffs = [
|
| 83 |
+
(8.156554524902461, -22.48329292557795, 15.878769915207462),
|
| 84 |
+
(4.042929935166739, -2.808917465908714, 0.5000178451051316),
|
| 85 |
+
(3.8916678022926607, -2.772484153217685, 0.5060648178503393),
|
| 86 |
+
(3.285753657755655, -2.3681294933425376, 0.46449024233003106),
|
| 87 |
+
(2.3465413258596377, -1.7097828382687081, 0.42323551169305323),
|
| 88 |
+
]
|
| 89 |
+
|
| 90 |
+
@torch.compile(dynamic=False, fullgraph=True)
|
| 91 |
+
def muon_step_fused(
|
| 92 |
+
stacked_grads: Tensor, # (12, 768, 3072) - stacked gradients
|
| 93 |
+
stacked_params: Tensor, # (12, 768, 3072) - stacked parameters
|
| 94 |
+
momentum_buffer: Tensor, # (12, 768, 3072) - first moment buffer
|
| 95 |
+
second_momentum_buffer: Tensor, # (12, 768, 1) or (12, 1, 3072) - factored second moment
|
| 96 |
+
momentum_t: Tensor, # () - 0-D CPU tensor, momentum coefficient
|
| 97 |
+
lr_t: Tensor, # () - 0-D CPU tensor, learning rate
|
| 98 |
+
wd_t: Tensor, # () - 0-D CPU tensor, weight decay
|
| 99 |
+
beta2_t: Tensor, # () - 0-D CPU tensor, beta2 for second moment
|
| 100 |
+
ns_steps: int, # 5 - number of Newton-Schulz/Polar Express iterations
|
| 101 |
+
red_dim: int, # -1 or -2 - reduction dimension for variance
|
| 102 |
+
) -> None:
|
| 103 |
+
"""
|
| 104 |
+
Fused Muon step: momentum -> polar_express -> variance_reduction -> cautious_update
|
| 105 |
+
All in one compiled graph to eliminate Python overhead between ops.
|
| 106 |
+
Some of the constants are 0-D CPU tensors to avoid recompilation when values change.
|
| 107 |
+
"""
|
| 108 |
+
|
| 109 |
+
# Nesterov momentum
|
| 110 |
+
momentum = momentum_t.to(stacked_grads.dtype)
|
| 111 |
+
momentum_buffer.lerp_(stacked_grads, 1 - momentum)
|
| 112 |
+
g = stacked_grads.lerp_(momentum_buffer, momentum)
|
| 113 |
+
|
| 114 |
+
# Polar express
|
| 115 |
+
X = g.bfloat16()
|
| 116 |
+
X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.01 + 1e-6)
|
| 117 |
+
if g.size(-2) > g.size(-1): # Tall matrix
|
| 118 |
+
for a, b, c in polar_express_coeffs[:ns_steps]:
|
| 119 |
+
A = X.mT @ X
|
| 120 |
+
B = b * A + c * (A @ A)
|
| 121 |
+
X = a * X + X @ B
|
| 122 |
+
else: # Wide matrix (original math)
|
| 123 |
+
for a, b, c in polar_express_coeffs[:ns_steps]:
|
| 124 |
+
A = X @ X.mT
|
| 125 |
+
B = b * A + c * (A @ A)
|
| 126 |
+
X = a * X + B @ X
|
| 127 |
+
g = X
|
| 128 |
+
|
| 129 |
+
# Variance reduction
|
| 130 |
+
beta2 = beta2_t.to(g.dtype)
|
| 131 |
+
v_mean = g.float().square().mean(dim=red_dim, keepdim=True)
|
| 132 |
+
red_dim_size = g.size(red_dim)
|
| 133 |
+
v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size
|
| 134 |
+
v_norm = v_norm_sq.sqrt()
|
| 135 |
+
second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2)
|
| 136 |
+
step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt()
|
| 137 |
+
scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square()
|
| 138 |
+
v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt()
|
| 139 |
+
final_scale = step_size * (v_norm / v_norm_new.clamp_min(1e-10))
|
| 140 |
+
g = g * final_scale.to(g.dtype)
|
| 141 |
+
|
| 142 |
+
# Cautious weight decay + parameter update
|
| 143 |
+
lr = lr_t.to(g.dtype)
|
| 144 |
+
wd = wd_t.to(g.dtype)
|
| 145 |
+
mask = (g * stacked_params) >= 0
|
| 146 |
+
stacked_params.sub_(lr * g + lr * wd * stacked_params * mask)
|
| 147 |
+
|
| 148 |
+
# -----------------------------------------------------------------------------
|
| 149 |
+
# Single GPU version of the MuonAdamW optimizer.
|
| 150 |
+
# Used mostly for reference, debugging and testing.
|
| 151 |
+
|
| 152 |
+
class MuonAdamW(torch.optim.Optimizer):
|
| 153 |
+
"""
|
| 154 |
+
Combined optimizer: Muon for 2D matrix params, AdamW for others, single GPU version.
|
| 155 |
+
|
| 156 |
+
AdamW - Fused AdamW optimizer step.
|
| 157 |
+
|
| 158 |
+
Muon - MomentUm Orthogonalized by Newton-schulz
|
| 159 |
+
https://kellerjordan.github.io/posts/muon/
|
| 160 |
+
|
| 161 |
+
Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
|
| 162 |
+
processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
|
| 163 |
+
matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
|
| 164 |
+
the advantage that it can be stably run in bfloat16 on the GPU.
|
| 165 |
+
|
| 166 |
+
Some warnings:
|
| 167 |
+
- The Muon optimizer should not be used for the embedding layer, the final fully connected layer,
|
| 168 |
+
or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW).
|
| 169 |
+
- To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions.
|
| 170 |
+
|
| 171 |
+
Arguments:
|
| 172 |
+
param_groups: List of dicts, each containing:
|
| 173 |
+
- 'params': List of parameters
|
| 174 |
+
- 'kind': 'adamw' or 'muon'
|
| 175 |
+
- For AdamW groups: 'lr', 'betas', 'eps', 'weight_decay'
|
| 176 |
+
- For Muon groups: 'lr', 'momentum', 'ns_steps', 'beta2', 'weight_decay'
|
| 177 |
+
"""
|
| 178 |
+
def __init__(self, param_groups: list[dict]):
|
| 179 |
+
super().__init__(param_groups, defaults={})
|
| 180 |
+
# 0-D CPU tensors to avoid torch.compile recompilation when values change
|
| 181 |
+
# AdamW tensors
|
| 182 |
+
self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 183 |
+
self._adamw_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 184 |
+
self._adamw_beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 185 |
+
self._adamw_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 186 |
+
self._adamw_eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 187 |
+
self._adamw_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 188 |
+
# Muon tensors
|
| 189 |
+
self._muon_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 190 |
+
self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 191 |
+
self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 192 |
+
self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 193 |
+
|
| 194 |
+
def _step_adamw(self, group: dict) -> None:
|
| 195 |
+
"""
|
| 196 |
+
AdamW update for each param in the group individually.
|
| 197 |
+
Lazy init the state, fill in all 0-D tensors, call the fused kernel.
|
| 198 |
+
"""
|
| 199 |
+
for p in group['params']:
|
| 200 |
+
if p.grad is None:
|
| 201 |
+
continue
|
| 202 |
+
grad = p.grad
|
| 203 |
+
state = self.state[p]
|
| 204 |
+
|
| 205 |
+
# State init
|
| 206 |
+
if not state:
|
| 207 |
+
state['step'] = 0
|
| 208 |
+
state['exp_avg'] = torch.zeros_like(p)
|
| 209 |
+
state['exp_avg_sq'] = torch.zeros_like(p)
|
| 210 |
+
exp_avg = state['exp_avg']
|
| 211 |
+
exp_avg_sq = state['exp_avg_sq']
|
| 212 |
+
state['step'] += 1
|
| 213 |
+
|
| 214 |
+
# Fill 0-D tensors with current values
|
| 215 |
+
self._adamw_step_t.fill_(state['step'])
|
| 216 |
+
self._adamw_lr_t.fill_(group['lr'])
|
| 217 |
+
self._adamw_beta1_t.fill_(group['betas'][0])
|
| 218 |
+
self._adamw_beta2_t.fill_(group['betas'][1])
|
| 219 |
+
self._adamw_eps_t.fill_(group['eps'])
|
| 220 |
+
self._adamw_wd_t.fill_(group['weight_decay'])
|
| 221 |
+
|
| 222 |
+
# Fused update: weight_decay -> momentum -> bias_correction -> param_update
|
| 223 |
+
adamw_step_fused(
|
| 224 |
+
p, grad, exp_avg, exp_avg_sq,
|
| 225 |
+
self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t,
|
| 226 |
+
self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t,
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
def _step_muon(self, group: dict) -> None:
|
| 230 |
+
"""
|
| 231 |
+
Muon update for all params in the group (stacked for efficiency).
|
| 232 |
+
Lazy init the state, fill in all 0-D tensors, call the fused kernel.
|
| 233 |
+
"""
|
| 234 |
+
params: list[Tensor] = group['params']
|
| 235 |
+
if not params:
|
| 236 |
+
return
|
| 237 |
+
|
| 238 |
+
# Get or create group-level buffers (stored in first param's state for convenience)
|
| 239 |
+
p = params[0]
|
| 240 |
+
state = self.state[p]
|
| 241 |
+
num_params = len(params)
|
| 242 |
+
shape, device, dtype = p.shape, p.device, p.dtype
|
| 243 |
+
|
| 244 |
+
# Momentum for every individual parameter
|
| 245 |
+
if "momentum_buffer" not in state:
|
| 246 |
+
state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device)
|
| 247 |
+
momentum_buffer = state["momentum_buffer"]
|
| 248 |
+
|
| 249 |
+
# Second momentum buffer is factored, either per-row or per-column
|
| 250 |
+
if "second_momentum_buffer" not in state:
|
| 251 |
+
state_shape = (num_params, shape[-2], 1) if shape[-2] >= shape[-1] else (num_params, 1, shape[-1])
|
| 252 |
+
state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device)
|
| 253 |
+
second_momentum_buffer = state["second_momentum_buffer"]
|
| 254 |
+
red_dim = -1 if shape[-2] >= shape[-1] else -2
|
| 255 |
+
|
| 256 |
+
# Stack grads and params (NOTE: this assumes all params have the same shape)
|
| 257 |
+
stacked_grads = torch.stack([p.grad for p in params])
|
| 258 |
+
stacked_params = torch.stack(params)
|
| 259 |
+
|
| 260 |
+
# Fill all the 0-D tensors with current values
|
| 261 |
+
self._muon_momentum_t.fill_(group["momentum"])
|
| 262 |
+
self._muon_beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
|
| 263 |
+
self._muon_lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
|
| 264 |
+
self._muon_wd_t.fill_(group["weight_decay"])
|
| 265 |
+
|
| 266 |
+
# Single fused kernel: momentum -> polar_express -> variance_reduction -> update
|
| 267 |
+
muon_step_fused(
|
| 268 |
+
stacked_grads,
|
| 269 |
+
stacked_params,
|
| 270 |
+
momentum_buffer,
|
| 271 |
+
second_momentum_buffer,
|
| 272 |
+
self._muon_momentum_t,
|
| 273 |
+
self._muon_lr_t,
|
| 274 |
+
self._muon_wd_t,
|
| 275 |
+
self._muon_beta2_t,
|
| 276 |
+
group["ns_steps"],
|
| 277 |
+
red_dim,
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
# Copy back to original params
|
| 281 |
+
torch._foreach_copy_(params, list(stacked_params.unbind(0)))
|
| 282 |
+
|
| 283 |
+
@torch.no_grad()
|
| 284 |
+
def step(self):
|
| 285 |
+
for group in self.param_groups:
|
| 286 |
+
if group['kind'] == 'adamw':
|
| 287 |
+
self._step_adamw(group)
|
| 288 |
+
elif group['kind'] == 'muon':
|
| 289 |
+
self._step_muon(group)
|
| 290 |
+
else:
|
| 291 |
+
raise ValueError(f"Unknown optimizer kind: {group['kind']}")
|
| 292 |
+
|
| 293 |
+
# -----------------------------------------------------------------------------
|
| 294 |
+
# Distributed version of the MuonAdamW optimizer.
|
| 295 |
+
# Used for training on multiple GPUs.
|
| 296 |
+
|
| 297 |
+
class DistMuonAdamW(torch.optim.Optimizer):
|
| 298 |
+
"""
|
| 299 |
+
Combined distributed optimizer: Muon for 2D matrix params, AdamW for others.
|
| 300 |
+
|
| 301 |
+
See MuonAdamW for the algorithmic details of each optimizer. This class adds
|
| 302 |
+
distributed communication to enable multi-GPU training without PyTorch DDP.
|
| 303 |
+
|
| 304 |
+
Design Goals:
|
| 305 |
+
- Overlap communication with computation (async ops)
|
| 306 |
+
- Minimize memory by sharding optimizer states across ranks (ZeRO-2 style)
|
| 307 |
+
- Batch small tensors into single comm ops where possible
|
| 308 |
+
|
| 309 |
+
Communication Pattern (3-phase async):
|
| 310 |
+
We use a 3-phase structure to maximize overlap between communication and compute:
|
| 311 |
+
|
| 312 |
+
Phase 1: Launch all async reduce ops
|
| 313 |
+
- Kick off all reduce_scatter/all_reduce operations
|
| 314 |
+
- Don't wait - let them run in background while we continue
|
| 315 |
+
|
| 316 |
+
Phase 2: Wait for reduces, compute updates, launch gathers
|
| 317 |
+
- For each group: wait for its reduce, compute the update, launch gather
|
| 318 |
+
- By processing groups in order, earlier gathers run while later computes happen
|
| 319 |
+
|
| 320 |
+
Phase 3: Wait for gathers, copy back
|
| 321 |
+
- Wait for all gathers to complete
|
| 322 |
+
- Copy updated params back to original tensors (Muon only)
|
| 323 |
+
|
| 324 |
+
AdamW Communication (ZeRO-2 style):
|
| 325 |
+
- Small params (<1024 elements): all_reduce gradients, update full param on each rank.
|
| 326 |
+
Optimizer state is replicated but these params are tiny (scalars, biases).
|
| 327 |
+
- Large params: reduce_scatter gradients so each rank gets 1/N of the grad, update
|
| 328 |
+
only that slice, then all_gather the updated slices. Optimizer state (exp_avg,
|
| 329 |
+
exp_avg_sq) is sharded - each rank only stores state for its slice.
|
| 330 |
+
Requires param.shape[0] divisible by world_size.
|
| 331 |
+
|
| 332 |
+
Muon Communication (stacked + chunked):
|
| 333 |
+
- All params in a Muon group must have the same shape (caller's responsibility).
|
| 334 |
+
- Stack all K params into a single (K, *shape) tensor for efficient comm.
|
| 335 |
+
- Divide K params across N ranks: each rank "owns" ceil(K/N) params.
|
| 336 |
+
- reduce_scatter the stacked grads so each rank gets its chunk.
|
| 337 |
+
- Each rank computes Muon update only for params it owns.
|
| 338 |
+
- all_gather the updated params back to all ranks.
|
| 339 |
+
- Optimizer state (momentum_buffer, second_momentum_buffer) is sharded by chunk.
|
| 340 |
+
- Padding: if K doesn't divide evenly, we zero-pad to (ceil(K/N) * N) for comm,
|
| 341 |
+
then ignore the padding when copying back.
|
| 342 |
+
|
| 343 |
+
Buffer Reuse:
|
| 344 |
+
- For Muon, we allocate stacked_grads for reduce_scatter input, then reuse the
|
| 345 |
+
same buffer as the output for all_gather (stacked_params). This saves memory
|
| 346 |
+
since we don't need both buffers simultaneously.
|
| 347 |
+
|
| 348 |
+
Arguments:
|
| 349 |
+
param_groups: List of dicts, each containing:
|
| 350 |
+
- 'params': List of parameters
|
| 351 |
+
- 'kind': 'adamw' or 'muon'
|
| 352 |
+
- For AdamW groups: 'lr', 'betas', 'eps', 'weight_decay'
|
| 353 |
+
- For Muon groups: 'lr', 'momentum', 'ns_steps', 'beta2', 'weight_decay'
|
| 354 |
+
"""
|
| 355 |
+
def __init__(self, param_groups: list[dict]):
|
| 356 |
+
super().__init__(param_groups, defaults={})
|
| 357 |
+
# 0-D CPU tensors to avoid torch.compile recompilation when values change
|
| 358 |
+
self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 359 |
+
self._adamw_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 360 |
+
self._adamw_beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 361 |
+
self._adamw_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 362 |
+
self._adamw_eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 363 |
+
self._adamw_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 364 |
+
self._muon_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 365 |
+
self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 366 |
+
self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 367 |
+
self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 368 |
+
|
| 369 |
+
def _reduce_adamw(self, group: dict, world_size: int) -> dict:
|
| 370 |
+
"""Launch async reduce ops for AdamW group. Returns info dict with per-param infos."""
|
| 371 |
+
param_infos = {}
|
| 372 |
+
for p in group['params']:
|
| 373 |
+
grad = p.grad
|
| 374 |
+
if p.numel() < 1024:
|
| 375 |
+
# Small params: all_reduce (no scatter/gather needed)
|
| 376 |
+
future = dist.all_reduce(grad, op=dist.ReduceOp.AVG, async_op=True).get_future()
|
| 377 |
+
param_infos[p] = dict(future=future, grad_slice=grad, is_small=True)
|
| 378 |
+
else:
|
| 379 |
+
# Large params: reduce_scatter
|
| 380 |
+
assert grad.shape[0] % world_size == 0, f"AdamW reduce_scatter requires shape[0] ({grad.shape[0]}) divisible by world_size ({world_size})"
|
| 381 |
+
rank_size = grad.shape[0] // world_size
|
| 382 |
+
grad_slice = torch.empty_like(grad[:rank_size])
|
| 383 |
+
future = dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()
|
| 384 |
+
param_infos[p] = dict(future=future, grad_slice=grad_slice, is_small=False)
|
| 385 |
+
return dict(param_infos=param_infos)
|
| 386 |
+
|
| 387 |
+
def _reduce_muon(self, group: dict, world_size: int) -> dict:
|
| 388 |
+
"""Launch async reduce op for Muon group. Returns info dict."""
|
| 389 |
+
params = group['params']
|
| 390 |
+
chunk_size = (len(params) + world_size - 1) // world_size
|
| 391 |
+
padded_num_params = chunk_size * world_size
|
| 392 |
+
p = params[0]
|
| 393 |
+
shape, device, dtype = p.shape, p.device, p.dtype
|
| 394 |
+
|
| 395 |
+
# Stack grads and zero-pad to padded_num_params
|
| 396 |
+
grad_stack = torch.stack([p.grad for p in params])
|
| 397 |
+
stacked_grads = torch.empty(padded_num_params, *shape, dtype=dtype, device=device)
|
| 398 |
+
stacked_grads[:len(params)].copy_(grad_stack)
|
| 399 |
+
if len(params) < padded_num_params:
|
| 400 |
+
stacked_grads[len(params):].zero_()
|
| 401 |
+
|
| 402 |
+
# Reduce_scatter to get this rank's chunk
|
| 403 |
+
grad_chunk = torch.empty(chunk_size, *shape, dtype=dtype, device=device)
|
| 404 |
+
future = dist.reduce_scatter_tensor(grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True).get_future()
|
| 405 |
+
|
| 406 |
+
return dict(future=future, grad_chunk=grad_chunk, stacked_grads=stacked_grads, chunk_size=chunk_size)
|
| 407 |
+
|
| 408 |
+
def _compute_adamw(self, group: dict, info: dict, gather_list: list, rank: int, world_size: int) -> None:
|
| 409 |
+
"""Wait for reduce, compute AdamW updates, launch gathers for large params."""
|
| 410 |
+
param_infos = info['param_infos']
|
| 411 |
+
for p in group['params']:
|
| 412 |
+
pinfo = param_infos[p]
|
| 413 |
+
pinfo['future'].wait()
|
| 414 |
+
grad_slice = pinfo['grad_slice']
|
| 415 |
+
state = self.state[p]
|
| 416 |
+
|
| 417 |
+
# For small params, operate on full param; for large, operate on slice
|
| 418 |
+
if pinfo['is_small']:
|
| 419 |
+
p_slice = p
|
| 420 |
+
else:
|
| 421 |
+
rank_size = p.shape[0] // world_size
|
| 422 |
+
p_slice = p[rank * rank_size:(rank + 1) * rank_size]
|
| 423 |
+
|
| 424 |
+
# State init
|
| 425 |
+
if not state:
|
| 426 |
+
state['step'] = 0
|
| 427 |
+
state['exp_avg'] = torch.zeros_like(p_slice)
|
| 428 |
+
state['exp_avg_sq'] = torch.zeros_like(p_slice)
|
| 429 |
+
state['step'] += 1
|
| 430 |
+
|
| 431 |
+
# Fill 0-D tensors and run fused kernel
|
| 432 |
+
self._adamw_step_t.fill_(state['step'])
|
| 433 |
+
self._adamw_lr_t.fill_(group['lr'])
|
| 434 |
+
self._adamw_beta1_t.fill_(group['betas'][0])
|
| 435 |
+
self._adamw_beta2_t.fill_(group['betas'][1])
|
| 436 |
+
self._adamw_eps_t.fill_(group['eps'])
|
| 437 |
+
self._adamw_wd_t.fill_(group['weight_decay'])
|
| 438 |
+
adamw_step_fused(
|
| 439 |
+
p_slice, grad_slice, state['exp_avg'], state['exp_avg_sq'],
|
| 440 |
+
self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t,
|
| 441 |
+
self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t,
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
# Large params need all_gather
|
| 445 |
+
if not pinfo['is_small']:
|
| 446 |
+
future = dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()
|
| 447 |
+
gather_list.append(dict(future=future, params=None))
|
| 448 |
+
|
| 449 |
+
def _compute_muon(self, group: dict, info: dict, gather_list: list, rank: int) -> None:
|
| 450 |
+
"""Wait for reduce, compute Muon updates, launch gather."""
|
| 451 |
+
info['future'].wait()
|
| 452 |
+
params = group['params']
|
| 453 |
+
chunk_size = info['chunk_size']
|
| 454 |
+
grad_chunk = info['grad_chunk']
|
| 455 |
+
p = params[0]
|
| 456 |
+
shape, device, dtype = p.shape, p.device, p.dtype
|
| 457 |
+
|
| 458 |
+
# How many params does this rank own?
|
| 459 |
+
start_idx = rank * chunk_size
|
| 460 |
+
num_owned = min(chunk_size, max(0, len(params) - start_idx))
|
| 461 |
+
|
| 462 |
+
# Get or create group-level state
|
| 463 |
+
state = self.state[p]
|
| 464 |
+
if "momentum_buffer" not in state:
|
| 465 |
+
state["momentum_buffer"] = torch.zeros(chunk_size, *shape, dtype=dtype, device=device)
|
| 466 |
+
if "second_momentum_buffer" not in state:
|
| 467 |
+
state_shape = (chunk_size, shape[-2], 1) if shape[-2] >= shape[-1] else (chunk_size, 1, shape[-1])
|
| 468 |
+
state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device)
|
| 469 |
+
red_dim = -1 if shape[-2] >= shape[-1] else -2
|
| 470 |
+
|
| 471 |
+
# Build output buffer for all_gather
|
| 472 |
+
updated_params = torch.empty(chunk_size, *shape, dtype=dtype, device=device)
|
| 473 |
+
|
| 474 |
+
if num_owned > 0:
|
| 475 |
+
owned_params = [params[start_idx + i] for i in range(num_owned)]
|
| 476 |
+
stacked_owned = torch.stack(owned_params)
|
| 477 |
+
|
| 478 |
+
# Fill 0-D tensors and run fused kernel
|
| 479 |
+
self._muon_momentum_t.fill_(group["momentum"])
|
| 480 |
+
self._muon_beta2_t.fill_(group["beta2"])
|
| 481 |
+
self._muon_lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
|
| 482 |
+
self._muon_wd_t.fill_(group["weight_decay"])
|
| 483 |
+
muon_step_fused(
|
| 484 |
+
grad_chunk[:num_owned], stacked_owned,
|
| 485 |
+
state["momentum_buffer"][:num_owned], state["second_momentum_buffer"][:num_owned],
|
| 486 |
+
self._muon_momentum_t, self._muon_lr_t, self._muon_wd_t, self._muon_beta2_t,
|
| 487 |
+
group["ns_steps"], red_dim,
|
| 488 |
+
)
|
| 489 |
+
updated_params[:num_owned].copy_(stacked_owned)
|
| 490 |
+
|
| 491 |
+
if num_owned < chunk_size:
|
| 492 |
+
updated_params[num_owned:].zero_()
|
| 493 |
+
|
| 494 |
+
# Reuse stacked_grads buffer for all_gather output
|
| 495 |
+
stacked_params = info["stacked_grads"]
|
| 496 |
+
future = dist.all_gather_into_tensor(stacked_params, updated_params, async_op=True).get_future()
|
| 497 |
+
gather_list.append(dict(future=future, stacked_params=stacked_params, params=params))
|
| 498 |
+
|
| 499 |
+
def _finish_gathers(self, gather_list: list) -> None:
|
| 500 |
+
"""Wait for all gathers and copy Muon params back."""
|
| 501 |
+
for info in gather_list:
|
| 502 |
+
info["future"].wait()
|
| 503 |
+
if info["params"] is not None:
|
| 504 |
+
# Muon: copy from stacked buffer back to individual params
|
| 505 |
+
torch._foreach_copy_(info["params"], list(info["stacked_params"][:len(info["params"])].unbind(0)))
|
| 506 |
+
|
| 507 |
+
@torch.no_grad()
|
| 508 |
+
def step(self):
|
| 509 |
+
rank = dist.get_rank()
|
| 510 |
+
world_size = dist.get_world_size()
|
| 511 |
+
|
| 512 |
+
# Phase 1: launch all async reduce ops
|
| 513 |
+
reduce_infos: list[dict] = []
|
| 514 |
+
for group in self.param_groups:
|
| 515 |
+
if group['kind'] == 'adamw':
|
| 516 |
+
reduce_infos.append(self._reduce_adamw(group, world_size))
|
| 517 |
+
elif group['kind'] == 'muon':
|
| 518 |
+
reduce_infos.append(self._reduce_muon(group, world_size))
|
| 519 |
+
else:
|
| 520 |
+
raise ValueError(f"Unknown optimizer kind: {group['kind']}")
|
| 521 |
+
|
| 522 |
+
# Phase 2: wait for reduces, compute updates, launch gathers
|
| 523 |
+
gather_list: list[dict] = []
|
| 524 |
+
for group, info in zip(self.param_groups, reduce_infos):
|
| 525 |
+
if group['kind'] == 'adamw':
|
| 526 |
+
self._compute_adamw(group, info, gather_list, rank, world_size)
|
| 527 |
+
elif group['kind'] == 'muon':
|
| 528 |
+
self._compute_muon(group, info, gather_list, rank)
|
| 529 |
+
else:
|
| 530 |
+
raise ValueError(f"Unknown optimizer kind: {group['kind']}")
|
| 531 |
+
|
| 532 |
+
# Phase 3: wait for gathers, copy back
|
| 533 |
+
self._finish_gathers(gather_list)
|
nanochat/tokenizer.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tokenizer module — patched for Victorian LLM HuggingFace Space.
|
| 3 |
+
|
| 4 |
+
Delegates to tokenizer_wrapper.py which provides the VictorianTokenizer class.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import sys
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
# Ensure the app root is on the path so tokenizer_wrapper can be imported
|
| 11 |
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
| 12 |
+
|
| 13 |
+
from tokenizer_wrapper import get_tokenizer, get_token_bytes
|
| 14 |
+
__all__ = ["get_tokenizer", "get_token_bytes"]
|
nanochat/ui.html
ADDED
|
@@ -0,0 +1,566 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="UTF-8">
|
| 5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0, viewport-fit=cover">
|
| 6 |
+
<title>NanoChat</title>
|
| 7 |
+
<link rel="icon" type="image/svg+xml" href="/logo.svg">
|
| 8 |
+
<style>
|
| 9 |
+
:root {
|
| 10 |
+
color-scheme: light;
|
| 11 |
+
}
|
| 12 |
+
|
| 13 |
+
* {
|
| 14 |
+
box-sizing: border-box;
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
html, body{
|
| 18 |
+
height: 100%;
|
| 19 |
+
margin: 0;
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
body {
|
| 23 |
+
font-family: ui-sans-serif, -apple-system, system-ui, "Segoe UI", Helvetica, "Apple Color Emoji", Arial, sans-serif, "Segoe UI Emoji", "Segoe UI Symbol";
|
| 24 |
+
background-color: #ffffff;
|
| 25 |
+
color: #111827;
|
| 26 |
+
min-height: 100dvh;
|
| 27 |
+
margin: 0;
|
| 28 |
+
display: flex;
|
| 29 |
+
flex-direction: column;
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
.header {
|
| 33 |
+
background-color: #ffffff;
|
| 34 |
+
padding: 1.25rem 1.5rem;
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
.header-left {
|
| 38 |
+
display: flex;
|
| 39 |
+
align-items: center;
|
| 40 |
+
gap: 0.75rem;
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
.header-logo {
|
| 44 |
+
height: 32px;
|
| 45 |
+
width: auto;
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
.header h1 {
|
| 49 |
+
font-size: 1.25rem;
|
| 50 |
+
font-weight: 600;
|
| 51 |
+
margin: 0;
|
| 52 |
+
color: #111827;
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
.new-conversation-btn {
|
| 56 |
+
width: 32px;
|
| 57 |
+
height: 32px;
|
| 58 |
+
padding: 0;
|
| 59 |
+
border: 1px solid #e5e7eb;
|
| 60 |
+
border-radius: 0.5rem;
|
| 61 |
+
background-color: #ffffff;
|
| 62 |
+
color: #6b7280;
|
| 63 |
+
cursor: pointer;
|
| 64 |
+
display: flex;
|
| 65 |
+
align-items: center;
|
| 66 |
+
justify-content: center;
|
| 67 |
+
transition: all 0.2s ease;
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
.new-conversation-btn:hover {
|
| 71 |
+
background-color: #f3f4f6;
|
| 72 |
+
border-color: #d1d5db;
|
| 73 |
+
color: #374151;
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
.chat-container {
|
| 77 |
+
flex: 1;
|
| 78 |
+
overflow-y: auto;
|
| 79 |
+
background-color: #ffffff;
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
.chat-wrapper {
|
| 83 |
+
max-width: 48rem;
|
| 84 |
+
margin: 0 auto;
|
| 85 |
+
padding: 2rem 1.5rem 3rem;
|
| 86 |
+
display: flex;
|
| 87 |
+
flex-direction: column;
|
| 88 |
+
gap: 0.75rem;
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
.message {
|
| 92 |
+
display: flex;
|
| 93 |
+
justify-content: flex-start;
|
| 94 |
+
margin-bottom: 0.5rem;
|
| 95 |
+
color: #0d0d0d;
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
.message.assistant {
|
| 99 |
+
justify-content: flex-start;
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
.message.user {
|
| 103 |
+
justify-content: flex-end;
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
.message-content {
|
| 107 |
+
white-space: pre-wrap;
|
| 108 |
+
line-height: 1.6;
|
| 109 |
+
max-width: 100%;
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
.message.assistant .message-content {
|
| 113 |
+
background: transparent;
|
| 114 |
+
border: none;
|
| 115 |
+
cursor: pointer;
|
| 116 |
+
border-radius: 0.5rem;
|
| 117 |
+
padding: 0.5rem;
|
| 118 |
+
margin-left: -0.5rem;
|
| 119 |
+
transition: background-color 0.2s ease;
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
.message.assistant .message-content:hover {
|
| 123 |
+
background-color: #f9fafb;
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
.message.user .message-content {
|
| 127 |
+
background-color: #f3f4f6;
|
| 128 |
+
border-radius: 1.25rem;
|
| 129 |
+
padding: 0.8rem 1rem;
|
| 130 |
+
max-width: 65%;
|
| 131 |
+
cursor: pointer;
|
| 132 |
+
transition: background-color 0.2s ease;
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
.message.user .message-content:hover {
|
| 136 |
+
background-color: #e5e7eb;
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
.message.console .message-content {
|
| 140 |
+
font-family: 'Monaco', 'Menlo', 'Ubuntu Mono', 'Consolas', 'Courier New', monospace;
|
| 141 |
+
font-size: 0.875rem;
|
| 142 |
+
background-color: #fafafa;
|
| 143 |
+
padding: 0.75rem 1rem;
|
| 144 |
+
color: #374151;
|
| 145 |
+
max-width: 80%;
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
.input-container {
|
| 149 |
+
background-color: #ffffff;
|
| 150 |
+
padding: 1rem;
|
| 151 |
+
padding-bottom: calc(1rem + env(safe-area-inset-bottom))
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
.input-wrapper {
|
| 155 |
+
max-width: 48rem;
|
| 156 |
+
margin: 0 auto;
|
| 157 |
+
display: flex;
|
| 158 |
+
gap: 0.75rem;
|
| 159 |
+
align-items: flex-end;
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
.chat-input {
|
| 163 |
+
flex: 1;
|
| 164 |
+
padding: 0.8rem 1rem;
|
| 165 |
+
border: 1px solid #d1d5db;
|
| 166 |
+
border-radius: 0.75rem;
|
| 167 |
+
background-color: #ffffff;
|
| 168 |
+
color: #111827;
|
| 169 |
+
font-size: 1rem;
|
| 170 |
+
line-height: 1.5;
|
| 171 |
+
resize: none;
|
| 172 |
+
outline: none;
|
| 173 |
+
min-height: 54px;
|
| 174 |
+
max-height: 200px;
|
| 175 |
+
transition: border-color 0.2s ease, box-shadow 0.2s ease;
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
.chat-input::placeholder {
|
| 179 |
+
color: #9ca3af;
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
.chat-input:focus {
|
| 183 |
+
border-color: #2563eb;
|
| 184 |
+
box-shadow: 0 0 0 3px rgba(37, 99, 235, 0.1);
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
.send-button {
|
| 188 |
+
flex-shrink: 0;
|
| 189 |
+
padding: 0;
|
| 190 |
+
width: 54px;
|
| 191 |
+
height: 54px;
|
| 192 |
+
border: 1px solid #111827;
|
| 193 |
+
border-radius: 0.75rem;
|
| 194 |
+
background-color: #111827;
|
| 195 |
+
color: #ffffff;
|
| 196 |
+
display: flex;
|
| 197 |
+
align-items: center;
|
| 198 |
+
justify-content: center;
|
| 199 |
+
cursor: pointer;
|
| 200 |
+
transition: background-color 0.2s ease, border-color 0.2s ease, color 0.2s ease;
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
.send-button:hover:not(:disabled) {
|
| 204 |
+
background-color: #2563eb;
|
| 205 |
+
border-color: #2563eb;
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
.send-button:disabled {
|
| 209 |
+
cursor: not-allowed;
|
| 210 |
+
border-color: #d1d5db;
|
| 211 |
+
background-color: #e5e7eb;
|
| 212 |
+
color: #9ca3af;
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
.typing-indicator {
|
| 216 |
+
display: inline-block;
|
| 217 |
+
color: #6b7280;
|
| 218 |
+
letter-spacing: 0.15em;
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
.typing-indicator::after {
|
| 222 |
+
content: '···';
|
| 223 |
+
animation: typing 1.4s infinite;
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
@keyframes typing {
|
| 227 |
+
0%, 60%, 100% { opacity: 0.2; }
|
| 228 |
+
30% { opacity: 1; }
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
.error-message {
|
| 232 |
+
background-color: #fee2e2;
|
| 233 |
+
border: 1px solid #fecaca;
|
| 234 |
+
color: #b91c1c;
|
| 235 |
+
padding: 0.75rem 1rem;
|
| 236 |
+
border-radius: 0.75rem;
|
| 237 |
+
margin-top: 0.5rem;
|
| 238 |
+
}
|
| 239 |
+
</style>
|
| 240 |
+
</head>
|
| 241 |
+
<body>
|
| 242 |
+
<div class="header">
|
| 243 |
+
<div class="header-left">
|
| 244 |
+
<button class="new-conversation-btn" onclick="newConversation()" title="New Conversation (Ctrl+Shift+N)">
|
| 245 |
+
<svg width="18" height="18" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
| 246 |
+
<path d="M12 5v14"></path>
|
| 247 |
+
<path d="M5 12h14"></path>
|
| 248 |
+
</svg>
|
| 249 |
+
</button>
|
| 250 |
+
<h1>nanochat</h1>
|
| 251 |
+
</div>
|
| 252 |
+
</div>
|
| 253 |
+
|
| 254 |
+
<div class="chat-container" id="chatContainer">
|
| 255 |
+
<div class="chat-wrapper" id="chatWrapper">
|
| 256 |
+
<!-- Messages will be added here -->
|
| 257 |
+
</div>
|
| 258 |
+
</div>
|
| 259 |
+
|
| 260 |
+
<div class="input-container">
|
| 261 |
+
<div class="input-wrapper">
|
| 262 |
+
<textarea
|
| 263 |
+
id="chatInput"
|
| 264 |
+
class="chat-input"
|
| 265 |
+
placeholder="Ask anything"
|
| 266 |
+
rows="1"
|
| 267 |
+
onkeydown="handleKeyDown(event)"
|
| 268 |
+
></textarea>
|
| 269 |
+
<button id="sendButton" class="send-button" onclick="sendMessage()" disabled>
|
| 270 |
+
<svg width="22" height="22" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
| 271 |
+
<path d="M22 2L11 13"></path>
|
| 272 |
+
<path d="M22 2l-7 20-4-9-9-4 20-7z"></path>
|
| 273 |
+
</svg>
|
| 274 |
+
</button>
|
| 275 |
+
</div>
|
| 276 |
+
</div>
|
| 277 |
+
|
| 278 |
+
<script>
|
| 279 |
+
const API_URL = '';
|
| 280 |
+
const chatContainer = document.getElementById('chatContainer');
|
| 281 |
+
const chatWrapper = document.getElementById('chatWrapper');
|
| 282 |
+
const chatInput = document.getElementById('chatInput');
|
| 283 |
+
const sendButton = document.getElementById('sendButton');
|
| 284 |
+
|
| 285 |
+
let messages = [];
|
| 286 |
+
let isGenerating = false;
|
| 287 |
+
let currentTemperature = 0.8;
|
| 288 |
+
let currentTopK = 50;
|
| 289 |
+
|
| 290 |
+
chatInput.addEventListener('input', function() {
|
| 291 |
+
this.style.height = 'auto';
|
| 292 |
+
this.style.height = Math.min(this.scrollHeight, 200) + 'px';
|
| 293 |
+
sendButton.disabled = !this.value.trim() || isGenerating;
|
| 294 |
+
});
|
| 295 |
+
|
| 296 |
+
function handleKeyDown(event) {
|
| 297 |
+
if (event.key === 'Enter' && !event.shiftKey) {
|
| 298 |
+
event.preventDefault();
|
| 299 |
+
sendMessage();
|
| 300 |
+
}
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
document.addEventListener('keydown', function(event) {
|
| 304 |
+
// Ctrl+Shift+N for new conversation
|
| 305 |
+
if (event.ctrlKey && event.shiftKey && event.key === 'N') {
|
| 306 |
+
event.preventDefault();
|
| 307 |
+
if (!isGenerating) {
|
| 308 |
+
newConversation();
|
| 309 |
+
}
|
| 310 |
+
}
|
| 311 |
+
});
|
| 312 |
+
|
| 313 |
+
function newConversation() {
|
| 314 |
+
messages = [];
|
| 315 |
+
chatWrapper.innerHTML = '';
|
| 316 |
+
chatInput.value = '';
|
| 317 |
+
chatInput.style.height = 'auto';
|
| 318 |
+
sendButton.disabled = false;
|
| 319 |
+
isGenerating = false;
|
| 320 |
+
chatInput.focus();
|
| 321 |
+
}
|
| 322 |
+
|
| 323 |
+
function addMessage(role, content, messageIndex = null) {
|
| 324 |
+
const messageDiv = document.createElement('div');
|
| 325 |
+
messageDiv.className = `message ${role}`;
|
| 326 |
+
|
| 327 |
+
const contentDiv = document.createElement('div');
|
| 328 |
+
contentDiv.className = 'message-content';
|
| 329 |
+
contentDiv.textContent = content;
|
| 330 |
+
|
| 331 |
+
// Add click handler for user messages to enable editing
|
| 332 |
+
if (role === 'user' && messageIndex !== null) {
|
| 333 |
+
contentDiv.setAttribute('data-message-index', messageIndex);
|
| 334 |
+
contentDiv.setAttribute('title', 'Click to edit and restart from here');
|
| 335 |
+
contentDiv.addEventListener('click', function() {
|
| 336 |
+
if (!isGenerating) {
|
| 337 |
+
editMessage(messageIndex);
|
| 338 |
+
}
|
| 339 |
+
});
|
| 340 |
+
}
|
| 341 |
+
|
| 342 |
+
// Add click handler for assistant messages to enable regeneration
|
| 343 |
+
if (role === 'assistant' && messageIndex !== null) {
|
| 344 |
+
contentDiv.setAttribute('data-message-index', messageIndex);
|
| 345 |
+
contentDiv.setAttribute('title', 'Click to regenerate this response');
|
| 346 |
+
contentDiv.addEventListener('click', function() {
|
| 347 |
+
if (!isGenerating) {
|
| 348 |
+
regenerateMessage(messageIndex);
|
| 349 |
+
}
|
| 350 |
+
});
|
| 351 |
+
}
|
| 352 |
+
|
| 353 |
+
messageDiv.appendChild(contentDiv);
|
| 354 |
+
chatWrapper.appendChild(messageDiv);
|
| 355 |
+
|
| 356 |
+
chatContainer.scrollTop = chatContainer.scrollHeight;
|
| 357 |
+
return contentDiv;
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
function editMessage(messageIndex) {
|
| 361 |
+
// Find the message in the messages array
|
| 362 |
+
if (messageIndex < 0 || messageIndex >= messages.length) return;
|
| 363 |
+
|
| 364 |
+
const messageToEdit = messages[messageIndex];
|
| 365 |
+
if (messageToEdit.role !== 'user') return;
|
| 366 |
+
|
| 367 |
+
// Copy message content to input
|
| 368 |
+
chatInput.value = messageToEdit.content;
|
| 369 |
+
chatInput.style.height = 'auto';
|
| 370 |
+
chatInput.style.height = Math.min(chatInput.scrollHeight, 200) + 'px';
|
| 371 |
+
|
| 372 |
+
// Remove this message and all subsequent messages from the array
|
| 373 |
+
messages = messages.slice(0, messageIndex);
|
| 374 |
+
|
| 375 |
+
// Remove message elements from DOM starting from messageIndex
|
| 376 |
+
const allMessages = chatWrapper.querySelectorAll('.message');
|
| 377 |
+
for (let i = messageIndex; i < allMessages.length; i++) {
|
| 378 |
+
allMessages[i].remove();
|
| 379 |
+
}
|
| 380 |
+
|
| 381 |
+
// Enable send button and focus input
|
| 382 |
+
sendButton.disabled = false;
|
| 383 |
+
chatInput.focus();
|
| 384 |
+
}
|
| 385 |
+
|
| 386 |
+
async function generateAssistantResponse() {
|
| 387 |
+
isGenerating = true;
|
| 388 |
+
sendButton.disabled = true;
|
| 389 |
+
|
| 390 |
+
const assistantContent = addMessage('assistant', '');
|
| 391 |
+
assistantContent.innerHTML = '<span class="typing-indicator"></span>';
|
| 392 |
+
|
| 393 |
+
try {
|
| 394 |
+
const response = await fetch(`${API_URL}/chat/completions`, {
|
| 395 |
+
method: 'POST',
|
| 396 |
+
headers: {
|
| 397 |
+
'Content-Type': 'application/json',
|
| 398 |
+
},
|
| 399 |
+
body: JSON.stringify({
|
| 400 |
+
messages: messages,
|
| 401 |
+
temperature: currentTemperature,
|
| 402 |
+
top_k: currentTopK,
|
| 403 |
+
max_tokens: 512
|
| 404 |
+
}),
|
| 405 |
+
});
|
| 406 |
+
|
| 407 |
+
if (!response.ok) {
|
| 408 |
+
throw new Error(`HTTP error! status: ${response.status}`);
|
| 409 |
+
}
|
| 410 |
+
|
| 411 |
+
const reader = response.body.getReader();
|
| 412 |
+
const decoder = new TextDecoder();
|
| 413 |
+
let fullResponse = '';
|
| 414 |
+
assistantContent.textContent = '';
|
| 415 |
+
|
| 416 |
+
while (true) {
|
| 417 |
+
const { done, value } = await reader.read();
|
| 418 |
+
if (done) break;
|
| 419 |
+
|
| 420 |
+
const chunk = decoder.decode(value);
|
| 421 |
+
const lines = chunk.split('\n');
|
| 422 |
+
|
| 423 |
+
for (const line of lines) {
|
| 424 |
+
if (line.startsWith('data: ')) {
|
| 425 |
+
try {
|
| 426 |
+
const data = JSON.parse(line.slice(6));
|
| 427 |
+
if (data.token) {
|
| 428 |
+
fullResponse += data.token;
|
| 429 |
+
assistantContent.textContent = fullResponse;
|
| 430 |
+
chatContainer.scrollTop = chatContainer.scrollHeight;
|
| 431 |
+
}
|
| 432 |
+
} catch (e) {
|
| 433 |
+
}
|
| 434 |
+
}
|
| 435 |
+
}
|
| 436 |
+
}
|
| 437 |
+
|
| 438 |
+
const assistantMessageIndex = messages.length;
|
| 439 |
+
messages.push({ role: 'assistant', content: fullResponse });
|
| 440 |
+
|
| 441 |
+
// Add click handler to regenerate this assistant message
|
| 442 |
+
assistantContent.setAttribute('data-message-index', assistantMessageIndex);
|
| 443 |
+
assistantContent.setAttribute('title', 'Click to regenerate this response');
|
| 444 |
+
assistantContent.addEventListener('click', function() {
|
| 445 |
+
if (!isGenerating) {
|
| 446 |
+
regenerateMessage(assistantMessageIndex);
|
| 447 |
+
}
|
| 448 |
+
});
|
| 449 |
+
|
| 450 |
+
} catch (error) {
|
| 451 |
+
console.error('Error:', error);
|
| 452 |
+
assistantContent.innerHTML = `<div class="error-message">Error: ${error.message}</div>`;
|
| 453 |
+
} finally {
|
| 454 |
+
isGenerating = false;
|
| 455 |
+
sendButton.disabled = !chatInput.value.trim();
|
| 456 |
+
}
|
| 457 |
+
}
|
| 458 |
+
|
| 459 |
+
async function regenerateMessage(messageIndex) {
|
| 460 |
+
// Find the message in the messages array
|
| 461 |
+
if (messageIndex < 0 || messageIndex >= messages.length) return;
|
| 462 |
+
|
| 463 |
+
const messageToRegenerate = messages[messageIndex];
|
| 464 |
+
if (messageToRegenerate.role !== 'assistant') return;
|
| 465 |
+
|
| 466 |
+
// Remove this message and all subsequent messages from the array
|
| 467 |
+
messages = messages.slice(0, messageIndex);
|
| 468 |
+
|
| 469 |
+
// Remove message elements from DOM starting from messageIndex
|
| 470 |
+
const allMessages = chatWrapper.querySelectorAll('.message');
|
| 471 |
+
for (let i = messageIndex; i < allMessages.length; i++) {
|
| 472 |
+
allMessages[i].remove();
|
| 473 |
+
}
|
| 474 |
+
|
| 475 |
+
// Regenerate the assistant response
|
| 476 |
+
await generateAssistantResponse();
|
| 477 |
+
}
|
| 478 |
+
|
| 479 |
+
function handleSlashCommand(command) {
|
| 480 |
+
const parts = command.trim().split(/\s+/);
|
| 481 |
+
const cmd = parts[0].toLowerCase();
|
| 482 |
+
const arg = parts[1];
|
| 483 |
+
|
| 484 |
+
if (cmd === '/temperature') {
|
| 485 |
+
if (arg === undefined) {
|
| 486 |
+
addMessage('console', `Current temperature: ${currentTemperature}`);
|
| 487 |
+
} else {
|
| 488 |
+
const temp = parseFloat(arg);
|
| 489 |
+
if (isNaN(temp) || temp < 0 || temp > 2) {
|
| 490 |
+
addMessage('console', 'Invalid temperature. Must be between 0.0 and 2.0');
|
| 491 |
+
} else {
|
| 492 |
+
currentTemperature = temp;
|
| 493 |
+
addMessage('console', `Temperature set to ${currentTemperature}`);
|
| 494 |
+
}
|
| 495 |
+
}
|
| 496 |
+
return true;
|
| 497 |
+
} else if (cmd === '/topk') {
|
| 498 |
+
if (arg === undefined) {
|
| 499 |
+
addMessage('console', `Current top-k: ${currentTopK}`);
|
| 500 |
+
} else {
|
| 501 |
+
const topk = parseInt(arg);
|
| 502 |
+
if (isNaN(topk) || topk < 1 || topk > 200) {
|
| 503 |
+
addMessage('console', 'Invalid top-k. Must be between 1 and 200');
|
| 504 |
+
} else {
|
| 505 |
+
currentTopK = topk;
|
| 506 |
+
addMessage('console', `Top-k set to ${currentTopK}`);
|
| 507 |
+
}
|
| 508 |
+
}
|
| 509 |
+
return true;
|
| 510 |
+
} else if (cmd === '/clear') {
|
| 511 |
+
newConversation();
|
| 512 |
+
return true;
|
| 513 |
+
} else if (cmd === '/help') {
|
| 514 |
+
addMessage('console',
|
| 515 |
+
'Available commands:\n' +
|
| 516 |
+
'/temperature - Show current temperature\n' +
|
| 517 |
+
'/temperature <value> - Set temperature (0.0-2.0)\n' +
|
| 518 |
+
'/topk - Show current top-k\n' +
|
| 519 |
+
'/topk <value> - Set top-k (1-200)\n' +
|
| 520 |
+
'/clear - Clear conversation\n' +
|
| 521 |
+
'/help - Show this help message'
|
| 522 |
+
);
|
| 523 |
+
return true;
|
| 524 |
+
}
|
| 525 |
+
return false;
|
| 526 |
+
}
|
| 527 |
+
|
| 528 |
+
async function sendMessage() {
|
| 529 |
+
const message = chatInput.value.trim();
|
| 530 |
+
if (!message || isGenerating) return;
|
| 531 |
+
|
| 532 |
+
// Handle slash commands
|
| 533 |
+
if (message.startsWith('/')) {
|
| 534 |
+
chatInput.value = '';
|
| 535 |
+
chatInput.style.height = 'auto';
|
| 536 |
+
handleSlashCommand(message);
|
| 537 |
+
return;
|
| 538 |
+
}
|
| 539 |
+
|
| 540 |
+
chatInput.value = '';
|
| 541 |
+
chatInput.style.height = 'auto';
|
| 542 |
+
|
| 543 |
+
const userMessageIndex = messages.length;
|
| 544 |
+
messages.push({ role: 'user', content: message });
|
| 545 |
+
addMessage('user', message, userMessageIndex);
|
| 546 |
+
|
| 547 |
+
await generateAssistantResponse();
|
| 548 |
+
}
|
| 549 |
+
|
| 550 |
+
sendButton.disabled = false;
|
| 551 |
+
|
| 552 |
+
// Autofocus the chat input on page load
|
| 553 |
+
chatInput.focus();
|
| 554 |
+
|
| 555 |
+
fetch(`${API_URL}/health`)
|
| 556 |
+
.then(response => response.json())
|
| 557 |
+
.then(data => {
|
| 558 |
+
console.log('Engine status:', data);
|
| 559 |
+
})
|
| 560 |
+
.catch(error => {
|
| 561 |
+
console.error('Engine not available:', error);
|
| 562 |
+
chatWrapper.innerHTML = '<div class="error-message">Engine not running. Please start engine.py first.</div>';
|
| 563 |
+
});
|
| 564 |
+
</script>
|
| 565 |
+
</body>
|
| 566 |
+
</html>
|
scripts/__init__.py
ADDED
|
File without changes
|
scripts/chat_web.py
ADDED
|
@@ -0,0 +1,421 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Unified web chat server - serves both UI and API from a single FastAPI instance.
|
| 4 |
+
|
| 5 |
+
Uses data parallelism to distribute requests across multiple GPUs. Each GPU loads
|
| 6 |
+
a full copy of the model, and incoming requests are distributed to available workers.
|
| 7 |
+
|
| 8 |
+
Launch examples:
|
| 9 |
+
|
| 10 |
+
- single available GPU (default)
|
| 11 |
+
python -m scripts.chat_web
|
| 12 |
+
|
| 13 |
+
- 4 GPUs
|
| 14 |
+
python -m scripts.chat_web --num-gpus 4
|
| 15 |
+
|
| 16 |
+
To chat, open the URL printed in the console. (If on cloud box, make sure to use public IP)
|
| 17 |
+
|
| 18 |
+
Endpoints:
|
| 19 |
+
GET / - Chat UI
|
| 20 |
+
POST /chat/completions - Chat API (streaming only)
|
| 21 |
+
GET /health - Health check with worker pool status
|
| 22 |
+
GET /stats - Worker pool statistics and GPU utilization
|
| 23 |
+
|
| 24 |
+
Abuse Prevention:
|
| 25 |
+
- Maximum 500 messages per request
|
| 26 |
+
- Maximum 8000 characters per message
|
| 27 |
+
- Maximum 32000 characters total conversation length
|
| 28 |
+
- Temperature clamped to 0.0-2.0
|
| 29 |
+
- Top-k clamped to 0-200 (0 disables top-k filtering, using full vocabulary)
|
| 30 |
+
- Max tokens clamped to 1-4096
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
import argparse
|
| 34 |
+
import json
|
| 35 |
+
import os
|
| 36 |
+
import torch
|
| 37 |
+
import asyncio
|
| 38 |
+
import logging
|
| 39 |
+
import random
|
| 40 |
+
from contextlib import asynccontextmanager
|
| 41 |
+
from fastapi import FastAPI, HTTPException
|
| 42 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 43 |
+
from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
|
| 44 |
+
from pydantic import BaseModel
|
| 45 |
+
from typing import List, Optional, AsyncGenerator
|
| 46 |
+
from dataclasses import dataclass
|
| 47 |
+
from nanochat.common import compute_init, autodetect_device_type
|
| 48 |
+
from nanochat.checkpoint_manager import load_model
|
| 49 |
+
from nanochat.engine import Engine
|
| 50 |
+
|
| 51 |
+
# Victorian system prompt — prepended to first user turn during inference
|
| 52 |
+
SYSTEM_PREFIX = (
|
| 53 |
+
"[You are a learned Victorian gentleman in conversation. "
|
| 54 |
+
"Address the question or remark put to you directly.]\n\n"
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
# Abuse prevention limits
|
| 58 |
+
MAX_MESSAGES_PER_REQUEST = 500
|
| 59 |
+
MAX_MESSAGE_LENGTH = 8000
|
| 60 |
+
MAX_TOTAL_CONVERSATION_LENGTH = 32000
|
| 61 |
+
MIN_TEMPERATURE = 0.0
|
| 62 |
+
MAX_TEMPERATURE = 2.0
|
| 63 |
+
MIN_TOP_K = 0 # 0 disables top-k filtering, using full vocabulary
|
| 64 |
+
MAX_TOP_K = 200
|
| 65 |
+
MIN_MAX_TOKENS = 1
|
| 66 |
+
MAX_MAX_TOKENS = 4096
|
| 67 |
+
|
| 68 |
+
parser = argparse.ArgumentParser(description='NanoChat Web Server')
|
| 69 |
+
parser.add_argument('-n', '--num-gpus', type=int, default=1, help='Number of GPUs to use (default: 1)')
|
| 70 |
+
parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|rl")
|
| 71 |
+
parser.add_argument('-t', '--temperature', type=float, default=0.7, help='Default temperature for generation')
|
| 72 |
+
parser.add_argument('-k', '--top-k', type=int, default=50, help='Default top-k sampling parameter')
|
| 73 |
+
parser.add_argument('-m', '--max-tokens', type=int, default=512, help='Default max tokens for generation')
|
| 74 |
+
parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
|
| 75 |
+
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
|
| 76 |
+
parser.add_argument('-p', '--port', type=int, default=8000, help='Port to run the server on')
|
| 77 |
+
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
|
| 78 |
+
parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to')
|
| 79 |
+
args = parser.parse_args()
|
| 80 |
+
|
| 81 |
+
# Configure logging for conversation traffic
|
| 82 |
+
logging.basicConfig(
|
| 83 |
+
level=logging.INFO,
|
| 84 |
+
format='%(asctime)s - %(message)s',
|
| 85 |
+
datefmt='%Y-%m-%d %H:%M:%S'
|
| 86 |
+
)
|
| 87 |
+
logger = logging.getLogger(__name__)
|
| 88 |
+
|
| 89 |
+
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
| 90 |
+
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
| 91 |
+
|
| 92 |
+
@dataclass
|
| 93 |
+
class Worker:
|
| 94 |
+
"""A worker with a model loaded on a specific GPU."""
|
| 95 |
+
gpu_id: int
|
| 96 |
+
device: torch.device
|
| 97 |
+
engine: Engine
|
| 98 |
+
tokenizer: object
|
| 99 |
+
|
| 100 |
+
class WorkerPool:
|
| 101 |
+
"""Pool of workers, each with a model replica on a different GPU."""
|
| 102 |
+
|
| 103 |
+
def __init__(self, num_gpus: Optional[int] = None):
|
| 104 |
+
if num_gpus is None:
|
| 105 |
+
if device_type == "cuda":
|
| 106 |
+
num_gpus = torch.cuda.device_count()
|
| 107 |
+
else:
|
| 108 |
+
num_gpus = 1 # e.g. cpu|mps
|
| 109 |
+
self.num_gpus = num_gpus
|
| 110 |
+
self.workers: List[Worker] = []
|
| 111 |
+
self.available_workers: asyncio.Queue = asyncio.Queue()
|
| 112 |
+
|
| 113 |
+
async def initialize(self, source: str, model_tag: Optional[str] = None, step: Optional[int] = None):
|
| 114 |
+
"""Load model on each GPU."""
|
| 115 |
+
print(f"Initializing worker pool with {self.num_gpus} GPUs...")
|
| 116 |
+
if self.num_gpus > 1:
|
| 117 |
+
assert device_type == "cuda", "Only CUDA supports multiple workers/GPUs. cpu|mps does not."
|
| 118 |
+
|
| 119 |
+
for gpu_id in range(self.num_gpus):
|
| 120 |
+
|
| 121 |
+
if device_type == "cuda":
|
| 122 |
+
device = torch.device(f"cuda:{gpu_id}")
|
| 123 |
+
print(f"Loading model on GPU {gpu_id}...")
|
| 124 |
+
else:
|
| 125 |
+
device = torch.device(device_type) # e.g. cpu|mps
|
| 126 |
+
print(f"Loading model on {device_type}...")
|
| 127 |
+
|
| 128 |
+
model, tokenizer, _ = load_model(source, device, phase="eval", model_tag=model_tag, step=step)
|
| 129 |
+
engine = Engine(model, tokenizer)
|
| 130 |
+
worker = Worker(
|
| 131 |
+
gpu_id=gpu_id,
|
| 132 |
+
device=device,
|
| 133 |
+
engine=engine,
|
| 134 |
+
tokenizer=tokenizer,
|
| 135 |
+
)
|
| 136 |
+
self.workers.append(worker)
|
| 137 |
+
await self.available_workers.put(worker)
|
| 138 |
+
|
| 139 |
+
print(f"All {self.num_gpus} workers initialized!")
|
| 140 |
+
|
| 141 |
+
async def acquire_worker(self) -> Worker:
|
| 142 |
+
"""Get an available worker from the pool."""
|
| 143 |
+
return await self.available_workers.get()
|
| 144 |
+
|
| 145 |
+
async def release_worker(self, worker: Worker):
|
| 146 |
+
"""Return a worker to the pool."""
|
| 147 |
+
await self.available_workers.put(worker)
|
| 148 |
+
|
| 149 |
+
class ChatMessage(BaseModel):
|
| 150 |
+
role: str
|
| 151 |
+
content: str
|
| 152 |
+
|
| 153 |
+
class ChatRequest(BaseModel):
|
| 154 |
+
messages: List[ChatMessage]
|
| 155 |
+
temperature: Optional[float] = None
|
| 156 |
+
max_tokens: Optional[int] = None
|
| 157 |
+
top_k: Optional[int] = None
|
| 158 |
+
|
| 159 |
+
def validate_chat_request(request: ChatRequest):
|
| 160 |
+
"""Validate chat request to prevent abuse."""
|
| 161 |
+
# Check number of messages
|
| 162 |
+
if len(request.messages) == 0:
|
| 163 |
+
raise HTTPException(status_code=400, detail="At least one message is required")
|
| 164 |
+
if len(request.messages) > MAX_MESSAGES_PER_REQUEST:
|
| 165 |
+
raise HTTPException(
|
| 166 |
+
status_code=400,
|
| 167 |
+
detail=f"Too many messages. Maximum {MAX_MESSAGES_PER_REQUEST} messages allowed per request"
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
# Check individual message lengths and total conversation length
|
| 171 |
+
total_length = 0
|
| 172 |
+
for i, message in enumerate(request.messages):
|
| 173 |
+
if not message.content:
|
| 174 |
+
raise HTTPException(status_code=400, detail=f"Message {i} has empty content")
|
| 175 |
+
|
| 176 |
+
msg_length = len(message.content)
|
| 177 |
+
if msg_length > MAX_MESSAGE_LENGTH:
|
| 178 |
+
raise HTTPException(
|
| 179 |
+
status_code=400,
|
| 180 |
+
detail=f"Message {i} is too long. Maximum {MAX_MESSAGE_LENGTH} characters allowed per message"
|
| 181 |
+
)
|
| 182 |
+
total_length += msg_length
|
| 183 |
+
|
| 184 |
+
if total_length > MAX_TOTAL_CONVERSATION_LENGTH:
|
| 185 |
+
raise HTTPException(
|
| 186 |
+
status_code=400,
|
| 187 |
+
detail=f"Total conversation is too long. Maximum {MAX_TOTAL_CONVERSATION_LENGTH} characters allowed"
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
# Validate role values
|
| 191 |
+
for i, message in enumerate(request.messages):
|
| 192 |
+
if message.role not in ["user", "assistant"]:
|
| 193 |
+
raise HTTPException(
|
| 194 |
+
status_code=400,
|
| 195 |
+
detail=f"Message {i} has invalid role. Must be 'user', 'assistant', or 'system'"
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
# Validate temperature
|
| 199 |
+
if request.temperature is not None:
|
| 200 |
+
if not (MIN_TEMPERATURE <= request.temperature <= MAX_TEMPERATURE):
|
| 201 |
+
raise HTTPException(
|
| 202 |
+
status_code=400,
|
| 203 |
+
detail=f"Temperature must be between {MIN_TEMPERATURE} and {MAX_TEMPERATURE}"
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
# Validate top_k
|
| 207 |
+
if request.top_k is not None:
|
| 208 |
+
if not (MIN_TOP_K <= request.top_k <= MAX_TOP_K):
|
| 209 |
+
raise HTTPException(
|
| 210 |
+
status_code=400,
|
| 211 |
+
detail=f"top_k must be between {MIN_TOP_K} and {MAX_TOP_K}"
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
# Validate max_tokens
|
| 215 |
+
if request.max_tokens is not None:
|
| 216 |
+
if not (MIN_MAX_TOKENS <= request.max_tokens <= MAX_MAX_TOKENS):
|
| 217 |
+
raise HTTPException(
|
| 218 |
+
status_code=400,
|
| 219 |
+
detail=f"max_tokens must be between {MIN_MAX_TOKENS} and {MAX_MAX_TOKENS}"
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
@asynccontextmanager
|
| 223 |
+
async def lifespan(app: FastAPI):
|
| 224 |
+
"""Load models on all GPUs on startup."""
|
| 225 |
+
print("Loading nanochat models across GPUs...")
|
| 226 |
+
app.state.worker_pool = WorkerPool(num_gpus=args.num_gpus)
|
| 227 |
+
await app.state.worker_pool.initialize(args.source, model_tag=args.model_tag, step=args.step)
|
| 228 |
+
print(f"Server ready at http://localhost:{args.port}")
|
| 229 |
+
yield
|
| 230 |
+
|
| 231 |
+
app = FastAPI(lifespan=lifespan)
|
| 232 |
+
|
| 233 |
+
app.add_middleware(
|
| 234 |
+
CORSMiddleware,
|
| 235 |
+
allow_origins=["*"],
|
| 236 |
+
allow_credentials=True,
|
| 237 |
+
allow_methods=["*"],
|
| 238 |
+
allow_headers=["*"],
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
@app.get("/")
|
| 242 |
+
async def root():
|
| 243 |
+
"""Serve the chat UI."""
|
| 244 |
+
ui_html_path = os.path.join("nanochat", "ui.html")
|
| 245 |
+
with open(ui_html_path, "r", encoding="utf-8") as f:
|
| 246 |
+
html_content = f.read()
|
| 247 |
+
# Replace the API_URL to use the same origin
|
| 248 |
+
html_content = html_content.replace(
|
| 249 |
+
"const API_URL = `http://${window.location.hostname}:8000`;",
|
| 250 |
+
"const API_URL = '';"
|
| 251 |
+
)
|
| 252 |
+
return HTMLResponse(content=html_content)
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
@app.get("/logo.svg")
|
| 256 |
+
async def logo():
|
| 257 |
+
"""Serve the NanoChat logo for favicon and header."""
|
| 258 |
+
logo_path = os.path.join("nanochat", "logo.svg")
|
| 259 |
+
return FileResponse(logo_path, media_type="image/svg+xml")
|
| 260 |
+
|
| 261 |
+
async def generate_stream(
|
| 262 |
+
worker: Worker,
|
| 263 |
+
tokens,
|
| 264 |
+
temperature=None,
|
| 265 |
+
max_new_tokens=None,
|
| 266 |
+
top_k=None
|
| 267 |
+
) -> AsyncGenerator[str, None]:
|
| 268 |
+
"""Generate assistant response with streaming."""
|
| 269 |
+
temperature = temperature if temperature is not None else args.temperature
|
| 270 |
+
max_new_tokens = max_new_tokens if max_new_tokens is not None else args.max_tokens
|
| 271 |
+
top_k = top_k if top_k is not None else args.top_k
|
| 272 |
+
|
| 273 |
+
assistant_end = worker.tokenizer.encode_special("<|assistant_end|>")
|
| 274 |
+
bos = worker.tokenizer.get_bos_token_id()
|
| 275 |
+
|
| 276 |
+
# Accumulate tokens to properly handle multi-byte UTF-8 characters (like emojis)
|
| 277 |
+
accumulated_tokens = []
|
| 278 |
+
# Track the last complete UTF-8 string (without replacement characters)
|
| 279 |
+
last_clean_text = ""
|
| 280 |
+
|
| 281 |
+
for token_column, token_masks in worker.engine.generate(
|
| 282 |
+
tokens,
|
| 283 |
+
num_samples=1,
|
| 284 |
+
max_tokens=max_new_tokens,
|
| 285 |
+
temperature=temperature,
|
| 286 |
+
top_k=top_k,
|
| 287 |
+
repetition_penalty=1.3,
|
| 288 |
+
repetition_window=64,
|
| 289 |
+
seed=random.randint(0, 2**31 - 1)
|
| 290 |
+
):
|
| 291 |
+
token = token_column[0]
|
| 292 |
+
|
| 293 |
+
# Stopping criteria
|
| 294 |
+
if token == assistant_end or token == bos:
|
| 295 |
+
break
|
| 296 |
+
|
| 297 |
+
# Append the token to sequence
|
| 298 |
+
accumulated_tokens.append(token)
|
| 299 |
+
# Decode all accumulated tokens to get proper UTF-8 handling
|
| 300 |
+
# Note that decode is a quite efficient operation, basically table lookup and string concat
|
| 301 |
+
current_text = worker.tokenizer.decode(accumulated_tokens)
|
| 302 |
+
# Only emit text if it doesn't end with a replacement character
|
| 303 |
+
# This ensures we don't emit incomplete UTF-8 sequences
|
| 304 |
+
if not current_text.endswith('�'):
|
| 305 |
+
# Extract only the new text since last clean decode
|
| 306 |
+
new_text = current_text[len(last_clean_text):]
|
| 307 |
+
if new_text: # Only yield if there's new content
|
| 308 |
+
yield f"data: {json.dumps({'token': new_text, 'gpu': worker.gpu_id}, ensure_ascii=False)}\n\n"
|
| 309 |
+
last_clean_text = current_text
|
| 310 |
+
|
| 311 |
+
yield f"data: {json.dumps({'done': True})}\n\n"
|
| 312 |
+
|
| 313 |
+
@app.post("/chat/completions")
|
| 314 |
+
async def chat_completions(request: ChatRequest):
|
| 315 |
+
"""Chat completion endpoint (streaming only) - uses worker pool for multi-GPU."""
|
| 316 |
+
|
| 317 |
+
# Basic validation to prevent abuse
|
| 318 |
+
validate_chat_request(request)
|
| 319 |
+
|
| 320 |
+
# Log incoming conversation to console
|
| 321 |
+
logger.info("="*20)
|
| 322 |
+
for i, message in enumerate(request.messages):
|
| 323 |
+
logger.info(f"[{message.role.upper()}]: {message.content}")
|
| 324 |
+
logger.info("-"*20)
|
| 325 |
+
|
| 326 |
+
# Acquire a worker from the pool (will wait if all are busy)
|
| 327 |
+
worker_pool = app.state.worker_pool
|
| 328 |
+
worker = await worker_pool.acquire_worker()
|
| 329 |
+
|
| 330 |
+
try:
|
| 331 |
+
# Build conversation tokens
|
| 332 |
+
bos = worker.tokenizer.get_bos_token_id()
|
| 333 |
+
user_start = worker.tokenizer.encode_special("<|user_start|>")
|
| 334 |
+
user_end = worker.tokenizer.encode_special("<|user_end|>")
|
| 335 |
+
assistant_start = worker.tokenizer.encode_special("<|assistant_start|>")
|
| 336 |
+
assistant_end = worker.tokenizer.encode_special("<|assistant_end|>")
|
| 337 |
+
|
| 338 |
+
conversation_tokens = [bos]
|
| 339 |
+
turn_count = 0
|
| 340 |
+
for message in request.messages:
|
| 341 |
+
if message.role == "user":
|
| 342 |
+
content = message.content
|
| 343 |
+
# Prepend system prompt to the first user turn
|
| 344 |
+
if turn_count == 0:
|
| 345 |
+
content = SYSTEM_PREFIX + content
|
| 346 |
+
conversation_tokens.append(user_start)
|
| 347 |
+
conversation_tokens.extend(worker.tokenizer.encode(content))
|
| 348 |
+
conversation_tokens.append(user_end)
|
| 349 |
+
turn_count += 1
|
| 350 |
+
elif message.role == "assistant":
|
| 351 |
+
conversation_tokens.append(assistant_start)
|
| 352 |
+
conversation_tokens.extend(worker.tokenizer.encode(message.content))
|
| 353 |
+
conversation_tokens.append(assistant_end)
|
| 354 |
+
|
| 355 |
+
conversation_tokens.append(assistant_start)
|
| 356 |
+
|
| 357 |
+
# Streaming response with worker release after completion
|
| 358 |
+
response_tokens = []
|
| 359 |
+
async def stream_and_release():
|
| 360 |
+
try:
|
| 361 |
+
async for chunk in generate_stream(
|
| 362 |
+
worker,
|
| 363 |
+
conversation_tokens,
|
| 364 |
+
temperature=request.temperature,
|
| 365 |
+
max_new_tokens=request.max_tokens,
|
| 366 |
+
top_k=request.top_k
|
| 367 |
+
):
|
| 368 |
+
# Accumulate response for logging
|
| 369 |
+
chunk_data = json.loads(chunk.replace("data: ", "").strip())
|
| 370 |
+
if "token" in chunk_data:
|
| 371 |
+
response_tokens.append(chunk_data["token"])
|
| 372 |
+
yield chunk
|
| 373 |
+
finally:
|
| 374 |
+
# Log the assistant response to console
|
| 375 |
+
full_response = "".join(response_tokens)
|
| 376 |
+
logger.info(f"[ASSISTANT] (GPU {worker.gpu_id}): {full_response}")
|
| 377 |
+
logger.info("="*20)
|
| 378 |
+
# Release worker back to pool after streaming is done
|
| 379 |
+
await worker_pool.release_worker(worker)
|
| 380 |
+
|
| 381 |
+
return StreamingResponse(
|
| 382 |
+
stream_and_release(),
|
| 383 |
+
media_type="text/event-stream"
|
| 384 |
+
)
|
| 385 |
+
except Exception as e:
|
| 386 |
+
# Make sure to release worker even on error
|
| 387 |
+
await worker_pool.release_worker(worker)
|
| 388 |
+
raise e
|
| 389 |
+
|
| 390 |
+
@app.get("/health")
|
| 391 |
+
async def health():
|
| 392 |
+
"""Health check endpoint."""
|
| 393 |
+
worker_pool = getattr(app.state, 'worker_pool', None)
|
| 394 |
+
return {
|
| 395 |
+
"status": "ok",
|
| 396 |
+
"ready": worker_pool is not None and len(worker_pool.workers) > 0,
|
| 397 |
+
"num_gpus": worker_pool.num_gpus if worker_pool else 0,
|
| 398 |
+
"available_workers": worker_pool.available_workers.qsize() if worker_pool else 0
|
| 399 |
+
}
|
| 400 |
+
|
| 401 |
+
@app.get("/stats")
|
| 402 |
+
async def stats():
|
| 403 |
+
"""Get worker pool statistics."""
|
| 404 |
+
worker_pool = app.state.worker_pool
|
| 405 |
+
return {
|
| 406 |
+
"total_workers": len(worker_pool.workers),
|
| 407 |
+
"available_workers": worker_pool.available_workers.qsize(),
|
| 408 |
+
"busy_workers": len(worker_pool.workers) - worker_pool.available_workers.qsize(),
|
| 409 |
+
"workers": [
|
| 410 |
+
{
|
| 411 |
+
"gpu_id": w.gpu_id,
|
| 412 |
+
"device": str(w.device)
|
| 413 |
+
} for w in worker_pool.workers
|
| 414 |
+
]
|
| 415 |
+
}
|
| 416 |
+
|
| 417 |
+
if __name__ == "__main__":
|
| 418 |
+
import uvicorn
|
| 419 |
+
print(f"Starting NanoChat Web Server")
|
| 420 |
+
print(f"Temperature: {args.temperature}, Top-k: {args.top_k}, Max tokens: {args.max_tokens}")
|
| 421 |
+
uvicorn.run(app, host=args.host, port=args.port)
|
start.sh
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
set -e
|
| 3 |
+
|
| 4 |
+
MODEL_DIR="/app/nanochat_cache/chatsft_checkpoints/d18"
|
| 5 |
+
MODEL_REPO="tventurella/mr_chatterbox_model"
|
| 6 |
+
|
| 7 |
+
# Download model checkpoint if not already present
|
| 8 |
+
if [ ! -f "$MODEL_DIR/model_000050.pt" ]; then
|
| 9 |
+
echo "Downloading model checkpoint..."
|
| 10 |
+
python -c "
|
| 11 |
+
from huggingface_hub import hf_hub_download
|
| 12 |
+
hf_hub_download('$MODEL_REPO', 'model_000050.pt', local_dir='$MODEL_DIR')
|
| 13 |
+
hf_hub_download('$MODEL_REPO', 'meta_000050.json', local_dir='$MODEL_DIR')
|
| 14 |
+
print('Model downloaded successfully.')
|
| 15 |
+
"
|
| 16 |
+
else
|
| 17 |
+
echo "Model checkpoint already present."
|
| 18 |
+
fi
|
| 19 |
+
|
| 20 |
+
# Start the server
|
| 21 |
+
exec python -m scripts.chat_web \
|
| 22 |
+
--model-tag d18 \
|
| 23 |
+
--device-type cpu \
|
| 24 |
+
--port 7860 \
|
| 25 |
+
--temperature 0.7 \
|
| 26 |
+
--top-k 50 \
|
| 27 |
+
--max-tokens 256
|
tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tokenizer_wrapper.py
ADDED
|
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
tokenizer_wrapper.py — nanochat-compatible wrapper for the Victorian BPE tokenizer
|
| 3 |
+
|
| 4 |
+
nanochat's base_train.py imports:
|
| 5 |
+
from nanochat.tokenizer import get_tokenizer, get_token_bytes
|
| 6 |
+
|
| 7 |
+
This wrapper provides a VictorianTokenizer class that satisfies nanochat's full
|
| 8 |
+
interface, plus get_tokenizer() and get_token_bytes() drop-in replacements.
|
| 9 |
+
|
| 10 |
+
Special token mapping:
|
| 11 |
+
<|endoftext|> → bos (document boundary, prepended to every document)
|
| 12 |
+
<|pad|> → pad
|
| 13 |
+
<human> → user_start (replaces nanochat's <|user_start|>)
|
| 14 |
+
<victorian> → assistant_start (replaces nanochat's <|assistant_start|>)
|
| 15 |
+
|
| 16 |
+
Usage — patch nanochat/tokenizer.py by adding at the bottom:
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
import sys
|
| 19 |
+
sys.path.insert(0, "/path/to/victorian")
|
| 20 |
+
from tokenizer_wrapper import get_tokenizer, get_token_bytes
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
import torch
|
| 25 |
+
from tokenizers import Tokenizer
|
| 26 |
+
|
| 27 |
+
TOKENIZER_PATH = Path(__file__).parent / "tokenizer.json"
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class VictorianTokenizer:
|
| 31 |
+
"""
|
| 32 |
+
Wraps our HuggingFace BPE tokenizer to match nanochat's expected interface.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def __init__(self, tokenizer_path: str | Path = TOKENIZER_PATH):
|
| 36 |
+
self._tok = Tokenizer.from_file(str(tokenizer_path))
|
| 37 |
+
self._tok.no_padding()
|
| 38 |
+
self._tok.no_truncation()
|
| 39 |
+
|
| 40 |
+
# ------------------------------------------------------------------
|
| 41 |
+
# Core nanochat interface (used by dataloader and base_train.py)
|
| 42 |
+
# ------------------------------------------------------------------
|
| 43 |
+
|
| 44 |
+
def get_vocab_size(self) -> int:
|
| 45 |
+
return self._tok.get_vocab_size()
|
| 46 |
+
|
| 47 |
+
def get_bos_token_id(self) -> int:
|
| 48 |
+
"""Prepended to every document by nanochat's dataloader."""
|
| 49 |
+
return self._tok.token_to_id("<|endoftext|>")
|
| 50 |
+
|
| 51 |
+
def encode(
|
| 52 |
+
self,
|
| 53 |
+
texts: list[str] | str,
|
| 54 |
+
prepend: int | str | None = None,
|
| 55 |
+
append: int | str | None = None,
|
| 56 |
+
num_threads: int = 4,
|
| 57 |
+
) -> list[int] | list[list[int]]:
|
| 58 |
+
"""
|
| 59 |
+
Encode strings → token ID list(s).
|
| 60 |
+
|
| 61 |
+
Matches nanochat's native tokenizer behaviour exactly:
|
| 62 |
+
- Single string → list[int]
|
| 63 |
+
- List of strings → list[list[int]]
|
| 64 |
+
|
| 65 |
+
prepend/append may be an int token ID or a special-token string
|
| 66 |
+
(e.g. prepend="<|bos|>"), matching nanochat's _encode_one interface.
|
| 67 |
+
"""
|
| 68 |
+
single = isinstance(texts, str)
|
| 69 |
+
if single:
|
| 70 |
+
texts = [texts]
|
| 71 |
+
|
| 72 |
+
# Resolve string prepend/append to token IDs (e.g. "<|bos|>" → 0)
|
| 73 |
+
if isinstance(prepend, str):
|
| 74 |
+
prepend = self.encode_special(prepend)
|
| 75 |
+
if isinstance(append, str):
|
| 76 |
+
append = self.encode_special(append)
|
| 77 |
+
|
| 78 |
+
encodings = self._tok.encode_batch(texts, is_pretokenized=False)
|
| 79 |
+
ids = [enc.ids for enc in encodings]
|
| 80 |
+
|
| 81 |
+
if prepend is not None:
|
| 82 |
+
ids = [[prepend] + seq for seq in ids]
|
| 83 |
+
if append is not None:
|
| 84 |
+
ids = [seq + [append] for seq in ids]
|
| 85 |
+
|
| 86 |
+
# Single string → flat list[int] to match nanochat's native encode()
|
| 87 |
+
return ids[0] if single else ids
|
| 88 |
+
|
| 89 |
+
def decode(self, ids: list[int]) -> str:
|
| 90 |
+
return self._tok.decode(ids)
|
| 91 |
+
|
| 92 |
+
# ------------------------------------------------------------------
|
| 93 |
+
# Special token accessors
|
| 94 |
+
# ------------------------------------------------------------------
|
| 95 |
+
|
| 96 |
+
def encode_special(self, token: str) -> int | None:
|
| 97 |
+
"""
|
| 98 |
+
Look up a special token ID by exact match.
|
| 99 |
+
Maps nanochat's native special tokens to Victorian equivalents where needed.
|
| 100 |
+
Required by nanochat's engine.py for sample generation.
|
| 101 |
+
"""
|
| 102 |
+
# Try exact match first (covers our own special tokens)
|
| 103 |
+
result = self._tok.token_to_id(token)
|
| 104 |
+
if result is not None:
|
| 105 |
+
return result
|
| 106 |
+
# Map nanochat's native chat tokens to Victorian equivalents
|
| 107 |
+
_map = {
|
| 108 |
+
"<|assistant_start|>": "<victorian>",
|
| 109 |
+
"<|assistant_end|>": "<|endoftext|>",
|
| 110 |
+
"<|user_start|>": "<human>",
|
| 111 |
+
"<|user_end|>": "<|endoftext|>",
|
| 112 |
+
"<|bos|>": "<|endoftext|>",
|
| 113 |
+
"<|eos|>": "<|endoftext|>",
|
| 114 |
+
}
|
| 115 |
+
mapped = _map.get(token)
|
| 116 |
+
if mapped:
|
| 117 |
+
return self._tok.token_to_id(mapped)
|
| 118 |
+
return None
|
| 119 |
+
|
| 120 |
+
def get_pad_token_id(self) -> int:
|
| 121 |
+
return self._tok.token_to_id("<|pad|>")
|
| 122 |
+
|
| 123 |
+
def get_user_start_id(self) -> int:
|
| 124 |
+
"""Maps to nanochat's <|user_start|> role."""
|
| 125 |
+
return self._tok.token_to_id("<human>")
|
| 126 |
+
|
| 127 |
+
def get_assistant_start_id(self) -> int:
|
| 128 |
+
"""Maps to nanochat's <|assistant_start|> role."""
|
| 129 |
+
return self._tok.token_to_id("<victorian>")
|
| 130 |
+
|
| 131 |
+
# ------------------------------------------------------------------
|
| 132 |
+
# Chat / fine-tuning interface (used by chat_sft.py)
|
| 133 |
+
# ------------------------------------------------------------------
|
| 134 |
+
|
| 135 |
+
def render_conversation(
|
| 136 |
+
self,
|
| 137 |
+
conversation: list[dict],
|
| 138 |
+
max_tokens: int = 2048,
|
| 139 |
+
) -> tuple[list[int], list[int]]:
|
| 140 |
+
"""
|
| 141 |
+
Encode a conversation into token IDs and a loss mask.
|
| 142 |
+
|
| 143 |
+
conversation: list of {"role": "user"|"assistant", "content": str}
|
| 144 |
+
Returns: (token_ids, loss_mask) — loss_mask is 1 for assistant tokens, 0 otherwise.
|
| 145 |
+
|
| 146 |
+
Victorian mapping:
|
| 147 |
+
"user" → <human> ...
|
| 148 |
+
"assistant" → <victorian> ... <|endoftext|> (end token trains model to stop)
|
| 149 |
+
"""
|
| 150 |
+
human_id = self.get_user_start_id()
|
| 151 |
+
victorian_id = self.get_assistant_start_id()
|
| 152 |
+
bos_id = self.get_bos_token_id()
|
| 153 |
+
|
| 154 |
+
tokens: list[int] = [bos_id]
|
| 155 |
+
mask: list[int] = [0]
|
| 156 |
+
|
| 157 |
+
for turn in conversation:
|
| 158 |
+
role = turn["role"]
|
| 159 |
+
content = turn["content"]
|
| 160 |
+
content_ids = self.encode(content)
|
| 161 |
+
|
| 162 |
+
if role == "user":
|
| 163 |
+
turn_tokens = [human_id] + content_ids
|
| 164 |
+
turn_mask = [0] * len(turn_tokens)
|
| 165 |
+
else: # assistant
|
| 166 |
+
turn_tokens = [victorian_id] + content_ids + [bos_id]
|
| 167 |
+
turn_mask = [1] * len(turn_tokens)
|
| 168 |
+
|
| 169 |
+
tokens.extend(turn_tokens)
|
| 170 |
+
mask.extend(turn_mask)
|
| 171 |
+
|
| 172 |
+
if len(tokens) >= max_tokens:
|
| 173 |
+
tokens = tokens[:max_tokens]
|
| 174 |
+
mask = mask[:max_tokens]
|
| 175 |
+
break
|
| 176 |
+
|
| 177 |
+
return tokens, mask
|
| 178 |
+
|
| 179 |
+
# ------------------------------------------------------------------
|
| 180 |
+
|
| 181 |
+
def __call__(self, texts, **kwargs):
|
| 182 |
+
"""Allow tokenizer(texts, ...) as an alias for encode() — required by nanochat's core_eval."""
|
| 183 |
+
return self.encode(texts, **kwargs)
|
| 184 |
+
|
| 185 |
+
@property
|
| 186 |
+
def vocab_size(self) -> int:
|
| 187 |
+
return self.get_vocab_size()
|
| 188 |
+
|
| 189 |
+
def __repr__(self) -> str:
|
| 190 |
+
return (
|
| 191 |
+
f"VictorianTokenizer(vocab_size={self.vocab_size}, "
|
| 192 |
+
f"bos={self.get_bos_token_id()}, "
|
| 193 |
+
f"human={self.get_user_start_id()}, "
|
| 194 |
+
f"victorian={self.get_assistant_start_id()})"
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
# ---------------------------------------------------------------------------
|
| 199 |
+
# nanochat drop-in functions
|
| 200 |
+
# ---------------------------------------------------------------------------
|
| 201 |
+
|
| 202 |
+
_tokenizer_singleton: VictorianTokenizer | None = None
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def get_tokenizer(tokenizer_path: str | Path = TOKENIZER_PATH) -> VictorianTokenizer:
|
| 206 |
+
"""Drop-in replacement for nanochat's get_tokenizer()."""
|
| 207 |
+
global _tokenizer_singleton
|
| 208 |
+
if _tokenizer_singleton is None:
|
| 209 |
+
_tokenizer_singleton = VictorianTokenizer(tokenizer_path)
|
| 210 |
+
return _tokenizer_singleton
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def get_token_bytes(device: str | torch.device = "cpu") -> torch.Tensor:
|
| 214 |
+
"""
|
| 215 |
+
Drop-in replacement for nanochat's get_token_bytes().
|
| 216 |
+
|
| 217 |
+
Returns a 1D tensor of shape [vocab_size] where each entry is the
|
| 218 |
+
UTF-8 byte length of that token. Used by base_train.py to convert
|
| 219 |
+
loss from nats/token → bits/byte (the BPB evaluation metric).
|
| 220 |
+
"""
|
| 221 |
+
tok = get_tokenizer()
|
| 222 |
+
vocab = tok._tok.get_vocab() # {token_str: id}
|
| 223 |
+
vocab_size = tok.get_vocab_size()
|
| 224 |
+
|
| 225 |
+
# Build id → token string mapping
|
| 226 |
+
id_to_token = {v: k for k, v in vocab.items()}
|
| 227 |
+
|
| 228 |
+
byte_lengths = []
|
| 229 |
+
for i in range(vocab_size):
|
| 230 |
+
token_str = id_to_token.get(i, "")
|
| 231 |
+
# ByteLevel BPE: Ġ represents a leading space (0x20).
|
| 232 |
+
# Decode the display string back to actual bytes for a correct byte count.
|
| 233 |
+
try:
|
| 234 |
+
# Replace Ġ with space, then encode to UTF-8
|
| 235 |
+
actual = token_str.replace("Ġ", " ").replace("Ċ", "\n").replace("ĉ", "\t")
|
| 236 |
+
n_bytes = len(actual.encode("utf-8"))
|
| 237 |
+
except Exception:
|
| 238 |
+
n_bytes = 1
|
| 239 |
+
byte_lengths.append(max(1, n_bytes)) # floor at 1 to avoid div-by-zero
|
| 240 |
+
|
| 241 |
+
return torch.tensor(byte_lengths, dtype=torch.long, device=device)
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
# ---------------------------------------------------------------------------
|
| 245 |
+
# Sanity check
|
| 246 |
+
# ---------------------------------------------------------------------------
|
| 247 |
+
|
| 248 |
+
if __name__ == "__main__":
|
| 249 |
+
import sys
|
| 250 |
+
|
| 251 |
+
if not TOKENIZER_PATH.exists():
|
| 252 |
+
print(f"Tokenizer not found at {TOKENIZER_PATH}")
|
| 253 |
+
sys.exit(1)
|
| 254 |
+
|
| 255 |
+
tok = get_tokenizer()
|
| 256 |
+
print(tok)
|
| 257 |
+
print(f" pad={tok.get_pad_token_id()}")
|
| 258 |
+
|
| 259 |
+
texts = [
|
| 260 |
+
"It is a truth universally acknowledged.",
|
| 261 |
+
"The phrenological examination was most illuminating, dear fellow.",
|
| 262 |
+
]
|
| 263 |
+
ids = tok.encode(texts, prepend=tok.get_bos_token_id())
|
| 264 |
+
for text, seq in zip(texts, ids):
|
| 265 |
+
decoded = tok.decode(seq[1:])
|
| 266 |
+
ok = "✓" if decoded == text else "✗"
|
| 267 |
+
print(f" {ok} {len(seq):3d} tokens {text!r}")
|
| 268 |
+
|
| 269 |
+
# Test render_conversation
|
| 270 |
+
conv = [
|
| 271 |
+
{"role": "user", "content": "What is your opinion on the railways?"},
|
| 272 |
+
{"role": "assistant", "content": "The railways are a most alarming development, yet undeniably useful."},
|
| 273 |
+
]
|
| 274 |
+
token_ids, loss_mask = tok.render_conversation(conv)
|
| 275 |
+
print(f"\n render_conversation: {len(token_ids)} tokens, "
|
| 276 |
+
f"{sum(loss_mask)} assistant tokens in loss mask")
|
| 277 |
+
|
| 278 |
+
# Test get_token_bytes
|
| 279 |
+
tb = get_token_bytes()
|
| 280 |
+
print(f"\n get_token_bytes: shape={tuple(tb.shape)}, "
|
| 281 |
+
f"mean={tb.mean():.2f} bytes/token, "
|
| 282 |
+
f"min={tb.min():.0f}, max={tb.max():.0f}")
|