remove gfpgan dep

This commit is contained in:
Qing 2024-08-12 11:02:55 +08:00
parent ffdf5e06e1
commit 60b1411d6b
27 changed files with 4745 additions and 37 deletions

View File

@ -8,4 +8,3 @@ def install(package):
def install_plugins_package():
install("rembg")
install("gfpgan")

View File

@ -0,0 +1,172 @@
import cv2
import math
import numpy as np
import os
import torch
from torchvision.utils import make_grid
def img2tensor(imgs, bgr2rgb=True, float32=True):
"""Numpy array to tensor.
Args:
imgs (list[ndarray] | ndarray): Input images.
bgr2rgb (bool): Whether to change bgr to rgb.
float32 (bool): Whether to change to float32.
Returns:
list[tensor] | tensor: Tensor images. If returned results only have
one element, just return tensor.
"""
def _totensor(img, bgr2rgb, float32):
if img.shape[2] == 3 and bgr2rgb:
if img.dtype == 'float64':
img = img.astype('float32')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = torch.from_numpy(img.transpose(2, 0, 1))
if float32:
img = img.float()
return img
if isinstance(imgs, list):
return [_totensor(img, bgr2rgb, float32) for img in imgs]
else:
return _totensor(imgs, bgr2rgb, float32)
def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
"""Convert torch Tensors into image numpy arrays.
After clamping to [min, max], values will be normalized to [0, 1].
Args:
tensor (Tensor or list[Tensor]): Accept shapes:
1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
2) 3D Tensor of shape (3/1 x H x W);
3) 2D Tensor of shape (H x W).
Tensor channel should be in RGB order.
rgb2bgr (bool): Whether to change rgb to bgr.
out_type (numpy type): output types. If ``np.uint8``, transform outputs
to uint8 type with range [0, 255]; otherwise, float type with
range [0, 1]. Default: ``np.uint8``.
min_max (tuple[int]): min and max values for clamp.
Returns:
(Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
shape (H x W). The channel order is BGR.
"""
if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')
if torch.is_tensor(tensor):
tensor = [tensor]
result = []
for _tensor in tensor:
_tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
_tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
n_dim = _tensor.dim()
if n_dim == 4:
img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy()
img_np = img_np.transpose(1, 2, 0)
if rgb2bgr:
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
elif n_dim == 3:
img_np = _tensor.numpy()
img_np = img_np.transpose(1, 2, 0)
if img_np.shape[2] == 1: # gray image
img_np = np.squeeze(img_np, axis=2)
else:
if rgb2bgr:
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
elif n_dim == 2:
img_np = _tensor.numpy()
else:
raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}')
if out_type == np.uint8:
# Unlike MATLAB, numpy.unit8() WILL NOT round by default.
img_np = (img_np * 255.0).round()
img_np = img_np.astype(out_type)
result.append(img_np)
if len(result) == 1:
result = result[0]
return result
def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)):
"""This implementation is slightly faster than tensor2img.
It now only supports torch tensor with shape (1, c, h, w).
Args:
tensor (Tensor): Now only support torch tensor with (1, c, h, w).
rgb2bgr (bool): Whether to change rgb to bgr. Default: True.
min_max (tuple[int]): min and max values for clamp.
"""
output = tensor.squeeze(0).detach().clamp_(*min_max).permute(1, 2, 0)
output = (output - min_max[0]) / (min_max[1] - min_max[0]) * 255
output = output.type(torch.uint8).cpu().numpy()
if rgb2bgr:
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
return output
def imfrombytes(content, flag='color', float32=False):
"""Read an image from bytes.
Args:
content (bytes): Image bytes got from files or other streams.
flag (str): Flags specifying the color type of a loaded image,
candidates are `color`, `grayscale` and `unchanged`.
float32 (bool): Whether to change to float32., If True, will also norm
to [0, 1]. Default: False.
Returns:
ndarray: Loaded image array.
"""
img_np = np.frombuffer(content, np.uint8)
imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED}
img = cv2.imdecode(img_np, imread_flags[flag])
if float32:
img = img.astype(np.float32) / 255.
return img
def imwrite(img, file_path, params=None, auto_mkdir=True):
"""Write image to file.
Args:
img (ndarray): Image array to be written.
file_path (str): Image file path.
params (None or list): Same as opencv's :func:`imwrite` interface.
auto_mkdir (bool): If the parent folder of `file_path` does not exist,
whether to create it automatically.
Returns:
bool: Successful or not.
"""
if auto_mkdir:
dir_name = os.path.abspath(os.path.dirname(file_path))
os.makedirs(dir_name, exist_ok=True)
ok = cv2.imwrite(file_path, img, params)
if not ok:
raise IOError('Failed in writing images.')
def crop_border(imgs, crop_border):
"""Crop borders of images.
Args:
imgs (list[ndarray] | ndarray): Images with shape (h, w, c).
crop_border (int): Crop border for each end of height and weight.
Returns:
list[ndarray]: Cropped images.
"""
if crop_border == 0:
return imgs
else:
if isinstance(imgs, list):
return [v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs]
else:
return imgs[crop_border:-crop_border, crop_border:-crop_border, ...]

135
iopaint/plugins/facexlib/.gitignore vendored Normal file
View File

@ -0,0 +1,135 @@
.vscode
*.pth
*.png
*.jpg
version.py
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/

View File

@ -0,0 +1,3 @@
# flake8: noqa
from .detection import *
from .utils import *

View File

@ -0,0 +1,31 @@
import torch
from copy import deepcopy
from ..utils import load_file_from_url
from .retinaface import RetinaFace
def init_detection_model(model_name, half=False, device='cuda', model_rootpath=None):
if model_name == 'retinaface_resnet50':
model = RetinaFace(network_name='resnet50', half=half, device=device)
model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth'
elif model_name == 'retinaface_mobile0.25':
model = RetinaFace(network_name='mobile0.25', half=half, device=device)
model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_mobilenet0.25_Final.pth'
else:
raise NotImplementedError(f'{model_name} is not implemented.')
model_path = load_file_from_url(
url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath)
# TODO: clean pretrained model
load_net = torch.load(model_path, map_location=lambda storage, loc: storage)
# remove unnecessary 'module.'
for k, v in deepcopy(load_net).items():
if k.startswith('module.'):
load_net[k[7:]] = v
load_net.pop(k)
model.load_state_dict(load_net, strict=True)
model.eval()
model = model.to(device)
return model

View File

@ -0,0 +1,219 @@
import cv2
import numpy as np
from .matlab_cp2tform import get_similarity_transform_for_cv2
# reference facial points, a list of coordinates (x,y)
REFERENCE_FACIAL_POINTS = [[30.29459953, 51.69630051], [65.53179932, 51.50139999], [48.02519989, 71.73660278],
[33.54930115, 92.3655014], [62.72990036, 92.20410156]]
DEFAULT_CROP_SIZE = (96, 112)
class FaceWarpException(Exception):
def __str__(self):
return 'In File {}:{}'.format(__file__, super.__str__(self))
def get_reference_facial_points(output_size=None, inner_padding_factor=0.0, outer_padding=(0, 0), default_square=False):
"""
Function:
----------
get reference 5 key points according to crop settings:
0. Set default crop_size:
if default_square:
crop_size = (112, 112)
else:
crop_size = (96, 112)
1. Pad the crop_size by inner_padding_factor in each side;
2. Resize crop_size into (output_size - outer_padding*2),
pad into output_size with outer_padding;
3. Output reference_5point;
Parameters:
----------
@output_size: (w, h) or None
size of aligned face image
@inner_padding_factor: (w_factor, h_factor)
padding factor for inner (w, h)
@outer_padding: (w_pad, h_pad)
each row is a pair of coordinates (x, y)
@default_square: True or False
if True:
default crop_size = (112, 112)
else:
default crop_size = (96, 112);
!!! make sure, if output_size is not None:
(output_size - outer_padding)
= some_scale * (default crop_size * (1.0 +
inner_padding_factor))
Returns:
----------
@reference_5point: 5x2 np.array
each row is a pair of transformed coordinates (x, y)
"""
tmp_5pts = np.array(REFERENCE_FACIAL_POINTS)
tmp_crop_size = np.array(DEFAULT_CROP_SIZE)
# 0) make the inner region a square
if default_square:
size_diff = max(tmp_crop_size) - tmp_crop_size
tmp_5pts += size_diff / 2
tmp_crop_size += size_diff
if (output_size and output_size[0] == tmp_crop_size[0] and output_size[1] == tmp_crop_size[1]):
return tmp_5pts
if (inner_padding_factor == 0 and outer_padding == (0, 0)):
if output_size is None:
return tmp_5pts
else:
raise FaceWarpException('No paddings to do, output_size must be None or {}'.format(tmp_crop_size))
# check output size
if not (0 <= inner_padding_factor <= 1.0):
raise FaceWarpException('Not (0 <= inner_padding_factor <= 1.0)')
if ((inner_padding_factor > 0 or outer_padding[0] > 0 or outer_padding[1] > 0) and output_size is None):
output_size = tmp_crop_size * \
(1 + inner_padding_factor * 2).astype(np.int32)
output_size += np.array(outer_padding)
if not (outer_padding[0] < output_size[0] and outer_padding[1] < output_size[1]):
raise FaceWarpException('Not (outer_padding[0] < output_size[0] and outer_padding[1] < output_size[1])')
# 1) pad the inner region according inner_padding_factor
if inner_padding_factor > 0:
size_diff = tmp_crop_size * inner_padding_factor * 2
tmp_5pts += size_diff / 2
tmp_crop_size += np.round(size_diff).astype(np.int32)
# 2) resize the padded inner region
size_bf_outer_pad = np.array(output_size) - np.array(outer_padding) * 2
if size_bf_outer_pad[0] * tmp_crop_size[1] != size_bf_outer_pad[1] * tmp_crop_size[0]:
raise FaceWarpException('Must have (output_size - outer_padding)'
'= some_scale * (crop_size * (1.0 + inner_padding_factor)')
scale_factor = size_bf_outer_pad[0].astype(np.float32) / tmp_crop_size[0]
tmp_5pts = tmp_5pts * scale_factor
# size_diff = tmp_crop_size * (scale_factor - min(scale_factor))
# tmp_5pts = tmp_5pts + size_diff / 2
tmp_crop_size = size_bf_outer_pad
# 3) add outer_padding to make output_size
reference_5point = tmp_5pts + np.array(outer_padding)
tmp_crop_size = output_size
return reference_5point
def get_affine_transform_matrix(src_pts, dst_pts):
"""
Function:
----------
get affine transform matrix 'tfm' from src_pts to dst_pts
Parameters:
----------
@src_pts: Kx2 np.array
source points matrix, each row is a pair of coordinates (x, y)
@dst_pts: Kx2 np.array
destination points matrix, each row is a pair of coordinates (x, y)
Returns:
----------
@tfm: 2x3 np.array
transform matrix from src_pts to dst_pts
"""
tfm = np.float32([[1, 0, 0], [0, 1, 0]])
n_pts = src_pts.shape[0]
ones = np.ones((n_pts, 1), src_pts.dtype)
src_pts_ = np.hstack([src_pts, ones])
dst_pts_ = np.hstack([dst_pts, ones])
A, res, rank, s = np.linalg.lstsq(src_pts_, dst_pts_)
if rank == 3:
tfm = np.float32([[A[0, 0], A[1, 0], A[2, 0]], [A[0, 1], A[1, 1], A[2, 1]]])
elif rank == 2:
tfm = np.float32([[A[0, 0], A[1, 0], 0], [A[0, 1], A[1, 1], 0]])
return tfm
def warp_and_crop_face(src_img, facial_pts, reference_pts=None, crop_size=(96, 112), align_type='smilarity'):
"""
Function:
----------
apply affine transform 'trans' to uv
Parameters:
----------
@src_img: 3x3 np.array
input image
@facial_pts: could be
1)a list of K coordinates (x,y)
or
2) Kx2 or 2xK np.array
each row or col is a pair of coordinates (x, y)
@reference_pts: could be
1) a list of K coordinates (x,y)
or
2) Kx2 or 2xK np.array
each row or col is a pair of coordinates (x, y)
or
3) None
if None, use default reference facial points
@crop_size: (w, h)
output face image size
@align_type: transform type, could be one of
1) 'similarity': use similarity transform
2) 'cv2_affine': use the first 3 points to do affine transform,
by calling cv2.getAffineTransform()
3) 'affine': use all points to do affine transform
Returns:
----------
@face_img: output face image with size (w, h) = @crop_size
"""
if reference_pts is None:
if crop_size[0] == 96 and crop_size[1] == 112:
reference_pts = REFERENCE_FACIAL_POINTS
else:
default_square = False
inner_padding_factor = 0
outer_padding = (0, 0)
output_size = crop_size
reference_pts = get_reference_facial_points(output_size, inner_padding_factor, outer_padding,
default_square)
ref_pts = np.float32(reference_pts)
ref_pts_shp = ref_pts.shape
if max(ref_pts_shp) < 3 or min(ref_pts_shp) != 2:
raise FaceWarpException('reference_pts.shape must be (K,2) or (2,K) and K>2')
if ref_pts_shp[0] == 2:
ref_pts = ref_pts.T
src_pts = np.float32(facial_pts)
src_pts_shp = src_pts.shape
if max(src_pts_shp) < 3 or min(src_pts_shp) != 2:
raise FaceWarpException('facial_pts.shape must be (K,2) or (2,K) and K>2')
if src_pts_shp[0] == 2:
src_pts = src_pts.T
if src_pts.shape != ref_pts.shape:
raise FaceWarpException('facial_pts and reference_pts must have the same shape')
if align_type == 'cv2_affine':
tfm = cv2.getAffineTransform(src_pts[0:3], ref_pts[0:3])
elif align_type == 'affine':
tfm = get_affine_transform_matrix(src_pts, ref_pts)
else:
tfm = get_similarity_transform_for_cv2(src_pts, ref_pts)
face_img = cv2.warpAffine(src_img, tfm, (crop_size[0], crop_size[1]))
return face_img

View File

