### Embedding Mixin + Pooler import os import sqlite3 import networkx as nx import numpy as np import torch from tqdm.auto import tqdm from typing import Callable, List, Optional from torch.utils.data import DataLoader from torch.utils.data import Dataset as TorchDataset from transformers import PreTrainedTokenizerBase class Pooler: def __init__(self, pooling_types: List[str]): self.pooling_types = pooling_types self.pooling_options = { 'mean': self.mean_pooling, 'max': self.max_pooling, 'norm': self.norm_pooling, 'median': self.median_pooling, 'std': self.std_pooling, 'var': self.var_pooling, 'cls': self.cls_pooling, 'parti': self._pool_parti, } def _create_pooled_matrices_across_layers(self, attentions: torch.Tensor) -> torch.Tensor: maxed_attentions = torch.max(attentions, dim=1)[0] return maxed_attentions def _page_rank(self, attention_matrix, personalization=None, nstart=None, prune_type="top_k_outdegree"): # Run PageRank on the attention matrix converted to a graph. # Raises exceptions if the graph doesn't match the token sequence or has no edges. # Returns the PageRank scores for each token node. G = self._convert_to_graph(attention_matrix) if G.number_of_nodes() != attention_matrix.shape[0]: raise Exception( f"The number of nodes in the graph should be equal to the number of tokens in sequence! You have {G.number_of_nodes()} nodes for {attention_matrix.shape[0]} tokens.") if G.number_of_edges() == 0: raise Exception(f"You don't seem to have any attention edges left in the graph.") return nx.pagerank(G, alpha=0.85, tol=1e-06, weight='weight', personalization=personalization, nstart=nstart, max_iter=100) def _convert_to_graph(self, matrix): # Convert a matrix (e.g., attention scores) to a directed graph using networkx. # Each element in the matrix represents a directed edge with a weight. G = nx.from_numpy_array(matrix, create_using=nx.DiGraph) return G def _calculate_importance_weights(self, dict_importance, attention_mask: Optional[torch.Tensor] = None): # Remove keys where attention_mask is 0 if attention_mask is not None: for k in list(dict_importance.keys()): if attention_mask[k] == 0: del dict_importance[k] #dict_importance[0] # remove cls #dict_importance[-1] # remove eos total = sum(dict_importance.values()) return np.array([v / total for _, v in dict_importance.items()]) def _pool_parti(self, emb: torch.Tensor, attentions: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): # (b, L, d) -> (b, d) maxed_attentions = self._create_pooled_matrices_across_layers(attentions).numpy() # emb is (b, L, d), maxed_attentions is (b, L, L) emb_pooled = [] for e, a, mask in zip(emb, maxed_attentions, attention_mask): dict_importance = self._page_rank(a) importance_weights = self._calculate_importance_weights(dict_importance, mask) num_tokens = int(mask.sum().item()) emb_pooled.append(np.average(e[:num_tokens], weights=importance_weights, axis=0)) pooled = torch.tensor(np.array(emb_pooled)) return pooled def mean_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) if attention_mask is None: return emb.mean(dim=1) else: attention_mask = attention_mask.unsqueeze(-1) return (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) def max_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) if attention_mask is None: return emb.max(dim=1).values else: attention_mask = attention_mask.unsqueeze(-1) return (emb * attention_mask).max(dim=1).values def norm_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) if attention_mask is None: return emb.norm(dim=1, p=2) else: attention_mask = attention_mask.unsqueeze(-1) return (emb * attention_mask).norm(dim=1, p=2) def median_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) if attention_mask is None: return emb.median(dim=1).values else: attention_mask = attention_mask.unsqueeze(-1) return (emb * attention_mask).median(dim=1).values def std_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) if attention_mask is None: return emb.std(dim=1) else: # Compute variance correctly over non-masked positions, then take sqrt var = self.var_pooling(emb, attention_mask, **kwargs) return torch.sqrt(var) def var_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) if attention_mask is None: return emb.var(dim=1) else: # Correctly compute variance over only non-masked positions attention_mask = attention_mask.unsqueeze(-1) # (b, L, 1) # Compute mean over non-masked positions mean = (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) # (b, d) mean = mean.unsqueeze(1) # (b, 1, d) # Compute squared differences from mean, only over non-masked positions squared_diff = (emb - mean) ** 2 # (b, L, d) # Sum squared differences over non-masked positions and divide by count var = (squared_diff * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) # (b, d) return var def cls_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) return emb[:, 0, :] def __call__( self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, attentions: Optional[torch.Tensor] = None ): # [mean, max] final_emb = [] for pooling_type in self.pooling_types: final_emb.append(self.pooling_options[pooling_type](emb=emb, attention_mask=attention_mask, attentions=attentions)) # (b, d) return torch.cat(final_emb, dim=-1) # (b, n_pooling_types * d) class ProteinDataset(TorchDataset): """Simple dataset for protein sequences.""" def __init__(self, sequences: list[str]): self.sequences = sequences def __len__(self) -> int: return len(self.sequences) def __getitem__(self, idx: int) -> str: return self.sequences[idx] def build_collator(tokenizer: PreTrainedTokenizerBase) -> Callable[[list[str]], dict[str, torch.Tensor]]: def _collate_fn(sequences: list[str]) -> dict[str, torch.Tensor]: return tokenizer(sequences, return_tensors="pt", padding='longest') return _collate_fn def parse_fasta(fasta_path: str) -> List[str]: assert os.path.exists(fasta_path), f"FASTA file does not exist: {fasta_path}" sequences = [] current_seq = [] with open(fasta_path, 'r') as f: for line in f: line = line.strip() if not line: continue if line.startswith('>'): if current_seq: sequences.append(''.join(current_seq)) current_seq = [] else: current_seq.append(line) if current_seq: sequences.append(''.join(current_seq)) return sequences class EmbeddingMixin: def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: raise NotImplementedError @property def device(self) -> torch.device: """Get the device of the model.""" return next(self.parameters()).device def _read_sequences_from_db(self, db_path: str) -> set[str]: """Read sequences from SQLite database.""" sequences = [] with sqlite3.connect(db_path) as conn: c = conn.cursor() c.execute("SELECT sequence FROM embeddings") while True: row = c.fetchone() if row is None: break sequences.append(row[0]) return set(sequences) def _ensure_embeddings_table(self, conn: sqlite3.Connection) -> None: cursor = conn.cursor() cursor.execute( "CREATE TABLE IF NOT EXISTS embeddings (" "sequence TEXT PRIMARY KEY, " "embedding BLOB NOT NULL, " "shape TEXT, " "dtype TEXT" ")" ) cursor.execute("PRAGMA table_info(embeddings)") rows = cursor.fetchall() column_names = [row[1] for row in rows] if "shape" not in column_names: cursor.execute("ALTER TABLE embeddings ADD COLUMN shape TEXT") if "dtype" not in column_names: cursor.execute("ALTER TABLE embeddings ADD COLUMN dtype TEXT") conn.commit() def load_embeddings_from_pth(self, save_path: str) -> dict[str, torch.Tensor]: assert os.path.exists(save_path), f"Embedding file does not exist: {save_path}" payload = torch.load(save_path, map_location="cpu", weights_only=True) assert isinstance(payload, dict), "Expected .pth embeddings file to contain a dictionary." for sequence, tensor in payload.items(): assert isinstance(sequence, str), "Expected embedding dictionary keys to be sequences (str)." assert isinstance(tensor, torch.Tensor), "Expected embedding dictionary values to be tensors." return payload def load_embeddings_from_db(self, db_path: str, sequences: Optional[List[str]] = None) -> dict[str, torch.Tensor]: assert os.path.exists(db_path), f"Embedding database does not exist: {db_path}" loaded: dict[str, torch.Tensor] = {} with sqlite3.connect(db_path) as conn: self._ensure_embeddings_table(conn) cursor = conn.cursor() if sequences is None: cursor.execute("SELECT sequence, embedding, shape, dtype FROM embeddings") else: if len(sequences) == 0: return loaded placeholders = ",".join(["?"] * len(sequences)) cursor.execute( f"SELECT sequence, embedding, shape, dtype FROM embeddings WHERE sequence IN ({placeholders})", tuple(sequences), ) rows = cursor.fetchall() for row in rows: sequence = row[0] embedding_bytes = row[1] shape_text = row[2] dtype_text = row[3] assert shape_text is not None, "Missing shape metadata in embeddings table." assert dtype_text is not None, "Missing dtype metadata in embeddings table." shape_values = [int(value) for value in shape_text.split(",") if len(value) > 0] assert len(shape_values) > 0, f"Invalid shape metadata for sequence: {sequence}" expected_size = int(np.prod(shape_values)) np_dtype = np.dtype(dtype_text) array = np.frombuffer(embedding_bytes, dtype=np_dtype) assert array.size == expected_size, f"Shape mismatch while reading sequence: {sequence}" reshaped = array.copy().reshape(tuple(shape_values)) loaded[sequence] = torch.from_numpy(reshaped) return loaded def embed_dataset( self, sequences: Optional[List[str]] = None, tokenizer: Optional[PreTrainedTokenizerBase] = None, batch_size: int = 2, max_len: int = 512, truncate: bool = True, full_embeddings: bool = False, embed_dtype: torch.dtype = torch.float32, pooling_types: List[str] = ['mean'], num_workers: int = 0, sql: bool = False, save: bool = True, sql_db_path: str = 'embeddings.db', save_path: str = 'embeddings.pth', fasta_path: Optional[str] = None, **kwargs, ) -> Optional[dict[str, torch.Tensor]]: """ Embed a dataset of protein sequences. Supports two modes: - Tokenizer mode (ESM2/ESM++): provide `tokenizer`, `_embed(input_ids, attention_mask)` is used. - Sequence mode (E1): pass `tokenizer=None`, `_embed(sequences, return_attention_mask=True, **kwargs)` is used. Sequences can be supplied as a list via `sequences`, parsed from a FASTA file via `fasta_path`, or both (the two sources are combined). At least one must be provided. """ if fasta_path is not None: fasta_sequences = parse_fasta(fasta_path) sequences = list(sequences or []) + fasta_sequences assert sequences is not None and len(sequences) > 0, \ "Must provide at least one sequence via `sequences` or `fasta_path`." sequences = list(set([seq[:max_len] if truncate else seq for seq in sequences])) sequences = sorted(sequences, key=len, reverse=True) hidden_size = self.config.hidden_size pooler = Pooler(pooling_types) if not full_embeddings else None tokenizer_mode = tokenizer is not None if tokenizer_mode: collate_fn = build_collator(tokenizer) device = self.device else: collate_fn = None device = None def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: if full_embeddings or residue_embeddings.ndim == 2: return residue_embeddings return pooler(residue_embeddings, attention_mask) def iter_batches(to_embed: List[str]): if tokenizer_mode: assert collate_fn is not None assert device is not None dataset = ProteinDataset(to_embed) dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=collate_fn, shuffle=False) for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'): seqs = to_embed[i * batch_size:(i + 1) * batch_size] input_ids = batch['input_ids'].to(device) attention_mask = batch['attention_mask'].to(device) residue_embeddings = self._embed(input_ids, attention_mask) yield seqs, residue_embeddings, attention_mask else: for batch_start in tqdm(range(0, len(to_embed), batch_size), desc='Embedding batches'): seqs = to_embed[batch_start:batch_start + batch_size] batch_output = self._embed(seqs, return_attention_mask=True, **kwargs) assert isinstance(batch_output, tuple), "Sequence mode _embed must return (last_hidden_state, attention_mask)." assert len(batch_output) == 2, "Sequence mode _embed must return exactly two values." residue_embeddings, attention_mask = batch_output assert isinstance(attention_mask, torch.Tensor), "Sequence mode _embed must return attention_mask as a torch.Tensor." yield seqs, residue_embeddings, attention_mask if sql: conn = sqlite3.connect(sql_db_path) self._ensure_embeddings_table(conn) c = conn.cursor() already_embedded = self._read_sequences_from_db(sql_db_path) to_embed = [seq for seq in sequences if seq not in already_embedded] print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}") print(f"Embedding {len(to_embed)} new sequences") if len(to_embed) > 0: with torch.no_grad(): for i, (seqs, residue_embeddings, attention_mask) in enumerate(iter_batches(to_embed)): embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype) for seq, emb, mask in zip(seqs, embeddings, attention_mask): if full_embeddings: emb = emb[mask.bool()].reshape(-1, hidden_size) emb_np = emb.cpu().numpy() emb_shape = ",".join([str(dim) for dim in emb_np.shape]) emb_dtype = str(emb_np.dtype) c.execute( "INSERT OR REPLACE INTO embeddings (sequence, embedding, shape, dtype) VALUES (?, ?, ?, ?)", (seq, emb_np.tobytes(), emb_shape, emb_dtype), ) if tokenizer_mode and (i + 1) % 100 == 0: conn.commit() conn.commit() conn.close() return None embeddings_dict = {} if os.path.exists(save_path): embeddings_dict = self.load_embeddings_from_pth(save_path) to_embed = [seq for seq in sequences if seq not in embeddings_dict] print(f"Found {len(embeddings_dict)} already embedded sequences in {save_path}") print(f"Embedding {len(to_embed)} new sequences") else: to_embed = sequences print(f"Embedding {len(to_embed)} new sequences") if len(to_embed) > 0: with torch.no_grad(): for seqs, residue_embeddings, attention_mask in iter_batches(to_embed): embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype) for seq, emb, mask in zip(seqs, embeddings, attention_mask): if full_embeddings: emb = emb[mask.bool()].reshape(-1, hidden_size) embeddings_dict[seq] = emb.cpu() if save: torch.save(embeddings_dict, save_path) return embeddings_dict import torch import torch.nn as nn from torch.nn import functional as F from typing import Optional, Tuple, Dict, Any from einops import rearrange from dataclasses import dataclass from transformers import PreTrainedModel, PretrainedConfig, EsmTokenizer from transformers.modeling_outputs import ModelOutput from transformers.models.esm.modeling_esm import ( EsmIntermediate, EsmOutput, EsmPooler, EsmLMHead, EsmSelfOutput, EsmClassificationHead, EsmContactPredictionHead, EsmEmbeddings, RotaryEmbedding, ) try: from torch.nn.attention.flex_attention import create_block_mask, flex_attention, BlockMask except ImportError: create_block_mask = None flex_attention = None BlockMask = None from enum import Enum ### Kernels Flash Attention Detection def _infer_kernels_flash_variant(kernel) -> str | None: if hasattr(kernel, "fwd") and hasattr(kernel, "varlen_fwd"): return "flash_attn2" if hasattr(kernel, "flash_attn_func") and hasattr(kernel, "flash_attn_varlen_func"): return "flash_attn3" return None def _try_get_kernels_flash(): try: from kernels import get_kernel except ImportError: return None, None flash_kernel = None flash_kernel_variant = None try: flash_kernel = get_kernel("kernels-community/flash-attn3") flash_kernel_variant = _infer_kernels_flash_variant(flash_kernel) assert flash_kernel_variant is not None, "Loaded flash-attn3 kernel does not expose a supported API." except Exception: try: flash_kernel = get_kernel("kernels-community/flash-attn2") flash_kernel_variant = _infer_kernels_flash_variant(flash_kernel) assert flash_kernel_variant is not None, "Loaded flash-attn2 kernel does not expose a supported API." except Exception: flash_kernel = None flash_kernel_variant = None return flash_kernel, flash_kernel_variant FLASH_KERNEL, FLASH_KERNEL_VARIANT = _try_get_kernels_flash() def _kernels_flash_forward( query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, causal: bool = False, ) -> torch.Tensor: assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment." if FLASH_KERNEL_VARIANT == "flash_attn2": return FLASH_KERNEL.fwd(q=query_states, k=key_states, v=value_states, is_causal=causal)[0] if FLASH_KERNEL_VARIANT == "flash_attn3": try: output = FLASH_KERNEL.flash_attn_func(q=query_states, k=key_states, v=value_states, causal=causal) except TypeError: output = FLASH_KERNEL.flash_attn_func(query_states, key_states, value_states, 0.0, None, causal) if isinstance(output, tuple): return output[0] return output raise AssertionError(f"Unsupported kernels flash attention variant: {FLASH_KERNEL_VARIANT}") def _kernels_flash_varlen_forward( query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, max_seqlen_in_batch_q: int, max_seqlen_in_batch_k: int, causal: bool = False, ) -> torch.Tensor: assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment." if FLASH_KERNEL_VARIANT == "flash_attn2": return FLASH_KERNEL.varlen_fwd( q=query_states, k=key_states, v=value_states, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k, is_causal=causal, )[0] if FLASH_KERNEL_VARIANT == "flash_attn3": try: output = FLASH_KERNEL.flash_attn_varlen_func( q=query_states, k=key_states, v=value_states, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k, causal=causal, ) except TypeError: output = FLASH_KERNEL.flash_attn_varlen_func( query_states, key_states, value_states, cu_seqlens_q, cu_seqlens_k, max_seqlen_in_batch_q, max_seqlen_in_batch_k, 0.0, None, causal, ) if isinstance(output, tuple): return output[0] return output raise AssertionError(f"Unsupported kernels flash attention variant: {FLASH_KERNEL_VARIANT}") ### Unpad / Pad helpers for varlen flash attention class IndexFirstAxis(torch.autograd.Function): @staticmethod def forward(ctx, input, indices) -> torch.Tensor: ctx.save_for_backward(indices) assert input.ndim >= 2 ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] second_dim = other_shape.numel() return torch.gather( rearrange(input, "b ... -> b (...)"), 0, indices.unsqueeze(1).expand(-1, second_dim) ).reshape(-1, *other_shape) @staticmethod def backward(ctx, grad_output) -> tuple[torch.Tensor, None]: (indices,) = ctx.saved_tensors assert grad_output.ndim >= 2 other_shape = grad_output.shape[1:] grad_output = rearrange(grad_output, "b ... -> b (...)") grad_input = torch.zeros( [ctx.first_axis_dim, grad_output.shape[1]], device=grad_output.device, dtype=grad_output.dtype ) grad_input.scatter_(0, indices.unsqueeze(1).expand(-1, grad_output.shape[1]), grad_output) return grad_input.reshape(ctx.first_axis_dim, *other_shape), None class IndexPutFirstAxis(torch.autograd.Function): @staticmethod def forward(ctx, values, indices, first_axis_dim) -> torch.Tensor: ctx.save_for_backward(indices) assert indices.ndim == 1 assert values.ndim >= 2 output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype) output[indices] = values return output @staticmethod def backward(ctx, grad_output) -> tuple[torch.Tensor, None, None]: (indices,) = ctx.saved_tensors return grad_output[indices], None, None index_first_axis = IndexFirstAxis.apply index_put_first_axis = IndexPutFirstAxis.apply def pad_input(hidden_states: torch.Tensor, indices: torch.Tensor, batch: int, seqlen: int) -> torch.Tensor: output = index_put_first_axis(hidden_states, indices, batch * seqlen) return rearrange(output, "(b s) ... -> b s ...", b=batch) def _unpad_input( query_layer: torch.Tensor, key_layer: torch.Tensor, value_layer: torch.Tensor, attention_mask_2d: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, tuple[torch.Tensor, torch.Tensor], tuple[int, int]]: batch_size, seq_len, num_heads, head_dim = query_layer.shape seqlens = attention_mask_2d.sum(dim=1).int() cu_seqlens = F.pad(seqlens.cumsum(0, dtype=torch.int32), (1, 0)) max_seqlen = int(seqlens.max().item()) indices = attention_mask_2d.flatten().nonzero(as_tuple=False).flatten() query_layer = index_first_axis(query_layer.reshape(batch_size * seq_len, num_heads, head_dim), indices) key_layer = index_first_axis(key_layer.reshape(batch_size * seq_len, num_heads, head_dim), indices) value_layer = index_first_axis(value_layer.reshape(batch_size * seq_len, num_heads, head_dim), indices) return query_layer, key_layer, value_layer, indices, (cu_seqlens, cu_seqlens), (max_seqlen, max_seqlen) def kernels_flash_attention_func( query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, attention_mask_2d: torch.Tensor | None = None, causal: bool = False, ) -> torch.Tensor: assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment." if not causal and attention_mask_2d is not None: batch_size, q_len = query_states.shape[:2] ( query_states, key_states, value_states, indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k), ) = _unpad_input(query_states, key_states, value_states, attention_mask_2d) attn_output_unpad = _kernels_flash_varlen_forward( query_states=query_states, key_states=key_states, value_states=value_states, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_in_batch_q=max_seqlen_q, max_seqlen_in_batch_k=max_seqlen_k, ) return pad_input(attn_output_unpad, indices_q, batch_size, q_len) else: return _kernels_flash_forward( query_states=query_states, key_states=key_states, value_states=value_states, causal=causal, ) ### Attention Backend Enum & Resolution class AttentionBackend(Enum): AUTO = "auto" KERNELS_FLASH = "kernels_flash" FLEX = "flex" SDPA = "sdpa" VALID_ATTENTION_BACKENDS = tuple(b.value for b in AttentionBackend) _BACKEND_CONFIRMED = False def resolve_attention_backend(requested_backend: str) -> AttentionBackend: global _BACKEND_CONFIRMED assert requested_backend in VALID_ATTENTION_BACKENDS, ( f"Unsupported attention backend: {requested_backend}. Expected one of {VALID_ATTENTION_BACKENDS}." ) if requested_backend == AttentionBackend.AUTO.value: if FLASH_KERNEL is not None: resolved = AttentionBackend.KERNELS_FLASH elif flex_attention is not None: resolved = AttentionBackend.FLEX else: resolved = AttentionBackend.SDPA elif requested_backend == AttentionBackend.KERNELS_FLASH.value: assert FLASH_KERNEL is not None, "Kernels Flash Attention is not available in this environment." resolved = AttentionBackend.KERNELS_FLASH elif requested_backend == AttentionBackend.FLEX.value: assert flex_attention is not None, "Flex Attention is not available in this environment." resolved = AttentionBackend.FLEX elif requested_backend == AttentionBackend.SDPA.value: resolved = AttentionBackend.SDPA else: raise AssertionError(f"Unsupported attention backend: {requested_backend}") if not _BACKEND_CONFIRMED: print(f"Attention backend: config='{requested_backend}' -> resolved='{resolved.value}'") _BACKEND_CONFIRMED = True return resolved def get_attention_mask( effective_backend: AttentionBackend, batch_size: int, seq_len: int, device: torch.device, attention_mask: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor | None, torch.Tensor | None, "BlockMask | None"]: """Build padding masks once for all encoder layers. Returns (attention_mask_2d, attention_mask_4d, flex_block_mask). """ if attention_mask is None: return None, None, None attention_mask_2d = attention_mask.bool() if effective_backend == AttentionBackend.KERNELS_FLASH: return attention_mask_2d, None, None if effective_backend == AttentionBackend.FLEX: assert create_block_mask is not None, "Flex attention backend requested but torch.create_block_mask is unavailable." valid_lens = attention_mask_2d.sum(dim=-1) def mask_mod(batch_idx, head_idx, q_idx, kv_idx): return (q_idx < valid_lens[batch_idx]) & (kv_idx < valid_lens[batch_idx]) flex_block_mask = create_block_mask(mask_mod, batch_size, 1, seq_len, seq_len, device=device) return attention_mask_2d, None, flex_block_mask # SDPA / manual — only mask the key dimension so padding query positions attend to # real keys and produce valid (non-NaN) outputs instead of NaN from softmax(-inf,...,-inf). attention_mask_4d = attention_mask_2d[:, None, None, :] return attention_mask_2d, attention_mask_4d, None @dataclass class FastEsmEncoderOutput(ModelOutput): last_hidden_state: Optional[torch.Tensor] = None hidden_states: Optional[Tuple[torch.Tensor, ...]] = None attentions: Optional[Tuple[torch.Tensor, ...]] = None s_max: Optional[Tuple[list[torch.Tensor], ...]] = None @dataclass class EsmMaskedLMOutput(ModelOutput): loss: Optional[torch.Tensor] = None logits: Optional[torch.Tensor] = None last_hidden_state: Optional[torch.Tensor] = None hidden_states: Optional[Tuple[torch.Tensor, ...]] = None attentions: Optional[Tuple[torch.Tensor, ...]] = None s_max: Optional[Tuple[list[torch.Tensor], ...]] = None class FastEsmConfig(PretrainedConfig): model_type = "fast_esm" def __init__( self, vocab_size: int = None, mask_token_id: int = None, pad_token_id: int = None, hidden_size: int = 768, num_hidden_layers: int = 12, num_attention_heads: int = 12, intermediate_size: int = 3072, hidden_dropout_prob: float = 0.1, attention_probs_dropout_prob: float = 0.1, max_position_embeddings: int = 1026, initializer_range: float = 0.02, layer_norm_eps: float = 1e-12, position_embedding_type: str = "rotary", emb_layer_norm_before: bool = None, token_dropout: bool = True, attn_backend: str = "sdpa", **kwargs, ): super().__init__( pad_token_id=pad_token_id, mask_token_id=mask_token_id, **kwargs, ) self.vocab_size = vocab_size self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.intermediate_size = intermediate_size self.hidden_dropout_prob = hidden_dropout_prob self.attention_probs_dropout_prob = attention_probs_dropout_prob self.max_position_embeddings = max_position_embeddings self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps self.position_embedding_type = position_embedding_type self.emb_layer_norm_before = emb_layer_norm_before self.tie_word_embeddings = False self.token_dropout = token_dropout self.attn_backend = attn_backend def to_dict(self) -> Dict[str, Any]: """ Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. Returns: `Dict[str, any]`: Dictionar y of all the attributes that make up this configuration instance, """ output = super().to_dict() return output class EsmSelfAttention(nn.Module): def __init__(self, config, position_embedding_type: Optional[str] = None): super().__init__() assert config.hidden_size % config.num_attention_heads == 0, ( f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " f"heads ({config.num_attention_heads})" ) self.num_attention_heads = config.num_attention_heads self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size self.query = nn.Linear(config.hidden_size, self.all_head_size) self.key = nn.Linear(config.hidden_size, self.all_head_size) self.value = nn.Linear(config.hidden_size, self.all_head_size) self.scale = self.attention_head_size**-0.5 self.dropout_prob = config.attention_probs_dropout_prob self.config = config self.attn_backend = resolve_attention_backend(config.attn_backend) self.position_embedding_type = position_embedding_type or config.position_embedding_type self.rotary_embeddings = None if self.position_embedding_type == "rotary": self.rotary_embeddings = RotaryEmbedding(dim=self.attention_head_size) def forward( self, hidden_states: torch.Tensor, attention_mask_2d: torch.Tensor | None = None, attention_mask_4d: torch.Tensor | None = None, flex_block_mask: "BlockMask | None" = None, output_attentions: bool = False, output_s_max: bool = False, ) -> tuple[torch.Tensor, torch.Tensor | None, list[torch.Tensor] | None]: batch_size, seq_length = hidden_states.shape[:-1] hidden_shape = (batch_size, seq_length, -1, self.attention_head_size) query_BHLD = self.query(hidden_states).view(hidden_shape).transpose(1, 2) key_BHLD = self.key(hidden_states).view(hidden_shape).transpose(1, 2) value_BHLD = self.value(hidden_states).view(hidden_shape).transpose(1, 2) query_BHLD = query_BHLD * self.scale if self.position_embedding_type == "rotary": query_BHLD, key_BHLD = self.rotary_embeddings(query_BHLD, key_BHLD) attn_output, attn_weights, s_max = self._attn( query_BHLD, key_BHLD, value_BHLD, attention_mask_2d=attention_mask_2d, attention_mask_4d=attention_mask_4d, flex_block_mask=flex_block_mask, output_attentions=output_attentions, output_s_max=output_s_max, ) return attn_output, attn_weights, s_max def _attn( self, query_BHLD: torch.Tensor, key_BHLD: torch.Tensor, value_BHLD: torch.Tensor, attention_mask_2d: torch.Tensor | None = None, attention_mask_4d: torch.Tensor | None = None, flex_block_mask: "BlockMask | None" = None, output_attentions: bool = False, output_s_max: bool = False, ) -> tuple[torch.Tensor, torch.Tensor | None, list[torch.Tensor] | None]: if output_attentions: return self._manual_attn(query_BHLD, key_BHLD, value_BHLD, attention_mask_4d, output_s_max) if self.attn_backend == AttentionBackend.KERNELS_FLASH: attn_output, attn_weights = self._kernels_flash_attn(query_BHLD, key_BHLD, value_BHLD, attention_mask_2d) elif self.attn_backend == AttentionBackend.FLEX: attn_output, attn_weights = self._flex_attn(query_BHLD, key_BHLD, value_BHLD, flex_block_mask) elif self.attn_backend == AttentionBackend.SDPA: attn_output, attn_weights = self._sdpa_attn(query_BHLD, key_BHLD, value_BHLD, attention_mask_4d) else: raise AssertionError(f"Unsupported resolved backend: {self.attn_backend}") s_max = self._compute_s_max(query_BHLD, key_BHLD) if output_s_max else None return attn_output, attn_weights, s_max @torch.no_grad() def _compute_s_max(self, query_BHLD: torch.Tensor, key_BHLD: torch.Tensor) -> list[torch.Tensor]: q_norm = torch.linalg.vector_norm(query_BHLD, dim=-1) k_norm = torch.linalg.vector_norm(key_BHLD, dim=-1) s_max_bound = (q_norm.max(dim=-1).values * k_norm.max(dim=-1).values).max(dim=0).values return [s_max_bound[h] for h in range(self.num_attention_heads)] def _manual_attn( self, query_BHLD: torch.Tensor, key_BHLD: torch.Tensor, value_BHLD: torch.Tensor, attention_mask_4d: torch.Tensor | None = None, output_s_max: bool = False, ) -> tuple[torch.Tensor, torch.Tensor, list[torch.Tensor] | None]: attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-1, -2)) if attention_mask_4d is not None: attn_weights = attn_weights.masked_fill(attention_mask_4d.logical_not(), float("-inf")) attn_weights = F.softmax(attn_weights, dim=-1) if self.dropout_prob > 0 and self.training: attn_weights = F.dropout(attn_weights, p=self.dropout_prob, training=self.training) context_BHLD = torch.matmul(attn_weights, value_BHLD) attn_output = rearrange(context_BHLD, "b h s d -> b s (h d)") s_max = self._compute_s_max(query_BHLD, key_BHLD) if output_s_max else None return attn_output, attn_weights, s_max def _kernels_flash_attn( self, query_BHLD: torch.Tensor, key_BHLD: torch.Tensor, value_BHLD: torch.Tensor, attention_mask_2d: torch.Tensor | None = None, ) -> tuple[torch.Tensor, None]: query_BLHD = query_BHLD.transpose(1, 2).contiguous() key_BLHD = key_BHLD.transpose(1, 2).contiguous() value_BLHD = value_BHLD.transpose(1, 2).contiguous() attn_output = kernels_flash_attention_func( query_states=query_BLHD, key_states=key_BLHD, value_states=value_BLHD, attention_mask_2d=attention_mask_2d, causal=False, ) return rearrange(attn_output, "b s h d -> b s (h d)"), None def _flex_attn( self, query_BHLD: torch.Tensor, key_BHLD: torch.Tensor, value_BHLD: torch.Tensor, flex_block_mask: "BlockMask | None" = None, ) -> tuple[torch.Tensor, None]: assert flex_attention is not None, "Flex attention is not available in this environment." assert query_BHLD.dtype in (torch.float16, torch.bfloat16), ( f"Flex attention requires float16 or bfloat16, got {query_BHLD.dtype}." ) context_BHLD = flex_attention(query_BHLD, key_BHLD, value_BHLD, block_mask=flex_block_mask, scale=1.0) return rearrange(context_BHLD, "b h s d -> b s (h d)"), None def _sdpa_attn( self, query_BHLD: torch.Tensor, key_BHLD: torch.Tensor, value_BHLD: torch.Tensor, attention_mask_4d: torch.Tensor | None = None, ) -> tuple[torch.Tensor, None]: context_BHLD = F.scaled_dot_product_attention( query_BHLD, key_BHLD, value_BHLD, attn_mask=attention_mask_4d, dropout_p=self.dropout_prob if self.training else 0.0, scale=1.0, ) return rearrange(context_BHLD, "b h s d -> b s (h d)"), None class EsmAttention(nn.Module): def __init__(self, config): super().__init__() self.self = EsmSelfAttention(config) self.output = EsmSelfOutput(config) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) def forward( self, hidden_states: torch.Tensor, attention_mask_2d: torch.Tensor | None = None, attention_mask_4d: torch.Tensor | None = None, flex_block_mask: "BlockMask | None" = None, output_attentions: bool = False, output_s_max: bool = False, ) -> tuple[torch.Tensor, torch.Tensor | None, list[torch.Tensor] | None]: hidden_states_ln = self.LayerNorm(hidden_states) attn_output, attn_weights, s_max = self.self( hidden_states_ln, attention_mask_2d=attention_mask_2d, attention_mask_4d=attention_mask_4d, flex_block_mask=flex_block_mask, output_attentions=output_attentions, output_s_max=output_s_max, ) attention_output = self.output(attn_output, hidden_states) return attention_output, attn_weights, s_max class EsmLayer(nn.Module): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 self.attention = EsmAttention(config) self.intermediate = EsmIntermediate(config) self.output = EsmOutput(config) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) def forward( self, hidden_states: torch.Tensor, attention_mask_2d: torch.Tensor | None = None, attention_mask_4d: torch.Tensor | None = None, flex_block_mask: "BlockMask | None" = None, output_attentions: bool = False, output_s_max: bool = False, ) -> tuple[torch.Tensor, torch.Tensor | None, list[torch.Tensor] | None]: attention_output, attn_weights, s_max = self.attention( hidden_states, attention_mask_2d=attention_mask_2d, attention_mask_4d=attention_mask_4d, flex_block_mask=flex_block_mask, output_attentions=output_attentions, output_s_max=output_s_max, ) layer_output = self.feed_forward_chunk(attention_output) return layer_output, attn_weights, s_max def feed_forward_chunk(self, attention_output): attention_output_ln = self.LayerNorm(attention_output) intermediate_output = self.intermediate(attention_output_ln) layer_output = self.output(intermediate_output, attention_output) return layer_output class EsmEncoder(nn.Module): def __init__(self, config): super().__init__() self.config = config self.attention_backend = resolve_attention_backend(config.attn_backend) self.layer = nn.ModuleList([EsmLayer(config) for _ in range(config.num_hidden_layers)]) self.emb_layer_norm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.gradient_checkpointing = False def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, output_hidden_states: bool = False, output_attentions: bool = False, output_s_max: bool = False, ) -> FastEsmEncoderOutput: all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None full_s_max = () if output_s_max else None attention_mask_2d, attention_mask_4d, flex_block_mask = get_attention_mask( effective_backend=self.attention_backend, batch_size=hidden_states.shape[0], seq_len=hidden_states.shape[1], device=hidden_states.device, attention_mask=attention_mask, ) for layer_module in self.layer: if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: hidden_states, attn_weights, s_max = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, attention_mask_2d, attention_mask_4d, flex_block_mask, output_attentions, output_s_max, ) else: hidden_states, attn_weights, s_max = layer_module( hidden_states, attention_mask_2d=attention_mask_2d, attention_mask_4d=attention_mask_4d, flex_block_mask=flex_block_mask, output_attentions=output_attentions, output_s_max=output_s_max, ) if all_attentions is not None: all_attentions = all_attentions + (attn_weights,) if full_s_max is not None: full_s_max = full_s_max + (s_max,) if self.emb_layer_norm_after: hidden_states = self.emb_layer_norm_after(hidden_states) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) return FastEsmEncoderOutput( last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions, s_max=full_s_max, ) class FastEsmPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ config_class = FastEsmConfig base_model_prefix = "fastesm" supports_gradient_checkpointing = True tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D") all_tied_weights_keys = {} @classmethod def is_remote_code(cls) -> bool: # Prevent post-load reinitialization of tensors already loaded from checkpoints. return True @torch.no_grad() def _init_weights(self, module: nn.Module) -> None: std = self.config.initializer_range if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() def post_init(self) -> None: super().post_init() def get_output_embeddings(self): # NOTE: get_output_embeddings() must return None to prevent accidental weight tying. # See e.g. https://github.com/huggingface/transformers/pull/39339#discussion_r2219126400 return None @property def attn_backend(self) -> str: return self.config.attn_backend @attn_backend.setter def attn_backend(self, backend: str) -> None: assert backend in VALID_ATTENTION_BACKENDS, f"Unsupported attn_backend: {backend}. Expected one of {VALID_ATTENTION_BACKENDS}." self.config.attn_backend = backend resolved = resolve_attention_backend(backend) for module in self.modules(): if isinstance(module, EsmEncoder): module.attention_backend = resolved elif isinstance(module, EsmSelfAttention): module.attn_backend = resolved class FAST_ESM_ENCODER(FastEsmPreTrainedModel, EmbeddingMixin): def __init__(self, config, add_pooling_layer: Optional[bool] = True, **kwargs): FastEsmPreTrainedModel.__init__(self, config, **kwargs) self.config = config self.embeddings = EsmEmbeddings(config) self.encoder = EsmEncoder(config) self.contact_head = EsmContactPredictionHead( in_features=config.num_hidden_layers * config.num_attention_heads, bias=True ) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.embeddings.word_embeddings def set_input_embeddings(self, value): self.embeddings.word_embeddings = value def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: token_embedding_output = self.embeddings(input_ids, attention_mask=attention_mask) encoder_outputs = self.encoder( token_embedding_output, attention_mask=attention_mask, output_hidden_states=False, output_attentions=False, ) return encoder_outputs.last_hidden_state def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: attns = self(input_ids, attention_mask=attention_mask, output_attentions=True).attentions attns = torch.stack(attns, dim=1) attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(3) attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(4) return self.contact_head(input_ids, attns) def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_s_max: Optional[bool] = False, return_dict: Optional[bool] = None, ) -> FastEsmEncoderOutput: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) elif inputs_embeds is None: raise ValueError("You have to specify either input_ids or inputs_embeds") token_embedding_output = self.embeddings( input_ids=input_ids, position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, ) encoder_outputs = self.encoder( token_embedding_output, attention_mask=attention_mask, output_hidden_states=output_hidden_states, output_attentions=output_attentions, output_s_max=output_s_max, ) return FastEsmEncoderOutput( last_hidden_state=encoder_outputs.last_hidden_state, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, s_max=encoder_outputs.s_max, ) class FastEsmModel(FastEsmPreTrainedModel, EmbeddingMixin): def __init__(self, config, add_pooling_layer: Optional[bool] = True, **kwargs): FastEsmPreTrainedModel.__init__(self, config, **kwargs) self.config = config self.esm = FAST_ESM_ENCODER(config) self.pooler = EsmPooler(config) if add_pooling_layer else None self.post_init() def get_input_embeddings(self): return self.esm.embeddings.word_embeddings def set_input_embeddings(self, value): self.esm.embeddings.word_embeddings = value def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: return self.esm._embed(input_ids, attention_mask) def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: return self.esm.predict_contacts(input_ids, attention_mask=attention_mask) def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_s_max: Optional[bool] = False, return_dict: Optional[bool] = None, **kwargs, ) -> FastEsmEncoderOutput: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states outputs = self.esm( input_ids, attention_mask=attention_mask, position_ids=position_ids, inputs_embeds=inputs_embeds, output_hidden_states=output_hidden_states, output_attentions=output_attentions, output_s_max=output_s_max, ) sequence_output = outputs.last_hidden_state pooled_output = self.pooler(sequence_output) if self.pooler is not None else None return FastEsmEncoderOutput( last_hidden_state=sequence_output, hidden_states=outputs.hidden_states, attentions=outputs.attentions, s_max=outputs.s_max, ) class FastEsmForMaskedLM(FastEsmPreTrainedModel, EmbeddingMixin): def __init__(self, config, **kwargs): FastEsmPreTrainedModel.__init__(self, config, **kwargs) self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False) self.lm_head = EsmLMHead(config) self.loss_fct = nn.CrossEntropyLoss() self.post_init() def get_input_embeddings(self): return self.esm.embeddings.word_embeddings def get_output_embeddings(self): return self.lm_head.decoder def set_output_embeddings(self, new_embeddings): self.lm_head.decoder = new_embeddings def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: return self.esm._embed(input_ids, attention_mask) def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: return self.esm.predict_contacts(input_ids, attention_mask=attention_mask) def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_s_max: Optional[bool] = False, return_dict: Optional[bool] = None, **kwargs, ) -> EsmMaskedLMOutput: outputs = self.esm( input_ids, attention_mask=attention_mask, position_ids=position_ids, inputs_embeds=inputs_embeds, output_hidden_states=output_hidden_states, output_attentions=output_attentions, output_s_max=output_s_max, ) sequence_output = outputs.last_hidden_state prediction_scores = self.lm_head(sequence_output) loss = None if labels is not None: labels = labels.to(prediction_scores.device) loss = self.loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) return EsmMaskedLMOutput( loss=loss, logits=prediction_scores, last_hidden_state=sequence_output, hidden_states=outputs.hidden_states, attentions=outputs.attentions, s_max=outputs.s_max, ) class FastEsmForSequenceClassification(FastEsmPreTrainedModel, EmbeddingMixin): def __init__(self, config, **kwargs): FastEsmPreTrainedModel.__init__(self, config, **kwargs) self.num_labels = config.num_labels self.config = config self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False) self.classifier = EsmClassificationHead(config) self.mse = nn.MSELoss() self.ce = nn.CrossEntropyLoss() self.bce = nn.BCEWithLogitsLoss() self.post_init() def get_input_embeddings(self): return self.esm.embeddings.word_embeddings def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: return self.esm._embed(input_ids, attention_mask) def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: return self.esm.predict_contacts(input_ids, attention_mask=attention_mask) def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_s_max: Optional[bool] = False, return_dict: Optional[bool] = None, **kwargs, ) -> EsmMaskedLMOutput: outputs = self.esm( input_ids, attention_mask=attention_mask, position_ids=position_ids, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, output_s_max=output_s_max, ) sequence_output = outputs.last_hidden_state logits = self.classifier(sequence_output) loss = None if labels is not None: labels = labels.to(logits.device) if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" if self.config.problem_type == "regression": if self.num_labels == 1: loss = self.mse(logits.squeeze(), labels.squeeze()) else: loss = self.mse(logits, labels) elif self.config.problem_type == "single_label_classification": loss = self.ce(logits.view(-1, self.num_labels), labels.view(-1)) elif self.config.problem_type == "multi_label_classification": loss = self.bce(logits, labels) return EsmMaskedLMOutput( loss=loss, logits=logits, last_hidden_state=sequence_output, hidden_states=outputs.hidden_states, attentions=outputs.attentions, s_max=outputs.s_max, ) class FastEsmForTokenClassification(FastEsmPreTrainedModel, EmbeddingMixin): def __init__(self, config, **kwargs): FastEsmPreTrainedModel.__init__(self, config, **kwargs) self.num_labels = config.num_labels self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.classifier = nn.Linear(config.hidden_size, config.num_labels) self.loss_fct = nn.CrossEntropyLoss() self.post_init() def get_input_embeddings(self): return self.esm.embeddings.word_embeddings def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: return self.esm._embed(input_ids, attention_mask) def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: return self.esm.predict_contacts(input_ids, attention_mask=attention_mask) def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_s_max: Optional[bool] = False, return_dict: Optional[bool] = None, **kwargs, ) -> EsmMaskedLMOutput: outputs = self.esm( input_ids, attention_mask=attention_mask, position_ids=position_ids, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, output_s_max=output_s_max, ) sequence_output = outputs.last_hidden_state sequence_output = self.dropout(sequence_output) logits = self.classifier(sequence_output) loss = None if labels is not None: labels = labels.to(logits.device) loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) return EsmMaskedLMOutput( loss=loss, logits=logits, last_hidden_state=sequence_output, hidden_states=outputs.hidden_states, attentions=outputs.attentions, s_max=outputs.s_max, ) if __name__ == "__main__": import random import torch from torch import Tensor from transformers import EsmTokenizer def print_tensor_shapes(prefix: str, obj): if isinstance(obj, Tensor): print(f"{prefix}{obj.shape}") elif isinstance(obj, dict): for name, value in obj.items(): print_tensor_shapes(f"{prefix}{name}.", value) elif isinstance(obj, list): for idx, value in enumerate(obj): print_tensor_shapes(f"{prefix}[{idx}].", value) elif isinstance(obj, tuple): for idx, value in enumerate(obj): print_tensor_shapes(f"{prefix}[{idx}].", value) elif hasattr(obj, "__dict__"): for name, value in vars(obj).items(): if name.startswith("_"): continue print_tensor_shapes(f"{prefix}{name}.", value) else: print(f"{prefix}{type(obj)}") random.seed(0) torch.manual_seed(0) tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D") num_attention_heads = random.choice([2, 4]) config = FastEsmConfig( vocab_size=tokenizer.vocab_size, hidden_size=16 * num_attention_heads, num_attention_heads=num_attention_heads, num_hidden_layers=random.choice([1, 2]), intermediate_size=64 * num_attention_heads, hidden_dropout_prob=0.0, attention_probs_dropout_prob=0.0, mask_token_id=tokenizer.mask_token_id, pad_token_id=tokenizer.pad_token_id, max_position_embeddings=256, emb_layer_norm_before=False, position_embedding_type="rotary", attn_backend="sdpa", ) batch = tokenizer(["ACDEFG", "MKTW"], return_tensors="pt", padding="longest") batch["labels"] = batch["input_ids"].clone() model = FastEsmForMaskedLM(config=config).eval() with torch.no_grad(): output = model(**batch, return_dict=True) print("Batch shape:") print_tensor_shapes("", batch) print("Output shape:") print_tensor_shapes("", output)