protein-structure-mlx-esm2_t33_650M_UR50D

ESM-2 650M fine-tuned for protein structure prediction using distance distogram, optimized for Apple Silicon with MLX.

Latest Finetuning

image/png

HF Model Path: gccmorgoth/protein-structure-mlx-esm2_t33_650M_UR50D

Model Details

  • Base Model: facebook/esm2_t33_650M_UR50D
  • Architecture: ESM-2 650M (frozen encoder) + 5-layer MLP prediction head (768 hidden dim)
  • Task: Protein inter-residue distance prediction (distogram)
  • Framework: MLX (Apple Silicon optimized)
  • Parameters: 651M (base) + 2.4M (trainable head) = 653.4M total

Training Details

  • Method: Supervised Fine-Tuning (SFT) with distogram loss
  • Loss Function: AlphaFold2-style cross-entropy on 64 distance bins (2.16-21.84ร…)
  • Dataset: 196 high-resolution PDB structures (<2.5ร… resolution)
    • Training: 3,004 examples
    • Validation: 752 examples
    • Negative examples: Gaussian noise perturbation
  • Validation Loss: 2.79 (cross-entropy on 64 bins)
  • Training Loss: 2.81
  • Hardware: Mac mini M4 Pro (24GB RAM, 16-core GPU)

Training Configuration

  • Optimizer: Adam
  • Learning Rate: 5e-4 with cosine decay to 1e-5
  • Warmup Steps: 200
  • Batch Size: 8 (effective 16 with gradient accumulation)
  • Epochs: 20 with early stopping (patience=5)
  • Best Checkpoint: Epoch 19

Performance

  • Random Baseline: 4.16 (ln(64) for 64-class classification)
  • Final Val Loss: 2.79 (33% improvement over random)
  • Convergence: Optimal performance reached at epoch 8-10

Distance Binning

Predicts 64 distance bins covering 2.16-21.84ร… with 0.3125ร… resolution per bin, following AlphaFold2's distogram approach.

Usage

import mlx.core as mx
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer

# Download model weights
model_path = hf_hub_download(
    repo_id="gccmorgoth/protein-structure-mlx-esm2_t33_650M_UR50D",
    filename="best_model.safetensors"
)

# Load model architecture (requires model code from repo)
from src.models.esm2_mlx import ESM2StructurePredictor

model = ESM2StructurePredictor(
    model_name="facebook/esm2_t33_650M_UR50D",
    output_type="distance",
    hidden_dim=768,
    num_layers=5
)

# Load trained weights
model.prediction_head.load_weights(model_path)

# Tokenize protein sequence
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
sequence = "MKTAYIAKQRQISFVKSHFSRQLEERLGLIEVQAPILSRVGDGTQDNLSGAEKAVQVKVKALPDAQFEVVHSLAKWKRQTLGQHDFSAGEGLYTHMKALRPDEDRLSPLHSVYVDQWDWERVMGDGERQFSTLKSTVEAIWAGIKATEAAVSEEFGLAPFLPDQIHFVHSQELLSRYPDLDAKGRERAIAKDLGAVFLVGIGGKLSDGHRHDVRAPDYDDWSTPSELGHAGLNGDILVWNPVLEDAFELSSMGIRVDADTLKHQLALTGDEDRLELEWHQALLRGEMPQTIGGGIGQSRLTMLLLQLPHIGQVQAGVWPAAVRESVPSLL"

inputs = tokenizer(sequence, return_tensors="np")
input_ids = mx.array(inputs["input_ids"])
attention_mask = mx.array(inputs["attention_mask"])

# Predict distance bins (logits)
_, distance_logits = model(input_ids, attention_mask)
# Shape: (1, seq_len, seq_len, 64)

# Get most likely distance bin per residue pair
predicted_bins = mx.argmax(distance_logits, axis=-1)

Limitations

  • Trained on small proteins (<100 residues)
  • Limited to 196 PDB structures (small dataset)
  • Does not use MSA (multiple sequence alignments) like AlphaFold
  • Simple MLP architecture (no attention mechanisms)
  • Distance prediction only (no angles, torsions, or 3D coordinates)

Citation

If you use this model, please cite:

@software{protein-structure-mlx-esm2_t33_650M_UR50D,
  title={ESM-2 Protein Structure Prediction with MLX},
  author={Mohammad Huzefa Shaikh},
  year={2025},
  url={https://huggingface.co/gccmorgoth/protein-structure-mlx-esm2_t33_650M_UR50D}
}

Acknowledgments

Downloads last month

-

Downloads are not tracked for this model. How to track
MLX
Hardware compatibility
Log In to add your hardware

Quantized

Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Model tree for gccmorgoth/protein-structure-mlx-esm2_t33_650M_UR50D

Finetuned
(30)
this model