@ -0,0 +1,317 @@
import numpy as np
from numpy.linalg import inv, lstsq
from numpy.linalg import matrix_rank as rank
from numpy.linalg import norm
class MatlabCp2tormException(Exception):
def __str__(self):
return 'In File {}:{}'.format(__file__, super.__str__(self))
def tformfwd(trans, uv):
"""
Function:
----------
apply affine transform 'trans' to uv
Parameters:
----------
@trans: 3x3 np.array
transform matrix
@uv: Kx2 np.array
each row is a pair of coordinates (x, y)
Returns:
----------
@xy: Kx2 np.array
each row is a pair of transformed coordinates (x, y)
"""
uv = np.hstack((uv, np.ones((uv.shape[0], 1))))
xy = np.dot(uv, trans)
xy = xy[:, 0:-1]
return xy
def tforminv(trans, uv):
"""
Function:
----------
apply the inverse of affine transform 'trans' to uv
Parameters:
----------
@trans: 3x3 np.array
transform matrix
@uv: Kx2 np.array
each row is a pair of coordinates (x, y)
Returns:
----------
@xy: Kx2 np.array
each row is a pair of inverse-transformed coordinates (x, y)
"""
Tinv = inv(trans)
xy = tformfwd(Tinv, uv)
return xy
def findNonreflectiveSimilarity(uv, xy, options=None):
options = {'K': 2}
K = options['K']
M = xy.shape[0]
x = xy[:, 0].reshape((-1, 1)) # use reshape to keep a column vector
y = xy[:, 1].reshape((-1, 1)) # use reshape to keep a column vector
tmp1 = np.hstack((x, y, np.ones((M, 1)), np.zeros((M, 1))))
tmp2 = np.hstack((y, -x, np.zeros((M, 1)), np.ones((M, 1))))
X = np.vstack((tmp1, tmp2))
u = uv[:, 0].reshape((-1, 1)) # use reshape to keep a column vector
v = uv[:, 1].reshape((-1, 1)) # use reshape to keep a column vector
U = np.vstack((u, v))
# We know that X * r = U
if rank(X) >= 2 * K:
r, _, _, _ = lstsq(X, U, rcond=-1)
r = np.squeeze(r)
else:
raise Exception('cp2tform:twoUniquePointsReq')
sc = r[0]
ss = r[1]
tx = r[2]
ty = r[3]
Tinv = np.array([[sc, -ss, 0], [ss, sc, 0], [tx, ty, 1]])
T = inv(Tinv)
T[:, 2] = np.array([0, 0, 1])
return T, Tinv
def findSimilarity(uv, xy, options=None):
options = {'K': 2}
# uv = np.array(uv)
# xy = np.array(xy)
# Solve for trans1
trans1, trans1_inv = findNonreflectiveSimilarity(uv, xy, options)
# Solve for trans2
# manually reflect the xy data across the Y-axis
xyR = xy
xyR[:, 0] = -1 * xyR[:, 0]
trans2r, trans2r_inv = findNonreflectiveSimilarity(uv, xyR, options)
# manually reflect the tform to undo the reflection done on xyR
TreflectY = np.array([[-1, 0, 0], [0, 1, 0], [0, 0, 1]])
trans2 = np.dot(trans2r, TreflectY)
# Figure out if trans1 or trans2 is better
xy1 = tformfwd(trans1, uv)
norm1 = norm(xy1 - xy)
xy2 = tformfwd(trans2, uv)
norm2 = norm(xy2 - xy)
if norm1 <= norm2:
return trans1, trans1_inv
else:
trans2_inv = inv(trans2)
return trans2, trans2_inv
def get_similarity_transform(src_pts, dst_pts, reflective=True):
"""
Function:
----------
Find Similarity Transform Matrix 'trans':
u = src_pts[:, 0]
v = src_pts[:, 1]
x = dst_pts[:, 0]
y = dst_pts[:, 1]
[x, y, 1] = [u, v, 1] * trans
Parameters:
----------
@src_pts: Kx2 np.array
source points, each row is a pair of coordinates (x, y)
@dst_pts: Kx2 np.array
destination points, each row is a pair of transformed
coordinates (x, y)
@reflective: True or False
if True:
use reflective similarity transform
else:
use non-reflective similarity transform
Returns:
----------
@trans: 3x3 np.array
transform matrix from uv to xy
trans_inv: 3x3 np.array
inverse of trans, transform matrix from xy to uv
"""
if reflective:
trans, trans_inv = findSimilarity(src_pts, dst_pts)
else:
trans, trans_inv = findNonreflectiveSimilarity(src_pts, dst_pts)
return trans, trans_inv
def cvt_tform_mat_for_cv2(trans):
"""
Function:
----------
Convert Transform Matrix 'trans' into 'cv2_trans' which could be
directly used by cv2.warpAffine():
u = src_pts[:, 0]
v = src_pts[:, 1]
x = dst_pts[:, 0]
y = dst_pts[:, 1]
[x, y].T = cv_trans * [u, v, 1].T
Parameters:
----------
@trans: 3x3 np.array
transform matrix from uv to xy
Returns:
----------
@cv2_trans: 2x3 np.array
transform matrix from src_pts to dst_pts, could be directly used
for cv2.warpAffine()
"""
cv2_trans = trans[:, 0:2].T
return cv2_trans
def get_similarity_transform_for_cv2(src_pts, dst_pts, reflective=True):
"""
Function:
----------
Find Similarity Transform Matrix 'cv2_trans' which could be
directly used by cv2.warpAffine():
u = src_pts[:, 0]
v = src_pts[:, 1]
x = dst_pts[:, 0]
y = dst_pts[:, 1]
[x, y].T = cv_trans * [u, v, 1].T
Parameters:
----------
@src_pts: Kx2 np.array
source points, each row is a pair of coordinates (x, y)
@dst_pts: Kx2 np.array
destination points, each row is a pair of transformed
coordinates (x, y)
reflective: True or False
if True:
use reflective similarity transform
else:
use non-reflective similarity transform
Returns:
----------
@cv2_trans: 2x3 np.array
transform matrix from src_pts to dst_pts, could be directly used
for cv2.warpAffine()
"""
trans, trans_inv = get_similarity_transform(src_pts, dst_pts, reflective)
cv2_trans = cvt_tform_mat_for_cv2(trans)
return cv2_trans
if __name__ == '__main__':
"""
u = [0, 6, -2]
v = [0, 3, 5]
x = [-1, 0, 4]
y = [-1, -10, 4]
# In Matlab, run:
#
# uv = [u'; v'];
# xy = [x'; y'];
# tform_sim=cp2tform(uv,xy,'similarity');
#
# trans = tform_sim.tdata.T
# ans =
# -0.0764 -1.6190 0
# 1.6190 -0.0764 0
# -3.2156 0.0290 1.0000
# trans_inv = tform_sim.tdata.Tinv
# ans =
#
# -0.0291 0.6163 0
# -0.6163 -0.0291 0
# -0.0756 1.9826 1.0000
# xy_m=tformfwd(tform_sim, u,v)
#
# xy_m =
#
# -3.2156 0.0290
# 1.1833 -9.9143
# 5.0323 2.8853
# uv_m=tforminv(tform_sim, x,y)
#
# uv_m =
#
# 0.5698 1.3953
# 6.0872 2.2733
# -2.6570 4.3314
"""
u = [0, 6, -2]
v = [0, 3, 5]
x = [-1, 0, 4]
y = [-1, -10, 4]
uv = np.array((u, v)).T
xy = np.array((x, y)).T
print('\n--->uv:')
print(uv)
print('\n--->xy:')
print(xy)
trans, trans_inv = get_similarity_transform(uv, xy)
print('\n--->trans matrix:')
print(trans)
print('\n--->trans_inv matrix:')
print(trans_inv)
print('\n---> apply transform to uv')
print('\nxy_m = uv_augmented * trans')
uv_aug = np.hstack((uv, np.ones((uv.shape[0], 1))))
xy_m = np.dot(uv_aug, trans)
print(xy_m)
print('\nxy_m = tformfwd(trans, uv)')
xy_m = tformfwd(trans, uv)
print(xy_m)
print('\n---> apply inverse transform to xy')
print('\nuv_m = xy_augmented * trans_inv')
xy_aug = np.hstack((xy, np.ones((xy.shape[0], 1))))
uv_m = np.dot(xy_aug, trans_inv)
print(uv_m)
print('\nuv_m = tformfwd(trans_inv, xy)')
uv_m = tformfwd(trans_inv, xy)
print(uv_m)
uv_m = tforminv(trans, xy)
print('\nuv_m = tforminv(trans, xy)')
print(uv_m)

View File

@ -0,0 +1,419 @@
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torchvision.models._utils import IntermediateLayerGetter as IntermediateLayerGetter
from .align_trans import get_reference_facial_points, warp_and_crop_face
from .retinaface_net import (
FPN,
SSH,
MobileNetV1,
make_bbox_head,
make_class_head,
make_landmark_head,
)
from .retinaface_utils import (
PriorBox,
batched_decode,
batched_decode_landm,
decode,
decode_landm,
py_cpu_nms,
)
def generate_config(network_name):
cfg_mnet = {
"name": "mobilenet0.25",
"min_sizes": [[16, 32], [64, 128], [256, 512]],
"steps": [8, 16, 32],
"variance": [0.1, 0.2],
"clip": False,
"loc_weight": 2.0,
"gpu_train": True,
"batch_size": 32,
"ngpu": 1,
"epoch": 250,
"decay1": 190,
"decay2": 220,
"image_size": 640,
"return_layers": {"stage1": 1, "stage2": 2, "stage3": 3},
"in_channel": 32,
"out_channel": 64,
}
cfg_re50 = {
"name": "Resnet50",
"min_sizes": [[16, 32], [64, 128], [256, 512]],
"steps": [8, 16, 32],
"variance": [0.1, 0.2],
"clip": False,
"loc_weight": 2.0,
"gpu_train": True,
"batch_size": 24,
"ngpu": 4,
"epoch": 100,
"decay1": 70,
"decay2": 90,
"image_size": 840,
"return_layers": {"layer2": 1, "layer3": 2, "layer4": 3},
"in_channel": 256,
"out_channel": 256,
}
if network_name == "mobile0.25":
return cfg_mnet
elif network_name == "resnet50":
return cfg_re50
else:
raise NotImplementedError(f"network_name={network_name}")
class RetinaFace(nn.Module):
def __init__(self, network_name="resnet50", half=False, phase="test", device=None):
self.device = (
torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device is None
else device
)
super(RetinaFace, self).__init__()
self.half_inference = half
cfg = generate_config(network_name)
self.backbone = cfg["name"]
self.model_name = f"retinaface_{network_name}"
self.cfg = cfg
self.phase = phase
self.target_size, self.max_size = 1600, 2150
self.resize, self.scale, self.scale1 = 1.0, None, None
self.mean_tensor = torch.tensor(
[[[[104.0]], [[117.0]], [[123.0]]]], device=self.device
)
self.reference = get_reference_facial_points(default_square=True)
# Build network.
backbone = None
if cfg["name"] == "mobilenet0.25":
backbone = MobileNetV1()
self.body = IntermediateLayerGetter(backbone, cfg["return_layers"])
elif cfg["name"] == "Resnet50":
import torchvision.models as models
backbone = models.resnet50(pretrained=False)
self.body = IntermediateLayerGetter(backbone, cfg["return_layers"])
in_channels_stage2 = cfg["in_channel"]
in_channels_list = [
in_channels_stage2 * 2,
in_channels_stage2 * 4,
in_channels_stage2 * 8,
]
out_channels = cfg["out_channel"]
self.fpn = FPN(in_channels_list, out_channels)
self.ssh1 = SSH(out_channels, out_channels)
self.ssh2 = SSH(out_channels, out_channels)
self.ssh3 = SSH(out_channels, out_channels)
self.ClassHead = make_class_head(fpn_num=3, inchannels=cfg["out_channel"])
self.BboxHead = make_bbox_head(fpn_num=3, inchannels=cfg["out_channel"])
self.LandmarkHead = make_landmark_head(fpn_num=3, inchannels=cfg["out_channel"])
self.to(self.device)
self.eval()
if self.half_inference:
self.half()
def forward(self, inputs):
out = self.body(inputs)
if self.backbone == "mobilenet0.25" or self.backbone == "Resnet50":
out = list(out.values())
# FPN
fpn = self.fpn(out)
# SSH
feature1 = self.ssh1(fpn[0])
feature2 = self.ssh2(fpn[1])
feature3 = self.ssh3(fpn[2])
features = [feature1, feature2, feature3]
bbox_regressions = torch.cat(
[self.BboxHead[i](feature) for i, feature in enumerate(features)], dim=1
)
classifications = torch.cat(
[self.ClassHead[i](feature) for i, feature in enumerate(features)], dim=1
)
tmp = [self.LandmarkHead[i](feature) for i, feature in enumerate(features)]
ldm_regressions = torch.cat(tmp, dim=1)
if self.phase == "train":
output = (bbox_regressions, classifications, ldm_regressions)
else:
output = (
bbox_regressions,
F.softmax(classifications, dim=-1),
ldm_regressions,
)
return output
def __detect_faces(self, inputs):
# get scale
height, width = inputs.shape[2:]
self.scale = torch.tensor(
[width, height, width, height], dtype=torch.float32, device=self.device
)
tmp = [
width,
height,
width,
height,
width,
height,
width,
height,
width,
height,
]
self.scale1 = torch.tensor(tmp, dtype=torch.float32, device=self.device)
# forawrd
inputs = inputs.to(self.device)
if self.half_inference:
inputs = inputs.half()
loc, conf, landmarks = self(inputs)
# get priorbox
priorbox = PriorBox(self.cfg, image_size=inputs.shape[2:])
priors = priorbox.forward().to(self.device)
return loc, conf, landmarks, priors
# single image detection
def transform(self, image, use_origin_size):
# convert to opencv format
if isinstance(image, Image.Image):
image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
image = image.astype(np.float32)
# testing scale
im_size_min = np.min(image.shape[0:2])
im_size_max = np.max(image.shape[0:2])
resize = float(self.target_size) / float(im_size_min)
# prevent bigger axis from being more than max_size
if np.round(resize * im_size_max) > self.max_size:
resize = float(self.max_size) / float(im_size_max)
resize = 1 if use_origin_size else resize
# resize
if resize != 1:
image = cv2.resize(
image, None, None, fx=resize, fy=resize, interpolation=cv2.INTER_LINEAR
)
# convert to torch.tensor format
# image -= (104, 117, 123)
image = image.transpose(2, 0, 1)
image = torch.from_numpy(image).unsqueeze(0)
return image, resize
def detect_faces(
self,
image,
conf_threshold=0.8,
nms_threshold=0.4,
use_origin_size=True,
):
image, self.resize = self.transform(image, use_origin_size)
image = image.to(self.device)
if self.half_inference:
image = image.half()
image = image - self.mean_tensor
loc, conf, landmarks, priors = self.__detect_faces(image)
boxes = decode(loc.data.squeeze(0), priors.data, self.cfg["variance"])
boxes = boxes * self.scale / self.resize
boxes = boxes.cpu().numpy()
scores = conf.squeeze(0).data.cpu().numpy()[:, 1]
landmarks = decode_landm(landmarks.squeeze(0), priors, self.cfg["variance"])
landmarks = landmarks * self.scale1 / self.resize
landmarks = landmarks.cpu().numpy()
# ignore low scores
inds = np.where(scores > conf_threshold)[0]
boxes, landmarks, scores = boxes[inds], landmarks[inds], scores[inds]
# sort
order = scores.argsort()[::-1]
boxes, landmarks, scores = boxes[order], landmarks[order], scores[order]
# do NMS
bounding_boxes = np.hstack((boxes, scores[:, np.newaxis])).astype(
np.float32, copy=False
)
keep = py_cpu_nms(bounding_boxes, nms_threshold)
bounding_boxes, landmarks = bounding_boxes[keep, :], landmarks[keep]
# self.t['forward_pass'].toc()
# print(self.t['forward_pass'].average_time)
# import sys
# sys.stdout.flush()
return np.concatenate((bounding_boxes, landmarks), axis=1)
def __align_multi(self, image, boxes, landmarks, limit=None):
if len(boxes) < 1:
return [], []
if limit:
boxes = boxes[:limit]
landmarks = landmarks[:limit]
faces = []
for landmark in landmarks:
facial5points = [[landmark[2 * j], landmark[2 * j + 1]] for j in range(5)]
warped_face = warp_and_crop_face(
np.array(image), facial5points, self.reference, crop_size=(112, 112)
)
faces.append(warped_face)
return np.concatenate((boxes, landmarks), axis=1), faces
def align_multi(self, img, conf_threshold=0.8, limit=None):
rlt = self.detect_faces(img, conf_threshold=conf_threshold)
boxes, landmarks = rlt[:, 0:5], rlt[:, 5:]
return self.__align_multi(img, boxes, landmarks, limit)
# batched detection
def batched_transform(self, frames, use_origin_size):
"""
Arguments:
frames: a list of PIL.Image, or torch.Tensor(shape=[n, h, w, c],
type=np.float32, BGR format).
use_origin_size: whether to use origin size.
"""
from_PIL = True if isinstance(frames[0], Image.Image) else False
# convert to opencv format
if from_PIL:
frames = [
cv2.cvtColor(np.asarray(frame), cv2.COLOR_RGB2BGR) for frame in frames
]
frames = np.asarray(frames, dtype=np.float32)
# testing scale
im_size_min = np.min(frames[0].shape[0:2])
im_size_max = np.max(frames[0].shape[0:2])
resize = float(self.target_size) / float(im_size_min)
# prevent bigger axis from being more than max_size
if np.round(resize * im_size_max) > self.max_size:
resize = float(self.max_size) / float(im_size_max)
resize = 1 if use_origin_size else resize
# resize
if resize != 1:
if not from_PIL:
frames = F.interpolate(frames, scale_factor=resize)
else:
frames = [
cv2.resize(
frame,
None,
None,
fx=resize,
fy=resize,
interpolation=cv2.INTER_LINEAR,
)
for frame in frames
]
# convert to torch.tensor format
if not from_PIL:
frames = frames.transpose(1, 2).transpose(1, 3).contiguous()
else:
frames = frames.transpose((0, 3, 1, 2))
frames = torch.from_numpy(frames)
return frames, resize
def batched_detect_faces(
self, frames, conf_threshold=0.8, nms_threshold=0.4, use_origin_size=True
):
"""
Arguments:
frames: a list of PIL.Image, or np.array(shape=[n, h, w, c],
type=np.uint8, BGR format).
conf_threshold: confidence threshold.
nms_threshold: nms threshold.
use_origin_size: whether to use origin size.
Returns:
final_bounding_boxes: list of np.array ([n_boxes, 5],
type=np.float32).
final_landmarks: list of np.array ([n_boxes, 10], type=np.float32).
"""
# self.t['forward_pass'].tic()
frames, self.resize = self.batched_transform(frames, use_origin_size)
frames = frames.to(self.device)
frames = frames - self.mean_tensor
b_loc, b_conf, b_landmarks, priors = self.__detect_faces(frames)
final_bounding_boxes, final_landmarks = [], []
# decode
priors = priors.unsqueeze(0)
b_loc = (
batched_decode(b_loc, priors, self.cfg["variance"])
* self.scale
/ self.resize
)
b_landmarks = (
batched_decode_landm(b_landmarks, priors, self.cfg["variance"])
* self.scale1
/ self.resize
)
b_conf = b_conf[:, :, 1]
# index for selection
b_indice = b_conf > conf_threshold
# concat
b_loc_and_conf = torch.cat((b_loc, b_conf.unsqueeze(-1)), dim=2).float()
for pred, landm, inds in zip(b_loc_and_conf, b_landmarks, b_indice):
# ignore low scores
pred, landm = pred[inds, :], landm[inds, :]
if pred.shape[0] == 0:
final_bounding_boxes.append(np.array([], dtype=np.float32))
final_landmarks.append(np.array([], dtype=np.float32))
continue
# sort
# order = score.argsort(descending=True)
# box, landm, score = box[order], landm[order], score[order]
# to CPU
bounding_boxes, landm = pred.cpu().numpy(), landm.cpu().numpy()
# NMS
keep = py_cpu_nms(bounding_boxes, nms_threshold)
bounding_boxes, landmarks = bounding_boxes[keep, :], landm[keep]
# append
final_bounding_boxes.append(bounding_boxes)
final_landmarks.append(landmarks)
# self.t['forward_pass'].toc(average=True)
# self.batch_time += self.t['forward_pass'].diff
# self.total_frame += len(frames)
# print(self.batch_time / self.total_frame)
return final_bounding_boxes, final_landmarks

