|
|
import torch.nn as nn |
|
|
|
|
|
from torch.nn import functional as F |
|
|
from timm import create_model |
|
|
|
|
|
|
|
|
__all__ = ['NoiseTransformer'] |
|
|
|
|
|
class NoiseTransformer(nn.Module): |
|
|
def __init__(self, resolution=(128,96)): |
|
|
super().__init__() |
|
|
self.upsample = lambda x: F.interpolate(x, [224,224]) |
|
|
self.downsample = lambda x: F.interpolate(x, [resolution[0],resolution[1]]) |
|
|
self.upconv = nn.Conv2d(7,4,(1,1),(1,1),(0,0)) |
|
|
self.downconv = nn.Conv2d(4,3,(1,1),(1,1),(0,0)) |
|
|
|
|
|
self.swin = create_model("swin_tiny_patch4_window7_224",pretrained=True) |
|
|
|
|
|
|
|
|
def forward(self, x, residual=False): |
|
|
if residual: |
|
|
x = self.upconv(self.downsample(self.swin.forward_features(self.downconv(self.upsample(x))))) + x |
|
|
else: |
|
|
x = self.upconv(self.downsample(self.swin.forward_features(self.downconv(self.upsample(x))))) |
|
|
|
|
|
return x |