Load the model:

import torch
from transformers import AutoModel, AutoTokenizer

model_name = "rnalm/144M_MS_MM_best"

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModel.from_pretrained(model_name, trust_remote_code=True)

# Move model to GPU
model = model.cuda()

Inference without using the track prediction head:

# disable the track head in order to avoid providing the metadata
model.model.predict_tracks = False
inputs = tokenizer("ACGTACGT", return_tensors="pt")
# always add taxonomy information in the multispecies model
assert model.model.use_taxonomy == True
# use human taxonomy
# for a full list of taxonomies check 'rnalm/tokenizers/taxonomy_mappings/processed_taxonomy.json'
human_taxonomy = torch.tensor([2317, 2318, 2319, 2266, 2248, 2072, 2053, 1875])

with torch.no_grad():
  outputs = model(input_ids=inputs["input_ids"].cuda(), masked_taxonomy=human_taxonomy.cuda())


last_hidden_state_w_taxonomy = outputs.last_hidden_state
last_hidden_state_wo_taxonomy = outputs.last_hidden_state[:, 1:, :]

last_hidden_state_w_taxonomy.shape
# torch.Size([1, 9, 768])

last_hidden_state_wo_taxonomy.shape
# torch.Size([1, 8, 768])

outputs.seq_logits.shape
# torch.Size([1, 8, 11])

Predict tracks using given metadata:

metadata = # path to tensor metadata
# Enable track prediction mode
model.model.predict_tracks = True

# Forward pass
with torch.no_grad():
    outputs = model(
        input_ids=inputs["input_ids"].cuda(),
        metadata=metadata.cuda(),
        masked_taxonomy=human_taxonomy.cuda()
    )

outputs.track_yhat

Get metadata-dependent embeddings:

metadata = # path to tensor metadata
# Enable track prediction mode
model.model.predict_tracks = True

# Forward pass
with torch.no_grad():
    outputs = model(
        input_ids=inputs["input_ids"].cuda(),
        metadata=metadata.cuda(),
        masked_taxonomy=human_taxonomy.cuda()
    )

outputs.last_hidden_state_track.shape
# torch.Size([1, 8, 768])
Downloads last month
3
Safetensors
Model size
0.1B params
Tensor type
F32
ยท
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support