View File

@ -0,0 +1,196 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
def conv_bn(inp, oup, stride=1, leaky=0):
return nn.Sequential(
nn.Conv2d(inp, oup, 3, stride, 1, bias=False), nn.BatchNorm2d(oup),
nn.LeakyReLU(negative_slope=leaky, inplace=True))
def conv_bn_no_relu(inp, oup, stride):
return nn.Sequential(
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
nn.BatchNorm2d(oup),
)
def conv_bn1X1(inp, oup, stride, leaky=0):
return nn.Sequential(
nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False), nn.BatchNorm2d(oup),
nn.LeakyReLU(negative_slope=leaky, inplace=True))
def conv_dw(inp, oup, stride, leaky=0.1):
return nn.Sequential(
nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
nn.BatchNorm2d(inp),
nn.LeakyReLU(negative_slope=leaky, inplace=True),
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
nn.LeakyReLU(negative_slope=leaky, inplace=True),
)
class SSH(nn.Module):
def __init__(self, in_channel, out_channel):
super(SSH, self).__init__()
assert out_channel % 4 == 0
leaky = 0
if (out_channel <= 64):
leaky = 0.1
self.conv3X3 = conv_bn_no_relu(in_channel, out_channel // 2, stride=1)
self.conv5X5_1 = conv_bn(in_channel, out_channel // 4, stride=1, leaky=leaky)
self.conv5X5_2 = conv_bn_no_relu(out_channel // 4, out_channel // 4, stride=1)
self.conv7X7_2 = conv_bn(out_channel // 4, out_channel // 4, stride=1, leaky=leaky)
self.conv7x7_3 = conv_bn_no_relu(out_channel // 4, out_channel // 4, stride=1)
def forward(self, input):
conv3X3 = self.conv3X3(input)
conv5X5_1 = self.conv5X5_1(input)
conv5X5 = self.conv5X5_2(conv5X5_1)
conv7X7_2 = self.conv7X7_2(conv5X5_1)
conv7X7 = self.conv7x7_3(conv7X7_2)
out = torch.cat([conv3X3, conv5X5, conv7X7], dim=1)
out = F.relu(out)
return out
class FPN(nn.Module):
def __init__(self, in_channels_list, out_channels):
super(FPN, self).__init__()
leaky = 0
if (out_channels <= 64):
leaky = 0.1
self.output1 = conv_bn1X1(in_channels_list[0], out_channels, stride=1, leaky=leaky)
self.output2 = conv_bn1X1(in_channels_list[1], out_channels, stride=1, leaky=leaky)
self.output3 = conv_bn1X1(in_channels_list[2], out_channels, stride=1, leaky=leaky)
self.merge1 = conv_bn(out_channels, out_channels, leaky=leaky)
self.merge2 = conv_bn(out_channels, out_channels, leaky=leaky)
def forward(self, input):
# names = list(input.keys())
# input = list(input.values())
output1 = self.output1(input[0])
output2 = self.output2(input[1])
output3 = self.output3(input[2])
up3 = F.interpolate(output3, size=[output2.size(2), output2.size(3)], mode='nearest')
output2 = output2 + up3
output2 = self.merge2(output2)
up2 = F.interpolate(output2, size=[output1.size(2), output1.size(3)], mode='nearest')
output1 = output1 + up2
output1 = self.merge1(output1)
out = [output1, output2, output3]
return out
class MobileNetV1(nn.Module):
def __init__(self):
super(MobileNetV1, self).__init__()
self.stage1 = nn.Sequential(
conv_bn(3, 8, 2, leaky=0.1), # 3
conv_dw(8, 16, 1), # 7
conv_dw(16, 32, 2), # 11
conv_dw(32, 32, 1), # 19
conv_dw(32, 64, 2), # 27
conv_dw(64, 64, 1), # 43
)
self.stage2 = nn.Sequential(
conv_dw(64, 128, 2), # 43 + 16 = 59
conv_dw(128, 128, 1), # 59 + 32 = 91
conv_dw(128, 128, 1), # 91 + 32 = 123
conv_dw(128, 128, 1), # 123 + 32 = 155
conv_dw(128, 128, 1), # 155 + 32 = 187
conv_dw(128, 128, 1), # 187 + 32 = 219
)
self.stage3 = nn.Sequential(
conv_dw(128, 256, 2), # 219 +3 2 = 241
conv_dw(256, 256, 1), # 241 + 64 = 301
)
self.avg = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(256, 1000)
def forward(self, x):
x = self.stage1(x)
x = self.stage2(x)
x = self.stage3(x)
x = self.avg(x)
# x = self.model(x)
x = x.view(-1, 256)
x = self.fc(x)
return x
class ClassHead(nn.Module):
def __init__(self, inchannels=512, num_anchors=3):
super(ClassHead, self).__init__()
self.num_anchors = num_anchors
self.conv1x1 = nn.Conv2d(inchannels, self.num_anchors * 2, kernel_size=(1, 1), stride=1, padding=0)
def forward(self, x):
out = self.conv1x1(x)
out = out.permute(0, 2, 3, 1).contiguous()
return out.view(out.shape[0], -1, 2)
class BboxHead(nn.Module):
def __init__(self, inchannels=512, num_anchors=3):
super(BboxHead, self).__init__()
self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 4, kernel_size=(1, 1), stride=1, padding=0)
def forward(self, x):
out = self.conv1x1(x)
out = out.permute(0, 2, 3, 1).contiguous()
return out.view(out.shape[0], -1, 4)
class LandmarkHead(nn.Module):
def __init__(self, inchannels=512, num_anchors=3):
super(LandmarkHead, self).__init__()
self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 10, kernel_size=(1, 1), stride=1, padding=0)
def forward(self, x):
out = self.conv1x1(x)
out = out.permute(0, 2, 3, 1).contiguous()
return out.view(out.shape[0], -1, 10)
def make_class_head(fpn_num=3, inchannels=64, anchor_num=2):
classhead = nn.ModuleList()
for i in range(fpn_num):
classhead.append(ClassHead(inchannels, anchor_num))
return classhead
def make_bbox_head(fpn_num=3, inchannels=64, anchor_num=2):
bboxhead = nn.ModuleList()
for i in range(fpn_num):
bboxhead.append(BboxHead(inchannels, anchor_num))
return bboxhead
def make_landmark_head(fpn_num=3, inchannels=64, anchor_num=2):
landmarkhead = nn.ModuleList()
for i in range(fpn_num):
landmarkhead.append(LandmarkHead(inchannels, anchor_num))
return landmarkhead

View File

@ -0,0 +1,421 @@
import numpy as np
import torch
import torchvision
from itertools import product as product
from math import ceil
class PriorBox(object):
def __init__(self, cfg, image_size=None, phase='train'):
super(PriorBox, self).__init__()
self.min_sizes = cfg['min_sizes']
self.steps = cfg['steps']
self.clip = cfg['clip']
self.image_size = image_size
self.feature_maps = [[ceil(self.image_size[0] / step), ceil(self.image_size[1] / step)] for step in self.steps]
self.name = 's'
def forward(self):
anchors = []
for k, f in enumerate(self.feature_maps):
min_sizes = self.min_sizes[k]
for i, j in product(range(f[0]), range(f[1])):
for min_size in min_sizes:
s_kx = min_size / self.image_size[1]
s_ky = min_size / self.image_size[0]
dense_cx = [x * self.steps[k] / self.image_size[1] for x in [j + 0.5]]
dense_cy = [y * self.steps[k] / self.image_size[0] for y in [i + 0.5]]
for cy, cx in product(dense_cy, dense_cx):
anchors += [cx, cy, s_kx, s_ky]
# back to torch land
output = torch.Tensor(anchors).view(-1, 4)
if self.clip:
output.clamp_(max=1, min=0)
return output
def py_cpu_nms(dets, thresh):
"""Pure Python NMS baseline."""
keep = torchvision.ops.nms(
boxes=torch.Tensor(dets[:, :4]),
scores=torch.Tensor(dets[:, 4]),
iou_threshold=thresh,
)
return list(keep)
def point_form(boxes):
""" Convert prior_boxes to (xmin, ymin, xmax, ymax)
representation for comparison to point form ground truth data.
Args:
boxes: (tensor) center-size default boxes from priorbox layers.
Return:
boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes.
"""
return torch.cat(
(
boxes[:, :2] - boxes[:, 2:] / 2, # xmin, ymin
boxes[:, :2] + boxes[:, 2:] / 2),
1) # xmax, ymax
def center_size(boxes):
""" Convert prior_boxes to (cx, cy, w, h)
representation for comparison to center-size form ground truth data.
Args:
boxes: (tensor) point_form boxes
Return:
boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes.
"""
return torch.cat(
(boxes[:, 2:] + boxes[:, :2]) / 2, # cx, cy
boxes[:, 2:] - boxes[:, :2],
1) # w, h
def intersect(box_a, box_b):
""" We resize both tensors to [A,B,2] without new malloc:
[A,2] -> [A,1,2] -> [A,B,2]
[B,2] -> [1,B,2] -> [A,B,2]
Then we compute the area of intersect between box_a and box_b.
Args:
box_a: (tensor) bounding boxes, Shape: [A,4].
box_b: (tensor) bounding boxes, Shape: [B,4].
Return:
(tensor) intersection area, Shape: [A,B].
"""
A = box_a.size(0)
B = box_b.size(0)
max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2), box_b[:, 2:].unsqueeze(0).expand(A, B, 2))
min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2), box_b[:, :2].unsqueeze(0).expand(A, B, 2))
inter = torch.clamp((max_xy - min_xy), min=0)
return inter[:, :, 0] * inter[:, :, 1]
def jaccard(box_a, box_b):
"""Compute the jaccard overlap of two sets of boxes. The jaccard overlap
is simply the intersection over union of two boxes. Here we operate on
ground truth boxes and default boxes.
E.g.:
A B / A B = A B / (area(A) + area(B) - A B)
Args:
box_a: (tensor) Ground truth bounding boxes, Shape: [num_objects,4]
box_b: (tensor) Prior boxes from priorbox layers, Shape: [num_priors,4]
Return:
jaccard overlap: (tensor) Shape: [box_a.size(0), box_b.size(0)]
"""
inter = intersect(box_a, box_b)
area_a = ((box_a[:, 2] - box_a[:, 0]) * (box_a[:, 3] - box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B]
area_b = ((box_b[:, 2] - box_b[:, 0]) * (box_b[:, 3] - box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B]
union = area_a + area_b - inter
return inter / union # [A,B]
def matrix_iou(a, b):
"""
return iou of a and b, numpy version for data augenmentation
"""
lt = np.maximum(a[:, np.newaxis, :2], b[:, :2])
rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2)
area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
area_b = np.prod(b[:, 2:] - b[:, :2], axis=1)
return area_i / (area_a[:, np.newaxis] + area_b - area_i)
def matrix_iof(a, b):
"""
return iof of a and b, numpy version for data augenmentation
"""
lt = np.maximum(a[:, np.newaxis, :2], b[:, :2])
rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2)
area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
return area_i / np.maximum(area_a[:, np.newaxis], 1)
def match(threshold, truths, priors, variances, labels, landms, loc_t, conf_t, landm_t, idx):
"""Match each prior box with the ground truth box of the highest jaccard
overlap, encode the bounding boxes, then return the matched indices
corresponding to both confidence and location preds.
Args:
threshold: (float) The overlap threshold used when matching boxes.
truths: (tensor) Ground truth boxes, Shape: [num_obj, 4].
priors: (tensor) Prior boxes from priorbox layers, Shape: [n_priors,4].
variances: (tensor) Variances corresponding to each prior coord,
Shape: [num_priors, 4].
labels: (tensor) All the class labels for the image, Shape: [num_obj].
landms: (tensor) Ground truth landms, Shape [num_obj, 10].
loc_t: (tensor) Tensor to be filled w/ encoded location targets.
conf_t: (tensor) Tensor to be filled w/ matched indices for conf preds.
landm_t: (tensor) Tensor to be filled w/ encoded landm targets.
idx: (int) current batch index
Return:
The matched indices corresponding to 1)location 2)confidence
3)landm preds.
"""
# jaccard index
overlaps = jaccard(truths, point_form(priors))
# (Bipartite Matching)
# [1,num_objects] best prior for each ground truth
best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True)
# ignore hard gt
valid_gt_idx = best_prior_overlap[:, 0] >= 0.2
best_prior_idx_filter = best_prior_idx[valid_gt_idx, :]
if best_prior_idx_filter.shape[0] <= 0:
loc_t[idx] = 0
conf_t[idx] = 0
return
# [1,num_priors] best ground truth for each prior
best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True)
best_truth_idx.squeeze_(0)
best_truth_overlap.squeeze_(0)
best_prior_idx.squeeze_(1)
best_prior_idx_filter.squeeze_(1)
best_prior_overlap.squeeze_(1)
best_truth_overlap.index_fill_(0, best_prior_idx_filter, 2) # ensure best prior
# TODO refactor: index best_prior_idx with long tensor
# ensure every gt matches with its prior of max overlap
for j in range(best_prior_idx.size(0)): # 判别此anchor是预测哪一个boxes
best_truth_idx[best_prior_idx[j]] = j
matches = truths[best_truth_idx] # Shape: [num_priors,4] 此处为每一个anchor对应的bbox取出来
conf = labels[best_truth_idx] # Shape: [num_priors] 此处为每一个anchor对应的label取出来
conf[best_truth_overlap < threshold] = 0 # label as background overlap<0.35的全部作为负样本
loc = encode(matches, priors, variances)
matches_landm = landms[best_truth_idx]
landm = encode_landm(matches_landm, priors, variances)
loc_t[idx] = loc # [num_priors,4] encoded offsets to learn
conf_t[idx] = conf # [num_priors] top class label for each prior
landm_t[idx] = landm
def encode(matched, priors, variances):
"""Encode the variances from the priorbox layers into the ground truth boxes
we have matched (based on jaccard overlap) with the prior boxes.
Args:
matched: (tensor) Coords of ground truth for each prior in point-form
Shape: [num_priors, 4].
priors: (tensor) Prior boxes in center-offset form
Shape: [num_priors,4].
variances: (list[float]) Variances of priorboxes
Return:
encoded boxes (tensor), Shape: [num_priors, 4]
"""
# dist b/t match center and prior's center
g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2]
# encode variance
g_cxcy /= (variances[0] * priors[:, 2:])
# match wh / prior wh
g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:]
g_wh = torch.log(g_wh) / variances[1]
# return target for smooth_l1_loss
return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4]
def encode_landm(matched, priors, variances):
"""Encode the variances from the priorbox layers into the ground truth boxes
we have matched (based on jaccard overlap) with the prior boxes.
Args:
matched: (tensor) Coords of ground truth for each prior in point-form
Shape: [num_priors, 10].
priors: (tensor) Prior boxes in center-offset form
Shape: [num_priors,4].
variances: (list[float]) Variances of priorboxes
Return:
encoded landm (tensor), Shape: [num_priors, 10]
"""
# dist b/t match center and prior's center
matched = torch.reshape(matched, (matched.size(0), 5, 2))
priors_cx = priors[:, 0].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
priors_cy = priors[:, 1].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
priors_w = priors[:, 2].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
priors_h = priors[:, 3].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
priors = torch.cat([priors_cx, priors_cy, priors_w, priors_h], dim=2)
g_cxcy = matched[:, :, :2] - priors[:, :, :2]
# encode variance
g_cxcy /= (variances[0] * priors[:, :, 2:])
# g_cxcy /= priors[:, :, 2:]
g_cxcy = g_cxcy.reshape(g_cxcy.size(0), -1)
# return target for smooth_l1_loss
return g_cxcy
# Adapted from https://github.com/Hakuyume/chainer-ssd
def decode(loc, priors, variances):
"""Decode locations from predictions using priors to undo
the encoding we did for offset regression at train time.
Args:
loc (tensor): location predictions for loc layers,
Shape: [num_priors,4]
priors (tensor): Prior boxes in center-offset form.
Shape: [num_priors,4].
variances: (list[float]) Variances of priorboxes
Return:
decoded bounding box predictions
"""
boxes = torch.cat((priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
boxes[:, :2] -= boxes[:, 2:] / 2
boxes[:, 2:] += boxes[:, :2]
return boxes
def decode_landm(pre, priors, variances):
"""Decode landm from predictions using priors to undo
the encoding we did for offset regression at train time.
Args:
pre (tensor): landm predictions for loc layers,
Shape: [num_priors,10]
priors (tensor): Prior boxes in center-offset form.
Shape: [num_priors,4].
variances: (list[float]) Variances of priorboxes
Return:
decoded landm predictions
"""
tmp = (
priors[:, :2] + pre[:, :2] * variances[0] * priors[:, 2:],
priors[:, :2] + pre[:, 2:4] * variances[0] * priors[:, 2:],
priors[:, :2] + pre[:, 4:6] * variances[0] * priors[:, 2:],
priors[:, :2] + pre[:, 6:8] * variances[0] * priors[:, 2:],
priors[:, :2] + pre[:, 8:10] * variances[0] * priors[:, 2:],
)
landms = torch.cat(tmp, dim=1)
return landms
def batched_decode(b_loc, priors, variances):
"""Decode locations from predictions using priors to undo
the encoding we did for offset regression at train time.
Args:
b_loc (tensor): location predictions for loc layers,
Shape: [num_batches,num_priors,4]
priors (tensor): Prior boxes in center-offset form.
Shape: [1,num_priors,4].
variances: (list[float]) Variances of priorboxes
Return:
decoded bounding box predictions
"""
boxes = (
priors[:, :, :2] + b_loc[:, :, :2] * variances[0] * priors[:, :, 2:],
priors[:, :, 2:] * torch.exp(b_loc[:, :, 2:] * variances[1]),
)
boxes = torch.cat(boxes, dim=2)
boxes[:, :, :2] -= boxes[:, :, 2:] / 2
boxes[:, :, 2:] += boxes[:, :, :2]
return boxes
def batched_decode_landm(pre, priors, variances):
"""Decode landm from predictions using priors to undo
the encoding we did for offset regression at train time.
Args:
pre (tensor): landm predictions for loc layers,
Shape: [num_batches,num_priors,10]
priors (tensor): Prior boxes in center-offset form.
Shape: [1,num_priors,4].
variances: (list[float]) Variances of priorboxes
Return:
decoded landm predictions
"""
landms = (
priors[:, :, :2] + pre[:, :, :2] * variances[0] * priors[:, :, 2:],
priors[:, :, :2] + pre[:, :, 2:4] * variances[0] * priors[:, :, 2:],
priors[:, :, :2] + pre[:, :, 4:6] * variances[0] * priors[:, :, 2:],
priors[:, :, :2] + pre[:, :, 6:8] * variances[0] * priors[:, :, 2:],
priors[:, :, :2] + pre[:, :, 8:10] * variances[0] * priors[:, :, 2:],
)
landms = torch.cat(landms, dim=2)
return landms
def log_sum_exp(x):
"""Utility function for computing log_sum_exp while determining
This will be used to determine unaveraged confidence loss across
all examples in a batch.
Args:
x (Variable(tensor)): conf_preds from conf layers
"""
x_max = x.data.max()
return torch.log(torch.sum(torch.exp(x - x_max), 1, keepdim=True)) + x_max
# Original author: Francisco Massa:
# https://github.com/fmassa/object-detection.torch
# Ported to PyTorch by Max deGroot (02/01/2017)
def nms(boxes, scores, overlap=0.5, top_k=200):
"""Apply non-maximum suppression at test time to avoid detecting too many
overlapping bounding boxes for a given object.
Args:
boxes: (tensor) The location preds for the img, Shape: [num_priors,4].
scores: (tensor) The class predscores for the img, Shape:[num_priors].
overlap: (float) The overlap thresh for suppressing unnecessary boxes.
top_k: (int) The Maximum number of box preds to consider.
Return:
The indices of the kept boxes with respect to num_priors.
"""
keep = torch.Tensor(scores.size(0)).fill_(0).long()
if boxes.numel() == 0:
return keep
x1 = boxes[:, 0]
y1 = boxes[:, 1]
x2 = boxes[:, 2]
y2 = boxes[:, 3]
area = torch.mul(x2 - x1, y2 - y1)
v, idx = scores.sort(0) # sort in ascending order
# I = I[v >= 0.01]
idx = idx[-top_k:] # indices of the top-k largest vals
xx1 = boxes.new()
yy1 = boxes.new()
xx2 = boxes.new()
yy2 = boxes.new()
w = boxes.new()
h = boxes.new()
# keep = torch.Tensor()
count = 0
while idx.numel() > 0:
i = idx[-1] # index of current largest val
# keep.append(i)
keep[count] = i
count += 1
if idx.size(0) == 1:
break
idx = idx[:-1] # remove kept element from view
# load bboxes of next highest vals
torch.index_select(x1, 0, idx, out=xx1)
torch.index_select(y1, 0, idx, out=yy1)
torch.index_select(x2, 0, idx, out=xx2)
torch.index_select(y2, 0, idx, out=yy2)
# store element-wise max with next highest score
xx1 = torch.clamp(xx1, min=x1[i])
yy1 = torch.clamp(yy1, min=y1[i])
xx2 = torch.clamp(xx2, max=x2[i])
yy2 = torch.clamp(yy2, max=y2[i])
w.resize_as_(xx2)
h.resize_as_(yy2)
w = xx2 - xx1
h = yy2 - yy1
# check sizes of xx1 and xx2.. after each iteration
w = torch.clamp(w, min=0.0)
h = torch.clamp(h, min=0.0)
inter = w * h
# IoU = i / (area(a) + area(b) - i)
rem_areas = torch.index_select(area, 0, idx) # load remaining areas)
union = (rem_areas - inter) + area[i]
IoU = inter / union # store result in iou
# keep only elements with an IoU <= overlap
idx = idx[IoU.le(overlap)]
return keep, count

View File

@ -0,0 +1,24 @@
import torch
from ..utils import load_file_from_url
from .bisenet import BiSeNet
from .parsenet import ParseNet
def init_parsing_model(model_name='bisenet', half=False, device='cuda', model_rootpath=None):
if model_name == 'bisenet':
model = BiSeNet(num_class=19)
model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.2.0/parsing_bisenet.pth'
elif model_name == 'parsenet':
model = ParseNet(in_size=512, out_size=512, parsing_ch=19)
model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.2.2/parsing_parsenet.pth'
else:
raise NotImplementedError(f'{model_name} is not implemented.')
model_path = load_file_from_url(
url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath)
load_net = torch.load(model_path, map_location=lambda storage, loc: storage)
model.load_state_dict(load_net, strict=True)
model.eval()
model = model.to(device)
return model

View File

@ -0,0 +1,140 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from .resnet import ResNet18
class ConvBNReLU(nn.Module):
def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1):
super(ConvBNReLU, self).__init__()
self.conv = nn.Conv2d(in_chan, out_chan, kernel_size=ks, stride=stride, padding=padding, bias=False)
self.bn = nn.BatchNorm2d(out_chan)
def forward(self, x):
x = self.conv(x)
x = F.relu(self.bn(x))
return x
class BiSeNetOutput(nn.Module):
def __init__(self, in_chan, mid_chan, num_class):
super(BiSeNetOutput, self).__init__()
self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
self.conv_out = nn.Conv2d(mid_chan, num_class, kernel_size=1, bias=False)
def forward(self, x):
feat = self.conv(x)
out = self.conv_out(feat)
return out, feat
class AttentionRefinementModule(nn.Module):
def __init__(self, in_chan, out_chan):
super(AttentionRefinementModule, self).__init__()
self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size=1, bias=False)
self.bn_atten = nn.BatchNorm2d(out_chan)
self.sigmoid_atten = nn.Sigmoid()
def forward(self, x):
feat = self.conv(x)
atten = F.avg_pool2d(feat, feat.size()[2:])
atten = self.conv_atten(atten)
atten = self.bn_atten(atten)
atten = self.sigmoid_atten(atten)
out = torch.mul(feat, atten)
return out
class ContextPath(nn.Module):
def __init__(self):
super(ContextPath, self).__init__()
self.resnet = ResNet18()
self.arm16 = AttentionRefinementModule(256, 128)
self.arm32 = AttentionRefinementModule(512, 128)
self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
def forward(self, x):
feat8, feat16, feat32 = self.resnet(x)
h8, w8 = feat8.size()[2:]
h16, w16 = feat16.size()[2:]
h32, w32 = feat32.size()[2:]
avg = F.avg_pool2d(feat32, feat32.size()[2:])
avg = self.conv_avg(avg)
avg_up = F.interpolate(avg, (h32, w32), mode='nearest')
feat32_arm = self.arm32(feat32)
feat32_sum = feat32_arm + avg_up
feat32_up = F.interpolate(feat32_sum, (h16, w16), mode='nearest')
feat32_up = self.conv_head32(feat32_up)
feat16_arm = self.arm16(feat16)
feat16_sum = feat16_arm + feat32_up
feat16_up = F.interpolate(feat16_sum, (h8, w8), mode='nearest')
feat16_up = self.conv_head16(feat16_up)
return feat8, feat16_up, feat32_up # x8, x8, x16
class FeatureFusionModule(nn.Module):
def __init__(self, in_chan, out_chan):
super(FeatureFusionModule, self).__init__()
self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
self.conv1 = nn.Conv2d(out_chan, out_chan // 4, kernel_size=1, stride=1, padding=0, bias=False)
self.conv2 = nn.Conv2d(out_chan // 4, out_chan, kernel_size=1, stride=1, padding=0, bias=False)
self.relu = nn.ReLU(inplace=True)
self.sigmoid = nn.Sigmoid()
def forward(self, fsp, fcp):
fcat = torch.cat([fsp, fcp], dim=1)
feat = self.convblk(fcat)
atten = F.avg_pool2d(feat, feat.size()[2:])
atten = self.conv1(atten)
atten = self.relu(atten)
atten = self.conv2(atten)
atten = self.sigmoid(atten)
feat_atten = torch.mul(feat, atten)
feat_out = feat_atten + feat
return feat_out
class BiSeNet(nn.Module):
def __init__(self, num_class):
super(BiSeNet, self).__init__()
self.cp = ContextPath()
self.ffm = FeatureFusionModule(256, 256)
self.conv_out = BiSeNetOutput(256, 256, num_class)
self.conv_out16 = BiSeNetOutput(128, 64, num_class)
self.conv_out32 = BiSeNetOutput(128, 64, num_class)
def forward(self, x, return_feat=False):
h, w = x.size()[2:]
feat_res8, feat_cp8, feat_cp16 = self.cp(x) # return res3b1 feature
feat_sp = feat_res8 # replace spatial path feature with res3b1 feature
feat_fuse = self.ffm(feat_sp, feat_cp8)
out, feat = self.conv_out(feat_fuse)
out16, feat16 = self.conv_out16(feat_cp8)
out32, feat32 = self.conv_out32(feat_cp16)
out = F.interpolate(out, (h, w), mode='bilinear', align_corners=True)
out16 = F.interpolate(out16, (h, w), mode='bilinear', align_corners=True)
out32 = F.interpolate(out32, (h, w), mode='bilinear', align_corners=True)
if return_feat:
feat = F.interpolate(feat, (h, w), mode='bilinear', align_corners=True)
feat16 = F.interpolate(feat16, (h, w), mode='bilinear', align_corners=True)
feat32 = F.interpolate(feat32, (h, w), mode='bilinear', align_corners=True)
return out, out16, out32, feat, feat16, feat32
else:
return out, out16, out32

View File

@ -0,0 +1,194 @@
"""Modified from https://github.com/chaofengc/PSFRGAN
"""
import numpy as np
import torch.nn as nn
from torch.nn import functional as F
class NormLayer(nn.Module):
"""Normalization Layers.
Args:
channels: input channels, for batch norm and instance norm.
input_size: input shape without batch size, for layer norm.
"""
def __init__(self, channels, normalize_shape=None, norm_type='bn'):
super(NormLayer, self).__init__()
norm_type = norm_type.lower()
self.norm_type = norm_type
if norm_type == 'bn':
self.norm = nn.BatchNorm2d(channels, affine=True)
elif norm_type == 'in':
self.norm = nn.InstanceNorm2d(channels, affine=False)
elif norm_type == 'gn':
self.norm = nn.GroupNorm(32, channels, affine=True)
elif norm_type == 'pixel':
self.norm = lambda x: F.normalize(x, p=2, dim=1)
elif norm_type == 'layer':
self.norm = nn.LayerNorm(normalize_shape)
elif norm_type == 'none':
self.norm = lambda x: x * 1.0
else:
assert 1 == 0, f'Norm type {norm_type} not support.'
def forward(self, x, ref=None):
if self.norm_type == 'spade':
return self.norm(x, ref)
else:
return self.norm(x)
class ReluLayer(nn.Module):
"""Relu Layer.
Args:
relu type: type of relu layer, candidates are
- ReLU
- LeakyReLU: default relu slope 0.2
- PRelu
- SELU
- none: direct pass
"""
def __init__(self, channels, relu_type='relu'):
super(ReluLayer, self).__init__()
relu_type = relu_type.lower()
if relu_type == 'relu':
self.func = nn.ReLU(True)
elif relu_type == 'leakyrelu':
self.func = nn.LeakyReLU(0.2, inplace=True)
elif relu_type == 'prelu':
self.func = nn.PReLU(channels)
elif relu_type == 'selu':
self.func = nn.SELU(True)
elif relu_type == 'none':
self.func = lambda x: x * 1.0
else:
assert 1 == 0, f'Relu type {relu_type} not support.'
def forward(self, x):
return self.func(x)
class ConvLayer(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size=3,
scale='none',
norm_type='none',
relu_type='none',
use_pad=True,
bias=True):
super(ConvLayer, self).__init__()
self.use_pad = use_pad
self.norm_type = norm_type
if norm_type in ['bn']:
bias = False
stride = 2 if scale == 'down' else 1
self.scale_func = lambda x: x
if scale == 'up':
self.scale_func = lambda x: nn.functional.interpolate(x, scale_factor=2, mode='nearest')
self.reflection_pad = nn.ReflectionPad2d(int(np.ceil((kernel_size - 1.) / 2)))
self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, bias=bias)
self.relu = ReluLayer(out_channels, relu_type)
self.norm = NormLayer(out_channels, norm_type=norm_type)
def forward(self, x):
out = self.scale_func(x)
if self.use_pad:
out = self.reflection_pad(out)
out = self.conv2d(out)
out = self.norm(out)
out = self.relu(out)
return out
class ResidualBlock(nn.Module):
"""
Residual block recommended in: http://torch.ch/blog/2016/02/04/resnets.html
"""
def __init__(self, c_in, c_out, relu_type='prelu', norm_type='bn', scale='none'):
super(ResidualBlock, self).__init__()
if scale == 'none' and c_in == c_out:
self.shortcut_func = lambda x: x
else:
self.shortcut_func = ConvLayer(c_in, c_out, 3, scale)
scale_config_dict = {'down': ['none', 'down'], 'up': ['up', 'none'], 'none': ['none', 'none']}
scale_conf = scale_config_dict[scale]
self.conv1 = ConvLayer(c_in, c_out, 3, scale_conf[0], norm_type=norm_type, relu_type=relu_type)
self.conv2 = ConvLayer(c_out, c_out, 3, scale_conf[1], norm_type=norm_type, relu_type='none')
def forward(self, x):
identity = self.shortcut_func(x)
res = self.conv1(x)
res = self.conv2(res)
return identity + res
class ParseNet(nn.Module):
def __init__(self,
in_size=128,
out_size=128,
min_feat_size=32,
base_ch=64,
parsing_ch=19,
res_depth=10,
relu_type='LeakyReLU',
norm_type='bn',
ch_range=[32, 256]):
super().__init__()
self.res_depth = res_depth
act_args = {'norm_type': norm_type, 'relu_type': relu_type}
min_ch, max_ch = ch_range
ch_clip = lambda x: max(min_ch, min(x, max_ch)) # noqa: E731
min_feat_size = min(in_size, min_feat_size)
down_steps = int(np.log2(in_size // min_feat_size))
up_steps = int(np.log2(out_size // min_feat_size))
# =============== define encoder-body-decoder ====================
self.encoder = []
self.encoder.append(ConvLayer(3, base_ch, 3, 1))
head_ch = base_ch
for i in range(down_steps):
cin, cout = ch_clip(head_ch), ch_clip(head_ch * 2)
self.encoder.append(ResidualBlock(cin, cout, scale='down', **act_args))
head_ch = head_ch * 2
self.body = []
for i in range(res_depth):
self.body.append(ResidualBlock(ch_clip(head_ch), ch_clip(head_ch), **act_args))
self.decoder = []
for i in range(up_steps):
cin, cout = ch_clip(head_ch), ch_clip(head_ch // 2)
self.decoder.append(ResidualBlock(cin, cout, scale='up', **act_args))
head_ch = head_ch // 2
self.encoder = nn.Sequential(*self.encoder)
self.body = nn.Sequential(*self.body)
self.decoder = nn.Sequential(*self.decoder)
self.out_img_conv = ConvLayer(ch_clip(head_ch), 3)
self.out_mask_conv = ConvLayer(ch_clip(head_ch), parsing_ch)
def forward(self, x):
feat = self.encoder(x)
x = feat + self.body(feat)
x = self.decoder(x)
out_img = self.out_img_conv(x)
out_mask = self.out_mask_conv(x)
return out_mask, out_img

View File

@ -0,0 +1,69 @@
import torch.nn as nn
import torch.nn.functional as F
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
class BasicBlock(nn.Module):
def __init__(self, in_chan, out_chan, stride=1):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(in_chan, out_chan, stride)
self.bn1 = nn.BatchNorm2d(out_chan)
self.conv2 = conv3x3(out_chan, out_chan)
self.bn2 = nn.BatchNorm2d(out_chan)
self.relu = nn.ReLU(inplace=True)
self.downsample = None
if in_chan != out_chan or stride != 1:
self.downsample = nn.Sequential(
nn.Conv2d(in_chan, out_chan, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(out_chan),
)
def forward(self, x):
residual = self.conv1(x)
residual = F.relu(self.bn1(residual))
residual = self.conv2(residual)
residual = self.bn2(residual)
shortcut = x
if self.downsample is not None:
shortcut = self.downsample(x)
out = shortcut + residual
out = self.relu(out)
return out
def create_layer_basic(in_chan, out_chan, bnum, stride=1):
layers = [BasicBlock(in_chan, out_chan, stride=stride)]
for i in range(bnum - 1):
layers.append(BasicBlock(out_chan, out_chan, stride=1))
return nn.Sequential(*layers)
class ResNet18(nn.Module):
def __init__(self):
super(ResNet18, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
def forward(self, x):
x = self.conv1(x)
x = F.relu(self.bn1(x))
x = self.maxpool(x)
x = self.layer1(x)
feat8 = self.layer2(x) # 1/8
feat16 = self.layer3(feat8) # 1/16
feat32 = self.layer4(feat16) # 1/32
return feat8, feat16, feat32

View File

@ -0,0 +1,7 @@
from .face_utils import align_crop_face_landmarks, compute_increased_bbox, get_valid_bboxes, paste_face_back
from .misc import img2tensor, load_file_from_url, scandir
__all__ = [
'align_crop_face_landmarks', 'compute_increased_bbox', 'get_valid_bboxes', 'load_file_from_url', 'paste_face_back',
'img2tensor', 'scandir'
]

View File

@ -0,0 +1,473 @@
import cv2
import numpy as np
import os
import torch
from torchvision.transforms.functional import normalize
from ..detection import init_detection_model
from ..parsing import init_parsing_model
from ..utils.misc import img2tensor, imwrite
def get_largest_face(det_faces, h, w):
def get_location(val, length):
if val < 0:
return 0
elif val > length:
return length
else:
return val
face_areas = []
for det_face in det_faces:
left = get_location(det_face[0], w)
right = get_location(det_face[2], w)
top = get_location(det_face[1], h)
bottom = get_location(det_face[3], h)
face_area = (right - left) * (bottom - top)
face_areas.append(face_area)
largest_idx = face_areas.index(max(face_areas))
return det_faces[largest_idx], largest_idx
def get_center_face(det_faces, h=0, w=0, center=None):
if center is not None:
center = np.array(center)
else:
center = np.array([w / 2, h / 2])
center_dist = []
for det_face in det_faces:
face_center = np.array(
[(det_face[0] + det_face[2]) / 2, (det_face[1] + det_face[3]) / 2]
)
dist = np.linalg.norm(face_center - center)
center_dist.append(dist)
center_idx = center_dist.index(min(center_dist))
return det_faces[center_idx], center_idx
class FaceRestoreHelper(object):
"""Helper for the face restoration pipeline (base class)."""
def __init__(
self,
upscale_factor,
face_size=512,
crop_ratio=(1, 1),
det_model="retinaface_resnet50",
save_ext="png",
template_3points=False,
pad_blur=False,
use_parse=False,
device=None,
model_rootpath=None,
):
self.template_3points = template_3points # improve robustness
self.upscale_factor = upscale_factor
# the cropped face ratio based on the square face
self.crop_ratio = crop_ratio # (h, w)
assert (
self.crop_ratio[0] >= 1 and self.crop_ratio[1] >= 1
), "crop ration only supports >=1"
self.face_size = (
int(face_size * self.crop_ratio[1]),
int(face_size * self.crop_ratio[0]),
)
if self.template_3points:
self.face_template = np.array([[192, 240], [319, 240], [257, 371]])
else:
# standard 5 landmarks for FFHQ faces with 512 x 512
self.face_template = np.array(
[
[192.98138, 239.94708],
[318.90277, 240.1936],
[256.63416, 314.01935],
[201.26117, 371.41043],
[313.08905, 371.15118],
]
)
self.face_template = self.face_template * (face_size / 512.0)
if self.crop_ratio[0] > 1:
self.face_template[:, 1] += face_size * (self.crop_ratio[0] - 1) / 2
if self.crop_ratio[1] > 1:
self.face_template[:, 0] += face_size * (self.crop_ratio[1] - 1) / 2
self.save_ext = save_ext
self.pad_blur = pad_blur
if self.pad_blur is True:
self.template_3points = False
self.all_landmarks_5 = []
self.det_faces = []
self.affine_matrices = []
self.inverse_affine_matrices = []
self.cropped_faces = []
self.restored_faces = []
self.pad_input_imgs = []
if device is None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
self.device = device
# init face detection model
self.face_det = init_detection_model(
det_model, half=False, device=self.device, model_rootpath=model_rootpath
)
# init face parsing model
self.use_parse = use_parse
self.face_parse = init_parsing_model(
model_name="parsenet", device=self.device, model_rootpath=model_rootpath
)
def set_upscale_factor(self, upscale_factor):
self.upscale_factor = upscale_factor
def read_image(self, img):
"""img can be image path or cv2 loaded image."""
# self.input_img is Numpy array, (h, w, c), BGR, uint8, [0, 255]
if isinstance(img, str):
img = cv2.imread(img)
if np.max(img) > 256: # 16-bit image
img = img / 65535 * 255
if len(img.shape) == 2: # gray image
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
elif img.shape[2] == 4: # RGBA image with alpha channel
img = img[:, :, 0:3]
self.input_img = img
def get_face_landmarks_5(
self,
only_keep_largest=False,
only_center_face=False,
resize=None,
blur_ratio=0.01,
eye_dist_threshold=None,
):
if resize is None:
scale = 1
input_img = self.input_img
else:
h, w = self.input_img.shape[0:2]
scale = min(h, w) / resize
h, w = int(h / scale), int(w / scale)
input_img = cv2.resize(
self.input_img, (w, h), interpolation=cv2.INTER_LANCZOS4
)
with torch.no_grad():
bboxes = self.face_det.detect_faces(input_img, 0.97) * scale
for bbox in bboxes:
# remove faces with too small eye distance: side faces or too small faces
eye_dist = np.linalg.norm([bbox[5] - bbox[7], bbox[6] - bbox[8]])
if eye_dist_threshold is not None and (eye_dist < eye_dist_threshold):
continue
if self.template_3points:
landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 11, 2)])
else:
landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 15, 2)])
self.all_landmarks_5.append(landmark)
self.det_faces.append(bbox[0:5])
if len(self.det_faces) == 0:
return 0
if only_keep_largest:
h, w, _ = self.input_img.shape
self.det_faces, largest_idx = get_largest_face(self.det_faces, h, w)
self.all_landmarks_5 = [self.all_landmarks_5[largest_idx]]
elif only_center_face:
h, w, _ = self.input_img.shape
self.det_faces, center_idx = get_center_face(self.det_faces, h, w)
self.all_landmarks_5 = [self.all_landmarks_5[center_idx]]
# pad blurry images
if self.pad_blur:
self.pad_input_imgs = []
for landmarks in self.all_landmarks_5:
# get landmarks
eye_left = landmarks[0, :]
eye_right = landmarks[1, :]
eye_avg = (eye_left + eye_right) * 0.5
mouth_avg = (landmarks[3, :] + landmarks[4, :]) * 0.5
eye_to_eye = eye_right - eye_left
eye_to_mouth = mouth_avg - eye_avg
# Get the oriented crop rectangle
# x: half width of the oriented crop rectangle
x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
# - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise
# norm with the hypotenuse: get the direction
x /= np.hypot(*x) # get the hypotenuse of a right triangle
rect_scale = 1.5
x *= max(
np.hypot(*eye_to_eye) * 2.0 * rect_scale,
np.hypot(*eye_to_mouth) * 1.8 * rect_scale,
)
# y: half height of the oriented crop rectangle
y = np.flipud(x) * [-1, 1]
# c: center
c = eye_avg + eye_to_mouth * 0.1
# quad: (left_top, left_bottom, right_bottom, right_top)
quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
# qsize: side length of the square
qsize = np.hypot(*x) * 2
border = max(int(np.rint(qsize * 0.1)), 3)
# get pad
# pad: (width_left, height_top, width_right, height_bottom)
pad = (
int(np.floor(min(quad[:, 0]))),
int(np.floor(min(quad[:, 1]))),
int(np.ceil(max(quad[:, 0]))),
int(np.ceil(max(quad[:, 1]))),
)
pad = [
max(-pad[0] + border, 1),
max(-pad[1] + border, 1),
max(pad[2] - self.input_img.shape[0] + border, 1),
max(pad[3] - self.input_img.shape[1] + border, 1),
]
if max(pad) > 1:
# pad image
pad_img = np.pad(
self.input_img,
((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)),
"reflect",
)
# modify landmark coords
landmarks[:, 0] += pad[0]
landmarks[:, 1] += pad[1]
# blur pad images
h, w, _ = pad_img.shape
y, x, _ = np.ogrid[:h, :w, :1]
mask = np.maximum(
1.0
- np.minimum(
np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]
),
1.0
- np.minimum(
np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3]
),
)
blur = int(qsize * blur_ratio)
if blur % 2 == 0:
blur += 1
blur_img = cv2.boxFilter(pad_img, 0, ksize=(blur, blur))
# blur_img = cv2.GaussianBlur(pad_img, (blur, blur), 0)
pad_img = pad_img.astype("float32")
pad_img += (blur_img - pad_img) * np.clip(
mask * 3.0 + 1.0, 0.0, 1.0
)
pad_img += (np.median(pad_img, axis=(0, 1)) - pad_img) * np.clip(
mask, 0.0, 1.0
)
pad_img = np.clip(pad_img, 0, 255) # float32, [0, 255]
self.pad_input_imgs.append(pad_img)
else:
self.pad_input_imgs.append(np.copy(self.input_img))
return len(self.all_landmarks_5)
def align_warp_face(self, save_cropped_path=None, border_mode="constant"):
"""Align and warp faces with face template."""
if self.pad_blur:
assert (
len(self.pad_input_imgs) == len(self.all_landmarks_5)
), f"Mismatched samples: {len(self.pad_input_imgs)} and {len(self.all_landmarks_5)}"
for idx, landmark in enumerate(self.all_landmarks_5):
# use 5 landmarks to get affine matrix
# use cv2.LMEDS method for the equivalence to skimage transform
# ref: https://blog.csdn.net/yichxi/article/details/115827338
affine_matrix = cv2.estimateAffinePartial2D(
landmark, self.face_template, method=cv2.LMEDS
)[0]
self.affine_matrices.append(affine_matrix)
# warp and crop faces
if border_mode == "constant":
border_mode = cv2.BORDER_CONSTANT
elif border_mode == "reflect101":
border_mode = cv2.BORDER_REFLECT101
elif border_mode == "reflect":
border_mode = cv2.BORDER_REFLECT
if self.pad_blur:
input_img = self.pad_input_imgs[idx]
else:
input_img = self.input_img
cropped_face = cv2.warpAffine(
input_img,
affine_matrix,
self.face_size,
borderMode=border_mode,
borderValue=(135, 133, 132),
) # gray
self.cropped_faces.append(cropped_face)
# save the cropped face
if save_cropped_path is not None:
path = os.path.splitext(save_cropped_path)[0]
save_path = f"{path}_{idx:02d}.{self.save_ext}"
imwrite(cropped_face, save_path)
def get_inverse_affine(self, save_inverse_affine_path=None):
"""Get inverse affine matrix."""
for idx, affine_matrix in enumerate(self.affine_matrices):
inverse_affine = cv2.invertAffineTransform(affine_matrix)
inverse_affine *= self.upscale_factor
self.inverse_affine_matrices.append(inverse_affine)
# save inverse affine matrices
if save_inverse_affine_path is not None:
path, _ = os.path.splitext(save_inverse_affine_path)
save_path = f"{path}_{idx:02d}.pth"
torch.save(inverse_affine, save_path)
def add_restored_face(self, face):
self.restored_faces.append(face)
def paste_faces_to_input_image(self, save_path=None, upsample_img=None):
h, w, _ = self.input_img.shape
h_up, w_up = int(h * self.upscale_factor), int(w * self.upscale_factor)
if upsample_img is None:
# simply resize the background
upsample_img = cv2.resize(
self.input_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4
)
else:
upsample_img = cv2.resize(
upsample_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4
)
assert len(self.restored_faces) == len(
self.inverse_affine_matrices
), "length of restored_faces and affine_matrices are different."
for restored_face, inverse_affine in zip(
self.restored_faces, self.inverse_affine_matrices
):
# Add an offset to inverse affine matrix, for more precise back alignment
if self.upscale_factor > 1:
extra_offset = 0.5 * self.upscale_factor
else:
extra_offset = 0
inverse_affine[:, 2] += extra_offset
inv_restored = cv2.warpAffine(restored_face, inverse_affine, (w_up, h_up))
if self.use_parse:
# inference
face_input = cv2.resize(
restored_face, (512, 512), interpolation=cv2.INTER_LINEAR
)
face_input = img2tensor(
face_input.astype("float32") / 255.0, bgr2rgb=True, float32=True
)
normalize(face_input, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
face_input = torch.unsqueeze(face_input, 0).to(self.device)
with torch.no_grad():
out = self.face_parse(face_input)[0]
out = out.argmax(dim=1).squeeze().cpu().numpy()
mask = np.zeros(out.shape)
MASK_COLORMAP = [
0,
255,
255,
255,
255,
255,
255,
255,
255,
255,
255,
255,
255,
255,
0,
255,
0,
0,
0,
]
for idx, color in enumerate(MASK_COLORMAP):
mask[out == idx] = color
# blur the mask
mask = cv2.GaussianBlur(mask, (101, 101), 11)
mask = cv2.GaussianBlur(mask, (101, 101), 11)
# remove the black borders
thres = 10
mask[:thres, :] = 0
mask[-thres:, :] = 0
mask[:, :thres] = 0
mask[:, -thres:] = 0
mask = mask / 255.0
mask = cv2.resize(mask, restored_face.shape[:2])
mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up), flags=3)
inv_soft_mask = mask[:, :, None]
pasted_face = inv_restored
else: # use square parse maps
mask = np.ones(self.face_size, dtype=np.float32)
inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
# remove the black borders
inv_mask_erosion = cv2.erode(
inv_mask,
np.ones(
(int(2 * self.upscale_factor), int(2 * self.upscale_factor)),
np.uint8,
),
)
pasted_face = inv_mask_erosion[:, :, None] * inv_restored
total_face_area = np.sum(inv_mask_erosion) # // 3
# compute the fusion edge based on the area of face
w_edge = int(total_face_area**0.5) // 20
erosion_radius = w_edge * 2
inv_mask_center = cv2.erode(
inv_mask_erosion,
np.ones((erosion_radius, erosion_radius), np.uint8),
)
blur_size = w_edge * 2
inv_soft_mask = cv2.GaussianBlur(
inv_mask_center, (blur_size + 1, blur_size + 1), 0
)
if len(upsample_img.shape) == 2: # upsample_img is gray image
upsample_img = upsample_img[:, :, None]
inv_soft_mask = inv_soft_mask[:, :, None]
if (
len(upsample_img.shape) == 3 and upsample_img.shape[2] == 4
): # alpha channel
alpha = upsample_img[:, :, 3:]
upsample_img = (
inv_soft_mask * pasted_face
+ (1 - inv_soft_mask) * upsample_img[:, :, 0:3]
)
upsample_img = np.concatenate((upsample_img, alpha), axis=2)
else:
upsample_img = (
inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img
)
if np.max(upsample_img) > 256: # 16-bit image
upsample_img = upsample_img.astype(np.uint16)
else:
upsample_img = upsample_img.astype(np.uint8)
if save_path is not None:
path = os.path.splitext(save_path)[0]
save_path = f"{path}.{self.save_ext}"
imwrite(upsample_img, save_path)
return upsample_img
def clean_all(self):
self.all_landmarks_5 = []
self.restored_faces = []
self.affine_matrices = []
self.cropped_faces = []
self.inverse_affine_matrices = []
self.det_faces = []
self.pad_input_imgs = []

