Spaces:
Sleeping
Sleeping
Commit
·
758a536
1
Parent(s):
c97ecfa
Create model.py
Browse files
model.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
from torch import Tensor
|
| 5 |
+
from typing import Tuple
|
| 6 |
+
|
| 7 |
+
from torchvision.models import resnet18, resnet50
|
| 8 |
+
from torchvision.models import ResNet18_Weights, ResNet50_Weights
|
| 9 |
+
|
| 10 |
+
class DistMult(nn.Module):
|
| 11 |
+
def __init__(self, args, num_ent_uid, target_list, device, all_locs=None, num_habitat=None, all_timestamps=None):
|
| 12 |
+
super(DistMult, self).__init__()
|
| 13 |
+
self.args = args
|
| 14 |
+
self.num_ent_uid = num_ent_uid
|
| 15 |
+
|
| 16 |
+
self.num_relations = 4
|
| 17 |
+
|
| 18 |
+
self.ent_embedding = torch.nn.Embedding(self.num_ent_uid, args.embedding_dim, sparse=False)
|
| 19 |
+
self.rel_embedding = torch.nn.Embedding(self.num_relations, args.embedding_dim, sparse=False)
|
| 20 |
+
|
| 21 |
+
self.location_embedding = MLP(args.location_input_dim, args.embedding_dim, args.mlp_location_numlayer)
|
| 22 |
+
|
| 23 |
+
self.time_embedding = MLP(args.time_input_dim, args.embedding_dim, args.mlp_time_numlayer)
|
| 24 |
+
|
| 25 |
+
if self.args.img_embed_model == 'resnet50':
|
| 26 |
+
self.image_embedding = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
|
| 27 |
+
self.image_embedding.fc = nn.Linear(2048, args.embedding_dim)
|
| 28 |
+
else:
|
| 29 |
+
self.image_embedding = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
|
| 30 |
+
self.image_embedding.fc = nn.Linear(512, args.embedding_dim)
|
| 31 |
+
|
| 32 |
+
self.target_list = target_list
|
| 33 |
+
|
| 34 |
+
if all_locs is not None:
|
| 35 |
+
self.all_locs = all_locs.to(device)
|
| 36 |
+
if all_timestamps is not None:
|
| 37 |
+
self.all_timestamps = all_timestamps.to(device)
|
| 38 |
+
|
| 39 |
+
self.args = args
|
| 40 |
+
self.device = device
|
| 41 |
+
|
| 42 |
+
self.init()
|
| 43 |
+
|
| 44 |
+
def init(self):
|
| 45 |
+
nn.init.xavier_uniform_(self.ent_embedding.weight.data)
|
| 46 |
+
nn.init.xavier_uniform_(self.rel_embedding.weight.data)
|
| 47 |
+
nn.init.xavier_uniform_(self.image_embedding.fc.weight.data)
|
| 48 |
+
|
| 49 |
+
def forward_ce(self, h, r, triple_type=None):
|
| 50 |
+
emb_h = self.batch_embedding_concat_h(h) # [batch, hid]
|
| 51 |
+
|
| 52 |
+
emb_r = self.rel_embedding(r.squeeze(-1)) # [batch, hid]
|
| 53 |
+
|
| 54 |
+
emb_hr = emb_h * emb_r # [batch, hid]
|
| 55 |
+
|
| 56 |
+
if triple_type == ('image', 'id'):
|
| 57 |
+
score = torch.mm(emb_hr, self.ent_embedding.weight[self.target_list.squeeze(-1)].T) # [batch, n_ent]
|
| 58 |
+
elif triple_type == ('id', 'id'):
|
| 59 |
+
score = torch.mm(emb_hr, self.ent_embedding.weight.T) # [batch, n_ent]
|
| 60 |
+
elif triple_type == ('image', 'location'):
|
| 61 |
+
loc_emb = self.location_embedding(self.all_locs) # computed for each batch
|
| 62 |
+
score = torch.mm(emb_hr, loc_emb.T)
|
| 63 |
+
elif triple_type == ('image', 'time'):
|
| 64 |
+
time_emb = self.time_embedding(self.all_timestamps)
|
| 65 |
+
score = torch.mm(emb_hr, time_emb.T)
|
| 66 |
+
else:
|
| 67 |
+
raise NotImplementedError
|
| 68 |
+
|
| 69 |
+
return score
|
| 70 |
+
|
| 71 |
+
def batch_embedding_concat_h(self, e1):
|
| 72 |
+
e1_embedded = None
|
| 73 |
+
|
| 74 |
+
if len(e1.size())==1 or e1.size(1) == 1: # uid
|
| 75 |
+
# print('ent_embedding = {}'.format(self.ent_embedding.weight.size()))
|
| 76 |
+
e1_embedded = self.ent_embedding(e1.squeeze(-1))
|
| 77 |
+
elif e1.size(1) == 15: # time
|
| 78 |
+
e1_embedded = self.time_embedding(e1)
|
| 79 |
+
elif e1.size(1) == 2: # GPS
|
| 80 |
+
e1_embedded = self.location_embedding(e1)
|
| 81 |
+
elif e1.size(1) == 3: # Image
|
| 82 |
+
e1_embedded = self.image_embedding(e1)
|
| 83 |
+
|
| 84 |
+
return e1_embedded
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class MLP(nn.Module):
|
| 88 |
+
def __init__(self,
|
| 89 |
+
input_dim,
|
| 90 |
+
output_dim,
|
| 91 |
+
num_layers=3,
|
| 92 |
+
p_dropout=0.0,
|
| 93 |
+
bias=True):
|
| 94 |
+
|
| 95 |
+
super().__init__()
|
| 96 |
+
|
| 97 |
+
self.input_dim = input_dim
|
| 98 |
+
self.output_dim = output_dim
|
| 99 |
+
|
| 100 |
+
self.p_dropout = p_dropout
|
| 101 |
+
step_size = (input_dim - output_dim) // num_layers
|
| 102 |
+
hidden_dims = [output_dim + (i * step_size)
|
| 103 |
+
for i in reversed(range(num_layers))]
|
| 104 |
+
|
| 105 |
+
mlp = list()
|
| 106 |
+
layer_indim = input_dim
|
| 107 |
+
for hidden_dim in hidden_dims:
|
| 108 |
+
mlp.extend([nn.Linear(layer_indim, hidden_dim, bias),
|
| 109 |
+
nn.Dropout(p=self.p_dropout, inplace=True),
|
| 110 |
+
nn.PReLU()])
|
| 111 |
+
|
| 112 |
+
layer_indim = hidden_dim
|
| 113 |
+
|
| 114 |
+
self.mlp = nn.Sequential(*mlp)
|
| 115 |
+
|
| 116 |
+
# initialize weights
|
| 117 |
+
self.init()
|
| 118 |
+
|
| 119 |
+
def forward(self, x):
|
| 120 |
+
return self.mlp(x)
|
| 121 |
+
|
| 122 |
+
def init(self):
|
| 123 |
+
for param in self.parameters():
|
| 124 |
+
nn.init.uniform_(param)
|