Spaces:
Running
Running
| import os | |
| import glob | |
| import torch | |
| import torch.jit | |
| import torch.nn as nn | |
| class Model(torch.jit.ScriptModule): | |
| CHECKPOINT_FILENAME_PATTERN = "model-{}.pth" | |
| __constants__ = [ | |
| "_hidden1", | |
| "_hidden2", | |
| "_hidden3", | |
| "_hidden4", | |
| "_hidden5", | |
| "_hidden6", | |
| "_hidden7", | |
| "_hidden8", | |
| "_hidden9", | |
| "_hidden10", | |
| "_features", | |
| "_classifier", | |
| "_digit_length", | |
| "_digit1", | |
| "_digit2", | |
| "_digit3", | |
| "_digit4", | |
| "_digit5", | |
| ] | |
| def __init__(self): | |
| super(Model, self).__init__() | |
| self._hidden1 = nn.Sequential( | |
| nn.Conv2d(in_channels=3, out_channels=48, kernel_size=5, padding=2), | |
| nn.BatchNorm2d(num_features=48), | |
| nn.ReLU(), | |
| nn.MaxPool2d(kernel_size=2, stride=2, padding=1), | |
| nn.Dropout(0.2), | |
| ) | |
| self._hidden2 = nn.Sequential( | |
| nn.Conv2d(in_channels=48, out_channels=64, kernel_size=5, padding=2), | |
| nn.BatchNorm2d(num_features=64), | |
| nn.ReLU(), | |
| nn.MaxPool2d(kernel_size=2, stride=1, padding=1), | |
| nn.Dropout(0.2), | |
| ) | |
| self._hidden3 = nn.Sequential( | |
| nn.Conv2d(in_channels=64, out_channels=128, kernel_size=5, padding=2), | |
| nn.BatchNorm2d(num_features=128), | |
| nn.ReLU(), | |
| nn.MaxPool2d(kernel_size=2, stride=2, padding=1), | |
| nn.Dropout(0.2), | |
| ) | |
| self._hidden4 = nn.Sequential( | |
| nn.Conv2d(in_channels=128, out_channels=160, kernel_size=5, padding=2), | |
| nn.BatchNorm2d(num_features=160), | |
| nn.ReLU(), | |
| nn.MaxPool2d(kernel_size=2, stride=1, padding=1), | |
| nn.Dropout(0.2), | |
| ) | |
| self._hidden5 = nn.Sequential( | |
| nn.Conv2d(in_channels=160, out_channels=192, kernel_size=5, padding=2), | |
| nn.BatchNorm2d(num_features=192), | |
| nn.ReLU(), | |
| nn.MaxPool2d(kernel_size=2, stride=2, padding=1), | |
| nn.Dropout(0.2), | |
| ) | |
| self._hidden6 = nn.Sequential( | |
| nn.Conv2d(in_channels=192, out_channels=192, kernel_size=5, padding=2), | |
| nn.BatchNorm2d(num_features=192), | |
| nn.ReLU(), | |
| nn.MaxPool2d(kernel_size=2, stride=1, padding=1), | |
| nn.Dropout(0.2), | |
| ) | |
| self._hidden7 = nn.Sequential( | |
| nn.Conv2d(in_channels=192, out_channels=192, kernel_size=5, padding=2), | |
| nn.BatchNorm2d(num_features=192), | |
| nn.ReLU(), | |
| nn.MaxPool2d(kernel_size=2, stride=2, padding=1), | |
| nn.Dropout(0.2), | |
| ) | |
| self._hidden8 = nn.Sequential( | |
| nn.Conv2d(in_channels=192, out_channels=192, kernel_size=5, padding=2), | |
| nn.BatchNorm2d(num_features=192), | |
| nn.ReLU(), | |
| nn.MaxPool2d(kernel_size=2, stride=1, padding=1), | |
| nn.Dropout(0.2), | |
| ) | |
| self._hidden9 = nn.Sequential(nn.Linear(192 * 7 * 7, 3072), nn.ReLU()) | |
| self._hidden10 = nn.Sequential(nn.Linear(3072, 3072), nn.ReLU()) | |
| self._digit_length = nn.Sequential(nn.Linear(3072, 7)) | |
| self._digit1 = nn.Sequential(nn.Linear(3072, 11)) | |
| self._digit2 = nn.Sequential(nn.Linear(3072, 11)) | |
| self._digit3 = nn.Sequential(nn.Linear(3072, 11)) | |
| self._digit4 = nn.Sequential(nn.Linear(3072, 11)) | |
| self._digit5 = nn.Sequential(nn.Linear(3072, 11)) | |
| def forward(self, x): | |
| x = self._hidden1(x) | |
| x = self._hidden2(x) | |
| x = self._hidden3(x) | |
| x = self._hidden4(x) | |
| x = self._hidden5(x) | |
| x = self._hidden6(x) | |
| x = self._hidden7(x) | |
| x = self._hidden8(x) | |
| x = x.view(x.size(0), 192 * 7 * 7) | |
| x = self._hidden9(x) | |
| x = self._hidden10(x) | |
| length_logits = self._digit_length(x) | |
| digit1_logits = self._digit1(x) | |
| digit2_logits = self._digit2(x) | |
| digit3_logits = self._digit3(x) | |
| digit4_logits = self._digit4(x) | |
| digit5_logits = self._digit5(x) | |
| return ( | |
| length_logits, | |
| digit1_logits, | |
| digit2_logits, | |
| digit3_logits, | |
| digit4_logits, | |
| digit5_logits, | |
| ) | |
| def store(self, path_to_dir, step, maximum=5): | |
| path_to_models = glob.glob( | |
| os.path.join(path_to_dir, Model.CHECKPOINT_FILENAME_PATTERN.format("*")) | |
| ) | |
| if len(path_to_models) == maximum: | |
| min_step = min( | |
| [ | |
| int(path_to_model.split("\\")[-1][6:-4]) | |
| for path_to_model in path_to_models | |
| ] | |
| ) | |
| path_to_min_step_model = os.path.join( | |
| path_to_dir, Model.CHECKPOINT_FILENAME_PATTERN.format(min_step) | |
| ) | |
| os.remove(path_to_min_step_model) | |
| path_to_checkpoint_file = os.path.join( | |
| path_to_dir, Model.CHECKPOINT_FILENAME_PATTERN.format(step) | |
| ) | |
| torch.save(self.state_dict(), path_to_checkpoint_file) | |
| return path_to_checkpoint_file | |
| def restore(self, path_to_checkpoint_file): | |
| self.load_state_dict( | |
| torch.load( | |
| path_to_checkpoint_file, | |
| map_location=torch.device("cpu"), | |
| ) | |
| ) | |
| step = int(path_to_checkpoint_file.split("model-")[-1][:-4]) | |
| return step | |