View File

@ -0,0 +1,208 @@
import cv2
import numpy as np
import torch
def compute_increased_bbox(bbox, increase_area, preserve_aspect=True):
left, top, right, bot = bbox
width = right - left
height = bot - top
if preserve_aspect:
width_increase = max(increase_area, ((1 + 2 * increase_area) * height - width) / (2 * width))
height_increase = max(increase_area, ((1 + 2 * increase_area) * width - height) / (2 * height))
else:
width_increase = height_increase = increase_area
left = int(left - width_increase * width)
top = int(top - height_increase * height)
right = int(right + width_increase * width)
bot = int(bot + height_increase * height)
return (left, top, right, bot)
def get_valid_bboxes(bboxes, h, w):
left = max(bboxes[0], 0)
top = max(bboxes[1], 0)
right = min(bboxes[2], w)
bottom = min(bboxes[3], h)
return (left, top, right, bottom)
def align_crop_face_landmarks(img,
landmarks,
output_size,
transform_size=None,
enable_padding=True,
return_inverse_affine=False,
shrink_ratio=(1, 1)):
"""Align and crop face with landmarks.
The output_size and transform_size are based on width. The height is
adjusted based on shrink_ratio_h/shring_ration_w.
Modified from:
https://github.com/NVlabs/ffhq-dataset/blob/master/download_ffhq.py
Args:
img (Numpy array): Input image.
landmarks (Numpy array): 5 or 68 or 98 landmarks.
output_size (int): Output face size.
transform_size (ing): Transform size. Usually the four time of
output_size.
enable_padding (float): Default: True.
shrink_ratio (float | tuple[float] | list[float]): Shring the whole
face for height and width (crop larger area). Default: (1, 1).
Returns:
(Numpy array): Cropped face.
"""
lm_type = 'retinaface_5' # Options: dlib_5, retinaface_5
if isinstance(shrink_ratio, (float, int)):
shrink_ratio = (shrink_ratio, shrink_ratio)
if transform_size is None:
transform_size = output_size * 4
# Parse landmarks
lm = np.array(landmarks)
if lm.shape[0] == 5 and lm_type == 'retinaface_5':
eye_left = lm[0]
eye_right = lm[1]
mouth_avg = (lm[3] + lm[4]) * 0.5
elif lm.shape[0] == 5 and lm_type == 'dlib_5':
lm_eye_left = lm[2:4]
lm_eye_right = lm[0:2]
eye_left = np.mean(lm_eye_left, axis=0)
eye_right = np.mean(lm_eye_right, axis=0)
mouth_avg = lm[4]
elif lm.shape[0] == 68:
lm_eye_left = lm[36:42]
lm_eye_right = lm[42:48]
eye_left = np.mean(lm_eye_left, axis=0)
eye_right = np.mean(lm_eye_right, axis=0)
mouth_avg = (lm[48] + lm[54]) * 0.5
elif lm.shape[0] == 98:
lm_eye_left = lm[60:68]
lm_eye_right = lm[68:76]
eye_left = np.mean(lm_eye_left, axis=0)
eye_right = np.mean(lm_eye_right, axis=0)
mouth_avg = (lm[76] + lm[82]) * 0.5
eye_avg = (eye_left + eye_right) * 0.5
eye_to_eye = eye_right - eye_left
eye_to_mouth = mouth_avg - eye_avg
# Get the oriented crop rectangle
# x: half width of the oriented crop rectangle
x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
# - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise
# norm with the hypotenuse: get the direction
x /= np.hypot(*x) # get the hypotenuse of a right triangle
rect_scale = 1 # TODO: you can edit it to get larger rect
x *= max(np.hypot(*eye_to_eye) * 2.0 * rect_scale, np.hypot(*eye_to_mouth) * 1.8 * rect_scale)
# y: half height of the oriented crop rectangle
y = np.flipud(x) * [-1, 1]
x *= shrink_ratio[1] # width
y *= shrink_ratio[0] # height
# c: center
c = eye_avg + eye_to_mouth * 0.1
# quad: (left_top, left_bottom, right_bottom, right_top)
quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
# qsize: side length of the square
qsize = np.hypot(*x) * 2
quad_ori = np.copy(quad)
# Shrink, for large face
# TODO: do we really need shrink
shrink = int(np.floor(qsize / output_size * 0.5))
if shrink > 1:
h, w = img.shape[0:2]
rsize = (int(np.rint(float(w) / shrink)), int(np.rint(float(h) / shrink)))
img = cv2.resize(img, rsize, interpolation=cv2.INTER_AREA)
quad /= shrink
qsize /= shrink
# Crop
h, w = img.shape[0:2]
border = max(int(np.rint(qsize * 0.1)), 3)
crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
int(np.ceil(max(quad[:, 1]))))
crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, w), min(crop[3] + border, h))
if crop[2] - crop[0] < w or crop[3] - crop[1] < h:
img = img[crop[1]:crop[3], crop[0]:crop[2], :]
quad -= crop[0:2]
# Pad
# pad: (width_left, height_top, width_right, height_bottom)
h, w = img.shape[0:2]
pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
int(np.ceil(max(quad[:, 1]))))
pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - w + border, 0), max(pad[3] - h + border, 0))
if enable_padding and max(pad) > border - 4:
pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
img = np.pad(img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
h, w = img.shape[0:2]
y, x, _ = np.ogrid[:h, :w, :1]
mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0],
np.float32(w - 1 - x) / pad[2]),
1.0 - np.minimum(np.float32(y) / pad[1],
np.float32(h - 1 - y) / pad[3]))
blur = int(qsize * 0.02)
if blur % 2 == 0:
blur += 1
blur_img = cv2.boxFilter(img, 0, ksize=(blur, blur))
img = img.astype('float32')
img += (blur_img - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)
img = np.clip(img, 0, 255) # float32, [0, 255]
quad += pad[:2]
# Transform use cv2
h_ratio = shrink_ratio[0] / shrink_ratio[1]
dst_h, dst_w = int(transform_size * h_ratio), transform_size
template = np.array([[0, 0], [0, dst_h], [dst_w, dst_h], [dst_w, 0]])
# use cv2.LMEDS method for the equivalence to skimage transform
# ref: https://blog.csdn.net/yichxi/article/details/115827338
affine_matrix = cv2.estimateAffinePartial2D(quad, template, method=cv2.LMEDS)[0]
cropped_face = cv2.warpAffine(
img, affine_matrix, (dst_w, dst_h), borderMode=cv2.BORDER_CONSTANT, borderValue=(135, 133, 132)) # gray
if output_size < transform_size:
cropped_face = cv2.resize(
cropped_face, (output_size, int(output_size * h_ratio)), interpolation=cv2.INTER_LINEAR)
if return_inverse_affine:
dst_h, dst_w = int(output_size * h_ratio), output_size
template = np.array([[0, 0], [0, dst_h], [dst_w, dst_h], [dst_w, 0]])
# use cv2.LMEDS method for the equivalence to skimage transform
# ref: https://blog.csdn.net/yichxi/article/details/115827338
affine_matrix = cv2.estimateAffinePartial2D(
quad_ori, np.array([[0, 0], [0, output_size], [dst_w, dst_h], [dst_w, 0]]), method=cv2.LMEDS)[0]
inverse_affine = cv2.invertAffineTransform(affine_matrix)
else:
inverse_affine = None
return cropped_face, inverse_affine
def paste_face_back(img, face, inverse_affine):
h, w = img.shape[0:2]
face_h, face_w = face.shape[0:2]
inv_restored = cv2.warpAffine(face, inverse_affine, (w, h))
mask = np.ones((face_h, face_w, 3), dtype=np.float32)
inv_mask = cv2.warpAffine(mask, inverse_affine, (w, h))
# remove the black borders
inv_mask_erosion = cv2.erode(inv_mask, np.ones((2, 2), np.uint8))
inv_restored_remove_border = inv_mask_erosion * inv_restored
total_face_area = np.sum(inv_mask_erosion) // 3
# compute the fusion edge based on the area of face
w_edge = int(total_face_area**0.5) // 20
erosion_radius = w_edge * 2
inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
blur_size = w_edge * 2
inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
img = inv_soft_mask * inv_restored_remove_border + (1 - inv_soft_mask) * img
# float32, [0, 255]
return img

