tventurella commited on
Commit
59856b4
·
verified ·
1 Parent(s): 5a27571

Upload 17 files

Browse files
Dockerfile ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Install system dependencies
6
+ RUN apt-get update && apt-get install -y --no-install-recommends \
7
+ build-essential \
8
+ && rm -rf /var/lib/apt/lists/*
9
+
10
+ # Install Python dependencies
11
+ RUN pip install --no-cache-dir \
12
+ torch --index-url https://download.pytorch.org/whl/cpu
13
+ RUN pip install --no-cache-dir \
14
+ tokenizers \
15
+ fastapi \
16
+ uvicorn[standard] \
17
+ pydantic \
18
+ httpx \
19
+ filelock \
20
+ huggingface_hub
21
+
22
+ # Copy application code
23
+ COPY nanochat/ nanochat/
24
+ COPY scripts/ scripts/
25
+ COPY tokenizer_wrapper.py .
26
+ COPY tokenizer.json .
27
+ COPY start.sh .
28
+ RUN chmod +x start.sh
29
+
30
+ # Create model directory
31
+ RUN mkdir -p /app/nanochat_cache/chatsft_checkpoints/d18
32
+
33
+ # Set environment variables
34
+ ENV NANOCHAT_BASE_DIR=/app/nanochat_cache
35
+ ENV PYTHONPATH=/app
36
+
37
+ # HuggingFace Spaces expects port 7860
38
+ EXPOSE 7860
39
+
40
+ CMD ["./start.sh"]
README.md CHANGED
@@ -1,12 +1,8 @@
1
  ---
2
- title: Mr Chatterbox
3
- emoji:
4
- colorFrom: purple
5
- colorTo: indigo
6
  sdk: docker
7
- pinned: false
8
- license: mit
9
- short_description: The Victorian Gentleman Chatbot
10
  ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Mr. Chatterbox
3
+ emoji: 🎩
4
+ colorFrom: red
5
+ colorTo: yellow
6
  sdk: docker
7
+ app_port: 7860
 
 
8
  ---
 
 
nanochat/__init__.py ADDED
File without changes
nanochat/checkpoint_manager.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities for saving and loading model/optim/state checkpoints.
3
+ """
4
+ import os
5
+ import re
6
+ import glob
7
+ import json
8
+ import logging
9
+ import torch
10
+
11
+ from nanochat.common import get_base_dir
12
+ from nanochat.gpt import GPT, GPTConfig
13
+ from nanochat.tokenizer import get_tokenizer
14
+ from nanochat.common import setup_default_logging
15
+
16
+ # Set up logging
17
+ setup_default_logging()
18
+ logger = logging.getLogger(__name__)
19
+ def log0(message):
20
+ if int(os.environ.get('RANK', 0)) == 0:
21
+ logger.info(message)
22
+
23
+ def _patch_missing_config_keys(model_config_kwargs):
24
+ """Add default values for new config keys missing in old checkpoints."""
25
+ # Old models were trained with full context (no sliding window)
26
+ if "window_pattern" not in model_config_kwargs:
27
+ model_config_kwargs["window_pattern"] = "L"
28
+ log0(f"Patching missing window_pattern in model config to 'L'")
29
+
30
+ def _patch_missing_keys(model_data, model_config):
31
+ """Add default values for new parameters that may be missing in old checkpoints."""
32
+ n_layer = model_config.n_layer
33
+ # resid_lambdas defaults to 1.0 (identity scaling)
34
+ if "resid_lambdas" not in model_data:
35
+ model_data["resid_lambdas"] = torch.ones(n_layer)
36
+ log0(f"Patching missing resid_lambdas in model data to 1.0")
37
+ # x0_lambdas defaults to 0.0 (disabled)
38
+ if "x0_lambdas" not in model_data:
39
+ model_data["x0_lambdas"] = torch.zeros(n_layer)
40
+ log0(f"Patching missing x0_lambdas in model data to 0.0")
41
+
42
+ def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, rank=0):
43
+ if rank == 0:
44
+ os.makedirs(checkpoint_dir, exist_ok=True)
45
+ # Save the model state parameters
46
+ model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
47
+ torch.save(model_data, model_path)
48
+ logger.info(f"Saved model parameters to: {model_path}")
49
+ # Save the metadata dict as json
50
+ meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
51
+ with open(meta_path, "w", encoding="utf-8") as f:
52
+ json.dump(meta_data, f, indent=2)
53
+ logger.info(f"Saved metadata to: {meta_path}")
54
+ # Note that optimizer state is sharded across ranks, so each rank must save its own.
55
+ if optimizer_data is not None:
56
+ os.makedirs(checkpoint_dir, exist_ok=True)
57
+ optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
58
+ torch.save(optimizer_data, optimizer_path)
59
+ logger.info(f"Saved optimizer state to: {optimizer_path}")
60
+
61
+ def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False, rank=0):
62
+ # Load the model state
63
+ model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
64
+ model_data = torch.load(model_path, map_location=device)
65
+ # Load the optimizer state if requested
66
+ optimizer_data = None
67
+ if load_optimizer:
68
+ optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
69
+ optimizer_data = torch.load(optimizer_path, map_location=device)
70
+ # Load the metadata
71
+ meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
72
+ with open(meta_path, "r", encoding="utf-8") as f:
73
+ meta_data = json.load(f)
74
+ return model_data, optimizer_data, meta_data
75
+
76
+
77
+ def build_model(checkpoint_dir, step, device, phase):
78
+ """
79
+ A bunch of repetitive code to build a model from a given checkpoint.
80
+ Returns:
81
+ - base model - uncompiled, not wrapped in DDP
82
+ - tokenizer
83
+ - meta data saved during base model training
84
+ """
85
+ assert phase in ["train", "eval"], f"Invalid phase: {phase}"
86
+ model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, step, device, load_optimizer=False)
87
+ if device.type in {"cpu", "mps"}:
88
+ # Convert bfloat16 tensors to float for CPU inference
89
+ model_data = {
90
+ k: v.float() if v.dtype == torch.bfloat16 else v
91
+ for k, v in model_data.items()
92
+ }
93
+ # Hack: fix torch compile issue, which prepends all keys with _orig_mod.
94
+ model_data = {k.removeprefix("_orig_mod."): v for k, v in model_data.items()}
95
+ model_config_kwargs = meta_data["model_config"]
96
+ _patch_missing_config_keys(model_config_kwargs)
97
+ log0(f"Building model with config: {model_config_kwargs}")
98
+ model_config = GPTConfig(**model_config_kwargs)
99
+ _patch_missing_keys(model_data, model_config)
100
+ with torch.device("meta"):
101
+ model = GPT(model_config)
102
+ # Load the model state
103
+ model.to_empty(device=device)
104
+ model.init_weights() # note: this is dumb, but we need to init the rotary embeddings. TODO: fix model re-init
105
+ result = model.load_state_dict(model_data, strict=False, assign=True)
106
+ if result.unexpected_keys:
107
+ log0(f"Ignoring unexpected checkpoint keys: {result.unexpected_keys}")
108
+ # Put the model in the right training phase / mode
109
+ if phase == "eval":
110
+ model.eval()
111
+ else:
112
+ model.train()
113
+ # Load the Tokenizer
114
+ tokenizer = get_tokenizer()
115
+ # Sanity check: compatibility between model and tokenizer
116
+ assert tokenizer.get_vocab_size() == model_config_kwargs["vocab_size"], f"Tokenizer vocab size {tokenizer.get_vocab_size()} does not match model config vocab size {model_config_kwargs['vocab_size']}"
117
+ return model, tokenizer, meta_data
118
+
119
+
120
+ def find_largest_model(checkpoints_dir):
121
+ # attempt to guess the model tag: take the biggest model available
122
+ model_tags = [f for f in os.listdir(checkpoints_dir) if os.path.isdir(os.path.join(checkpoints_dir, f))]
123
+ if not model_tags:
124
+ raise FileNotFoundError(f"No checkpoints found in {checkpoints_dir}")
125
+ # 1) normally all model tags are of the form d<number>, try that first:
126
+ candidates = []
127
+ for model_tag in model_tags:
128
+ match = re.match(r"d(\d+)", model_tag)
129
+ if match:
130
+ model_depth = int(match.group(1))
131
+ candidates.append((model_depth, model_tag))
132
+ if candidates:
133
+ candidates.sort(key=lambda x: x[0], reverse=True)
134
+ return candidates[0][1]
135
+ # 2) if that failed, take the most recently updated model:
136
+ model_tags.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoints_dir, x)), reverse=True)
137
+ return model_tags[0]
138
+
139
+
140
+ def find_last_step(checkpoint_dir):
141
+ # Look into checkpoint_dir and find model_<step>.pt with the highest step
142
+ checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "model_*.pt"))
143
+ if not checkpoint_files:
144
+ raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
145
+ last_step = int(max(os.path.basename(f).split("_")[-1].split(".")[0] for f in checkpoint_files))
146
+ return last_step
147
+
148
+ # -----------------------------------------------------------------------------
149
+ # convenience functions that take into account nanochat's directory structure
150
+
151
+ def load_model_from_dir(checkpoints_dir, device, phase, model_tag=None, step=None):
152
+ if model_tag is None:
153
+ # guess the model tag by defaulting to the largest model
154
+ model_tag = find_largest_model(checkpoints_dir)
155
+ log0(f"No model tag provided, guessing model tag: {model_tag}")
156
+ checkpoint_dir = os.path.join(checkpoints_dir, model_tag)
157
+ if step is None:
158
+ # guess the step by defaulting to the last step
159
+ step = find_last_step(checkpoint_dir)
160
+ assert step is not None, f"No checkpoints found in {checkpoint_dir}"
161
+ # build the model
162
+ log0(f"Loading model from {checkpoint_dir} with step {step}")
163
+ model, tokenizer, meta_data = build_model(checkpoint_dir, step, device, phase)
164
+ return model, tokenizer, meta_data
165
+
166
+ def load_model(source, *args, **kwargs):
167
+ model_dir = {
168
+ "base": "base_checkpoints",
169
+ "sft": "chatsft_checkpoints",
170
+ "rl": "chatrl_checkpoints",
171
+ }[source]
172
+ base_dir = get_base_dir()
173
+ checkpoints_dir = os.path.join(base_dir, model_dir)
174
+ return load_model_from_dir(checkpoints_dir, *args, **kwargs)
175
+
176
+ def load_optimizer_state(source, device, rank, model_tag=None, step=None):
177
+ """Load just the optimizer shard for a given rank, without re-loading the model."""
178
+ model_dir = {
179
+ "base": "base_checkpoints",
180
+ "sft": "chatsft_checkpoints",
181
+ "rl": "chatrl_checkpoints",
182
+ }[source]
183
+ base_dir = get_base_dir()
184
+ checkpoints_dir = os.path.join(base_dir, model_dir)
185
+ if model_tag is None:
186
+ model_tag = find_largest_model(checkpoints_dir)
187
+ checkpoint_dir = os.path.join(checkpoints_dir, model_tag)
188
+ if step is None:
189
+ step = find_last_step(checkpoint_dir)
190
+ optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
191
+ if not os.path.exists(optimizer_path):
192
+ log0(f"Optimizer checkpoint not found: {optimizer_path}")
193
+ return None
194
+ log0(f"Loading optimizer state from {optimizer_path}")
195
+ optimizer_data = torch.load(optimizer_path, map_location=device)
196
+ return optimizer_data
nanochat/common.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Common utilities for nanochat.
3
+ """
4
+
5
+ import os
6
+ import re
7
+ import logging
8
+ import urllib.request
9
+ import torch
10
+ import torch.distributed as dist
11
+ from filelock import FileLock
12
+
13
+ # The dtype used for compute (matmuls, activations). Master weights stay fp32 for optimizer precision.
14
+ # Linear layers cast their weights to this dtype in forward, replacing torch.amp.autocast.
15
+ # Override with NANOCHAT_DTYPE env var: "bfloat16", "float16", "float32"
16
+ _DTYPE_MAP = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}
17
+ def _detect_compute_dtype():
18
+ env = os.environ.get("NANOCHAT_DTYPE")
19
+ if env is not None:
20
+ return _DTYPE_MAP[env], f"set via NANOCHAT_DTYPE={env}"
21
+ if torch.cuda.is_available():
22
+ # bf16 requires SM 80+ (Ampere: A100, A10, etc.)
23
+ # Older GPUs like V100 (SM 70) and T4 (SM 75) only have fp16 tensor cores
24
+ capability = torch.cuda.get_device_capability()
25
+ if capability >= (8, 0):
26
+ return torch.bfloat16, f"auto-detected: CUDA SM {capability[0]}{capability[1]} (bf16 supported)"
27
+ # fp16 training requires GradScaler (not yet implemented), so fall back to fp32.
28
+ # Users can still force fp16 via NANOCHAT_DTYPE=float16 if they know what they're doing.
29
+ return torch.float32, f"auto-detected: CUDA SM {capability[0]}{capability[1]} (pre-Ampere, bf16 not supported, using fp32)"
30
+ return torch.float32, "auto-detected: no CUDA (CPU/MPS)"
31
+ COMPUTE_DTYPE, COMPUTE_DTYPE_REASON = _detect_compute_dtype()
32
+
33
+ class ColoredFormatter(logging.Formatter):
34
+ """Custom formatter that adds colors to log messages."""
35
+ # ANSI color codes
36
+ COLORS = {
37
+ 'DEBUG': '\033[36m', # Cyan
38
+ 'INFO': '\033[32m', # Green
39
+ 'WARNING': '\033[33m', # Yellow
40
+ 'ERROR': '\033[31m', # Red
41
+ 'CRITICAL': '\033[35m', # Magenta
42
+ }
43
+ RESET = '\033[0m'
44
+ BOLD = '\033[1m'
45
+ def format(self, record):
46
+ # Add color to the level name
47
+ levelname = record.levelname
48
+ if levelname in self.COLORS:
49
+ record.levelname = f"{self.COLORS[levelname]}{self.BOLD}{levelname}{self.RESET}"
50
+ # Format the message
51
+ message = super().format(record)
52
+ # Add color to specific parts of the message
53
+ if levelname == 'INFO':
54
+ # Highlight numbers and percentages
55
+ message = re.sub(r'(\d+\.?\d*\s*(?:GB|MB|%|docs))', rf'{self.BOLD}\1{self.RESET}', message)
56
+ message = re.sub(r'(Shard \d+)', rf'{self.COLORS["INFO"]}{self.BOLD}\1{self.RESET}', message)
57
+ return message
58
+
59
+ def setup_default_logging():
60
+ handler = logging.StreamHandler()
61
+ handler.setFormatter(ColoredFormatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
62
+ logging.basicConfig(
63
+ level=logging.INFO,
64
+ handlers=[handler]
65
+ )
66
+
67
+ setup_default_logging()
68
+ logger = logging.getLogger(__name__)
69
+
70
+ def get_base_dir():
71
+ # co-locate nanochat intermediates with other cached data in ~/.cache (by default)
72
+ if os.environ.get("NANOCHAT_BASE_DIR"):
73
+ nanochat_dir = os.environ.get("NANOCHAT_BASE_DIR")
74
+ else:
75
+ home_dir = os.path.expanduser("~")
76
+ cache_dir = os.path.join(home_dir, ".cache")
77
+ nanochat_dir = os.path.join(cache_dir, "nanochat")
78
+ os.makedirs(nanochat_dir, exist_ok=True)
79
+ return nanochat_dir
80
+
81
+ def download_file_with_lock(url, filename, postprocess_fn=None):
82
+ """
83
+ Downloads a file from a URL to a local path in the base directory.
84
+ Uses a lock file to prevent concurrent downloads among multiple ranks.
85
+ """
86
+ base_dir = get_base_dir()
87
+ file_path = os.path.join(base_dir, filename)
88
+ lock_path = file_path + ".lock"
89
+
90
+ if os.path.exists(file_path):
91
+ return file_path
92
+
93
+ with FileLock(lock_path):
94
+ # Only a single rank can acquire this lock
95
+ # All other ranks block until it is released
96
+
97
+ # Recheck after acquiring lock
98
+ if os.path.exists(file_path):
99
+ return file_path
100
+
101
+ # Download the content as bytes
102
+ print(f"Downloading {url}...")
103
+ with urllib.request.urlopen(url) as response:
104
+ content = response.read() # bytes
105
+
106
+ # Write to local file
107
+ with open(file_path, 'wb') as f:
108
+ f.write(content)
109
+ print(f"Downloaded to {file_path}")
110
+
111
+ # Run the postprocess function if provided
112
+ if postprocess_fn is not None:
113
+ postprocess_fn(file_path)
114
+
115
+ return file_path
116
+
117
+ def print0(s="",**kwargs):
118
+ ddp_rank = int(os.environ.get('RANK', 0))
119
+ if ddp_rank == 0:
120
+ print(s, **kwargs)
121
+
122
+ def print_banner():
123
+ # Cool DOS Rebel font ASCII banner made with https://manytools.org/hacker-tools/ascii-banner/
124
+ banner = """
125
+ █████ █████
126
+ ░░███ ░░███
127
+ ████████ ██████ ██��█████ ██████ ██████ ░███████ ██████ ███████
128
+ ░░███░░███ ░░░░░███ ░░███░░███ ███░░███ ███░░███ ░███░░███ ░░░░░███░░░███░
129
+ ░███ ░███ ███████ ░███ ░███ ░███ ░███░███ ░░░ ░███ ░███ ███████ ░███
130
+ ░███ ░███ ███░░███ ░███ ░███ ░███ ░███░███ ███ ░███ ░███ ███░░███ ░███ ███
131
+ ████ █████░░████████ ████ █████░░██████ ░░██████ ████ █████░░███████ ░░█████
132
+ ░░░░ ░░░░░ ░░░░░░░░ ░░░░ ░░░░░ ░░░░░░ ░░░░░░ ░░░░ ░░░░░ ░░░░░░░░ ░░░░░
133
+ """
134
+ print0(banner)
135
+
136
+ def is_ddp_requested() -> bool:
137
+ """
138
+ True if launched by torchrun (env present), even before init.
139
+ Used to decide whether we *should* initialize a PG.
140
+ """
141
+ return all(k in os.environ for k in ("RANK", "LOCAL_RANK", "WORLD_SIZE"))
142
+
143
+ def is_ddp_initialized() -> bool:
144
+ """
145
+ True if torch.distributed is available and the process group is initialized.
146
+ Used at cleanup to avoid destroying a non-existent PG.
147
+ """
148
+ return dist.is_available() and dist.is_initialized()
149
+
150
+ def get_dist_info():
151
+ if is_ddp_requested():
152
+ # We rely on torchrun's env to decide if we SHOULD init.
153
+ # (Initialization itself happens in compute init.)
154
+ assert all(var in os.environ for var in ['RANK', 'LOCAL_RANK', 'WORLD_SIZE'])
155
+ ddp_rank = int(os.environ['RANK'])
156
+ ddp_local_rank = int(os.environ['LOCAL_RANK'])
157
+ ddp_world_size = int(os.environ['WORLD_SIZE'])
158
+ return True, ddp_rank, ddp_local_rank, ddp_world_size
159
+ else:
160
+ return False, 0, 0, 1
161
+
162
+ def autodetect_device_type():
163
+ # prefer to use CUDA if available, otherwise use MPS, otherwise fallback on CPU
164
+ if torch.cuda.is_available():
165
+ device_type = "cuda"
166
+ elif torch.backends.mps.is_available():
167
+ device_type = "mps"
168
+ else:
169
+ device_type = "cpu"
170
+ print0(f"Autodetected device type: {device_type}")
171
+ return device_type
172
+
173
+ def compute_init(device_type="cuda"): # cuda|cpu|mps
174
+ """Basic initialization that we keep doing over and over, so make common."""
175
+
176
+ assert device_type in ["cuda", "mps", "cpu"], "Invalid device type atm"
177
+ if device_type == "cuda":
178
+ assert torch.cuda.is_available(), "Your PyTorch installation is not configured for CUDA but device_type is 'cuda'"
179
+ if device_type == "mps":
180
+ assert torch.backends.mps.is_available(), "Your PyTorch installation is not configured for MPS but device_type is 'mps'"
181
+
182
+ # Reproducibility
183
+ # Note that we set the global seeds here, but most of the code uses explicit rng objects.
184
+ # The only place where global rng might be used is nn.Module initialization of the model weights.
185
+ torch.manual_seed(42)
186
+ if device_type == "cuda":
187
+ torch.cuda.manual_seed(42)
188
+ # skipping full reproducibility for now, possibly investigate slowdown later
189
+ # torch.use_deterministic_algorithms(True)
190
+
191
+ # Precision
192
+ if device_type == "cuda":
193
+ torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls, see https://docs.pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html
194
+
195
+ # Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA
196
+ is_ddp_requested, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
197
+ if is_ddp_requested and device_type == "cuda":
198
+ device = torch.device("cuda", ddp_local_rank)
199
+ torch.cuda.set_device(device) # make "cuda" default to this device
200
+ dist.init_process_group(backend="nccl", device_id=device)
201
+ dist.barrier()
202
+ else:
203
+ device = torch.device(device_type) # mps|cpu
204
+
205
+ if ddp_rank == 0:
206
+ logger.info(f"Distributed world size: {ddp_world_size}")
207
+
208
+ return is_ddp_requested, ddp_rank, ddp_local_rank, ddp_world_size, device
209
+
210
+ def compute_cleanup():
211
+ """Companion function to compute_init, to clean things up before script exit"""
212
+ if is_ddp_initialized():
213
+ dist.destroy_process_group()
214
+
215
+ class DummyWandb:
216
+ """Useful if we wish to not use wandb but have all the same signatures"""
217
+ def __init__(self):
218
+ pass
219
+ def log(self, *args, **kwargs):
220
+ pass
221
+ def finish(self):
222
+ pass
223
+
224
+ # hardcoded BF16 peak flops for various GPUs
225
+ # inspired by torchtitan: https://github.com/pytorch/torchtitan/blob/main/torchtitan/tools/utils.py
226
+ # and PR: https://github.com/karpathy/nanochat/pull/147
227
+ def get_peak_flops(device_name: str) -> float:
228
+ name = device_name.lower()
229
+
230
+ # Table order matters: more specific patterns first.
231
+ _PEAK_FLOPS_TABLE = (
232
+ # NVIDIA Blackwell
233
+ (["gb200"], 2.5e15),
234
+ (["grace blackwell"], 2.5e15),
235
+ (["b200"], 2.25e15),
236
+ (["b100"], 1.8e15),
237
+ # NVIDIA Hopper
238
+ (["h200", "nvl"], 836e12),
239
+ (["h200", "pcie"], 836e12),
240
+ (["h200"], 989e12),
241
+ (["h100", "nvl"], 835e12),
242
+ (["h100", "pcie"], 756e12),
243
+ (["h100"], 989e12),
244
+ (["h800", "nvl"], 989e12),
245
+ (["h800"], 756e12),
246
+ # NVIDIA Ampere data center
247
+ (["a100"], 312e12),
248
+ (["a800"], 312e12),
249
+ (["a40"], 149.7e12),
250
+ (["a30"], 165e12),
251
+ # NVIDIA Ada data center
252
+ (["l40s"], 362e12),
253
+ (["l40-s"], 362e12),
254
+ (["l40 s"], 362e12),
255
+ (["l4"], 121e12),
256
+ # AMD CDNA accelerators
257
+ (["mi355"], 2.5e15),
258
+ (["mi325"], 1.3074e15),
259
+ (["mi300x"], 1.3074e15),
260
+ (["mi300a"], 980.6e12),
261
+ (["mi250x"], 383e12),
262
+ (["mi250"], 362.1e12),
263
+ # Consumer RTX
264
+ (["5090"], 209.5e12),
265
+ (["4090"], 165.2e12),
266
+ (["3090"], 71e12),
267
+ )
268
+ for patterns, flops in _PEAK_FLOPS_TABLE:
269
+ if all(p in name for p in patterns):
270
+ return flops
271
+ if "data center gpu max 1550" in name:
272
+ # Ponte Vecchio (PVC) - dynamic based on compute units
273
+ max_comp_units = torch.xpu.get_device_properties("xpu").max_compute_units
274
+ return 512 * max_comp_units * 1300 * 10**6
275
+
276
+ # Unknown GPU - return inf so MFU shows as 0% rather than a wrong guess
277
+ logger.warning(f"Peak flops undefined for: {device_name}, MFU will show as 0%")
278
+ return float('inf')
nanochat/engine.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Engine for efficient inference of our models.
3
+
4
+ Everything works around token sequences:
5
+ - The user can send token sequences to the engine
6
+ - The engine returns the next token
7
+
8
+ Notes:
9
+ - The engine knows nothing about tokenization, it's purely token id sequences.
10
+
11
+ The whole thing is made as efficient as possible.
12
+ """
13
+
14
+ import torch
15
+ import torch.nn.functional as F
16
+ import signal
17
+ import warnings
18
+ from contextlib import contextmanager
19
+ from collections import deque
20
+ from nanochat.common import compute_init, autodetect_device_type
21
+ from nanochat.checkpoint_manager import load_model
22
+
23
+ # -----------------------------------------------------------------------------
24
+ # Calculator tool helpers
25
+ @contextmanager
26
+ def timeout(duration, formula):
27
+ def timeout_handler(signum, frame):
28
+ raise Exception(f"'{formula}': timed out after {duration} seconds")
29
+
30
+ signal.signal(signal.SIGALRM, timeout_handler)
31
+ signal.alarm(duration)
32
+ yield
33
+ signal.alarm(0)
34
+
35
+ def eval_with_timeout(formula, max_time=3):
36
+ try:
37
+ with timeout(max_time, formula):
38
+ with warnings.catch_warnings():
39
+ warnings.simplefilter("ignore", SyntaxWarning)
40
+ return eval(formula, {"__builtins__": {}}, {})
41
+ except Exception as e:
42
+ signal.alarm(0)
43
+ # print(f"Warning: Failed to eval {formula}, exception: {e}") # it's ok ignore wrong calculator usage
44
+ return None
45
+
46
+ def use_calculator(expr):
47
+ """
48
+ Evaluate a Python expression safely.
49
+ Supports both math expressions and string operations like .count()
50
+ """
51
+ # Remove commas from numbers
52
+ expr = expr.replace(",", "")
53
+
54
+ # Check if it's a pure math expression (old behavior)
55
+ if all([x in "0123456789*+-/.() " for x in expr]):
56
+ if "**" in expr: # disallow power operator
57
+ return None
58
+ return eval_with_timeout(expr)
59
+
60
+ # Check if it's a string operation we support
61
+ # Allow: strings (single/double quotes), .count(), letters, numbers, spaces, parens
62
+ allowed_chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'\"()._ "
63
+ if not all([x in allowed_chars for x in expr]):
64
+ return None
65
+
66
+ # Disallow dangerous patterns
67
+ dangerous_patterns = ['__', 'import', 'exec', 'eval', 'compile', 'open', 'file',
68
+ 'input', 'raw_input', 'globals', 'locals', 'vars', 'dir',
69
+ 'getattr', 'setattr', 'delattr', 'hasattr']
70
+ expr_lower = expr.lower()
71
+ if any(pattern in expr_lower for pattern in dangerous_patterns):
72
+ return None
73
+
74
+ # Only allow .count() method for now (can expand later)
75
+ if '.count(' not in expr:
76
+ return None
77
+
78
+ # Evaluate with timeout
79
+ return eval_with_timeout(expr)
80
+
81
+ # -----------------------------------------------------------------------------
82
+ class KVCache:
83
+ """
84
+ KV Cache designed for Flash Attention 3's flash_attn_with_kvcache API.
85
+
86
+ Key differences from FA2-style cache:
87
+ - Tensors are (B, T, H, D) not (B, H, T, D)
88
+ - FA3 updates the cache in-place during flash_attn_with_kvcache
89
+ - Position tracked per batch element via cache_seqlens tensor
90
+ """
91
+
92
+ def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers, device, dtype):
93
+ self.batch_size = batch_size
94
+ self.max_seq_len = seq_len
95
+ self.n_layers = num_layers
96
+ self.n_heads = num_heads
97
+ self.head_dim = head_dim
98
+ # Pre-allocate cache tensors: (n_layers, B, T, H, D)
99
+ self.k_cache = torch.zeros(num_layers, batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)
100
+ self.v_cache = torch.zeros(num_layers, batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)
101
+ # Current sequence length per batch element (FA3 needs int32)
102
+ self.cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device)
103
+
104
+ def reset(self):
105
+ """Reset cache to empty state."""
106
+ self.cache_seqlens.zero_()
107
+
108
+ def get_pos(self):
109
+ """Get current position (assumes all batch elements at same position)."""
110
+ return self.cache_seqlens[0].item()
111
+
112
+ def get_layer_cache(self, layer_idx):
113
+ """Return (k_cache, v_cache) views for a specific layer."""
114
+ return self.k_cache[layer_idx], self.v_cache[layer_idx]
115
+
116
+ def advance(self, num_tokens):
117
+ """Advance the cache position by num_tokens."""
118
+ self.cache_seqlens += num_tokens
119
+
120
+ def prefill(self, other):
121
+ """
122
+ Copy cached KV from another cache into this one.
123
+ Used when we do batch=1 prefill and then want to generate multiple samples in parallel.
124
+ """
125
+ assert self.get_pos() == 0, "Cannot prefill a non-empty KV cache"
126
+ assert self.n_layers == other.n_layers and self.n_heads == other.n_heads and self.head_dim == other.head_dim
127
+ assert self.max_seq_len >= other.max_seq_len
128
+ other_pos = other.get_pos()
129
+ self.k_cache[:, :, :other_pos, :, :] = other.k_cache[:, :, :other_pos, :, :]
130
+ self.v_cache[:, :, :other_pos, :, :] = other.v_cache[:, :, :other_pos, :, :]
131
+ self.cache_seqlens.fill_(other_pos)
132
+
133
+ # -----------------------------------------------------------------------------
134
+ @torch.inference_mode()
135
+ def sample_next_token(logits, rng, temperature=1.0, top_k=None):
136
+ """Sample a single next token from given logits of shape (B, vocab_size). Returns (B, 1)."""
137
+ assert temperature >= 0.0, "temperature must be non-negative"
138
+ if temperature == 0.0:
139
+ return torch.argmax(logits, dim=-1, keepdim=True)
140
+ if top_k is not None and top_k > 0:
141
+ k = min(top_k, logits.size(-1))
142
+ vals, idx = torch.topk(logits, k, dim=-1)
143
+ vals = vals / temperature
144
+ probs = F.softmax(vals, dim=-1)
145
+ choice = torch.multinomial(probs, num_samples=1, generator=rng)
146
+ return idx.gather(1, choice)
147
+ else:
148
+ logits = logits / temperature
149
+ probs = F.softmax(logits, dim=-1)
150
+ return torch.multinomial(probs, num_samples=1, generator=rng)
151
+
152
+ # -----------------------------------------------------------------------------
153
+
154
+ class RowState:
155
+ # Per-row state tracking during generation
156
+ def __init__(self, current_tokens=None):
157
+ self.current_tokens = current_tokens or [] # Current token sequence for this row
158
+ self.forced_tokens = deque() # Queue of tokens to force inject
159
+ self.in_python_block = False # Whether we are inside a python block
160
+ self.python_expr_tokens = [] # Tokens of the current python expression
161
+ self.completed = False # Whether this row has completed generation
162
+
163
+ class Engine:
164
+
165
+ def __init__(self, model, tokenizer):
166
+ self.model = model
167
+ self.tokenizer = tokenizer # needed for tool use
168
+
169
+ @torch.inference_mode()
170
+ def generate(self, tokens, num_samples=1, max_tokens=None, temperature=1.0, top_k=None, seed=42, repetition_penalty=1.0, repetition_window=64):
171
+ """Same as generate, but does single prefill and then clones the KV cache."""
172
+ assert isinstance(tokens, list) and isinstance(tokens[0], int), "expecting list of ints"
173
+ device = self.model.get_device()
174
+ # NOTE: setting the dtype here and in this way is an ugly hack.
175
+ # Currently the repo assumes that cuda -> bfloat16 and everything else -> float32.
176
+ # We need to know the dtype here to call __init__ on KVCache and pre-allocate its tensors.
177
+ # As a quick hack, we're making generate() function inherit and know about this repo-wise assumption.
178
+ # I think there has to be a bigger refactor to deal with device/dtype tracking across the codebase.
179
+ # In particular, the KVCache should allocate its tensors lazily
180
+ dtype = torch.bfloat16 if device.type == "cuda" else torch.float32
181
+ rng = torch.Generator(device=device)
182
+ rng.manual_seed(seed)
183
+
184
+ # Get the special tokens we need to coordinate the tool use state machine
185
+ get_special = lambda s: self.tokenizer.encode_special(s)
186
+ python_start = get_special("<|python_start|>")
187
+ python_end = get_special("<|python_end|>")
188
+ output_start = get_special("<|output_start|>")
189
+ output_end = get_special("<|output_end|>")
190
+ assistant_end = get_special("<|assistant_end|>") # if sampled, ends row
191
+ bos = self.tokenizer.get_bos_token_id() # if sampled, ends row
192
+
193
+ # 1) Run a batch 1 prefill of the prompt tokens
194
+ m = self.model.config
195
+ kv_model_kwargs = {"num_heads": m.n_kv_head, "head_dim": m.n_embd // m.n_head, "num_layers": m.n_layer}
196
+ kv_cache_prefill = KVCache(
197
+ batch_size=1,
198
+ seq_len=len(tokens),
199
+ device=device,
200
+ dtype=dtype,
201
+ **kv_model_kwargs,
202
+ )
203
+ ids = torch.tensor([tokens], dtype=torch.long, device=device)
204
+ logits = self.model.forward(ids, kv_cache=kv_cache_prefill)
205
+ logits = logits[:, -1, :].expand(num_samples, -1) # (num_samples, vocab_size)
206
+
207
+ # 2) Replicate the KV cache for each sample/row
208
+ kv_length_hint = (len(tokens) + max_tokens) if max_tokens is not None else self.model.config.sequence_len
209
+ kv_cache_decode = KVCache(
210
+ batch_size=num_samples,
211
+ seq_len=kv_length_hint,
212
+ device=device,
213
+ dtype=dtype,
214
+ **kv_model_kwargs,
215
+ )
216
+ kv_cache_decode.prefill(kv_cache_prefill)
217
+ del kv_cache_prefill # no need to keep this memory around
218
+
219
+ # 3) Initialize states for each sample
220
+ row_states = [RowState(tokens.copy()) for _ in range(num_samples)]
221
+
222
+ # 4) Main generation loop
223
+ num_generated = 0
224
+ while True:
225
+ # Stop condition: we've reached max tokens
226
+ if max_tokens is not None and num_generated >= max_tokens:
227
+ break
228
+ # Stop condition: all rows are completed
229
+ if all(state.completed for state in row_states):
230
+ break
231
+
232
+ # Sample the next token for each row
233
+ if repetition_penalty != 1.0: # Victorian repetition penalty patch
234
+ _pen = logits.clone()
235
+ for _i, _s in enumerate(row_states):
236
+ if not _s.completed:
237
+ for _t in set(_s.current_tokens[-repetition_window:]):
238
+ _pen[_i, _t] = (_pen[_i, _t] / repetition_penalty
239
+ if _pen[_i, _t] > 0 else _pen[_i, _t] * repetition_penalty)
240
+ next_ids = sample_next_token(_pen, rng, temperature, top_k)
241
+ else:
242
+ next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1)
243
+ sampled_tokens = next_ids[:, 0].tolist()
244
+
245
+ # Process each row: choose the next token, update state, optional tool use
246
+ token_column = [] # contains the next token id along each row
247
+ token_masks = [] # contains the mask (was it sampled (1) or forced (0)?) along each row
248
+ for i, state in enumerate(row_states):
249
+ # Select the next token in this row
250
+ is_forced = len(state.forced_tokens) > 0 # are there tokens waiting to be forced in deque?
251
+ token_masks.append(0 if is_forced else 1) # mask is 0 if forced, 1 if sampled
252
+ next_token = state.forced_tokens.popleft() if is_forced else sampled_tokens[i]
253
+ token_column.append(next_token)
254
+ # Update the state of this row to include the next token
255
+ state.current_tokens.append(next_token)
256
+ # On <|assistant_end|> or <|bos|>, mark the row as completed
257
+ if next_token == assistant_end or next_token == bos:
258
+ state.completed = True
259
+ # Handle tool logic
260
+ if next_token == python_start:
261
+ state.in_python_block = True
262
+ state.python_expr_tokens = []
263
+ elif next_token == python_end and state.in_python_block:
264
+ state.in_python_block = False
265
+ if state.python_expr_tokens:
266
+ expr = self.tokenizer.decode(state.python_expr_tokens)
267
+ result = use_calculator(expr)
268
+ if result is not None:
269
+ result_tokens = self.tokenizer.encode(str(result))
270
+ state.forced_tokens.append(output_start)
271
+ state.forced_tokens.extend(result_tokens)
272
+ state.forced_tokens.append(output_end)
273
+ state.python_expr_tokens = []
274
+ elif state.in_python_block:
275
+ state.python_expr_tokens.append(next_token)
276
+
277
+ # Yield the token column
278
+ yield token_column, token_masks
279
+ num_generated += 1
280
+
281
+ # Prepare logits for next iteration
282
+ ids = torch.tensor(token_column, dtype=torch.long, device=device).unsqueeze(1)
283
+ logits = self.model.forward(ids, kv_cache=kv_cache_decode)[:, -1, :] # (B, vocab_size)
284
+
285
+ def generate_batch(self, tokens, num_samples=1, **kwargs):
286
+ """
287
+ Non-streaming batch generation that just returns the final token sequences.
288
+ Returns a list of token sequences (list of lists of ints).
289
+ Terminal tokens (assistant_end, bos) are not included in the results.
290
+ """
291
+ assistant_end = self.tokenizer.encode_special("<|assistant_end|>")
292
+ bos = self.tokenizer.get_bos_token_id()
293
+ results = [tokens.copy() for _ in range(num_samples)]
294
+ masks = [[0] * len(tokens) for _ in range(num_samples)]
295
+ completed = [False] * num_samples
296
+ for token_column, token_masks in self.generate(tokens, num_samples, **kwargs):
297
+ for i, (token, mask) in enumerate(zip(token_column, token_masks)):
298
+ if not completed[i]:
299
+ if token == assistant_end or token == bos:
300
+ completed[i] = True
301
+ else:
302
+ results[i].append(token)
303
+ masks[i].append(mask)
304
+ # Stop if all rows are completed
305
+ if all(completed):
306
+ break
307
+ return results, masks
308
+
309
+
310
+ if __name__ == "__main__":
311
+ """
312
+ Quick inline test to make sure that the naive/slow model.generate function
313
+ is equivalent to the faster Engine.generate function here.
314
+ """
315
+ import time
316
+ # init compute
317
+ device_type = autodetect_device_type()
318
+ ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
319
+ # load the model and tokenizer
320
+ model, tokenizer, meta = load_model("base", device, phase="eval")
321
+ bos_token_id = tokenizer.get_bos_token_id()
322
+ # common hyperparameters
323
+ kwargs = dict(max_tokens=64, temperature=0.0)
324
+ # set the starting prompt
325
+ prompt_tokens = tokenizer.encode("The chemical formula of water is", prepend=bos_token_id)
326
+ # generate the reference sequence using the model.generate() function
327
+ generated_tokens = []
328
+ torch.cuda.synchronize()
329
+ t0 = time.time()
330
+ stream = model.generate(prompt_tokens, **kwargs)
331
+ for token in stream:
332
+ generated_tokens.append(token)
333
+ chunk = tokenizer.decode([token])
334
+ print(chunk, end="", flush=True)
335
+ print()
336
+ torch.cuda.synchronize()
337
+ t1 = time.time()
338
+ print(f"Reference time: {t1 - t0:.2f}s")
339
+ reference_ids = generated_tokens
340
+ # generate tokens with Engine
341
+ generated_tokens = []
342
+ engine = Engine(model, tokenizer)
343
+ stream = engine.generate(prompt_tokens, num_samples=1, **kwargs) # note: runs in fp32
344
+ torch.cuda.synchronize()
345
+ t0 = time.time()
346
+ for token_column, token_masks in stream:
347
+ token = token_column[0] # only print out the first row
348
+ generated_tokens.append(token)
349
+ chunk = tokenizer.decode([token])
350
+ print(chunk, end="", flush=True)
351
+ print()
352
+ torch.cuda.synchronize()
353
+ t1 = time.time()
354
+ print(f"Engine time: {t1 - t0:.2f}s")
355
+ # compare the two sequences
356
+ for i in range(len(reference_ids)):
357
+ if reference_ids[i] != generated_tokens[i]:
358
+ print(f"Mismatch at {i}: {reference_ids[i]} != {generated_tokens[i]}")
359
+ break
360
+ print(f"Match: {reference_ids == generated_tokens}")
nanochat/flash_attention.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Unified Flash Attention interface with automatic FA3/SDPA switching.
3
+
4
+ Exports `flash_attn` module that matches the FA3 API exactly, but falls back
5
+ to PyTorch SDPA on non-Hopper GPUs (including Blackwell), MPS, and CPU.
6
+
7
+ Usage (drop-in replacement for FA3):
8
+ from nanochat.flash_attention import flash_attn
9
+
10
+ # Training (no KV cache)
11
+ y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=window_size)
12
+
13
+ # Inference (with KV cache)
14
+ y = flash_attn.flash_attn_with_kvcache(q, k_cache, v_cache, k=k, v=v, ...)
15
+ """
16
+ import torch
17
+ import torch.nn.functional as F
18
+
19
+
20
+ # =============================================================================
21
+ # Detection: Try to load FA3 on Hopper+ GPUs
22
+ # =============================================================================
23
+ def _load_flash_attention_3():
24
+ """Try to load Flash Attention 3 (requires Hopper GPU, sm90)."""
25
+ if not torch.cuda.is_available():
26
+ return None
27
+ try:
28
+ major, _ = torch.cuda.get_device_capability()
29
+ # FA3 kernels are compiled for Hopper (sm90) only
30
+ # Ada (sm89), Blackwell (sm100) need SDPA fallback until FA3 is recompiled
31
+ if major != 9:
32
+ return None
33
+ import os
34
+ os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
35
+ from kernels import get_kernel
36
+ return get_kernel('varunneal/flash-attention-3').flash_attn_interface
37
+ except Exception:
38
+ return None
39
+
40
+
41
+ _fa3 = _load_flash_attention_3()
42
+ HAS_FA3 = _fa3 is not None
43
+
44
+ # Override for testing: set to 'fa3', 'sdpa', or None (auto)
45
+ _override_impl = None
46
+
47
+
48
+ def _resolve_use_fa3():
49
+ """Decide once whether to use FA3, based on availability, override, and dtype."""
50
+ if _override_impl == 'fa3':
51
+ assert HAS_FA3, "Cannot override to FA3: not available on this hardware"
52
+ return True
53
+ if _override_impl == 'sdpa':
54
+ return False
55
+ if HAS_FA3:
56
+ # FA3 Hopper kernels only support bf16 and fp8; fp16/fp32 must use SDPA fallback
57
+ from nanochat.common import COMPUTE_DTYPE
58
+ if COMPUTE_DTYPE == torch.bfloat16:
59
+ return True
60
+ return False
61
+ return False
62
+
63
+ USE_FA3 = _resolve_use_fa3()
64
+
65
+
66
+ # =============================================================================
67
+ # SDPA helpers
68
+ # =============================================================================
69
+ def _sdpa_attention(q, k, v, window_size, enable_gqa):
70
+ """
71
+ SDPA attention with sliding window support.
72
+ q, k, v are (B, H, T, D) format.
73
+ """
74
+ Tq = q.size(2)
75
+ Tk = k.size(2)
76
+ window = window_size[0]
77
+
78
+ # Full context, same length
79
+ if (window < 0 or window >= Tq) and Tq == Tk:
80
+ return F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa)
81
+
82
+ # Single token generation
83
+ if Tq == 1:
84
+ if window >= 0 and window < Tk:
85
+ # window is "left" tokens we need to include (window + 1) keys total
86
+ start = max(0, Tk - (window + 1))
87
+ k = k[:, :, start:, :]
88
+ v = v[:, :, start:, :]
89
+ return F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa)
90
+
91
+ # Need explicit mask for sliding window/chunk inference
92
+ device = q.device
93
+ # For chunk inference (Tq != Tk), is_causal is not aligned to cache position => build an explicit bool mask
94
+ row_idx = (Tk - Tq) + torch.arange(Tq, device=device).unsqueeze(1)
95
+ col_idx = torch.arange(Tk, device=device).unsqueeze(0)
96
+ mask = col_idx <= row_idx
97
+
98
+ # sliding window (left)
99
+ if window >= 0 and window < Tk:
100
+ mask = mask & ((row_idx - col_idx) <= window)
101
+
102
+ return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, enable_gqa=enable_gqa)
103
+
104
+ # =============================================================================
105
+ # Public API: Same interface as FA3
106
+ # =============================================================================
107
+ def flash_attn_func(q, k, v, causal=False, window_size=(-1, -1)):
108
+ """
109
+ Flash Attention for training (no KV cache).
110
+
111
+ Args:
112
+ q, k, v: Tensors of shape (B, T, H, D)
113
+ causal: Whether to use causal masking
114
+ window_size: (left, right) sliding window. -1 means unlimited.
115
+
116
+ Returns:
117
+ Output tensor of shape (B, T, H, D)
118
+ """
119
+ if USE_FA3:
120
+ return _fa3.flash_attn_func(q, k, v, causal=causal, window_size=window_size)
121
+
122
+ # SDPA fallback: transpose (B, T, H, D) -> (B, H, T, D)
123
+ q = q.transpose(1, 2)
124
+ k = k.transpose(1, 2)
125
+ v = v.transpose(1, 2)
126
+ enable_gqa = q.size(1) != k.size(1)
127
+ y = _sdpa_attention(q, k, v, window_size, enable_gqa)
128
+ return y.transpose(1, 2) # back to (B, T, H, D)
129
+
130
+
131
+ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=None,
132
+ causal=False, window_size=(-1, -1)):
133
+ """
134
+ Flash Attention with KV cache for inference.
135
+
136
+ FA3 updates k_cache/v_cache in-place. Our SDPA fallback does the same.
137
+
138
+ Args:
139
+ q: Queries, shape (B, T_new, H, D)
140
+ k_cache, v_cache: Pre-allocated cache tensors, shape (B, T_max, H_kv, D)
141
+ k, v: New keys/values to insert, shape (B, T_new, H_kv, D)
142
+ cache_seqlens: Current position in cache, shape (B,) int32
143
+ causal: Whether to use causal masking
144
+ window_size: (left, right) sliding window. -1 means unlimited.
145
+
146
+ Returns:
147
+ Output tensor of shape (B, T_new, H, D)
148
+ """
149
+ if USE_FA3:
150
+ return _fa3.flash_attn_with_kvcache(
151
+ q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens,
152
+ causal=causal, window_size=window_size
153
+ )
154
+
155
+ # SDPA fallback: manually manage KV cache
156
+ B, T_new, H, D = q.shape
157
+ pos = cache_seqlens[0].item() # assume uniform position across batch
158
+
159
+ # Insert new k, v into cache (in-place, matching FA3 behavior)
160
+ if k is not None and v is not None:
161
+ k_cache[:, pos:pos+T_new, :, :] = k
162
+ v_cache[:, pos:pos+T_new, :, :] = v
163
+
164
+ # Get full cache up to current position + new tokens
165
+ end_pos = pos + T_new
166
+ k_full = k_cache[:, :end_pos, :, :]
167
+ v_full = v_cache[:, :end_pos, :, :]
168
+
169
+ # Transpose to SDPA layout: (B, T, H, D) -> (B, H, T, D)
170
+ q_sdpa = q.transpose(1, 2)
171
+ k_sdpa = k_full.transpose(1, 2)
172
+ v_sdpa = v_full.transpose(1, 2)
173
+
174
+ enable_gqa = q_sdpa.size(1) != k_sdpa.size(1)
175
+ y_sdpa = _sdpa_attention(q_sdpa, k_sdpa, v_sdpa, window_size, enable_gqa)
176
+
177
+ return y_sdpa.transpose(1, 2) # back to (B, T, H, D)
178
+
179
+
180
+ # =============================================================================
181
+ # Export: flash_attn module interface (drop-in replacement for FA3)
182
+ # =============================================================================
183
+ from types import SimpleNamespace
184
+ flash_attn = SimpleNamespace(
185
+ flash_attn_func=flash_attn_func,
186
+ flash_attn_with_kvcache=flash_attn_with_kvcache,
187
+ )
nanochat/gpt.py ADDED
@@ -0,0 +1,465 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GPT model (rewrite, a lot simpler)
3
+ Notable features:
4
+ - rotary embeddings (and no positional embeddings)
5
+ - QK norm
6
+ - untied weights for token embedding and lm_head
7
+ - relu^2 activation in MLP
8
+ - norm after token embedding
9
+ - no learnable params in rmsnorm
10
+ - no bias in linear layers
11
+ - Group-Query Attention (GQA) support for more efficient inference
12
+ - Flash Attention 3 integration
13
+ """
14
+
15
+ from functools import partial
16
+ from dataclasses import dataclass
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+
22
+ from nanochat.common import get_dist_info, print0, COMPUTE_DTYPE
23
+ from nanochat.optim import MuonAdamW, DistMuonAdamW
24
+
25
+ # Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere
26
+ from nanochat.flash_attention import flash_attn
27
+
28
+ @dataclass
29
+ class GPTConfig:
30
+ sequence_len: int = 2048
31
+ vocab_size: int = 32768
32
+ n_layer: int = 12
33
+ n_head: int = 6 # number of query heads
34
+ n_kv_head: int = 6 # number of key/value heads (GQA)
35
+ n_embd: int = 768
36
+ # Sliding window attention pattern string, tiled across layers. Final layer always L.
37
+ # Characters: L=long (full context), S=short (half context)
38
+ # Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long
39
+ window_pattern: str = "SSSL"
40
+
41
+
42
+ def norm(x):
43
+ return F.rms_norm(x, (x.size(-1),)) # note that this will run in bf16, seems ok
44
+
45
+ class Linear(nn.Linear):
46
+ """nn.Linear that casts weights to match input dtype in forward.
47
+ Replaces autocast: master weights stay fp32 for optimizer precision,
48
+ but matmuls run in the activation dtype (typically bf16 from embeddings)."""
49
+ def forward(self, x):
50
+ return F.linear(x, self.weight.to(dtype=x.dtype))
51
+
52
+
53
+ def has_ve(layer_idx, n_layer):
54
+ """Returns True if GPT layer should have Value Embedding (alternating, last layer always included)."""
55
+ return layer_idx % 2 == (n_layer - 1) % 2
56
+
57
+ def apply_rotary_emb(x, cos, sin):
58
+ assert x.ndim == 4 # multihead attention
59
+ d = x.shape[3] // 2
60
+ x1, x2 = x[..., :d], x[..., d:] # split up last dim into two halves
61
+ y1 = x1 * cos + x2 * sin # rotate pairs of dims
62
+ y2 = x1 * (-sin) + x2 * cos
63
+ return torch.cat([y1, y2], 3)
64
+
65
+ class CausalSelfAttention(nn.Module):
66
+ def __init__(self, config, layer_idx):
67
+ super().__init__()
68
+ self.layer_idx = layer_idx
69
+ self.n_head = config.n_head
70
+ self.n_kv_head = config.n_kv_head
71
+ self.n_embd = config.n_embd
72
+ self.head_dim = self.n_embd // self.n_head
73
+ assert self.n_embd % self.n_head == 0
74
+ assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0
75
+ self.c_q = Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
76
+ self.c_k = Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
77
+ self.c_v = Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
78
+ self.c_proj = Linear(self.n_embd, self.n_embd, bias=False)
79
+ self.ve_gate_channels = 32 # Victorian checkpoint patch
80
+ self.ve_gate = Linear(self.ve_gate_channels, self.n_kv_head, bias=False) if has_ve(layer_idx, config.n_layer) else None
81
+
82
+ def forward(self, x, ve, cos_sin, window_size, kv_cache):
83
+ B, T, C = x.size()
84
+
85
+ # Project the input to get queries, keys, and values
86
+ # Shape: (B, T, H, D) - FA3's native layout, no transpose needed!
87
+ q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
88
+ k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
89
+ v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
90
+
91
+ # Value residual (ResFormer): mix in value embedding with input-dependent gate per head
92
+ if ve is not None:
93
+ ve = ve.view(B, T, self.n_kv_head, self.head_dim)
94
+ gate = 3 * torch.sigmoid(self.ve_gate(x[..., :self.ve_gate_channels])) # (B, T, n_kv_head), range (0, 3)
95
+ v = v + gate.unsqueeze(-1) * ve
96
+
97
+ # Apply Rotary Embeddings to queries and keys to get relative positional encoding
98
+ cos, sin = cos_sin
99
+ q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
100
+ q, k = norm(q), norm(k) # QK norm
101
+ q = q * 1.15 # sharper attention (split scale between Q and K), TODO think through better
102
+ k = k * 1.15
103
+
104
+ # Flash Attention (FA3 on Hopper+, PyTorch SDPA fallback elsewhere)
105
+ # window_size is (left, right) tuple: (N, 0) for causal, (-1, 0) for full context
106
+ if kv_cache is None:
107
+ # Training: causal attention with optional sliding window
108
+ y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=window_size)
109
+ else:
110
+ # Inference: use flash_attn_with_kvcache which handles cache management
111
+ k_cache, v_cache = kv_cache.get_layer_cache(self.layer_idx)
112
+ y = flash_attn.flash_attn_with_kvcache(
113
+ q, k_cache, v_cache,
114
+ k=k, v=v,
115
+ cache_seqlens=kv_cache.cache_seqlens,
116
+ causal=True,
117
+ window_size=window_size,
118
+ )
119
+ # Advance position after last layer processes
120
+ if self.layer_idx == kv_cache.n_layers - 1:
121
+ kv_cache.advance(T)
122
+
123
+ # Re-assemble the heads and project back to residual stream
124
+ y = y.contiguous().view(B, T, -1)
125
+ y = self.c_proj(y)
126
+ return y
127
+
128
+
129
+ class MLP(nn.Module):
130
+ def __init__(self, config):
131
+ super().__init__()
132
+ self.c_fc = Linear(config.n_embd, 4 * config.n_embd, bias=False)
133
+ self.c_proj = Linear(4 * config.n_embd, config.n_embd, bias=False)
134
+
135
+ def forward(self, x):
136
+ x = self.c_fc(x)
137
+ x = F.relu(x).square()
138
+ x = self.c_proj(x)
139
+ return x
140
+
141
+
142
+ class Block(nn.Module):
143
+ def __init__(self, config, layer_idx):
144
+ super().__init__()
145
+ self.attn = CausalSelfAttention(config, layer_idx)
146
+ self.mlp = MLP(config)
147
+
148
+ def forward(self, x, ve, cos_sin, window_size, kv_cache):
149
+ x = x + self.attn(norm(x), ve, cos_sin, window_size, kv_cache)
150
+ x = x + self.mlp(norm(x))
151
+ return x
152
+
153
+
154
+ class GPT(nn.Module):
155
+ def __init__(self, config, pad_vocab_size_to=64):
156
+ """
157
+ NOTE a major footgun: this __init__ function runs in meta device context (!!)
158
+ Therefore, any calculations inside here are shapes and dtypes only, no actual data.
159
+ => We actually initialize all data (parameters, buffers, etc.) in init_weights() instead.
160
+ """
161
+ super().__init__()
162
+ self.config = config
163
+ # Compute per-layer window sizes for sliding window attention
164
+ # window_size is (left, right) tuple: (-1, 0) for full context, (N, 0) for sliding window
165
+ self.window_sizes = self._compute_window_sizes(config)
166
+ # Pad vocab for efficiency (DDP, tensor cores). This is just an optimization - outputs are cropped in forward().
167
+ # https://huggingface.co/docs/transformers/main_classes/model#transformers.PreTrainedModel.resize_token_embeddings
168
+ padded_vocab_size = ((config.vocab_size + pad_vocab_size_to - 1) // pad_vocab_size_to) * pad_vocab_size_to
169
+ if padded_vocab_size != config.vocab_size:
170
+ print0(f"Padding vocab_size from {config.vocab_size} to {padded_vocab_size} for efficiency")
171
+ self.transformer = nn.ModuleDict({
172
+ "wte": nn.Embedding(padded_vocab_size, config.n_embd),
173
+ "h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
174
+ })
175
+ self.lm_head = Linear(config.n_embd, padded_vocab_size, bias=False)
176
+ # Per-layer learnable scalars (inspired by modded-nanogpt)
177
+ # resid_lambdas: scales the residual stream at each layer (init 1.0 = neutral)
178
+ # x0_lambdas: blends initial embedding back in at each layer (init 0.0 = disabled)
179
+ # Separate parameters so they can have different optimizer treatment
180
+ self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) # fake init, real init in init_weights()
181
+ self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights()
182
+ # Value embeddings (ResFormer-style): alternating layers, last layer always included
183
+ head_dim = config.n_embd // config.n_head
184
+ kv_dim = config.n_kv_head * head_dim
185
+ self.value_embeds = nn.ModuleDict({str(i): nn.Embedding(padded_vocab_size, kv_dim) for i in range(config.n_layer) if has_ve(i, config.n_layer)})
186
+ # To support meta device initialization, we init the rotary embeddings here, but it's just "fake" meta tensors only.
187
+ # As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
188
+ # so let's just over-compute them by 10X, but assert fail if we ever reach that amount.
189
+ # In the future we can dynamically grow the cache, for now it's fine.
190
+ self.rotary_seq_len = config.sequence_len * 10 # 10X over-compute should be enough, TODO make nicer?
191
+ head_dim = config.n_embd // config.n_head
192
+ cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
193
+ self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
194
+ self.register_buffer("sin", sin, persistent=False)
195
+
196
+ @torch.no_grad()
197
+ def init_weights(self):
198
+ """
199
+ Initialize the full model in this one function for maximum clarity.
200
+
201
+ wte (embedding): normal, std=1.0
202
+ lm_head: normal, std=0.001
203
+ for each block:
204
+ attn.c_q: uniform, std=1/sqrt(n_embd)
205
+ attn.c_k: uniform, std=1/sqrt(n_embd)
206
+ attn.c_v: uniform, std=1/sqrt(n_embd)
207
+ attn.c_proj: zeros
208
+ mlp.c_fc: uniform, std=1/sqrt(n_embd)
209
+ mlp.c_proj: zeros
210
+ """
211
+
212
+ # Embedding and unembedding
213
+ torch.nn.init.normal_(self.transformer.wte.weight, mean=0.0, std=0.8)
214
+ torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=0.001)
215
+
216
+ # Transformer blocks: uniform init with bound = sqrt(3) * std (same standard deviation as normal)
217
+ n_embd = self.config.n_embd
218
+ s = 3**0.5 * n_embd**-0.5 # sqrt(3) multiplier makes sure Uniform achieves the same std as Normal
219
+ for block in self.transformer.h:
220
+ torch.nn.init.uniform_(block.attn.c_q.weight, -s, s) # weights use Uniform to avoid outliers
221
+ torch.nn.init.uniform_(block.attn.c_k.weight, -s, s)
222
+ torch.nn.init.uniform_(block.attn.c_v.weight, -s, s)
223
+ torch.nn.init.zeros_(block.attn.c_proj.weight) # projections are zero
224
+ torch.nn.init.uniform_(block.mlp.c_fc.weight, -s * 0.5, s * 0.5) # 0.5x init scale for c_fc
225
+ torch.nn.init.zeros_(block.mlp.c_proj.weight)
226
+
227
+ # Per-layer scalars
228
+ self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init
229
+ self.x0_lambdas.fill_(0.1) # 0.1 => small initial weight for skip connection to input embedding
230
+
231
+ # Value embeddings (init like c_v: uniform with same std)
232
+ for ve in self.value_embeds.values():
233
+ torch.nn.init.uniform_(ve.weight, -s, s)
234
+
235
+ # Gate weights init with small positive values so gates start slightly above neutral
236
+ for block in self.transformer.h:
237
+ if block.attn.ve_gate is not None:
238
+ torch.nn.init.uniform_(block.attn.ve_gate.weight, 0.0, 0.02)
239
+
240
+ # Rotary embeddings
241
+ head_dim = self.config.n_embd // self.config.n_head
242
+ cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
243
+ self.cos, self.sin = cos, sin
244
+
245
+ # Cast embeddings to COMPUTE_DTYPE: optimizer can tolerate reduced-precision
246
+ # embeddings and it saves memory. Exception: fp16 requires fp32 embeddings
247
+ # because GradScaler cannot unscale fp16 gradients.
248
+ if COMPUTE_DTYPE != torch.float16:
249
+ self.transformer.wte.to(dtype=COMPUTE_DTYPE)
250
+ for ve in self.value_embeds.values():
251
+ ve.to(dtype=COMPUTE_DTYPE)
252
+
253
+ def _precompute_rotary_embeddings(self, seq_len, head_dim, base=100000, device=None):
254
+ # TODO: bump base theta more? e.g. 100K is more common more recently
255
+ # autodetect the device from model embeddings
256
+ if device is None:
257
+ device = self.transformer.wte.weight.device
258
+ # stride the channels
259
+ channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
260
+ inv_freq = 1.0 / (base ** (channel_range / head_dim))
261
+ # stride the time steps
262
+ t = torch.arange(seq_len, dtype=torch.float32, device=device)
263
+ # calculate the rotation frequencies at each (time, channel) pair
264
+ freqs = torch.outer(t, inv_freq)
265
+ cos, sin = freqs.cos(), freqs.sin()
266
+ cos, sin = cos.to(COMPUTE_DTYPE), sin.to(COMPUTE_DTYPE)
267
+ cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting
268
+ return cos, sin
269
+
270
+ def _compute_window_sizes(self, config):
271
+ """
272
+ Compute per-layer window sizes for sliding window attention.
273
+
274
+ Returns list of (left, right) tuples for FA3's window_size parameter:
275
+ - left: how many tokens before current position to attend to (-1 = unlimited)
276
+ - right: how many tokens after current position to attend to (0 for causal)
277
+
278
+ Pattern string is tiled across layers. Final layer always gets L (full context).
279
+ Characters: L=long (full context), S=short (half context)
280
+ """
281
+ pattern = config.window_pattern.upper()
282
+ assert all(c in "SL" for c in pattern), f"Invalid window_pattern: {pattern}. Use only S and L."
283
+ # Map characters to window sizes
284
+ long_window = config.sequence_len
285
+ short_window = -(-long_window // 3 // 128) * 128 # ceil to FA3 tile size (2048 -> 768)
286
+ char_to_window = {
287
+ "L": (long_window, 0),
288
+ "S": (short_window, 0),
289
+ }
290
+ # Tile pattern across layers
291
+ window_sizes = []
292
+ for layer_idx in range(config.n_layer):
293
+ char = pattern[layer_idx % len(pattern)]
294
+ window_sizes.append(char_to_window[char])
295
+ # Final layer always gets full context
296
+ window_sizes[-1] = (long_window, 0)
297
+ return window_sizes
298
+
299
+ def get_device(self):
300
+ return self.transformer.wte.weight.device
301
+
302
+ def estimate_flops(self):
303
+ """
304
+ Return the estimated FLOPs per token for the model (forward + backward).
305
+ Each matmul weight parameter contributes 2 FLOPs (multiply *, accumulate +) in forward, and 2X that in backward => 2+4=6.
306
+ Cleanest explanation of this: https://medium.com/@dzmitrybahdanau/the-flops-calculus-of-language-model-training-3b19c1f025e4
307
+ On top of that, 12 * h * q * effective_seq_len accounts for key @ query matmul flops inside attention.
308
+ With sliding windows, effective_seq_len varies per layer (capped by window size).
309
+ Ref: https://arxiv.org/abs/2204.02311 (PaLM paper).
310
+ This is ~1% off from the exact formulas of Chinchilla paper, the difference is:
311
+ - Chinchilla counts the embedding layer as flops (? weird, it's just a lookup => we ignore)
312
+ - Chinchilla counts exp/sum/divide in attention softmax as flops (a little sus and very tiny => we ignore)
313
+ """
314
+ nparams = sum(p.numel() for p in self.parameters())
315
+ # Exclude non-matmul params: embeddings and per-layer scalars
316
+ value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values())
317
+ nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel +
318
+ self.resid_lambdas.numel() + self.x0_lambdas.numel())
319
+ h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
320
+ # Sum attention FLOPs per layer, accounting for sliding window
321
+ attn_flops = 0
322
+ for window_size in self.window_sizes:
323
+ window = window_size[0] # (left, right) tuple, we use left
324
+ effective_seq = t if window < 0 else min(window, t)
325
+ attn_flops += 12 * h * q * effective_seq
326
+ num_flops_per_token = 6 * (nparams - nparams_exclude) + attn_flops
327
+ return num_flops_per_token
328
+
329
+ def num_scaling_params(self):
330
+ """
331
+ Return detailed parameter counts for scaling law analysis.
332
+ Different papers use different conventions:
333
+ - Kaplan et al. excluded embedding parameters
334
+ - Chinchilla included all parameters
335
+ Ref: https://arxiv.org/abs/2203.15556 (Chinchilla paper)
336
+ Ref: https://arxiv.org/abs/2001.08361 (Kaplan et al. original scaling laws paper)
337
+
338
+ Returns a dict with counts for each parameter group, so downstream analysis
339
+ can experiment with which combination gives the cleanest scaling laws.
340
+ """
341
+ # Count each group separately (mirrors the grouping in setup_optimizers)
342
+ wte = sum(p.numel() for p in self.transformer.wte.parameters())
343
+ value_embeds = sum(p.numel() for p in self.value_embeds.parameters())
344
+ lm_head = sum(p.numel() for p in self.lm_head.parameters())
345
+ transformer_matrices = sum(p.numel() for p in self.transformer.h.parameters())
346
+ scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel()
347
+ total = wte + value_embeds + lm_head + transformer_matrices + scalars
348
+ assert total == sum(p.numel() for p in self.parameters()), "Parameter count mismatch"
349
+ return {
350
+ 'wte': wte,
351
+ 'value_embeds': value_embeds,
352
+ 'lm_head': lm_head,
353
+ 'transformer_matrices': transformer_matrices,
354
+ 'scalars': scalars,
355
+ 'total': total,
356
+ }
357
+
358
+ def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, scalar_lr=0.5):
359
+ model_dim = self.config.n_embd
360
+ ddp, rank, local_rank, world_size = get_dist_info()
361
+
362
+ # Separate out all parameters into groups
363
+ matrix_params = list(self.transformer.h.parameters())
364
+ value_embeds_params = list(self.value_embeds.parameters())
365
+ embedding_params = list(self.transformer.wte.parameters())
366
+ lm_head_params = list(self.lm_head.parameters())
367
+ resid_params = [self.resid_lambdas]
368
+ x0_params = [self.x0_lambdas]
369
+ assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params)
370
+
371
+ # Scale the LR for the AdamW parameters by ∝1/√dmodel (tuned for 768 dim model)
372
+ dmodel_lr_scale = (model_dim / 768) ** -0.5
373
+ print0(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}")
374
+
375
+ # Build param_groups with all required fields explicit
376
+ param_groups = [
377
+ # AdamW groups (embeddings, lm_head, scalars)
378
+ dict(kind='adamw', params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale, betas=(0.8, 0.96), eps=1e-10, weight_decay=0.01),
379
+ dict(kind='adamw', params=embedding_params, lr=embedding_lr * dmodel_lr_scale, betas=(0.8, 0.995), eps=1e-10, weight_decay=0.001),
380
+ dict(kind='adamw', params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale * 0.5, betas=(0.8, 0.995), eps=1e-10, weight_decay=0.01),
381
+ dict(kind='adamw', params=resid_params, lr=scalar_lr * 0.01, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.05),
382
+ dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0), # higher beta1 for x0
383
+ ]
384
+ # Muon groups (matrix params, grouped by shape for stacking)
385
+ for shape in sorted({p.shape for p in matrix_params}):
386
+ group_params = [p for p in matrix_params if p.shape == shape]
387
+ param_groups.append(dict(
388
+ kind='muon', params=group_params, lr=matrix_lr,
389
+ momentum=0.95, ns_steps=5, beta2=0.9, weight_decay=weight_decay,
390
+ ))
391
+
392
+ Factory = DistMuonAdamW if ddp else MuonAdamW
393
+ optimizer = Factory(param_groups)
394
+ for group in optimizer.param_groups:
395
+ group["initial_lr"] = group["lr"]
396
+ return optimizer
397
+
398
+ def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
399
+ B, T = idx.size()
400
+
401
+ # Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2))
402
+ assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
403
+ assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
404
+ assert self.cos.dtype == COMPUTE_DTYPE, f"Rotary embeddings must be in {COMPUTE_DTYPE}, got {self.cos.dtype}"
405
+ # if kv cache exists, we need to offset the rotary embeddings to the current position in the cache
406
+ T0 = 0 if kv_cache is None else kv_cache.get_pos()
407
+ cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length
408
+
409
+ # Forward the trunk of the Transformer
410
+ x = self.transformer.wte(idx) # embed current token
411
+ x = x.to(COMPUTE_DTYPE) # ensure activations are in compute dtype (no-op usually, but active for fp16 code path)
412
+ x = norm(x)
413
+ x0 = x # save initial normalized embedding for x0 residual
414
+ for i, block in enumerate(self.transformer.h):
415
+ x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
416
+ ve = self.value_embeds[str(i)](idx).to(x.dtype) if str(i) in self.value_embeds else None
417
+ x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache)
418
+ x = norm(x)
419
+
420
+ # Forward the lm_head (compute logits)
421
+ softcap = 15 # smoothly cap the logits to the range [-softcap, softcap]
422
+ logits = self.lm_head(x) # (B, T, padded_vocab_size) <- very big tensor, large amount of memory
423
+ logits = logits[..., :self.config.vocab_size] # slice to remove padding
424
+ logits = logits.float() # switch to fp32 for logit softcap and loss computation
425
+ logits = softcap * torch.tanh(logits / softcap) # squash the logits
426
+
427
+ if targets is not None:
428
+ # training: given the targets, compute and return the loss
429
+ # TODO experiment with chunked cross-entropy?
430
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction)
431
+ return loss
432
+ else:
433
+ # inference: just return the logits directly
434
+ return logits
435
+
436
+ @torch.inference_mode()
437
+ def generate(self, tokens, max_tokens, temperature=1.0, top_k=None, seed=42):
438
+ """
439
+ Naive autoregressive streaming inference.
440
+ To make it super simple, let's assume:
441
+ - batch size is 1
442
+ - ids and the yielded tokens are simple Python lists and ints
443
+ """
444
+ assert isinstance(tokens, list)
445
+ device = self.get_device()
446
+ rng = None
447
+ if temperature > 0:
448
+ rng = torch.Generator(device=device)
449
+ rng.manual_seed(seed)
450
+ ids = torch.tensor([tokens], dtype=torch.long, device=device) # add batch dim
451
+ for _ in range(max_tokens):
452
+ logits = self.forward(ids) # (B, T, vocab_size)
453
+ logits = logits[:, -1, :] # (B, vocab_size)
454
+ if top_k is not None and top_k > 0:
455
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
456
+ logits[logits < v[:, [-1]]] = -float('Inf')
457
+ if temperature > 0:
458
+ logits = logits / temperature
459
+ probs = F.softmax(logits, dim=-1)
460
+ next_ids = torch.multinomial(probs, num_samples=1, generator=rng)
461
+ else:
462
+ next_ids = torch.argmax(logits, dim=-1, keepdim=True)
463
+ ids = torch.cat((ids, next_ids), dim=1)
464
+ token = next_ids.item()
465
+ yield token
nanochat/logo.svg ADDED
nanochat/optim.py ADDED
@@ -0,0 +1,533 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A nice and efficient mixed AdamW/Muon Combined Optimizer.
3
+ Usually the embeddings and scalars go into AdamW, and the matrix parameters go into Muon.
4
+ Two versions are provided (MuonAdamW, DistMuonAdamW), for single GPU and distributed.
5
+
6
+ Addapted from: https://github.com/KellerJordan/modded-nanogpt
7
+ Further contributions from @karpathy and @chrisjmccormick.
8
+ """
9
+
10
+ import torch
11
+ import torch.distributed as dist
12
+ from torch import Tensor
13
+
14
+ # -----------------------------------------------------------------------------
15
+ """
16
+ Good old AdamW optimizer, fused kernel.
17
+ https://arxiv.org/abs/1711.05101
18
+ """
19
+
20
+ @torch.compile(dynamic=False, fullgraph=True)
21
+ def adamw_step_fused(
22
+ p: Tensor, # (32768, 768) - parameter tensor
23
+ grad: Tensor, # (32768, 768) - gradient, same shape as p
24
+ exp_avg: Tensor, # (32768, 768) - first moment, same shape as p
25
+ exp_avg_sq: Tensor, # (32768, 768) - second moment, same shape as p
26
+ step_t: Tensor, # () - 0-D CPU tensor, step count
27
+ lr_t: Tensor, # () - 0-D CPU tensor, learning rate
28
+ beta1_t: Tensor, # () - 0-D CPU tensor, beta1
29
+ beta2_t: Tensor, # () - 0-D CPU tensor, beta2
30
+ eps_t: Tensor, # () - 0-D CPU tensor, epsilon
31
+ wd_t: Tensor, # () - 0-D CPU tensor, weight decay
32
+ ) -> None:
33
+ """
34
+ Fused AdamW step: weight_decay -> momentum_update -> bias_correction -> param_update
35
+ All in one compiled graph to eliminate Python overhead between ops.
36
+ The 0-D CPU tensors avoid recompilation when hyperparameter values change.
37
+ """
38
+ # Weight decay (decoupled, applied before the update)
39
+ p.mul_(1 - lr_t * wd_t)
40
+ # Update running averages (lerp_ is cleaner and fuses well)
41
+ exp_avg.lerp_(grad, 1 - beta1_t)
42
+ exp_avg_sq.lerp_(grad.square(), 1 - beta2_t)
43
+ # Bias corrections
44
+ bias1 = 1 - beta1_t ** step_t
45
+ bias2 = 1 - beta2_t ** step_t
46
+ # Compute update and apply
47
+ denom = (exp_avg_sq / bias2).sqrt() + eps_t
48
+ step_size = lr_t / bias1
49
+ p.add_(exp_avg / denom, alpha=-step_size)
50
+
51
+ # -----------------------------------------------------------------------------
52
+ """
53
+ Muon optimizer adapted and simplified from modded-nanogpt.
54
+ https://github.com/KellerJordan/modded-nanogpt
55
+
56
+ Background:
57
+ Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
58
+ quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
59
+ of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
60
+ zero even beyond the point where the iteration no longer converges all the way to one everywhere
61
+ on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
62
+ where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
63
+ performance at all relative to UV^T, where USV^T = G is the SVD.
64
+
65
+ Here, an alternative to Newton-Schulz iteration with potentially better convergence properties:
66
+ Polar Express Sign Method for orthogonalization.
67
+ https://arxiv.org/pdf/2505.16932
68
+ by Noah Amsel, David Persson, Christopher Musco, Robert M. Gower.
69
+
70
+ NorMuon variance reduction: per-neuron/column adaptive learning rate that normalizes
71
+ update scales after orthogonalization (Muon's output has non-uniform scales across neurons).
72
+ https://arxiv.org/pdf/2510.05491
73
+
74
+ Some of the changes in nanochat implementation:
75
+ - Uses a simpler, more general approach to parameter grouping and stacking
76
+ - Uses a single fused kernel for the momentum -> polar_express -> variance_reduction -> update step
77
+ - Makes no assumptions about model architecture (e.g. that attention weights are fused into QKVO format)
78
+ """
79
+
80
+ # Coefficients for Polar Express (computed for num_iters=5, safety_factor=2e-2, cushion=2)
81
+ # From https://arxiv.org/pdf/2505.16932
82
+ polar_express_coeffs = [
83
+ (8.156554524902461, -22.48329292557795, 15.878769915207462),
84
+ (4.042929935166739, -2.808917465908714, 0.5000178451051316),
85
+ (3.8916678022926607, -2.772484153217685, 0.5060648178503393),
86
+ (3.285753657755655, -2.3681294933425376, 0.46449024233003106),
87
+ (2.3465413258596377, -1.7097828382687081, 0.42323551169305323),
88
+ ]
89
+
90
+ @torch.compile(dynamic=False, fullgraph=True)
91
+ def muon_step_fused(
92
+ stacked_grads: Tensor, # (12, 768, 3072) - stacked gradients
93
+ stacked_params: Tensor, # (12, 768, 3072) - stacked parameters
94
+ momentum_buffer: Tensor, # (12, 768, 3072) - first moment buffer
95
+ second_momentum_buffer: Tensor, # (12, 768, 1) or (12, 1, 3072) - factored second moment
96
+ momentum_t: Tensor, # () - 0-D CPU tensor, momentum coefficient
97
+ lr_t: Tensor, # () - 0-D CPU tensor, learning rate
98
+ wd_t: Tensor, # () - 0-D CPU tensor, weight decay
99
+ beta2_t: Tensor, # () - 0-D CPU tensor, beta2 for second moment
100
+ ns_steps: int, # 5 - number of Newton-Schulz/Polar Express iterations
101
+ red_dim: int, # -1 or -2 - reduction dimension for variance
102
+ ) -> None:
103
+ """
104
+ Fused Muon step: momentum -> polar_express -> variance_reduction -> cautious_update
105
+ All in one compiled graph to eliminate Python overhead between ops.
106
+ Some of the constants are 0-D CPU tensors to avoid recompilation when values change.
107
+ """
108
+
109
+ # Nesterov momentum
110
+ momentum = momentum_t.to(stacked_grads.dtype)
111
+ momentum_buffer.lerp_(stacked_grads, 1 - momentum)
112
+ g = stacked_grads.lerp_(momentum_buffer, momentum)
113
+
114
+ # Polar express
115
+ X = g.bfloat16()
116
+ X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.01 + 1e-6)
117
+ if g.size(-2) > g.size(-1): # Tall matrix
118
+ for a, b, c in polar_express_coeffs[:ns_steps]:
119
+ A = X.mT @ X
120
+ B = b * A + c * (A @ A)
121
+ X = a * X + X @ B
122
+ else: # Wide matrix (original math)
123
+ for a, b, c in polar_express_coeffs[:ns_steps]:
124
+ A = X @ X.mT
125
+ B = b * A + c * (A @ A)
126
+ X = a * X + B @ X
127
+ g = X
128
+
129
+ # Variance reduction
130
+ beta2 = beta2_t.to(g.dtype)
131
+ v_mean = g.float().square().mean(dim=red_dim, keepdim=True)
132
+ red_dim_size = g.size(red_dim)
133
+ v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size
134
+ v_norm = v_norm_sq.sqrt()
135
+ second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2)
136
+ step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt()
137
+ scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square()
138
+ v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt()
139
+ final_scale = step_size * (v_norm / v_norm_new.clamp_min(1e-10))
140
+ g = g * final_scale.to(g.dtype)
141
+
142
+ # Cautious weight decay + parameter update
143
+ lr = lr_t.to(g.dtype)
144
+ wd = wd_t.to(g.dtype)
145
+ mask = (g * stacked_params) >= 0
146
+ stacked_params.sub_(lr * g + lr * wd * stacked_params * mask)
147
+
148
+ # -----------------------------------------------------------------------------
149
+ # Single GPU version of the MuonAdamW optimizer.
150
+ # Used mostly for reference, debugging and testing.
151
+
152
+ class MuonAdamW(torch.optim.Optimizer):
153
+ """
154
+ Combined optimizer: Muon for 2D matrix params, AdamW for others, single GPU version.
155
+
156
+ AdamW - Fused AdamW optimizer step.
157
+
158
+ Muon - MomentUm Orthogonalized by Newton-schulz
159
+ https://kellerjordan.github.io/posts/muon/
160
+
161
+ Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
162
+ processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
163
+ matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
164
+ the advantage that it can be stably run in bfloat16 on the GPU.
165
+
166
+ Some warnings:
167
+ - The Muon optimizer should not be used for the embedding layer, the final fully connected layer,
168
+ or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW).
169
+ - To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions.
170
+
171
+ Arguments:
172
+ param_groups: List of dicts, each containing:
173
+ - 'params': List of parameters
174
+ - 'kind': 'adamw' or 'muon'
175
+ - For AdamW groups: 'lr', 'betas', 'eps', 'weight_decay'
176
+ - For Muon groups: 'lr', 'momentum', 'ns_steps', 'beta2', 'weight_decay'
177
+ """
178
+ def __init__(self, param_groups: list[dict]):
179
+ super().__init__(param_groups, defaults={})
180
+ # 0-D CPU tensors to avoid torch.compile recompilation when values change
181
+ # AdamW tensors
182
+ self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
183
+ self._adamw_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
184
+ self._adamw_beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
185
+ self._adamw_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
186
+ self._adamw_eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
187
+ self._adamw_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
188
+ # Muon tensors
189
+ self._muon_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
190
+ self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
191
+ self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
192
+ self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
193
+
194
+ def _step_adamw(self, group: dict) -> None:
195
+ """
196
+ AdamW update for each param in the group individually.
197
+ Lazy init the state, fill in all 0-D tensors, call the fused kernel.
198
+ """
199
+ for p in group['params']:
200
+ if p.grad is None:
201
+ continue
202
+ grad = p.grad
203
+ state = self.state[p]
204
+
205
+ # State init
206
+ if not state:
207
+ state['step'] = 0
208
+ state['exp_avg'] = torch.zeros_like(p)
209
+ state['exp_avg_sq'] = torch.zeros_like(p)
210
+ exp_avg = state['exp_avg']
211
+ exp_avg_sq = state['exp_avg_sq']
212
+ state['step'] += 1
213
+
214
+ # Fill 0-D tensors with current values
215
+ self._adamw_step_t.fill_(state['step'])
216
+ self._adamw_lr_t.fill_(group['lr'])
217
+ self._adamw_beta1_t.fill_(group['betas'][0])
218
+ self._adamw_beta2_t.fill_(group['betas'][1])
219
+ self._adamw_eps_t.fill_(group['eps'])
220
+ self._adamw_wd_t.fill_(group['weight_decay'])
221
+
222
+ # Fused update: weight_decay -> momentum -> bias_correction -> param_update
223
+ adamw_step_fused(
224
+ p, grad, exp_avg, exp_avg_sq,
225
+ self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t,
226
+ self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t,
227
+ )
228
+
229
+ def _step_muon(self, group: dict) -> None:
230
+ """
231
+ Muon update for all params in the group (stacked for efficiency).
232
+ Lazy init the state, fill in all 0-D tensors, call the fused kernel.
233
+ """
234
+ params: list[Tensor] = group['params']
235
+ if not params:
236
+ return
237
+
238
+ # Get or create group-level buffers (stored in first param's state for convenience)
239
+ p = params[0]
240
+ state = self.state[p]
241
+ num_params = len(params)
242
+ shape, device, dtype = p.shape, p.device, p.dtype
243
+
244
+ # Momentum for every individual parameter
245
+ if "momentum_buffer" not in state:
246
+ state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device)
247
+ momentum_buffer = state["momentum_buffer"]
248
+
249
+ # Second momentum buffer is factored, either per-row or per-column
250
+ if "second_momentum_buffer" not in state:
251
+ state_shape = (num_params, shape[-2], 1) if shape[-2] >= shape[-1] else (num_params, 1, shape[-1])
252
+ state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device)
253
+ second_momentum_buffer = state["second_momentum_buffer"]
254
+ red_dim = -1 if shape[-2] >= shape[-1] else -2
255
+
256
+ # Stack grads and params (NOTE: this assumes all params have the same shape)
257
+ stacked_grads = torch.stack([p.grad for p in params])
258
+ stacked_params = torch.stack(params)
259
+
260
+ # Fill all the 0-D tensors with current values
261
+ self._muon_momentum_t.fill_(group["momentum"])
262
+ self._muon_beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
263
+ self._muon_lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
264
+ self._muon_wd_t.fill_(group["weight_decay"])
265
+
266
+ # Single fused kernel: momentum -> polar_express -> variance_reduction -> update
267
+ muon_step_fused(
268
+ stacked_grads,
269
+ stacked_params,
270
+ momentum_buffer,
271
+ second_momentum_buffer,
272
+ self._muon_momentum_t,
273
+ self._muon_lr_t,
274
+ self._muon_wd_t,
275
+ self._muon_beta2_t,
276
+ group["ns_steps"],
277
+ red_dim,
278
+ )
279
+
280
+ # Copy back to original params
281
+ torch._foreach_copy_(params, list(stacked_params.unbind(0)))
282
+
283
+ @torch.no_grad()
284
+ def step(self):
285
+ for group in self.param_groups:
286
+ if group['kind'] == 'adamw':
287
+ self._step_adamw(group)
288
+ elif group['kind'] == 'muon':
289
+ self._step_muon(group)
290
+ else:
291
+ raise ValueError(f"Unknown optimizer kind: {group['kind']}")
292
+
293
+ # -----------------------------------------------------------------------------
294
+ # Distributed version of the MuonAdamW optimizer.
295
+ # Used for training on multiple GPUs.
296
+
297
+ class DistMuonAdamW(torch.optim.Optimizer):
298
+ """
299
+ Combined distributed optimizer: Muon for 2D matrix params, AdamW for others.
300
+
301
+ See MuonAdamW for the algorithmic details of each optimizer. This class adds
302
+ distributed communication to enable multi-GPU training without PyTorch DDP.
303
+
304
+ Design Goals:
305
+ - Overlap communication with computation (async ops)
306
+ - Minimize memory by sharding optimizer states across ranks (ZeRO-2 style)
307
+ - Batch small tensors into single comm ops where possible
308
+
309
+ Communication Pattern (3-phase async):
310
+ We use a 3-phase structure to maximize overlap between communication and compute:
311
+
312
+ Phase 1: Launch all async reduce ops
313
+ - Kick off all reduce_scatter/all_reduce operations
314
+ - Don't wait - let them run in background while we continue
315
+
316
+ Phase 2: Wait for reduces, compute updates, launch gathers
317
+ - For each group: wait for its reduce, compute the update, launch gather
318
+ - By processing groups in order, earlier gathers run while later computes happen
319
+
320
+ Phase 3: Wait for gathers, copy back
321
+ - Wait for all gathers to complete
322
+ - Copy updated params back to original tensors (Muon only)
323
+
324
+ AdamW Communication (ZeRO-2 style):
325
+ - Small params (<1024 elements): all_reduce gradients, update full param on each rank.
326
+ Optimizer state is replicated but these params are tiny (scalars, biases).
327
+ - Large params: reduce_scatter gradients so each rank gets 1/N of the grad, update
328
+ only that slice, then all_gather the updated slices. Optimizer state (exp_avg,
329
+ exp_avg_sq) is sharded - each rank only stores state for its slice.
330
+ Requires param.shape[0] divisible by world_size.
331
+
332
+ Muon Communication (stacked + chunked):
333
+ - All params in a Muon group must have the same shape (caller's responsibility).
334
+ - Stack all K params into a single (K, *shape) tensor for efficient comm.
335
+ - Divide K params across N ranks: each rank "owns" ceil(K/N) params.
336
+ - reduce_scatter the stacked grads so each rank gets its chunk.
337
+ - Each rank computes Muon update only for params it owns.
338
+ - all_gather the updated params back to all ranks.
339
+ - Optimizer state (momentum_buffer, second_momentum_buffer) is sharded by chunk.
340
+ - Padding: if K doesn't divide evenly, we zero-pad to (ceil(K/N) * N) for comm,
341
+ then ignore the padding when copying back.
342
+
343
+ Buffer Reuse:
344
+ - For Muon, we allocate stacked_grads for reduce_scatter input, then reuse the
345
+ same buffer as the output for all_gather (stacked_params). This saves memory
346
+ since we don't need both buffers simultaneously.
347
+
348
+ Arguments:
349
+ param_groups: List of dicts, each containing:
350
+ - 'params': List of parameters
351
+ - 'kind': 'adamw' or 'muon'
352
+ - For AdamW groups: 'lr', 'betas', 'eps', 'weight_decay'
353
+ - For Muon groups: 'lr', 'momentum', 'ns_steps', 'beta2', 'weight_decay'
354
+ """
355
+ def __init__(self, param_groups: list[dict]):
356
+ super().__init__(param_groups, defaults={})
357
+ # 0-D CPU tensors to avoid torch.compile recompilation when values change
358
+ self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
359
+ self._adamw_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
360
+ self._adamw_beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
361
+ self._adamw_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
362
+ self._adamw_eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
363
+ self._adamw_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
364
+ self._muon_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
365
+ self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
366
+ self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
367
+ self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
368
+
369
+ def _reduce_adamw(self, group: dict, world_size: int) -> dict:
370
+ """Launch async reduce ops for AdamW group. Returns info dict with per-param infos."""
371
+ param_infos = {}
372
+ for p in group['params']:
373
+ grad = p.grad
374
+ if p.numel() < 1024:
375
+ # Small params: all_reduce (no scatter/gather needed)
376
+ future = dist.all_reduce(grad, op=dist.ReduceOp.AVG, async_op=True).get_future()
377
+ param_infos[p] = dict(future=future, grad_slice=grad, is_small=True)
378
+ else:
379
+ # Large params: reduce_scatter
380
+ assert grad.shape[0] % world_size == 0, f"AdamW reduce_scatter requires shape[0] ({grad.shape[0]}) divisible by world_size ({world_size})"
381
+ rank_size = grad.shape[0] // world_size
382
+ grad_slice = torch.empty_like(grad[:rank_size])
383
+ future = dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()
384
+ param_infos[p] = dict(future=future, grad_slice=grad_slice, is_small=False)
385
+ return dict(param_infos=param_infos)
386
+
387
+ def _reduce_muon(self, group: dict, world_size: int) -> dict:
388
+ """Launch async reduce op for Muon group. Returns info dict."""
389
+ params = group['params']
390
+ chunk_size = (len(params) + world_size - 1) // world_size
391
+ padded_num_params = chunk_size * world_size
392
+ p = params[0]
393
+ shape, device, dtype = p.shape, p.device, p.dtype
394
+
395
+ # Stack grads and zero-pad to padded_num_params
396
+ grad_stack = torch.stack([p.grad for p in params])
397
+ stacked_grads = torch.empty(padded_num_params, *shape, dtype=dtype, device=device)
398
+ stacked_grads[:len(params)].copy_(grad_stack)
399
+ if len(params) < padded_num_params:
400
+ stacked_grads[len(params):].zero_()
401
+
402
+ # Reduce_scatter to get this rank's chunk
403
+ grad_chunk = torch.empty(chunk_size, *shape, dtype=dtype, device=device)
404
+ future = dist.reduce_scatter_tensor(grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True).get_future()
405
+
406
+ return dict(future=future, grad_chunk=grad_chunk, stacked_grads=stacked_grads, chunk_size=chunk_size)
407
+
408
+ def _compute_adamw(self, group: dict, info: dict, gather_list: list, rank: int, world_size: int) -> None:
409
+ """Wait for reduce, compute AdamW updates, launch gathers for large params."""
410
+ param_infos = info['param_infos']
411
+ for p in group['params']:
412
+ pinfo = param_infos[p]
413
+ pinfo['future'].wait()
414
+ grad_slice = pinfo['grad_slice']
415
+ state = self.state[p]
416
+
417
+ # For small params, operate on full param; for large, operate on slice
418
+ if pinfo['is_small']:
419
+ p_slice = p
420
+ else:
421
+ rank_size = p.shape[0] // world_size
422
+ p_slice = p[rank * rank_size:(rank + 1) * rank_size]
423
+
424
+ # State init
425
+ if not state:
426
+ state['step'] = 0
427
+ state['exp_avg'] = torch.zeros_like(p_slice)
428
+ state['exp_avg_sq'] = torch.zeros_like(p_slice)
429
+ state['step'] += 1
430
+
431
+ # Fill 0-D tensors and run fused kernel
432
+ self._adamw_step_t.fill_(state['step'])
433
+ self._adamw_lr_t.fill_(group['lr'])
434
+ self._adamw_beta1_t.fill_(group['betas'][0])
435
+ self._adamw_beta2_t.fill_(group['betas'][1])
436
+ self._adamw_eps_t.fill_(group['eps'])
437
+ self._adamw_wd_t.fill_(group['weight_decay'])
438
+ adamw_step_fused(
439
+ p_slice, grad_slice, state['exp_avg'], state['exp_avg_sq'],
440
+ self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t,
441
+ self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t,
442
+ )
443
+
444
+ # Large params need all_gather
445
+ if not pinfo['is_small']:
446
+ future = dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()
447
+ gather_list.append(dict(future=future, params=None))
448
+
449
+ def _compute_muon(self, group: dict, info: dict, gather_list: list, rank: int) -> None:
450
+ """Wait for reduce, compute Muon updates, launch gather."""
451
+ info['future'].wait()
452
+ params = group['params']
453
+ chunk_size = info['chunk_size']
454
+ grad_chunk = info['grad_chunk']
455
+ p = params[0]
456
+ shape, device, dtype = p.shape, p.device, p.dtype
457
+
458
+ # How many params does this rank own?
459
+ start_idx = rank * chunk_size
460
+ num_owned = min(chunk_size, max(0, len(params) - start_idx))
461
+
462
+ # Get or create group-level state
463
+ state = self.state[p]
464
+ if "momentum_buffer" not in state:
465
+ state["momentum_buffer"] = torch.zeros(chunk_size, *shape, dtype=dtype, device=device)
466
+ if "second_momentum_buffer" not in state:
467
+ state_shape = (chunk_size, shape[-2], 1) if shape[-2] >= shape[-1] else (chunk_size, 1, shape[-1])
468
+ state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device)
469
+ red_dim = -1 if shape[-2] >= shape[-1] else -2
470
+
471
+ # Build output buffer for all_gather
472
+ updated_params = torch.empty(chunk_size, *shape, dtype=dtype, device=device)
473
+
474
+ if num_owned > 0:
475
+ owned_params = [params[start_idx + i] for i in range(num_owned)]
476
+ stacked_owned = torch.stack(owned_params)
477
+
478
+ # Fill 0-D tensors and run fused kernel
479
+ self._muon_momentum_t.fill_(group["momentum"])
480
+ self._muon_beta2_t.fill_(group["beta2"])
481
+ self._muon_lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
482
+ self._muon_wd_t.fill_(group["weight_decay"])
483
+ muon_step_fused(
484
+ grad_chunk[:num_owned], stacked_owned,
485
+ state["momentum_buffer"][:num_owned], state["second_momentum_buffer"][:num_owned],
486
+ self._muon_momentum_t, self._muon_lr_t, self._muon_wd_t, self._muon_beta2_t,
487
+ group["ns_steps"], red_dim,
488
+ )
489
+ updated_params[:num_owned].copy_(stacked_owned)
490
+
491
+ if num_owned < chunk_size:
492
+ updated_params[num_owned:].zero_()
493
+
494
+ # Reuse stacked_grads buffer for all_gather output
495
+ stacked_params = info["stacked_grads"]
496
+ future = dist.all_gather_into_tensor(stacked_params, updated_params, async_op=True).get_future()
497
+ gather_list.append(dict(future=future, stacked_params=stacked_params, params=params))
498
+
499
+ def _finish_gathers(self, gather_list: list) -> None:
500
+ """Wait for all gathers and copy Muon params back."""
501
+ for info in gather_list:
502
+ info["future"].wait()
503
+ if info["params"] is not None:
504
+ # Muon: copy from stacked buffer back to individual params
505
+ torch._foreach_copy_(info["params"], list(info["stacked_params"][:len(info["params"])].unbind(0)))
506
+
507
+ @torch.no_grad()
508
+ def step(self):
509
+ rank = dist.get_rank()
510
+ world_size = dist.get_world_size()
511
+
512
+ # Phase 1: launch all async reduce ops
513
+ reduce_infos: list[dict] = []
514
+ for group in self.param_groups:
515
+ if group['kind'] == 'adamw':
516
+ reduce_infos.append(self._reduce_adamw(group, world_size))
517
+ elif group['kind'] == 'muon':
518
+ reduce_infos.append(self._reduce_muon(group, world_size))
519
+ else:
520
+ raise ValueError(f"Unknown optimizer kind: {group['kind']}")
521
+
522
+ # Phase 2: wait for reduces, compute updates, launch gathers
523
+ gather_list: list[dict] = []
524
+ for group, info in zip(self.param_groups, reduce_infos):
525
+ if group['kind'] == 'adamw':
526
+ self._compute_adamw(group, info, gather_list, rank, world_size)
527
+ elif group['kind'] == 'muon':
528
+ self._compute_muon(group, info, gather_list, rank)
529
+ else:
530
+ raise ValueError(f"Unknown optimizer kind: {group['kind']}")
531
+
532
+ # Phase 3: wait for gathers, copy back
533
+ self._finish_gathers(gather_list)
nanochat/tokenizer.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tokenizer module — patched for Victorian LLM HuggingFace Space.
3
+
4
+ Delegates to tokenizer_wrapper.py which provides the VictorianTokenizer class.
5
+ """
6
+
7
+ import sys
8
+ import os
9
+
10
+ # Ensure the app root is on the path so tokenizer_wrapper can be imported
11
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
12
+
13
+ from tokenizer_wrapper import get_tokenizer, get_token_bytes
14
+ __all__ = ["get_tokenizer", "get_token_bytes"]
nanochat/ui.html ADDED
@@ -0,0 +1,566 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0, viewport-fit=cover">
6
+ <title>NanoChat</title>
7
+ <link rel="icon" type="image/svg+xml" href="/logo.svg">
8
+ <style>
9
+ :root {
10
+ color-scheme: light;
11
+ }
12
+
13
+ * {
14
+ box-sizing: border-box;
15
+ }
16
+
17
+ html, body{
18
+ height: 100%;
19
+ margin: 0;
20
+ }
21
+
22
+ body {
23
+ font-family: ui-sans-serif, -apple-system, system-ui, "Segoe UI", Helvetica, "Apple Color Emoji", Arial, sans-serif, "Segoe UI Emoji", "Segoe UI Symbol";
24
+ background-color: #ffffff;
25
+ color: #111827;
26
+ min-height: 100dvh;
27
+ margin: 0;
28
+ display: flex;
29
+ flex-direction: column;
30
+ }
31
+
32
+ .header {
33
+ background-color: #ffffff;
34
+ padding: 1.25rem 1.5rem;
35
+ }
36
+
37
+ .header-left {
38
+ display: flex;
39
+ align-items: center;
40
+ gap: 0.75rem;
41
+ }
42
+
43
+ .header-logo {
44
+ height: 32px;
45
+ width: auto;
46
+ }
47
+
48
+ .header h1 {
49
+ font-size: 1.25rem;
50
+ font-weight: 600;
51
+ margin: 0;
52
+ color: #111827;
53
+ }
54
+
55
+ .new-conversation-btn {
56
+ width: 32px;
57
+ height: 32px;
58
+ padding: 0;
59
+ border: 1px solid #e5e7eb;
60
+ border-radius: 0.5rem;
61
+ background-color: #ffffff;
62
+ color: #6b7280;
63
+ cursor: pointer;
64
+ display: flex;
65
+ align-items: center;
66
+ justify-content: center;
67
+ transition: all 0.2s ease;
68
+ }
69
+
70
+ .new-conversation-btn:hover {
71
+ background-color: #f3f4f6;
72
+ border-color: #d1d5db;
73
+ color: #374151;
74
+ }
75
+
76
+ .chat-container {
77
+ flex: 1;
78
+ overflow-y: auto;
79
+ background-color: #ffffff;
80
+ }
81
+
82
+ .chat-wrapper {
83
+ max-width: 48rem;
84
+ margin: 0 auto;
85
+ padding: 2rem 1.5rem 3rem;
86
+ display: flex;
87
+ flex-direction: column;
88
+ gap: 0.75rem;
89
+ }
90
+
91
+ .message {
92
+ display: flex;
93
+ justify-content: flex-start;
94
+ margin-bottom: 0.5rem;
95
+ color: #0d0d0d;
96
+ }
97
+
98
+ .message.assistant {
99
+ justify-content: flex-start;
100
+ }
101
+
102
+ .message.user {
103
+ justify-content: flex-end;
104
+ }
105
+
106
+ .message-content {
107
+ white-space: pre-wrap;
108
+ line-height: 1.6;
109
+ max-width: 100%;
110
+ }
111
+
112
+ .message.assistant .message-content {
113
+ background: transparent;
114
+ border: none;
115
+ cursor: pointer;
116
+ border-radius: 0.5rem;
117
+ padding: 0.5rem;
118
+ margin-left: -0.5rem;
119
+ transition: background-color 0.2s ease;
120
+ }
121
+
122
+ .message.assistant .message-content:hover {
123
+ background-color: #f9fafb;
124
+ }
125
+
126
+ .message.user .message-content {
127
+ background-color: #f3f4f6;
128
+ border-radius: 1.25rem;
129
+ padding: 0.8rem 1rem;
130
+ max-width: 65%;
131
+ cursor: pointer;
132
+ transition: background-color 0.2s ease;
133
+ }
134
+
135
+ .message.user .message-content:hover {
136
+ background-color: #e5e7eb;
137
+ }
138
+
139
+ .message.console .message-content {
140
+ font-family: 'Monaco', 'Menlo', 'Ubuntu Mono', 'Consolas', 'Courier New', monospace;
141
+ font-size: 0.875rem;
142
+ background-color: #fafafa;
143
+ padding: 0.75rem 1rem;
144
+ color: #374151;
145
+ max-width: 80%;
146
+ }
147
+
148
+ .input-container {
149
+ background-color: #ffffff;
150
+ padding: 1rem;
151
+ padding-bottom: calc(1rem + env(safe-area-inset-bottom))
152
+ }
153
+
154
+ .input-wrapper {
155
+ max-width: 48rem;
156
+ margin: 0 auto;
157
+ display: flex;
158
+ gap: 0.75rem;
159
+ align-items: flex-end;
160
+ }
161
+
162
+ .chat-input {
163
+ flex: 1;
164
+ padding: 0.8rem 1rem;
165
+ border: 1px solid #d1d5db;
166
+ border-radius: 0.75rem;
167
+ background-color: #ffffff;
168
+ color: #111827;
169
+ font-size: 1rem;
170
+ line-height: 1.5;
171
+ resize: none;
172
+ outline: none;
173
+ min-height: 54px;
174
+ max-height: 200px;
175
+ transition: border-color 0.2s ease, box-shadow 0.2s ease;
176
+ }
177
+
178
+ .chat-input::placeholder {
179
+ color: #9ca3af;
180
+ }
181
+
182
+ .chat-input:focus {
183
+ border-color: #2563eb;
184
+ box-shadow: 0 0 0 3px rgba(37, 99, 235, 0.1);
185
+ }
186
+
187
+ .send-button {
188
+ flex-shrink: 0;
189
+ padding: 0;
190
+ width: 54px;
191
+ height: 54px;
192
+ border: 1px solid #111827;
193
+ border-radius: 0.75rem;
194
+ background-color: #111827;
195
+ color: #ffffff;
196
+ display: flex;
197
+ align-items: center;
198
+ justify-content: center;
199
+ cursor: pointer;
200
+ transition: background-color 0.2s ease, border-color 0.2s ease, color 0.2s ease;
201
+ }
202
+
203
+ .send-button:hover:not(:disabled) {
204
+ background-color: #2563eb;
205
+ border-color: #2563eb;
206
+ }
207
+
208
+ .send-button:disabled {
209
+ cursor: not-allowed;
210
+ border-color: #d1d5db;
211
+ background-color: #e5e7eb;
212
+ color: #9ca3af;
213
+ }
214
+
215
+ .typing-indicator {
216
+ display: inline-block;
217
+ color: #6b7280;
218
+ letter-spacing: 0.15em;
219
+ }
220
+
221
+ .typing-indicator::after {
222
+ content: '···';
223
+ animation: typing 1.4s infinite;
224
+ }
225
+
226
+ @keyframes typing {
227
+ 0%, 60%, 100% { opacity: 0.2; }
228
+ 30% { opacity: 1; }
229
+ }
230
+
231
+ .error-message {
232
+ background-color: #fee2e2;
233
+ border: 1px solid #fecaca;
234
+ color: #b91c1c;
235
+ padding: 0.75rem 1rem;
236
+ border-radius: 0.75rem;
237
+ margin-top: 0.5rem;
238
+ }
239
+ </style>
240
+ </head>
241
+ <body>
242
+ <div class="header">
243
+ <div class="header-left">
244
+ <button class="new-conversation-btn" onclick="newConversation()" title="New Conversation (Ctrl+Shift+N)">
245
+ <svg width="18" height="18" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
246
+ <path d="M12 5v14"></path>
247
+ <path d="M5 12h14"></path>
248
+ </svg>
249
+ </button>
250
+ <h1>nanochat</h1>
251
+ </div>
252
+ </div>
253
+
254
+ <div class="chat-container" id="chatContainer">
255
+ <div class="chat-wrapper" id="chatWrapper">
256
+ <!-- Messages will be added here -->
257
+ </div>
258
+ </div>
259
+
260
+ <div class="input-container">
261
+ <div class="input-wrapper">
262
+ <textarea
263
+ id="chatInput"
264
+ class="chat-input"
265
+ placeholder="Ask anything"
266
+ rows="1"
267
+ onkeydown="handleKeyDown(event)"
268
+ ></textarea>
269
+ <button id="sendButton" class="send-button" onclick="sendMessage()" disabled>
270
+ <svg width="22" height="22" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
271
+ <path d="M22 2L11 13"></path>
272
+ <path d="M22 2l-7 20-4-9-9-4 20-7z"></path>
273
+ </svg>
274
+ </button>
275
+ </div>
276
+ </div>
277
+
278
+ <script>
279
+ const API_URL = '';
280
+ const chatContainer = document.getElementById('chatContainer');
281
+ const chatWrapper = document.getElementById('chatWrapper');
282
+ const chatInput = document.getElementById('chatInput');
283
+ const sendButton = document.getElementById('sendButton');
284
+
285
+ let messages = [];
286
+ let isGenerating = false;
287
+ let currentTemperature = 0.8;
288
+ let currentTopK = 50;
289
+
290
+ chatInput.addEventListener('input', function() {
291
+ this.style.height = 'auto';
292
+ this.style.height = Math.min(this.scrollHeight, 200) + 'px';
293
+ sendButton.disabled = !this.value.trim() || isGenerating;
294
+ });
295
+
296
+ function handleKeyDown(event) {
297
+ if (event.key === 'Enter' && !event.shiftKey) {
298
+ event.preventDefault();
299
+ sendMessage();
300
+ }
301
+ }
302
+
303
+ document.addEventListener('keydown', function(event) {
304
+ // Ctrl+Shift+N for new conversation
305
+ if (event.ctrlKey && event.shiftKey && event.key === 'N') {
306
+ event.preventDefault();
307
+ if (!isGenerating) {
308
+ newConversation();
309
+ }
310
+ }
311
+ });
312
+
313
+ function newConversation() {
314
+ messages = [];
315
+ chatWrapper.innerHTML = '';
316
+ chatInput.value = '';
317
+ chatInput.style.height = 'auto';
318
+ sendButton.disabled = false;
319
+ isGenerating = false;
320
+ chatInput.focus();
321
+ }
322
+
323
+ function addMessage(role, content, messageIndex = null) {
324
+ const messageDiv = document.createElement('div');
325
+ messageDiv.className = `message ${role}`;
326
+
327
+ const contentDiv = document.createElement('div');
328
+ contentDiv.className = 'message-content';
329
+ contentDiv.textContent = content;
330
+
331
+ // Add click handler for user messages to enable editing
332
+ if (role === 'user' && messageIndex !== null) {
333
+ contentDiv.setAttribute('data-message-index', messageIndex);
334
+ contentDiv.setAttribute('title', 'Click to edit and restart from here');
335
+ contentDiv.addEventListener('click', function() {
336
+ if (!isGenerating) {
337
+ editMessage(messageIndex);
338
+ }
339
+ });
340
+ }
341
+
342
+ // Add click handler for assistant messages to enable regeneration
343
+ if (role === 'assistant' && messageIndex !== null) {
344
+ contentDiv.setAttribute('data-message-index', messageIndex);
345
+ contentDiv.setAttribute('title', 'Click to regenerate this response');
346
+ contentDiv.addEventListener('click', function() {
347
+ if (!isGenerating) {
348
+ regenerateMessage(messageIndex);
349
+ }
350
+ });
351
+ }
352
+
353
+ messageDiv.appendChild(contentDiv);
354
+ chatWrapper.appendChild(messageDiv);
355
+
356
+ chatContainer.scrollTop = chatContainer.scrollHeight;
357
+ return contentDiv;
358
+ }
359
+
360
+ function editMessage(messageIndex) {
361
+ // Find the message in the messages array
362
+ if (messageIndex < 0 || messageIndex >= messages.length) return;
363
+
364
+ const messageToEdit = messages[messageIndex];
365
+ if (messageToEdit.role !== 'user') return;
366
+
367
+ // Copy message content to input
368
+ chatInput.value = messageToEdit.content;
369
+ chatInput.style.height = 'auto';
370
+ chatInput.style.height = Math.min(chatInput.scrollHeight, 200) + 'px';
371
+
372
+ // Remove this message and all subsequent messages from the array
373
+ messages = messages.slice(0, messageIndex);
374
+
375
+ // Remove message elements from DOM starting from messageIndex
376
+ const allMessages = chatWrapper.querySelectorAll('.message');
377
+ for (let i = messageIndex; i < allMessages.length; i++) {
378
+ allMessages[i].remove();
379
+ }
380
+
381
+ // Enable send button and focus input
382
+ sendButton.disabled = false;
383
+ chatInput.focus();
384
+ }
385
+
386
+ async function generateAssistantResponse() {
387
+ isGenerating = true;
388
+ sendButton.disabled = true;
389
+
390
+ const assistantContent = addMessage('assistant', '');
391
+ assistantContent.innerHTML = '<span class="typing-indicator"></span>';
392
+
393
+ try {
394
+ const response = await fetch(`${API_URL}/chat/completions`, {
395
+ method: 'POST',
396
+ headers: {
397
+ 'Content-Type': 'application/json',
398
+ },
399
+ body: JSON.stringify({
400
+ messages: messages,
401
+ temperature: currentTemperature,
402
+ top_k: currentTopK,
403
+ max_tokens: 512
404
+ }),
405
+ });
406
+
407
+ if (!response.ok) {
408
+ throw new Error(`HTTP error! status: ${response.status}`);
409
+ }
410
+
411
+ const reader = response.body.getReader();
412
+ const decoder = new TextDecoder();
413
+ let fullResponse = '';
414
+ assistantContent.textContent = '';
415
+
416
+ while (true) {
417
+ const { done, value } = await reader.read();
418
+ if (done) break;
419
+
420
+ const chunk = decoder.decode(value);
421
+ const lines = chunk.split('\n');
422
+
423
+ for (const line of lines) {
424
+ if (line.startsWith('data: ')) {
425
+ try {
426
+ const data = JSON.parse(line.slice(6));
427
+ if (data.token) {
428
+ fullResponse += data.token;
429
+ assistantContent.textContent = fullResponse;
430
+ chatContainer.scrollTop = chatContainer.scrollHeight;
431
+ }
432
+ } catch (e) {
433
+ }
434
+ }
435
+ }
436
+ }
437
+
438
+ const assistantMessageIndex = messages.length;
439
+ messages.push({ role: 'assistant', content: fullResponse });
440
+
441
+ // Add click handler to regenerate this assistant message
442
+ assistantContent.setAttribute('data-message-index', assistantMessageIndex);
443
+ assistantContent.setAttribute('title', 'Click to regenerate this response');
444
+ assistantContent.addEventListener('click', function() {
445
+ if (!isGenerating) {
446
+ regenerateMessage(assistantMessageIndex);
447
+ }
448
+ });
449
+
450
+ } catch (error) {
451
+ console.error('Error:', error);
452
+ assistantContent.innerHTML = `<div class="error-message">Error: ${error.message}</div>`;
453
+ } finally {
454
+ isGenerating = false;
455
+ sendButton.disabled = !chatInput.value.trim();
456
+ }
457
+ }
458
+
459
+ async function regenerateMessage(messageIndex) {
460
+ // Find the message in the messages array
461
+ if (messageIndex < 0 || messageIndex >= messages.length) return;
462
+
463
+ const messageToRegenerate = messages[messageIndex];
464
+ if (messageToRegenerate.role !== 'assistant') return;
465
+
466
+ // Remove this message and all subsequent messages from the array
467
+ messages = messages.slice(0, messageIndex);
468
+
469
+ // Remove message elements from DOM starting from messageIndex
470
+ const allMessages = chatWrapper.querySelectorAll('.message');
471
+ for (let i = messageIndex; i < allMessages.length; i++) {
472
+ allMessages[i].remove();
473
+ }
474
+
475
+ // Regenerate the assistant response
476
+ await generateAssistantResponse();
477
+ }
478
+
479
+ function handleSlashCommand(command) {
480
+ const parts = command.trim().split(/\s+/);
481
+ const cmd = parts[0].toLowerCase();
482
+ const arg = parts[1];
483
+
484
+ if (cmd === '/temperature') {
485
+ if (arg === undefined) {
486
+ addMessage('console', `Current temperature: ${currentTemperature}`);
487
+ } else {
488
+ const temp = parseFloat(arg);
489
+ if (isNaN(temp) || temp < 0 || temp > 2) {
490
+ addMessage('console', 'Invalid temperature. Must be between 0.0 and 2.0');
491
+ } else {
492
+ currentTemperature = temp;
493
+ addMessage('console', `Temperature set to ${currentTemperature}`);
494
+ }
495
+ }
496
+ return true;
497
+ } else if (cmd === '/topk') {
498
+ if (arg === undefined) {
499
+ addMessage('console', `Current top-k: ${currentTopK}`);
500
+ } else {
501
+ const topk = parseInt(arg);
502
+ if (isNaN(topk) || topk < 1 || topk > 200) {
503
+ addMessage('console', 'Invalid top-k. Must be between 1 and 200');
504
+ } else {
505
+ currentTopK = topk;
506
+ addMessage('console', `Top-k set to ${currentTopK}`);
507
+ }
508
+ }
509
+ return true;
510
+ } else if (cmd === '/clear') {
511
+ newConversation();
512
+ return true;
513
+ } else if (cmd === '/help') {
514
+ addMessage('console',
515
+ 'Available commands:\n' +
516
+ '/temperature - Show current temperature\n' +
517
+ '/temperature <value> - Set temperature (0.0-2.0)\n' +
518
+ '/topk - Show current top-k\n' +
519
+ '/topk <value> - Set top-k (1-200)\n' +
520
+ '/clear - Clear conversation\n' +
521
+ '/help - Show this help message'
522
+ );
523
+ return true;
524
+ }
525
+ return false;
526
+ }
527
+
528
+ async function sendMessage() {
529
+ const message = chatInput.value.trim();
530
+ if (!message || isGenerating) return;
531
+
532
+ // Handle slash commands
533
+ if (message.startsWith('/')) {
534
+ chatInput.value = '';
535
+ chatInput.style.height = 'auto';
536
+ handleSlashCommand(message);
537
+ return;
538
+ }
539
+
540
+ chatInput.value = '';
541
+ chatInput.style.height = 'auto';
542
+
543
+ const userMessageIndex = messages.length;
544
+ messages.push({ role: 'user', content: message });
545
+ addMessage('user', message, userMessageIndex);
546
+
547
+ await generateAssistantResponse();
548
+ }
549
+
550
+ sendButton.disabled = false;
551
+
552
+ // Autofocus the chat input on page load
553
+ chatInput.focus();
554
+
555
+ fetch(`${API_URL}/health`)
556
+ .then(response => response.json())
557
+ .then(data => {
558
+ console.log('Engine status:', data);
559
+ })
560
+ .catch(error => {
561
+ console.error('Engine not available:', error);
562
+ chatWrapper.innerHTML = '<div class="error-message">Engine not running. Please start engine.py first.</div>';
563
+ });
564
+ </script>
565
+ </body>
566
+ </html>
scripts/__init__.py ADDED
File without changes
scripts/chat_web.py ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Unified web chat server - serves both UI and API from a single FastAPI instance.
4
+
5
+ Uses data parallelism to distribute requests across multiple GPUs. Each GPU loads
6
+ a full copy of the model, and incoming requests are distributed to available workers.
7
+
8
+ Launch examples:
9
+
10
+ - single available GPU (default)
11
+ python -m scripts.chat_web
12
+
13
+ - 4 GPUs
14
+ python -m scripts.chat_web --num-gpus 4
15
+
16
+ To chat, open the URL printed in the console. (If on cloud box, make sure to use public IP)
17
+
18
+ Endpoints:
19
+ GET / - Chat UI
20
+ POST /chat/completions - Chat API (streaming only)
21
+ GET /health - Health check with worker pool status
22
+ GET /stats - Worker pool statistics and GPU utilization
23
+
24
+ Abuse Prevention:
25
+ - Maximum 500 messages per request
26
+ - Maximum 8000 characters per message
27
+ - Maximum 32000 characters total conversation length
28
+ - Temperature clamped to 0.0-2.0
29
+ - Top-k clamped to 0-200 (0 disables top-k filtering, using full vocabulary)
30
+ - Max tokens clamped to 1-4096
31
+ """
32
+
33
+ import argparse
34
+ import json
35
+ import os
36
+ import torch
37
+ import asyncio
38
+ import logging
39
+ import random
40
+ from contextlib import asynccontextmanager
41
+ from fastapi import FastAPI, HTTPException
42
+ from fastapi.middleware.cors import CORSMiddleware
43
+ from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
44
+ from pydantic import BaseModel
45
+ from typing import List, Optional, AsyncGenerator
46
+ from dataclasses import dataclass
47
+ from nanochat.common import compute_init, autodetect_device_type
48
+ from nanochat.checkpoint_manager import load_model
49
+ from nanochat.engine import Engine
50
+
51
+ # Victorian system prompt — prepended to first user turn during inference
52
+ SYSTEM_PREFIX = (
53
+ "[You are a learned Victorian gentleman in conversation. "
54
+ "Address the question or remark put to you directly.]\n\n"
55
+ )
56
+
57
+ # Abuse prevention limits
58
+ MAX_MESSAGES_PER_REQUEST = 500
59
+ MAX_MESSAGE_LENGTH = 8000
60
+ MAX_TOTAL_CONVERSATION_LENGTH = 32000
61
+ MIN_TEMPERATURE = 0.0
62
+ MAX_TEMPERATURE = 2.0
63
+ MIN_TOP_K = 0 # 0 disables top-k filtering, using full vocabulary
64
+ MAX_TOP_K = 200
65
+ MIN_MAX_TOKENS = 1
66
+ MAX_MAX_TOKENS = 4096
67
+
68
+ parser = argparse.ArgumentParser(description='NanoChat Web Server')
69
+ parser.add_argument('-n', '--num-gpus', type=int, default=1, help='Number of GPUs to use (default: 1)')
70
+ parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|rl")
71
+ parser.add_argument('-t', '--temperature', type=float, default=0.7, help='Default temperature for generation')
72
+ parser.add_argument('-k', '--top-k', type=int, default=50, help='Default top-k sampling parameter')
73
+ parser.add_argument('-m', '--max-tokens', type=int, default=512, help='Default max tokens for generation')
74
+ parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
75
+ parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
76
+ parser.add_argument('-p', '--port', type=int, default=8000, help='Port to run the server on')
77
+ parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
78
+ parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to')
79
+ args = parser.parse_args()
80
+
81
+ # Configure logging for conversation traffic
82
+ logging.basicConfig(
83
+ level=logging.INFO,
84
+ format='%(asctime)s - %(message)s',
85
+ datefmt='%Y-%m-%d %H:%M:%S'
86
+ )
87
+ logger = logging.getLogger(__name__)
88
+
89
+ device_type = autodetect_device_type() if args.device_type == "" else args.device_type
90
+ ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
91
+
92
+ @dataclass
93
+ class Worker:
94
+ """A worker with a model loaded on a specific GPU."""
95
+ gpu_id: int
96
+ device: torch.device
97
+ engine: Engine
98
+ tokenizer: object
99
+
100
+ class WorkerPool:
101
+ """Pool of workers, each with a model replica on a different GPU."""
102
+
103
+ def __init__(self, num_gpus: Optional[int] = None):
104
+ if num_gpus is None:
105
+ if device_type == "cuda":
106
+ num_gpus = torch.cuda.device_count()
107
+ else:
108
+ num_gpus = 1 # e.g. cpu|mps
109
+ self.num_gpus = num_gpus
110
+ self.workers: List[Worker] = []
111
+ self.available_workers: asyncio.Queue = asyncio.Queue()
112
+
113
+ async def initialize(self, source: str, model_tag: Optional[str] = None, step: Optional[int] = None):
114
+ """Load model on each GPU."""
115
+ print(f"Initializing worker pool with {self.num_gpus} GPUs...")
116
+ if self.num_gpus > 1:
117
+ assert device_type == "cuda", "Only CUDA supports multiple workers/GPUs. cpu|mps does not."
118
+
119
+ for gpu_id in range(self.num_gpus):
120
+
121
+ if device_type == "cuda":
122
+ device = torch.device(f"cuda:{gpu_id}")
123
+ print(f"Loading model on GPU {gpu_id}...")
124
+ else:
125
+ device = torch.device(device_type) # e.g. cpu|mps
126
+ print(f"Loading model on {device_type}...")
127
+
128
+ model, tokenizer, _ = load_model(source, device, phase="eval", model_tag=model_tag, step=step)
129
+ engine = Engine(model, tokenizer)
130
+ worker = Worker(
131
+ gpu_id=gpu_id,
132
+ device=device,
133
+ engine=engine,
134
+ tokenizer=tokenizer,
135
+ )
136
+ self.workers.append(worker)
137
+ await self.available_workers.put(worker)
138
+
139
+ print(f"All {self.num_gpus} workers initialized!")
140
+
141
+ async def acquire_worker(self) -> Worker:
142
+ """Get an available worker from the pool."""
143
+ return await self.available_workers.get()
144
+
145
+ async def release_worker(self, worker: Worker):
146
+ """Return a worker to the pool."""
147
+ await self.available_workers.put(worker)
148
+
149
+ class ChatMessage(BaseModel):
150
+ role: str
151
+ content: str
152
+
153
+ class ChatRequest(BaseModel):
154
+ messages: List[ChatMessage]
155
+ temperature: Optional[float] = None
156
+ max_tokens: Optional[int] = None
157
+ top_k: Optional[int] = None
158
+
159
+ def validate_chat_request(request: ChatRequest):
160
+ """Validate chat request to prevent abuse."""
161
+ # Check number of messages
162
+ if len(request.messages) == 0:
163
+ raise HTTPException(status_code=400, detail="At least one message is required")
164
+ if len(request.messages) > MAX_MESSAGES_PER_REQUEST:
165
+ raise HTTPException(
166
+ status_code=400,
167
+ detail=f"Too many messages. Maximum {MAX_MESSAGES_PER_REQUEST} messages allowed per request"
168
+ )
169
+
170
+ # Check individual message lengths and total conversation length
171
+ total_length = 0
172
+ for i, message in enumerate(request.messages):
173
+ if not message.content:
174
+ raise HTTPException(status_code=400, detail=f"Message {i} has empty content")
175
+
176
+ msg_length = len(message.content)
177
+ if msg_length > MAX_MESSAGE_LENGTH:
178
+ raise HTTPException(
179
+ status_code=400,
180
+ detail=f"Message {i} is too long. Maximum {MAX_MESSAGE_LENGTH} characters allowed per message"
181
+ )
182
+ total_length += msg_length
183
+
184
+ if total_length > MAX_TOTAL_CONVERSATION_LENGTH:
185
+ raise HTTPException(
186
+ status_code=400,
187
+ detail=f"Total conversation is too long. Maximum {MAX_TOTAL_CONVERSATION_LENGTH} characters allowed"
188
+ )
189
+
190
+ # Validate role values
191
+ for i, message in enumerate(request.messages):
192
+ if message.role not in ["user", "assistant"]:
193
+ raise HTTPException(
194
+ status_code=400,
195
+ detail=f"Message {i} has invalid role. Must be 'user', 'assistant', or 'system'"
196
+ )
197
+
198
+ # Validate temperature
199
+ if request.temperature is not None:
200
+ if not (MIN_TEMPERATURE <= request.temperature <= MAX_TEMPERATURE):
201
+ raise HTTPException(
202
+ status_code=400,
203
+ detail=f"Temperature must be between {MIN_TEMPERATURE} and {MAX_TEMPERATURE}"
204
+ )
205
+
206
+ # Validate top_k
207
+ if request.top_k is not None:
208
+ if not (MIN_TOP_K <= request.top_k <= MAX_TOP_K):
209
+ raise HTTPException(
210
+ status_code=400,
211
+ detail=f"top_k must be between {MIN_TOP_K} and {MAX_TOP_K}"
212
+ )
213
+
214
+ # Validate max_tokens
215
+ if request.max_tokens is not None:
216
+ if not (MIN_MAX_TOKENS <= request.max_tokens <= MAX_MAX_TOKENS):
217
+ raise HTTPException(
218
+ status_code=400,
219
+ detail=f"max_tokens must be between {MIN_MAX_TOKENS} and {MAX_MAX_TOKENS}"
220
+ )
221
+
222
+ @asynccontextmanager
223
+ async def lifespan(app: FastAPI):
224
+ """Load models on all GPUs on startup."""
225
+ print("Loading nanochat models across GPUs...")
226
+ app.state.worker_pool = WorkerPool(num_gpus=args.num_gpus)
227
+ await app.state.worker_pool.initialize(args.source, model_tag=args.model_tag, step=args.step)
228
+ print(f"Server ready at http://localhost:{args.port}")
229
+ yield
230
+
231
+ app = FastAPI(lifespan=lifespan)
232
+
233
+ app.add_middleware(
234
+ CORSMiddleware,
235
+ allow_origins=["*"],
236
+ allow_credentials=True,
237
+ allow_methods=["*"],
238
+ allow_headers=["*"],
239
+ )
240
+
241
+ @app.get("/")
242
+ async def root():
243
+ """Serve the chat UI."""
244
+ ui_html_path = os.path.join("nanochat", "ui.html")
245
+ with open(ui_html_path, "r", encoding="utf-8") as f:
246
+ html_content = f.read()
247
+ # Replace the API_URL to use the same origin
248
+ html_content = html_content.replace(
249
+ "const API_URL = `http://${window.location.hostname}:8000`;",
250
+ "const API_URL = '';"
251
+ )
252
+ return HTMLResponse(content=html_content)
253
+
254
+
255
+ @app.get("/logo.svg")
256
+ async def logo():
257
+ """Serve the NanoChat logo for favicon and header."""
258
+ logo_path = os.path.join("nanochat", "logo.svg")
259
+ return FileResponse(logo_path, media_type="image/svg+xml")
260
+
261
+ async def generate_stream(
262
+ worker: Worker,
263
+ tokens,
264
+ temperature=None,
265
+ max_new_tokens=None,
266
+ top_k=None
267
+ ) -> AsyncGenerator[str, None]:
268
+ """Generate assistant response with streaming."""
269
+ temperature = temperature if temperature is not None else args.temperature
270
+ max_new_tokens = max_new_tokens if max_new_tokens is not None else args.max_tokens
271
+ top_k = top_k if top_k is not None else args.top_k
272
+
273
+ assistant_end = worker.tokenizer.encode_special("<|assistant_end|>")
274
+ bos = worker.tokenizer.get_bos_token_id()
275
+
276
+ # Accumulate tokens to properly handle multi-byte UTF-8 characters (like emojis)
277
+ accumulated_tokens = []
278
+ # Track the last complete UTF-8 string (without replacement characters)
279
+ last_clean_text = ""
280
+
281
+ for token_column, token_masks in worker.engine.generate(
282
+ tokens,
283
+ num_samples=1,
284
+ max_tokens=max_new_tokens,
285
+ temperature=temperature,
286
+ top_k=top_k,
287
+ repetition_penalty=1.3,
288
+ repetition_window=64,
289
+ seed=random.randint(0, 2**31 - 1)
290
+ ):
291
+ token = token_column[0]
292
+
293
+ # Stopping criteria
294
+ if token == assistant_end or token == bos:
295
+ break
296
+
297
+ # Append the token to sequence
298
+ accumulated_tokens.append(token)
299
+ # Decode all accumulated tokens to get proper UTF-8 handling
300
+ # Note that decode is a quite efficient operation, basically table lookup and string concat
301
+ current_text = worker.tokenizer.decode(accumulated_tokens)
302
+ # Only emit text if it doesn't end with a replacement character
303
+ # This ensures we don't emit incomplete UTF-8 sequences
304
+ if not current_text.endswith('�'):
305
+ # Extract only the new text since last clean decode
306
+ new_text = current_text[len(last_clean_text):]
307
+ if new_text: # Only yield if there's new content
308
+ yield f"data: {json.dumps({'token': new_text, 'gpu': worker.gpu_id}, ensure_ascii=False)}\n\n"
309
+ last_clean_text = current_text
310
+
311
+ yield f"data: {json.dumps({'done': True})}\n\n"
312
+
313
+ @app.post("/chat/completions")
314
+ async def chat_completions(request: ChatRequest):
315
+ """Chat completion endpoint (streaming only) - uses worker pool for multi-GPU."""
316
+
317
+ # Basic validation to prevent abuse
318
+ validate_chat_request(request)
319
+
320
+ # Log incoming conversation to console
321
+ logger.info("="*20)
322
+ for i, message in enumerate(request.messages):
323
+ logger.info(f"[{message.role.upper()}]: {message.content}")
324
+ logger.info("-"*20)
325
+
326
+ # Acquire a worker from the pool (will wait if all are busy)
327
+ worker_pool = app.state.worker_pool
328
+ worker = await worker_pool.acquire_worker()
329
+
330
+ try:
331
+ # Build conversation tokens
332
+ bos = worker.tokenizer.get_bos_token_id()
333
+ user_start = worker.tokenizer.encode_special("<|user_start|>")
334
+ user_end = worker.tokenizer.encode_special("<|user_end|>")
335
+ assistant_start = worker.tokenizer.encode_special("<|assistant_start|>")
336
+ assistant_end = worker.tokenizer.encode_special("<|assistant_end|>")
337
+
338
+ conversation_tokens = [bos]
339
+ turn_count = 0
340
+ for message in request.messages:
341
+ if message.role == "user":
342
+ content = message.content
343
+ # Prepend system prompt to the first user turn
344
+ if turn_count == 0:
345
+ content = SYSTEM_PREFIX + content
346
+ conversation_tokens.append(user_start)
347
+ conversation_tokens.extend(worker.tokenizer.encode(content))
348
+ conversation_tokens.append(user_end)
349
+ turn_count += 1
350
+ elif message.role == "assistant":
351
+ conversation_tokens.append(assistant_start)
352
+ conversation_tokens.extend(worker.tokenizer.encode(message.content))
353
+ conversation_tokens.append(assistant_end)
354
+
355
+ conversation_tokens.append(assistant_start)
356
+
357
+ # Streaming response with worker release after completion
358
+ response_tokens = []
359
+ async def stream_and_release():
360
+ try:
361
+ async for chunk in generate_stream(
362
+ worker,
363
+ conversation_tokens,
364
+ temperature=request.temperature,
365
+ max_new_tokens=request.max_tokens,
366
+ top_k=request.top_k
367
+ ):
368
+ # Accumulate response for logging
369
+ chunk_data = json.loads(chunk.replace("data: ", "").strip())
370
+ if "token" in chunk_data:
371
+ response_tokens.append(chunk_data["token"])
372
+ yield chunk
373
+ finally:
374
+ # Log the assistant response to console
375
+ full_response = "".join(response_tokens)
376
+ logger.info(f"[ASSISTANT] (GPU {worker.gpu_id}): {full_response}")
377
+ logger.info("="*20)
378
+ # Release worker back to pool after streaming is done
379
+ await worker_pool.release_worker(worker)
380
+
381
+ return StreamingResponse(
382
+ stream_and_release(),
383
+ media_type="text/event-stream"
384
+ )
385
+ except Exception as e:
386
+ # Make sure to release worker even on error
387
+ await worker_pool.release_worker(worker)
388
+ raise e
389
+
390
+ @app.get("/health")
391
+ async def health():
392
+ """Health check endpoint."""
393
+ worker_pool = getattr(app.state, 'worker_pool', None)
394
+ return {
395
+ "status": "ok",
396
+ "ready": worker_pool is not None and len(worker_pool.workers) > 0,
397
+ "num_gpus": worker_pool.num_gpus if worker_pool else 0,
398
+ "available_workers": worker_pool.available_workers.qsize() if worker_pool else 0
399
+ }
400
+
401
+ @app.get("/stats")
402
+ async def stats():
403
+ """Get worker pool statistics."""
404
+ worker_pool = app.state.worker_pool
405
+ return {
406
+ "total_workers": len(worker_pool.workers),
407
+ "available_workers": worker_pool.available_workers.qsize(),
408
+ "busy_workers": len(worker_pool.workers) - worker_pool.available_workers.qsize(),
409
+ "workers": [
410
+ {
411
+ "gpu_id": w.gpu_id,
412
+ "device": str(w.device)
413
+ } for w in worker_pool.workers
414
+ ]
415
+ }
416
+
417
+ if __name__ == "__main__":
418
+ import uvicorn
419
+ print(f"Starting NanoChat Web Server")
420
+ print(f"Temperature: {args.temperature}, Top-k: {args.top_k}, Max tokens: {args.max_tokens}")
421
+ uvicorn.run(app, host=args.host, port=args.port)
start.sh ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -e
3
+
4
+ MODEL_DIR="/app/nanochat_cache/chatsft_checkpoints/d18"
5
+ MODEL_REPO="tventurella/mr_chatterbox_model"
6
+
7
+ # Download model checkpoint if not already present
8
+ if [ ! -f "$MODEL_DIR/model_000050.pt" ]; then
9
+ echo "Downloading model checkpoint..."
10
+ python -c "
11
+ from huggingface_hub import hf_hub_download
12
+ hf_hub_download('$MODEL_REPO', 'model_000050.pt', local_dir='$MODEL_DIR')
13
+ hf_hub_download('$MODEL_REPO', 'meta_000050.json', local_dir='$MODEL_DIR')
14
+ print('Model downloaded successfully.')
15
+ "
16
+ else
17
+ echo "Model checkpoint already present."
18
+ fi
19
+
20
+ # Start the server
21
+ exec python -m scripts.chat_web \
22
+ --model-tag d18 \
23
+ --device-type cpu \
24
+ --port 7860 \
25
+ --temperature 0.7 \
26
+ --top-k 50 \
27
+ --max-tokens 256
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_wrapper.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ tokenizer_wrapper.py — nanochat-compatible wrapper for the Victorian BPE tokenizer
3
+
4
+ nanochat's base_train.py imports:
5
+ from nanochat.tokenizer import get_tokenizer, get_token_bytes
6
+
7
+ This wrapper provides a VictorianTokenizer class that satisfies nanochat's full
8
+ interface, plus get_tokenizer() and get_token_bytes() drop-in replacements.
9
+
10
+ Special token mapping:
11
+ <|endoftext|> → bos (document boundary, prepended to every document)
12
+ <|pad|> → pad
13
+ <human> → user_start (replaces nanochat's <|user_start|>)
14
+ <victorian> → assistant_start (replaces nanochat's <|assistant_start|>)
15
+
16
+ Usage — patch nanochat/tokenizer.py by adding at the bottom:
17
+ from pathlib import Path
18
+ import sys
19
+ sys.path.insert(0, "/path/to/victorian")
20
+ from tokenizer_wrapper import get_tokenizer, get_token_bytes
21
+ """
22
+
23
+ from pathlib import Path
24
+ import torch
25
+ from tokenizers import Tokenizer
26
+
27
+ TOKENIZER_PATH = Path(__file__).parent / "tokenizer.json"
28
+
29
+
30
+ class VictorianTokenizer:
31
+ """
32
+ Wraps our HuggingFace BPE tokenizer to match nanochat's expected interface.
33
+ """
34
+
35
+ def __init__(self, tokenizer_path: str | Path = TOKENIZER_PATH):
36
+ self._tok = Tokenizer.from_file(str(tokenizer_path))
37
+ self._tok.no_padding()
38
+ self._tok.no_truncation()
39
+
40
+ # ------------------------------------------------------------------
41
+ # Core nanochat interface (used by dataloader and base_train.py)
42
+ # ------------------------------------------------------------------
43
+
44
+ def get_vocab_size(self) -> int:
45
+ return self._tok.get_vocab_size()
46
+
47
+ def get_bos_token_id(self) -> int:
48
+ """Prepended to every document by nanochat's dataloader."""
49
+ return self._tok.token_to_id("<|endoftext|>")
50
+
51
+ def encode(
52
+ self,
53
+ texts: list[str] | str,
54
+ prepend: int | str | None = None,
55
+ append: int | str | None = None,
56
+ num_threads: int = 4,
57
+ ) -> list[int] | list[list[int]]:
58
+ """
59
+ Encode strings → token ID list(s).
60
+
61
+ Matches nanochat's native tokenizer behaviour exactly:
62
+ - Single string → list[int]
63
+ - List of strings → list[list[int]]
64
+
65
+ prepend/append may be an int token ID or a special-token string
66
+ (e.g. prepend="<|bos|>"), matching nanochat's _encode_one interface.
67
+ """
68
+ single = isinstance(texts, str)
69
+ if single:
70
+ texts = [texts]
71
+
72
+ # Resolve string prepend/append to token IDs (e.g. "<|bos|>" → 0)
73
+ if isinstance(prepend, str):
74
+ prepend = self.encode_special(prepend)
75
+ if isinstance(append, str):
76
+ append = self.encode_special(append)
77
+
78
+ encodings = self._tok.encode_batch(texts, is_pretokenized=False)
79
+ ids = [enc.ids for enc in encodings]
80
+
81
+ if prepend is not None:
82
+ ids = [[prepend] + seq for seq in ids]
83
+ if append is not None:
84
+ ids = [seq + [append] for seq in ids]
85
+
86
+ # Single string → flat list[int] to match nanochat's native encode()
87
+ return ids[0] if single else ids
88
+
89
+ def decode(self, ids: list[int]) -> str:
90
+ return self._tok.decode(ids)
91
+
92
+ # ------------------------------------------------------------------
93
+ # Special token accessors
94
+ # ------------------------------------------------------------------
95
+
96
+ def encode_special(self, token: str) -> int | None:
97
+ """
98
+ Look up a special token ID by exact match.
99
+ Maps nanochat's native special tokens to Victorian equivalents where needed.
100
+ Required by nanochat's engine.py for sample generation.
101
+ """
102
+ # Try exact match first (covers our own special tokens)
103
+ result = self._tok.token_to_id(token)
104
+ if result is not None:
105
+ return result
106
+ # Map nanochat's native chat tokens to Victorian equivalents
107
+ _map = {
108
+ "<|assistant_start|>": "<victorian>",
109
+ "<|assistant_end|>": "<|endoftext|>",
110
+ "<|user_start|>": "<human>",
111
+ "<|user_end|>": "<|endoftext|>",
112
+ "<|bos|>": "<|endoftext|>",
113
+ "<|eos|>": "<|endoftext|>",
114
+ }
115
+ mapped = _map.get(token)
116
+ if mapped:
117
+ return self._tok.token_to_id(mapped)
118
+ return None
119
+
120
+ def get_pad_token_id(self) -> int:
121
+ return self._tok.token_to_id("<|pad|>")
122
+
123
+ def get_user_start_id(self) -> int:
124
+ """Maps to nanochat's <|user_start|> role."""
125
+ return self._tok.token_to_id("<human>")
126
+
127
+ def get_assistant_start_id(self) -> int:
128
+ """Maps to nanochat's <|assistant_start|> role."""
129
+ return self._tok.token_to_id("<victorian>")
130
+
131
+ # ------------------------------------------------------------------
132
+ # Chat / fine-tuning interface (used by chat_sft.py)
133
+ # ------------------------------------------------------------------
134
+
135
+ def render_conversation(
136
+ self,
137
+ conversation: list[dict],
138
+ max_tokens: int = 2048,
139
+ ) -> tuple[list[int], list[int]]:
140
+ """
141
+ Encode a conversation into token IDs and a loss mask.
142
+
143
+ conversation: list of {"role": "user"|"assistant", "content": str}
144
+ Returns: (token_ids, loss_mask) — loss_mask is 1 for assistant tokens, 0 otherwise.
145
+
146
+ Victorian mapping:
147
+ "user" → <human> ...
148
+ "assistant" → <victorian> ... <|endoftext|> (end token trains model to stop)
149
+ """
150
+ human_id = self.get_user_start_id()
151
+ victorian_id = self.get_assistant_start_id()
152
+ bos_id = self.get_bos_token_id()
153
+
154
+ tokens: list[int] = [bos_id]
155
+ mask: list[int] = [0]
156
+
157
+ for turn in conversation:
158
+ role = turn["role"]
159
+ content = turn["content"]
160
+ content_ids = self.encode(content)
161
+
162
+ if role == "user":
163
+ turn_tokens = [human_id] + content_ids
164
+ turn_mask = [0] * len(turn_tokens)
165
+ else: # assistant
166
+ turn_tokens = [victorian_id] + content_ids + [bos_id]
167
+ turn_mask = [1] * len(turn_tokens)
168
+
169
+ tokens.extend(turn_tokens)
170
+ mask.extend(turn_mask)
171
+
172
+ if len(tokens) >= max_tokens:
173
+ tokens = tokens[:max_tokens]
174
+ mask = mask[:max_tokens]
175
+ break
176
+
177
+ return tokens, mask
178
+
179
+ # ------------------------------------------------------------------
180
+
181
+ def __call__(self, texts, **kwargs):
182
+ """Allow tokenizer(texts, ...) as an alias for encode() — required by nanochat's core_eval."""
183
+ return self.encode(texts, **kwargs)
184
+
185
+ @property
186
+ def vocab_size(self) -> int:
187
+ return self.get_vocab_size()
188
+
189
+ def __repr__(self) -> str:
190
+ return (
191
+ f"VictorianTokenizer(vocab_size={self.vocab_size}, "
192
+ f"bos={self.get_bos_token_id()}, "
193
+ f"human={self.get_user_start_id()}, "
194
+ f"victorian={self.get_assistant_start_id()})"
195
+ )
196
+
197
+
198
+ # ---------------------------------------------------------------------------
199
+ # nanochat drop-in functions
200
+ # ---------------------------------------------------------------------------
201
+
202
+ _tokenizer_singleton: VictorianTokenizer | None = None
203
+
204
+
205
+ def get_tokenizer(tokenizer_path: str | Path = TOKENIZER_PATH) -> VictorianTokenizer:
206
+ """Drop-in replacement for nanochat's get_tokenizer()."""
207
+ global _tokenizer_singleton
208
+ if _tokenizer_singleton is None:
209
+ _tokenizer_singleton = VictorianTokenizer(tokenizer_path)
210
+ return _tokenizer_singleton
211
+
212
+
213
+ def get_token_bytes(device: str | torch.device = "cpu") -> torch.Tensor:
214
+ """
215
+ Drop-in replacement for nanochat's get_token_bytes().
216
+
217
+ Returns a 1D tensor of shape [vocab_size] where each entry is the
218
+ UTF-8 byte length of that token. Used by base_train.py to convert
219
+ loss from nats/token → bits/byte (the BPB evaluation metric).
220
+ """
221
+ tok = get_tokenizer()
222
+ vocab = tok._tok.get_vocab() # {token_str: id}
223
+ vocab_size = tok.get_vocab_size()
224
+
225
+ # Build id → token string mapping
226
+ id_to_token = {v: k for k, v in vocab.items()}
227
+
228
+ byte_lengths = []
229
+ for i in range(vocab_size):
230
+ token_str = id_to_token.get(i, "")
231
+ # ByteLevel BPE: Ġ represents a leading space (0x20).
232
+ # Decode the display string back to actual bytes for a correct byte count.
233
+ try:
234
+ # Replace Ġ with space, then encode to UTF-8
235
+ actual = token_str.replace("Ġ", " ").replace("Ċ", "\n").replace("ĉ", "\t")
236
+ n_bytes = len(actual.encode("utf-8"))
237
+ except Exception:
238
+ n_bytes = 1
239
+ byte_lengths.append(max(1, n_bytes)) # floor at 1 to avoid div-by-zero
240
+
241
+ return torch.tensor(byte_lengths, dtype=torch.long, device=device)
242
+
243
+
244
+ # ---------------------------------------------------------------------------
245
+ # Sanity check
246
+ # ---------------------------------------------------------------------------
247
+
248
+ if __name__ == "__main__":
249
+ import sys
250
+
251
+ if not TOKENIZER_PATH.exists():
252
+ print(f"Tokenizer not found at {TOKENIZER_PATH}")
253
+ sys.exit(1)
254
+
255
+ tok = get_tokenizer()
256
+ print(tok)
257
+ print(f" pad={tok.get_pad_token_id()}")
258
+
259
+ texts = [
260
+ "It is a truth universally acknowledged.",
261
+ "The phrenological examination was most illuminating, dear fellow.",
262
+ ]
263
+ ids = tok.encode(texts, prepend=tok.get_bos_token_id())
264
+ for text, seq in zip(texts, ids):
265
+ decoded = tok.decode(seq[1:])
266
+ ok = "✓" if decoded == text else "✗"
267
+ print(f" {ok} {len(seq):3d} tokens {text!r}")
268
+
269
+ # Test render_conversation
270
+ conv = [
271
+ {"role": "user", "content": "What is your opinion on the railways?"},
272
+ {"role": "assistant", "content": "The railways are a most alarming development, yet undeniably useful."},
273
+ ]
274
+ token_ids, loss_mask = tok.render_conversation(conv)
275
+ print(f"\n render_conversation: {len(token_ids)} tokens, "
276
+ f"{sum(loss_mask)} assistant tokens in loss mask")
277
+
278
+ # Test get_token_bytes
279
+ tb = get_token_bytes()
280
+ print(f"\n get_token_bytes: shape={tuple(tb.shape)}, "
281
+ f"mean={tb.mean():.2f} bytes/token, "
282
+ f"min={tb.min():.0f}, max={tb.max():.0f}")