| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| from transformers import PreTrainedModel, AutoConfig
|
| from transformers.modeling_outputs import MaskedLMOutput
|
| from mamba_ssm.modules.mamba_simple import Mamba
|
| from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
|
| from mamba_ssm.models.config_mamba import MambaConfig
|
| from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
|
| try:
|
| from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
|
| except ImportError:
|
| RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
|
| from einops import rearrange
|
|
|
| def convert_hf_config_to_mamba(hf_config) -> MambaConfig:
|
| return MambaConfig(
|
| d_model=hf_config.d_model,
|
| d_intermediate=getattr(hf_config, "intermediate_size", 4 * hf_config.d_model),
|
| n_layer=getattr(hf_config, "n_layer", getattr(hf_config, "num_hidden_layers", 12)),
|
| vocab_size=hf_config.vocab_size,
|
| ssm_cfg=getattr(hf_config, "ssm_cfg", {}),
|
| attn_layer_idx=getattr(hf_config, "attn_layer_idx", []),
|
| attn_cfg=getattr(hf_config, "attn_cfg", {}),
|
| rms_norm=getattr(hf_config, "rms_norm", True),
|
| residual_in_fp32=getattr(hf_config, "residual_in_fp32", True),
|
| fused_add_norm=getattr(hf_config, "fused_add_norm", False),
|
| pad_vocab_size_multiple=getattr(hf_config, "pad_vocab_size_multiple", 8),
|
| tie_embeddings=getattr(hf_config, "tie_embeddings", False),
|
| )
|
|
|
| def patch_mixer_forward_to_accept_embeddings(model):
|
| """
|
| Injects a new forward method into a MixerModel instance,
|
| allowing it to accept either input_ids or inputs_embeds.
|
| """
|
|
|
| def new_forward(self, input_ids=None, inputs_embeds=None, inference_params=None, attention_mask=None, **mixer_kwargs):
|
| if inputs_embeds is not None:
|
| hidden_states = inputs_embeds
|
| elif input_ids is not None:
|
| hidden_states = self.embedding(input_ids)
|
| else:
|
| raise ValueError("You must provide either input_ids or inputs_embeds.")
|
|
|
| residual = None
|
|
|
|
|
|
|
| mask = attention_mask.unsqueeze(-1)
|
|
|
| for layer in self.layers:
|
| hidden_states, residual = layer(
|
| hidden_states, residual, inference_params=inference_params, **mixer_kwargs
|
| )
|
|
|
|
|
| hidden_states = hidden_states * mask
|
| residual = residual * mask
|
|
|
| if not self.fused_add_norm:
|
| residual = (hidden_states + residual) if residual is not None else hidden_states
|
| hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
|
| else:
|
|
|
| hidden_states = layer_norm_fn(
|
| hidden_states,
|
| self.norm_f.weight,
|
| self.norm_f.bias,
|
| eps=self.norm_f.eps,
|
| residual=residual,
|
| prenorm=False,
|
| residual_in_fp32=self.residual_in_fp32,
|
| is_rms_norm=isinstance(self.norm_f, RMSNorm)
|
| )
|
| return hidden_states
|
|
|
|
|
| model.backbone.forward = new_forward.__get__(model.backbone, model.backbone.__class__)
|
|
|
| class BiMambaForMaskedLM(PreTrainedModel):
|
| config_class = AutoConfig
|
| base_model_prefix = "bimamba"
|
|
|
| def __init__(self, config):
|
| super().__init__(config)
|
| mamba_cfg = convert_hf_config_to_mamba(config)
|
|
|
|
|
| self.token_embedding = nn.Embedding(config.vocab_size, config.d_model, padding_idx=config.pad_token_id)
|
| self.mamba_forward = MambaLMHeadModel(mamba_cfg)
|
| self.mamba_backward = MambaLMHeadModel(mamba_cfg)
|
| self.lm_head_proj = nn.Linear(config.d_model * 2, config.d_model, bias=False)
|
|
|
|
|
| patch_mixer_forward_to_accept_embeddings(self.mamba_forward)
|
| patch_mixer_forward_to_accept_embeddings(self.mamba_backward)
|
|
|
|
|
|
|
|
|
| def get_input_embeddings(self):
|
| return self.token_embedding
|
|
|
| def set_input_embeddings(self, new_emb):
|
| self.token_embedding = new_emb
|
|
|
| def get_output_embeddings(self):
|
| return self.lm_head_proj
|
|
|
| def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
|
| for backbone in (self.mamba_forward.backbone,
|
| self.mamba_backward.backbone):
|
| for block in backbone.layers:
|
| block.gradient_checkpointing = True
|
|
|
| def forward(
|
| self,
|
| input_ids=None,
|
| inputs_embeds=None,
|
| attention_mask=None,
|
| labels=None,
|
| return_dict=True,
|
| ):
|
| if inputs_embeds is None:
|
| input_ids = input_ids.long()
|
| inputs_embeds = self.token_embedding(input_ids)
|
|
|
| hid_fwd = self.mamba_forward.backbone(inputs_embeds=inputs_embeds, attention_mask=attention_mask)
|
| rev_emb = torch.flip(inputs_embeds, dims=[1])
|
| rev_mask = torch.flip(attention_mask, dims=[1])
|
| hid_bwd = self.mamba_backward.backbone(inputs_embeds=rev_emb, attention_mask=rev_mask)
|
| hid_bwd = torch.flip(hid_bwd, dims=[1])
|
|
|
| combined = torch.cat([hid_fwd, hid_bwd], dim=-1)
|
| projected = self.lm_head_proj(combined)
|
| logits = F.linear(projected, self.token_embedding.weight)
|
|
|
| loss = None
|
| if labels is not None:
|
| loss_fn = nn.CrossEntropyLoss(ignore_index=-100)
|
| loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))
|
|
|
| if not return_dict:
|
| out = (logits, combined)
|
| return (loss,) + out if loss is not None else out
|
|
|
| return MaskedLMOutput(
|
| loss=loss,
|
| logits=logits,
|
| hidden_states=projected,
|
| )
|
|
|