View File

@ -0,0 +1,118 @@
import cv2
import os
import os.path as osp
import torch
from torch.hub import download_url_to_file, get_dir
from urllib.parse import urlparse
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
def imwrite(img, file_path, params=None, auto_mkdir=True):
"""Write image to file.
Args:
img (ndarray): Image array to be written.
file_path (str): Image file path.
params (None or list): Same as opencv's :func:`imwrite` interface.
auto_mkdir (bool): If the parent folder of `file_path` does not exist,
whether to create it automatically.
Returns:
bool: Successful or not.
"""
if auto_mkdir:
dir_name = os.path.abspath(os.path.dirname(file_path))
os.makedirs(dir_name, exist_ok=True)
return cv2.imwrite(file_path, img, params)
def img2tensor(imgs, bgr2rgb=True, float32=True):
"""Numpy array to tensor.
Args:
imgs (list[ndarray] | ndarray): Input images.
bgr2rgb (bool): Whether to change bgr to rgb.
float32 (bool): Whether to change to float32.
Returns:
list[tensor] | tensor: Tensor images. If returned results only have
one element, just return tensor.
"""
def _totensor(img, bgr2rgb, float32):
if img.shape[2] == 3 and bgr2rgb:
if img.dtype == 'float64':
img = img.astype('float32')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = torch.from_numpy(img.transpose(2, 0, 1))
if float32:
img = img.float()
return img
if isinstance(imgs, list):
return [_totensor(img, bgr2rgb, float32) for img in imgs]
else:
return _totensor(imgs, bgr2rgb, float32)
def load_file_from_url(url, model_dir=None, progress=True, file_name=None, save_dir=None):
"""Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
"""
if model_dir is None:
hub_dir = get_dir()
model_dir = os.path.join(hub_dir, 'checkpoints')
if save_dir is None:
save_dir = os.path.join(ROOT_DIR, model_dir)
os.makedirs(save_dir, exist_ok=True)
parts = urlparse(url)
filename = os.path.basename(parts.path)
if file_name is not None:
filename = file_name
cached_file = os.path.abspath(os.path.join(save_dir, filename))
if not os.path.exists(cached_file):
print(f'Downloading: "{url}" to {cached_file}\n')
download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
return cached_file
def scandir(dir_path, suffix=None, recursive=False, full_path=False):
"""Scan a directory to find the interested files.
Args:
dir_path (str): Path of the directory.
suffix (str | tuple(str), optional): File suffix that we are
interested in. Default: None.
recursive (bool, optional): If set to True, recursively scan the
directory. Default: False.
full_path (bool, optional): If set to True, include the dir_path.
Default: False.
Returns:
A generator for all the interested files with relative paths.
"""
if (suffix is not None) and not isinstance(suffix, (str, tuple)):
raise TypeError('"suffix" must be a string or tuple of strings')
root = dir_path
def _scandir(dir_path, suffix, recursive):
for entry in os.scandir(dir_path):
if not entry.name.startswith('.') and entry.is_file():
if full_path:
return_path = entry.path
else:
return_path = osp.relpath(entry.path, root)
if suffix is None:
yield return_path
elif return_path.endswith(suffix):
yield return_path
else:
if recursive:
yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
else:
continue
return _scandir(dir_path, suffix=suffix, recursive=recursive)

