GPT Diffusion (SEDD) β€” 350M, PG-19 Byte-Level

A 355M parameter discrete diffusion transformer trained on raw UTF-8 bytes from PG-19 (Project Gutenberg books), using the Score Entropy Discrete Diffusion (SEDD) framework with a nanoGPT-style architecture.

Inference code: github.com/JustKitting/byte-diffusion-inference

Quickstart

git clone https://github.com/JustKitting/byte-diffusion-inference.git
cd byte-diffusion-inference
pip install torch huggingface_hub
python sample.py

The checkpoint downloads automatically on first run.

Usage

from huggingface_hub import hf_hub_download
from sample import load_model, sample

# Download checkpoint (~4GB, cached after first run)
ckpt_path = hf_hub_download(
    repo_id="justkitting/gpt-diffusion-nanogptsettings",
    filename="pg19_step_50000_final.pth",
)

# Load model (uses EMA weights automatically)
device = "cuda"
model, config = load_model(ckpt_path, device)

# Generate text
text = sample(model, config, steps=128, length=512, device=device)
print(text)

Results

Metric Value
Val VLB 1.63 bits/byte
Test VLB 1.72 bits/byte
Val perplexity 3.09
Final train loss 1.36
Final val loss 1.80
Total tokens seen 13.1B

Benchmark target: MambaByte (353M) at 0.87 bits/byte on PG-19.

Architecture

Parameters ~355M
Type DDiT β€” Transformer with adaptive LayerNorm (adaLN-Zero) conditioned on noise level
Layers 25
Heads 16
Embedding dim 1024
Conditioning dim 256
Context length 8192 bytes
Vocab 256 (raw UTF-8 bytes)
Activation SwiGLU
Position embedding RoPE
Normalization RMSNorm
Bias None

Training

Hyperparameters were selected via a Bayesian sweep (10 runs, 750 steps each) then used for the full 50k-step run.

Optimizer: Split Muon + AdamW

  • Muon (hidden 2D weights): lr=0.029, momentum=0.95, nesterov=True, weight_decay=0.1
  • AdamW (embeddings, biases, 1D params): peak_lr=7e-5, min_lr=7e-6, betas=(0.9, 0.98), weight_decay=0.01

Schedule: Linear warmup (11,000 steps) then cosine decay to min_lr

Batch size: micro_batch=4 x grad_accum=8 x 8192 context = 262K tokens/step

Other:

  • Precision: bfloat16 (torch.autocast)
  • EMA decay: 0.999
  • Gradient clipping: max_norm=1.0
  • Dropout: 0.1

Noise: Geometric schedule, sigma_min=1e-4, sigma_max=25

Data: PG-19 (Project Gutenberg books pre-1919), ~11GB raw bytes split into 6 training chunks + validation + test sets. Each chunk is memory-mapped for efficient streaming. Training saw ~1.3 epochs over the full dataset.

Hardware: 2x NVIDIA RTX PRO 6000 Blackwell (96GB each), DDP with gloo backend. Total training time ~3.5 days.

Checkpoint Contents

The .pth file contains:

  • model β€” main model state dict
  • ema_model β€” EMA weights (use these for inference)
  • optimizer β€” Muon optimizer state
  • adamw_optimizer β€” AdamW optimizer state
  • step β€” training step (50000)
  • config β€” model config dict
  • chunk_idx β€” data chunk index

License

MIT

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Dataset used to train justkitting/gpt-diffusion-nanogptsettings