GPT Diffusion (SEDD) β 350M, PG-19 Byte-Level
A 355M parameter discrete diffusion transformer trained on raw UTF-8 bytes from PG-19 (Project Gutenberg books), using the Score Entropy Discrete Diffusion (SEDD) framework with a nanoGPT-style architecture.
Inference code: github.com/JustKitting/byte-diffusion-inference
Quickstart
git clone https://github.com/JustKitting/byte-diffusion-inference.git
cd byte-diffusion-inference
pip install torch huggingface_hub
python sample.py
The checkpoint downloads automatically on first run.
Usage
from huggingface_hub import hf_hub_download
from sample import load_model, sample
# Download checkpoint (~4GB, cached after first run)
ckpt_path = hf_hub_download(
repo_id="justkitting/gpt-diffusion-nanogptsettings",
filename="pg19_step_50000_final.pth",
)
# Load model (uses EMA weights automatically)
device = "cuda"
model, config = load_model(ckpt_path, device)
# Generate text
text = sample(model, config, steps=128, length=512, device=device)
print(text)
Results
| Metric | Value |
|---|---|
| Val VLB | 1.63 bits/byte |
| Test VLB | 1.72 bits/byte |
| Val perplexity | 3.09 |
| Final train loss | 1.36 |
| Final val loss | 1.80 |
| Total tokens seen | 13.1B |
Benchmark target: MambaByte (353M) at 0.87 bits/byte on PG-19.
Architecture
| Parameters | ~355M |
| Type | DDiT β Transformer with adaptive LayerNorm (adaLN-Zero) conditioned on noise level |
| Layers | 25 |
| Heads | 16 |
| Embedding dim | 1024 |
| Conditioning dim | 256 |
| Context length | 8192 bytes |
| Vocab | 256 (raw UTF-8 bytes) |
| Activation | SwiGLU |
| Position embedding | RoPE |
| Normalization | RMSNorm |
| Bias | None |
Training
Hyperparameters were selected via a Bayesian sweep (10 runs, 750 steps each) then used for the full 50k-step run.
Optimizer: Split Muon + AdamW
- Muon (hidden 2D weights): lr=0.029, momentum=0.95, nesterov=True, weight_decay=0.1
- AdamW (embeddings, biases, 1D params): peak_lr=7e-5, min_lr=7e-6, betas=(0.9, 0.98), weight_decay=0.01
Schedule: Linear warmup (11,000 steps) then cosine decay to min_lr
Batch size: micro_batch=4 x grad_accum=8 x 8192 context = 262K tokens/step
Other:
- Precision: bfloat16 (torch.autocast)
- EMA decay: 0.999
- Gradient clipping: max_norm=1.0
- Dropout: 0.1
Noise: Geometric schedule, sigma_min=1e-4, sigma_max=25
Data: PG-19 (Project Gutenberg books pre-1919), ~11GB raw bytes split into 6 training chunks + validation + test sets. Each chunk is memory-mapped for efficient streaming. Training saw ~1.3 epochs over the full dataset.
Hardware: 2x NVIDIA RTX PRO 6000 Blackwell (96GB each), DDP with gloo backend. Total training time ~3.5 days.
Checkpoint Contents
The .pth file contains:
modelβ main model state dictema_modelβ EMA weights (use these for inference)optimizerβ Muon optimizer stateadamw_optimizerβ AdamW optimizer statestepβ training step (50000)configβ model config dictchunk_idxβ data chunk index
License
MIT