View File

@ -0,0 +1,322 @@
import math
import random
import torch
from torch import nn
from torch.nn import functional as F
from .stylegan2_clean_arch import StyleGAN2GeneratorClean
class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean):
"""StyleGAN2 Generator with SFT modulation (Spatial Feature Transform).
It is the clean version without custom compiled CUDA extensions used in StyleGAN2.
Args:
out_size (int): The spatial size of outputs.
num_style_feat (int): Channel number of style features. Default: 512.
num_mlp (int): Layer number of MLP style layers. Default: 8.
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
narrow (float): The narrow ratio for channels. Default: 1.
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
"""
def __init__(self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1, sft_half=False):
super(StyleGAN2GeneratorCSFT, self).__init__(
out_size,
num_style_feat=num_style_feat,
num_mlp=num_mlp,
channel_multiplier=channel_multiplier,
narrow=narrow)
self.sft_half = sft_half
def forward(self,
styles,
conditions,
input_is_latent=False,
noise=None,
randomize_noise=True,
truncation=1,
truncation_latent=None,
inject_index=None,
return_latents=False):
"""Forward function for StyleGAN2GeneratorCSFT.
Args:
styles (list[Tensor]): Sample codes of styles.
conditions (list[Tensor]): SFT conditions to generators.
input_is_latent (bool): Whether input is latent style. Default: False.
noise (Tensor | None): Input noise or None. Default: None.
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
truncation (float): The truncation ratio. Default: 1.
truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
inject_index (int | None): The injection index for mixing noise. Default: None.
return_latents (bool): Whether to return style latents. Default: False.
"""
# style codes -> latents with Style MLP layer
if not input_is_latent:
styles = [self.style_mlp(s) for s in styles]
# noises
if noise is None:
if randomize_noise:
noise = [None] * self.num_layers # for each style conv layer
else: # use the stored noise
noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
# style truncation
if truncation < 1:
style_truncation = []
for style in styles:
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
styles = style_truncation
# get style latents with injection
if len(styles) == 1:
inject_index = self.num_latent
if styles[0].ndim < 3:
# repeat latent code for all the layers
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
else: # used for encoder with different latent code for each layer
latent = styles[0]
elif len(styles) == 2: # mixing noises
if inject_index is None:
inject_index = random.randint(1, self.num_latent - 1)
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
latent = torch.cat([latent1, latent2], 1)
# main generation
out = self.constant_input(latent.shape[0])
out = self.style_conv1(out, latent[:, 0], noise=noise[0])
skip = self.to_rgb1(out, latent[:, 1])
i = 1
for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
noise[2::2], self.to_rgbs):
out = conv1(out, latent[:, i], noise=noise1)
# the conditions may have fewer levels
if i < len(conditions):
# SFT part to combine the conditions
if self.sft_half: # only apply SFT to half of the channels
out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
out_sft = out_sft * conditions[i - 1] + conditions[i]
out = torch.cat([out_same, out_sft], dim=1)
else: # apply SFT to all the channels
out = out * conditions[i - 1] + conditions[i]
out = conv2(out, latent[:, i + 1], noise=noise2)
skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
i += 2
image = skip
if return_latents:
return image, latent
else:
return image, None
class ResBlock(nn.Module):
"""Residual block with bilinear upsampling/downsampling.
Args:
in_channels (int): Channel number of the input.
out_channels (int): Channel number of the output.
mode (str): Upsampling/downsampling mode. Options: down | up. Default: down.
"""
def __init__(self, in_channels, out_channels, mode='down'):
super(ResBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, in_channels, 3, 1, 1)
self.conv2 = nn.Conv2d(in_channels, out_channels, 3, 1, 1)
self.skip = nn.Conv2d(in_channels, out_channels, 1, bias=False)
if mode == 'down':
self.scale_factor = 0.5
elif mode == 'up':
self.scale_factor = 2
def forward(self, x):
out = F.leaky_relu_(self.conv1(x), negative_slope=0.2)
# upsample/downsample
out = F.interpolate(out, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)
out = F.leaky_relu_(self.conv2(out), negative_slope=0.2)
# skip
x = F.interpolate(x, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)
skip = self.skip(x)
out = out + skip
return out
class GFPGANv1Clean(nn.Module):
"""The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT.
It is the clean version without custom compiled CUDA extensions used in StyleGAN2.
Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior.
Args:
out_size (int): The spatial size of outputs.
num_style_feat (int): Channel number of style features. Default: 512.
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None.
fix_decoder (bool): Whether to fix the decoder. Default: True.
num_mlp (int): Layer number of MLP style layers. Default: 8.
input_is_latent (bool): Whether input is latent style. Default: False.
different_w (bool): Whether to use different latent w for different layers. Default: False.
narrow (float): The narrow ratio for channels. Default: 1.
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
"""
def __init__(
self,
out_size,
num_style_feat=512,
channel_multiplier=1,
decoder_load_path=None,
fix_decoder=True,
# for stylegan decoder
num_mlp=8,
input_is_latent=False,
different_w=False,
narrow=1,
sft_half=False):
super(GFPGANv1Clean, self).__init__()
self.input_is_latent = input_is_latent
self.different_w = different_w
self.num_style_feat = num_style_feat
unet_narrow = narrow * 0.5 # by default, use a half of input channels
channels = {
'4': int(512 * unet_narrow),
'8': int(512 * unet_narrow),
'16': int(512 * unet_narrow),
'32': int(512 * unet_narrow),
'64': int(256 * channel_multiplier * unet_narrow),
'128': int(128 * channel_multiplier * unet_narrow),
'256': int(64 * channel_multiplier * unet_narrow),
'512': int(32 * channel_multiplier * unet_narrow),
'1024': int(16 * channel_multiplier * unet_narrow)
}
self.log_size = int(math.log(out_size, 2))
first_out_size = 2**(int(math.log(out_size, 2)))
self.conv_body_first = nn.Conv2d(3, channels[f'{first_out_size}'], 1)
# downsample
in_channels = channels[f'{first_out_size}']
self.conv_body_down = nn.ModuleList()
for i in range(self.log_size, 2, -1):
out_channels = channels[f'{2**(i - 1)}']
self.conv_body_down.append(ResBlock(in_channels, out_channels, mode='down'))
in_channels = out_channels
self.final_conv = nn.Conv2d(in_channels, channels['4'], 3, 1, 1)
# upsample
in_channels = channels['4']
self.conv_body_up = nn.ModuleList()
for i in range(3, self.log_size + 1):
out_channels = channels[f'{2**i}']
self.conv_body_up.append(ResBlock(in_channels, out_channels, mode='up'))
in_channels = out_channels
# to RGB
self.toRGB = nn.ModuleList()
for i in range(3, self.log_size + 1):
self.toRGB.append(nn.Conv2d(channels[f'{2**i}'], 3, 1))
if different_w:
linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat
else:
linear_out_channel = num_style_feat
self.final_linear = nn.Linear(channels['4'] * 4 * 4, linear_out_channel)
# the decoder: stylegan2 generator with SFT modulations
self.stylegan_decoder = StyleGAN2GeneratorCSFT(
out_size=out_size,
num_style_feat=num_style_feat,
num_mlp=num_mlp,
channel_multiplier=channel_multiplier,
narrow=narrow,
sft_half=sft_half)
# load pre-trained stylegan2 model if necessary
if decoder_load_path:
self.stylegan_decoder.load_state_dict(
torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema'])
# fix decoder without updating params
if fix_decoder:
for _, param in self.stylegan_decoder.named_parameters():
param.requires_grad = False
# for SFT modulations (scale and shift)
self.condition_scale = nn.ModuleList()
self.condition_shift = nn.ModuleList()
for i in range(3, self.log_size + 1):
out_channels = channels[f'{2**i}']
if sft_half:
sft_out_channels = out_channels
else:
sft_out_channels = out_channels * 2
self.condition_scale.append(
nn.Sequential(
nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, True),
nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1)))
self.condition_shift.append(
nn.Sequential(
nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, True),
nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1)))
def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True, **kwargs):
"""Forward function for GFPGANv1Clean.
Args:
x (Tensor): Input images.
return_latents (bool): Whether to return style latents. Default: False.
return_rgb (bool): Whether return intermediate rgb images. Default: True.
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
"""
conditions = []
unet_skips = []
out_rgbs = []
# encoder
feat = F.leaky_relu_(self.conv_body_first(x), negative_slope=0.2)
for i in range(self.log_size - 2):
feat = self.conv_body_down[i](feat)
unet_skips.insert(0, feat)
feat = F.leaky_relu_(self.final_conv(feat), negative_slope=0.2)
# style code
style_code = self.final_linear(feat.view(feat.size(0), -1))
if self.different_w:
style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)
# decode
for i in range(self.log_size - 2):
# add unet skip
feat = feat + unet_skips[i]
# ResUpLayer
feat = self.conv_body_up[i](feat)
# generate scale and shift for SFT layers
scale = self.condition_scale[i](feat)
conditions.append(scale.clone())
shift = self.condition_shift[i](feat)
conditions.append(shift.clone())
# generate rgb images
if return_rgb:
out_rgbs.append(self.toRGB[i](feat))
# decoder
image, _ = self.stylegan_decoder([style_code],
conditions,
return_latents=return_latents,
input_is_latent=self.input_is_latent,
randomize_noise=randomize_noise)
return image, out_rgbs

