protein-structure-mlx-esm2_t33_650M_UR50D
ESM-2 650M fine-tuned for protein structure prediction using distance distogram, optimized for Apple Silicon with MLX.
Latest Finetuning
HF Model Path: gccmorgoth/protein-structure-mlx-esm2_t33_650M_UR50D
Model Details
- Base Model: facebook/esm2_t33_650M_UR50D
- Architecture: ESM-2 650M (frozen encoder) + 5-layer MLP prediction head (768 hidden dim)
- Task: Protein inter-residue distance prediction (distogram)
- Framework: MLX (Apple Silicon optimized)
- Parameters: 651M (base) + 2.4M (trainable head) = 653.4M total
Training Details
- Method: Supervised Fine-Tuning (SFT) with distogram loss
- Loss Function: AlphaFold2-style cross-entropy on 64 distance bins (2.16-21.84ร )
- Dataset: 196 high-resolution PDB structures (<2.5ร
resolution)
- Training: 3,004 examples
- Validation: 752 examples
- Negative examples: Gaussian noise perturbation
- Validation Loss: 2.79 (cross-entropy on 64 bins)
- Training Loss: 2.81
- Hardware: Mac mini M4 Pro (24GB RAM, 16-core GPU)
Training Configuration
- Optimizer: Adam
- Learning Rate: 5e-4 with cosine decay to 1e-5
- Warmup Steps: 200
- Batch Size: 8 (effective 16 with gradient accumulation)
- Epochs: 20 with early stopping (patience=5)
- Best Checkpoint: Epoch 19
Performance
- Random Baseline: 4.16 (ln(64) for 64-class classification)
- Final Val Loss: 2.79 (33% improvement over random)
- Convergence: Optimal performance reached at epoch 8-10
Distance Binning
Predicts 64 distance bins covering 2.16-21.84ร with 0.3125ร resolution per bin, following AlphaFold2's distogram approach.
Usage
import mlx.core as mx
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer
# Download model weights
model_path = hf_hub_download(
repo_id="gccmorgoth/protein-structure-mlx-esm2_t33_650M_UR50D",
filename="best_model.safetensors"
)
# Load model architecture (requires model code from repo)
from src.models.esm2_mlx import ESM2StructurePredictor
model = ESM2StructurePredictor(
model_name="facebook/esm2_t33_650M_UR50D",
output_type="distance",
hidden_dim=768,
num_layers=5
)
# Load trained weights
model.prediction_head.load_weights(model_path)
# Tokenize protein sequence
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
sequence = "MKTAYIAKQRQISFVKSHFSRQLEERLGLIEVQAPILSRVGDGTQDNLSGAEKAVQVKVKALPDAQFEVVHSLAKWKRQTLGQHDFSAGEGLYTHMKALRPDEDRLSPLHSVYVDQWDWERVMGDGERQFSTLKSTVEAIWAGIKATEAAVSEEFGLAPFLPDQIHFVHSQELLSRYPDLDAKGRERAIAKDLGAVFLVGIGGKLSDGHRHDVRAPDYDDWSTPSELGHAGLNGDILVWNPVLEDAFELSSMGIRVDADTLKHQLALTGDEDRLELEWHQALLRGEMPQTIGGGIGQSRLTMLLLQLPHIGQVQAGVWPAAVRESVPSLL"
inputs = tokenizer(sequence, return_tensors="np")
input_ids = mx.array(inputs["input_ids"])
attention_mask = mx.array(inputs["attention_mask"])
# Predict distance bins (logits)
_, distance_logits = model(input_ids, attention_mask)
# Shape: (1, seq_len, seq_len, 64)
# Get most likely distance bin per residue pair
predicted_bins = mx.argmax(distance_logits, axis=-1)
Limitations
- Trained on small proteins (<100 residues)
- Limited to 196 PDB structures (small dataset)
- Does not use MSA (multiple sequence alignments) like AlphaFold
- Simple MLP architecture (no attention mechanisms)
- Distance prediction only (no angles, torsions, or 3D coordinates)
Citation
If you use this model, please cite:
@software{protein-structure-mlx-esm2_t33_650M_UR50D,
title={ESM-2 Protein Structure Prediction with MLX},
author={Mohammad Huzefa Shaikh},
year={2025},
url={https://huggingface.co/gccmorgoth/protein-structure-mlx-esm2_t33_650M_UR50D}
}
Acknowledgments
- ESM-2: Lin et al. 2022 - Meta AI Research
- AlphaFold2: Jumper et al. 2021 - DeepMind (distogram loss design)
- MLX: Apple ML Research
Hardware compatibility
Log In to add your hardware
Quantized
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐ Ask for provider support
Model tree for gccmorgoth/protein-structure-mlx-esm2_t33_650M_UR50D
Base model
facebook/esm2_t33_650M_UR50D