Spaces:
Build error
Build error
Commit
·
77c8482
1
Parent(s):
6692ae2
Upload 17 files
Browse files- .gitattributes +4 -0
- app.py +209 -0
- assets/attn_plot.png +3 -0
- assets/examples.png +3 -0
- assets/model_transformer.png +0 -0
- checkpoints/RATCHET.tf/keras_metadata.pb +3 -0
- checkpoints/RATCHET.tf/saved_model.pb +3 -0
- checkpoints/RATCHET.tf/variables/variables.data-00000-of-00001 +3 -0
- checkpoints/RATCHET.tf/variables/variables.index +0 -0
- checkpoints/cxr_validator_model.tf/fingerprint.pb +3 -0
- checkpoints/cxr_validator_model.tf/keras_metadata.pb +3 -0
- checkpoints/cxr_validator_model.tf/saved_model.pb +3 -0
- checkpoints/cxr_validator_model.tf/variables/variables.data-00000-of-00001 +3 -0
- checkpoints/cxr_validator_model.tf/variables/variables.index +0 -0
- mimic/mimic-merges.txt +0 -0
- mimic/mimic-vocab.json +0 -0
- requirements.txt +6 -0
- transformer.py +263 -0
.gitattributes
CHANGED
|
@@ -32,3 +32,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
assets/attn_plot.png filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
assets/examples.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
checkpoints/cxr_validator_model.tf/variables/variables.data-00000-of-00001 filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
checkpoints/RATCHET.tf/variables/variables.data-00000-of-00001 filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tqdm
|
| 2 |
+
import datetime
|
| 3 |
+
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
import numpy as np
|
| 6 |
+
import streamlit as st
|
| 7 |
+
import tensorflow as tf
|
| 8 |
+
|
| 9 |
+
from skimage import io
|
| 10 |
+
from transformer import Transformer
|
| 11 |
+
from tokenizers import ByteLevelBPETokenizer
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@st.cache_resource
|
| 15 |
+
def load_validator():
|
| 16 |
+
validator_model = tf.keras.models.load_model('checkpoints/cxr_validator_model.tf')
|
| 17 |
+
print('Validator Model Loaded!')
|
| 18 |
+
return validator_model
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@st.cache_resource
|
| 22 |
+
def load_model():
|
| 23 |
+
|
| 24 |
+
# Load Tokenizer
|
| 25 |
+
tokenizer = ByteLevelBPETokenizer(
|
| 26 |
+
'mimic/mimic-vocab.json',
|
| 27 |
+
'mimic/mimic-merges.txt',
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
# Load Model
|
| 31 |
+
hparams = default_hparams()
|
| 32 |
+
transformer = Transformer(
|
| 33 |
+
num_layers=hparams['num_layers'],
|
| 34 |
+
d_model=hparams['d_model'],
|
| 35 |
+
num_heads=hparams['num_heads'],
|
| 36 |
+
dff=hparams['dff'],
|
| 37 |
+
target_vocab_size=tokenizer.get_vocab_size(),
|
| 38 |
+
dropout_rate=hparams['dropout_rate'])
|
| 39 |
+
transformer.load_weights('checkpoints/RATCHET.tf')
|
| 40 |
+
print(f'Model Loaded! Checkpoint file: checkpoints/RATCHET.tf')
|
| 41 |
+
|
| 42 |
+
return transformer, tokenizer
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def top_k_logits(logits, k):
|
| 46 |
+
if k == 0:
|
| 47 |
+
# no truncation
|
| 48 |
+
return logits
|
| 49 |
+
|
| 50 |
+
def _top_k():
|
| 51 |
+
values, _ = tf.nn.top_k(logits, k=k)
|
| 52 |
+
min_values = values[:, -1, tf.newaxis]
|
| 53 |
+
return tf.where(
|
| 54 |
+
logits < min_values,
|
| 55 |
+
tf.ones_like(logits, dtype=logits.dtype) * -1e10,
|
| 56 |
+
logits,
|
| 57 |
+
)
|
| 58 |
+
return tf.cond(
|
| 59 |
+
tf.equal(k, 0),
|
| 60 |
+
lambda: logits,
|
| 61 |
+
lambda: _top_k(),
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def top_p_logits(logits, p):
|
| 66 |
+
"""Nucleus sampling"""
|
| 67 |
+
batch, _ = logits.shape.as_list()
|
| 68 |
+
sorted_logits = tf.sort(logits, direction='DESCENDING', axis=-1)
|
| 69 |
+
cumulative_probs = tf.cumsum(tf.nn.softmax(sorted_logits, axis=-1), axis=-1)
|
| 70 |
+
indices = tf.stack([
|
| 71 |
+
tf.range(0, batch),
|
| 72 |
+
# number of indices to include
|
| 73 |
+
tf.maximum(tf.reduce_sum(tf.cast(cumulative_probs <= p, tf.int32), axis=-1) - 1, 0),
|
| 74 |
+
], axis=-1)
|
| 75 |
+
min_values = tf.gather_nd(sorted_logits, indices)
|
| 76 |
+
return tf.where(
|
| 77 |
+
logits < min_values,
|
| 78 |
+
tf.ones_like(logits) * -1e10,
|
| 79 |
+
logits,
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def evaluate(inp_img, tokenizer, transformer, temperature, top_k, top_p, options, seed, MAX_LENGTH=128):
|
| 84 |
+
|
| 85 |
+
# The first token to the transformer should be the start token
|
| 86 |
+
output = tf.convert_to_tensor([[tokenizer.token_to_id('<s>')]])
|
| 87 |
+
|
| 88 |
+
my_bar = st.progress(0)
|
| 89 |
+
for i in tqdm.tqdm(range(MAX_LENGTH)):
|
| 90 |
+
my_bar.progress(i/MAX_LENGTH)
|
| 91 |
+
|
| 92 |
+
# predictions.shape == (batch_size, seq_len, vocab_size)
|
| 93 |
+
predictions = transformer([inp_img, output], training=False)
|
| 94 |
+
|
| 95 |
+
# select the last word from the seq_len dimension
|
| 96 |
+
predictions = predictions[:, -1, :] / temperature # (batch_size, vocab_size)
|
| 97 |
+
predictions = top_k_logits(predictions, k=top_k)
|
| 98 |
+
predictions = top_p_logits(predictions, p=top_p)
|
| 99 |
+
|
| 100 |
+
if options == 'Greedy':
|
| 101 |
+
predicted_id = tf.cast(tf.argmax(predictions, axis=-1), tf.int32)[:, tf.newaxis]
|
| 102 |
+
elif options == 'Sampling':
|
| 103 |
+
predicted_id = tf.random.categorical(predictions, num_samples=1, dtype=tf.int32, seed=seed)
|
| 104 |
+
else:
|
| 105 |
+
st.write('SHOULD NOT HAPPEN')
|
| 106 |
+
|
| 107 |
+
# return the result if the predicted_id is equal to the end token
|
| 108 |
+
if predicted_id == 2: # stop token #tokenizer_en.vocab_size + 1:
|
| 109 |
+
my_bar.empty()
|
| 110 |
+
break
|
| 111 |
+
|
| 112 |
+
# concatentate the predicted_id to the output which is given to the decoder
|
| 113 |
+
# as its input.
|
| 114 |
+
output = tf.concat([output, predicted_id], axis=-1)
|
| 115 |
+
|
| 116 |
+
my_bar.empty()
|
| 117 |
+
|
| 118 |
+
# transformer([inp_img, output[:, :-1]], training=False)
|
| 119 |
+
return tf.squeeze(output, axis=0)[1:], transformer.decoder.last_attn_scores
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def main():
|
| 123 |
+
|
| 124 |
+
st.title('Chest X-ray AI Diagnosis Demo')
|
| 125 |
+
st.text('Made with Streamlit and Attention RNN')
|
| 126 |
+
|
| 127 |
+
transformer, tokenizer = load_model()
|
| 128 |
+
cxr_validator_model = load_validator()
|
| 129 |
+
|
| 130 |
+
st.sidebar.title('Configuration')
|
| 131 |
+
options = st.sidebar.selectbox('Generation Method', ('Greedy', 'Sampling'))
|
| 132 |
+
seed = st.sidebar.number_input('Sampling Seed:', value=42)
|
| 133 |
+
temperature = st.sidebar.number_input('Temperature', value=1.)
|
| 134 |
+
top_k = st.sidebar.slider('top_k', min_value=0, max_value=tokenizer.get_vocab_size(), value=6, step=1)
|
| 135 |
+
top_p = st.sidebar.slider('top_p', min_value=0., max_value=1., value=1., step=0.01)
|
| 136 |
+
attention_head = st.sidebar.slider('attention_head', min_value=-1, max_value=7, value=-1, step=1)
|
| 137 |
+
|
| 138 |
+
st.sidebar.info('PRIVACY POLICY: Uploaded images are never stored on disk.')
|
| 139 |
+
|
| 140 |
+
st.set_option('deprecation.showfileUploaderEncoding', False)
|
| 141 |
+
uploaded_file = st.file_uploader('Choose an image...', type=('png', 'jpg', 'jpeg'))
|
| 142 |
+
|
| 143 |
+
if uploaded_file:
|
| 144 |
+
|
| 145 |
+
# Read input image with size [1, H, W, 1] and range (0, 255)
|
| 146 |
+
img_array = io.imread(uploaded_file, as_gray=True)[None, ..., None]
|
| 147 |
+
|
| 148 |
+
# Convert image to float values in (0, 1)
|
| 149 |
+
img_array = tf.image.convert_image_dtype(img_array, tf.float32)
|
| 150 |
+
|
| 151 |
+
# Resize image with padding to [1, 224, 224, 1]
|
| 152 |
+
img_array = tf.image.resize_with_pad(img_array, 224, 224, method=tf.image.ResizeMethod.BILINEAR)
|
| 153 |
+
|
| 154 |
+
# Display input image
|
| 155 |
+
st.image(np.squeeze(img_array.numpy()), caption='Uploaded Image')
|
| 156 |
+
|
| 157 |
+
# Check image
|
| 158 |
+
valid = tf.nn.sigmoid(cxr_validator_model(img_array))
|
| 159 |
+
if valid < 0.1:
|
| 160 |
+
st.info('Image is not a Chest X-ray')
|
| 161 |
+
return
|
| 162 |
+
|
| 163 |
+
# Log datetime
|
| 164 |
+
print('[{}] Running Analysis...'
|
| 165 |
+
.format(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")))
|
| 166 |
+
|
| 167 |
+
# Generate radiology report
|
| 168 |
+
with st.spinner('Generating report... Do not refresh or close window.'):
|
| 169 |
+
result, attention_weights = evaluate(img_array, tokenizer, transformer,
|
| 170 |
+
temperature, top_k, top_p,
|
| 171 |
+
options, seed)
|
| 172 |
+
predicted_sentence = tokenizer.decode(result)
|
| 173 |
+
|
| 174 |
+
# Display generated text
|
| 175 |
+
st.subheader('Generated Report:')
|
| 176 |
+
st.write(predicted_sentence)
|
| 177 |
+
# st.info(predicted_sentence)
|
| 178 |
+
|
| 179 |
+
st.subheader('Attention Plot:')
|
| 180 |
+
|
| 181 |
+
attn_map = attention_weights[0] # squeeze
|
| 182 |
+
if attention_head == -1: # average attention heads
|
| 183 |
+
attn_map = tf.reduce_mean(attn_map, axis=0)
|
| 184 |
+
else: # select attention heads
|
| 185 |
+
attn_map = attn_map[attention_head]
|
| 186 |
+
attn_map = attn_map / attn_map.numpy().max() * 255
|
| 187 |
+
|
| 188 |
+
fig = plt.figure(figsize=(40, 80))
|
| 189 |
+
|
| 190 |
+
for i in range(attn_map.shape[0] - 1):
|
| 191 |
+
attn_token = attn_map[i, ...]
|
| 192 |
+
attn_token = tf.reshape(attn_token, [7, 7])
|
| 193 |
+
|
| 194 |
+
ax = fig.add_subplot(16, 8, i + 1)
|
| 195 |
+
ax.set_title(tokenizer.decode([result.numpy()[i]]))
|
| 196 |
+
img = ax.imshow(np.squeeze(img_array))
|
| 197 |
+
ax.imshow(attn_token, cmap='gray', alpha=0.6, extent=img.get_extent())
|
| 198 |
+
|
| 199 |
+
st.pyplot(plt)
|
| 200 |
+
|
| 201 |
+
# Run again?
|
| 202 |
+
st.button('Regenerate Report')
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
if __name__ == '__main__':
|
| 206 |
+
|
| 207 |
+
tf.config.set_visible_devices([], 'GPU')
|
| 208 |
+
|
| 209 |
+
main()
|
assets/attn_plot.png
ADDED
|
Git LFS Details
|
assets/examples.png
ADDED
|
Git LFS Details
|
assets/model_transformer.png
ADDED
|
checkpoints/RATCHET.tf/keras_metadata.pb
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a8fa018ac83d10617e20e3f03de3718d9d3d6e1b89673707cb510318fd3198b3
|
| 3 |
+
size 1065144
|
checkpoints/RATCHET.tf/saved_model.pb
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:84e9d837b881c58edee113c7bbdc793159e6e57c2ddcf9d2a3e4da7c5104a7db
|
| 3 |
+
size 26013311
|
checkpoints/RATCHET.tf/variables/variables.data-00000-of-00001
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ae18face6fa821f8c6c62923ef5533fca681e01b6bb8ae511a9c94844f618c8e
|
| 3 |
+
size 1669994429
|
checkpoints/RATCHET.tf/variables/variables.index
ADDED
|
Binary file (121 kB). View file
|
|
|
checkpoints/cxr_validator_model.tf/fingerprint.pb
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:21a31ac72a46d124de283ecbd75c35efc8ac0c5f597efd3040ed8dd00d071ef2
|
| 3 |
+
size 53
|
checkpoints/cxr_validator_model.tf/keras_metadata.pb
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:19106ee698a03e8b9ec11b0092fd65c32654380171a3c55a7976d56313e4438a
|
| 3 |
+
size 2538679
|
checkpoints/cxr_validator_model.tf/saved_model.pb
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:16e7434007981626733e6f925cd0b226e1f4130cfaec7e79ba81ffd16d7ab1cb
|
| 3 |
+
size 14320368
|
checkpoints/cxr_validator_model.tf/variables/variables.data-00000-of-00001
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e2edd5cef46c1624f31464e13f3b5fb8c0ceb4ce8a1d834a6cde9c2e71dd509e
|
| 3 |
+
size 224256098
|
checkpoints/cxr_validator_model.tf/variables/variables.index
ADDED
|
Binary file (51.9 kB). View file
|
|
|
mimic/mimic-merges.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
mimic/mimic-vocab.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
matplotlib
|
| 2 |
+
numpy
|
| 3 |
+
scikit-image
|
| 4 |
+
tensorflow
|
| 5 |
+
tokenizers
|
| 6 |
+
tqdm
|
transformer.py
ADDED
|
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import absolute_import
|
| 2 |
+
from __future__ import division
|
| 3 |
+
from __future__ import print_function
|
| 4 |
+
from __future__ import unicode_literals
|
| 5 |
+
|
| 6 |
+
import datetime
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import tensorflow as tf
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def default_hparams():
|
| 13 |
+
return {
|
| 14 |
+
'img_x': 224,
|
| 15 |
+
'img_y': 224,
|
| 16 |
+
'img_ch': 1,
|
| 17 |
+
'd_model': 512,
|
| 18 |
+
'dff': 2048,
|
| 19 |
+
'num_heads': 8,
|
| 20 |
+
'num_layers': 6,
|
| 21 |
+
'dropout_rate': 0.1
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def positional_encoding(length, depth):
|
| 26 |
+
depth = depth / 2
|
| 27 |
+
|
| 28 |
+
positions = np.arange(length)[:, np.newaxis] # (seq, 1)
|
| 29 |
+
depths = np.arange(depth)[np.newaxis, :] / depth # (1, depth)
|
| 30 |
+
|
| 31 |
+
angle_rates = 1 / (10000 ** depths) # (1, depth)
|
| 32 |
+
angle_rads = positions * angle_rates # (pos, depth)
|
| 33 |
+
|
| 34 |
+
pos_encoding = np.concatenate(
|
| 35 |
+
[np.sin(angle_rads), np.cos(angle_rads)],
|
| 36 |
+
axis=-1)
|
| 37 |
+
|
| 38 |
+
return tf.cast(pos_encoding, dtype=tf.float32)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class PositionalEmbedding(tf.keras.layers.Layer):
|
| 42 |
+
def __init__(self, vocab_size, d_model):
|
| 43 |
+
super().__init__()
|
| 44 |
+
self.d_model = d_model
|
| 45 |
+
self.embedding = tf.keras.layers.Embedding(vocab_size, d_model, mask_zero=True)
|
| 46 |
+
self.pos_encoding = positional_encoding(length=2048, depth=d_model)
|
| 47 |
+
|
| 48 |
+
def compute_mask(self, *args, **kwargs):
|
| 49 |
+
return self.embedding.compute_mask(*args, **kwargs)
|
| 50 |
+
|
| 51 |
+
def call(self, x):
|
| 52 |
+
length = tf.shape(x)[1]
|
| 53 |
+
x = self.embedding(x)
|
| 54 |
+
# This factor sets the relative scale of the embedding and positonal_encoding.
|
| 55 |
+
x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
|
| 56 |
+
x = x + self.pos_encoding[tf.newaxis, :length, :]
|
| 57 |
+
return x
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class BaseAttention(tf.keras.layers.Layer):
|
| 61 |
+
def __init__(self, **kwargs):
|
| 62 |
+
super().__init__()
|
| 63 |
+
self.mha = tf.keras.layers.MultiHeadAttention(**kwargs)
|
| 64 |
+
self.layernorm = tf.keras.layers.LayerNormalization()
|
| 65 |
+
self.add = tf.keras.layers.Add()
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class CrossAttention(BaseAttention):
|
| 69 |
+
def call(self, x, context):
|
| 70 |
+
attn_output, attn_scores = self.mha(
|
| 71 |
+
query=x,
|
| 72 |
+
key=context,
|
| 73 |
+
value=context,
|
| 74 |
+
return_attention_scores=True)
|
| 75 |
+
|
| 76 |
+
# Cache the attention scores for plotting later.
|
| 77 |
+
self.last_attn_scores = attn_scores
|
| 78 |
+
|
| 79 |
+
x = self.add([x, attn_output])
|
| 80 |
+
x = self.layernorm(x)
|
| 81 |
+
|
| 82 |
+
return x
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class CausalSelfAttention(BaseAttention):
|
| 86 |
+
def call(self, x):
|
| 87 |
+
attn_output = self.mha(
|
| 88 |
+
query=x,
|
| 89 |
+
value=x,
|
| 90 |
+
key=x,
|
| 91 |
+
use_causal_mask=True)
|
| 92 |
+
x = self.add([x, attn_output])
|
| 93 |
+
x = self.layernorm(x)
|
| 94 |
+
return x
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class FeedForward(tf.keras.layers.Layer):
|
| 98 |
+
def __init__(self, d_model, dff, dropout_rate=0.1):
|
| 99 |
+
super().__init__()
|
| 100 |
+
self.seq = tf.keras.Sequential([
|
| 101 |
+
tf.keras.layers.Dense(dff, activation='relu'),
|
| 102 |
+
tf.keras.layers.Dense(d_model),
|
| 103 |
+
tf.keras.layers.Dropout(dropout_rate)
|
| 104 |
+
])
|
| 105 |
+
self.add = tf.keras.layers.Add()
|
| 106 |
+
self.layer_norm = tf.keras.layers.LayerNormalization()
|
| 107 |
+
|
| 108 |
+
def call(self, x):
|
| 109 |
+
x = self.add([x, self.seq(x)])
|
| 110 |
+
x = self.layer_norm(x)
|
| 111 |
+
return x
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class DecoderLayer(tf.keras.layers.Layer):
|
| 115 |
+
def __init__(self,
|
| 116 |
+
*,
|
| 117 |
+
d_model,
|
| 118 |
+
num_heads,
|
| 119 |
+
dff,
|
| 120 |
+
dropout_rate=0.1):
|
| 121 |
+
super(DecoderLayer, self).__init__()
|
| 122 |
+
|
| 123 |
+
self.causal_self_attention = CausalSelfAttention(
|
| 124 |
+
num_heads=num_heads,
|
| 125 |
+
key_dim=d_model,
|
| 126 |
+
dropout=dropout_rate)
|
| 127 |
+
|
| 128 |
+
self.cross_attention = CrossAttention(
|
| 129 |
+
num_heads=num_heads,
|
| 130 |
+
key_dim=d_model,
|
| 131 |
+
dropout=dropout_rate)
|
| 132 |
+
|
| 133 |
+
self.ffn = FeedForward(d_model, dff)
|
| 134 |
+
|
| 135 |
+
def call(self, x, context):
|
| 136 |
+
x = self.causal_self_attention(x=x)
|
| 137 |
+
x = self.cross_attention(x=x, context=context)
|
| 138 |
+
|
| 139 |
+
# Cache the last attention scores for plotting later
|
| 140 |
+
self.last_attn_scores = self.cross_attention.last_attn_scores
|
| 141 |
+
|
| 142 |
+
x = self.ffn(x) # Shape `(batch_size, seq_len, d_model)`.
|
| 143 |
+
return x
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
class Encoder(tf.keras.layers.Layer):
|
| 147 |
+
def __init__(self, embedding_dim, input_shape, pretrain_weights=None):
|
| 148 |
+
super(Encoder, self).__init__()
|
| 149 |
+
|
| 150 |
+
# shape after fc == (batch_size, nf * nf, embedding_dim)
|
| 151 |
+
self.fc = tf.keras.layers.Dense(embedding_dim, activation='relu')
|
| 152 |
+
|
| 153 |
+
# Use DenseNet-121 as feature extraction model
|
| 154 |
+
self.base_model = tf.keras.applications.DenseNet121(
|
| 155 |
+
include_top=False, weights=None, input_shape=input_shape)
|
| 156 |
+
|
| 157 |
+
# Load pre-trained weights if present
|
| 158 |
+
if pretrain_weights:
|
| 159 |
+
print(f'{datetime.datetime.now()}: I Loading Pretrained DenseNet-121 weights: {pretrain_weights}')
|
| 160 |
+
self.base_model.load_weights(pretrain_weights)
|
| 161 |
+
else:
|
| 162 |
+
print(f'{datetime.datetime.now()}: I No Pretrained DenseNet-121 weights specified')
|
| 163 |
+
|
| 164 |
+
def call(self, x, **kwargs):
|
| 165 |
+
x = self.base_model(x)
|
| 166 |
+
# DenseNet-121 output is (batch_size, ?, ?, 1024)
|
| 167 |
+
s = tf.shape(x)
|
| 168 |
+
x = tf.reshape(x, (s[0], s[1] * s[2], x.shape[3]))
|
| 169 |
+
x = self.fc(x)
|
| 170 |
+
return x
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class Decoder(tf.keras.layers.Layer):
|
| 174 |
+
def __init__(self, *, num_layers, d_model, num_heads, dff, vocab_size,
|
| 175 |
+
dropout_rate=0.1):
|
| 176 |
+
super(Decoder, self).__init__()
|
| 177 |
+
|
| 178 |
+
self.d_model = d_model
|
| 179 |
+
self.num_layers = num_layers
|
| 180 |
+
|
| 181 |
+
self.pos_embedding = PositionalEmbedding(vocab_size=vocab_size,
|
| 182 |
+
d_model=d_model)
|
| 183 |
+
self.dropout = tf.keras.layers.Dropout(dropout_rate)
|
| 184 |
+
self.dec_layers = [
|
| 185 |
+
DecoderLayer(d_model=d_model, num_heads=num_heads,
|
| 186 |
+
dff=dff, dropout_rate=dropout_rate)
|
| 187 |
+
for _ in range(num_layers)]
|
| 188 |
+
|
| 189 |
+
self.last_attn_scores = None
|
| 190 |
+
|
| 191 |
+
def call(self, x, context):
|
| 192 |
+
# `x` is token-IDs shape (batch, target_seq_len)
|
| 193 |
+
x = self.pos_embedding(x) # (batch_size, target_seq_len, d_model)
|
| 194 |
+
|
| 195 |
+
x = self.dropout(x)
|
| 196 |
+
|
| 197 |
+
for i in range(self.num_layers):
|
| 198 |
+
x = self.dec_layers[i](x, context)
|
| 199 |
+
|
| 200 |
+
self.last_attn_scores = self.dec_layers[-1].last_attn_scores
|
| 201 |
+
|
| 202 |
+
# The shape of x is (batch_size, target_seq_len, d_model).
|
| 203 |
+
return x
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
class Transformer(tf.keras.Model):
|
| 207 |
+
def __init__(self, num_layers, d_model, num_heads, dff,
|
| 208 |
+
target_vocab_size, dropout_rate=0.1, input_shape=(224, 224, 1),
|
| 209 |
+
classifier_weights=None):
|
| 210 |
+
super(Transformer, self).__init__()
|
| 211 |
+
|
| 212 |
+
self.encoder = Encoder(d_model, input_shape,
|
| 213 |
+
pretrain_weights=classifier_weights)
|
| 214 |
+
|
| 215 |
+
self.decoder = Decoder(num_layers=num_layers, d_model=d_model,
|
| 216 |
+
num_heads=num_heads, dff=dff,
|
| 217 |
+
vocab_size=target_vocab_size,
|
| 218 |
+
dropout_rate=dropout_rate)
|
| 219 |
+
|
| 220 |
+
self.final_layer = tf.keras.layers.Dense(target_vocab_size)
|
| 221 |
+
|
| 222 |
+
def call(self, inputs):
|
| 223 |
+
# To use a Keras model with `.fit` you must pass all your inputs in the
|
| 224 |
+
# first argument.
|
| 225 |
+
context, x = inputs
|
| 226 |
+
|
| 227 |
+
context = self.encoder(context) # (batch_size, context_len, d_model)
|
| 228 |
+
|
| 229 |
+
x = self.decoder(x, context) # (batch_size, target_len, d_model)
|
| 230 |
+
|
| 231 |
+
# Final linear layer output.
|
| 232 |
+
logits = self.final_layer(x) # (batch_size, target_len, target_vocab_size)
|
| 233 |
+
|
| 234 |
+
try:
|
| 235 |
+
# Drop the keras mask, so it doesn't scale the losses/metrics.
|
| 236 |
+
# b/250038731
|
| 237 |
+
del logits._keras_mask
|
| 238 |
+
except AttributeError:
|
| 239 |
+
pass
|
| 240 |
+
|
| 241 |
+
# Return the final output and the attention weights.
|
| 242 |
+
return logits
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
if __name__ == "__main__":
|
| 246 |
+
|
| 247 |
+
hparams = default_hparams()
|
| 248 |
+
|
| 249 |
+
transformer = Transformer(
|
| 250 |
+
num_layers=hparams['num_layers'],
|
| 251 |
+
d_model=hparams['d_model'],
|
| 252 |
+
num_heads=hparams['num_heads'],
|
| 253 |
+
dff=hparams['dff'],
|
| 254 |
+
target_vocab_size=2048,
|
| 255 |
+
dropout_rate=hparams['dropout_rate'])
|
| 256 |
+
|
| 257 |
+
a=1
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
image = np.random.rand(1,224,224,1).astype('float32')
|
| 261 |
+
text = np.random.randint(0, 2048, size=(1, 27))
|
| 262 |
+
|
| 263 |
+
output = transformer((image, text))
|