View File

@ -0,0 +1,759 @@
"""Modified from https://github.com/wzhouxiff/RestoreFormer"""
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class VectorQuantizer(nn.Module):
"""
see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
____________________________________________
Discretization bottleneck part of the VQ-VAE.
Inputs:
- n_e : number of embeddings
- e_dim : dimension of embedding
- beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
_____________________________________________
"""
def __init__(self, n_e, e_dim, beta):
super(VectorQuantizer, self).__init__()
self.n_e = n_e
self.e_dim = e_dim
self.beta = beta
self.embedding = nn.Embedding(self.n_e, self.e_dim)
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
def forward(self, z):
"""
Inputs the output of the encoder network z and maps it to a discrete
one-hot vector that is the index of the closest embedding vector e_j
z (continuous) -> z_q (discrete)
z.shape = (batch, channel, height, width)
quantization pipeline:
1. get encoder input (B,C,H,W)
2. flatten input to (B*H*W,C)
"""
# reshape z -> (batch, height, width, channel) and flatten
z = z.permute(0, 2, 3, 1).contiguous()
z_flattened = z.view(-1, self.e_dim)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
d = (
torch.sum(z_flattened**2, dim=1, keepdim=True)
+ torch.sum(self.embedding.weight**2, dim=1)
- 2 * torch.matmul(z_flattened, self.embedding.weight.t())
)
# could possible replace this here
# #\start...
# find closest encodings
min_value, min_encoding_indices = torch.min(d, dim=1)
min_encoding_indices = min_encoding_indices.unsqueeze(1)
min_encodings = torch.zeros(min_encoding_indices.shape[0], self.n_e).to(z)
min_encodings.scatter_(1, min_encoding_indices, 1)
# dtype min encodings: torch.float32
# min_encodings shape: torch.Size([2048, 512])
# min_encoding_indices.shape: torch.Size([2048, 1])
# get quantized latent vectors
z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
# .........\end
# with:
# .........\start
# min_encoding_indices = torch.argmin(d, dim=1)
# z_q = self.embedding(min_encoding_indices)
# ......\end......... (TODO)
# compute loss for embedding
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean(
(z_q - z.detach()) ** 2
)
# preserve gradients
z_q = z + (z_q - z).detach()
# perplexity
e_mean = torch.mean(min_encodings, dim=0)
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
# reshape back to match original input shape
z_q = z_q.permute(0, 3, 1, 2).contiguous()
return z_q, loss, (perplexity, min_encodings, min_encoding_indices, d)
def get_codebook_entry(self, indices, shape):
# shape specifying (batch, height, width, channel)
# TODO: check for more easy handling with nn.Embedding
min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices)
min_encodings.scatter_(1, indices[:, None], 1)
# get quantized latent vectors
z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
if shape is not None:
z_q = z_q.view(shape)
# reshape back to match original input shape
z_q = z_q.permute(0, 3, 1, 2).contiguous()
return z_q
# pytorch_diffusion + derived encoder decoder
def nonlinearity(x):
# swish
return x * torch.sigmoid(x)
def Normalize(in_channels):
return torch.nn.GroupNorm(
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
)
class Upsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=1, padding=1
)
def forward(self, x):
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
if self.with_conv:
x = self.conv(x)
return x
class Downsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=2, padding=0
)
def forward(self, x):
if self.with_conv:
pad = (0, 1, 0, 1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
else:
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
return x
class ResnetBlock(nn.Module):
def __init__(
self,
*,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout,
temb_channels=512,
):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = Normalize(in_channels)
self.conv1 = torch.nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=1, padding=1
)
if temb_channels > 0:
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
self.norm2 = Normalize(out_channels)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=1, padding=1
)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=1, padding=1
)
else:
self.nin_shortcut = torch.nn.Conv2d(
in_channels, out_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, x, temb):
h = x
h = self.norm1(h)
h = nonlinearity(h)
h = self.conv1(h)
if temb is not None:
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
h = self.norm2(h)
h = nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x + h
class MultiHeadAttnBlock(nn.Module):
def __init__(self, in_channels, head_size=1):
super().__init__()
self.in_channels = in_channels
self.head_size = head_size
self.att_size = in_channels // head_size
assert (
in_channels % head_size == 0
), "The size of head should be divided by the number of channels."
self.norm1 = Normalize(in_channels)
self.norm2 = Normalize(in_channels)
self.q = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.k = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.v = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.proj_out = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.num = 0
def forward(self, x, y=None):
h_ = x
h_ = self.norm1(h_)
if y is None:
y = h_
else:
y = self.norm2(y)
q = self.q(y)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q.shape
q = q.reshape(b, self.head_size, self.att_size, h * w)
q = q.permute(0, 3, 1, 2) # b, hw, head, att
k = k.reshape(b, self.head_size, self.att_size, h * w)
k = k.permute(0, 3, 1, 2)
v = v.reshape(b, self.head_size, self.att_size, h * w)
v = v.permute(0, 3, 1, 2)
q = q.transpose(1, 2)
v = v.transpose(1, 2)
k = k.transpose(1, 2).transpose(2, 3)
scale = int(self.att_size) ** (-0.5)
q.mul_(scale)
w_ = torch.matmul(q, k)
w_ = F.softmax(w_, dim=3)
w_ = w_.matmul(v)
w_ = w_.transpose(1, 2).contiguous() # [b, h*w, head, att]
w_ = w_.view(b, h, w, -1)
w_ = w_.permute(0, 3, 1, 2)
w_ = self.proj_out(w_)
return x + w_
class MultiHeadEncoder(nn.Module):
def __init__(
self,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks=2,
attn_resolutions=(16,),
dropout=0.0,
resamp_with_conv=True,
in_channels=3,
resolution=512,
z_channels=256,
double_z=True,
enable_mid=True,
head_size=1,
**ignore_kwargs,
):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.enable_mid = enable_mid
# downsampling
self.conv_in = torch.nn.Conv2d(
in_channels, self.ch, kernel_size=3, stride=1, padding=1
)
curr_res = resolution
in_ch_mult = (1,) + tuple(ch_mult)
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(
ResnetBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
)
)
block_in = block_out
if curr_res in attn_resolutions:
attn.append(MultiHeadAttnBlock(block_in, head_size))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in, resamp_with_conv)
curr_res = curr_res // 2
self.down.append(down)
# middle
if self.enable_mid:
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
)
self.mid.attn_1 = MultiHeadAttnBlock(block_in, head_size)
self.mid.block_2 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
)
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(
block_in,
2 * z_channels if double_z else z_channels,
kernel_size=3,
stride=1,
padding=1,
)
def forward(self, x):
hs = {}
# timestep embedding
temb = None
# downsampling
h = self.conv_in(x)
hs["in"] = h
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](h, temb)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
if i_level != self.num_resolutions - 1:
# hs.append(h)
hs["block_" + str(i_level)] = h
h = self.down[i_level].downsample(h)
# middle
# h = hs[-1]
if self.enable_mid:
h = self.mid.block_1(h, temb)
hs["block_" + str(i_level) + "_atten"] = h
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
hs["mid_atten"] = h
# end
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
# hs.append(h)
hs["out"] = h
return hs
class MultiHeadDecoder(nn.Module):
def __init__(
self,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks=2,
attn_resolutions=(16,),
dropout=0.0,
resamp_with_conv=True,
in_channels=3,
resolution=512,
z_channels=256,
give_pre_end=False,
enable_mid=True,
head_size=1,
**ignorekwargs,
):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.give_pre_end = give_pre_end
self.enable_mid = enable_mid
# compute in_ch_mult, block_in and curr_res at lowest res
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
print(
"Working with z of shape {} = {} dimensions.".format(
self.z_shape, np.prod(self.z_shape)
)
)
# z to block_in
self.conv_in = torch.nn.Conv2d(
z_channels, block_in, kernel_size=3, stride=1, padding=1
)
# middle
if self.enable_mid:
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
)
self.mid.attn_1 = MultiHeadAttnBlock(block_in, head_size)
self.mid.block_2 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(
ResnetBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
)
)
block_in = block_out
if curr_res in attn_resolutions:
attn.append(MultiHeadAttnBlock(block_in, head_size))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(
block_in, out_ch, kernel_size=3, stride=1, padding=1
)
def forward(self, z):
# assert z.shape[1:] == self.z_shape[1:]
self.last_z_shape = z.shape
# timestep embedding
temb = None
# z to block_in
h = self.conv_in(z)
# middle
if self.enable_mid:
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h, temb)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
if self.give_pre_end:
return h
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
class MultiHeadDecoderTransformer(nn.Module):
def __init__(
self,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks=2,
attn_resolutions=(16,),
dropout=0.0,
resamp_with_conv=True,
in_channels=3,
resolution=512,
z_channels=256,
give_pre_end=False,
enable_mid=True,
head_size=1,
**ignorekwargs,
):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.give_pre_end = give_pre_end
self.enable_mid = enable_mid
# compute in_ch_mult, block_in and curr_res at lowest res
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
print(
"Working with z of shape {} = {} dimensions.".format(
self.z_shape, np.prod(self.z_shape)
)
)
# z to block_in
self.conv_in = torch.nn.Conv2d(
z_channels, block_in, kernel_size=3, stride=1, padding=1
)
# middle
if self.enable_mid:
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
)
self.mid.attn_1 = MultiHeadAttnBlock(block_in, head_size)
self.mid.block_2 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(
ResnetBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
)
)
block_in = block_out
if curr_res in attn_resolutions:
attn.append(MultiHeadAttnBlock(block_in, head_size))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(
block_in, out_ch, kernel_size=3, stride=1, padding=1
)
def forward(self, z, hs):
# assert z.shape[1:] == self.z_shape[1:]
# self.last_z_shape = z.shape
# timestep embedding
temb = None
# z to block_in
h = self.conv_in(z)
# middle
if self.enable_mid:
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h, hs["mid_atten"])
h = self.mid.block_2(h, temb)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h, temb)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](
h, hs["block_" + str(i_level) + "_atten"]
)
# hfeature = h.clone()
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
if self.give_pre_end:
return h
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
class RestoreFormer(nn.Module):
def __init__(
self,
n_embed=1024,
embed_dim=256,
ch=64,
out_ch=3,
ch_mult=(1, 2, 2, 4, 4, 8),
num_res_blocks=2,
attn_resolutions=(16,),
dropout=0.0,
in_channels=3,
resolution=512,
z_channels=256,
double_z=False,
enable_mid=True,
fix_decoder=False,
fix_codebook=True,
fix_encoder=False,
head_size=8,
):
super(RestoreFormer, self).__init__()
self.encoder = MultiHeadEncoder(
ch=ch,
out_ch=out_ch,
ch_mult=ch_mult,
num_res_blocks=num_res_blocks,
attn_resolutions=attn_resolutions,
dropout=dropout,
in_channels=in_channels,
resolution=resolution,
z_channels=z_channels,
double_z=double_z,
enable_mid=enable_mid,
head_size=head_size,
)
self.decoder = MultiHeadDecoderTransformer(
ch=ch,
out_ch=out_ch,
ch_mult=ch_mult,
num_res_blocks=num_res_blocks,
attn_resolutions=attn_resolutions,
dropout=dropout,
in_channels=in_channels,
resolution=resolution,
z_channels=z_channels,
enable_mid=enable_mid,
head_size=head_size,
)
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25)
self.quant_conv = torch.nn.Conv2d(z_channels, embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1)
if fix_decoder:
for _, param in self.decoder.named_parameters():
param.requires_grad = False
for _, param in self.post_quant_conv.named_parameters():
param.requires_grad = False
for _, param in self.quantize.named_parameters():
param.requires_grad = False
elif fix_codebook:
for _, param in self.quantize.named_parameters():
param.requires_grad = False
if fix_encoder:
for _, param in self.encoder.named_parameters():
param.requires_grad = False
def encode(self, x):
hs = self.encoder(x)
h = self.quant_conv(hs["out"])
quant, emb_loss, info = self.quantize(h)
return quant, emb_loss, info, hs
def decode(self, quant, hs):
quant = self.post_quant_conv(quant)
dec = self.decoder(quant, hs)
return dec
def forward(self, input, **kwargs):
quant, diff, info, hs = self.encode(input)
dec = self.decode(quant, hs)
return dec, None

View File

