Load the model:
import torch
from transformers import AutoModel, AutoTokenizer
model_name = "rnalm/446M_MS_MM_best"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
# Move model to GPU
model = model.cuda()
Inference without using the track prediction head:
# disable the track head in order to avoid providing the metadata
model.model.predict_tracks = False
inputs = tokenizer("ACGTACGT", return_tensors="pt")
# always add taxonomy information in the multispecies model
assert model.model.use_taxonomy == True
# use human taxonomy
# for a full list of taxonomies check 'rnalm/tokenizers/taxonomy_mappings/processed_taxonomy.json'
human_taxonomy = torch.tensor([2317, 2318, 2319, 2266, 2248, 2072, 2053, 1875])
with torch.no_grad():
outputs = model(input_ids=inputs["input_ids"].cuda(), masked_taxonomy=human_taxonomy.cuda())
last_hidden_state_w_taxonomy = outputs.last_hidden_state
last_hidden_state_wo_taxonomy = outputs.last_hidden_state[:, 1:, :]
last_hidden_state_w_taxonomy.shape
# torch.Size([1, 9, 1024])
last_hidden_state_wo_taxonomy.shape
# torch.Size([1, 8, 1024])
outputs.seq_logits.shape
# torch.Size([1, 8, 11])
Predict tracks using given metadata:
metadata = # path to tensor metadata
# Enable track prediction mode
model.model.predict_tracks = True
# Forward pass
with torch.no_grad():
outputs = model(
input_ids=inputs["input_ids"].cuda(),
metadata=metadata.cuda(),
masked_taxonomy=human_taxonomy.cuda()
)
outputs.track_yhat
Get metadata-dependent embeddings:
metadata = # path to tensor metadata
# Enable track prediction mode
model.model.predict_tracks = True
# Forward pass
with torch.no_grad():
outputs = model(
input_ids=inputs["input_ids"].cuda(),
metadata=metadata.cuda(),
masked_taxonomy=human_taxonomy.cuda()
)
outputs.last_hidden_state_track.shape
# torch.Size([1, 8, 1024])
- Downloads last month
- 4
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
๐
Ask for provider support