195 lines
6.3 KiB
Python
195 lines
6.3 KiB
Python
"""Modified from https://github.com/chaofengc/PSFRGAN
|
|
"""
|
|
import numpy as np
|
|
import torch.nn as nn
|
|
from torch.nn import functional as F
|
|
|
|
|
|
class NormLayer(nn.Module):
|
|
"""Normalization Layers.
|
|
|
|
Args:
|
|
channels: input channels, for batch norm and instance norm.
|
|
input_size: input shape without batch size, for layer norm.
|
|
"""
|
|
|
|
def __init__(self, channels, normalize_shape=None, norm_type='bn'):
|
|
super(NormLayer, self).__init__()
|
|
norm_type = norm_type.lower()
|
|
self.norm_type = norm_type
|
|
if norm_type == 'bn':
|
|
self.norm = nn.BatchNorm2d(channels, affine=True)
|
|
elif norm_type == 'in':
|
|
self.norm = nn.InstanceNorm2d(channels, affine=False)
|
|
elif norm_type == 'gn':
|
|
self.norm = nn.GroupNorm(32, channels, affine=True)
|
|
elif norm_type == 'pixel':
|
|
self.norm = lambda x: F.normalize(x, p=2, dim=1)
|
|
elif norm_type == 'layer':
|
|
self.norm = nn.LayerNorm(normalize_shape)
|
|
elif norm_type == 'none':
|
|
self.norm = lambda x: x * 1.0
|
|
else:
|
|
assert 1 == 0, f'Norm type {norm_type} not support.'
|
|
|
|
def forward(self, x, ref=None):
|
|
if self.norm_type == 'spade':
|
|
return self.norm(x, ref)
|
|
else:
|
|
return self.norm(x)
|
|
|
|
|
|
class ReluLayer(nn.Module):
|
|
"""Relu Layer.
|
|
|
|
Args:
|
|
relu type: type of relu layer, candidates are
|
|
- ReLU
|
|
- LeakyReLU: default relu slope 0.2
|
|
- PRelu
|
|
- SELU
|
|
- none: direct pass
|
|
"""
|
|
|
|
def __init__(self, channels, relu_type='relu'):
|
|
super(ReluLayer, self).__init__()
|
|
relu_type = relu_type.lower()
|
|
if relu_type == 'relu':
|
|
self.func = nn.ReLU(True)
|
|
elif relu_type == 'leakyrelu':
|
|
self.func = nn.LeakyReLU(0.2, inplace=True)
|
|
elif relu_type == 'prelu':
|
|
self.func = nn.PReLU(channels)
|
|
elif relu_type == 'selu':
|
|
self.func = nn.SELU(True)
|
|
elif relu_type == 'none':
|
|
self.func = lambda x: x * 1.0
|
|
else:
|
|
assert 1 == 0, f'Relu type {relu_type} not support.'
|
|
|
|
def forward(self, x):
|
|
return self.func(x)
|
|
|
|
|
|
class ConvLayer(nn.Module):
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size=3,
|
|
scale='none',
|
|
norm_type='none',
|
|
relu_type='none',
|
|
use_pad=True,
|
|
bias=True):
|
|
super(ConvLayer, self).__init__()
|
|
self.use_pad = use_pad
|
|
self.norm_type = norm_type
|
|
if norm_type in ['bn']:
|
|
bias = False
|
|
|
|
stride = 2 if scale == 'down' else 1
|
|
|
|
self.scale_func = lambda x: x
|
|
if scale == 'up':
|
|
self.scale_func = lambda x: nn.functional.interpolate(x, scale_factor=2, mode='nearest')
|
|
|
|
self.reflection_pad = nn.ReflectionPad2d(int(np.ceil((kernel_size - 1.) / 2)))
|
|
self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, bias=bias)
|
|
|
|
self.relu = ReluLayer(out_channels, relu_type)
|
|
self.norm = NormLayer(out_channels, norm_type=norm_type)
|
|
|
|
def forward(self, x):
|
|
out = self.scale_func(x)
|
|
if self.use_pad:
|
|
out = self.reflection_pad(out)
|
|
out = self.conv2d(out)
|
|
out = self.norm(out)
|
|
out = self.relu(out)
|
|
return out
|
|
|
|
|
|
class ResidualBlock(nn.Module):
|
|
"""
|
|
Residual block recommended in: http://torch.ch/blog/2016/02/04/resnets.html
|
|
"""
|
|
|
|
def __init__(self, c_in, c_out, relu_type='prelu', norm_type='bn', scale='none'):
|
|
super(ResidualBlock, self).__init__()
|
|
|
|
if scale == 'none' and c_in == c_out:
|
|
self.shortcut_func = lambda x: x
|
|
else:
|
|
self.shortcut_func = ConvLayer(c_in, c_out, 3, scale)
|
|
|
|
scale_config_dict = {'down': ['none', 'down'], 'up': ['up', 'none'], 'none': ['none', 'none']}
|
|
scale_conf = scale_config_dict[scale]
|
|
|
|
self.conv1 = ConvLayer(c_in, c_out, 3, scale_conf[0], norm_type=norm_type, relu_type=relu_type)
|
|
self.conv2 = ConvLayer(c_out, c_out, 3, scale_conf[1], norm_type=norm_type, relu_type='none')
|
|
|
|
def forward(self, x):
|
|
identity = self.shortcut_func(x)
|
|
|
|
res = self.conv1(x)
|
|
res = self.conv2(res)
|
|
return identity + res
|
|
|
|
|
|
class ParseNet(nn.Module):
|
|
|
|
def __init__(self,
|
|
in_size=128,
|
|
out_size=128,
|
|
min_feat_size=32,
|
|
base_ch=64,
|
|
parsing_ch=19,
|
|
res_depth=10,
|
|
relu_type='LeakyReLU',
|
|
norm_type='bn',
|
|
ch_range=[32, 256]):
|
|
super().__init__()
|
|
self.res_depth = res_depth
|
|
act_args = {'norm_type': norm_type, 'relu_type': relu_type}
|
|
min_ch, max_ch = ch_range
|
|
|
|
ch_clip = lambda x: max(min_ch, min(x, max_ch)) # noqa: E731
|
|
min_feat_size = min(in_size, min_feat_size)
|
|
|
|
down_steps = int(np.log2(in_size // min_feat_size))
|
|
up_steps = int(np.log2(out_size // min_feat_size))
|
|
|
|
# =============== define encoder-body-decoder ====================
|
|
self.encoder = []
|
|
self.encoder.append(ConvLayer(3, base_ch, 3, 1))
|
|
head_ch = base_ch
|
|
for i in range(down_steps):
|
|
cin, cout = ch_clip(head_ch), ch_clip(head_ch * 2)
|
|
self.encoder.append(ResidualBlock(cin, cout, scale='down', **act_args))
|
|
head_ch = head_ch * 2
|
|
|
|
self.body = []
|
|
for i in range(res_depth):
|
|
self.body.append(ResidualBlock(ch_clip(head_ch), ch_clip(head_ch), **act_args))
|
|
|
|
self.decoder = []
|
|
for i in range(up_steps):
|
|
cin, cout = ch_clip(head_ch), ch_clip(head_ch // 2)
|
|
self.decoder.append(ResidualBlock(cin, cout, scale='up', **act_args))
|
|
head_ch = head_ch // 2
|
|
|
|
self.encoder = nn.Sequential(*self.encoder)
|
|
self.body = nn.Sequential(*self.body)
|
|
self.decoder = nn.Sequential(*self.decoder)
|
|
self.out_img_conv = ConvLayer(ch_clip(head_ch), 3)
|
|
self.out_mask_conv = ConvLayer(ch_clip(head_ch), parsing_ch)
|
|
|
|
def forward(self, x):
|
|
feat = self.encoder(x)
|
|
x = feat + self.body(feat)
|
|
x = self.decoder(x)
|
|
out_img = self.out_img_conv(x)
|
|
out_mask = self.out_mask_conv(x)
|
|
return out_mask, out_img
|