@ -0,0 +1,434 @@
import math
import random
import torch
from torch import nn
from torch.nn import functional as F
from iopaint.plugins.basicsr.arch_util import default_init_weights
class NormStyleCode(nn.Module):
def forward(self, x):
"""Normalize the style codes.
Args:
x (Tensor): Style codes with shape (b, c).
Returns:
Tensor: Normalized tensor.
"""
return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)
class ModulatedConv2d(nn.Module):
"""Modulated Conv2d used in StyleGAN2.
There is no bias in ModulatedConv2d.
Args:
in_channels (int): Channel number of the input.
out_channels (int): Channel number of the output.
kernel_size (int): Size of the convolving kernel.
num_style_feat (int): Channel number of style features.
demodulate (bool): Whether to demodulate in the conv layer. Default: True.
sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None.
eps (float): A value added to the denominator for numerical stability. Default: 1e-8.
"""
def __init__(
self,
in_channels,
out_channels,
kernel_size,
num_style_feat,
demodulate=True,
sample_mode=None,
eps=1e-8,
):
super(ModulatedConv2d, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.demodulate = demodulate
self.sample_mode = sample_mode
self.eps = eps
# modulation inside each modulated conv
self.modulation = nn.Linear(num_style_feat, in_channels, bias=True)
# initialization
default_init_weights(
self.modulation,
scale=1,
bias_fill=1,
a=0,
mode="fan_in",
nonlinearity="linear",
)
self.weight = nn.Parameter(
torch.randn(1, out_channels, in_channels, kernel_size, kernel_size)
/ math.sqrt(in_channels * kernel_size**2)
)
self.padding = kernel_size // 2
def forward(self, x, style):
"""Forward function.
Args:
x (Tensor): Tensor with shape (b, c, h, w).
style (Tensor): Tensor with shape (b, num_style_feat).
Returns:
Tensor: Modulated tensor after convolution.
"""
b, c, h, w = x.shape # c = c_in
# weight modulation
style = self.modulation(style).view(b, 1, c, 1, 1)
# self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1)
weight = self.weight * style # (b, c_out, c_in, k, k)
if self.demodulate:
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
weight = weight * demod.view(b, self.out_channels, 1, 1, 1)
weight = weight.view(
b * self.out_channels, c, self.kernel_size, self.kernel_size
)
# upsample or downsample if necessary
if self.sample_mode == "upsample":
x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=False)
elif self.sample_mode == "downsample":
x = F.interpolate(x, scale_factor=0.5, mode="bilinear", align_corners=False)
b, c, h, w = x.shape
x = x.view(1, b * c, h, w)
# weight: (b*c_out, c_in, k, k), groups=b
out = F.conv2d(x, weight, padding=self.padding, groups=b)
out = out.view(b, self.out_channels, *out.shape[2:4])
return out
def __repr__(self):
return (
f"{self.__class__.__name__}(in_channels={self.in_channels}, out_channels={self.out_channels}, "
f"kernel_size={self.kernel_size}, demodulate={self.demodulate}, sample_mode={self.sample_mode})"
)
class StyleConv(nn.Module):
"""Style conv used in StyleGAN2.
Args:
in_channels (int): Channel number of the input.
out_channels (int): Channel number of the output.
kernel_size (int): Size of the convolving kernel.
num_style_feat (int): Channel number of style features.
demodulate (bool): Whether demodulate in the conv layer. Default: True.
sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None.
"""
def __init__(
self,
in_channels,
out_channels,
kernel_size,
num_style_feat,
demodulate=True,
sample_mode=None,
):
super(StyleConv, self).__init__()
self.modulated_conv = ModulatedConv2d(
in_channels,
out_channels,
kernel_size,
num_style_feat,
demodulate=demodulate,
sample_mode=sample_mode,
)
self.weight = nn.Parameter(torch.zeros(1)) # for noise injection
self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))
self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x, style, noise=None):
# modulate
out = self.modulated_conv(x, style) * 2**0.5 # for conversion
# noise injection
if noise is None:
b, _, h, w = out.shape
noise = out.new_empty(b, 1, h, w).normal_()
out = out + self.weight * noise
# add bias
out = out + self.bias
# activation
out = self.activate(out)
return out
class ToRGB(nn.Module):
"""To RGB (image space) from features.
Args:
in_channels (int): Channel number of input.
num_style_feat (int): Channel number of style features.
upsample (bool): Whether to upsample. Default: True.
"""
def __init__(self, in_channels, num_style_feat, upsample=True):
super(ToRGB, self).__init__()
self.upsample = upsample
self.modulated_conv = ModulatedConv2d(
in_channels,
3,
kernel_size=1,
num_style_feat=num_style_feat,
demodulate=False,
sample_mode=None,
)
self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
def forward(self, x, style, skip=None):
"""Forward function.
Args:
x (Tensor): Feature tensor with shape (b, c, h, w).
style (Tensor): Tensor with shape (b, num_style_feat).
skip (Tensor): Base/skip tensor. Default: None.
Returns:
Tensor: RGB images.
"""
out = self.modulated_conv(x, style)
out = out + self.bias
if skip is not None:
if self.upsample:
skip = F.interpolate(
skip, scale_factor=2, mode="bilinear", align_corners=False
)
out = out + skip
return out
class ConstantInput(nn.Module):
"""Constant input.
Args:
num_channel (int): Channel number of constant input.
size (int): Spatial size of constant input.
"""
def __init__(self, num_channel, size):
super(ConstantInput, self).__init__()
self.weight = nn.Parameter(torch.randn(1, num_channel, size, size))
def forward(self, batch):
out = self.weight.repeat(batch, 1, 1, 1)
return out
class StyleGAN2GeneratorClean(nn.Module):
"""Clean version of StyleGAN2 Generator.
Args:
out_size (int): The spatial size of outputs.
num_style_feat (int): Channel number of style features. Default: 512.
num_mlp (int): Layer number of MLP style layers. Default: 8.
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
narrow (float): Narrow ratio for channels. Default: 1.0.
"""
def __init__(
self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1
):
super(StyleGAN2GeneratorClean, self).__init__()
# Style MLP layers
self.num_style_feat = num_style_feat
style_mlp_layers = [NormStyleCode()]
for i in range(num_mlp):
style_mlp_layers.extend(
[
nn.Linear(num_style_feat, num_style_feat, bias=True),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
]
)
self.style_mlp = nn.Sequential(*style_mlp_layers)
# initialization
default_init_weights(
self.style_mlp,
scale=1,
bias_fill=0,
a=0.2,
mode="fan_in",
nonlinearity="leaky_relu",
)
# channel list
channels = {
"4": int(512 * narrow),
"8": int(512 * narrow),
"16": int(512 * narrow),
"32": int(512 * narrow),
"64": int(256 * channel_multiplier * narrow),
"128": int(128 * channel_multiplier * narrow),
"256": int(64 * channel_multiplier * narrow),
"512": int(32 * channel_multiplier * narrow),
"1024": int(16 * channel_multiplier * narrow),
}
self.channels = channels
self.constant_input = ConstantInput(channels["4"], size=4)
self.style_conv1 = StyleConv(
channels["4"],
channels["4"],
kernel_size=3,
num_style_feat=num_style_feat,
demodulate=True,
sample_mode=None,
)
self.to_rgb1 = ToRGB(channels["4"], num_style_feat, upsample=False)
self.log_size = int(math.log(out_size, 2))
self.num_layers = (self.log_size - 2) * 2 + 1
self.num_latent = self.log_size * 2 - 2
self.style_convs = nn.ModuleList()
self.to_rgbs = nn.ModuleList()
self.noises = nn.Module()
in_channels = channels["4"]
# noise
for layer_idx in range(self.num_layers):
resolution = 2 ** ((layer_idx + 5) // 2)
shape = [1, 1, resolution, resolution]
self.noises.register_buffer(f"noise{layer_idx}", torch.randn(*shape))
# style convs and to_rgbs
for i in range(3, self.log_size + 1):
out_channels = channels[f"{2 ** i}"]
self.style_convs.append(
StyleConv(
in_channels,
out_channels,
kernel_size=3,
num_style_feat=num_style_feat,
demodulate=True,
sample_mode="upsample",
)
)
self.style_convs.append(
StyleConv(
out_channels,
out_channels,
kernel_size=3,
num_style_feat=num_style_feat,
demodulate=True,
sample_mode=None,
)
)
self.to_rgbs.append(ToRGB(out_channels, num_style_feat, upsample=True))
in_channels = out_channels
def make_noise(self):
"""Make noise for noise injection."""
device = self.constant_input.weight.device
noises = [torch.randn(1, 1, 4, 4, device=device)]
for i in range(3, self.log_size + 1):
for _ in range(2):
noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))
return noises
def get_latent(self, x):
return self.style_mlp(x)
def mean_latent(self, num_latent):
latent_in = torch.randn(
num_latent, self.num_style_feat, device=self.constant_input.weight.device
)
latent = self.style_mlp(latent_in).mean(0, keepdim=True)
return latent
def forward(
self,
styles,
input_is_latent=False,
noise=None,
randomize_noise=True,
truncation=1,
truncation_latent=None,
inject_index=None,
return_latents=False,
):
"""Forward function for StyleGAN2GeneratorClean.
Args:
styles (list[Tensor]): Sample codes of styles.
input_is_latent (bool): Whether input is latent style. Default: False.
noise (Tensor | None): Input noise or None. Default: None.
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
truncation (float): The truncation ratio. Default: 1.
truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
inject_index (int | None): The injection index for mixing noise. Default: None.
return_latents (bool): Whether to return style latents. Default: False.
"""
# style codes -> latents with Style MLP layer
if not input_is_latent:
styles = [self.style_mlp(s) for s in styles]
# noises
if noise is None:
if randomize_noise:
noise = [None] * self.num_layers # for each style conv layer
else: # use the stored noise
noise = [
getattr(self.noises, f"noise{i}") for i in range(self.num_layers)
]
# style truncation
if truncation < 1:
style_truncation = []
for style in styles:
style_truncation.append(
truncation_latent + truncation * (style - truncation_latent)
)
styles = style_truncation
# get style latents with injection
if len(styles) == 1:
inject_index = self.num_latent
if styles[0].ndim < 3:
# repeat latent code for all the layers
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
else: # used for encoder with different latent code for each layer
latent = styles[0]
elif len(styles) == 2: # mixing noises
if inject_index is None:
inject_index = random.randint(1, self.num_latent - 1)
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
latent2 = (
styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
)
latent = torch.cat([latent1, latent2], 1)
# main generation
out = self.constant_input(latent.shape[0])
out = self.style_conv1(out, latent[:, 0], noise=noise[0])
skip = self.to_rgb1(out, latent[:, 1])
i = 1
for conv1, conv2, noise1, noise2, to_rgb in zip(
self.style_convs[::2],
self.style_convs[1::2],
noise[1::2],
noise[2::2],
self.to_rgbs,
):
out = conv1(out, latent[:, i], noise=noise1)
out = conv2(out, latent[:, i + 1], noise=noise2)
skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
i += 2
image = skip
if return_latents:
return image, latent
else:
return image, None

View File

@ -20,11 +20,6 @@ class GFPGANPlugin(BasePlugin):
model_path = download_model(url, model_md5)
logger.info(f"GFPGAN model path: {model_path}")
import facexlib
if hasattr(facexlib.detection.retinaface, "device"):
facexlib.detection.retinaface.device = device
# Use GFPGAN for face enhancement
self.face_enhancer = MyGFPGANer(
model_path=model_path,
@ -64,11 +59,3 @@ class GFPGANPlugin(BasePlugin):
# except Exception as error:
# print("wrong scale input.", error)
return bgr_output
def check_dep(self):
try:
import gfpgan
except ImportError:
return (
"gfpgan is not installed, please install it first. pip install gfpgan"
)

View File

@ -1,12 +1,16 @@
import os
import cv2
import torch
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
from gfpgan import GFPGANv1Clean, GFPGANer
from torchvision.transforms.functional import normalize
from torch.hub import get_dir
from .facexlib.utils.face_restoration_helper import FaceRestoreHelper
from .gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean
from .basicsr.img_util import img2tensor, tensor2img
class MyGFPGANer(GFPGANer):
class MyGFPGANer:
"""Helper for restoration with GFPGAN.
It will detect and crop faces, and then resize the faces to 512x512.
@ -55,7 +59,7 @@ class MyGFPGANer(GFPGANer):
sft_half=True,
)
elif arch == "RestoreFormer":
from gfpgan.archs.restoreformer_arch import RestoreFormer
from .gfpgan.archs.restoreformer_arch import RestoreFormer
self.gfpgan = RestoreFormer()
@ -82,3 +86,71 @@ class MyGFPGANer(GFPGANer):
self.gfpgan.load_state_dict(loadnet[keyname], strict=True)
self.gfpgan.eval()
self.gfpgan = self.gfpgan.to(self.device)
@torch.no_grad()
def enhance(
self,
img,
has_aligned=False,
only_center_face=False,
paste_back=True,
weight=0.5,
):
self.face_helper.clean_all()
if has_aligned: # the inputs are already aligned
img = cv2.resize(img, (512, 512))
self.face_helper.cropped_faces = [img]
else:
self.face_helper.read_image(img)
# get face landmarks for each face
self.face_helper.get_face_landmarks_5(
only_center_face=only_center_face, eye_dist_threshold=5
)
# eye_dist_threshold=5: skip faces whose eye distance is smaller than 5 pixels
# TODO: even with eye_dist_threshold, it will still introduce wrong detections and restorations.
# align and warp each face
self.face_helper.align_warp_face()
# face restoration
for cropped_face in self.face_helper.cropped_faces:
# prepare data
cropped_face_t = img2tensor(
cropped_face / 255.0, bgr2rgb=True, float32=True
)
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
cropped_face_t = cropped_face_t.unsqueeze(0).to(self.device)
try:
output = self.gfpgan(cropped_face_t, return_rgb=False, weight=weight)[0]
# convert to image
restored_face = tensor2img(
output.squeeze(0), rgb2bgr=True, min_max=(-1, 1)
)
except RuntimeError as error:
print(f"\tFailed inference for GFPGAN: {error}.")
restored_face = cropped_face
restored_face = restored_face.astype("uint8")
self.face_helper.add_restored_face(restored_face)
if not has_aligned and paste_back:
# upsample the background
if self.bg_upsampler is not None:
# Now only support RealESRGAN for upsampling background
bg_img = self.bg_upsampler.enhance(img, outscale=self.upscale)[0]
else:
bg_img = None
self.face_helper.get_inverse_affine(None)
# paste each restored face to the input image
restored_img = self.face_helper.paste_faces_to_input_image(
upsample_img=bg_img
)
return (
self.face_helper.cropped_faces,
self.face_helper.restored_faces,
restored_img,
)
else:
return self.face_helper.cropped_faces, self.face_helper.restored_faces, None

View File

@ -466,6 +466,3 @@ class RealESRGANUpscaler(BasePlugin):
# 输出是 BGR
upsampled = self.model.enhance(bgr_np_img, outscale=scale)[0]
return upsampled
def check_dep(self):
pass

View File

@ -20,11 +20,6 @@ class RestoreFormerPlugin(BasePlugin):
model_path = download_model(url, model_md5)
logger.info(f"RestoreFormer model path: {model_path}")
import facexlib
if hasattr(facexlib.detection.retinaface, "device"):
facexlib.detection.retinaface.device = device
self.face_enhancer = MyGFPGANer(
model_path=model_path,
upscale=1,
@ -47,11 +42,3 @@ class RestoreFormerPlugin(BasePlugin):
)
logger.info(f"RestoreFormer output shape: {bgr_output.shape}")
return bgr_output
def check_dep(self):
try:
import gfpgan
except ImportError:
return (
"gfpgan is not installed, please install it first. pip install gfpgan"
)

View File

@ -30,7 +30,6 @@ _CANDIDATES = [
"accelerate",
"iopaint",
"rembg",
"gfpgan",
]
# Check once at runtime
for name in _CANDIDATES:

View File

@ -26,6 +26,11 @@ rgb_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB)
rgb_img_base64 = encode_pil_to_base64(Image.fromarray(rgb_img), 100, {})
bgr_img_base64 = encode_pil_to_base64(Image.fromarray(bgr_img), 100, {})
person_p = current_dir / "image.png"
person_bgr_img = cv2.imread(str(person_p))
person_rgb_img = cv2.cvtColor(person_bgr_img, cv2.COLOR_BGR2RGB)
person_rgb_img = cv2.resize(person_rgb_img, (512, 512))
def _save(img, name):
cv2.imwrite(str(save_dir / name), img)
@ -86,7 +91,7 @@ def test_gfpgan(device):
check_device(device)
model = GFPGANPlugin(device)
res = model.gen_image(
rgb_img, RunPluginRequest(name=GFPGANPlugin.name, image=rgb_img_base64)
person_rgb_img, RunPluginRequest(name=GFPGANPlugin.name, image=rgb_img_base64)
)
_save(res, f"test_gfpgan_{device}.png")
@ -96,7 +101,8 @@ def test_restoreformer(device):
check_device(device)
model = RestoreFormerPlugin(device)
res = model.gen_image(
rgb_img, RunPluginRequest(name=RestoreFormerPlugin.name, image=rgb_img_base64)
person_rgb_img,
RunPluginRequest(name=RestoreFormerPlugin.name, image=rgb_img_base64),
)
_save(res, f"test_restoreformer_{device}.png")